summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPeter Rong <PeterRong96@gmail.com>2023-01-12 10:58:38 -0800
committerPeter Rong <PeterRong96@gmail.com>2023-01-24 20:22:06 -0800
commit9b70a28e0d767f99bdc778356e81b4d072f59819 (patch)
treef211eef10e0e1de8c615dc4a89cdeefeec93e8ee
parentf9599bbc7a3f831e1793a549d8a7a19265f3e504 (diff)
downloadllvm-9b70a28e0d767f99bdc778356e81b4d072f59819.tar.gz
[Transform] Rewrite LowerSwitch using APInt
This rewrite fixes https://github.com/llvm/llvm-project/issues/59316. Previously LowerSwitch uses int64_t, which will crash on case branches using integers with more than 64 bits. Using APInt fixes this problem. This patch also includes a test Reviewed By: RKSimon Differential Revision: https://reviews.llvm.org/D140747
-rw-r--r--llvm/lib/Transforms/Utils/LowerSwitch.cpp21
-rw-r--r--llvm/test/Transforms/LowerSwitch/pr59316.ll29
2 files changed, 40 insertions, 10 deletions
diff --git a/llvm/lib/Transforms/Utils/LowerSwitch.cpp b/llvm/lib/Transforms/Utils/LowerSwitch.cpp
index 26aebdfff640..227de425ff85 100644
--- a/llvm/lib/Transforms/Utils/LowerSwitch.cpp
+++ b/llvm/lib/Transforms/Utils/LowerSwitch.cpp
@@ -370,7 +370,9 @@ void ProcessSwitchInst(SwitchInst *SI,
const unsigned NumSimpleCases = Clusterify(Cases, SI);
IntegerType *IT = cast<IntegerType>(SI->getCondition()->getType());
const unsigned BitWidth = IT->getBitWidth();
- APInt SignedZero(BitWidth, 0);
+ // Explictly use higher precision to prevent unsigned overflow where
+ // `UnsignedMax - 0 + 1 == 0`
+ APInt UnsignedZero(BitWidth + 1, 0);
APInt UnsignedMax = APInt::getMaxValue(BitWidth);
LLVM_DEBUG(dbgs() << "Clusterify finished. Total clusters: " << Cases.size()
<< ". Total non-default cases: " << NumSimpleCases
@@ -431,7 +433,7 @@ void ProcessSwitchInst(SwitchInst *SI,
if (DefaultIsUnreachableFromSwitch) {
DenseMap<BasicBlock *, APInt> Popularity;
- APInt MaxPop(SignedZero);
+ APInt MaxPop(UnsignedZero);
BasicBlock *PopSucc = nullptr;
APInt SignedMax = APInt::getSignedMaxValue(BitWidth);
@@ -457,11 +459,11 @@ void ProcessSwitchInst(SwitchInst *SI,
}
// Count popularity.
- APInt N = High - Low + 1;
- assert(N.sge(SignedZero) && "Popularity shouldn't be negative.");
+ assert(High.sge(Low) && "Popularity shouldn't be negative.");
+ APInt N = High.sext(BitWidth + 1) - Low.sext(BitWidth + 1) + 1;
// Explict insert to make sure the bitwidth of APInts match
- APInt &Pop = Popularity.insert({I.BB, APInt(SignedZero)}).first->second;
- if ((Pop += N).sgt(MaxPop)) {
+ APInt &Pop = Popularity.insert({I.BB, APInt(UnsignedZero)}).first->second;
+ if ((Pop += N).ugt(MaxPop)) {
MaxPop = Pop;
PopSucc = I.BB;
}
@@ -486,8 +488,6 @@ void ProcessSwitchInst(SwitchInst *SI,
// Use the most popular block as the new default, reducing the number of
// cases.
- assert(MaxPop.sgt(SignedZero) && PopSucc &&
- "Max populartion shouldn't be negative.");
Default = PopSucc;
llvm::erase_if(Cases,
[PopSucc](const CaseRange &R) { return R.BB == PopSucc; });
@@ -498,8 +498,9 @@ void ProcessSwitchInst(SwitchInst *SI,
SI->eraseFromParent();
// As all the cases have been replaced with a single branch, only keep
// one entry in the PHI nodes.
- for (APInt I(SignedZero); I.slt(MaxPop - 1); ++I)
- PopSucc->removePredecessor(OrigBlock);
+ if (!MaxPop.isZero())
+ for (APInt I(UnsignedZero); I.ult(MaxPop - 1); ++I)
+ PopSucc->removePredecessor(OrigBlock);
return;
}
diff --git a/llvm/test/Transforms/LowerSwitch/pr59316.ll b/llvm/test/Transforms/LowerSwitch/pr59316.ll
index 2e4226c71ea7..0616ace67296 100644
--- a/llvm/test/Transforms/LowerSwitch/pr59316.ll
+++ b/llvm/test/Transforms/LowerSwitch/pr59316.ll
@@ -62,3 +62,32 @@ BB:
BB1: ; preds = %BB
unreachable
}
+
+define void @f_i1() {
+entry:
+ switch i1 false, label %sw.bb [
+ i1 false, label %sw.bb12
+ ]
+
+sw.bb: ; preds = %entry
+ unreachable
+
+sw.bb12: ; preds = %entry
+ unreachable
+}
+
+define void @f_i2(i2 %cond) {
+entry:
+ switch i2 %cond, label %sw.bb [
+ i2 0, label %sw.bb12
+ i2 1, label %sw.bb12
+ i2 2, label %sw.bb12
+ i2 3, label %sw.bb12
+ ]
+
+sw.bb: ; preds = %entry
+ unreachable
+
+sw.bb12: ; preds = %entry
+ unreachable
+} \ No newline at end of file