NAT44: active-passive HA (VPP-1571)

session synchronization so that we can build a plain active-passive HA NAT pair

Change-Id: I21db200491081ca46b7af3e82afc677c1985abf4
Signed-off-by: Matus Fabian <matfabia@cisco.com>
diff --git a/test/test_nat.py b/test/test_nat.py
index 0d74cb6..fce7efe 100644
--- a/test/test_nat.py
+++ b/test/test_nat.py
@@ -24,6 +24,45 @@
 from vpp_papi_provider import SYSLOG_SEVERITY
 from io import BytesIO
 from vpp_papi import VppEnum
+from scapy.all import bind_layers, Packet, ByteEnumField, ShortField, \
+    IPField, IntField, LongField, XByteField, FlagsField, FieldLenField, \
+    PacketListField
+
+
+# NAT HA protocol event data
+class Event(Packet):
+    name = "Event"
+    fields_desc = [ByteEnumField("event_type", None,
+                                 {1: "add", 2: "del", 3: "refresh"}),
+                   ByteEnumField("protocol", None,
+                                 {0: "udp", 1: "tcp", 2: "icmp"}),
+                   ShortField("flags", 0),
+                   IPField("in_addr", None),
+                   IPField("out_addr", None),
+                   ShortField("in_port", None),
+                   ShortField("out_port", None),
+                   IPField("eh_addr", None),
+                   IPField("ehn_addr", None),
+                   ShortField("eh_port", None),
+                   ShortField("ehn_port", None),
+                   IntField("fib_index", None),
+                   IntField("total_pkts", 0),
+                   LongField("total_bytes", 0)]
+
+    def extract_padding(self, s):
+        return "", s
+
+
+# NAT HA protocol header
+class HANATStateSync(Packet):
+    name = "HA NAT state sync"
+    fields_desc = [XByteField("version", 1),
+                   FlagsField("flags", 0, 8, ['ACK']),
+                   FieldLenField("count", None, count_of="events"),
+                   IntField("sequence_number", 1),
+                   IntField("thread_index", 0),
+                   PacketListField("events", [], Event,
+                                   count_from=lambda pkt:pkt.count)]
 
 
 class MethodHolder(VppTestCase):
@@ -75,6 +114,9 @@
 
         self.vapi.syslog_set_filter(SYSLOG_SEVERITY.EMERG)
 
+        self.vapi.nat_ha_set_listener('0.0.0.0', 0)
+        self.vapi.nat_ha_set_failover('0.0.0.0', 0)
+
         interfaces = self.vapi.nat44_interface_dump()
         for intf in interfaces:
             if intf.is_inside > 1:
@@ -1460,6 +1502,7 @@
             cls.ipfix_src_port = 4739
             cls.ipfix_domain_id = 1
             cls.tcp_external_port = 80
+            cls.udp_external_port = 69
 
             cls.create_pg_interfaces(range(10))
             cls.interfaces = list(cls.pg_interfaces[0:4])
@@ -2247,8 +2290,8 @@
             self.assertTrue(session.is_static)
             self.assertEqual(session.inside_ip_address[0:4],
                              self.pg6.remote_ip4n)
-            self.assertEqual(map(ord, session.outside_ip_address[0:4]),
-                             map(int, static_nat_ip.split('.')))
+            self.assertEqual(session.outside_ip_address,
+                             socket.inet_pton(socket.AF_INET, static_nat_ip))
             self.assertTrue(session.inside_port in
                             [self.tcp_port_in, self.udp_port_in,
                              self.icmp_id_in])
@@ -3803,6 +3846,311 @@
         # Negotiated MSS value smaller than configured - unchanged
         self.verify_mss_value(capture[0], 1400)
 
+    @unittest.skipUnless(running_extended_tests, "part of extended tests")
+    def test_ha_send(self):
+        """ Send HA session synchronization events (active) """
+        self.nat44_add_address(self.nat_addr)
+        self.vapi.nat44_interface_add_del_feature(self.pg0.sw_if_index)
+        self.vapi.nat44_interface_add_del_feature(self.pg1.sw_if_index,
+                                                  is_inside=0)
+        self.vapi.nat_ha_set_listener(self.pg3.local_ip4, port=12345)
+        self.vapi.nat_ha_set_failover(self.pg3.remote_ip4, port=12346)
+        bind_layers(UDP, HANATStateSync, sport=12345)
+
+        # create sessions
+        pkts = self.create_stream_in(self.pg0, self.pg1)
+        self.pg0.add_stream(pkts)
+        self.pg_enable_capture(self.pg_interfaces)
+        self.pg_start()
+        capture = self.pg1.get_capture(len(pkts))
+        self.verify_capture_out(capture)
+        # active send HA events
+        self.vapi.nat_ha_flush()
+        stats = self.statistics.get_counter('/nat44/ha/add-event-send')
+        self.assertEqual(stats[0][0], 3)
+        capture = self.pg3.get_capture(1)
+        p = capture[0]
+        self.assert_packet_checksums_valid(p)
+        try:
+            ip = p[IP]
+            udp = p[UDP]
+            hanat = p[HANATStateSync]
+        except IndexError:
+            self.logger.error(ppp("Invalid packet:", p))
+            raise
+        else:
+            self.assertEqual(ip.src, self.pg3.local_ip4)
+            self.assertEqual(ip.dst, self.pg3.remote_ip4)
+            self.assertEqual(udp.sport, 12345)
+            self.assertEqual(udp.dport, 12346)
+            self.assertEqual(hanat.version, 1)
+            self.assertEqual(hanat.thread_index, 0)
+            self.assertEqual(hanat.count, 3)
+            seq = hanat.sequence_number
+            for event in hanat.events:
+                self.assertEqual(event.event_type, 1)
+                self.assertEqual(event.in_addr, self.pg0.remote_ip4)
+                self.assertEqual(event.out_addr, self.nat_addr)
+                self.assertEqual(event.fib_index, 0)
+
+        # ACK received events
+        ack = (Ether(dst=self.pg3.local_mac, src=self.pg3.remote_mac) /
+               IP(src=self.pg3.remote_ip4, dst=self.pg3.local_ip4) /
+               UDP(sport=12346, dport=12345) /
+               HANATStateSync(sequence_number=seq, flags='ACK'))
+        self.pg3.add_stream(ack)
+        self.pg_start()
+        stats = self.statistics.get_counter('/nat44/ha/ack-recv')
+        self.assertEqual(stats[0][0], 1)
+
+        # delete one session
+        self.pg_enable_capture(self.pg_interfaces)
+        self.vapi.nat44_del_session(self.pg0.remote_ip4n, self.tcp_port_in,
+                                    IP_PROTOS.tcp)
+        self.vapi.nat_ha_flush()
+        stats = self.statistics.get_counter('/nat44/ha/del-event-send')
+        self.assertEqual(stats[0][0], 1)
+        capture = self.pg3.get_capture(1)
+        p = capture[0]
+        try:
+            hanat = p[HANATStateSync]
+        except IndexError:
+            self.logger.error(ppp("Invalid packet:", p))
+            raise
+        else:
+            self.assertGreater(hanat.sequence_number, seq)
+
+        # do not send ACK, active retry send HA event again
+        self.pg_enable_capture(self.pg_interfaces)
+        sleep(12)
+        stats = self.statistics.get_counter('/nat44/ha/retry-count')
+        self.assertEqual(stats[0][0], 3)
+        stats = self.statistics.get_counter('/nat44/ha/missed-count')
+        self.assertEqual(stats[0][0], 1)
+        capture = self.pg3.get_capture(3)
+        for packet in capture:
+            self.assertEqual(packet, p)
+
+        # session counters refresh
+        pkts = self.create_stream_out(self.pg1)
+        self.pg1.add_stream(pkts)
+        self.pg_enable_capture(self.pg_interfaces)
+        self.pg_start()
+        self.pg0.get_capture(2)
+        self.vapi.nat_ha_flush()
+        stats = self.statistics.get_counter('/nat44/ha/refresh-event-send')
+        self.assertEqual(stats[0][0], 2)
+        capture = self.pg3.get_capture(1)
+        p = capture[0]
+        self.assert_packet_checksums_valid(p)
+        try:
+            ip = p[IP]
+            udp = p[UDP]
+            hanat = p[HANATStateSync]
+        except IndexError:
+            self.logger.error(ppp("Invalid packet:", p))
+            raise
+        else:
+            self.assertEqual(ip.src, self.pg3.local_ip4)
+            self.assertEqual(ip.dst, self.pg3.remote_ip4)
+            self.assertEqual(udp.sport, 12345)
+            self.assertEqual(udp.dport, 12346)
+            self.assertEqual(hanat.version, 1)
+            self.assertEqual(hanat.count, 2)
+            seq = hanat.sequence_number
+            for event in hanat.events:
+                self.assertEqual(event.event_type, 3)
+                self.assertEqual(event.out_addr, self.nat_addr)
+                self.assertEqual(event.fib_index, 0)
+                self.assertEqual(event.total_pkts, 2)
+                self.assertGreater(event.total_bytes, 0)
+
+        ack = (Ether(dst=self.pg3.local_mac, src=self.pg3.remote_mac) /
+               IP(src=self.pg3.remote_ip4, dst=self.pg3.local_ip4) /
+               UDP(sport=12346, dport=12345) /
+               HANATStateSync(sequence_number=seq, flags='ACK'))
+        self.pg3.add_stream(ack)
+        self.pg_start()
+        stats = self.statistics.get_counter('/nat44/ha/ack-recv')
+        self.assertEqual(stats[0][0], 2)
+
+    def test_ha_recv(self):
+        """ Receive HA session synchronization events (passive) """
+        self.nat44_add_address(self.nat_addr)
+        self.vapi.nat44_interface_add_del_feature(self.pg0.sw_if_index)
+        self.vapi.nat44_interface_add_del_feature(self.pg1.sw_if_index,
+                                                  is_inside=0)
+        self.vapi.nat_ha_set_listener(self.pg3.local_ip4, port=12345)
+        bind_layers(UDP, HANATStateSync, sport=12345)
+
+        self.tcp_port_out = random.randint(1025, 65535)
+        self.udp_port_out = random.randint(1025, 65535)
+
+        # send HA session add events to failover/passive
+        p = (Ether(dst=self.pg3.local_mac, src=self.pg3.remote_mac) /
+             IP(src=self.pg3.remote_ip4, dst=self.pg3.local_ip4) /
+             UDP(sport=12346, dport=12345) /
+             HANATStateSync(sequence_number=1, events=[
+                 Event(event_type='add', protocol='tcp',
+                       in_addr=self.pg0.remote_ip4, out_addr=self.nat_addr,
+                       in_port=self.tcp_port_in, out_port=self.tcp_port_out,
+                       eh_addr=self.pg1.remote_ip4,
+                       ehn_addr=self.pg1.remote_ip4,
+                       eh_port=self.tcp_external_port,
+                       ehn_port=self.tcp_external_port, fib_index=0),
+                 Event(event_type='add', protocol='udp',
+                       in_addr=self.pg0.remote_ip4, out_addr=self.nat_addr,
+                       in_port=self.udp_port_in, out_port=self.udp_port_out,
+                       eh_addr=self.pg1.remote_ip4,
+                       ehn_addr=self.pg1.remote_ip4,
+                       eh_port=self.udp_external_port,
+                       ehn_port=self.udp_external_port, fib_index=0)]))
+
+        self.pg3.add_stream(p)
+        self.pg_enable_capture(self.pg_interfaces)
+        self.pg_start()
+        # receive ACK
+        capture = self.pg3.get_capture(1)
+        p = capture[0]
+        try:
+            hanat = p[HANATStateSync]
+        except IndexError:
+            self.logger.error(ppp("Invalid packet:", p))
+            raise
+        else:
+            self.assertEqual(hanat.sequence_number, 1)
+            self.assertEqual(hanat.flags, 'ACK')
+            self.assertEqual(hanat.version, 1)
+            self.assertEqual(hanat.thread_index, 0)
+        stats = self.statistics.get_counter('/nat44/ha/ack-send')
+        self.assertEqual(stats[0][0], 1)
+        stats = self.statistics.get_counter('/nat44/ha/add-event-recv')
+        self.assertEqual(stats[0][0], 2)
+        users = self.statistics.get_counter('/nat44/total-users')
+        self.assertEqual(users[0][0], 1)
+        sessions = self.statistics.get_counter('/nat44/total-sessions')
+        self.assertEqual(sessions[0][0], 2)
+        users = self.vapi.nat44_user_dump()
+        self.assertEqual(len(users), 1)
+        self.assertEqual(users[0].ip_address, self.pg0.remote_ip4n)
+        # there should be 2 sessions created by HA
+        sessions = self.vapi.nat44_user_session_dump(users[0].ip_address,
+                                                     users[0].vrf_id)
+        self.assertEqual(len(sessions), 2)
+        for session in sessions:
+            self.assertEqual(session.inside_ip_address, self.pg0.remote_ip4n)
+            self.assertEqual(session.outside_ip_address, self.nat_addr_n)
+            self.assertIn(session.inside_port,
+                          [self.tcp_port_in, self.udp_port_in])
+            self.assertIn(session.outside_port,
+                          [self.tcp_port_out, self.udp_port_out])
+            self.assertIn(session.protocol, [IP_PROTOS.tcp, IP_PROTOS.udp])
+
+        # send HA session delete event to failover/passive
+        p = (Ether(dst=self.pg3.local_mac, src=self.pg3.remote_mac) /
+             IP(src=self.pg3.remote_ip4, dst=self.pg3.local_ip4) /
+             UDP(sport=12346, dport=12345) /
+             HANATStateSync(sequence_number=2, events=[
+                 Event(event_type='del', protocol='udp',
+                       in_addr=self.pg0.remote_ip4, out_addr=self.nat_addr,
+                       in_port=self.udp_port_in, out_port=self.udp_port_out,
+                       eh_addr=self.pg1.remote_ip4,
+                       ehn_addr=self.pg1.remote_ip4,
+                       eh_port=self.udp_external_port,
+                       ehn_port=self.udp_external_port, fib_index=0)]))
+
+        self.pg3.add_stream(p)
+        self.pg_enable_capture(self.pg_interfaces)
+        self.pg_start()
+        # receive ACK
+        capture = self.pg3.get_capture(1)
+        p = capture[0]
+        try:
+            hanat = p[HANATStateSync]
+        except IndexError:
+            self.logger.error(ppp("Invalid packet:", p))
+            raise
+        else:
+            self.assertEqual(hanat.sequence_number, 2)
+            self.assertEqual(hanat.flags, 'ACK')
+            self.assertEqual(hanat.version, 1)
+        users = self.vapi.nat44_user_dump()
+        self.assertEqual(len(users), 1)
+        self.assertEqual(users[0].ip_address, self.pg0.remote_ip4n)
+        # now we should have only 1 session, 1 deleted by HA
+        sessions = self.vapi.nat44_user_session_dump(users[0].ip_address,
+                                                     users[0].vrf_id)
+        self.assertEqual(len(sessions), 1)
+        stats = self.statistics.get_counter('/nat44/ha/del-event-recv')
+        self.assertEqual(stats[0][0], 1)
+
+        stats = self.statistics.get_counter('/err/nat-ha/pkts-processed')
+        self.assertEqual(stats, 2)
+
+        # send HA session refresh event to failover/passive
+        p = (Ether(dst=self.pg3.local_mac, src=self.pg3.remote_mac) /
+             IP(src=self.pg3.remote_ip4, dst=self.pg3.local_ip4) /
+             UDP(sport=12346, dport=12345) /
+             HANATStateSync(sequence_number=3, events=[
+                 Event(event_type='refresh', protocol='tcp',
+                       in_addr=self.pg0.remote_ip4, out_addr=self.nat_addr,
+                       in_port=self.tcp_port_in, out_port=self.tcp_port_out,
+                       eh_addr=self.pg1.remote_ip4,
+                       ehn_addr=self.pg1.remote_ip4,
+                       eh_port=self.tcp_external_port,
+                       ehn_port=self.tcp_external_port, fib_index=0,
+                       total_bytes=1024, total_pkts=2)]))
+        self.pg3.add_stream(p)
+        self.pg_enable_capture(self.pg_interfaces)
+        self.pg_start()
+        # receive ACK
+        capture = self.pg3.get_capture(1)
+        p = capture[0]
+        try:
+            hanat = p[HANATStateSync]
+        except IndexError:
+            self.logger.error(ppp("Invalid packet:", p))
+            raise
+        else:
+            self.assertEqual(hanat.sequence_number, 3)
+            self.assertEqual(hanat.flags, 'ACK')
+            self.assertEqual(hanat.version, 1)
+        users = self.vapi.nat44_user_dump()
+        self.assertEqual(len(users), 1)
+        self.assertEqual(users[0].ip_address, self.pg0.remote_ip4n)
+        sessions = self.vapi.nat44_user_session_dump(users[0].ip_address,
+                                                     users[0].vrf_id)
+        self.assertEqual(len(sessions), 1)
+        session = sessions[0]
+        self.assertEqual(session.total_bytes, 1024)
+        self.assertEqual(session.total_pkts, 2)
+        stats = self.statistics.get_counter('/nat44/ha/refresh-event-recv')
+        self.assertEqual(stats[0][0], 1)
+
+        stats = self.statistics.get_counter('/err/nat-ha/pkts-processed')
+        self.assertEqual(stats, 3)
+
+        # send packet to test session created by HA
+        p = (Ether(dst=self.pg1.local_mac, src=self.pg1.remote_mac) /
+             IP(src=self.pg1.remote_ip4, dst=self.nat_addr) /
+             TCP(sport=self.tcp_external_port, dport=self.tcp_port_out))
+        self.pg1.add_stream(p)
+        self.pg_enable_capture(self.pg_interfaces)
+        self.pg_start()
+        capture = self.pg0.get_capture(1)
+        p = capture[0]
+        try:
+            ip = p[IP]
+            tcp = p[TCP]
+        except IndexError:
+            self.logger.error(ppp("Invalid packet:", p))
+            raise
+        else:
+            self.assertEqual(ip.src, self.pg1.remote_ip4)
+            self.assertEqual(ip.dst, self.pg0.remote_ip4)
+            self.assertEqual(tcp.sport, self.tcp_external_port)
+            self.assertEqual(tcp.dport, self.tcp_port_in)
+
     def tearDown(self):
         super(TestNAT44, self).tearDown()
         if not self.vpp_dead:
@@ -3816,6 +4164,7 @@
             self.logger.info(self.vapi.cli("show nat timeouts"))
             self.logger.info(
                 self.vapi.cli("show nat addr-port-assignment-alg"))
+            self.logger.info(self.vapi.cli("show nat ha"))
             self.clear_nat44()
             self.vapi.cli("clear logging")