ip: respect buffer boundary when searching for ipv6 headers

Type: fix

Change-Id: I5a5461652f8115fa1270e20f748178fb5f5450f2
Signed-off-by: Klement Sekera <ksekera@cisco.com>
diff --git a/src/examples/srv6-sample-localsid/node.c b/src/examples/srv6-sample-localsid/node.c
index 3ac7108..e3a3259 100644
--- a/src/examples/srv6-sample-localsid/node.c
+++ b/src/examples/srv6-sample-localsid/node.c
@@ -188,7 +188,6 @@
       vlib_buffer_t * b0;
       ip6_header_t * ip0 = 0;
       ip6_sr_header_t * sr0;
-      ip6_ext_header_t *prev0
       u32 next0 = SRV6_SAMPLE_LOCALSID_NEXT_IP6LOOKUP;
       ip6_sr_localsid_t *ls0;
       srv6_localsid_sample_per_sid_memory_t *ls0_mem;
@@ -209,7 +208,7 @@
       ls0_mem = ls0->plugin_mem;
 
       /* SRH processing */
-      ip6_ext_header_find_t (ip0, prev0, sr0, IP_PROTOCOL_IPV6_ROUTE);
+      sr0 = ip6_ext_header_find (vm, b0, ip0, IP_PROTOCOL_IPV6_ROUTE, NULL);
       end_decaps_srh_processing (node, b0, ip0, sr0, ls0, &next0);
 
       /* ==================================================================== */
diff --git a/src/vnet/ip/ip6_packet.h b/src/vnet/ip/ip6_packet.h
index c1bd2aa..ed96ece 100644
--- a/src/vnet/ip/ip6_packet.h
+++ b/src/vnet/ip/ip6_packet.h
@@ -510,6 +510,7 @@
   /* Length of this header plus option data in 8 byte units. */
   u8 n_data_u64s;
 }) ip6_ext_header_t;
+/* *INDENT-ON* */
 
 #define foreach_ext_hdr_type \
   _(IP6_HOP_BY_HOP_OPTIONS) \
@@ -522,12 +523,13 @@
   _(HIP) \
   _(SHIM6)
 
-always_inline u8 ip6_ext_hdr(u8 nexthdr)
+always_inline u8
+ip6_ext_hdr (u8 nexthdr)
 {
 #ifdef CLIB_HAVE_VEC128
   static const u8x16 ext_hdr_types = {
 #define _(x) IP_PROTOCOL_##x,
- foreach_ext_hdr_type
+    foreach_ext_hdr_type
 #undef _
   };
 
@@ -536,9 +538,9 @@
   /*
    * find out if nexthdr is an extension header or a protocol
    */
-  return   0
+  return 0
 #define _(x) || (nexthdr == IP_PROTOCOL_##x)
- foreach_ext_hdr_type;
+    foreach_ext_hdr_type;
 #undef _
 #endif
 }
@@ -547,37 +549,79 @@
 #define ip6_ext_authhdr_len(p) ((((ip6_ext_header_t *)(p))->n_data_u64s+2) << 2)
 
 always_inline void *
-ip6_ext_next_header (ip6_ext_header_t *ext_hdr )
-{ return (void *)((u8 *) ext_hdr + ip6_ext_header_len(ext_hdr)); }
-
-/*
- * Macro to find the IPv6 ext header of type t
- * I is the IPv6 header
- * P is the previous IPv6 ext header (NULL if none)
- * M is the matched IPv6 ext header of type t
- */
-#define ip6_ext_header_find_t(i, p, m, t)               \
-if ((i)->protocol == t)                                 \
-{                                                       \
-  (m) = (void *)((i)+1);                                \
-  (p) = NULL;                                           \
-}                                                       \
-else                                                    \
-{                                                       \
-  (m) = NULL;                                           \
-  (p) = (void *)((i)+1);                                \
-  while (ip6_ext_hdr((p)->next_hdr) &&                  \
-    ((ip6_ext_header_t *)(p))->next_hdr != (t))         \
-  {                                                     \
-    (p) = ip6_ext_next_header((p));                     \
-  }                                                     \
-  if ( ((p)->next_hdr) == (t))                          \
-  {                                                     \
-    (m) = (void *)(ip6_ext_next_header((p)));           \
-  }                                                     \
+ip6_ext_next_header (ip6_ext_header_t * ext_hdr)
+{
+  return (void *) ((u8 *) ext_hdr + ip6_ext_header_len (ext_hdr));
 }
 
+always_inline int
+vlib_object_within_buffer_data (vlib_main_t * vm, vlib_buffer_t * b,
+				void *obj, size_t len)
+{
+  u8 *o = obj;
+  if (o < b->data ||
+      o + len > b->data + vlib_buffer_get_default_data_size (vm))
+    return 0;
+  return 1;
+}
 
+/*
+ * find ipv6 extension header within ipv6 header within buffer b
+ *
+ * @param vm
+ * @param b buffer to limit search to
+ * @param ip6_header ipv6 header
+ * @param header_type extension header type to search for
+ * @param[out] prev_ext_header address of header preceding found header
+ */
+always_inline void *
+ip6_ext_header_find (vlib_main_t * vm, vlib_buffer_t * b,
+		     ip6_header_t * ip6_header, u8 header_type,
+		     ip6_ext_header_t ** prev_ext_header)
+{
+  ip6_ext_header_t *prev = NULL;
+  ip6_ext_header_t *result = NULL;
+  if ((ip6_header)->protocol == header_type)
+    {
+      result = (void *) (ip6_header + 1);
+      if (!vlib_object_within_buffer_data (vm, b, result,
+					   ip6_ext_header_len (result)))
+	{
+	  result = NULL;
+	}
+    }
+  else
+    {
+      result = NULL;
+      prev = (void *) (ip6_header + 1);
+      while (ip6_ext_hdr (prev->next_hdr) && prev->next_hdr != header_type)
+	{
+	  prev = ip6_ext_next_header (prev);
+	  if (!vlib_object_within_buffer_data (vm, b, prev,
+					       ip6_ext_header_len (prev)))
+	    {
+	      prev = NULL;
+	      break;
+	    }
+	}
+      if (prev && (prev->next_hdr == header_type))
+	{
+	  result = ip6_ext_next_header (prev);
+	  if (!vlib_object_within_buffer_data (vm, b, result,
+					       ip6_ext_header_len (result)))
+	    {
+	      result = NULL;
+	    }
+	}
+    }
+  if (prev_ext_header)
+    {
+      *prev_ext_header = prev;
+    }
+  return result;
+}
+
+/* *INDENT-OFF* */
 typedef CLIB_PACKED (struct {
   u8 next_hdr;
   /* Length of this header plus option data in 8 byte units. */
diff --git a/src/vnet/ip/reass/ip6_full_reass.c b/src/vnet/ip/reass/ip6_full_reass.c
index 7b11e78..ef10149 100644
--- a/src/vnet/ip/reass/ip6_full_reass.c
+++ b/src/vnet/ip/reass/ip6_full_reass.c
@@ -688,8 +688,9 @@
   ip6_header_t *ip = vlib_buffer_get_current (first_b);
   u16 ip6_frag_hdr_offset = first_b_vnb->ip.reass.ip6_frag_hdr_offset;
   ip6_ext_header_t *prev_hdr;
-  ip6_ext_header_find_t (ip, prev_hdr, frag_hdr,
-			 IP_PROTOCOL_IPV6_FRAGMENTATION);
+  frag_hdr =
+    ip6_ext_header_find (vm, first_b, ip, IP_PROTOCOL_IPV6_FRAGMENTATION,
+			 &prev_hdr);
   if (prev_hdr)
     {
       prev_hdr->next_hdr = frag_hdr->next_hdr;
@@ -1040,8 +1041,10 @@
 	  ip6_ext_header_t *prev_hdr;
 	  if (ip6_ext_hdr (ip0->protocol))
 	    {
-	      ip6_ext_header_find_t (ip0, prev_hdr, frag_hdr,
-				     IP_PROTOCOL_IPV6_FRAGMENTATION);
+	      frag_hdr =
+		ip6_ext_header_find (vm, b0, ip0,
+				     IP_PROTOCOL_IPV6_FRAGMENTATION,
+				     &prev_hdr);
 	    }
 	  if (!frag_hdr)
 	    {
diff --git a/src/vnet/ipsec/ah_decrypt.c b/src/vnet/ipsec/ah_decrypt.c
index bbe6b64..f46fa6e 100644
--- a/src/vnet/ipsec/ah_decrypt.c
+++ b/src/vnet/ipsec/ah_decrypt.c
@@ -184,7 +184,8 @@
       if (is_ip6)
 	{
 	  ip6_ext_header_t *prev = NULL;
-	  ip6_ext_header_find_t (ih6, prev, ah0, IP_PROTOCOL_IPSEC_AH);
+	  ah0 =
+	    ip6_ext_header_find (vm, b[0], ih6, IP_PROTOCOL_IPSEC_AH, &prev);
 	  pd->ip_hdr_size = sizeof (ip6_header_t);
 	  ASSERT ((u8 *) ah0 - (u8 *) ih6 == pd->ip_hdr_size);
 	}
diff --git a/src/vnet/srv6/sr_localsid.c b/src/vnet/srv6/sr_localsid.c
index 6d7c26b..c592f79 100755
--- a/src/vnet/srv6/sr_localsid.c
+++ b/src/vnet/srv6/sr_localsid.c
@@ -901,7 +901,6 @@
 	  u32 bi0, bi1, bi2, bi3;
 	  vlib_buffer_t *b0, *b1, *b2, *b3;
 	  ip6_header_t *ip0, *ip1, *ip2, *ip3;
-	  ip6_ext_header_t *prev0, *prev1, *prev2, *prev3;
 	  ip6_sr_header_t *sr0, *sr1, *sr2, *sr3;
 	  u32 next0, next1, next2, next3;
 	  next0 = next1 = next2 = next3 = SR_LOCALSID_NEXT_IP6_LOOKUP;
@@ -960,10 +959,14 @@
 	  ip2 = vlib_buffer_get_current (b2);
 	  ip3 = vlib_buffer_get_current (b3);
 
-	  ip6_ext_header_find_t (ip0, prev0, sr0, IP_PROTOCOL_IPV6_ROUTE);
-	  ip6_ext_header_find_t (ip1, prev1, sr1, IP_PROTOCOL_IPV6_ROUTE);
-	  ip6_ext_header_find_t (ip2, prev2, sr2, IP_PROTOCOL_IPV6_ROUTE);
-	  ip6_ext_header_find_t (ip3, prev3, sr3, IP_PROTOCOL_IPV6_ROUTE);
+	  sr0 =
+	    ip6_ext_header_find (vm, b0, ip0, IP_PROTOCOL_IPV6_ROUTE, NULL);
+	  sr1 =
+	    ip6_ext_header_find (vm, b1, ip1, IP_PROTOCOL_IPV6_ROUTE, NULL);
+	  sr2 =
+	    ip6_ext_header_find (vm, b2, ip2, IP_PROTOCOL_IPV6_ROUTE, NULL);
+	  sr3 =
+	    ip6_ext_header_find (vm, b3, ip3, IP_PROTOCOL_IPV6_ROUTE, NULL);
 
 	  end_decaps_srh_processing (node, b0, ip0, sr0, ls0, &next0);
 	  end_decaps_srh_processing (node, b1, ip1, sr1, ls1, &next1);
@@ -1097,7 +1100,6 @@
 	  u32 bi0;
 	  vlib_buffer_t *b0;
 	  ip6_header_t *ip0;
-	  ip6_ext_header_t *prev0;
 	  ip6_sr_header_t *sr0;
 	  u32 next0 = SR_LOCALSID_NEXT_IP6_LOOKUP;
 	  ip6_sr_localsid_t *ls0;
@@ -1118,7 +1120,8 @@
 			       vnet_buffer (b0)->ip.adj_index[VLIB_TX]);
 
 	  /* Find SRH as well as previous header */
-	  ip6_ext_header_find_t (ip0, prev0, sr0, IP_PROTOCOL_IPV6_ROUTE);
+	  sr0 =
+	    ip6_ext_header_find (vm, b0, ip0, IP_PROTOCOL_IPV6_ROUTE, NULL);
 
 	  /* SRH processing and End variants */
 	  end_decaps_srh_processing (node, b0, ip0, sr0, ls0, &next0);
@@ -1250,10 +1253,14 @@
 	  ip2 = vlib_buffer_get_current (b2);
 	  ip3 = vlib_buffer_get_current (b3);
 
-	  ip6_ext_header_find_t (ip0, prev0, sr0, IP_PROTOCOL_IPV6_ROUTE);
-	  ip6_ext_header_find_t (ip1, prev1, sr1, IP_PROTOCOL_IPV6_ROUTE);
-	  ip6_ext_header_find_t (ip2, prev2, sr2, IP_PROTOCOL_IPV6_ROUTE);
-	  ip6_ext_header_find_t (ip3, prev3, sr3, IP_PROTOCOL_IPV6_ROUTE);
+	  sr0 =
+	    ip6_ext_header_find (vm, b0, ip0, IP_PROTOCOL_IPV6_ROUTE, &prev0);
+	  sr1 =
+	    ip6_ext_header_find (vm, b1, ip1, IP_PROTOCOL_IPV6_ROUTE, &prev1);
+	  sr2 =
+	    ip6_ext_header_find (vm, b2, ip2, IP_PROTOCOL_IPV6_ROUTE, &prev2);
+	  sr3 =
+	    ip6_ext_header_find (vm, b3, ip3, IP_PROTOCOL_IPV6_ROUTE, &prev3);
 
 	  ls0 =
 	    pool_elt_at_index (sm->localsids,
@@ -1418,7 +1425,8 @@
 
 	  b0 = vlib_get_buffer (vm, bi0);
 	  ip0 = vlib_buffer_get_current (b0);
-	  ip6_ext_header_find_t (ip0, prev0, sr0, IP_PROTOCOL_IPV6_ROUTE);
+	  sr0 =
+	    ip6_ext_header_find (vm, b0, ip0, IP_PROTOCOL_IPV6_ROUTE, &prev0);
 
 	  /* Lookup the SR End behavior based on IP DA (adj) */
 	  ls0 =
diff --git a/src/vnet/srv6/sr_policy_rewrite.c b/src/vnet/srv6/sr_policy_rewrite.c
index aa2f067..feac151 100755
--- a/src/vnet/srv6/sr_policy_rewrite.c
+++ b/src/vnet/srv6/sr_policy_rewrite.c
@@ -2921,7 +2921,6 @@
 	  ip6_header_t *ip0, *ip1, *ip2, *ip3;
 	  ip6_header_t *ip0_encap, *ip1_encap, *ip2_encap, *ip3_encap;
 	  ip6_sr_header_t *sr0, *sr1, *sr2, *sr3;
-	  ip6_ext_header_t *prev0, *prev1, *prev2, *prev3;
 	  ip6_sr_sl_t *sl0, *sl1, *sl2, *sl3;
 
 	  /* Prefetch next iteration. */
@@ -2985,14 +2984,18 @@
 	  ip2_encap = vlib_buffer_get_current (b2);
 	  ip3_encap = vlib_buffer_get_current (b3);
 
-	  ip6_ext_header_find_t (ip0_encap, prev0, sr0,
-				 IP_PROTOCOL_IPV6_ROUTE);
-	  ip6_ext_header_find_t (ip1_encap, prev1, sr1,
-				 IP_PROTOCOL_IPV6_ROUTE);
-	  ip6_ext_header_find_t (ip2_encap, prev2, sr2,
-				 IP_PROTOCOL_IPV6_ROUTE);
-	  ip6_ext_header_find_t (ip3_encap, prev3, sr3,
-				 IP_PROTOCOL_IPV6_ROUTE);
+	  sr0 =
+	    ip6_ext_header_find (vm, b0, ip0_encap, IP_PROTOCOL_IPV6_ROUTE,
+				 NULL);
+	  sr1 =
+	    ip6_ext_header_find (vm, b1, ip1_encap, IP_PROTOCOL_IPV6_ROUTE,
+				 NULL);
+	  sr2 =
+	    ip6_ext_header_find (vm, b2, ip2_encap, IP_PROTOCOL_IPV6_ROUTE,
+				 NULL);
+	  sr3 =
+	    ip6_ext_header_find (vm, b3, ip3_encap, IP_PROTOCOL_IPV6_ROUTE,
+				 NULL);
 
 	  end_bsid_encaps_srh_processing (node, b0, ip0_encap, sr0, &next0);
 	  end_bsid_encaps_srh_processing (node, b1, ip1_encap, sr1, &next1);
@@ -3078,7 +3081,6 @@
 	  u32 bi0;
 	  vlib_buffer_t *b0;
 	  ip6_header_t *ip0 = 0, *ip0_encap = 0;
-	  ip6_ext_header_t *prev0;
 	  ip6_sr_header_t *sr0;
 	  ip6_sr_sl_t *sl0;
 	  u32 next0 = SR_POLICY_REWRITE_NEXT_IP6_LOOKUP;
@@ -3098,8 +3100,9 @@
 		  vec_len (sl0->rewrite));
 
 	  ip0_encap = vlib_buffer_get_current (b0);
-	  ip6_ext_header_find_t (ip0_encap, prev0, sr0,
-				 IP_PROTOCOL_IPV6_ROUTE);
+	  sr0 =
+	    ip6_ext_header_find (vm, b0, ip0_encap, IP_PROTOCOL_IPV6_ROUTE,
+				 NULL);
 	  end_bsid_encaps_srh_processing (node, b0, ip0_encap, sr0, &next0);
 
 	  clib_memcpy_fast (((u8 *) ip0_encap) - vec_len (sl0->rewrite),
diff --git a/test/test_reassembly.py b/test/test_reassembly.py
index 0b7073c..407b626 100644
--- a/test/test_reassembly.py
+++ b/test/test_reassembly.py
@@ -10,8 +10,8 @@
 from scapy.packet import Raw
 from scapy.layers.l2 import Ether, GRE
 from scapy.layers.inet import IP, UDP, ICMP
-from scapy.layers.inet6 import IPv6, IPv6ExtHdrFragment, ICMPv6ParamProblem,\
-    ICMPv6TimeExceeded
+from scapy.layers.inet6 import HBHOptUnknown, ICMPv6ParamProblem,\
+        ICMPv6TimeExceeded, IPv6, IPv6ExtHdrFragment, IPv6ExtHdrHopByHop
 from framework import VppTestCase, VppTestRunner
 from util import ppp, ppc, fragment_rfc791, fragment_rfc8200
 from vpp_gre_interface import VppGreInterface
@@ -818,6 +818,23 @@
         self.verify_capture(packets)
         self.src_if.assert_nothing_captured()
 
+    def test_buffer_boundary(self):
+        """ fragment header crossing buffer boundary """
+
+        p = (Ether(dst=self.src_if.local_mac, src=self.src_if.remote_mac) /
+             IPv6(src=self.src_if.remote_ip6,
+                  dst=self.src_if.local_ip6) /
+             IPv6ExtHdrHopByHop(
+                 options=[HBHOptUnknown(otype=0xff, optlen=0)] * 1000) /
+             IPv6ExtHdrFragment(m=1) /
+             UDP(sport=1234, dport=5678) /
+             Raw())
+        self.pg_enable_capture()
+        self.src_if.add_stream([p])
+        self.pg_start()
+        self.src_if.assert_nothing_captured()
+        self.dst_if.assert_nothing_captured()
+
     def test_reversed(self):
         """ reverse order reassembly """