summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--psi/zarith.c70
1 files changed, 64 insertions, 6 deletions
diff --git a/psi/zarith.c b/psi/zarith.c
index b664cd37e..0c94b100d 100644
--- a/psi/zarith.c
+++ b/psi/zarith.c
@@ -184,6 +184,64 @@ zdiv(i_ctx_t *i_ctx_p)
return 0;
}
+/*
+To detect 64bit x 64bit multiplication overflowing, consider
+breaking the numbers down into 32bit chunks.
+
+ abc = (a<<64) + (b<<32) + c
+ (where a is 0 or -1, and b and c are 32bit unsigned.
+
+Similarly:
+
+ def = (d<<64) + (b<<32) + f
+
+Then:
+
+ abc.def = ((a<<64) + (b<<32) + c) * ((d<<64) + (e<<32) + f)
+ = (a<<64).def + (d<<64).abc + (b<<32).(e<<32) +
+ (b<<32).f + (e<<32).c + cf
+ = (a.def + d.abc + b.e)<<64 + (b.f + e.c)<<32 + cf
+
+*/
+
+int mul_64_64_overflowcheck(int64_t abc, int64_t def, int64_t *res)
+{
+ uint32_t b = (abc>>32);
+ uint32_t c = (uint32_t)abc;
+ uint32_t e = (def>>32);
+ uint32_t f = (uint32_t)def;
+ uint64_t low, mid, high, bf, ec;
+
+ /* Low contribution */
+ low = (uint64_t)c * (uint64_t)f;
+ /* Mid contributions */
+ bf = (uint64_t)b * (uint64_t)f;
+ ec = (uint64_t)e * (uint64_t)c;
+ /* Top contribution */
+ high = (uint64_t)b * (uint64_t)e;
+ if (abc < 0)
+ high -= def;
+ if (def < 0)
+ high -= abc;
+ /* How do we check for carries from 64bit unsigned adds?
+ * x + y >= (1<<64) == x >= (1<<64) - y
+ * == x > (1<<64) - y - 1
+ * if we consider just 64bits, this is:
+ * x > NOT y
+ */
+ if (bf > ~ec)
+ high += ((uint64_t)1)<<32;
+ mid = bf + ec;
+ if (low > ~(mid<<32))
+ high += 1;
+ high += (mid>>32);
+ low += (mid<<32);
+
+ *res = low;
+
+ return (int64_t)low < 0 ? high != -1 : high != 0;
+}
+
/* <num1> <num2> mul <product> */
int
zmul(i_ctx_t *i_ctx_p)
@@ -242,13 +300,13 @@ zmul(i_ctx_t *i_ctx_p)
op[-1].value.intval = (ps_int)ab;
}
else {
- double ab = (double)op[-1].value.intval * op->value.intval;
- if (ab > (double)MAX_PS_INT) /* (double)0x7fffffffffffffff */
+ int64_t result;
+ if (mul_64_64_overflowcheck(op[-1].value.intval, op->value.intval, &result)) {
+ double ab = (double)op[-1].value.intval * op->value.intval;
make_real(op - 1, ab);
- else if (ab < (double)MIN_PS_INT) /* (double)(int64_t)0x8000000000000000 */
- make_real(op - 1, ab);
- else
- op[-1].value.intval = (ps_int)ab;
+ } else {
+ op[-1].value.intval = result;
+ }
}
}
}