diff options
-rw-r--r-- | psi/zarith.c | 70 |
1 files changed, 64 insertions, 6 deletions
diff --git a/psi/zarith.c b/psi/zarith.c index b664cd37e..0c94b100d 100644 --- a/psi/zarith.c +++ b/psi/zarith.c @@ -184,6 +184,64 @@ zdiv(i_ctx_t *i_ctx_p) return 0; } +/* +To detect 64bit x 64bit multiplication overflowing, consider +breaking the numbers down into 32bit chunks. + + abc = (a<<64) + (b<<32) + c + (where a is 0 or -1, and b and c are 32bit unsigned. + +Similarly: + + def = (d<<64) + (b<<32) + f + +Then: + + abc.def = ((a<<64) + (b<<32) + c) * ((d<<64) + (e<<32) + f) + = (a<<64).def + (d<<64).abc + (b<<32).(e<<32) + + (b<<32).f + (e<<32).c + cf + = (a.def + d.abc + b.e)<<64 + (b.f + e.c)<<32 + cf + +*/ + +int mul_64_64_overflowcheck(int64_t abc, int64_t def, int64_t *res) +{ + uint32_t b = (abc>>32); + uint32_t c = (uint32_t)abc; + uint32_t e = (def>>32); + uint32_t f = (uint32_t)def; + uint64_t low, mid, high, bf, ec; + + /* Low contribution */ + low = (uint64_t)c * (uint64_t)f; + /* Mid contributions */ + bf = (uint64_t)b * (uint64_t)f; + ec = (uint64_t)e * (uint64_t)c; + /* Top contribution */ + high = (uint64_t)b * (uint64_t)e; + if (abc < 0) + high -= def; + if (def < 0) + high -= abc; + /* How do we check for carries from 64bit unsigned adds? + * x + y >= (1<<64) == x >= (1<<64) - y + * == x > (1<<64) - y - 1 + * if we consider just 64bits, this is: + * x > NOT y + */ + if (bf > ~ec) + high += ((uint64_t)1)<<32; + mid = bf + ec; + if (low > ~(mid<<32)) + high += 1; + high += (mid>>32); + low += (mid<<32); + + *res = low; + + return (int64_t)low < 0 ? high != -1 : high != 0; +} + /* <num1> <num2> mul <product> */ int zmul(i_ctx_t *i_ctx_p) @@ -242,13 +300,13 @@ zmul(i_ctx_t *i_ctx_p) op[-1].value.intval = (ps_int)ab; } else { - double ab = (double)op[-1].value.intval * op->value.intval; - if (ab > (double)MAX_PS_INT) /* (double)0x7fffffffffffffff */ + int64_t result; + if (mul_64_64_overflowcheck(op[-1].value.intval, op->value.intval, &result)) { + double ab = (double)op[-1].value.intval * op->value.intval; make_real(op - 1, ab); - else if (ab < (double)MIN_PS_INT) /* (double)(int64_t)0x8000000000000000 */ - make_real(op - 1, ab); - else - op[-1].value.intval = (ps_int)ab; + } else { + op[-1].value.intval = result; + } } } } |