summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNoah Goldstein <goldstein.w.n@gmail.com>2023-04-18 16:34:02 -0500
committerTom Stellard <tstellar@redhat.com>2023-05-01 21:26:28 -0700
commit93e555fbd794755fe113ef5351852728d33b6084 (patch)
tree0a16a56a4c834a362629235bcefa71c04a2ec9d3
parentff9dc9c4fb113de9619fad6a77f4888277579718 (diff)
downloadllvm-93e555fbd794755fe113ef5351852728d33b6084.tar.gz
[InstCombine] Fix buggy `(mul X, Y)` -> `(shl X, Log2(Y))` transform PR62175
Bug was because we recognized patterns like `(shl 4, Z)` as a power of 2 we could take Log2 of (`2 + Z`), but doing `(shl X, (2 + Z))` can cause a poison shift. https://alive2.llvm.org/ce/z/yuJm_k The fix is to verify that `Log2(Y)` will be a non-poisonous shift amount. We can do this with: `nsw` flag: - https://alive2.llvm.org/ce/z/yyyJBr - https://alive2.llvm.org/ce/z/YgubD_ `nuw` flag: - https://alive2.llvm.org/ce/z/-4mpyV - https://alive2.llvm.org/ce/z/a6ik6r Prove `Y != 0`: - https://alive2.llvm.org/ce/z/ced4su - https://alive2.llvm.org/ce/z/X-JJHb Reviewed By: nikic Differential Revision: https://reviews.llvm.org/D148609
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp41
1 files changed, 27 insertions, 14 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index 97f129e200de..dd1e8da2eb48 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -1121,7 +1121,7 @@ static const unsigned MaxDepth = 6;
// actual instructions, otherwise return a non-null dummy value. Return nullptr
// on failure.
static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth,
- bool DoFold) {
+ bool AssumeNonZero, bool DoFold) {
auto IfFold = [DoFold](function_ref<Value *()> Fn) {
if (!DoFold)
return reinterpret_cast<Value *>(-1);
@@ -1147,14 +1147,18 @@ static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth,
// FIXME: Require one use?
Value *X, *Y;
if (match(Op, m_ZExt(m_Value(X))))
- if (Value *LogX = takeLog2(Builder, X, Depth, DoFold))
+ if (Value *LogX = takeLog2(Builder, X, Depth, AssumeNonZero, DoFold))
return IfFold([&]() { return Builder.CreateZExt(LogX, Op->getType()); });
// log2(X << Y) -> log2(X) + Y
// FIXME: Require one use unless X is 1?
- if (match(Op, m_Shl(m_Value(X), m_Value(Y))))
- if (Value *LogX = takeLog2(Builder, X, Depth, DoFold))
- return IfFold([&]() { return Builder.CreateAdd(LogX, Y); });
+ if (match(Op, m_Shl(m_Value(X), m_Value(Y)))) {
+ auto *BO = cast<OverflowingBinaryOperator>(Op);
+ // nuw will be set if the `shl` is trivially non-zero.
+ if (AssumeNonZero || BO->hasNoUnsignedWrap() || BO->hasNoSignedWrap())
+ if (Value *LogX = takeLog2(Builder, X, Depth, AssumeNonZero, DoFold))
+ return IfFold([&]() { return Builder.CreateAdd(LogX, Y); });
+ }
// log2(Cond ? X : Y) -> Cond ? log2(X) : log2(Y)
// FIXME: missed optimization: if one of the hands of select is/contains
@@ -1162,8 +1166,10 @@ static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth,
// FIXME: can both hands contain undef?
// FIXME: Require one use?
if (SelectInst *SI = dyn_cast<SelectInst>(Op))
- if (Value *LogX = takeLog2(Builder, SI->getOperand(1), Depth, DoFold))
- if (Value *LogY = takeLog2(Builder, SI->getOperand(2), Depth, DoFold))
+ if (Value *LogX = takeLog2(Builder, SI->getOperand(1), Depth,
+ AssumeNonZero, DoFold))
+ if (Value *LogY = takeLog2(Builder, SI->getOperand(2), Depth,
+ AssumeNonZero, DoFold))
return IfFold([&]() {
return Builder.CreateSelect(SI->getOperand(0), LogX, LogY);
});
@@ -1171,13 +1177,18 @@ static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth,
// log2(umin(X, Y)) -> umin(log2(X), log2(Y))
// log2(umax(X, Y)) -> umax(log2(X), log2(Y))
auto *MinMax = dyn_cast<MinMaxIntrinsic>(Op);
- if (MinMax && MinMax->hasOneUse() && !MinMax->isSigned())
- if (Value *LogX = takeLog2(Builder, MinMax->getLHS(), Depth, DoFold))
- if (Value *LogY = takeLog2(Builder, MinMax->getRHS(), Depth, DoFold))
+ if (MinMax && MinMax->hasOneUse() && !MinMax->isSigned()) {
+ // Use AssumeNonZero as false here. Otherwise we can hit case where
+ // log2(umax(X, Y)) != umax(log2(X), log2(Y)) (because overflow).
+ if (Value *LogX = takeLog2(Builder, MinMax->getLHS(), Depth,
+ /*AssumeNonZero*/ false, DoFold))
+ if (Value *LogY = takeLog2(Builder, MinMax->getRHS(), Depth,
+ /*AssumeNonZero*/ false, DoFold))
return IfFold([&]() {
- return Builder.CreateBinaryIntrinsic(
- MinMax->getIntrinsicID(), LogX, LogY);
+ return Builder.CreateBinaryIntrinsic(MinMax->getIntrinsicID(), LogX,
+ LogY);
});
+ }
return nullptr;
}
@@ -1297,8 +1308,10 @@ Instruction *InstCombinerImpl::visitUDiv(BinaryOperator &I) {
}
// Op1 udiv Op2 -> Op1 lshr log2(Op2), if log2() folds away.
- if (takeLog2(Builder, Op1, /*Depth*/0, /*DoFold*/false)) {
- Value *Res = takeLog2(Builder, Op1, /*Depth*/0, /*DoFold*/true);
+ if (takeLog2(Builder, Op1, /*Depth*/ 0, /*AssumeNonZero*/ true,
+ /*DoFold*/ false)) {
+ Value *Res = takeLog2(Builder, Op1, /*Depth*/ 0,
+ /*AssumeNonZero*/ true, /*DoFold*/ true);
return replaceInstUsesWith(
I, Builder.CreateLShr(Op0, Res, I.getName(), I.isExact()));
}