ip: reassembly: send packet out on correct worker

Note which worker received fragment with offset zero and use this worker
to send out the reassembled packet.

Type: fix
Change-Id: I1d3cee16788db3b230682525239c0100d51dc380
Signed-off-by: Klement Sekera <ksekera@cisco.com>
diff --git a/test/test_reassembly.py b/test/test_reassembly.py
index e95d533..4c8712f 100644
--- a/test/test_reassembly.py
+++ b/test/test_reassembly.py
@@ -2,7 +2,7 @@
 
 import six
 import unittest
-from random import shuffle
+from random import shuffle, choice, randrange
 
 from framework import VppTestCase, VppTestRunner
 
@@ -10,11 +10,10 @@
 from scapy.packet import Raw
 from scapy.layers.l2 import Ether, GRE
 from scapy.layers.inet import IP, UDP, ICMP
-from util import ppp, fragment_rfc791, fragment_rfc8200
 from scapy.layers.inet6 import IPv6, IPv6ExtHdrFragment, ICMPv6ParamProblem,\
     ICMPv6TimeExceeded
 from framework import VppTestCase, VppTestRunner
-from util import ppp, fragment_rfc791, fragment_rfc8200
+from util import ppp, ppc, fragment_rfc791, fragment_rfc8200
 from vpp_gre_interface import VppGreInterface
 from vpp_ip import DpoProto
 from vpp_ip_route import VppIpRoute, VppRoutePath, FibPathProto
@@ -22,6 +21,9 @@
 # 35 is enough to have >257 400-byte fragments
 test_packet_count = 35
 
+# number of workers used for multi-worker test cases
+worker_count = 3
+
 
 class TestIPv4Reassembly(VppTestCase):
     """ IPv4 Reassembly """
@@ -499,6 +501,179 @@
         self.src_if.assert_nothing_captured()
 
 
+class TestIPv4MWReassembly(VppTestCase):
+    """ IPv4 Reassembly (multiple workers) """
+    worker_config = "workers %d" % worker_count
+
+    @classmethod
+    def setUpClass(cls):
+        super(TestIPv4MWReassembly, cls).setUpClass()
+
+        cls.create_pg_interfaces(range(worker_count+1))
+        cls.src_if = cls.pg0
+        cls.send_ifs = cls.pg_interfaces[:-1]
+        cls.dst_if = cls.pg_interfaces[-1]
+
+        # setup all interfaces
+        for i in cls.pg_interfaces:
+            i.admin_up()
+            i.config_ip4()
+            i.resolve_arp()
+
+        # packets sizes reduced here because we are generating packets without
+        # Ethernet headers, which are added later (diff fragments go via
+        # different interfaces)
+        cls.packet_sizes = [64-len(Ether()), 512-len(Ether()),
+                            1518-len(Ether()), 9018-len(Ether())]
+        cls.padding = " abcdefghijklmn"
+        cls.create_stream(cls.packet_sizes)
+        cls.create_fragments()
+
+    @classmethod
+    def tearDownClass(cls):
+        super(TestIPv4MWReassembly, cls).tearDownClass()
+
+    def setUp(self):
+        """ Test setup - force timeout on existing reassemblies """
+        super(TestIPv4MWReassembly, self).setUp()
+        for intf in self.send_ifs:
+            self.vapi.ip_reassembly_enable_disable(
+                sw_if_index=intf.sw_if_index, enable_ip4=True)
+        self.vapi.ip_reassembly_set(timeout_ms=0, max_reassemblies=1000,
+                                    max_reassembly_length=1000,
+                                    expire_walk_interval_ms=10)
+        self.sleep(.25)
+        self.vapi.ip_reassembly_set(timeout_ms=1000000, max_reassemblies=1000,
+                                    max_reassembly_length=1000,
+                                    expire_walk_interval_ms=10000)
+
+    def tearDown(self):
+        super(TestIPv4MWReassembly, self).tearDown()
+
+    def show_commands_at_teardown(self):
+        self.logger.debug(self.vapi.ppcli("show ip4-reassembly details"))
+        self.logger.debug(self.vapi.ppcli("show buffers"))
+
+    @classmethod
+    def create_stream(cls, packet_sizes, packet_count=test_packet_count):
+        """Create input packet stream
+
+        :param list packet_sizes: Required packet sizes.
+        """
+        for i in range(0, packet_count):
+            info = cls.create_packet_info(cls.src_if, cls.src_if)
+            payload = cls.info_to_payload(info)
+            p = (IP(id=info.index, src=cls.src_if.remote_ip4,
+                    dst=cls.dst_if.remote_ip4) /
+                 UDP(sport=1234, dport=5678) /
+                 Raw(payload))
+            size = packet_sizes[(i // 2) % len(packet_sizes)]
+            cls.extend_packet(p, size, cls.padding)
+            info.data = p
+
+    @classmethod
+    def create_fragments(cls):
+        infos = cls._packet_infos
+        cls.pkt_infos = []
+        for index, info in six.iteritems(infos):
+            p = info.data
+            # cls.logger.debug(ppp("Packet:",
+            #                      p.__class__(scapy.compat.raw(p))))
+            fragments_400 = fragment_rfc791(p, 400)
+            cls.pkt_infos.append((index, fragments_400))
+        cls.fragments_400 = [
+            x for (_, frags) in cls.pkt_infos for x in frags]
+        cls.logger.debug("Fragmented %s packets into %s 400-byte fragments, " %
+                         (len(infos), len(cls.fragments_400)))
+
+    def verify_capture(self, capture, dropped_packet_indexes=[]):
+        """Verify captured packet stream.
+
+        :param list capture: Captured packet stream.
+        """
+        info = None
+        seen = set()
+        for packet in capture:
+            try:
+                self.logger.debug(ppp("Got packet:", packet))
+                ip = packet[IP]
+                udp = packet[UDP]
+                payload_info = self.payload_to_info(packet[Raw])
+                packet_index = payload_info.index
+                self.assertTrue(
+                    packet_index not in dropped_packet_indexes,
+                    ppp("Packet received, but should be dropped:", packet))
+                if packet_index in seen:
+                    raise Exception(ppp("Duplicate packet received", packet))
+                seen.add(packet_index)
+                self.assertEqual(payload_info.dst, self.src_if.sw_if_index)
+                info = self._packet_infos[packet_index]
+                self.assertTrue(info is not None)
+                self.assertEqual(packet_index, info.index)
+                saved_packet = info.data
+                self.assertEqual(ip.src, saved_packet[IP].src)
+                self.assertEqual(ip.dst, saved_packet[IP].dst)
+                self.assertEqual(udp.payload, saved_packet[UDP].payload)
+            except Exception:
+                self.logger.error(ppp("Unexpected or invalid packet:", packet))
+                raise
+        for index in self._packet_infos:
+            self.assertTrue(index in seen or index in dropped_packet_indexes,
+                            "Packet with packet_index %d not received" % index)
+
+    def send_packets(self, packets):
+        for counter in range(worker_count):
+            if 0 == len(packets[counter]):
+                continue
+            send_if = self.send_ifs[counter]
+            send_if.add_stream(
+                (Ether(dst=send_if.local_mac, src=send_if.remote_mac) / x
+                 for x in packets[counter]),
+                worker=counter)
+        self.pg_start()
+
+    def test_worker_conflict(self):
+        """ 1st and FO=0 fragments on different workers """
+
+        # in first wave we send fragments which don't start at offset 0
+        # then we send fragments with offset 0 on a different thread
+        # then the rest of packets on a random thread
+        first_packets = [[] for n in range(worker_count)]
+        second_packets = [[] for n in range(worker_count)]
+        rest_of_packets = [[] for n in range(worker_count)]
+        for (_, p) in self.pkt_infos:
+            wi = randrange(worker_count)
+            second_packets[wi].append(p[0])
+            if len(p) <= 1:
+                continue
+            wi2 = wi
+            while wi2 == wi:
+                wi2 = randrange(worker_count)
+            first_packets[wi2].append(p[1])
+            wi3 = randrange(worker_count)
+            rest_of_packets[wi3].extend(p[2:])
+
+        self.pg_enable_capture()
+        self.send_packets(first_packets)
+        self.send_packets(second_packets)
+        self.send_packets(rest_of_packets)
+
+        packets = self.dst_if.get_capture(len(self.pkt_infos))
+        self.verify_capture(packets)
+        for send_if in self.send_ifs:
+            send_if.assert_nothing_captured()
+
+        self.pg_enable_capture()
+        self.send_packets(first_packets)
+        self.send_packets(second_packets)
+        self.send_packets(rest_of_packets)
+
+        packets = self.dst_if.get_capture(len(self.pkt_infos))
+        self.verify_capture(packets)
+        for send_if in self.send_ifs:
+            send_if.assert_nothing_captured()
+
+
 class TestIPv6Reassembly(VppTestCase):
     """ IPv6 Reassembly """
 
@@ -937,6 +1112,179 @@
         self.assert_equal(icmp[ICMPv6ParamProblem].code, 0, "ICMP code")
 
 
+class TestIPv6MWReassembly(VppTestCase):
+    """ IPv6 Reassembly (multiple workers) """
+    worker_config = "workers %d" % worker_count
+
+    @classmethod
+    def setUpClass(cls):
+        super(TestIPv6MWReassembly, cls).setUpClass()
+
+        cls.create_pg_interfaces(range(worker_count+1))
+        cls.src_if = cls.pg0
+        cls.send_ifs = cls.pg_interfaces[:-1]
+        cls.dst_if = cls.pg_interfaces[-1]
+
+        # setup all interfaces
+        for i in cls.pg_interfaces:
+            i.admin_up()
+            i.config_ip6()
+            i.resolve_ndp()
+
+        # packets sizes reduced here because we are generating packets without
+        # Ethernet headers, which are added later (diff fragments go via
+        # different interfaces)
+        cls.packet_sizes = [64-len(Ether()), 512-len(Ether()),
+                            1518-len(Ether()), 9018-len(Ether())]
+        cls.padding = " abcdefghijklmn"
+        cls.create_stream(cls.packet_sizes)
+        cls.create_fragments()
+
+    @classmethod
+    def tearDownClass(cls):
+        super(TestIPv6MWReassembly, cls).tearDownClass()
+
+    def setUp(self):
+        """ Test setup - force timeout on existing reassemblies """
+        super(TestIPv6MWReassembly, self).setUp()
+        for intf in self.send_ifs:
+            self.vapi.ip_reassembly_enable_disable(
+                sw_if_index=intf.sw_if_index, enable_ip6=True)
+        self.vapi.ip_reassembly_set(timeout_ms=0, max_reassemblies=1000,
+                                    max_reassembly_length=1000,
+                                    expire_walk_interval_ms=10, is_ip6=1)
+        self.sleep(.25)
+        self.vapi.ip_reassembly_set(timeout_ms=1000000, max_reassemblies=1000,
+                                    max_reassembly_length=1000,
+                                    expire_walk_interval_ms=1000, is_ip6=1)
+
+    def tearDown(self):
+        super(TestIPv6MWReassembly, self).tearDown()
+
+    def show_commands_at_teardown(self):
+        self.logger.debug(self.vapi.ppcli("show ip6-reassembly details"))
+        self.logger.debug(self.vapi.ppcli("show buffers"))
+
+    @classmethod
+    def create_stream(cls, packet_sizes, packet_count=test_packet_count):
+        """Create input packet stream
+
+        :param list packet_sizes: Required packet sizes.
+        """
+        for i in range(0, packet_count):
+            info = cls.create_packet_info(cls.src_if, cls.src_if)
+            payload = cls.info_to_payload(info)
+            p = (IPv6(src=cls.src_if.remote_ip6,
+                      dst=cls.dst_if.remote_ip6) /
+                 UDP(sport=1234, dport=5678) /
+                 Raw(payload))
+            size = packet_sizes[(i // 2) % len(packet_sizes)]
+            cls.extend_packet(p, size, cls.padding)
+            info.data = p
+
+    @classmethod
+    def create_fragments(cls):
+        infos = cls._packet_infos
+        cls.pkt_infos = []
+        for index, info in six.iteritems(infos):
+            p = info.data
+            # cls.logger.debug(ppp("Packet:",
+            #                      p.__class__(scapy.compat.raw(p))))
+            fragments_400 = fragment_rfc8200(p, index, 400)
+            cls.pkt_infos.append((index, fragments_400))
+        cls.fragments_400 = [
+            x for (_, frags) in cls.pkt_infos for x in frags]
+        cls.logger.debug("Fragmented %s packets into %s 400-byte fragments, " %
+                         (len(infos), len(cls.fragments_400)))
+
+    def verify_capture(self, capture, dropped_packet_indexes=[]):
+        """Verify captured packet strea .
+
+        :param list capture: Captured packet stream.
+        """
+        info = None
+        seen = set()
+        for packet in capture:
+            try:
+                self.logger.debug(ppp("Got packet:", packet))
+                ip = packet[IPv6]
+                udp = packet[UDP]
+                payload_info = self.payload_to_info(packet[Raw])
+                packet_index = payload_info.index
+                self.assertTrue(
+                    packet_index not in dropped_packet_indexes,
+                    ppp("Packet received, but should be dropped:", packet))
+                if packet_index in seen:
+                    raise Exception(ppp("Duplicate packet received", packet))
+                seen.add(packet_index)
+                self.assertEqual(payload_info.dst, self.src_if.sw_if_index)
+                info = self._packet_infos[packet_index]
+                self.assertTrue(info is not None)
+                self.assertEqual(packet_index, info.index)
+                saved_packet = info.data
+                self.assertEqual(ip.src, saved_packet[IPv6].src)
+                self.assertEqual(ip.dst, saved_packet[IPv6].dst)
+                self.assertEqual(udp.payload, saved_packet[UDP].payload)
+            except Exception:
+                self.logger.error(ppp("Unexpected or invalid packet:", packet))
+                raise
+        for index in self._packet_infos:
+            self.assertTrue(index in seen or index in dropped_packet_indexes,
+                            "Packet with packet_index %d not received" % index)
+
+    def send_packets(self, packets):
+        for counter in range(worker_count):
+            if 0 == len(packets[counter]):
+                continue
+            send_if = self.send_ifs[counter]
+            send_if.add_stream(
+                (Ether(dst=send_if.local_mac, src=send_if.remote_mac) / x
+                 for x in packets[counter]),
+                worker=counter)
+        self.pg_start()
+
+    def test_worker_conflict(self):
+        """ 1st and FO=0 fragments on different workers """
+
+        # in first wave we send fragments which don't start at offset 0
+        # then we send fragments with offset 0 on a different thread
+        # then the rest of packets on a random thread
+        first_packets = [[] for n in range(worker_count)]
+        second_packets = [[] for n in range(worker_count)]
+        rest_of_packets = [[] for n in range(worker_count)]
+        for (_, p) in self.pkt_infos:
+            wi = randrange(worker_count)
+            second_packets[wi].append(p[0])
+            if len(p) <= 1:
+                continue
+            wi2 = wi
+            while wi2 == wi:
+                wi2 = randrange(worker_count)
+            first_packets[wi2].append(p[1])
+            wi3 = randrange(worker_count)
+            rest_of_packets[wi3].extend(p[2:])
+
+        self.pg_enable_capture()
+        self.send_packets(first_packets)
+        self.send_packets(second_packets)
+        self.send_packets(rest_of_packets)
+
+        packets = self.dst_if.get_capture(len(self.pkt_infos))
+        self.verify_capture(packets)
+        for send_if in self.send_ifs:
+            send_if.assert_nothing_captured()
+
+        self.pg_enable_capture()
+        self.send_packets(first_packets)
+        self.send_packets(second_packets)
+        self.send_packets(rest_of_packets)
+
+        packets = self.dst_if.get_capture(len(self.pkt_infos))
+        self.verify_capture(packets)
+        for send_if in self.send_ifs:
+            send_if.assert_nothing_captured()
+
+
 class TestIPv4ReassemblyLocalNode(VppTestCase):
     """ IPv4 Reassembly for packets coming to ip4-local node """