ipsec: Fix setting the hi-sequence number for decrypt

Type: fix

two problems;
 1 - just because anti-reply is not enabled doesn't mean the high sequence
number should not be used.
   - fix, there needs to be some means to detect a wrapped packet, so we
use a window size of 2^30.
 2 - The SA object was used as a scratch pad for the high-sequence
number used during decryption. That means that once the batch has been
processed the high-sequence number used is lost. This means it is not
possible to distinguish this case:
      if (seq < IPSEC_SA_ANTI_REPLAY_WINDOW_LOWER_BOUND (tl))
	{
	  ...
	  if (post_decrypt)
	    {
	      if (hi_seq_used == sa->seq_hi)
		/* the high sequence number used to succesfully decrypt this
		 * packet is the same as the last-sequnence number of the SA.
		 * that means this packet did not cause a wrap.
		 * this packet is thus out of window and should be dropped */
		return 1;
	      else
		/* The packet decrypted with a different high sequence number
		 * to the SA, that means it is the wrap packet and should be
		 * accepted */
		return 0;
	    }
  - fix: don't use the SA as a scratch pad, use the 'packet_data' - the
same place that is used as the scratch pad for the low sequence number.

other consequences:
 - An SA doesn't have seq and last_seq, it has only seq; the sequence
numnber of the last packet tx'd or rx'd.
 - there's 64bits of space available on the SA's first cache line. move
the AES CTR mode IV there.
 - test the ESN/AR combinations to catch the bugs this fixes. This
doubles the amount of tests, but without AR on they only run for 2
seconds. In the AR tests, the time taken to wait for packets that won't
arrive is dropped from 1 to 0.2 seconds thus reducing the runtime of
these tests from 10-15 to about 5 sceonds.

Signed-off-by: Neale Ranns <neale@graphiant.com>
Change-Id: Iaac78905289a272dc01930d70decd8109cf5e7a5
diff --git a/src/vnet/ipsec/ah_decrypt.c b/src/vnet/ipsec/ah_decrypt.c
index d192fb6..182ed3d 100644
--- a/src/vnet/ipsec/ah_decrypt.c
+++ b/src/vnet/ipsec/ah_decrypt.c
@@ -98,6 +98,7 @@
   };
   u32 sa_index;
   u32 seq;
+  u32 seq_hi;
   u8 icv_padding_len;
   u8 icv_size;
   u8 ip_hdr_size;
@@ -221,7 +222,8 @@
       pd->seq = clib_host_to_net_u32 (ah0->seq_no);
 
       /* anti-replay check */
-      if (ipsec_sa_anti_replay_check (sa0, pd->seq))
+      if (ipsec_sa_anti_replay_and_sn_advance (sa0, pd->seq, ~0, false,
+					       &pd->seq_hi))
 	{
 	  b[0]->error = node->errors[AH_DECRYPT_ERROR_REPLAY];
 	  next[0] = AH_DECRYPT_NEXT_DROP;
@@ -257,7 +259,7 @@
 	  op->user_data = b - bufs;
 	  if (ipsec_sa_is_set_USE_ESN (sa0))
 	    {
-	      u32 seq_hi = clib_host_to_net_u32 (sa0->seq_hi);
+	      u32 seq_hi = clib_host_to_net_u32 (pd->seq_hi);
 
 	      op->len += sizeof (seq_hi);
 	      clib_memcpy (op->src + b[0]->current_length, &seq_hi,
@@ -322,13 +324,14 @@
       if (PREDICT_TRUE (sa0->integ_alg != IPSEC_INTEG_ALG_NONE))
 	{
 	  /* redo the anit-reply check. see esp_decrypt for details */
-	  if (ipsec_sa_anti_replay_check (sa0, pd->seq))
+	  if (ipsec_sa_anti_replay_and_sn_advance (sa0, pd->seq, pd->seq_hi,
+						   true, NULL))
 	    {
 	      b[0]->error = node->errors[AH_DECRYPT_ERROR_REPLAY];
 	      next[0] = AH_DECRYPT_NEXT_DROP;
 	      goto trace;
 	    }
-	  ipsec_sa_anti_replay_advance (sa0, pd->seq);
+	  ipsec_sa_anti_replay_advance (sa0, pd->seq, pd->seq_hi);
 	}
 
       u16 ah_hdr_len = sizeof (ah_header_t) + pd->icv_size
diff --git a/src/vnet/ipsec/esp.h b/src/vnet/ipsec/esp.h
index a0643c3..d179233 100644
--- a/src/vnet/ipsec/esp.h
+++ b/src/vnet/ipsec/esp.h
@@ -118,7 +118,8 @@
 }
 
 always_inline u16
-esp_aad_fill (u8 * data, const esp_header_t * esp, const ipsec_sa_t * sa)
+esp_aad_fill (u8 *data, const esp_header_t *esp, const ipsec_sa_t *sa,
+	      u32 seq_hi)
 {
   esp_aead_t *aad;
 
@@ -128,7 +129,7 @@
   if (ipsec_sa_is_set_USE_ESN (sa))
     {
       /* SPI, seq-hi, seq-low */
-      aad->data[1] = (u32) clib_host_to_net_u32 (sa->seq_hi);
+      aad->data[1] = (u32) clib_host_to_net_u32 (seq_hi);
       aad->data[2] = esp->seq;
       return 12;
     }
@@ -199,7 +200,7 @@
   i16 current_length;
   u16 hdr_sz;
   u16 is_chain;
-  u32 protect_index;
+  u32 seq_hi;
 } esp_decrypt_packet_data_t;
 
 STATIC_ASSERT_SIZEOF (esp_decrypt_packet_data_t, 3 * sizeof (u64));
diff --git a/src/vnet/ipsec/esp_decrypt.c b/src/vnet/ipsec/esp_decrypt.c
index ec6d981..b700f2c 100644
--- a/src/vnet/ipsec/esp_decrypt.c
+++ b/src/vnet/ipsec/esp_decrypt.c
@@ -92,10 +92,14 @@
   u32 seq;
   u32 sa_seq;
   u32 sa_seq_hi;
+  u32 pkt_seq_hi;
   ipsec_crypto_alg_t crypto_alg;
   ipsec_integ_alg_t integ_alg;
 } esp_decrypt_trace_t;
 
+/* The number of byres in the hisequence number */
+#define N_HI_ESN_BYTES 4
+
 /* packet trace format function */
 static u8 *
 format_esp_decrypt_trace (u8 * s, va_list * args)
@@ -104,11 +108,11 @@
   CLIB_UNUSED (vlib_node_t * node) = va_arg (*args, vlib_node_t *);
   esp_decrypt_trace_t *t = va_arg (*args, esp_decrypt_trace_t *);
 
-  s =
-    format (s,
-	    "esp: crypto %U integrity %U pkt-seq %d sa-seq %u sa-seq-hi %u",
-	    format_ipsec_crypto_alg, t->crypto_alg, format_ipsec_integ_alg,
-	    t->integ_alg, t->seq, t->sa_seq, t->sa_seq_hi);
+  s = format (s,
+	      "esp: crypto %U integrity %U pkt-seq %d sa-seq %u sa-seq-hi %u "
+	      "pkt-seq-hi %u",
+	      format_ipsec_crypto_alg, t->crypto_alg, format_ipsec_integ_alg,
+	      t->integ_alg, t->seq, t->sa_seq, t->sa_seq_hi, t->pkt_seq_hi);
   return s;
 }
 
@@ -235,40 +239,41 @@
   return lb_curr;
 }
 
-static_always_inline i16
-esp_insert_esn (vlib_main_t * vm, ipsec_sa_t * sa,
-		esp_decrypt_packet_data2_t * pd2, u32 * data_len,
-		u8 ** digest, u16 * len, vlib_buffer_t * b, u8 * payload)
+static_always_inline u16
+esp_insert_esn (vlib_main_t *vm, ipsec_sa_t *sa, esp_decrypt_packet_data_t *pd,
+		esp_decrypt_packet_data2_t *pd2, u32 *data_len, u8 **digest,
+		u16 *len, vlib_buffer_t *b, u8 *payload)
 {
   if (!ipsec_sa_is_set_USE_ESN (sa))
     return 0;
-
   /* shift ICV by 4 bytes to insert ESN */
-  u32 seq_hi = clib_host_to_net_u32 (sa->seq_hi);
-  u8 tmp[ESP_MAX_ICV_SIZE], sz = sizeof (sa->seq_hi);
+  u32 seq_hi = clib_host_to_net_u32 (pd->seq_hi);
+  u8 tmp[ESP_MAX_ICV_SIZE];
 
   if (pd2->icv_removed)
     {
       u16 space_left = vlib_buffer_space_left_at_end (vm, pd2->lb);
-      if (space_left >= sz)
+      if (space_left >= N_HI_ESN_BYTES)
 	{
-	  clib_memcpy_fast (vlib_buffer_get_tail (pd2->lb), &seq_hi, sz);
-	  *data_len += sz;
+	  clib_memcpy_fast (vlib_buffer_get_tail (pd2->lb), &seq_hi,
+			    N_HI_ESN_BYTES);
+	  *data_len += N_HI_ESN_BYTES;
 	}
       else
-	return sz;
+	return N_HI_ESN_BYTES;
 
       len[0] = b->current_length;
     }
   else
     {
       clib_memcpy_fast (tmp, payload + len[0], ESP_MAX_ICV_SIZE);
-      clib_memcpy_fast (payload + len[0], &seq_hi, sz);
-      clib_memcpy_fast (payload + len[0] + sz, tmp, ESP_MAX_ICV_SIZE);
-      *data_len += sz;
-      *digest += sz;
+      clib_memcpy_fast (payload + len[0], &seq_hi, N_HI_ESN_BYTES);
+      clib_memcpy_fast (payload + len[0] + N_HI_ESN_BYTES, tmp,
+			ESP_MAX_ICV_SIZE);
+      *data_len += N_HI_ESN_BYTES;
+      *digest += N_HI_ESN_BYTES;
     }
-  return sz;
+  return N_HI_ESN_BYTES;
 }
 
 static_always_inline u8 *
@@ -284,14 +289,14 @@
 
   if (ipsec_sa_is_set_USE_ESN (sa))
     {
-      u8 sz = sizeof (sa->seq_hi);
-      u32 seq_hi = clib_host_to_net_u32 (sa->seq_hi);
+      u32 seq_hi = clib_host_to_net_u32 (pd->seq_hi);
       u16 space_left = vlib_buffer_space_left_at_end (vm, pd2->lb);
 
-      if (space_left >= sz)
+      if (space_left >= N_HI_ESN_BYTES)
 	{
-	  clib_memcpy_fast (vlib_buffer_get_tail (pd2->lb), &seq_hi, sz);
-	  *len += sz;
+	  clib_memcpy_fast (vlib_buffer_get_tail (pd2->lb), &seq_hi,
+			    N_HI_ESN_BYTES);
+	  *len += N_HI_ESN_BYTES;
 	}
       else
 	{
@@ -299,7 +304,8 @@
 	   * (with ICV data) */
 	  ASSERT (pd2->icv_removed);
 	  vlib_buffer_t *tmp = vlib_get_buffer (vm, pd2->free_buffer_index);
-	  clib_memcpy_fast (vlib_buffer_get_current (tmp) - sz, &seq_hi, sz);
+	  clib_memcpy_fast (vlib_buffer_get_current (tmp) - N_HI_ESN_BYTES,
+			    &seq_hi, N_HI_ESN_BYTES);
 	  extra_esn[0] = 1;
 	}
     }
@@ -307,11 +313,12 @@
 }
 
 static_always_inline int
-esp_decrypt_chain_integ (vlib_main_t * vm, ipsec_per_thread_data_t * ptd,
-			 esp_decrypt_packet_data2_t * pd2,
-			 ipsec_sa_t * sa0, vlib_buffer_t * b, u8 icv_sz,
-			 u8 * start_src, u32 start_len,
-			 u8 ** digest, u16 * n_ch, u32 * integ_total_len)
+esp_decrypt_chain_integ (vlib_main_t *vm, ipsec_per_thread_data_t *ptd,
+			 const esp_decrypt_packet_data_t *pd,
+			 esp_decrypt_packet_data2_t *pd2, ipsec_sa_t *sa0,
+			 vlib_buffer_t *b, u8 icv_sz, u8 *start_src,
+			 u32 start_len, u8 **digest, u16 *n_ch,
+			 u32 *integ_total_len)
 {
   vnet_crypto_op_chunk_t *ch;
   vlib_buffer_t *cb = vlib_get_buffer (vm, b->next_buffer);
@@ -334,19 +341,19 @@
 	    ch->len = cb->current_length - icv_sz;
 	  if (ipsec_sa_is_set_USE_ESN (sa0))
 	    {
-	      u32 seq_hi = clib_host_to_net_u32 (sa0->seq_hi);
-	      u8 tmp[ESP_MAX_ICV_SIZE], sz = sizeof (sa0->seq_hi);
+	      u32 seq_hi = clib_host_to_net_u32 (pd->seq_hi);
+	      u8 tmp[ESP_MAX_ICV_SIZE];
 	      u8 *esn;
 	      vlib_buffer_t *tmp_b;
 	      u16 space_left = vlib_buffer_space_left_at_end (vm, pd2->lb);
-	      if (space_left < sz)
+	      if (space_left < N_HI_ESN_BYTES)
 		{
 		  if (pd2->icv_removed)
 		    {
 		      /* use pre-data area from the last bufer
 		         that was removed from the chain */
 		      tmp_b = vlib_get_buffer (vm, pd2->free_buffer_index);
-		      esn = tmp_b->data - sz;
+		      esn = tmp_b->data - N_HI_ESN_BYTES;
 		    }
 		  else
 		    {
@@ -358,28 +365,29 @@
 		      esn = tmp_b->data;
 		      pd2->free_buffer_index = tmp_bi;
 		    }
-		  clib_memcpy_fast (esn, &seq_hi, sz);
+		  clib_memcpy_fast (esn, &seq_hi, N_HI_ESN_BYTES);
 
 		  vec_add2 (ptd->chunks, ch, 1);
 		  n_chunks += 1;
 		  ch->src = esn;
-		  ch->len = sz;
+		  ch->len = N_HI_ESN_BYTES;
 		}
 	      else
 		{
 		  if (pd2->icv_removed)
 		    {
-		      clib_memcpy_fast (vlib_buffer_get_tail
-					(pd2->lb), &seq_hi, sz);
+		      clib_memcpy_fast (vlib_buffer_get_tail (pd2->lb),
+					&seq_hi, N_HI_ESN_BYTES);
 		    }
 		  else
 		    {
 		      clib_memcpy_fast (tmp, *digest, ESP_MAX_ICV_SIZE);
-		      clib_memcpy_fast (*digest, &seq_hi, sz);
-		      clib_memcpy_fast (*digest + sz, tmp, ESP_MAX_ICV_SIZE);
-		      *digest += sz;
+		      clib_memcpy_fast (*digest, &seq_hi, N_HI_ESN_BYTES);
+		      clib_memcpy_fast (*digest + N_HI_ESN_BYTES, tmp,
+					ESP_MAX_ICV_SIZE);
+		      *digest += N_HI_ESN_BYTES;
 		    }
-		  ch->len += sz;
+		  ch->len += N_HI_ESN_BYTES;
 		}
 	    }
 	  total_len += ch->len;
@@ -540,7 +548,7 @@
 
 	  op->flags |= VNET_CRYPTO_OP_FLAG_CHAINED_BUFFERS;
 	  op->chunk_index = vec_len (ptd->chunks);
-	  if (esp_decrypt_chain_integ (vm, ptd, pd2, sa0, b, icv_sz,
+	  if (esp_decrypt_chain_integ (vm, ptd, pd, pd2, sa0, b, icv_sz,
 				       payload, pd->current_length,
 				       &op->digest, &op->n_chunks, 0) < 0)
 	    {
@@ -550,7 +558,7 @@
 	    }
 	}
       else
-	esp_insert_esn (vm, sa0, pd2, &op->len, &op->digest, &len, b,
+	esp_insert_esn (vm, sa0, pd, pd2, &op->len, &op->digest, &len, b,
 			payload);
     out:
       vec_add_aligned (*(integ_ops[0]), op, 1, CLIB_CACHE_LINE_BYTES);
@@ -576,7 +584,7 @@
 	      /* constuct aad in a scratch space in front of the nonce */
 	      esp_header_t *esp0 = (esp_header_t *) (payload - esp_sz);
 	      op->aad = (u8 *) nonce - sizeof (esp_aead_t);
-	      op->aad_len = esp_aad_fill (op->aad, esp0, sa0);
+	      op->aad_len = esp_aad_fill (op->aad, esp0, sa0, pd->seq_hi);
 	      op->tag = payload + len;
 	      op->tag_len = 16;
 	    }
@@ -617,7 +625,6 @@
 				 vlib_buffer_t *b, u16 *next, u16 async_next)
 {
   const u8 esp_sz = sizeof (esp_header_t);
-  u32 current_protect_index = vnet_buffer (b)->ipsec.protect_index;
   esp_decrypt_packet_data_t *async_pd = &(esp_post_data (b))->decrypt_data;
   esp_decrypt_packet_data2_t *async_pd2 = esp_post_data2 (b);
   u8 *tag = payload + len, *iv = payload + esp_sz, *aad = 0;
@@ -671,16 +678,16 @@
 	    tag = vlib_buffer_get_tail (pd2->lb) - icv_sz;
 
 	  flags |= VNET_CRYPTO_OP_FLAG_CHAINED_BUFFERS;
-	  if (esp_decrypt_chain_integ (vm, ptd, pd2, sa0, b, icv_sz, payload,
-				       pd->current_length, &tag,
-				       0, &integ_len) < 0)
+	  if (esp_decrypt_chain_integ (vm, ptd, pd, pd2, sa0, b, icv_sz,
+				       payload, pd->current_length, &tag, 0,
+				       &integ_len) < 0)
 	    {
 	      /* allocate buffer failed, will not add to frame and drop */
 	      return (ESP_DECRYPT_ERROR_NO_BUFFERS);
 	    }
 	}
       else
-	esp_insert_esn (vm, sa0, pd2, &integ_len, &tag, &len, b, payload);
+	esp_insert_esn (vm, sa0, pd, pd2, &integ_len, &tag, &len, b, payload);
     }
   else
     key_index = sa0->crypto_key_index;
@@ -701,7 +708,7 @@
 	  /* constuct aad in a scratch space in front of the nonce */
 	  esp_header_t *esp0 = (esp_header_t *) (payload - esp_sz);
 	  aad = (u8 *) nonce - sizeof (esp_aead_t);
-	  esp_aad_fill (aad, esp0, sa0);
+	  esp_aad_fill (aad, esp0, sa0, pd->seq_hi);
 	  tag = payload + len;
 	}
       else
@@ -730,7 +737,6 @@
 
   *async_pd = *pd;
   *async_pd2 = *pd2;
-  pd->protect_index = current_protect_index;
 
   /* for AEAD integ_len - crypto_len will be negative, it is ok since it
    * is ignored by the engine. */
@@ -776,14 +782,15 @@
    * a sequence s, s+1, s+2, s+3, ... s+n and nothing will prevent any
    * implementation, sequential or batching, from decrypting these.
    */
-  if (ipsec_sa_anti_replay_check (sa0, pd->seq))
+  if (ipsec_sa_anti_replay_and_sn_advance (sa0, pd->seq, pd->seq_hi, true,
+					   NULL))
     {
       b->error = node->errors[ESP_DECRYPT_ERROR_REPLAY];
       next[0] = ESP_DECRYPT_NEXT_DROP;
       return;
     }
 
-  ipsec_sa_anti_replay_advance (sa0, pd->seq);
+  ipsec_sa_anti_replay_advance (sa0, pd->seq, pd->seq_hi);
 
   if (pd->is_chain)
     {
@@ -968,12 +975,8 @@
 	       */
 	      const ipsec_tun_protect_t *itp;
 
-	      if (is_async)
-		itp = ipsec_tun_protect_get (pd->protect_index);
-	      else
-		itp =
-		  ipsec_tun_protect_get (vnet_buffer (b)->
-					 ipsec.protect_index);
+	      itp =
+		ipsec_tun_protect_get (vnet_buffer (b)->ipsec.protect_index);
 
 	      if (PREDICT_TRUE (next_header == IP_PROTOCOL_IP_IN_IP))
 		{
@@ -1148,7 +1151,8 @@
       pd->current_length = b[0]->current_length;
 
       /* anti-reply check */
-      if (ipsec_sa_anti_replay_check (sa0, pd->seq))
+      if (ipsec_sa_anti_replay_and_sn_advance (sa0, pd->seq, ~0, false,
+					       &pd->seq_hi))
 	{
 	  err = ESP_DECRYPT_ERROR_REPLAY;
 	  esp_set_next_index (b[0], node, err, n_noop, noop_nexts,
@@ -1306,8 +1310,9 @@
 	  tr->crypto_alg = sa0->crypto_alg;
 	  tr->integ_alg = sa0->integ_alg;
 	  tr->seq = pd->seq;
-	  tr->sa_seq = sa0->last_seq;
+	  tr->sa_seq = sa0->seq;
 	  tr->sa_seq_hi = sa0->seq_hi;
+	  tr->pkt_seq_hi = pd->seq_hi;
 	}
 
       /* next */
@@ -1374,7 +1379,7 @@
 	  tr->crypto_alg = sa0->crypto_alg;
 	  tr->integ_alg = sa0->integ_alg;
 	  tr->seq = pd->seq;
-	  tr->sa_seq = sa0->last_seq;
+	  tr->sa_seq = sa0->seq;
 	  tr->sa_seq_hi = sa0->seq_hi;
 	}
 
diff --git a/src/vnet/ipsec/esp_encrypt.c b/src/vnet/ipsec/esp_encrypt.c
index 68aeb60..da9c56a 100644
--- a/src/vnet/ipsec/esp_encrypt.c
+++ b/src/vnet/ipsec/esp_encrypt.c
@@ -379,7 +379,7 @@
 always_inline void
 esp_prepare_sync_op (vlib_main_t *vm, ipsec_per_thread_data_t *ptd,
 		     vnet_crypto_op_t **crypto_ops,
-		     vnet_crypto_op_t **integ_ops, ipsec_sa_t *sa0,
+		     vnet_crypto_op_t **integ_ops, ipsec_sa_t *sa0, u32 seq_hi,
 		     u8 *payload, u16 payload_len, u8 iv_sz, u8 icv_sz, u32 bi,
 		     vlib_buffer_t **b, vlib_buffer_t *lb, u32 hdr_len,
 		     esp_header_t *esp)
@@ -408,7 +408,7 @@
 	    {
 	      /* constuct aad in a scratch space in front of the nonce */
 	      op->aad = (u8 *) nonce - sizeof (esp_aead_t);
-	      op->aad_len = esp_aad_fill (op->aad, esp, sa0);
+	      op->aad_len = esp_aad_fill (op->aad, esp, sa0, seq_hi);
 	      op->tag = payload + op->len;
 	      op->tag_len = 16;
 	    }
@@ -465,8 +465,8 @@
 	}
       else if (ipsec_sa_is_set_USE_ESN (sa0))
 	{
-	  u32 seq_hi = clib_net_to_host_u32 (sa0->seq_hi);
-	  clib_memcpy_fast (op->digest, &seq_hi, sizeof (seq_hi));
+	  u32 tmp = clib_net_to_host_u32 (seq_hi);
+	  clib_memcpy_fast (op->digest, &tmp, sizeof (seq_hi));
 	  op->len += sizeof (seq_hi);
 	}
     }
@@ -508,7 +508,7 @@
 	{
 	  /* constuct aad in a scratch space in front of the nonce */
 	  aad = (u8 *) nonce - sizeof (esp_aead_t);
-	  esp_aad_fill (aad, esp, sa);
+	  esp_aad_fill (aad, esp, sa, sa->seq_hi);
 	  key_index = sa->crypto_key_index;
 	}
       else
@@ -956,9 +956,9 @@
 				   async_next_node, lb);
 	}
       else
-	esp_prepare_sync_op (vm, ptd, crypto_ops, integ_ops, sa0, payload,
-			     payload_len, iv_sz, icv_sz, n_sync, b, lb,
-			     hdr_len, esp);
+	esp_prepare_sync_op (vm, ptd, crypto_ops, integ_ops, sa0, sa0->seq_hi,
+			     payload, payload_len, iv_sz, icv_sz, n_sync, b,
+			     lb, hdr_len, esp);
 
       vlib_buffer_advance (b[0], 0LL - hdr_len);
 
diff --git a/src/vnet/ipsec/ipsec_api.c b/src/vnet/ipsec/ipsec_api.c
index 73f4474..11bfa41 100644
--- a/src/vnet/ipsec/ipsec_api.c
+++ b/src/vnet/ipsec/ipsec_api.c
@@ -826,11 +826,11 @@
     }
 
   mp->seq_outbound = clib_host_to_net_u64 (((u64) sa->seq));
-  mp->last_seq_inbound = clib_host_to_net_u64 (((u64) sa->last_seq));
+  mp->last_seq_inbound = clib_host_to_net_u64 (((u64) sa->seq));
   if (ipsec_sa_is_set_USE_ESN (sa))
     {
       mp->seq_outbound |= (u64) (clib_host_to_net_u32 (sa->seq_hi));
-      mp->last_seq_inbound |= (u64) (clib_host_to_net_u32 (sa->last_seq_hi));
+      mp->last_seq_inbound |= (u64) (clib_host_to_net_u32 (sa->seq_hi));
     }
   if (ipsec_sa_is_set_USE_ANTI_REPLAY (sa))
     mp->replay_window = clib_host_to_net_u64 (sa->replay_window);
@@ -913,11 +913,11 @@
   mp->entry.dscp = ip_dscp_encode (sa->tunnel.t_dscp);
 
   mp->seq_outbound = clib_host_to_net_u64 (((u64) sa->seq));
-  mp->last_seq_inbound = clib_host_to_net_u64 (((u64) sa->last_seq));
+  mp->last_seq_inbound = clib_host_to_net_u64 (((u64) sa->seq));
   if (ipsec_sa_is_set_USE_ESN (sa))
     {
       mp->seq_outbound |= (u64) (clib_host_to_net_u32 (sa->seq_hi));
-      mp->last_seq_inbound |= (u64) (clib_host_to_net_u32 (sa->last_seq_hi));
+      mp->last_seq_inbound |= (u64) (clib_host_to_net_u32 (sa->seq_hi));
     }
   if (ipsec_sa_is_set_USE_ANTI_REPLAY (sa))
     mp->replay_window = clib_host_to_net_u64 (sa->replay_window);
@@ -993,11 +993,11 @@
     }
 
   mp->seq_outbound = clib_host_to_net_u64 (((u64) sa->seq));
-  mp->last_seq_inbound = clib_host_to_net_u64 (((u64) sa->last_seq));
+  mp->last_seq_inbound = clib_host_to_net_u64 (((u64) sa->seq));
   if (ipsec_sa_is_set_USE_ESN (sa))
     {
       mp->seq_outbound |= (u64) (clib_host_to_net_u32 (sa->seq_hi));
-      mp->last_seq_inbound |= (u64) (clib_host_to_net_u32 (sa->last_seq_hi));
+      mp->last_seq_inbound |= (u64) (clib_host_to_net_u32 (sa->seq_hi));
     }
   if (ipsec_sa_is_set_USE_ANTI_REPLAY (sa))
     mp->replay_window = clib_host_to_net_u64 (sa->replay_window);
diff --git a/src/vnet/ipsec/ipsec_format.c b/src/vnet/ipsec/ipsec_format.c
index b67c11d..5f7caab 100644
--- a/src/vnet/ipsec/ipsec_format.c
+++ b/src/vnet/ipsec/ipsec_format.c
@@ -293,9 +293,8 @@
   s = format (s, "\n   salt 0x%x", clib_net_to_host_u32 (sa->salt));
   s = format (s, "\n   thread-index:%d", sa->thread_index);
   s = format (s, "\n   seq %u seq-hi %u", sa->seq, sa->seq_hi);
-  s = format (s, "\n   last-seq %u last-seq-hi %u window %U",
-	      sa->last_seq, sa->last_seq_hi,
-	      format_ipsec_replay_window, sa->replay_window);
+  s = format (s, "\n   window %U", format_ipsec_replay_window,
+	      sa->replay_window);
   s = format (s, "\n   crypto alg %U",
 	      format_ipsec_crypto_alg, sa->crypto_alg);
   if (sa->crypto_alg && (flags & IPSEC_FORMAT_INSECURE))
diff --git a/src/vnet/ipsec/ipsec_sa.h b/src/vnet/ipsec/ipsec_sa.h
index 7827ef1..14461ad 100644
--- a/src/vnet/ipsec/ipsec_sa.h
+++ b/src/vnet/ipsec/ipsec_sa.h
@@ -131,9 +131,8 @@
   u32 spi;
   u32 seq;
   u32 seq_hi;
-  u32 last_seq;
-  u32 last_seq_hi;
   u64 replay_window;
+  u64 ctr_iv_counter;
   dpo_id_t dpo;
 
   vnet_crypto_key_index_t crypto_key_index;
@@ -162,7 +161,6 @@
 
   CLIB_CACHE_LINE_ALIGN_MARK (cacheline1);
 
-  u64 ctr_iv_counter;
   union
   {
     ip4_header_t ip4_hdr;
@@ -312,45 +310,104 @@
  */
 #define IPSEC_SA_ANTI_REPLAY_WINDOW_LOWER_BOUND(_tl) (_tl - IPSEC_SA_ANTI_REPLAY_WINDOW_SIZE + 1)
 
+always_inline int
+ipsec_sa_anti_replay_check (const ipsec_sa_t *sa, u32 seq)
+{
+  if (ipsec_sa_is_set_USE_ANTI_REPLAY (sa) &&
+      sa->replay_window & (1ULL << (sa->seq - seq)))
+    return 1;
+  else
+    return 0;
+}
+
 /*
  * Anti replay check.
  *  inputs need to be in host byte order.
+ *
+ * The function runs in two contexts. pre and post decrypt.
+ * Pre-decrypt it:
+ *  1 - determines if a packet is a replay - a simple check in the window
+ *  2 - returns the hi-seq number that should be used to decrypt.
+ * post-decrypt:
+ *  Checks whether the packet is a replay or falls out of window
+ *
+ * This funcion should be called even without anti-replay enabled to ensure
+ * the high sequence number is set.
  */
 always_inline int
-ipsec_sa_anti_replay_check (ipsec_sa_t * sa, u32 seq)
+ipsec_sa_anti_replay_and_sn_advance (const ipsec_sa_t *sa, u32 seq,
+				     u32 hi_seq_used, bool post_decrypt,
+				     u32 *hi_seq_req)
 {
-  u32 diff, tl, th;
-
-  if ((sa->flags & IPSEC_SA_FLAG_USE_ANTI_REPLAY) == 0)
-    return 0;
+  ASSERT ((post_decrypt == false) == (hi_seq_req != 0));
 
   if (!ipsec_sa_is_set_USE_ESN (sa))
     {
-      if (PREDICT_TRUE (seq > sa->last_seq))
+      if (hi_seq_req)
+	/* no ESN, therefore the hi-seq is always 0 */
+	*hi_seq_req = 0;
+
+      if (!ipsec_sa_is_set_USE_ANTI_REPLAY (sa))
 	return 0;
 
-      diff = sa->last_seq - seq;
+      if (PREDICT_TRUE (seq > sa->seq))
+	return 0;
+
+      u32 diff = sa->seq - seq;
 
       if (IPSEC_SA_ANTI_REPLAY_WINDOW_SIZE > diff)
-	return (sa->replay_window & (1ULL << diff)) ? 1 : 0;
+	return ((sa->replay_window & (1ULL << diff)) ? 1 : 0);
       else
 	return 1;
 
       return 0;
     }
 
-  tl = sa->last_seq;
-  th = sa->last_seq_hi;
-  diff = tl - seq;
-
-  if (PREDICT_TRUE (tl >= (IPSEC_SA_ANTI_REPLAY_WINDOW_MAX_INDEX)))
+  if (!ipsec_sa_is_set_USE_ANTI_REPLAY (sa))
+    {
+      /* there's no AR configured for this SA, but in order
+       * to know whether a packet has wrapped the hi ESN we need
+       * to know whether it is out of window. if we use the default
+       * lower bound then we are effectively forcing AR because
+       * out of window packets will get the increased hi seq number
+       * and will thus fail to decrypt. IOW we need a window to know
+       * if the SN has wrapped, but we don't want a window to check for
+       * anti replay. to resolve the contradiction we use a huge window.
+       * if the packet is not within 2^30 of the current SN, we'll consider
+       * it a wrap.
+       */
+      if (hi_seq_req)
+	{
+	  if (seq >= sa->seq)
+	    /* The packet's sequence number is larger that the SA's.
+	     * that can't be a warp - unless we lost more than
+	     * 2^32 packets ... how could we know? */
+	    *hi_seq_req = sa->seq_hi;
+	  else
+	    {
+	      /* The packet's SN is less than the SAs, so either the SN has
+	       * wrapped or the SN is just old. */
+	      if (sa->seq - seq > (1 << 30))
+		/* It's really really really old => it wrapped */
+		*hi_seq_req = sa->seq_hi + 1;
+	      else
+		*hi_seq_req = sa->seq_hi;
+	    }
+	}
+      /*
+       * else
+       *   this is post-decrpyt and since it decrypted we accept it
+       */
+      return 0;
+    }
+  if (PREDICT_TRUE (sa->seq >= (IPSEC_SA_ANTI_REPLAY_WINDOW_MAX_INDEX)))
     {
       /*
        * the last sequence number VPP recieved is more than one
        * window size greater than zero.
        * Case A from RFC4303 Appendix A.
        */
-      if (seq < IPSEC_SA_ANTI_REPLAY_WINDOW_LOWER_BOUND (tl))
+      if (seq < IPSEC_SA_ANTI_REPLAY_WINDOW_LOWER_BOUND (sa->seq))
 	{
 	  /*
 	   * the received sequence number is lower than the lower bound
@@ -358,8 +415,28 @@
 	   * the high sequence number has wrapped. if it decrypts corrently
 	   * then it's the latter.
 	   */
-	  sa->seq_hi = th + 1;
-	  return 0;
+	  if (post_decrypt)
+	    {
+	      if (hi_seq_used == sa->seq_hi)
+		/* the high sequence number used to succesfully decrypt this
+		 * packet is the same as the last-sequnence number of the SA.
+		 * that means this packet did not cause a wrap.
+		 * this packet is thus out of window and should be dropped */
+		return 1;
+	      else
+		/* The packet decrypted with a different high sequence number
+		 * to the SA, that means it is the wrap packet and should be
+		 * accepted */
+		return 0;
+	    }
+	  else
+	    {
+	      /* pre-decrypt it might be the might that casues a wrap, we
+	       * need to decrpyt to find out */
+	      if (hi_seq_req)
+		*hi_seq_req = sa->seq_hi + 1;
+	      return 0;
+	    }
 	}
       else
 	{
@@ -367,13 +444,14 @@
 	   * the recieved sequence number greater than the low
 	   * end of the window.
 	   */
-	  sa->seq_hi = th;
-	  if (seq <= tl)
+	  if (hi_seq_req)
+	    *hi_seq_req = sa->seq_hi;
+	  if (seq <= sa->seq)
 	    /*
 	     * The recieved seq number is within bounds of the window
 	     * check if it's a duplicate
 	     */
-	    return (sa->replay_window & (1ULL << diff)) ? 1 : 0;
+	    return (ipsec_sa_anti_replay_check (sa, seq));
 	  else
 	    /*
 	     * The received sequence number is greater than the window
@@ -393,19 +471,20 @@
        * RHS will be a larger number.
        * Case B from RFC4303 Appendix A.
        */
-      if (seq < IPSEC_SA_ANTI_REPLAY_WINDOW_LOWER_BOUND (tl))
+      if (seq < IPSEC_SA_ANTI_REPLAY_WINDOW_LOWER_BOUND (sa->seq))
 	{
 	  /*
 	   * the sequence number is less than the lower bound.
 	   */
-	  if (seq <= tl)
+	  if (seq <= sa->seq)
 	    {
 	      /*
 	       * the packet is within the window upper bound.
 	       * check for duplicates.
 	       */
-	      sa->seq_hi = th;
-	      return (sa->replay_window & (1ULL << diff)) ? 1 : 0;
+	      if (hi_seq_req)
+		*hi_seq_req = sa->seq_hi;
+	      return (ipsec_sa_anti_replay_check (sa, seq));
 	    }
 	  else
 	    {
@@ -418,7 +497,8 @@
 	       * wrapped the high sequence again. If it were the latter then
 	       * we've lost close to 2^32 packets.
 	       */
-	      sa->seq_hi = th;
+	      if (hi_seq_req)
+		*hi_seq_req = sa->seq_hi;
 	      return 0;
 	    }
 	}
@@ -431,73 +511,79 @@
 	   * However, since TL is the other side of 0 to the received
 	   * packet, the SA has moved on to a higher sequence number.
 	   */
-	  sa->seq_hi = th - 1;
-	  return (sa->replay_window & (1ULL << diff)) ? 1 : 0;
+	  if (hi_seq_req)
+	    *hi_seq_req = sa->seq_hi - 1;
+	  return (ipsec_sa_anti_replay_check (sa, seq));
 	}
     }
 
+  /* unhandled case */
+  ASSERT (0);
   return 0;
 }
 
 /*
  * Anti replay window advance
  *  inputs need to be in host byte order.
+ * This function both advances the anti-replay window and the sequence number
+ * We always need to move on the SN but the window updates are only needed
+ * if AR is on.
+ * However, updating the window is trivial, so we do it anyway to save
+ * the branch cost.
  */
 always_inline void
-ipsec_sa_anti_replay_advance (ipsec_sa_t * sa, u32 seq)
+ipsec_sa_anti_replay_advance (ipsec_sa_t *sa, u32 seq, u32 hi_seq)
 {
   u32 pos;
-  if (PREDICT_TRUE (sa->flags & IPSEC_SA_FLAG_USE_ANTI_REPLAY) == 0)
-    return;
 
-  if (PREDICT_TRUE (sa->flags & IPSEC_SA_FLAG_USE_ESN))
+  if (ipsec_sa_is_set_USE_ESN (sa))
     {
-      int wrap = sa->seq_hi - sa->last_seq_hi;
+      int wrap = hi_seq - sa->seq_hi;
 
-      if (wrap == 0 && seq > sa->last_seq)
+      if (wrap == 0 && seq > sa->seq)
 	{
-	  pos = seq - sa->last_seq;
+	  pos = seq - sa->seq;
 	  if (pos < IPSEC_SA_ANTI_REPLAY_WINDOW_SIZE)
 	    sa->replay_window = ((sa->replay_window) << pos) | 1;
 	  else
 	    sa->replay_window = 1;
-	  sa->last_seq = seq;
+	  sa->seq = seq;
 	}
       else if (wrap > 0)
 	{
-	  pos = ~seq + sa->last_seq + 1;
+	  pos = ~seq + sa->seq + 1;
 	  if (pos < IPSEC_SA_ANTI_REPLAY_WINDOW_SIZE)
 	    sa->replay_window = ((sa->replay_window) << pos) | 1;
 	  else
 	    sa->replay_window = 1;
-	  sa->last_seq = seq;
-	  sa->last_seq_hi = sa->seq_hi;
+	  sa->seq = seq;
+	  sa->seq_hi = hi_seq;
 	}
       else if (wrap < 0)
 	{
-	  pos = ~seq + sa->last_seq + 1;
+	  pos = ~seq + sa->seq + 1;
 	  sa->replay_window |= (1ULL << pos);
 	}
       else
 	{
-	  pos = sa->last_seq - seq;
+	  pos = sa->seq - seq;
 	  sa->replay_window |= (1ULL << pos);
 	}
     }
   else
     {
-      if (seq > sa->last_seq)
+      if (seq > sa->seq)
 	{
-	  pos = seq - sa->last_seq;
+	  pos = seq - sa->seq;
 	  if (pos < IPSEC_SA_ANTI_REPLAY_WINDOW_SIZE)
 	    sa->replay_window = ((sa->replay_window) << pos) | 1;
 	  else
 	    sa->replay_window = 1;
-	  sa->last_seq = seq;
+	  sa->seq = seq;
 	}
       else
 	{
-	  pos = sa->last_seq - seq;
+	  pos = sa->seq - seq;
 	  sa->replay_window |= (1ULL << pos);
 	}
     }