cnat: Add DHCP support

Type: feature

Change-Id: I4bd50fd672ac35cf14ebda2b0b10ec0b9a208628
Signed-off-by: Nathan Skrzypczak <nathan.skrzypczak@gmail.com>
diff --git a/src/plugins/cnat/test/test_cnat.py b/src/plugins/cnat/test/test_cnat.py
index 3f8d33c..d46d047 100644
--- a/src/plugins/cnat/test/test_cnat.py
+++ b/src/plugins/cnat/test/test_cnat.py
@@ -3,7 +3,8 @@
 import unittest
 
 from framework import VppTestCase, VppTestRunner
-from vpp_ip import DpoProto
+from vpp_ip import DpoProto, INVALID_INDEX
+from itertools import product
 
 from scapy.packet import Raw
 from scapy.layers.l2 import Ether
@@ -23,25 +24,34 @@
 N_PKTS = 15
 
 
-def find_cnat_translation(test, id):
-    ts = test.vapi.cnat_translation_dump()
-    for t in ts:
-        if id == t.translation.id:
-            return True
-    return False
-
-
 class Ep(object):
     """ CNat endpoint """
 
-    def __init__(self, ip, port, l4p=TCP):
+    def __init__(self, ip=None, port=0, l4p=TCP,
+                 sw_if_index=INVALID_INDEX, is_v6=False):
         self.ip = ip
+        if ip is None:
+            self.ip = "::" if is_v6 else "0.0.0.0"
         self.port = port
         self.l4p = l4p
+        self.sw_if_index = sw_if_index
+        if is_v6:
+            self.if_af = VppEnum.vl_api_address_family_t.ADDRESS_IP6
+        else:
+            self.if_af = VppEnum.vl_api_address_family_t.ADDRESS_IP4
 
     def encode(self):
         return {'addr': self.ip,
-                'port': self.port}
+                'port': self.port,
+                'sw_if_index': self.sw_if_index,
+                'if_af': self.if_af}
+
+    @classmethod
+    def from_pg(cls, pg, is_v6=False):
+        if pg is None:
+            return cls(is_v6=is_v6)
+        else:
+            return cls(sw_if_index=pg.sw_if_index, is_v6=is_v6)
 
     @property
     def isV6(self):
@@ -77,6 +87,9 @@
         for path in self.paths:
             self.encoded_paths.append(path.encode())
 
+    def __str__(self):
+        return ("%s %s %s" % (self.vip, self.iproto, self.paths))
+
     @property
     def vl4_proto(self):
         ip_proto = VppEnum.vl_api_ip_proto_t
@@ -85,9 +98,6 @@
             TCP: ip_proto.IP_API_PROTO_TCP,
         }[self.iproto]
 
-    def delete(self):
-        r = self._test.vapi.cnat_translation_del(id=self.id)
-
     def add_vpp_config(self):
         r = self._test.vapi.cnat_translation_update(
             {'vip': self.vip.encode(),
@@ -111,10 +121,13 @@
         self._test.registry.register(self, self._test.logger)
 
     def remove_vpp_config(self):
-        self._test.vapi.cnat_translation_del(self.id)
+        self._test.vapi.cnat_translation_del(id=self.id)
 
     def query_vpp_config(self):
-        return find_cnat_translation(self._test, self.id)
+        for t in self._test.vapi.cnat_translation_dump():
+            if self.id == t.translation.id:
+                return t.translation
+        return None
 
     def object_id(self):
         return ("cnat-translation-%s" % (self.vip))
@@ -191,6 +204,7 @@
                 rxs = self.send_and_expect(self.pg0,
                                            p1 * N_PKTS,
                                            self.pg1)
+                self.logger.info(self.vapi.cli("show trace max 1"))
 
                 for rx in rxs:
                     self.assert_packet_checksums_valid(rx)
@@ -352,7 +366,7 @@
             n_tries += 1
             sessions = self.vapi.cnat_session_dump()
             self.sleep(2)
-            print(self.vapi.cli("show cnat session verbose"))
+            self.logger.info(self.vapi.cli("show cnat session verbose"))
 
         self.assertTrue(n_tries < 100)
         self.vapi.cli("test cnat scanner off")
@@ -374,7 +388,7 @@
                                          self.pg2)
 
         for tr in trs:
-            tr.delete()
+            tr.remove_vpp_config()
 
         self.assertTrue(self.vapi.cnat_session_dump())
         self.vapi.cnat_session_purge()
@@ -762,10 +776,13 @@
                 l4p(sport=sports[nbr], dport=dports[nbr]) /
                 Raw())
 
+            self.vapi.cli("trace add pg-input 1")
             rxs = self.send_and_expect(
                 self.pg0,
                 p1 * N_PKTS,
                 self.pg1)
+            self.logger.info(self.vapi.cli("show trace max 1"))
+
             for rx in rxs:
                 self.assert_packet_checksums_valid(rx)
                 self.assertEqual(rx[IP46].dst, remote_addr)
@@ -825,5 +842,123 @@
             self.vapi.cnat_session_purge()
 
 
+class TestCNatDHCP(VppTestCase):
+    """ CNat Translation """
+    extra_vpp_punt_config = ["cnat", "{",
+                             "session-db-buckets", "64",
+                             "session-cleanup-timeout", "0.1",
+                             "session-max-age", "1",
+                             "tcp-max-age", "1",
+                             "scanner", "off", "}"]
+
+    @classmethod
+    def setUpClass(cls):
+        super(TestCNatDHCP, cls).setUpClass()
+
+    @classmethod
+    def tearDownClass(cls):
+        super(TestCNatDHCP, cls).tearDownClass()
+
+    def tearDown(self):
+        for i in self.pg_interfaces:
+            i.admin_down()
+        super(TestCNatDHCP, self).tearDown()
+
+    def create_translation(self, vip_pg, *args, is_v6=False):
+        vip = Ep(sw_if_index=vip_pg.sw_if_index, is_v6=is_v6)
+        paths = []
+        for (src_pg, dst_pg) in args:
+            paths.append(EpTuple(
+                Ep.from_pg(src_pg, is_v6=is_v6),
+                Ep.from_pg(dst_pg, is_v6=is_v6)
+            ))
+        t1 = VppCNatTranslation(self, TCP, vip, paths)
+        t1.add_vpp_config()
+        return t1
+
+    def make_addr(self, sw_if_index, i, is_v6):
+        if is_v6:
+            return "fd01:%x::%u" % (sw_if_index, i + 1)
+        else:
+            return "172.16.%u.%u" % (sw_if_index, i)
+
+    def make_prefix(self, sw_if_index, i, is_v6):
+        if is_v6:
+            return "%s/128" % self.make_addr(sw_if_index, i, is_v6)
+        else:
+            return "%s/32" % self.make_addr(sw_if_index, i, is_v6)
+
+    def check_resolved(self, tr, vip_pg, *args, i=0, is_v6=False):
+        qt1 = tr.query_vpp_config()
+        self.assertEqual(str(qt1.vip.addr), self.make_addr(
+            vip_pg.sw_if_index, i, is_v6))
+        for (src_pg, dst_pg), path in zip(args, qt1.paths):
+            if src_pg:
+                self.assertEqual(str(path.src_ep.addr), self.make_addr(
+                    src_pg.sw_if_index, i, is_v6))
+            if dst_pg:
+                self.assertEqual(str(path.dst_ep.addr), self.make_addr(
+                    dst_pg.sw_if_index, i, is_v6))
+
+    def config_ips(self, rng, is_add=1, is_v6=False):
+        for pg, i in product(self.pg_interfaces, rng):
+            self.vapi.sw_interface_add_del_address(
+                sw_if_index=pg.sw_if_index,
+                prefix=self.make_prefix(pg.sw_if_index, i, is_v6),
+                is_add=is_add)
+
+    def test_dhcp_v4(self):
+        self.create_pg_interfaces(range(5))
+        for i in self.pg_interfaces:
+            i.admin_up()
+        pglist = (self.pg0, (self.pg1, self.pg2), (self.pg1, self.pg4))
+        t1 = self.create_translation(*pglist)
+        self.config_ips([0])
+        self.check_resolved(t1, *pglist)
+        self.config_ips([1])
+        self.config_ips([0], is_add=0)
+        self.check_resolved(t1, *pglist, i=1)
+        self.config_ips([1], is_add=0)
+        t1.remove_vpp_config()
+
+    def test_dhcp_v6(self):
+        self.create_pg_interfaces(range(5))
+        for i in self.pg_interfaces:
+            i.admin_up()
+        pglist = (self.pg0, (self.pg1, self.pg2), (self.pg1, self.pg4))
+        t1 = self.create_translation(*pglist, is_v6=True)
+        self.config_ips([0], is_v6=True)
+        self.check_resolved(t1, *pglist, is_v6=True)
+        self.config_ips([1], is_v6=True)
+        self.config_ips([0], is_add=0, is_v6=True)
+        self.check_resolved(t1, *pglist, i=1, is_v6=True)
+        self.config_ips([1], is_add=0, is_v6=True)
+        t1.remove_vpp_config()
+
+    def test_dhcp_snat(self):
+        self.create_pg_interfaces(range(1))
+        for i in self.pg_interfaces:
+            i.admin_up()
+        self.vapi.cnat_set_snat_addresses(sw_if_index=self.pg0.sw_if_index)
+        self.config_ips([0], is_v6=False)
+        self.config_ips([0], is_v6=True)
+        r = self.vapi.cnat_get_snat_addresses()
+        self.assertEqual(str(r.snat_ip4), self.make_addr(
+            self.pg0.sw_if_index, 0, False))
+        self.assertEqual(str(r.snat_ip6), self.make_addr(
+            self.pg0.sw_if_index, 0, True))
+        self.config_ips([1], is_v6=False)
+        self.config_ips([1], is_v6=True)
+        self.config_ips([0], is_add=0, is_v6=False)
+        self.config_ips([0], is_add=0, is_v6=True)
+        r = self.vapi.cnat_get_snat_addresses()
+        self.assertEqual(str(r.snat_ip4), self.make_addr(
+            self.pg0.sw_if_index, 1, False))
+        self.assertEqual(str(r.snat_ip6), self.make_addr(
+            self.pg0.sw_if_index, 1, True))
+        self.config_ips([1], is_add=0, is_v6=False)
+        self.config_ips([1], is_add=0, is_v6=True)
+
+
 if __name__ == '__main__':
     unittest.main(testRunner=VppTestRunner)