summaryrefslogtreecommitdiff
path: root/mlir/lib
diff options
context:
space:
mode:
authorJakub Kuderski <kubak@google.com>2023-03-06 19:28:39 -0500
committerJakub Kuderski <kubak@google.com>2023-03-06 19:28:39 -0500
commitfbe91fe2cc3bd2c907e63f30db719204aaaf3973 (patch)
tree8070b06bcbf8657667abd6ffc63696c5037b3c0e /mlir/lib
parent260bae5ba27cde110590c28941966a6e02df5325 (diff)
downloadllvm-fbe91fe2cc3bd2c907e63f30db719204aaaf3973.tar.gz
[mlir][arith] Canonicalize `addi(x, muli(y, -1))` -> `subi(x, y)`
These propagate all the way down to SPIR-V and result in some fishy code with large constants. Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D145423
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td21
-rw-r--r--mlir/lib/Dialect/Arith/IR/ArithOps.cpp4
2 files changed, 23 insertions, 2 deletions
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index abf3db1728dc..7c687142247a 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -49,6 +49,27 @@ def AddISubConstantLHS :
(ConstantLikeMatcher APIntAttr:$c1)),
(Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x)>;
+def IsScalarOrSplatNegativeOne :
+ Constraint<And<[
+ CPred<"succeeded(getIntOrSplatIntValue($0))">,
+ CPred<"getIntOrSplatIntValue($0)->isAllOnes()">]>>;
+
+// addi(x, muli(y, -1)) -> subi(x, y)
+def AddIMulNegativeOneRhs :
+ Pat<(Arith_AddIOp
+ $x,
+ (Arith_MulIOp $y, (ConstantLikeMatcher AnyAttr:$c0))),
+ (Arith_SubIOp $x, $y),
+ [(IsScalarOrSplatNegativeOne $c0)]>;
+
+// addi(muli(x, -1), y) -> subi(y, x)
+def AddIMulNegativeOneLhs :
+ Pat<(Arith_AddIOp
+ (Arith_MulIOp $x, (ConstantLikeMatcher AnyAttr:$c0)),
+ $y),
+ (Arith_SubIOp $y, $x),
+ [(IsScalarOrSplatNegativeOne $c0)]>;
+
//===----------------------------------------------------------------------===//
// AddUIExtendedOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index f6308a6b000b..e56f4526291a 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -258,8 +258,8 @@ OpFoldResult arith::AddIOp::fold(FoldAdaptor adaptor) {
void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
- patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS>(
- context);
+ patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS,
+ AddIMulNegativeOneRhs, AddIMulNegativeOneLhs>(context);
}
//===----------------------------------------------------------------------===//