tls: P256: change logic so that we don't need double-wide vectors everywhere

Change sp_256to512z_mont_{mul,sqr}_8 to not require/zero upper 256 bits.
There is only one place where we actually used that (and that's why there
used to be zeroing memset of top half!). Fix up that place.
As a bonus, 256x256->512 multiply no longer needs to care for
"r overlaps a or b" case.

This shrinks sp_point structure as well, not just temporaries.

function                                             old     new   delta
sp_256to512z_mont_mul_8                              150       -    -150
sp_256_mont_mul_8                                      -     147    +147
sp_256to512z_mont_sqr_8                                7       -      -7
sp_256_mont_sqr_8                                      -       7      +7
sp_256_ecc_mulmod_8                                  494     543     +49
sp_512to256_mont_reduce_8                            243     249      +6
sp_256_point_from_bin2x32                             73      70      -3
sp_256_proj_point_dbl_8                              353     345      -8
sp_256_proj_point_add_8                              544     499     -45
------------------------------------------------------------------------------
(add/remove: 2/2 grow/shrink: 2/3 up/down: 209/-213)           Total: -4 bytes

Signed-off-by: Denys Vlasenko <vda.linux@googlemail.com>
diff --git a/networking/tls_sp_c32.c b/networking/tls_sp_c32.c
index 3291b55..3452b08 100644
--- a/networking/tls_sp_c32.c
+++ b/networking/tls_sp_c32.c
@@ -49,9 +49,9 @@
  */
 
 typedef struct sp_point {
-	sp_digit x[2 * 8];
-	sp_digit y[2 * 8];
-	sp_digit z[2 * 8];
+	sp_digit x[8];
+	sp_digit y[8];
+	sp_digit z[8];
 	int infinity;
 } sp_point;
 
@@ -456,12 +456,11 @@
 #endif
 
 /* Multiply a and b into r. (r = a * b)
- * r should be [16] array (512 bits).
+ * r should be [16] array (512 bits), and must not coincide with a or b.
  */
 static void sp_256to512_mul_8(sp_digit* r, const sp_digit* a, const sp_digit* b)
 {
 #if ALLOW_ASM && defined(__GNUC__) && defined(__i386__)
-	sp_digit rr[15]; /* in case r coincides with a or b */
 	int k;
 	uint32_t accl;
 	uint32_t acch;
@@ -493,16 +492,15 @@
 		        j--;
 			i++;
 		} while (i != 8 && i <= k);
-		rr[k] = accl;
+		r[k] = accl;
 		accl = acch;
 		acch = acc_hi;
 	}
 	r[15] = accl;
-	memcpy(r, rr, sizeof(rr));
 #elif ALLOW_ASM && defined(__GNUC__) && defined(__x86_64__)
 	const uint64_t* aa = (const void*)a;
 	const uint64_t* bb = (const void*)b;
-	uint64_t rr[8];
+	const uint64_t* rr = (const void*)r;
 	int k;
 	uint64_t accl;
 	uint64_t acch;
@@ -539,11 +537,8 @@
 		acch = acc_hi;
 	}
 	rr[7] = accl;
-	memcpy(r, rr, sizeof(rr));
 #elif 0
 	//TODO: arm assembly (untested)
-	sp_digit tmp[16];
-
 	asm volatile (
 "\n		mov	r5, #0"
 "\n		mov	r6, #0"
@@ -575,12 +570,10 @@
 "\n		cmp	r5, #56"
 "\n		ble	1b"
 "\n		str	r6, [%[r], r5]"
-		: [r] "r" (tmp), [a] "r" (a), [b] "r" (b)
+		: [r] "r" (r), [a] "r" (a), [b] "r" (b)
 		: "memory", "r3", "r4", "r5", "r6", "r7", "r8", "r9", "r10", "r12", "r14"
 	);
-	memcpy(r, tmp, sizeof(tmp));
 #else
-	sp_digit rr[15]; /* in case r coincides with a or b */
 	int i, j, k;
 	uint64_t acc;
 
@@ -600,11 +593,10 @@
 		        j--;
 			i++;
 		} while (i != 8 && i <= k);
-		rr[k] = acc;
+		r[k] = acc;
 		acc = (acc >> 32) | ((uint64_t)acc_hi << 32);
 	}
 	r[15] = acc;
-	memcpy(r, rr, sizeof(rr));
 #endif
 }
 
@@ -709,30 +701,11 @@
 }
 
 /* Shift the result in the high 256 bits down to the bottom.
- * High half is cleared to zeros.
  */
-#if BB_UNALIGNED_MEMACCESS_OK && ULONG_MAX > 0xffffffff
-static void sp_512to256_mont_shift_8(sp_digit* rr)
+static void sp_512to256_mont_shift_8(sp_digit* r, sp_digit* a)
 {
-	uint64_t *r = (void*)rr;
-	int i;
-
-	for (i = 0; i < 4; i++) {
-		r[i] = r[i+4];
-		r[i+4] = 0;
-	}
+	memcpy(r, a + 8, sizeof(*r) * 8);
 }
-#else
-static void sp_512to256_mont_shift_8(sp_digit* r)
-{
-	int i;
-
-	for (i = 0; i < 8; i++) {
-		r[i] = r[i+8];
-		r[i+8] = 0;
-	}
-}
-#endif
 
 /* Mul a by scalar b and add into r. (r += a * b)
  * a = p256_mod
@@ -868,11 +841,12 @@
  * Note: the result is NOT guaranteed to be less than p256_mod!
  * (it is only guaranteed to fit into 256 bits).
  *
- * a   Double-wide number to reduce in place.
+ * r   Result.
+ * a   Double-wide number to reduce. Clobbered.
  * m   The single precision number representing the modulus.
  * mp  The digit representing the negative inverse of m mod 2^n.
  */
-static void sp_512to256_mont_reduce_8(sp_digit* a/*, const sp_digit* m, sp_digit mp*/)
+static void sp_512to256_mont_reduce_8(sp_digit* r, sp_digit* a/*, const sp_digit* m, sp_digit mp*/)
 {
 //	const sp_digit* m = p256_mod;
 	sp_digit mp = p256_mp_mod;
@@ -895,10 +869,10 @@
 					goto inc_next_word0;
 			}
 		}
-		sp_512to256_mont_shift_8(a);
+		sp_512to256_mont_shift_8(r, a);
 		if (word16th != 0)
-			sp_256_sub_8_p256_mod(a);
-		sp_256_norm_8(a);
+			sp_256_sub_8_p256_mod(r);
+		sp_256_norm_8(r);
 	}
 	else { /* Same code for explicit mp == 1 (which is always the case for P256) */
 		sp_digit word16th = 0;
@@ -915,10 +889,10 @@
 					goto inc_next_word;
 			}
 		}
-		sp_512to256_mont_shift_8(a);
+		sp_512to256_mont_shift_8(r, a);
 		if (word16th != 0)
-			sp_256_sub_8_p256_mod(a);
-		sp_256_norm_8(a);
+			sp_256_sub_8_p256_mod(r);
+		sp_256_norm_8(r);
 	}
 }
 
@@ -926,35 +900,34 @@
  * (r = a * b mod m)
  *
  * r   Result of multiplication.
- *     Should be [16] array (512 bits), but high half is cleared to zeros (used as scratch pad).
  * a   First number to multiply in Montogmery form.
  * b   Second number to multiply in Montogmery form.
  * m   Modulus (prime).
  * mp  Montogmery mulitplier.
  */
-static void sp_256to512z_mont_mul_8(sp_digit* r, const sp_digit* a, const sp_digit* b
+static void sp_256_mont_mul_8(sp_digit* r, const sp_digit* a, const sp_digit* b
 		/*, const sp_digit* m, sp_digit mp*/)
 {
 	//const sp_digit* m = p256_mod;
 	//sp_digit mp = p256_mp_mod;
-	sp_256to512_mul_8(r, a, b);
-	sp_512to256_mont_reduce_8(r /*, m, mp*/);
+	sp_digit t[2 * 8];
+	sp_256to512_mul_8(t, a, b);
+	sp_512to256_mont_reduce_8(r, t /*, m, mp*/);
 }
 
 /* Square the Montgomery form number. (r = a * a mod m)
  *
  * r   Result of squaring.
- *     Should be [16] array (512 bits), but high half is cleared to zeros (used as scratch pad).
  * a   Number to square in Montogmery form.
  * m   Modulus (prime).
  * mp  Montogmery mulitplier.
  */
-static void sp_256to512z_mont_sqr_8(sp_digit* r, const sp_digit* a
+static void sp_256_mont_sqr_8(sp_digit* r, const sp_digit* a
 		/*, const sp_digit* m, sp_digit mp*/)
 {
 	//const sp_digit* m = p256_mod;
 	//sp_digit mp = p256_mp_mod;
-	sp_256to512z_mont_mul_8(r, a, a /*, m, mp*/);
+	sp_256_mont_mul_8(r, a, a /*, m, mp*/);
 }
 
 /* Invert the number, in Montgomery form, modulo the modulus (prime) of the
@@ -964,11 +937,8 @@
  * a   Number to invert.
  */
 #if 0
-/* Mod-2 for the P256 curve. */
-static const uint32_t p256_mod_2[8] = {
-	0xfffffffd,0xffffffff,0xffffffff,0x00000000,
-	0x00000000,0x00000000,0x00000001,0xffffffff,
-};
+//p256_mod - 2:
+//ffffffff 00000001 00000000 00000000 00000000 ffffffff ffffffff ffffffff - 2
 //Bit pattern:
 //2    2         2         2         2         2         2         1...1
 //5    5         4         3         2         1         0         9...0         9...1
@@ -977,15 +947,15 @@
 #endif
 static void sp_256_mont_inv_8(sp_digit* r, sp_digit* a)
 {
-	sp_digit t[2*8];
+	sp_digit t[8];
 	int i;
 
 	memcpy(t, a, sizeof(sp_digit) * 8);
 	for (i = 254; i >= 0; i--) {
-		sp_256to512z_mont_sqr_8(t, t /*, p256_mod, p256_mp_mod*/);
+		sp_256_mont_sqr_8(t, t /*, p256_mod, p256_mp_mod*/);
 		/*if (p256_mod_2[i / 32] & ((sp_digit)1 << (i % 32)))*/
 		if (i >= 224 || i == 192 || (i <= 95 && i != 1))
-			sp_256to512z_mont_mul_8(t, t, a /*, p256_mod, p256_mp_mod*/);
+			sp_256_mont_mul_8(t, t, a /*, p256_mod, p256_mp_mod*/);
 	}
 	memcpy(r, t, sizeof(sp_digit) * 8);
 }
@@ -1056,25 +1026,28 @@
  */
 static void sp_256_map_8(sp_point* r, sp_point* p)
 {
-	sp_digit t1[2*8];
-	sp_digit t2[2*8];
+	sp_digit t1[8];
+	sp_digit t2[8];
+	sp_digit rr[2 * 8];
 
 	sp_256_mont_inv_8(t1, p->z);
 
-	sp_256to512z_mont_sqr_8(t2, t1 /*, p256_mod, p256_mp_mod*/);
-	sp_256to512z_mont_mul_8(t1, t2, t1 /*, p256_mod, p256_mp_mod*/);
+	sp_256_mont_sqr_8(t2, t1 /*, p256_mod, p256_mp_mod*/);
+	sp_256_mont_mul_8(t1, t2, t1 /*, p256_mod, p256_mp_mod*/);
 
 	/* x /= z^2 */
-	sp_256to512z_mont_mul_8(r->x, p->x, t2 /*, p256_mod, p256_mp_mod*/);
-	sp_512to256_mont_reduce_8(r->x /*, p256_mod, p256_mp_mod*/);
+	sp_256_mont_mul_8(rr, p->x, t2 /*, p256_mod, p256_mp_mod*/);
+	memset(rr + 8, 0, sizeof(rr) / 2);
+	sp_512to256_mont_reduce_8(r->x, rr /*, p256_mod, p256_mp_mod*/);
 	/* Reduce x to less than modulus */
 	if (sp_256_cmp_8(r->x, p256_mod) >= 0)
 		sp_256_sub_8_p256_mod(r->x);
 	sp_256_norm_8(r->x);
 
 	/* y /= z^3 */
-	sp_256to512z_mont_mul_8(r->y, p->y, t1 /*, p256_mod, p256_mp_mod*/);
-	sp_512to256_mont_reduce_8(r->y /*, p256_mod, p256_mp_mod*/);
+	sp_256_mont_mul_8(rr, p->y, t1 /*, p256_mod, p256_mp_mod*/);
+	memset(rr + 8, 0, sizeof(rr) / 2);
+	sp_512to256_mont_reduce_8(r->y, rr /*, p256_mod, p256_mp_mod*/);
 	/* Reduce y to less than modulus */
 	if (sp_256_cmp_8(r->y, p256_mod) >= 0)
 		sp_256_sub_8_p256_mod(r->y);
@@ -1091,8 +1064,8 @@
  */
 static void sp_256_proj_point_dbl_8(sp_point* r, sp_point* p)
 {
-	sp_digit t1[2*8];
-	sp_digit t2[2*8];
+	sp_digit t1[8];
+	sp_digit t2[8];
 
 	/* Put point to double into result */
 	if (r != p)
@@ -1101,17 +1074,10 @@
 	if (r->infinity)
 		return;
 
-	if (SP_DEBUG) {
-		/* unused part of t2, may result in spurios
-		 * differences in debug output. Clear it.
-		 */
-		memset(t2, 0, sizeof(t2));
-	}
-
 	/* T1 = Z * Z */
-	sp_256to512z_mont_sqr_8(t1, r->z /*, p256_mod, p256_mp_mod*/);
+	sp_256_mont_sqr_8(t1, r->z /*, p256_mod, p256_mp_mod*/);
 	/* Z = Y * Z */
-	sp_256to512z_mont_mul_8(r->z, r->y, r->z /*, p256_mod, p256_mp_mod*/);
+	sp_256_mont_mul_8(r->z, r->y, r->z /*, p256_mod, p256_mp_mod*/);
 	/* Z = 2Z */
 	sp_256_mont_dbl_8(r->z, r->z /*, p256_mod*/);
 	/* T2 = X - T1 */
@@ -1119,21 +1085,21 @@
 	/* T1 = X + T1 */
 	sp_256_mont_add_8(t1, r->x, t1 /*, p256_mod*/);
 	/* T2 = T1 * T2 */
-	sp_256to512z_mont_mul_8(t2, t1, t2 /*, p256_mod, p256_mp_mod*/);
+	sp_256_mont_mul_8(t2, t1, t2 /*, p256_mod, p256_mp_mod*/);
 	/* T1 = 3T2 */
 	sp_256_mont_tpl_8(t1, t2 /*, p256_mod*/);
 	/* Y = 2Y */
 	sp_256_mont_dbl_8(r->y, r->y /*, p256_mod*/);
 	/* Y = Y * Y */
-	sp_256to512z_mont_sqr_8(r->y, r->y /*, p256_mod, p256_mp_mod*/);
+	sp_256_mont_sqr_8(r->y, r->y /*, p256_mod, p256_mp_mod*/);
 	/* T2 = Y * Y */
-	sp_256to512z_mont_sqr_8(t2, r->y /*, p256_mod, p256_mp_mod*/);
+	sp_256_mont_sqr_8(t2, r->y /*, p256_mod, p256_mp_mod*/);
 	/* T2 = T2/2 */
 	sp_256_div2_8(t2 /*, p256_mod*/);
 	/* Y = Y * X */
-	sp_256to512z_mont_mul_8(r->y, r->y, r->x /*, p256_mod, p256_mp_mod*/);
+	sp_256_mont_mul_8(r->y, r->y, r->x /*, p256_mod, p256_mp_mod*/);
 	/* X = T1 * T1 */
-	sp_256to512z_mont_mul_8(r->x, t1, t1 /*, p256_mod, p256_mp_mod*/);
+	sp_256_mont_mul_8(r->x, t1, t1 /*, p256_mod, p256_mp_mod*/);
 	/* X = X - Y */
 	sp_256_mont_sub_8(r->x, r->x, r->y /*, p256_mod*/);
 	/* X = X - Y */
@@ -1141,7 +1107,7 @@
 	/* Y = Y - X */
 	sp_256_mont_sub_8(r->y, r->y, r->x /*, p256_mod*/);
 	/* Y = Y * T1 */
-	sp_256to512z_mont_mul_8(r->y, r->y, t1 /*, p256_mod, p256_mp_mod*/);
+	sp_256_mont_mul_8(r->y, r->y, t1 /*, p256_mod, p256_mp_mod*/);
 	/* Y = Y - T2 */
 	sp_256_mont_sub_8(r->y, r->y, t2 /*, p256_mod*/);
 	dump_512("y2 %s\n", r->y);
@@ -1155,11 +1121,11 @@
  */
 static NOINLINE void sp_256_proj_point_add_8(sp_point* r, sp_point* p, sp_point* q)
 {
-	sp_digit t1[2*8];
-	sp_digit t2[2*8];
-	sp_digit t3[2*8];
-	sp_digit t4[2*8];
-	sp_digit t5[2*8];
+	sp_digit t1[8];
+	sp_digit t2[8];
+	sp_digit t3[8];
+	sp_digit t4[8];
+	sp_digit t5[8];
 
 	/* Ensure only the first point is the same as the result. */
 	if (q == r) {
@@ -1186,36 +1152,36 @@
 	}
 
 	/* U1 = X1*Z2^2 */
-	sp_256to512z_mont_sqr_8(t1, q->z /*, p256_mod, p256_mp_mod*/);
-	sp_256to512z_mont_mul_8(t3, t1, q->z /*, p256_mod, p256_mp_mod*/);
-	sp_256to512z_mont_mul_8(t1, t1, r->x /*, p256_mod, p256_mp_mod*/);
+	sp_256_mont_sqr_8(t1, q->z /*, p256_mod, p256_mp_mod*/);
+	sp_256_mont_mul_8(t3, t1, q->z /*, p256_mod, p256_mp_mod*/);
+	sp_256_mont_mul_8(t1, t1, r->x /*, p256_mod, p256_mp_mod*/);
 	/* U2 = X2*Z1^2 */
-	sp_256to512z_mont_sqr_8(t2, r->z /*, p256_mod, p256_mp_mod*/);
-	sp_256to512z_mont_mul_8(t4, t2, r->z /*, p256_mod, p256_mp_mod*/);
-	sp_256to512z_mont_mul_8(t2, t2, q->x /*, p256_mod, p256_mp_mod*/);
+	sp_256_mont_sqr_8(t2, r->z /*, p256_mod, p256_mp_mod*/);
+	sp_256_mont_mul_8(t4, t2, r->z /*, p256_mod, p256_mp_mod*/);
+	sp_256_mont_mul_8(t2, t2, q->x /*, p256_mod, p256_mp_mod*/);
 	/* S1 = Y1*Z2^3 */
-	sp_256to512z_mont_mul_8(t3, t3, r->y /*, p256_mod, p256_mp_mod*/);
+	sp_256_mont_mul_8(t3, t3, r->y /*, p256_mod, p256_mp_mod*/);
 	/* S2 = Y2*Z1^3 */
-	sp_256to512z_mont_mul_8(t4, t4, q->y /*, p256_mod, p256_mp_mod*/);
+	sp_256_mont_mul_8(t4, t4, q->y /*, p256_mod, p256_mp_mod*/);
 	/* H = U2 - U1 */
 	sp_256_mont_sub_8(t2, t2, t1 /*, p256_mod*/);
 	/* R = S2 - S1 */
 	sp_256_mont_sub_8(t4, t4, t3 /*, p256_mod*/);
 	/* Z3 = H*Z1*Z2 */
-	sp_256to512z_mont_mul_8(r->z, r->z, q->z /*, p256_mod, p256_mp_mod*/);
-	sp_256to512z_mont_mul_8(r->z, r->z, t2 /*, p256_mod, p256_mp_mod*/);
+	sp_256_mont_mul_8(r->z, r->z, q->z /*, p256_mod, p256_mp_mod*/);
+	sp_256_mont_mul_8(r->z, r->z, t2 /*, p256_mod, p256_mp_mod*/);
 	/* X3 = R^2 - H^3 - 2*U1*H^2 */
-	sp_256to512z_mont_sqr_8(r->x, t4 /*, p256_mod, p256_mp_mod*/);
-	sp_256to512z_mont_sqr_8(t5, t2 /*, p256_mod, p256_mp_mod*/);
-	sp_256to512z_mont_mul_8(r->y, t1, t5 /*, p256_mod, p256_mp_mod*/);
-	sp_256to512z_mont_mul_8(t5, t5, t2 /*, p256_mod, p256_mp_mod*/);
+	sp_256_mont_sqr_8(r->x, t4 /*, p256_mod, p256_mp_mod*/);
+	sp_256_mont_sqr_8(t5, t2 /*, p256_mod, p256_mp_mod*/);
+	sp_256_mont_mul_8(r->y, t1, t5 /*, p256_mod, p256_mp_mod*/);
+	sp_256_mont_mul_8(t5, t5, t2 /*, p256_mod, p256_mp_mod*/);
 	sp_256_mont_sub_8(r->x, r->x, t5 /*, p256_mod*/);
 	sp_256_mont_dbl_8(t1, r->y /*, p256_mod*/);
 	sp_256_mont_sub_8(r->x, r->x, t1 /*, p256_mod*/);
 	/* Y3 = R*(U1*H^2 - X3) - S1*H^3 */
 	sp_256_mont_sub_8(r->y, r->y, r->x /*, p256_mod*/);
-	sp_256to512z_mont_mul_8(r->y, r->y, t4 /*, p256_mod, p256_mp_mod*/);
-	sp_256to512z_mont_mul_8(t5, t5, t3 /*, p256_mod, p256_mp_mod*/);
+	sp_256_mont_mul_8(r->y, r->y, t4 /*, p256_mod, p256_mp_mod*/);
+	sp_256_mont_mul_8(t5, t5, t3 /*, p256_mod, p256_mp_mod*/);
 	sp_256_mont_sub_8(r->y, r->y, t5 /*, p256_mod*/);
 }