api: verify message size on receipt

When a message is received, verify that it's sufficiently large to
accomodate any VLAs within message. To do that, we need a way to
calculate message size including any VLAs. This patch adds such
funcionality to vppapigen and necessary C code to use those to validate
message size on receipt. Drop messages which are malformed.

Type: improvement
Signed-off-by: Klement Sekera <ksekera@cisco.com>
Change-Id: I2903aa21dee84be6822b064795ba314de46c18f4
diff --git a/src/vlibapi/api_common.h b/src/vlibapi/api_common.h
index 320e7c4..6b36314 100644
--- a/src/vlibapi/api_common.h
+++ b/src/vlibapi/api_common.h
@@ -132,6 +132,7 @@
   void *print_json;		/**< message print function (JSON format)  */
   void *tojson;			/**< binary to JSON convert function */
   void *fromjson;		/**< JSON to binary convert function */
+  void *calc_size;		/**< message size calculation */
   int size;			/**< message size  */
   int traced;			/**< is this message to be traced?  */
   int replay;			/**< is this message to be replayed?  */
@@ -170,17 +171,18 @@
 }
 
 /* api_shared.c prototypes */
-void vl_msg_api_handler (void *the_msg);
-void vl_msg_api_handler_no_free (void *the_msg);
-void vl_msg_api_handler_no_trace_no_free (void *the_msg);
-void vl_msg_api_trace_only (void *the_msg);
+void vl_msg_api_handler (void *the_msg, uword msg_len);
+void vl_msg_api_handler_no_free (void *the_msg, uword msg_len);
+void vl_msg_api_handler_no_trace_no_free (void *the_msg, uword msg_len);
+void vl_msg_api_trace_only (void *the_msg, uword msg_len);
 void vl_msg_api_cleanup_handler (void *the_msg);
 void vl_msg_api_replay_handler (void *the_msg);
-void vl_msg_api_socket_handler (void *the_msg);
+void vl_msg_api_socket_handler (void *the_msg, uword msg_len);
 void vl_msg_api_set_handlers (int msg_id, char *msg_name, void *handler,
 			      void *cleanup, void *endian, void *print,
 			      int msg_size, int traced, void *print_json,
-			      void *tojson, void *fromjson);
+			      void *tojson, void *fromjson,
+			      void *validate_size);
 void vl_msg_api_clean_handlers (int msg_id);
 void vl_msg_api_config (vl_msg_api_msg_config_t *);
 void vl_msg_api_set_cleanup_handler (int msg_id, void *fp);
@@ -251,6 +253,9 @@
   /** Message convert function vector */
   void *(**msg_fromjson_handlers) (cJSON *, int *);
 
+  /** Message calc size function vector */
+  uword (**msg_calc_size_funcs) (void *);
+
   /** Message name vector */
   const char **msg_names;
 
diff --git a/src/vlibapi/api_shared.c b/src/vlibapi/api_shared.c
index dd51ee5..f11344e 100644
--- a/src/vlibapi/api_shared.c
+++ b/src/vlibapi/api_shared.c
@@ -500,8 +500,8 @@
 }
 
 always_inline void
-msg_handler_internal (api_main_t * am,
-		      void *the_msg, int trace_it, int do_it, int free_it)
+msg_handler_internal (api_main_t *am, void *the_msg, uword msg_len,
+		      int trace_it, int do_it, int free_it)
 {
   u16 id = clib_net_to_host_u16 (*((u16 *) the_msg));
   u8 *(*print_fp) (void *, void *);
@@ -545,8 +545,35 @@
 	    }
 	}
 
-      if (do_it)
+      uword calc_size = 0;
+      uword (*calc_size_fp) (void *);
+      calc_size_fp = am->msg_calc_size_funcs[id];
+      ASSERT (NULL != calc_size_fp);
+      if (calc_size_fp)
 	{
+	  calc_size = (*calc_size_fp) (the_msg);
+	  ASSERT (calc_size <= msg_len);
+	  if (calc_size > msg_len)
+	    {
+	      clib_warning (
+		"Truncated message '%s' (id %u) received, calculated size "
+		"%lu is bigger than actual size %llu, message dropped.",
+		am->msg_names[id], id, calc_size, msg_len);
+	    }
+	}
+      else
+	{
+	  clib_warning ("Message '%s' (id %u) has NULL calc_size_func, cannot "
+			"verify message size is correct",
+			am->msg_names[id], id);
+	}
+
+      /* don't process message if it's truncated, otherwise byte swaps
+       * and stuff could corrupt memory even beyond message if it's malicious
+       * e.g. VLA length field set to 1M elements, but VLA empty */
+      if (do_it && calc_size <= msg_len)
+	{
+
 	  if (!am->is_mp_safe[id])
 	    {
 	      vl_msg_api_barrier_trace_context (am->msg_names[id]);
@@ -569,6 +596,7 @@
 	  if (PREDICT_FALSE (vec_len (am->perf_counter_cbs) != 0))
 	    clib_call_callbacks (am->perf_counter_cbs, am, id,
 				 1 /* after */ );
+
 	  if (!am->is_mp_safe[id])
 	    vl_msg_api_barrier_release ();
 	}
@@ -767,32 +795,30 @@
 }
 
 void
-vl_msg_api_handler (void *the_msg)
+vl_msg_api_handler (void *the_msg, uword msg_len)
 {
   api_main_t *am = vlibapi_get_main ();
 
-  msg_handler_internal (am, the_msg,
-			(am->rx_trace
-			 && am->rx_trace->enabled) /* trace_it */ ,
-			1 /* do_it */ , 1 /* free_it */ );
+  msg_handler_internal (am, the_msg, msg_len,
+			(am->rx_trace && am->rx_trace->enabled) /* trace_it */,
+			1 /* do_it */, 1 /* free_it */);
 }
 
 void
-vl_msg_api_handler_no_free (void *the_msg)
+vl_msg_api_handler_no_free (void *the_msg, uword msg_len)
 {
   api_main_t *am = vlibapi_get_main ();
-  msg_handler_internal (am, the_msg,
-			(am->rx_trace
-			 && am->rx_trace->enabled) /* trace_it */ ,
-			1 /* do_it */ , 0 /* free_it */ );
+  msg_handler_internal (am, the_msg, msg_len,
+			(am->rx_trace && am->rx_trace->enabled) /* trace_it */,
+			1 /* do_it */, 0 /* free_it */);
 }
 
 void
-vl_msg_api_handler_no_trace_no_free (void *the_msg)
+vl_msg_api_handler_no_trace_no_free (void *the_msg, uword msg_len)
 {
   api_main_t *am = vlibapi_get_main ();
-  msg_handler_internal (am, the_msg, 0 /* trace_it */ , 1 /* do_it */ ,
-			0 /* free_it */ );
+  msg_handler_internal (am, the_msg, msg_len, 0 /* trace_it */, 1 /* do_it */,
+			0 /* free_it */);
 }
 
 /*
@@ -805,14 +831,13 @@
  *
  */
 void
-vl_msg_api_trace_only (void *the_msg)
+vl_msg_api_trace_only (void *the_msg, uword msg_len)
 {
   api_main_t *am = vlibapi_get_main ();
 
-  msg_handler_internal (am, the_msg,
-			(am->rx_trace
-			 && am->rx_trace->enabled) /* trace_it */ ,
-			0 /* do_it */ , 0 /* free_it */ );
+  msg_handler_internal (am, the_msg, msg_len,
+			(am->rx_trace && am->rx_trace->enabled) /* trace_it */,
+			0 /* do_it */, 0 /* free_it */);
 }
 
 void
@@ -863,14 +888,13 @@
  * vl_msg_api_socket_handler
  */
 void
-vl_msg_api_socket_handler (void *the_msg)
+vl_msg_api_socket_handler (void *the_msg, uword msg_len)
 {
   api_main_t *am = vlibapi_get_main ();
 
-  msg_handler_internal (am, the_msg,
-			(am->rx_trace
-			 && am->rx_trace->enabled) /* trace_it */ ,
-			1 /* do_it */ , 0 /* free_it */ );
+  msg_handler_internal (am, the_msg, msg_len,
+			(am->rx_trace && am->rx_trace->enabled) /* trace_it */,
+			1 /* do_it */, 0 /* free_it */);
 }
 
 #define foreach_msg_api_vector                                                \
@@ -882,6 +906,7 @@
   _ (msg_print_json_handlers)                                                 \
   _ (msg_tojson_handlers)                                                     \
   _ (msg_fromjson_handlers)                                                   \
+  _ (msg_calc_size_funcs)                                                     \
   _ (api_trace_cfg)                                                           \
   _ (message_bounce)                                                          \
   _ (is_mp_safe)                                                              \
@@ -927,6 +952,7 @@
   am->msg_print_json_handlers[c->id] = c->print_json;
   am->msg_tojson_handlers[c->id] = c->tojson;
   am->msg_fromjson_handlers[c->id] = c->fromjson;
+  am->msg_calc_size_funcs[c->id] = c->calc_size;
   am->message_bounce[c->id] = c->message_bounce;
   am->is_mp_safe[c->id] = c->is_mp_safe;
   am->is_autoendian[c->id] = c->is_autoendian;
@@ -948,7 +974,8 @@
 void
 vl_msg_api_set_handlers (int id, char *name, void *handler, void *cleanup,
 			 void *endian, void *print, int size, int traced,
-			 void *print_json, void *tojson, void *fromjson)
+			 void *print_json, void *tojson, void *fromjson,
+			 void *calc_size)
 {
   vl_msg_api_msg_config_t cfg;
   vl_msg_api_msg_config_t *c = &cfg;
@@ -969,6 +996,7 @@
   c->tojson = tojson;
   c->fromjson = fromjson;
   c->print_json = print_json;
+  c->calc_size = calc_size;
   vl_msg_api_config (c);
 }
 
@@ -999,8 +1027,11 @@
 {
   uword msg;
 
-  while (!svm_queue_sub (q, (u8 *) & msg, SVM_Q_WAIT, 0))
-    vl_msg_api_handler ((void *) msg);
+  while (!svm_queue_sub (q, (u8 *) &msg, SVM_Q_WAIT, 0))
+    {
+      msgbuf_t *msgbuf = (msgbuf_t *) ((u8 *) msg - offsetof (msgbuf_t, data));
+      vl_msg_api_handler ((void *) msg, ntohl (msgbuf->data_len));
+    }
 }
 
 u32