diff options
Diffstat (limited to 'libclc/clspv/lib/math/fma.cl')
-rw-r--r-- | libclc/clspv/lib/math/fma.cl | 256 |
1 files changed, 256 insertions, 0 deletions
diff --git a/libclc/clspv/lib/math/fma.cl b/libclc/clspv/lib/math/fma.cl new file mode 100644 index 000000000000..fdc8b8b29687 --- /dev/null +++ b/libclc/clspv/lib/math/fma.cl @@ -0,0 +1,256 @@ +/* + * Copyright (c) 2014 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +// This version is derived from the generic fma software implementation +// (__clc_sw_fma), but avoids the use of ulong in favor of uint2. The logic has +// been updated as appropriate. + +#include <clc/clc.h> +#include "../../../generic/lib/clcmacro.h" +#include "../../../generic/lib/math/math.h" + +struct fp { + uint2 mantissa; + int exponent; + uint sign; +}; + +_CLC_DEF _CLC_OVERLOAD float fma(float a, float b, float c) { + /* special cases */ + if (isnan(a) || isnan(b) || isnan(c) || isinf(a) || isinf(b)) { + return mad(a, b, c); + } + + /* If only c is inf, and both a,b are regular numbers, the result is c*/ + if (isinf(c)) { + return c; + } + + a = __clc_flush_denormal_if_not_supported(a); + b = __clc_flush_denormal_if_not_supported(b); + c = __clc_flush_denormal_if_not_supported(c); + + if (a == 0.0f || b == 0.0f) { + return c; + } + + if (c == 0) { + return a * b; + } + + struct fp st_a, st_b, st_c; + + st_a.exponent = a == .0f ? 0 : ((as_uint(a) & 0x7f800000) >> 23) - 127; + st_b.exponent = b == .0f ? 0 : ((as_uint(b) & 0x7f800000) >> 23) - 127; + st_c.exponent = c == .0f ? 0 : ((as_uint(c) & 0x7f800000) >> 23) - 127; + + st_a.mantissa.lo = a == .0f ? 0 : (as_uint(a) & 0x7fffff) | 0x800000; + st_b.mantissa.lo = b == .0f ? 0 : (as_uint(b) & 0x7fffff) | 0x800000; + st_c.mantissa.lo = c == .0f ? 0 : (as_uint(c) & 0x7fffff) | 0x800000; + st_a.mantissa.hi = 0; + st_b.mantissa.hi = 0; + st_c.mantissa.hi = 0; + + st_a.sign = as_uint(a) & 0x80000000; + st_b.sign = as_uint(b) & 0x80000000; + st_c.sign = as_uint(c) & 0x80000000; + + // Multiplication. + // Move the product to the highest bits to maximize precision + // mantissa is 24 bits => product is 48 bits, 2bits non-fraction. + // Add one bit for future addition overflow, + // add another bit to detect subtraction underflow + struct fp st_mul; + st_mul.sign = st_a.sign ^ st_b.sign; + st_mul.mantissa.hi = mul_hi(st_a.mantissa.lo, st_b.mantissa.lo); + st_mul.mantissa.lo = st_a.mantissa.lo * st_b.mantissa.lo; + uint upper_14bits = (st_mul.mantissa.lo >> 18) & 0x3fff; + st_mul.mantissa.lo <<= 14; + st_mul.mantissa.hi <<= 14; + st_mul.mantissa.hi |= upper_14bits; + st_mul.exponent = (st_mul.mantissa.lo != 0 || st_mul.mantissa.hi != 0) + ? st_a.exponent + st_b.exponent + : 0; + +// Mantissa is 23 fractional bits, shift it the same way as product mantissa +#define C_ADJUST 37ul + + // both exponents are bias adjusted + int exp_diff = st_mul.exponent - st_c.exponent; + + uint abs_exp_diff = abs(exp_diff); + st_c.mantissa.hi = (st_c.mantissa.lo << 5); + st_c.mantissa.lo = 0; + uint2 cutoff_bits = (uint2)(0, 0); + uint2 cutoff_mask = (uint2)(0, 0); + if (abs_exp_diff < 32) { + cutoff_mask.lo = (1u << abs(exp_diff)) - 1u; + } else if (abs_exp_diff < 64) { + cutoff_mask.lo = 0xffffffff; + uint remaining = abs_exp_diff - 32; + cutoff_mask.hi = (1u << remaining) - 1u; + } else { + cutoff_mask = (uint2)(0, 0); + } + uint2 tmp = (exp_diff > 0) ? st_c.mantissa : st_mul.mantissa; + if (abs_exp_diff > 0) { + cutoff_bits = abs_exp_diff >= 64 ? tmp : (tmp & cutoff_mask); + if (abs_exp_diff < 32) { + // shift some of the hi bits into the shifted lo bits. + uint shift_mask = (1u << abs_exp_diff) - 1; + uint upper_saved_bits = tmp.hi & shift_mask; + upper_saved_bits = upper_saved_bits << (32 - abs_exp_diff); + tmp.hi >>= abs_exp_diff; + tmp.lo >>= abs_exp_diff; + tmp.lo |= upper_saved_bits; + } else if (abs_exp_diff < 64) { + tmp.lo = (tmp.hi >> (abs_exp_diff - 32)); + tmp.hi = 0; + } else { + tmp = (uint2)(0, 0); + } + } + if (exp_diff > 0) + st_c.mantissa = tmp; + else + st_mul.mantissa = tmp; + + struct fp st_fma; + st_fma.sign = st_mul.sign; + st_fma.exponent = max(st_mul.exponent, st_c.exponent); + st_fma.mantissa = (uint2)(0, 0); + if (st_c.sign == st_mul.sign) { + uint carry = (hadd(st_mul.mantissa.lo, st_c.mantissa.lo) >> 31) & 0x1; + st_fma.mantissa = st_mul.mantissa + st_c.mantissa; + st_fma.mantissa.hi += carry; + } else { + // cutoff bits borrow one + uint cutoff_borrow = ((cutoff_bits.lo != 0 || cutoff_bits.hi != 0) && + (st_mul.exponent > st_c.exponent)) + ? 1 + : 0; + uint borrow = 0; + if (st_c.mantissa.lo > st_mul.mantissa.lo) { + borrow = 1; + } else if (st_c.mantissa.lo == UINT_MAX && cutoff_borrow == 1) { + borrow = 1; + } else if ((st_c.mantissa.lo + cutoff_borrow) > st_mul.mantissa.lo) { + borrow = 1; + } + + st_fma.mantissa.lo = st_mul.mantissa.lo - st_c.mantissa.lo - cutoff_borrow; + st_fma.mantissa.hi = st_mul.mantissa.hi - st_c.mantissa.hi - borrow; + } + + // underflow: st_c.sign != st_mul.sign, and magnitude switches the sign + if (st_fma.mantissa.hi > INT_MAX) { + st_fma.mantissa = ~st_fma.mantissa; + uint carry = (hadd(st_fma.mantissa.lo, 1u) >> 31) & 0x1; + st_fma.mantissa.lo += 1; + st_fma.mantissa.hi += carry; + + st_fma.sign = st_mul.sign ^ 0x80000000; + } + + // detect overflow/underflow + uint leading_zeroes = clz(st_fma.mantissa.hi); + if (leading_zeroes == 32) { + leading_zeroes += clz(st_fma.mantissa.lo); + } + int overflow_bits = 3 - leading_zeroes; + + // adjust exponent + st_fma.exponent += overflow_bits; + + // handle underflow + if (overflow_bits < 0) { + uint shift = -overflow_bits; + if (shift < 32) { + uint shift_mask = (1u << shift) - 1; + uint saved_lo_bits = (st_fma.mantissa.lo >> (32 - shift)) & shift_mask; + st_fma.mantissa.lo <<= shift; + st_fma.mantissa.hi <<= shift; + st_fma.mantissa.hi |= saved_lo_bits; + } else if (shift < 64) { + st_fma.mantissa.hi = (st_fma.mantissa.lo << (64 - shift)); + st_fma.mantissa.lo = 0; + } else { + st_fma.mantissa = (uint2)(0, 0); + } + + overflow_bits = 0; + } + + // rounding + // overflow_bits is now in the range of [0, 3] making the shift greater than + // 32 bits. + uint2 trunc_mask; + uint trunc_shift = C_ADJUST + overflow_bits - 32; + trunc_mask.hi = (1u << trunc_shift) - 1; + trunc_mask.lo = UINT_MAX; + uint2 trunc_bits = st_fma.mantissa & trunc_mask; + trunc_bits.lo |= (cutoff_bits.hi != 0 || cutoff_bits.lo != 0) ? 1 : 0; + uint2 last_bit; + last_bit.lo = 0; + last_bit.hi = st_fma.mantissa.hi & (1u << trunc_shift); + uint grs_shift = C_ADJUST - 3 + overflow_bits - 32; + uint2 grs_bits; + grs_bits.lo = 0; + grs_bits.hi = 0x4u << grs_shift; + + // round to nearest even + if ((trunc_bits.hi > grs_bits.hi || + (trunc_bits.hi == grs_bits.hi && trunc_bits.lo > grs_bits.lo)) || + (trunc_bits.hi == grs_bits.hi && trunc_bits.lo == grs_bits.lo && + last_bit.hi != 0)) { + uint shift = C_ADJUST + overflow_bits - 32; + st_fma.mantissa.hi += 1u << shift; + } + + // Shift mantissa back to bit 23 + st_fma.mantissa.lo = (st_fma.mantissa.hi >> (C_ADJUST + overflow_bits - 32)); + st_fma.mantissa.hi = 0; + + // Detect rounding overflow + if (st_fma.mantissa.lo > 0xffffff) { + ++st_fma.exponent; + st_fma.mantissa.lo >>= 1; + } + + if (st_fma.mantissa.lo == 0) { + return 0.0f; + } + + // Flating point range limit + if (st_fma.exponent > 127) { + return as_float(as_uint(INFINITY) | st_fma.sign); + } + + // Flush denormals + if (st_fma.exponent <= -127) { + return as_float(st_fma.sign); + } + + return as_float(st_fma.sign | ((st_fma.exponent + 127) << 23) | + ((uint)st_fma.mantissa.lo & 0x7fffff)); +} +_CLC_TERNARY_VECTORIZE(_CLC_DEF _CLC_OVERLOAD, float, fma, float, float, float) |