L2-VTR: add vtr tests

re-enable l2 fib flush tests
reorder l2bd multi instance tests - move flags test as last
enabling of uu-flood will now flood when entry is stale

Change-Id: I052663ec3eb4acee5f296fb7525dd535924e0003
Signed-off-by: Eyal Bari <ebari@cisco.com>
diff --git a/test/test_l2bd_multi_instance.py b/test/test_l2bd_multi_instance.py
index 1cf1b13..0bb9e59 100644
--- a/test/test_l2bd_multi_instance.py
+++ b/test/test_l2bd_multi_instance.py
@@ -403,7 +403,39 @@
         self.run_verify_test()
 
     def test_l2bd_inst_02(self):
-        """ L2BD Multi-instance test 2 - update data of 5 BDs
+        """ L2BD Multi-instance test 2 - delete 2 BDs
+        """
+        # Config 3
+        # Delete 2 BDs
+        self.delete_bd(2)
+
+        # Verify 3
+        for bd_id in self.bd_deleted_list:
+            self.assertEqual(self.verify_bd(bd_id), 0)
+        for bd_id in self.bd_list:
+            self.assertEqual(self.verify_bd(bd_id), 1)
+
+        # Test 3
+        self.run_verify_test()
+
+    def test_l2bd_inst_03(self):
+        """ L2BD Multi-instance test 3 - add 2 BDs
+        """
+        # Config 4
+        # Create 5 BDs, put interfaces to these BDs and send MAC learning
+        # packets
+        self.create_bd_and_mac_learn(2)
+
+        # Verify 4
+        for bd_id in self.bd_list:
+            self.assertEqual(self.verify_bd(bd_id), 1)
+
+        # Test 4
+        # self.vapi.cli("clear trace")
+        self.run_verify_test()
+
+    def test_l2bd_inst_04(self):
+        """ L2BD Multi-instance test 4 - update data of 5 BDs
         """
         # Config 2
         # Update data of 5 BDs (disable learn, forward, flood, uu-flood)
@@ -428,40 +460,8 @@
         self.verify_bd(self.bd_list[4], learn=False, forward=True,
                        flood=True, uu_flood=True)
 
-    def test_l2bd_inst_03(self):
-        """ L2BD Multi-instance 3 - delete 2 BDs
-        """
-        # Config 3
-        # Delete 2 BDs
-        self.delete_bd(2)
-
-        # Verify 3
-        for bd_id in self.bd_deleted_list:
-            self.assertEqual(self.verify_bd(bd_id), 0)
-        for bd_id in self.bd_list:
-            self.assertEqual(self.verify_bd(bd_id), 1)
-
-        # Test 3
-        self.run_verify_test()
-
-    def test_l2bd_inst_04(self):
-        """ L2BD Multi-instance test 4 - add 2 BDs
-        """
-        # Config 4
-        # Create 5 BDs, put interfaces to these BDs and send MAC learning
-        # packets
-        self.create_bd_and_mac_learn(2)
-
-        # Verify 4
-        for bd_id in self.bd_list:
-            self.assertEqual(self.verify_bd(bd_id), 1)
-
-        # Test 4
-        # self.vapi.cli("clear trace")
-        self.run_verify_test()
-
     def test_l2bd_inst_05(self):
-        """ L2BD Multi-instance 5 - delete 5 BDs
+        """ L2BD Multi-instance test 5 - delete 5 BDs
         """
         # Config 5
         # Delete 5 BDs
diff --git a/test/test_vtr.py b/test/test_vtr.py
new file mode 100644
index 0000000..02df2ce
--- /dev/null
+++ b/test/test_vtr.py
@@ -0,0 +1,332 @@
+#!/usr/bin/env python
+
+import unittest
+import random
+
+from scapy.packet import Raw
+from scapy.layers.l2 import Ether, Dot1Q
+from scapy.layers.inet import IP, UDP
+
+from util import Host
+from framework import VppTestCase, VppTestRunner
+from vpp_sub_interface import VppDot1QSubint, VppDot1ADSubint
+from vpp_papi_provider import L2_VTR_OP
+from collections import namedtuple
+
+Tag = namedtuple('Tag', ['dot1', 'vlan'])
+DOT1AD = 0x88A8
+DOT1Q = 0x8100
+
+
+class TestVtr(VppTestCase):
+    """ VTR Test Case """
+
+    @classmethod
+    def setUpClass(cls):
+        super(TestVtr, cls).setUpClass()
+
+        # Test variables
+        cls.bd_id = 1
+        cls.mac_entries_count = 5
+        cls.Atag = 100
+        cls.Btag = 200
+        cls.dot1ad_sub_id = 20
+
+        try:
+            ifs = range(3)
+            cls.create_pg_interfaces(ifs)
+
+            cls.sub_interfaces = [
+                VppDot1ADSubint(cls, cls.pg1, cls.dot1ad_sub_id,
+                                cls.Btag, cls.Atag),
+                VppDot1QSubint(cls, cls.pg2, cls.Btag)]
+
+            interfaces = list(cls.pg_interfaces)
+            interfaces.extend(cls.sub_interfaces)
+
+            # Create BD with MAC learning enabled and put interfaces and
+            #  sub-interfaces to this BD
+            for pg_if in cls.pg_interfaces:
+                sw_if_index = pg_if.sub_if.sw_if_index \
+                    if hasattr(pg_if, 'sub_if') else pg_if.sw_if_index
+                cls.vapi.sw_interface_set_l2_bridge(sw_if_index,
+                                                    bd_id=cls.bd_id)
+
+            # setup all interfaces
+            for i in interfaces:
+                i.admin_up()
+
+            # mapping between packet-generator index and lists of test hosts
+            cls.hosts_by_pg_idx = dict()
+
+            # create test host entries and inject packets to learn MAC entries
+            # in the bridge-domain
+            cls.create_hosts_and_learn(cls.mac_entries_count)
+            cls.logger.info(cls.vapi.ppcli("show l2fib"))
+
+        except Exception:
+            super(TestVtr, cls).tearDownClass()
+            raise
+
+    def setUp(self):
+        """
+        Clear trace and packet infos before running each test.
+        """
+        super(TestVtr, self).setUp()
+        self.reset_packet_infos()
+
+    def tearDown(self):
+        """
+        Show various debug prints after each test.
+        """
+        super(TestVtr, self).tearDown()
+        if not self.vpp_dead:
+            self.logger.info(self.vapi.ppcli("show l2fib verbose"))
+            self.logger.info(self.vapi.ppcli("show bridge-domain %s detail" %
+                                             self.bd_id))
+
+    @classmethod
+    def create_hosts_and_learn(cls, count):
+        for pg_if in cls.pg_interfaces:
+            cls.hosts_by_pg_idx[pg_if.sw_if_index] = []
+            hosts = cls.hosts_by_pg_idx[pg_if.sw_if_index]
+            packets = []
+            for j in range(1, count + 1):
+                host = Host(
+                    "00:00:00:ff:%02x:%02x" % (pg_if.sw_if_index, j),
+                    "172.17.1%02x.%u" % (pg_if.sw_if_index, j))
+                packet = (Ether(dst="ff:ff:ff:ff:ff:ff", src=host.mac))
+                hosts.append(host)
+                if hasattr(pg_if, 'sub_if'):
+                    packet = pg_if.sub_if.add_dot1_layer(packet)
+                packets.append(packet)
+            pg_if.add_stream(packets)
+        cls.logger.info("Sending broadcast eth frames for MAC learning")
+        cls.pg_enable_capture(cls.pg_interfaces)
+        cls.pg_start()
+
+    def create_packet(self, src_if, dst_if, do_dot1=True):
+        packet_sizes = [64, 512, 1518, 9018]
+        dst_host = random.choice(self.hosts_by_pg_idx[dst_if.sw_if_index])
+        src_host = random.choice(self.hosts_by_pg_idx[src_if.sw_if_index])
+        pkt_info = self.create_packet_info(src_if, dst_if)
+        payload = self.info_to_payload(pkt_info)
+        p = (Ether(dst=dst_host.mac, src=src_host.mac) /
+             IP(src=src_host.ip4, dst=dst_host.ip4) /
+             UDP(sport=1234, dport=1234) /
+             Raw(payload))
+        pkt_info.data = p.copy()
+        if do_dot1 and hasattr(src_if, 'sub_if'):
+            p = src_if.sub_if.add_dot1_layer(p)
+        size = random.choice(packet_sizes)
+        self.extend_packet(p, size)
+        return p
+
+    def _add_tag(self, packet, vlan, tag_type):
+        payload = packet.payload
+        inner_type = packet.type
+        packet.remove_payload()
+        packet.add_payload(Dot1Q(vlan=vlan) / payload)
+        packet.payload.type = inner_type
+        packet.payload.vlan = vlan
+        packet.type = tag_type
+        return packet
+
+    def _remove_tag(self, packet, vlan=None, tag_type=None):
+        if tag_type:
+            self.assertEqual(packet.type, tag_type)
+
+        payload = packet.payload
+        if vlan:
+            self.assertEqual(payload.vlan, vlan)
+        inner_type = payload.type
+        payload = payload.payload
+        packet.remove_payload()
+        packet.add_payload(payload)
+        packet.type = inner_type
+
+    def add_tags(self, packet, tags):
+        for t in reversed(tags):
+            self._add_tag(packet, t.vlan, t.dot1)
+
+    def remove_tags(self, packet, tags):
+        for t in tags:
+            self._remove_tag(packet, t.vlan, t.dot1)
+
+    def vtr_test(self, swif, tags):
+        p = self.create_packet(swif, self.pg0)
+        swif.add_stream(p)
+        self.pg_enable_capture(self.pg_interfaces)
+        self.pg_start()
+        rx = self.pg0.get_capture(1)
+
+        if tags:
+            self.remove_tags(rx[0], tags)
+        self.assertTrue(Dot1Q not in rx[0])
+
+        if not tags:
+            return
+
+        i = VppDot1QSubint(self, self.pg0, tags[0].vlan)
+        self.vapi.sw_interface_set_l2_bridge(
+            i.sw_if_index, bd_id=self.bd_id, enable=1)
+        i.admin_up()
+
+        p = self.create_packet(self.pg0, swif, do_dot1=False)
+        self.add_tags(p, tags)
+        self.pg0.add_stream(p)
+        self.pg_enable_capture(self.pg_interfaces)
+        self.pg_start()
+        rx = swif.get_capture(1)
+        swif.sub_if.remove_dot1_layer(rx[0])
+        self.assertTrue(Dot1Q not in rx[0])
+
+        self.vapi.sw_interface_set_l2_bridge(
+            i.sw_if_index, bd_id=self.bd_id, enable=0)
+        i.remove_vpp_config()
+
+    def test_1ad_vtr_pop_1(self):
+        """ 1AD VTR pop 1 test
+        """
+        self.pg1.sub_if.set_vtr(L2_VTR_OP.L2_POP_1)
+        self.vtr_test(self.pg1, [Tag(dot1=DOT1Q, vlan=100)])
+
+    def test_1ad_vtr_pop_2(self):
+        """ 1AD VTR pop 2 test
+        """
+        self.pg1.sub_if.set_vtr(L2_VTR_OP.L2_POP_2)
+        self.vtr_test(self.pg1, [])
+
+    def test_1ad_vtr_push_1ad(self):
+        """ 1AD VTR push 1 1AD test
+        """
+        self.pg1.sub_if.set_vtr(L2_VTR_OP.L2_PUSH_1, tag=300)
+        self.vtr_test(self.pg1, [Tag(dot1=DOT1AD, vlan=300),
+                                 Tag(dot1=DOT1AD, vlan=200),
+                                 Tag(dot1=DOT1Q, vlan=100)])
+
+    def test_1ad_vtr_push_2ad(self):
+        """ 1AD VTR push 2 1AD test
+        """
+        self.pg1.sub_if.set_vtr(L2_VTR_OP.L2_PUSH_2, outer=400, inner=300)
+        self.vtr_test(self.pg1, [Tag(dot1=DOT1AD, vlan=400),
+                                 Tag(dot1=DOT1Q, vlan=300),
+                                 Tag(dot1=DOT1AD, vlan=200),
+                                 Tag(dot1=DOT1Q, vlan=100)])
+
+    def test_1ad_vtr_push_1q(self):
+        """ 1AD VTR push 1 1Q test
+        """
+        self.pg1.sub_if.set_vtr(L2_VTR_OP.L2_PUSH_1, tag=300, push1q=1)
+        self.vtr_test(self.pg1, [Tag(dot1=DOT1Q, vlan=300),
+                                 Tag(dot1=DOT1AD, vlan=200),
+                                 Tag(dot1=DOT1Q, vlan=100)])
+
+    def test_1ad_vtr_push_2q(self):
+        """ 1AD VTR push 2 1Q test
+        """
+        self.pg1.sub_if.set_vtr(L2_VTR_OP.L2_PUSH_2,
+                                outer=400, inner=300, push1q=1)
+        self.vtr_test(self.pg1, [Tag(dot1=DOT1Q, vlan=400),
+                                 Tag(dot1=DOT1Q, vlan=300),
+                                 Tag(dot1=DOT1AD, vlan=200),
+                                 Tag(dot1=DOT1Q, vlan=100)])
+
+    def test_1ad_vtr_translate_1_1ad(self):
+        """ 1AD VTR translate 1 -> 1 1AD test
+        """
+        self.pg1.sub_if.set_vtr(L2_VTR_OP.L2_TRANSLATE_1_1, tag=300)
+        self.vtr_test(self.pg1, [Tag(dot1=DOT1AD, vlan=300),
+                                 Tag(dot1=DOT1Q, vlan=100)])
+
+    def test_1ad_vtr_translate_1_2ad(self):
+        """ 1AD VTR translate 1 -> 2 1AD test
+        """
+        self.pg1.sub_if.set_vtr(
+            L2_VTR_OP.L2_TRANSLATE_1_2, inner=300, outer=400)
+        self.vtr_test(self.pg1, [Tag(dot1=DOT1AD, vlan=400),
+                                 Tag(dot1=DOT1Q, vlan=300),
+                                 Tag(dot1=DOT1Q, vlan=100)])
+
+    def test_1ad_vtr_translate_2_1ad(self):
+        """ 1AD VTR translate 2 -> 1 1AD test
+        """
+        self.pg1.sub_if.set_vtr(L2_VTR_OP.L2_TRANSLATE_2_1, tag=300)
+        self.vtr_test(self.pg1, [Tag(dot1=DOT1AD, vlan=300)])
+
+    def test_1ad_vtr_translate_2_2ad(self):
+        """ 1AD VTR translate 2 -> 2 1AD test
+        """
+        self.pg1.sub_if.set_vtr(
+            L2_VTR_OP.L2_TRANSLATE_2_2, inner=300, outer=400)
+        self.vtr_test(self.pg1, [Tag(dot1=DOT1AD, vlan=400),
+                                 Tag(dot1=DOT1Q, vlan=300)])
+
+    def test_1ad_vtr_translate_1_1q(self):
+        """ 1AD VTR translate 1 -> 1 1Q test
+        """
+        self.pg1.sub_if.set_vtr(L2_VTR_OP.L2_TRANSLATE_1_1, tag=300, push1q=1)
+        self.vtr_test(self.pg1, [Tag(dot1=DOT1Q, vlan=300),
+                                 Tag(dot1=DOT1Q, vlan=100)])
+
+    def test_1ad_vtr_translate_1_2q(self):
+        """ 1AD VTR translate 1 -> 2 1Q test
+        """
+        self.pg1.sub_if.set_vtr(
+            L2_VTR_OP.L2_TRANSLATE_1_2, inner=300, outer=400, push1q=1)
+        self.vtr_test(self.pg1, [Tag(dot1=DOT1Q, vlan=400),
+                                 Tag(dot1=DOT1Q, vlan=300),
+                                 Tag(dot1=DOT1Q, vlan=100)])
+
+    def test_1ad_vtr_translate_2_1q(self):
+        """ 1AD VTR translate 2 -> 1 1Q test
+        """
+        self.pg1.sub_if.set_vtr(L2_VTR_OP.L2_TRANSLATE_2_1, tag=300, push1q=1)
+        self.vtr_test(self.pg1, [Tag(dot1=DOT1Q, vlan=300)])
+
+    def test_1ad_vtr_translate_2_2q(self):
+        """ 1AD VTR translate 2 -> 2 1Q test
+        """
+        self.pg1.sub_if.set_vtr(
+            L2_VTR_OP.L2_TRANSLATE_2_2, inner=300, outer=400, push1q=1)
+        self.vtr_test(self.pg1, [Tag(dot1=DOT1Q, vlan=400),
+                                 Tag(dot1=DOT1Q, vlan=300)])
+
+    def test_1q_vtr_pop_1(self):
+        """ 1Q VTR pop 1 test
+        """
+        self.pg2.sub_if.set_vtr(L2_VTR_OP.L2_POP_1)
+        self.vtr_test(self.pg2, [])
+
+    def test_1q_vtr_push_1(self):
+        """ 1Q VTR push 1 test
+        """
+        self.pg2.sub_if.set_vtr(L2_VTR_OP.L2_PUSH_1, tag=300)
+        self.vtr_test(self.pg2, [Tag(dot1=DOT1AD, vlan=300),
+                                 Tag(dot1=DOT1Q, vlan=200)])
+
+    def test_1q_vtr_push_2(self):
+        """ 1Q VTR push 2 test
+        """
+        self.pg2.sub_if.set_vtr(L2_VTR_OP.L2_PUSH_2, outer=400, inner=300)
+        self.vtr_test(self.pg2, [Tag(dot1=DOT1AD, vlan=400),
+                                 Tag(dot1=DOT1Q, vlan=300),
+                                 Tag(dot1=DOT1Q, vlan=200)])
+
+    def test_1q_vtr_translate_1_1(self):
+        """ 1Q VTR translate 1 -> 1 test
+        """
+        self.pg2.sub_if.set_vtr(L2_VTR_OP.L2_TRANSLATE_1_1, tag=300)
+        self.vtr_test(self.pg2, [Tag(dot1=DOT1AD, vlan=300)])
+
+    def test_1q_vtr_translate_1_2(self):
+        """ 1Q VTR translate 1 -> 2 test
+        """
+        self.pg2.sub_if.set_vtr(
+            L2_VTR_OP.L2_TRANSLATE_1_2, inner=300, outer=400)
+        self.vtr_test(self.pg2, [Tag(dot1=DOT1AD, vlan=400),
+                                 Tag(dot1=DOT1Q, vlan=300)])
+
+
+if __name__ == '__main__':
+    unittest.main(testRunner=VppTestRunner)
diff --git a/test/vpp_papi_provider.py b/test/vpp_papi_provider.py
index 4c02a34..5f27e85 100644
--- a/test/vpp_papi_provider.py
+++ b/test/vpp_papi_provider.py
@@ -25,7 +25,15 @@
 
 
 class L2_VTR_OP:
+    L2_DISABLED = 0
+    L2_PUSH_1 = 1
+    L2_PUSH_2 = 2
     L2_POP_1 = 3
+    L2_POP_2 = 4
+    L2_TRANSLATE_1_1 = 5
+    L2_TRANSLATE_1_2 = 6
+    L2_TRANSLATE_2_1 = 7
+    L2_TRANSLATE_2_2 = 8
 
 
 class UnexpectedApiReturnValueError(Exception):
diff --git a/test/vpp_sub_interface.py b/test/vpp_sub_interface.py
index b4c415c..dcd82da 100644
--- a/test/vpp_sub_interface.py
+++ b/test/vpp_sub_interface.py
@@ -1,7 +1,8 @@
-from scapy.layers.l2 import Ether, Dot1Q
+from scapy.layers.l2 import Dot1Q
 from abc import abstractmethod, ABCMeta
 from vpp_interface import VppInterface
 from vpp_pg_interface import VppPGInterface
+from vpp_papi_provider import L2_VTR_OP
 
 
 class VppSubInterface(VppPGInterface):
@@ -17,11 +18,26 @@
         """Sub-interface ID"""
         return self._sub_id
 
+    @property
+    def tag1(self):
+        return self._tag1
+
+    @property
+    def tag2(self):
+        return self._tag2
+
+    @property
+    def vtr(self):
+        return self._vtr
+
     def __init__(self, test, parent, sub_id):
         VppInterface.__init__(self, test)
         self._parent = parent
         self._parent.add_sub_if(self)
         self._sub_id = sub_id
+        self.set_vtr(L2_VTR_OP.L2_DISABLED)
+        self.DOT1AD_TYPE = 0x88A8
+        self.DOT1Q_TYPE = 0x8100
 
     @abstractmethod
     def create_arp_req(self):
@@ -44,6 +60,66 @@
     def remove_vpp_config(self):
         self.test.vapi.delete_subif(self._sw_if_index)
 
+    def _add_tag(self, packet, vlan, tag_type):
+        payload = packet.payload
+        inner_type = packet.type
+        packet.remove_payload()
+        packet.add_payload(Dot1Q(vlan=vlan) / payload)
+        packet.payload.type = inner_type
+        packet.payload.vlan = vlan
+        packet.type = tag_type
+        return packet
+
+    def _remove_tag(self, packet, vlan=None, tag_type=None):
+        if tag_type:
+            self.test.instance().assertEqual(packet.type, tag_type)
+
+        payload = packet.payload
+        if vlan:
+            self.test.instance().assertEqual(payload.vlan, vlan)
+        inner_type = payload.type
+        payload = payload.payload
+        packet.remove_payload()
+        packet.add_payload(payload)
+        packet.type = inner_type
+        return packet
+
+    def add_dot1q_layer(self, packet, vlan):
+        return self._add_tag(packet, vlan, self.DOT1Q_TYPE)
+
+    def add_dot1ad_layer(self, packet, outer, inner):
+        p = self._add_tag(packet, inner, self.DOT1Q_TYPE)
+        return self._add_tag(p, outer, self.DOT1AD_TYPE)
+
+    def remove_dot1q_layer(self, packet, vlan=None):
+        return self._remove_tag(packet, vlan, self.DOT1Q_TYPE)
+
+    def remove_dot1ad_layer(self, packet, outer=None, inner=None):
+        p = self._remove_tag(packet, outer, self.DOT1AD_TYPE)
+        return self._remove_tag(p, inner, self.DOT1Q_TYPE)
+
+    def set_vtr(self, vtr, push1q=0, tag=None, inner=None, outer=None):
+        self._tag1 = 0
+        self._tag2 = 0
+        self._push1q = 0
+
+        if (vtr == L2_VTR_OP.L2_PUSH_1 or
+            vtr == L2_VTR_OP.L2_TRANSLATE_1_1 or
+                vtr == L2_VTR_OP.L2_TRANSLATE_2_1):
+            self._tag1 = tag
+            self._push1q = push1q
+        if (vtr == L2_VTR_OP.L2_PUSH_2 or
+            vtr == L2_VTR_OP.L2_TRANSLATE_1_2 or
+                vtr == L2_VTR_OP.L2_TRANSLATE_2_2):
+            self._tag1 = outer
+            self._tag2 = inner
+            self._push1q = push1q
+
+        self.test.vapi.sw_interface_set_l2_tag_rewrite(
+            self.sw_if_index, vtr, push=self._push1q,
+            tag1=self._tag1, tag2=self._tag2)
+        self._vtr = vtr
+
 
 class VppDot1QSubint(VppSubInterface):
 
@@ -68,20 +144,13 @@
         packet = VppPGInterface.create_ndp_req(self)
         return self.add_dot1_layer(packet)
 
+    # called before sending packet
     def add_dot1_layer(self, packet):
-        payload = packet.payload
-        packet.remove_payload()
-        packet.add_payload(Dot1Q(vlan=self.sub_id) / payload)
-        return packet
+        return self.add_dot1q_layer(packet, self.vlan)
 
+    # called on received packet to "reverse" the add call
     def remove_dot1_layer(self, packet):
-        payload = packet.payload
-        self.test.instance().assertEqual(type(payload), Dot1Q)
-        self.test.instance().assertEqual(payload.vlan, self.vlan)
-        payload = payload.payload
-        packet.remove_payload()
-        packet.add_payload(payload)
-        return packet
+        return self.remove_dot1q_layer(packet, self.vlan)
 
 
 class VppDot1ADSubint(VppSubInterface):
@@ -101,11 +170,9 @@
                                    inner_vlan, dot1ad=1, two_tags=1,
                                    exact_match=1)
         self._sw_if_index = r.sw_if_index
-        super(VppDot1ADSubint, self).__init__(test, parent, sub_id)
-        self.DOT1AD_TYPE = 0x88A8
-        self.DOT1Q_TYPE = 0x8100
         self._outer_vlan = outer_vlan
         self._inner_vlan = inner_vlan
+        super(VppDot1ADSubint, self).__init__(test, parent, sub_id)
 
     def create_arp_req(self):
         packet = VppPGInterface.create_arp_req(self)
@@ -116,25 +183,8 @@
         return self.add_dot1_layer(packet)
 
     def add_dot1_layer(self, packet):
-        payload = packet.payload
-        packet.remove_payload()
-        packet.add_payload(Dot1Q(vlan=self.outer_vlan) /
-                           Dot1Q(vlan=self.inner_vlan) / payload)
-        packet.type = self.DOT1AD_TYPE
-        return packet
+        return self.add_dot1ad_layer(packet, self.outer_vlan, self.inner_vlan)
 
     def remove_dot1_layer(self, packet):
-        self.test.instance().assertEqual(type(packet), Ether)
-        self.test.instance().assertEqual(packet.type, self.DOT1AD_TYPE)
-        packet.type = self.DOT1Q_TYPE
-        packet = Ether(str(packet))
-        payload = packet.payload
-        self.test.instance().assertEqual(type(payload), Dot1Q)
-        self.test.instance().assertEqual(payload.vlan, self.outer_vlan)
-        payload = payload.payload
-        self.test.instance().assertEqual(type(payload), Dot1Q)
-        self.test.instance().assertEqual(payload.vlan, self.inner_vlan)
-        payload = payload.payload
-        packet.remove_payload()
-        packet.add_payload(payload)
-        return packet
+        return self.remove_dot1ad_layer(packet, self.outer_vlan,
+                                        self.inner_vlan)