blob: 4788fcb5ea5655ea3f54c394a25698a387b01b57 [file] [log] [blame]
Jakub Grajciar07363a42020-04-02 10:02:17 +02001/*
2 *------------------------------------------------------------------
3 * Copyright (c) 2020 Cisco and/or its affiliates.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at:
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *------------------------------------------------------------------
16 */
17
18package memif
19
20import (
21 "bytes"
22 "container/list"
23 "encoding/binary"
24 "fmt"
25 "os"
26 "sync"
27 "syscall"
28)
29
30const maxEpollEvents = 1
31const maxControlLen = 256
32
33const errorFdNotFound = "fd not found"
34
35// controlMsg represents a message used in communication between memif peers
36type controlMsg struct {
37 Buffer *bytes.Buffer
38 Fd int
39}
40
41// listener represents a listener functionality of UNIX domain socket
42type listener struct {
43 socket *Socket
44 event syscall.EpollEvent
45}
46
47// controlChannel represents a communication channel between memif peers
48// backed by UNIX domain socket
49type controlChannel struct {
50 listRef *list.Element
51 socket *Socket
52 i *Interface
53 event syscall.EpollEvent
54 data [msgSize]byte
55 control [maxControlLen]byte
56 controlLen int
57 msgQueue []controlMsg
58 isConnected bool
59}
60
61// Socket represents a UNIX domain socket used for communication
62// between memif peers
63type Socket struct {
64 appName string
65 filename string
66 listener *listener
67 interfaceList *list.List
68 ccList *list.List
69 epfd int
Daniel Béreš82ec9082022-07-27 12:22:39 +000070 interruptfd int
Jakub Grajciar07363a42020-04-02 10:02:17 +020071 wakeEvent syscall.EpollEvent
72 stopPollChan chan struct{}
73 wg sync.WaitGroup
74}
75
Daniel Béreš82ec9082022-07-27 12:22:39 +000076type interrupt struct {
77 socket *Socket
78 event syscall.EpollEvent
79}
80
81type memifInterrupt struct {
82 connection *Socket
83 qid uint16
84}
85
Jakub Grajciar07363a42020-04-02 10:02:17 +020086// StopPolling stops polling events on the socket
87func (socket *Socket) StopPolling() error {
88 if socket.stopPollChan != nil {
89 // stop polling msg
90 close(socket.stopPollChan)
91 // wake epoll
92 buf := make([]byte, 8)
93 binary.PutUvarint(buf, 1)
94 n, err := syscall.Write(int(socket.wakeEvent.Fd), buf[:])
95 if err != nil {
96 return err
97 }
98 if n != 8 {
99 return fmt.Errorf("Faild to write to eventfd")
100 }
101 // wait until polling is stopped
102 socket.wg.Wait()
103 }
104
105 return nil
106}
107
108// StartPolling starts polling and handling events on the socket,
109// enabling communication between memif peers
110func (socket *Socket) StartPolling(errChan chan<- error) {
111 socket.stopPollChan = make(chan struct{})
112 socket.wg.Add(1)
113 go func() {
114 var events [maxEpollEvents]syscall.EpollEvent
115 defer socket.wg.Done()
116
117 for {
118 select {
119 case <-socket.stopPollChan:
120 return
121 default:
122 num, err := syscall.EpollWait(socket.epfd, events[:], -1)
123 if err != nil {
124 errChan <- fmt.Errorf("EpollWait: ", err)
125 return
126 }
127
128 for ev := 0; ev < num; ev++ {
129 if events[0].Fd == socket.wakeEvent.Fd {
130 continue
131 }
132 err = socket.handleEvent(&events[0])
133 if err != nil {
134 errChan <- fmt.Errorf("handleEvent: ", err)
135 }
136 }
137 }
138 }
139 }()
140}
141
142// addEvent adds event to epoll instance associated with the socket
143func (socket *Socket) addEvent(event *syscall.EpollEvent) error {
144 err := syscall.EpollCtl(socket.epfd, syscall.EPOLL_CTL_ADD, int(event.Fd), event)
145 if err != nil {
146 return fmt.Errorf("EpollCtl: %s", err)
147 }
148 return nil
149}
150
151// addEvent deletes event to epoll instance associated with the socket
152func (socket *Socket) delEvent(event *syscall.EpollEvent) error {
153 err := syscall.EpollCtl(socket.epfd, syscall.EPOLL_CTL_DEL, int(event.Fd), event)
154 if err != nil {
155 return fmt.Errorf("EpollCtl: %s", err)
156 }
157 return nil
158}
159
160// Delete deletes the socket
161func (socket *Socket) Delete() (err error) {
162 for elt := socket.ccList.Front(); elt != nil; elt = elt.Next() {
163 cc, ok := elt.Value.(*controlChannel)
164 if ok {
165 err = cc.close(true, "Socket deleted")
166 if err != nil {
167 return nil
168 }
169 }
170 }
171 for elt := socket.interfaceList.Front(); elt != nil; elt = elt.Next() {
172 i, ok := elt.Value.(*Interface)
173 if ok {
174 err = i.Delete()
175 if err != nil {
176 return err
177 }
178 }
179 }
180
181 if socket.listener != nil {
182 err = socket.listener.close()
183 if err != nil {
184 return err
185 }
186 err = os.Remove(socket.filename)
187 if err != nil {
188 return nil
189 }
190 }
191
192 err = socket.delEvent(&socket.wakeEvent)
193 if err != nil {
194 return fmt.Errorf("Failed to delete event: ", err)
195 }
196
197 syscall.Close(socket.epfd)
198
199 return nil
200}
201
202// NewSocket returns a new Socket
203func NewSocket(appName string, filename string) (socket *Socket, err error) {
204 socket = &Socket{
205 appName: appName,
206 filename: filename,
207 interfaceList: list.New(),
208 ccList: list.New(),
209 }
210 if socket.filename == "" {
211 socket.filename = DefaultSocketFilename
212 }
213
214 socket.epfd, _ = syscall.EpollCreate1(0)
215
216 efd, err := eventFd()
217 socket.wakeEvent = syscall.EpollEvent{
218 Events: syscall.EPOLLIN | syscall.EPOLLERR | syscall.EPOLLHUP,
219 Fd: int32(efd),
220 }
221 err = socket.addEvent(&socket.wakeEvent)
222 if err != nil {
223 return nil, fmt.Errorf("Failed to add event: ", err)
224 }
225
226 return socket, nil
227}
228
229// handleEvent handles epoll event
230func (socket *Socket) handleEvent(event *syscall.EpollEvent) error {
231 if socket.listener != nil && socket.listener.event.Fd == event.Fd {
232 return socket.listener.handleEvent(event)
233 }
Daniel Béreš82ec9082022-07-27 12:22:39 +0000234 intf := socket.interfaceList.Back().Value.(*Interface)
235 if intf.args.InterruptFunc != nil {
236 if int(event.Fd) == int(intf.args.InterruptFd) {
237 b := make([]byte, 8)
238 syscall.Read(int(event.Fd), b)
239 intf.onInterrupt(intf)
240 return nil
241 }
242 }
Jakub Grajciar07363a42020-04-02 10:02:17 +0200243
244 for elt := socket.ccList.Front(); elt != nil; elt = elt.Next() {
245 cc, ok := elt.Value.(*controlChannel)
246 if ok {
247 if cc.event.Fd == event.Fd {
248 return cc.handleEvent(event)
249 }
250 }
251 }
252
253 return fmt.Errorf(errorFdNotFound)
254}
255
Daniel Béreš82ec9082022-07-27 12:22:39 +0000256func (socket *Socket) addInterrupt(fd int) (err error) {
257 l := &interrupt{
258 // we will need this to look up master interface by id
259 socket: socket,
260 }
261
262 l.event = syscall.EpollEvent{
263 Events: syscall.EPOLLIN,
264 Fd: int32(fd),
265 }
266 err = socket.addEvent(&l.event)
267 if err != nil {
268 return fmt.Errorf("Failed to add event: ", err)
269 }
270
271 return nil
272
273}
274
Jakub Grajciar07363a42020-04-02 10:02:17 +0200275// handleEvent handles epoll event for listener
276func (l *listener) handleEvent(event *syscall.EpollEvent) error {
277 // hang up
278 if (event.Events & syscall.EPOLLHUP) == syscall.EPOLLHUP {
279 err := l.close()
280 if err != nil {
281 return fmt.Errorf("Failed to close listener after hang up event: ", err)
282 }
283 return fmt.Errorf("Hang up: ", l.socket.filename)
284 }
285
286 // error
287 if (event.Events & syscall.EPOLLERR) == syscall.EPOLLERR {
288 err := l.close()
289 if err != nil {
290 return fmt.Errorf("Failed to close listener after receiving an error event: ", err)
291 }
292 return fmt.Errorf("Received error event on listener ", l.socket.filename)
293 }
294
295 // read message
296 if (event.Events & syscall.EPOLLIN) == syscall.EPOLLIN {
297 newFd, _, err := syscall.Accept(int(l.event.Fd))
298 if err != nil {
299 return fmt.Errorf("Accept: %s", err)
300 }
301
302 cc, err := l.socket.addControlChannel(newFd, nil)
303 if err != nil {
304 return fmt.Errorf("Failed to add control channel: %s", err)
305 }
306
307 err = cc.msgEnqHello()
308 if err != nil {
309 return fmt.Errorf("msgEnqHello: %s", err)
310 }
311
312 err = cc.sendMsg()
313 if err != nil {
314 return err
315 }
316
317 return nil
318 }
319
320 return fmt.Errorf("Unexpected event: ", event.Events)
321}
322
323// handleEvent handles epoll event for control channel
324func (cc *controlChannel) handleEvent(event *syscall.EpollEvent) error {
325 var size int
326 var err error
327
328 // hang up
329 if (event.Events & syscall.EPOLLHUP) == syscall.EPOLLHUP {
330 // close cc, don't send msg
331 err := cc.close(false, "")
332 if err != nil {
333 return fmt.Errorf("Failed to close control channel after hang up event: ", err)
334 }
335 return fmt.Errorf("Hang up: ", cc.i.GetName())
336 }
337
338 if (event.Events & syscall.EPOLLERR) == syscall.EPOLLERR {
339 // close cc, don't send msg
340 err := cc.close(false, "")
341 if err != nil {
342 return fmt.Errorf("Failed to close control channel after receiving an error event: ", err)
343 }
344 return fmt.Errorf("Received error event on control channel ", cc.i.GetName())
345 }
346
347 if (event.Events & syscall.EPOLLIN) == syscall.EPOLLIN {
348 size, cc.controlLen, _, _, err = syscall.Recvmsg(int(cc.event.Fd), cc.data[:], cc.control[:], 0)
349 if err != nil {
350 return fmt.Errorf("recvmsg: %s", err)
351 }
352 if size != msgSize {
353 return fmt.Errorf("invalid message size %d", size)
354 }
355
356 err = cc.parseMsg()
357 if err != nil {
358 return err
359 }
360
361 err = cc.sendMsg()
362 if err != nil {
363 return err
364 }
365
366 return nil
367 }
368
369 return fmt.Errorf("Unexpected event: ", event.Events)
370}
371
372// close closes the listener
373func (l *listener) close() error {
374 err := l.socket.delEvent(&l.event)
375 if err != nil {
376 return fmt.Errorf("Failed to del event: ", err)
377 }
378 err = syscall.Close(int(l.event.Fd))
379 if err != nil {
380 return fmt.Errorf("Failed to close socket: ", err)
381 }
382 return nil
383}
384
385// AddListener adds a lisntener to the socket. The fd must describe a
386// UNIX domain socket already bound to a UNIX domain filename and
387// marked as listener
388func (socket *Socket) AddListener(fd int) (err error) {
389 l := &listener{
390 // we will need this to look up master interface by id
391 socket: socket,
392 }
393
394 l.event = syscall.EpollEvent{
395 Events: syscall.EPOLLIN | syscall.EPOLLERR | syscall.EPOLLHUP,
396 Fd: int32(fd),
397 }
398 err = socket.addEvent(&l.event)
399 if err != nil {
400 return fmt.Errorf("Failed to add event: ", err)
401 }
402
403 socket.listener = l
404
405 return nil
406}
407
408// addListener creates new UNIX domain socket, binds it to the address
409// and marks it as listener
410func (socket *Socket) addListener() (err error) {
411 // create socket
412 fd, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_SEQPACKET, 0)
413 if err != nil {
414 return fmt.Errorf("Failed to create UNIX domain socket")
415 }
416 usa := &syscall.SockaddrUnix{Name: socket.filename}
Jakub Grajciar07363a42020-04-02 10:02:17 +0200417 // Bind to address and start listening
418 err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_PASSCRED, 1)
419 if err != nil {
420 return fmt.Errorf("Failed to set socket option %s : %v", socket.filename, err)
421 }
422 err = syscall.Bind(fd, usa)
423 if err != nil {
424 return fmt.Errorf("Failed to bind socket %s : %v", socket.filename, err)
425 }
426 err = syscall.Listen(fd, syscall.SOMAXCONN)
427 if err != nil {
428 return fmt.Errorf("Failed to listen on socket %s : %v", socket.filename, err)
429 }
430
431 return socket.AddListener(fd)
432}
433
434// close closes a control channel, if the control channel is assigned an
435// interface, the interface is disconnected
436func (cc *controlChannel) close(sendMsg bool, str string) (err error) {
437 if sendMsg == true {
438 // first clear message queue so that the disconnect
439 // message is the only message in queue
440 cc.msgQueue = []controlMsg{}
441 cc.msgEnqDisconnect(str)
442
443 err = cc.sendMsg()
444 if err != nil {
445 return err
446 }
447 }
448
449 err = cc.socket.delEvent(&cc.event)
450 if err != nil {
451 return fmt.Errorf("Failed to del event: ", err)
452 }
453
454 // remove referance form socket
455 cc.socket.ccList.Remove(cc.listRef)
456
457 if cc.i != nil {
458 err = cc.i.disconnect()
459 if err != nil {
460 return fmt.Errorf("Interface Disconnect: ", err)
461 }
462 }
463
464 return nil
465}
466
467//addControlChannel returns a new controlChannel and adds it to the socket
468func (socket *Socket) addControlChannel(fd int, i *Interface) (*controlChannel, error) {
469 cc := &controlChannel{
470 socket: socket,
471 i: i,
472 isConnected: false,
473 }
474
475 var err error
476
477 cc.event = syscall.EpollEvent{
478 Events: syscall.EPOLLIN | syscall.EPOLLERR | syscall.EPOLLHUP,
479 Fd: int32(fd),
480 }
481 err = socket.addEvent(&cc.event)
482 if err != nil {
483 return nil, fmt.Errorf("Failed to add event: ", err)
484 }
485
486 cc.listRef = socket.ccList.PushBack(cc)
487
488 return cc, nil
489}
490
491func (cc *controlChannel) msgEnqAck() (err error) {
492 buf := new(bytes.Buffer)
493 err = binary.Write(buf, binary.LittleEndian, msgTypeAck)
494
495 msg := controlMsg{
496 Buffer: buf,
497 Fd: -1,
498 }
499
500 cc.msgQueue = append(cc.msgQueue, msg)
501
502 return nil
503}
504
505func (cc *controlChannel) msgEnqHello() (err error) {
506 hello := MsgHello{
507 VersionMin: Version,
508 VersionMax: Version,
509 MaxRegion: 255,
510 MaxRingM2S: 255,
511 MaxRingS2M: 255,
512 MaxLog2RingSize: 14,
513 }
514
515 copy(hello.Name[:], []byte(cc.socket.appName))
516
517 buf := new(bytes.Buffer)
518 err = binary.Write(buf, binary.LittleEndian, msgTypeHello)
519 err = binary.Write(buf, binary.LittleEndian, hello)
520
521 msg := controlMsg{
522 Buffer: buf,
523 Fd: -1,
524 }
525
526 cc.msgQueue = append(cc.msgQueue, msg)
527
528 return nil
529}
530
531func (cc *controlChannel) parseHello() (err error) {
532 var hello MsgHello
533
534 buf := bytes.NewReader(cc.data[msgTypeSize:])
535 err = binary.Read(buf, binary.LittleEndian, &hello)
536 if err != nil {
537 return
538 }
539
540 if hello.VersionMin > Version || hello.VersionMax < Version {
541 return fmt.Errorf("Incompatible memif version")
542 }
543
544 cc.i.run = cc.i.args.MemoryConfig
545
546 cc.i.run.NumQueuePairs = min16(cc.i.args.MemoryConfig.NumQueuePairs, hello.MaxRingS2M)
547 cc.i.run.NumQueuePairs = min16(cc.i.args.MemoryConfig.NumQueuePairs, hello.MaxRingM2S)
548 cc.i.run.Log2RingSize = min8(cc.i.args.MemoryConfig.Log2RingSize, hello.MaxLog2RingSize)
549
550 cc.i.remoteName = string(hello.Name[:])
551
552 return nil
553}
554
555func (cc *controlChannel) msgEnqInit() (err error) {
556 init := MsgInit{
557 Version: Version,
558 Id: cc.i.args.Id,
Nathan Skrzypczak176373c2021-05-07 19:39:07 +0200559 Mode: cc.i.args.Mode,
Jakub Grajciar07363a42020-04-02 10:02:17 +0200560 }
561
562 copy(init.Name[:], []byte(cc.socket.appName))
563
564 buf := new(bytes.Buffer)
565 err = binary.Write(buf, binary.LittleEndian, msgTypeInit)
566 err = binary.Write(buf, binary.LittleEndian, init)
567
568 msg := controlMsg{
569 Buffer: buf,
570 Fd: -1,
571 }
572
573 cc.msgQueue = append(cc.msgQueue, msg)
574
575 return nil
576}
577
578func (cc *controlChannel) parseInit() (err error) {
579 var init MsgInit
580
581 buf := bytes.NewReader(cc.data[msgTypeSize:])
582 err = binary.Read(buf, binary.LittleEndian, &init)
583 if err != nil {
584 return
585 }
586
587 if init.Version != Version {
588 return fmt.Errorf("Incompatible memif driver version")
589 }
590
591 // find peer interface
592 for elt := cc.socket.interfaceList.Front(); elt != nil; elt = elt.Next() {
593 i, ok := elt.Value.(*Interface)
594 if ok {
595 if i.args.Id == init.Id && i.args.IsMaster && i.cc == nil {
596 // verify secret
597 if i.args.Secret != init.Secret {
598 return fmt.Errorf("Invalid secret")
599 }
600 // interface is assigned to control channel
601 i.cc = cc
602 cc.i = i
603 cc.i.run = cc.i.args.MemoryConfig
604 cc.i.remoteName = string(init.Name[:])
605
606 return nil
607 }
608 }
609 }
610
611 return fmt.Errorf("Invalid interface id")
612}
613
614func (cc *controlChannel) msgEnqAddRegion(regionIndex uint16) (err error) {
615 if len(cc.i.regions) <= int(regionIndex) {
616 return fmt.Errorf("Invalid region index")
617 }
618
619 addRegion := MsgAddRegion{
620 Index: regionIndex,
621 Size: cc.i.regions[regionIndex].size,
622 }
623
624 buf := new(bytes.Buffer)
625 err = binary.Write(buf, binary.LittleEndian, msgTypeAddRegion)
626 err = binary.Write(buf, binary.LittleEndian, addRegion)
627
628 msg := controlMsg{
629 Buffer: buf,
630 Fd: cc.i.regions[regionIndex].fd,
631 }
632
633 cc.msgQueue = append(cc.msgQueue, msg)
634
635 return nil
636}
637
638func (cc *controlChannel) parseAddRegion() (err error) {
639 var addRegion MsgAddRegion
640
641 buf := bytes.NewReader(cc.data[msgTypeSize:])
642 err = binary.Read(buf, binary.LittleEndian, &addRegion)
643 if err != nil {
644 return
645 }
646
647 fd, err := cc.parseControlMsg()
648 if err != nil {
649 return fmt.Errorf("parseControlMsg: %s", err)
650 }
651
652 if addRegion.Index > 255 {
653 return fmt.Errorf("Invalid memory region index")
654 }
655
656 region := memoryRegion{
657 size: addRegion.Size,
658 fd: fd,
659 }
660
661 cc.i.regions = append(cc.i.regions, region)
662
663 return nil
664}
665
666func (cc *controlChannel) msgEnqAddRing(ringType ringType, ringIndex uint16) (err error) {
667 var q Queue
668 var flags uint16 = 0
669
670 if ringType == ringTypeS2M {
671 q = cc.i.txQueues[ringIndex]
672 flags = msgAddRingFlagS2M
673 } else {
674 q = cc.i.rxQueues[ringIndex]
675 }
676
677 addRing := MsgAddRing{
678 Index: ringIndex,
679 Offset: uint32(q.ring.offset),
680 Region: uint16(q.ring.region),
681 RingSizeLog2: uint8(q.ring.log2Size),
682 Flags: flags,
683 PrivateHdrSize: 0,
684 }
685
686 buf := new(bytes.Buffer)
687 err = binary.Write(buf, binary.LittleEndian, msgTypeAddRing)
688 err = binary.Write(buf, binary.LittleEndian, addRing)
689
690 msg := controlMsg{
691 Buffer: buf,
692 Fd: q.interruptFd,
693 }
694
695 cc.msgQueue = append(cc.msgQueue, msg)
696
697 return nil
698}
699
700func (cc *controlChannel) parseAddRing() (err error) {
701 var addRing MsgAddRing
702
703 buf := bytes.NewReader(cc.data[msgTypeSize:])
704 err = binary.Read(buf, binary.LittleEndian, &addRing)
705 if err != nil {
706 return
707 }
708
709 fd, err := cc.parseControlMsg()
710 if err != nil {
711 return err
712 }
713
714 if addRing.Index >= cc.i.run.NumQueuePairs {
715 return fmt.Errorf("invalid ring index")
716 }
717
718 q := Queue{
719 i: cc.i,
720 interruptFd: fd,
721 }
722
723 if (addRing.Flags & msgAddRingFlagS2M) == msgAddRingFlagS2M {
724 q.ring = newRing(int(addRing.Region), ringTypeS2M, int(addRing.Offset), int(addRing.RingSizeLog2))
725 cc.i.rxQueues = append(cc.i.rxQueues, q)
726 } else {
727 q.ring = newRing(int(addRing.Region), ringTypeM2S, int(addRing.Offset), int(addRing.RingSizeLog2))
728 cc.i.txQueues = append(cc.i.txQueues, q)
729 }
730
731 return nil
732}
733
734func (cc *controlChannel) msgEnqConnect() (err error) {
735 var connect MsgConnect
736 copy(connect.Name[:], []byte(cc.i.args.Name))
737
738 buf := new(bytes.Buffer)
739 err = binary.Write(buf, binary.LittleEndian, msgTypeConnect)
740 err = binary.Write(buf, binary.LittleEndian, connect)
741
742 msg := controlMsg{
743 Buffer: buf,
744 Fd: -1,
745 }
746
747 cc.msgQueue = append(cc.msgQueue, msg)
748
749 return nil
750}
751
752func (cc *controlChannel) parseConnect() (err error) {
753 var connect MsgConnect
754
755 buf := bytes.NewReader(cc.data[msgTypeSize:])
756 err = binary.Read(buf, binary.LittleEndian, &connect)
757 if err != nil {
758 return
759 }
760
761 cc.i.peerName = string(connect.Name[:])
762
763 err = cc.i.connect()
764 if err != nil {
765 return err
766 }
Jakub Grajciar07363a42020-04-02 10:02:17 +0200767 cc.isConnected = true
768
769 return nil
770}
771
772func (cc *controlChannel) msgEnqConnected() (err error) {
773 var connected MsgConnected
774 copy(connected.Name[:], []byte(cc.i.args.Name))
775
776 buf := new(bytes.Buffer)
777 err = binary.Write(buf, binary.LittleEndian, msgTypeConnected)
778 err = binary.Write(buf, binary.LittleEndian, connected)
779
780 msg := controlMsg{
781 Buffer: buf,
782 Fd: -1,
783 }
784
785 cc.msgQueue = append(cc.msgQueue, msg)
786
787 return nil
788}
789
790func (cc *controlChannel) parseConnected() (err error) {
791 var conn MsgConnected
792
793 buf := bytes.NewReader(cc.data[msgTypeSize:])
794 err = binary.Read(buf, binary.LittleEndian, &conn)
795 if err != nil {
796 return
797 }
798
799 cc.i.peerName = string(conn.Name[:])
800
801 err = cc.i.connect()
802 if err != nil {
803 return err
804 }
Jakub Grajciar07363a42020-04-02 10:02:17 +0200805 cc.isConnected = true
806
807 return nil
808}
809
810func (cc *controlChannel) msgEnqDisconnect(str string) (err error) {
811 dc := MsgDisconnect{
812 // not implemented
813 Code: 0,
814 }
815 copy(dc.String[:], str)
816
817 buf := new(bytes.Buffer)
818 err = binary.Write(buf, binary.LittleEndian, msgTypeDisconnect)
819 err = binary.Write(buf, binary.LittleEndian, dc)
820
821 msg := controlMsg{
822 Buffer: buf,
823 Fd: -1,
824 }
825
826 cc.msgQueue = append(cc.msgQueue, msg)
827
828 return nil
829}
830
831func (cc *controlChannel) parseDisconnect() (err error) {
832 var dc MsgDisconnect
833
834 buf := bytes.NewReader(cc.data[msgTypeSize:])
835 err = binary.Read(buf, binary.LittleEndian, &dc)
836 if err != nil {
837 return
838 }
839
840 err = cc.close(false, string(dc.String[:]))
841 if err != nil {
842 return fmt.Errorf("Failed to disconnect control channel: ", err)
843 }
844
845 return nil
846}
847
848func (cc *controlChannel) parseMsg() error {
849 var msgType msgType
850 var err error
851
852 buf := bytes.NewReader(cc.data[:])
853 err = binary.Read(buf, binary.LittleEndian, &msgType)
854
855 if msgType == msgTypeAck {
856 return nil
857 } else if msgType == msgTypeHello {
858 // Configure
859 err = cc.parseHello()
860 if err != nil {
861 goto error
862 }
863 // Initialize slave memif
864 err = cc.i.initializeRegions()
865 if err != nil {
866 goto error
867 }
868 err = cc.i.initializeQueues()
869 if err != nil {
870 goto error
871 }
872 // Enqueue messages
873 err = cc.msgEnqInit()
874 if err != nil {
875 goto error
876 }
877 for i := 0; i < len(cc.i.regions); i++ {
878 err = cc.msgEnqAddRegion(uint16(i))
879 if err != nil {
880 goto error
881 }
882 }
883 for i := 0; uint16(i) < cc.i.run.NumQueuePairs; i++ {
884 err = cc.msgEnqAddRing(ringTypeS2M, uint16(i))
885 if err != nil {
886 goto error
887 }
888 }
889 for i := 0; uint16(i) < cc.i.run.NumQueuePairs; i++ {
890 err = cc.msgEnqAddRing(ringTypeM2S, uint16(i))
891 if err != nil {
892 goto error
893 }
894 }
895 err = cc.msgEnqConnect()
896 if err != nil {
897 goto error
898 }
899 } else if msgType == msgTypeInit {
900 err = cc.parseInit()
901 if err != nil {
902 goto error
903 }
904
905 err = cc.msgEnqAck()
906 if err != nil {
907 goto error
908 }
909 } else if msgType == msgTypeAddRegion {
910 err = cc.parseAddRegion()
911 if err != nil {
912 goto error
913 }
914
915 err = cc.msgEnqAck()
916 if err != nil {
917 goto error
918 }
919 } else if msgType == msgTypeAddRing {
920 err = cc.parseAddRing()
921 if err != nil {
922 goto error
923 }
924
925 err = cc.msgEnqAck()
926 if err != nil {
927 goto error
928 }
929 } else if msgType == msgTypeConnect {
930 err = cc.parseConnect()
931 if err != nil {
932 goto error
933 }
934
935 err = cc.msgEnqConnected()
936 if err != nil {
937 goto error
938 }
939 } else if msgType == msgTypeConnected {
940 err = cc.parseConnected()
941 if err != nil {
942 goto error
943 }
944 } else if msgType == msgTypeDisconnect {
945 err = cc.parseDisconnect()
946 if err != nil {
947 goto error
948 }
949 } else {
950 err = fmt.Errorf("unknown message %d", msgType)
951 goto error
952 }
953
954 return nil
955
956error:
957 err1 := cc.close(true, err.Error())
958 if err1 != nil {
959 return fmt.Errorf(err.Error(), ": Failed to close control channel: ", err1)
960 }
961
962 return err
963}
964
965// parseControlMsg parses control message and returns file descriptor
966// if any
967func (cc *controlChannel) parseControlMsg() (fd int, err error) {
968 // Assert only called when we require FD
969 fd = -1
970
971 controlMsgs, err := syscall.ParseSocketControlMessage(cc.control[:cc.controlLen])
972 if err != nil {
973 return -1, fmt.Errorf("syscall.ParseSocketControlMessage: %s", err)
974 }
975
976 if len(controlMsgs) == 0 {
977 return -1, fmt.Errorf("Missing control message")
978 }
979
980 for _, cmsg := range controlMsgs {
981 if cmsg.Header.Level == syscall.SOL_SOCKET {
982 if cmsg.Header.Type == syscall.SCM_RIGHTS {
983 FDs, err := syscall.ParseUnixRights(&cmsg)
984 if err != nil {
985 return -1, fmt.Errorf("syscall.ParseUnixRights: %s", err)
986 }
987 if len(FDs) == 0 {
988 continue
989 }
990 // Only expect single FD
991 fd = FDs[0]
992 }
993 }
994 }
995
996 if fd == -1 {
997 return -1, fmt.Errorf("Missing file descriptor")
998 }
999
1000 return fd, nil
1001}