shell/math: deconvolute and explain ?: handling. Give better error message

function                                             old     new   delta
arith_apply                                         1271    1283     +12

Signed-off-by: Denys Vlasenko <dvlasenk@redhat.com>
diff --git a/shell/math.c b/shell/math.c
index 871c06c..9d3b912 100644
--- a/shell/math.c
+++ b/shell/math.c
@@ -1,5 +1,5 @@
 /*
- * arithmetic code ripped out of ash shell for code sharing
+ * Arithmetic code ripped out of ash shell for code sharing.
  *
  * This code is derived from software contributed to Berkeley by
  * Kenneth Almquist.
@@ -154,7 +154,7 @@
 
 #define fix_assignment_prec(prec) do { if (prec == 3) prec = 2; } while (0)
 
-/* ternary conditional operator is right associative too */
+/* Ternary conditional operator is right associative too */
 #define TOK_CONDITIONAL         tok_decl(4,0)
 #define TOK_CONDITIONAL_SEP     tok_decl(4,1)
 
@@ -186,10 +186,10 @@
 #define TOK_DIV                 tok_decl(14,1)
 #define TOK_REM                 tok_decl(14,2)
 
-/* exponent is right associative */
+/* Exponent is right associative */
 #define TOK_EXPONENT            tok_decl(15,1)
 
-/* unary operators */
+/* Unary operators */
 #define UNARYPREC               16
 #define TOK_BNOT                tok_decl(UNARYPREC,0)
 #define TOK_NOT                 tok_decl(UNARYPREC,1)
@@ -213,30 +213,37 @@
 #define TOK_RPAREN              tok_decl(SPEC_PREC, 1)
 
 static int
-tok_have_assign(operator op)
+is_assign_op(operator op)
 {
 	operator prec = PREC(op);
-
 	fix_assignment_prec(prec);
-	return (prec == PREC(TOK_ASSIGN) ||
-			prec == PREC_PRE || prec == PREC_POST);
+	return prec == PREC(TOK_ASSIGN)
+	|| prec == PREC_PRE
+	|| prec == PREC_POST;
 }
 
 static int
 is_right_associative(operator prec)
 {
-	return (prec == PREC(TOK_ASSIGN) || prec == PREC(TOK_EXPONENT)
-	        || prec == PREC(TOK_CONDITIONAL));
+	return prec == PREC(TOK_ASSIGN)
+	|| prec == PREC(TOK_EXPONENT)
+	|| prec == PREC(TOK_CONDITIONAL);
 }
 
 
 typedef struct {
 	arith_t val;
-	arith_t contidional_second_val;
-	char contidional_second_val_initialized;
-	char *var;      /* if NULL then is regular number,
-			   else is variable name */
-} v_n_t;
+	/* We acquire second_val only when "expr1 : expr2" part
+	 * of ternary ?: op is evaluated.
+	 * We treat ?: as two binary ops: (expr ? (expr1 : expr2)).
+	 * ':' produces a new value which has two parts, val and second_val;
+	 * then '?' selects one of them based on its left side.
+	 */
+	arith_t second_val;
+	char second_val_present;
+	/* If NULL then it's just a number, else it's a named variable */
+	char *var;
+} var_or_num_t;
 
 typedef struct remembered_name {
 	struct remembered_name *next;
@@ -248,7 +255,7 @@
 evaluate_string(arith_state_t *math_state, const char *expr);
 
 static const char*
-arith_lookup_val(arith_state_t *math_state, v_n_t *t)
+arith_lookup_val(arith_state_t *math_state, var_or_num_t *t)
 {
 	if (t->var) {
 		const char *p = lookupvar(t->var);
@@ -290,27 +297,28 @@
  * stack. For an unary operator it will only change the top element, but a
  * binary operator will pop two arguments and push the result */
 static NOINLINE const char*
-arith_apply(arith_state_t *math_state, operator op, v_n_t *numstack, v_n_t **numstackptr)
+arith_apply(arith_state_t *math_state, operator op, var_or_num_t *numstack, var_or_num_t **numstackptr)
 {
 #define NUMPTR (*numstackptr)
 
-	v_n_t *numptr_m1;
-	arith_t numptr_val, rez;
+	var_or_num_t *top_of_stack;
+	arith_t rez;
 	const char *err;
 
 	/* There is no operator that can work without arguments */
 	if (NUMPTR == numstack)
 		goto err;
-	numptr_m1 = NUMPTR - 1;
 
-	/* Check operand is var with noninteger value */
-	err = arith_lookup_val(math_state, numptr_m1);
+	top_of_stack = NUMPTR - 1;
+
+	/* Resolve name to value, if needed */
+	err = arith_lookup_val(math_state, top_of_stack);
 	if (err)
 		return err;
 
-	rez = numptr_m1->val;
+	rez = top_of_stack->val;
 	if (op == TOK_UMINUS)
-		rez *= -1;
+		rez = -rez;
 	else if (op == TOK_NOT)
 		rez = !rez;
 	else if (op == TOK_BNOT)
@@ -321,112 +329,119 @@
 		rez--;
 	else if (op != TOK_UPLUS) {
 		/* Binary operators */
+		arith_t right_side_val;
+		char bad_second_val;
 
-		/* check and binary operators need two arguments */
-		if (numptr_m1 == numstack) goto err;
-
-		/* ... and they pop one */
-		--NUMPTR;
-		numptr_val = rez;
-		if (op == TOK_CONDITIONAL) {
-			if (!numptr_m1->contidional_second_val_initialized) {
-				/* protect $((expr1 ? expr2)) without ": expr" */
-				goto err;
-			}
-			rez = numptr_m1->contidional_second_val;
-		} else if (numptr_m1->contidional_second_val_initialized) {
-			/* protect $((expr1 : expr2)) without "expr ? " */
+		/* Binary operators need two arguments */
+		if (top_of_stack == numstack)
 			goto err;
+		/* ...and they pop one */
+		NUMPTR = top_of_stack; /* this decrements NUMPTR */
+
+		bad_second_val = top_of_stack->second_val_present;
+		if (op == TOK_CONDITIONAL) { /* ? operation */
+			/* Make next if (...) protect against
+			 * $((expr1 ? expr2)) - that is, missing ": expr" */
+			bad_second_val = !bad_second_val;
 		}
-		numptr_m1 = NUMPTR - 1;
+		if (bad_second_val) {
+			/* Protect against $((expr <not_?_op> expr1 : expr2)) */
+			return "malformed ?: operator";
+		}
+
+		top_of_stack--; /* now points to left side */
+
 		if (op != TOK_ASSIGN) {
-			/* check operand is var with noninteger value for not '=' */
-			err = arith_lookup_val(math_state, numptr_m1);
+			/* Resolve left side value (unless the op is '=') */
+			err = arith_lookup_val(math_state, top_of_stack);
 			if (err)
 				return err;
 		}
-		if (op == TOK_CONDITIONAL) {
-			numptr_m1->contidional_second_val = rez;
-		}
-		rez = numptr_m1->val;
-		if (op == TOK_BOR || op == TOK_OR_ASSIGN)
-			rez |= numptr_val;
-		else if (op == TOK_OR)
-			rez = numptr_val || rez;
-		else if (op == TOK_BAND || op == TOK_AND_ASSIGN)
-			rez &= numptr_val;
-		else if (op == TOK_BXOR || op == TOK_XOR_ASSIGN)
-			rez ^= numptr_val;
-		else if (op == TOK_AND)
-			rez = rez && numptr_val;
-		else if (op == TOK_EQ)
-			rez = (rez == numptr_val);
-		else if (op == TOK_NE)
-			rez = (rez != numptr_val);
-		else if (op == TOK_GE)
-			rez = (rez >= numptr_val);
-		else if (op == TOK_RSHIFT || op == TOK_RSHIFT_ASSIGN)
-			rez >>= numptr_val;
-		else if (op == TOK_LSHIFT || op == TOK_LSHIFT_ASSIGN)
-			rez <<= numptr_val;
-		else if (op == TOK_GT)
-			rez = (rez > numptr_val);
-		else if (op == TOK_LT)
-			rez = (rez < numptr_val);
-		else if (op == TOK_LE)
-			rez = (rez <= numptr_val);
-		else if (op == TOK_MUL || op == TOK_MUL_ASSIGN)
-			rez *= numptr_val;
-		else if (op == TOK_ADD || op == TOK_PLUS_ASSIGN)
-			rez += numptr_val;
-		else if (op == TOK_SUB || op == TOK_MINUS_ASSIGN)
-			rez -= numptr_val;
-		else if (op == TOK_ASSIGN || op == TOK_COMMA)
-			rez = numptr_val;
-		else if (op == TOK_CONDITIONAL_SEP) {
-			if (numptr_m1 == numstack) {
-				/* protect $((expr : expr)) without "expr ? " */
-				goto err;
+
+		right_side_val = rez;
+		rez = top_of_stack->val;
+		if (op == TOK_CONDITIONAL) /* ? operation */
+			rez = (rez ? right_side_val : top_of_stack[1].second_val);
+		else if (op == TOK_CONDITIONAL_SEP) { /* : operation */
+			if (top_of_stack == numstack) {
+				/* Protect against $((expr : expr)) */
+				return "malformed ?: operator";
 			}
-			numptr_m1->contidional_second_val_initialized = op;
-			numptr_m1->contidional_second_val = numptr_val;
-		} else if (op == TOK_CONDITIONAL) {
-			rez = rez ?
-				numptr_val : numptr_m1->contidional_second_val;
-		} else if (op == TOK_EXPONENT) {
+			top_of_stack->second_val_present = op;
+			top_of_stack->second_val = right_side_val;
+		}
+		else if (op == TOK_BOR || op == TOK_OR_ASSIGN)
+			rez |= right_side_val;
+		else if (op == TOK_OR)
+			rez = right_side_val || rez;
+		else if (op == TOK_BAND || op == TOK_AND_ASSIGN)
+			rez &= right_side_val;
+		else if (op == TOK_BXOR || op == TOK_XOR_ASSIGN)
+			rez ^= right_side_val;
+		else if (op == TOK_AND)
+			rez = rez && right_side_val;
+		else if (op == TOK_EQ)
+			rez = (rez == right_side_val);
+		else if (op == TOK_NE)
+			rez = (rez != right_side_val);
+		else if (op == TOK_GE)
+			rez = (rez >= right_side_val);
+		else if (op == TOK_RSHIFT || op == TOK_RSHIFT_ASSIGN)
+			rez >>= right_side_val;
+		else if (op == TOK_LSHIFT || op == TOK_LSHIFT_ASSIGN)
+			rez <<= right_side_val;
+		else if (op == TOK_GT)
+			rez = (rez > right_side_val);
+		else if (op == TOK_LT)
+			rez = (rez < right_side_val);
+		else if (op == TOK_LE)
+			rez = (rez <= right_side_val);
+		else if (op == TOK_MUL || op == TOK_MUL_ASSIGN)
+			rez *= right_side_val;
+		else if (op == TOK_ADD || op == TOK_PLUS_ASSIGN)
+			rez += right_side_val;
+		else if (op == TOK_SUB || op == TOK_MINUS_ASSIGN)
+			rez -= right_side_val;
+		else if (op == TOK_ASSIGN || op == TOK_COMMA)
+			rez = right_side_val;
+		else if (op == TOK_EXPONENT) {
 			arith_t c;
-			if (numptr_val < 0)
+			if (right_side_val < 0)
 				return "exponent less than 0";
 			c = 1;
-			while (--numptr_val >= 0)
+			while (--right_side_val >= 0)
 			    c *= rez;
 			rez = c;
-		} else if (numptr_val == 0)
+		}
+		else if (right_side_val == 0)
 			return "divide by zero";
 		else if (op == TOK_DIV || op == TOK_DIV_ASSIGN)
-			rez /= numptr_val;
+			rez /= right_side_val;
 		else if (op == TOK_REM || op == TOK_REM_ASSIGN)
-			rez %= numptr_val;
+			rez %= right_side_val;
 	}
-	if (tok_have_assign(op)) {
+
+	if (is_assign_op(op)) {
 		char buf[sizeof(arith_t)*3 + 2];
 
-		if (numptr_m1->var == NULL) {
+		if (top_of_stack->var == NULL) {
 			/* Hmm, 1=2 ? */
+//TODO: actually, bash allows ++7 but for some reason it evals to 7, not 8
 			goto err;
 		}
-		/* save to shell variable */
-		sprintf(buf, arith_t_fmt, rez);
-		setvar(numptr_m1->var, buf);
-		/* after saving, make previous value for v++ or v-- */
+		/* Save to shell variable */
+		sprintf(buf, ARITH_FMT, rez);
+		setvar(top_of_stack->var, buf);
+		/* After saving, make previous value for v++ or v-- */
 		if (op == TOK_POST_INC)
 			rez--;
 		else if (op == TOK_POST_DEC)
 			rez++;
 	}
-	numptr_m1->val = rez;
-	/* erase var name, it is just a number now */
-	numptr_m1->var = NULL;
+
+	top_of_stack->val = rez;
+	/* Erase var name, it is just a number now */
+	top_of_stack->var = NULL;
 	return NULL;
  err:
 	return "arithmetic syntax error";
@@ -499,16 +514,17 @@
 	const char *start_expr = expr = skip_whitespace(expr);
 	unsigned expr_len = strlen(expr) + 2;
 	/* Stack of integers */
-	/* The proof that there can be no more than strlen(startbuf)/2+1 integers
-	 * in any given correct or incorrect expression is left as an exercise to
-	 * the reader. */
-	v_n_t *const numstack = alloca((expr_len / 2) * sizeof(numstack[0]));
-	v_n_t *numstackptr = numstack;
+	/* The proof that there can be no more than strlen(startbuf)/2+1
+	 * integers in any given correct or incorrect expression
+	 * is left as an exercise to the reader. */
+	var_or_num_t *const numstack = alloca((expr_len / 2) * sizeof(numstack[0]));
+	var_or_num_t *numstackptr = numstack;
 	/* Stack of operator tokens */
 	operator *const stack = alloca(expr_len * sizeof(stack[0]));
 	operator *stackptr = stack;
 
-	*stackptr++ = lasttok = TOK_LPAREN;     /* start off with a left paren */
+	/* Start with a left paren */
+	*stackptr++ = lasttok = TOK_LPAREN;
 	errmsg = NULL;
 
 	while (1) {
@@ -521,7 +537,7 @@
 		arithval = *expr;
 		if (arithval == '\0') {
 			if (expr == start_expr) {
-				/* Null expression. */
+				/* Null expression */
 				numstack->val = 0;
 				goto ret;
 			}
@@ -558,7 +574,7 @@
 			safe_strncpy(numstackptr->var, expr, var_name_size);
 			expr = p;
  num:
-			numstackptr->contidional_second_val_initialized = 0;
+			numstackptr->second_val_present = 0;
 			numstackptr++;
 			lasttok = TOK_NUM;
 			continue;
@@ -577,21 +593,32 @@
 		/* Should be an operator */
 		p = op_tokens;
 		while (1) {
-			const char *e = expr;
+// TODO: bash allows 7+++v, treats it as 7 + ++v
+// we treat it as 7++ + v and reject
 			/* Compare expr to current op_tokens[] element */
-			while (*p && *e == *p)
-				p++, e++;
-			if (*p == '\0') { /* match: operator is found */
-				expr = e;
-				break;
+			const char *e = expr;
+			while (1) {
+				if (*p == '\0') {
+					/* Match: operator is found */
+					expr = e;
+					goto tok_found;
+				}
+				if (*p != *e)
+					break;
+				p++;
+				e++;
 			}
-			/* Go to next element of op_tokens[] */
+			/* No match, go to next element of op_tokens[] */
 			while (*p)
 				p++;
 			p += 2; /* skip NUL and TOK_foo bytes */
-			if (*p == '\0') /* no next element, operator not found */
+			if (*p == '\0') {
+				/* No next element, operator not found */
+				//math_state->syntax_error_at = expr;
 				goto err;
+			}
 		}
+ tok_found:
 		op = p[1]; /* fetch TOK_foo value */
 		/* NB: expr now points past the operator */
 
@@ -662,21 +689,21 @@
 				}
 				errmsg = arith_apply(math_state, prev_op, numstack, &numstackptr);
 				if (errmsg)
-					goto ret;
+					goto err_with_custom_msg;
 			}
-			if (op == TOK_RPAREN) {
+			if (op == TOK_RPAREN)
 				goto err;
-			}
 		}
 
-		/* Push this operator to the stack and remember it. */
+		/* Push this operator to the stack and remember it */
 		*stackptr++ = lasttok = op;
  next: ;
 	} /* while (1) */
 
  err:
-	numstack->val = -1;
 	errmsg = "arithmetic syntax error";
+ err_with_custom_msg:
+	numstack->val = -1;
  ret:
 	math_state->errmsg = errmsg;
 	return numstack->val;