| /* |
| *------------------------------------------------------------------ |
| * Copyright (c) 2020 Cisco and/or its affiliates. |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at: |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| *------------------------------------------------------------------ |
| */ |
| |
| package memif |
| |
| import ( |
| "bytes" |
| "container/list" |
| "encoding/binary" |
| "fmt" |
| "os" |
| "sync" |
| "syscall" |
| ) |
| |
| const maxEpollEvents = 1 |
| const maxControlLen = 256 |
| |
| const errorFdNotFound = "fd not found" |
| |
| // controlMsg represents a message used in communication between memif peers |
| type controlMsg struct { |
| Buffer *bytes.Buffer |
| Fd int |
| } |
| |
| // listener represents a listener functionality of UNIX domain socket |
| type listener struct { |
| socket *Socket |
| event syscall.EpollEvent |
| } |
| |
| // controlChannel represents a communication channel between memif peers |
| // backed by UNIX domain socket |
| type controlChannel struct { |
| listRef *list.Element |
| socket *Socket |
| i *Interface |
| event syscall.EpollEvent |
| data [msgSize]byte |
| control [maxControlLen]byte |
| controlLen int |
| msgQueue []controlMsg |
| isConnected bool |
| } |
| |
| // Socket represents a UNIX domain socket used for communication |
| // between memif peers |
| type Socket struct { |
| appName string |
| filename string |
| listener *listener |
| interfaceList *list.List |
| ccList *list.List |
| epfd int |
| interruptfd int |
| wakeEvent syscall.EpollEvent |
| stopPollChan chan struct{} |
| wg sync.WaitGroup |
| } |
| |
| type interrupt struct { |
| socket *Socket |
| event syscall.EpollEvent |
| } |
| |
| type memifInterrupt struct { |
| connection *Socket |
| qid uint16 |
| } |
| |
| // StopPolling stops polling events on the socket |
| func (socket *Socket) StopPolling() error { |
| if socket.stopPollChan != nil { |
| // stop polling msg |
| close(socket.stopPollChan) |
| // wake epoll |
| buf := make([]byte, 8) |
| binary.PutUvarint(buf, 1) |
| n, err := syscall.Write(int(socket.wakeEvent.Fd), buf[:]) |
| if err != nil { |
| return err |
| } |
| if n != 8 { |
| return fmt.Errorf("Faild to write to eventfd") |
| } |
| // wait until polling is stopped |
| socket.wg.Wait() |
| } |
| |
| return nil |
| } |
| |
| // StartPolling starts polling and handling events on the socket, |
| // enabling communication between memif peers |
| func (socket *Socket) StartPolling(errChan chan<- error) { |
| socket.stopPollChan = make(chan struct{}) |
| socket.wg.Add(1) |
| go func() { |
| var events [maxEpollEvents]syscall.EpollEvent |
| defer socket.wg.Done() |
| |
| for { |
| select { |
| case <-socket.stopPollChan: |
| return |
| default: |
| num, err := syscall.EpollWait(socket.epfd, events[:], -1) |
| if err != nil { |
| errChan <- fmt.Errorf("EpollWait: ", err) |
| return |
| } |
| |
| for ev := 0; ev < num; ev++ { |
| if events[0].Fd == socket.wakeEvent.Fd { |
| continue |
| } |
| err = socket.handleEvent(&events[0]) |
| if err != nil { |
| errChan <- fmt.Errorf("handleEvent: ", err) |
| } |
| } |
| } |
| } |
| }() |
| } |
| |
| // addEvent adds event to epoll instance associated with the socket |
| func (socket *Socket) addEvent(event *syscall.EpollEvent) error { |
| err := syscall.EpollCtl(socket.epfd, syscall.EPOLL_CTL_ADD, int(event.Fd), event) |
| if err != nil { |
| return fmt.Errorf("EpollCtl: %s", err) |
| } |
| return nil |
| } |
| |
| // addEvent deletes event to epoll instance associated with the socket |
| func (socket *Socket) delEvent(event *syscall.EpollEvent) error { |
| err := syscall.EpollCtl(socket.epfd, syscall.EPOLL_CTL_DEL, int(event.Fd), event) |
| if err != nil { |
| return fmt.Errorf("EpollCtl: %s", err) |
| } |
| return nil |
| } |
| |
| // Delete deletes the socket |
| func (socket *Socket) Delete() (err error) { |
| for elt := socket.ccList.Front(); elt != nil; elt = elt.Next() { |
| cc, ok := elt.Value.(*controlChannel) |
| if ok { |
| err = cc.close(true, "Socket deleted") |
| if err != nil { |
| return nil |
| } |
| } |
| } |
| for elt := socket.interfaceList.Front(); elt != nil; elt = elt.Next() { |
| i, ok := elt.Value.(*Interface) |
| if ok { |
| err = i.Delete() |
| if err != nil { |
| return err |
| } |
| } |
| } |
| |
| if socket.listener != nil { |
| err = socket.listener.close() |
| if err != nil { |
| return err |
| } |
| err = os.Remove(socket.filename) |
| if err != nil { |
| return nil |
| } |
| } |
| |
| err = socket.delEvent(&socket.wakeEvent) |
| if err != nil { |
| return fmt.Errorf("Failed to delete event: ", err) |
| } |
| |
| syscall.Close(socket.epfd) |
| |
| return nil |
| } |
| |
| // NewSocket returns a new Socket |
| func NewSocket(appName string, filename string) (socket *Socket, err error) { |
| socket = &Socket{ |
| appName: appName, |
| filename: filename, |
| interfaceList: list.New(), |
| ccList: list.New(), |
| } |
| if socket.filename == "" { |
| socket.filename = DefaultSocketFilename |
| } |
| |
| socket.epfd, _ = syscall.EpollCreate1(0) |
| |
| efd, err := eventFd() |
| socket.wakeEvent = syscall.EpollEvent{ |
| Events: syscall.EPOLLIN | syscall.EPOLLERR | syscall.EPOLLHUP, |
| Fd: int32(efd), |
| } |
| err = socket.addEvent(&socket.wakeEvent) |
| if err != nil { |
| return nil, fmt.Errorf("Failed to add event: ", err) |
| } |
| |
| return socket, nil |
| } |
| |
| // handleEvent handles epoll event |
| func (socket *Socket) handleEvent(event *syscall.EpollEvent) error { |
| if socket.listener != nil && socket.listener.event.Fd == event.Fd { |
| return socket.listener.handleEvent(event) |
| } |
| intf := socket.interfaceList.Back().Value.(*Interface) |
| if intf.args.InterruptFunc != nil { |
| if int(event.Fd) == int(intf.args.InterruptFd) { |
| b := make([]byte, 8) |
| syscall.Read(int(event.Fd), b) |
| intf.onInterrupt(intf) |
| return nil |
| } |
| } |
| |
| for elt := socket.ccList.Front(); elt != nil; elt = elt.Next() { |
| cc, ok := elt.Value.(*controlChannel) |
| if ok { |
| if cc.event.Fd == event.Fd { |
| return cc.handleEvent(event) |
| } |
| } |
| } |
| |
| return fmt.Errorf(errorFdNotFound) |
| } |
| |
| func (socket *Socket) addInterrupt(fd int) (err error) { |
| l := &interrupt{ |
| // we will need this to look up master interface by id |
| socket: socket, |
| } |
| |
| l.event = syscall.EpollEvent{ |
| Events: syscall.EPOLLIN, |
| Fd: int32(fd), |
| } |
| err = socket.addEvent(&l.event) |
| if err != nil { |
| return fmt.Errorf("Failed to add event: ", err) |
| } |
| |
| return nil |
| |
| } |
| |
| // handleEvent handles epoll event for listener |
| func (l *listener) handleEvent(event *syscall.EpollEvent) error { |
| // hang up |
| if (event.Events & syscall.EPOLLHUP) == syscall.EPOLLHUP { |
| err := l.close() |
| if err != nil { |
| return fmt.Errorf("Failed to close listener after hang up event: ", err) |
| } |
| return fmt.Errorf("Hang up: ", l.socket.filename) |
| } |
| |
| // error |
| if (event.Events & syscall.EPOLLERR) == syscall.EPOLLERR { |
| err := l.close() |
| if err != nil { |
| return fmt.Errorf("Failed to close listener after receiving an error event: ", err) |
| } |
| return fmt.Errorf("Received error event on listener ", l.socket.filename) |
| } |
| |
| // read message |
| if (event.Events & syscall.EPOLLIN) == syscall.EPOLLIN { |
| newFd, _, err := syscall.Accept(int(l.event.Fd)) |
| if err != nil { |
| return fmt.Errorf("Accept: %s", err) |
| } |
| |
| cc, err := l.socket.addControlChannel(newFd, nil) |
| if err != nil { |
| return fmt.Errorf("Failed to add control channel: %s", err) |
| } |
| |
| err = cc.msgEnqHello() |
| if err != nil { |
| return fmt.Errorf("msgEnqHello: %s", err) |
| } |
| |
| err = cc.sendMsg() |
| if err != nil { |
| return err |
| } |
| |
| return nil |
| } |
| |
| return fmt.Errorf("Unexpected event: ", event.Events) |
| } |
| |
| // handleEvent handles epoll event for control channel |
| func (cc *controlChannel) handleEvent(event *syscall.EpollEvent) error { |
| var size int |
| var err error |
| |
| // hang up |
| if (event.Events & syscall.EPOLLHUP) == syscall.EPOLLHUP { |
| // close cc, don't send msg |
| err := cc.close(false, "") |
| if err != nil { |
| return fmt.Errorf("Failed to close control channel after hang up event: ", err) |
| } |
| return fmt.Errorf("Hang up: ", cc.i.GetName()) |
| } |
| |
| if (event.Events & syscall.EPOLLERR) == syscall.EPOLLERR { |
| // close cc, don't send msg |
| err := cc.close(false, "") |
| if err != nil { |
| return fmt.Errorf("Failed to close control channel after receiving an error event: ", err) |
| } |
| return fmt.Errorf("Received error event on control channel ", cc.i.GetName()) |
| } |
| |
| if (event.Events & syscall.EPOLLIN) == syscall.EPOLLIN { |
| size, cc.controlLen, _, _, err = syscall.Recvmsg(int(cc.event.Fd), cc.data[:], cc.control[:], 0) |
| if err != nil { |
| return fmt.Errorf("recvmsg: %s", err) |
| } |
| if size != msgSize { |
| return fmt.Errorf("invalid message size %d", size) |
| } |
| |
| err = cc.parseMsg() |
| if err != nil { |
| return err |
| } |
| |
| err = cc.sendMsg() |
| if err != nil { |
| return err |
| } |
| |
| return nil |
| } |
| |
| return fmt.Errorf("Unexpected event: ", event.Events) |
| } |
| |
| // close closes the listener |
| func (l *listener) close() error { |
| err := l.socket.delEvent(&l.event) |
| if err != nil { |
| return fmt.Errorf("Failed to del event: ", err) |
| } |
| err = syscall.Close(int(l.event.Fd)) |
| if err != nil { |
| return fmt.Errorf("Failed to close socket: ", err) |
| } |
| return nil |
| } |
| |
| // AddListener adds a lisntener to the socket. The fd must describe a |
| // UNIX domain socket already bound to a UNIX domain filename and |
| // marked as listener |
| func (socket *Socket) AddListener(fd int) (err error) { |
| l := &listener{ |
| // we will need this to look up master interface by id |
| socket: socket, |
| } |
| |
| l.event = syscall.EpollEvent{ |
| Events: syscall.EPOLLIN | syscall.EPOLLERR | syscall.EPOLLHUP, |
| Fd: int32(fd), |
| } |
| err = socket.addEvent(&l.event) |
| if err != nil { |
| return fmt.Errorf("Failed to add event: ", err) |
| } |
| |
| socket.listener = l |
| |
| return nil |
| } |
| |
| // addListener creates new UNIX domain socket, binds it to the address |
| // and marks it as listener |
| func (socket *Socket) addListener() (err error) { |
| // create socket |
| fd, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_SEQPACKET, 0) |
| if err != nil { |
| return fmt.Errorf("Failed to create UNIX domain socket") |
| } |
| usa := &syscall.SockaddrUnix{Name: socket.filename} |
| // Bind to address and start listening |
| err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_PASSCRED, 1) |
| if err != nil { |
| return fmt.Errorf("Failed to set socket option %s : %v", socket.filename, err) |
| } |
| err = syscall.Bind(fd, usa) |
| if err != nil { |
| return fmt.Errorf("Failed to bind socket %s : %v", socket.filename, err) |
| } |
| err = syscall.Listen(fd, syscall.SOMAXCONN) |
| if err != nil { |
| return fmt.Errorf("Failed to listen on socket %s : %v", socket.filename, err) |
| } |
| |
| return socket.AddListener(fd) |
| } |
| |
| // close closes a control channel, if the control channel is assigned an |
| // interface, the interface is disconnected |
| func (cc *controlChannel) close(sendMsg bool, str string) (err error) { |
| if sendMsg == true { |
| // first clear message queue so that the disconnect |
| // message is the only message in queue |
| cc.msgQueue = []controlMsg{} |
| cc.msgEnqDisconnect(str) |
| |
| err = cc.sendMsg() |
| if err != nil { |
| return err |
| } |
| } |
| |
| err = cc.socket.delEvent(&cc.event) |
| if err != nil { |
| return fmt.Errorf("Failed to del event: ", err) |
| } |
| |
| // remove referance form socket |
| cc.socket.ccList.Remove(cc.listRef) |
| |
| if cc.i != nil { |
| err = cc.i.disconnect() |
| if err != nil { |
| return fmt.Errorf("Interface Disconnect: ", err) |
| } |
| } |
| |
| return nil |
| } |
| |
| //addControlChannel returns a new controlChannel and adds it to the socket |
| func (socket *Socket) addControlChannel(fd int, i *Interface) (*controlChannel, error) { |
| cc := &controlChannel{ |
| socket: socket, |
| i: i, |
| isConnected: false, |
| } |
| |
| var err error |
| |
| cc.event = syscall.EpollEvent{ |
| Events: syscall.EPOLLIN | syscall.EPOLLERR | syscall.EPOLLHUP, |
| Fd: int32(fd), |
| } |
| err = socket.addEvent(&cc.event) |
| if err != nil { |
| return nil, fmt.Errorf("Failed to add event: ", err) |
| } |
| |
| cc.listRef = socket.ccList.PushBack(cc) |
| |
| return cc, nil |
| } |
| |
| func (cc *controlChannel) msgEnqAck() (err error) { |
| buf := new(bytes.Buffer) |
| err = binary.Write(buf, binary.LittleEndian, msgTypeAck) |
| |
| msg := controlMsg{ |
| Buffer: buf, |
| Fd: -1, |
| } |
| |
| cc.msgQueue = append(cc.msgQueue, msg) |
| |
| return nil |
| } |
| |
| func (cc *controlChannel) msgEnqHello() (err error) { |
| hello := MsgHello{ |
| VersionMin: Version, |
| VersionMax: Version, |
| MaxRegion: 255, |
| MaxRingM2S: 255, |
| MaxRingS2M: 255, |
| MaxLog2RingSize: 14, |
| } |
| |
| copy(hello.Name[:], []byte(cc.socket.appName)) |
| |
| buf := new(bytes.Buffer) |
| err = binary.Write(buf, binary.LittleEndian, msgTypeHello) |
| err = binary.Write(buf, binary.LittleEndian, hello) |
| |
| msg := controlMsg{ |
| Buffer: buf, |
| Fd: -1, |
| } |
| |
| cc.msgQueue = append(cc.msgQueue, msg) |
| |
| return nil |
| } |
| |
| func (cc *controlChannel) parseHello() (err error) { |
| var hello MsgHello |
| |
| buf := bytes.NewReader(cc.data[msgTypeSize:]) |
| err = binary.Read(buf, binary.LittleEndian, &hello) |
| if err != nil { |
| return |
| } |
| |
| if hello.VersionMin > Version || hello.VersionMax < Version { |
| return fmt.Errorf("Incompatible memif version") |
| } |
| |
| cc.i.run = cc.i.args.MemoryConfig |
| |
| cc.i.run.NumQueuePairs = min16(cc.i.args.MemoryConfig.NumQueuePairs, hello.MaxRingS2M) |
| cc.i.run.NumQueuePairs = min16(cc.i.args.MemoryConfig.NumQueuePairs, hello.MaxRingM2S) |
| cc.i.run.Log2RingSize = min8(cc.i.args.MemoryConfig.Log2RingSize, hello.MaxLog2RingSize) |
| |
| cc.i.remoteName = string(hello.Name[:]) |
| |
| return nil |
| } |
| |
| func (cc *controlChannel) msgEnqInit() (err error) { |
| init := MsgInit{ |
| Version: Version, |
| Id: cc.i.args.Id, |
| Mode: cc.i.args.Mode, |
| } |
| |
| copy(init.Name[:], []byte(cc.socket.appName)) |
| |
| buf := new(bytes.Buffer) |
| err = binary.Write(buf, binary.LittleEndian, msgTypeInit) |
| err = binary.Write(buf, binary.LittleEndian, init) |
| |
| msg := controlMsg{ |
| Buffer: buf, |
| Fd: -1, |
| } |
| |
| cc.msgQueue = append(cc.msgQueue, msg) |
| |
| return nil |
| } |
| |
| func (cc *controlChannel) parseInit() (err error) { |
| var init MsgInit |
| |
| buf := bytes.NewReader(cc.data[msgTypeSize:]) |
| err = binary.Read(buf, binary.LittleEndian, &init) |
| if err != nil { |
| return |
| } |
| |
| if init.Version != Version { |
| return fmt.Errorf("Incompatible memif driver version") |
| } |
| |
| // find peer interface |
| for elt := cc.socket.interfaceList.Front(); elt != nil; elt = elt.Next() { |
| i, ok := elt.Value.(*Interface) |
| if ok { |
| if i.args.Id == init.Id && i.args.IsMaster && i.cc == nil { |
| // verify secret |
| if i.args.Secret != init.Secret { |
| return fmt.Errorf("Invalid secret") |
| } |
| // interface is assigned to control channel |
| i.cc = cc |
| cc.i = i |
| cc.i.run = cc.i.args.MemoryConfig |
| cc.i.remoteName = string(init.Name[:]) |
| |
| return nil |
| } |
| } |
| } |
| |
| return fmt.Errorf("Invalid interface id") |
| } |
| |
| func (cc *controlChannel) msgEnqAddRegion(regionIndex uint16) (err error) { |
| if len(cc.i.regions) <= int(regionIndex) { |
| return fmt.Errorf("Invalid region index") |
| } |
| |
| addRegion := MsgAddRegion{ |
| Index: regionIndex, |
| Size: cc.i.regions[regionIndex].size, |
| } |
| |
| buf := new(bytes.Buffer) |
| err = binary.Write(buf, binary.LittleEndian, msgTypeAddRegion) |
| err = binary.Write(buf, binary.LittleEndian, addRegion) |
| |
| msg := controlMsg{ |
| Buffer: buf, |
| Fd: cc.i.regions[regionIndex].fd, |
| } |
| |
| cc.msgQueue = append(cc.msgQueue, msg) |
| |
| return nil |
| } |
| |
| func (cc *controlChannel) parseAddRegion() (err error) { |
| var addRegion MsgAddRegion |
| |
| buf := bytes.NewReader(cc.data[msgTypeSize:]) |
| err = binary.Read(buf, binary.LittleEndian, &addRegion) |
| if err != nil { |
| return |
| } |
| |
| fd, err := cc.parseControlMsg() |
| if err != nil { |
| return fmt.Errorf("parseControlMsg: %s", err) |
| } |
| |
| if addRegion.Index > 255 { |
| return fmt.Errorf("Invalid memory region index") |
| } |
| |
| region := memoryRegion{ |
| size: addRegion.Size, |
| fd: fd, |
| } |
| |
| cc.i.regions = append(cc.i.regions, region) |
| |
| return nil |
| } |
| |
| func (cc *controlChannel) msgEnqAddRing(ringType ringType, ringIndex uint16) (err error) { |
| var q Queue |
| var flags uint16 = 0 |
| |
| if ringType == ringTypeS2M { |
| q = cc.i.txQueues[ringIndex] |
| flags = msgAddRingFlagS2M |
| } else { |
| q = cc.i.rxQueues[ringIndex] |
| } |
| |
| addRing := MsgAddRing{ |
| Index: ringIndex, |
| Offset: uint32(q.ring.offset), |
| Region: uint16(q.ring.region), |
| RingSizeLog2: uint8(q.ring.log2Size), |
| Flags: flags, |
| PrivateHdrSize: 0, |
| } |
| |
| buf := new(bytes.Buffer) |
| err = binary.Write(buf, binary.LittleEndian, msgTypeAddRing) |
| err = binary.Write(buf, binary.LittleEndian, addRing) |
| |
| msg := controlMsg{ |
| Buffer: buf, |
| Fd: q.interruptFd, |
| } |
| |
| cc.msgQueue = append(cc.msgQueue, msg) |
| |
| return nil |
| } |
| |
| func (cc *controlChannel) parseAddRing() (err error) { |
| var addRing MsgAddRing |
| |
| buf := bytes.NewReader(cc.data[msgTypeSize:]) |
| err = binary.Read(buf, binary.LittleEndian, &addRing) |
| if err != nil { |
| return |
| } |
| |
| fd, err := cc.parseControlMsg() |
| if err != nil { |
| return err |
| } |
| |
| if addRing.Index >= cc.i.run.NumQueuePairs { |
| return fmt.Errorf("invalid ring index") |
| } |
| |
| q := Queue{ |
| i: cc.i, |
| interruptFd: fd, |
| } |
| |
| if (addRing.Flags & msgAddRingFlagS2M) == msgAddRingFlagS2M { |
| q.ring = newRing(int(addRing.Region), ringTypeS2M, int(addRing.Offset), int(addRing.RingSizeLog2)) |
| cc.i.rxQueues = append(cc.i.rxQueues, q) |
| } else { |
| q.ring = newRing(int(addRing.Region), ringTypeM2S, int(addRing.Offset), int(addRing.RingSizeLog2)) |
| cc.i.txQueues = append(cc.i.txQueues, q) |
| } |
| |
| return nil |
| } |
| |
| func (cc *controlChannel) msgEnqConnect() (err error) { |
| var connect MsgConnect |
| copy(connect.Name[:], []byte(cc.i.args.Name)) |
| |
| buf := new(bytes.Buffer) |
| err = binary.Write(buf, binary.LittleEndian, msgTypeConnect) |
| err = binary.Write(buf, binary.LittleEndian, connect) |
| |
| msg := controlMsg{ |
| Buffer: buf, |
| Fd: -1, |
| } |
| |
| cc.msgQueue = append(cc.msgQueue, msg) |
| |
| return nil |
| } |
| |
| func (cc *controlChannel) parseConnect() (err error) { |
| var connect MsgConnect |
| |
| buf := bytes.NewReader(cc.data[msgTypeSize:]) |
| err = binary.Read(buf, binary.LittleEndian, &connect) |
| if err != nil { |
| return |
| } |
| |
| cc.i.peerName = string(connect.Name[:]) |
| |
| err = cc.i.connect() |
| if err != nil { |
| return err |
| } |
| cc.isConnected = true |
| |
| return nil |
| } |
| |
| func (cc *controlChannel) msgEnqConnected() (err error) { |
| var connected MsgConnected |
| copy(connected.Name[:], []byte(cc.i.args.Name)) |
| |
| buf := new(bytes.Buffer) |
| err = binary.Write(buf, binary.LittleEndian, msgTypeConnected) |
| err = binary.Write(buf, binary.LittleEndian, connected) |
| |
| msg := controlMsg{ |
| Buffer: buf, |
| Fd: -1, |
| } |
| |
| cc.msgQueue = append(cc.msgQueue, msg) |
| |
| return nil |
| } |
| |
| func (cc *controlChannel) parseConnected() (err error) { |
| var conn MsgConnected |
| |
| buf := bytes.NewReader(cc.data[msgTypeSize:]) |
| err = binary.Read(buf, binary.LittleEndian, &conn) |
| if err != nil { |
| return |
| } |
| |
| cc.i.peerName = string(conn.Name[:]) |
| |
| err = cc.i.connect() |
| if err != nil { |
| return err |
| } |
| cc.isConnected = true |
| |
| return nil |
| } |
| |
| func (cc *controlChannel) msgEnqDisconnect(str string) (err error) { |
| dc := MsgDisconnect{ |
| // not implemented |
| Code: 0, |
| } |
| copy(dc.String[:], str) |
| |
| buf := new(bytes.Buffer) |
| err = binary.Write(buf, binary.LittleEndian, msgTypeDisconnect) |
| err = binary.Write(buf, binary.LittleEndian, dc) |
| |
| msg := controlMsg{ |
| Buffer: buf, |
| Fd: -1, |
| } |
| |
| cc.msgQueue = append(cc.msgQueue, msg) |
| |
| return nil |
| } |
| |
| func (cc *controlChannel) parseDisconnect() (err error) { |
| var dc MsgDisconnect |
| |
| buf := bytes.NewReader(cc.data[msgTypeSize:]) |
| err = binary.Read(buf, binary.LittleEndian, &dc) |
| if err != nil { |
| return |
| } |
| |
| err = cc.close(false, string(dc.String[:])) |
| if err != nil { |
| return fmt.Errorf("Failed to disconnect control channel: ", err) |
| } |
| |
| return nil |
| } |
| |
| func (cc *controlChannel) parseMsg() error { |
| var msgType msgType |
| var err error |
| |
| buf := bytes.NewReader(cc.data[:]) |
| err = binary.Read(buf, binary.LittleEndian, &msgType) |
| |
| if msgType == msgTypeAck { |
| return nil |
| } else if msgType == msgTypeHello { |
| // Configure |
| err = cc.parseHello() |
| if err != nil { |
| goto error |
| } |
| // Initialize slave memif |
| err = cc.i.initializeRegions() |
| if err != nil { |
| goto error |
| } |
| err = cc.i.initializeQueues() |
| if err != nil { |
| goto error |
| } |
| // Enqueue messages |
| err = cc.msgEnqInit() |
| if err != nil { |
| goto error |
| } |
| for i := 0; i < len(cc.i.regions); i++ { |
| err = cc.msgEnqAddRegion(uint16(i)) |
| if err != nil { |
| goto error |
| } |
| } |
| for i := 0; uint16(i) < cc.i.run.NumQueuePairs; i++ { |
| err = cc.msgEnqAddRing(ringTypeS2M, uint16(i)) |
| if err != nil { |
| goto error |
| } |
| } |
| for i := 0; uint16(i) < cc.i.run.NumQueuePairs; i++ { |
| err = cc.msgEnqAddRing(ringTypeM2S, uint16(i)) |
| if err != nil { |
| goto error |
| } |
| } |
| err = cc.msgEnqConnect() |
| if err != nil { |
| goto error |
| } |
| } else if msgType == msgTypeInit { |
| err = cc.parseInit() |
| if err != nil { |
| goto error |
| } |
| |
| err = cc.msgEnqAck() |
| if err != nil { |
| goto error |
| } |
| } else if msgType == msgTypeAddRegion { |
| err = cc.parseAddRegion() |
| if err != nil { |
| goto error |
| } |
| |
| err = cc.msgEnqAck() |
| if err != nil { |
| goto error |
| } |
| } else if msgType == msgTypeAddRing { |
| err = cc.parseAddRing() |
| if err != nil { |
| goto error |
| } |
| |
| err = cc.msgEnqAck() |
| if err != nil { |
| goto error |
| } |
| } else if msgType == msgTypeConnect { |
| err = cc.parseConnect() |
| if err != nil { |
| goto error |
| } |
| |
| err = cc.msgEnqConnected() |
| if err != nil { |
| goto error |
| } |
| } else if msgType == msgTypeConnected { |
| err = cc.parseConnected() |
| if err != nil { |
| goto error |
| } |
| } else if msgType == msgTypeDisconnect { |
| err = cc.parseDisconnect() |
| if err != nil { |
| goto error |
| } |
| } else { |
| err = fmt.Errorf("unknown message %d", msgType) |
| goto error |
| } |
| |
| return nil |
| |
| error: |
| err1 := cc.close(true, err.Error()) |
| if err1 != nil { |
| return fmt.Errorf(err.Error(), ": Failed to close control channel: ", err1) |
| } |
| |
| return err |
| } |
| |
| // parseControlMsg parses control message and returns file descriptor |
| // if any |
| func (cc *controlChannel) parseControlMsg() (fd int, err error) { |
| // Assert only called when we require FD |
| fd = -1 |
| |
| controlMsgs, err := syscall.ParseSocketControlMessage(cc.control[:cc.controlLen]) |
| if err != nil { |
| return -1, fmt.Errorf("syscall.ParseSocketControlMessage: %s", err) |
| } |
| |
| if len(controlMsgs) == 0 { |
| return -1, fmt.Errorf("Missing control message") |
| } |
| |
| for _, cmsg := range controlMsgs { |
| if cmsg.Header.Level == syscall.SOL_SOCKET { |
| if cmsg.Header.Type == syscall.SCM_RIGHTS { |
| FDs, err := syscall.ParseUnixRights(&cmsg) |
| if err != nil { |
| return -1, fmt.Errorf("syscall.ParseUnixRights: %s", err) |
| } |
| if len(FDs) == 0 { |
| continue |
| } |
| // Only expect single FD |
| fd = FDs[0] |
| } |
| } |
| } |
| |
| if fd == -1 { |
| return -1, fmt.Errorf("Missing file descriptor") |
| } |
| |
| return fd, nil |
| } |