diff options
author | Jakub Kuderski <kubak@google.com> | 2022-12-12 14:39:57 -0500 |
---|---|---|
committer | Jakub Kuderski <kubak@google.com> | 2022-12-12 14:39:58 -0500 |
commit | f39b47264edda30b90ea9cc435fc1434a2e06edc (patch) | |
tree | f9184cda0854a2c2277c946f5ecc31a80699140c | |
parent | 11b9c7943bad1915e3ba096b597af3d050048d53 (diff) | |
download | llvm-f39b47264edda30b90ea9cc435fc1434a2e06edc.tar.gz |
[mlir][arith][tosa] Use extended mul in 32-bit `tosa.apply_scale`
To not introduce 64-bit types that may be difficult to handle for some
targets.
Reviewed By: rsuderman, antiagainst
Differential Revision: https://reviews.llvm.org/D139777
-rw-r--r-- | mlir/lib/Conversion/TosaToArith/TosaToArith.cpp | 16 | ||||
-rw-r--r-- | mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir | 8 |
2 files changed, 5 insertions, 19 deletions
diff --git a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp index f188b1bbe5dc..fb0cf4f38d79 100644 --- a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp +++ b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp @@ -127,7 +127,6 @@ public: Type resultTy = op.getType(); Type i32Ty = matchContainerType(rewriter.getI32Type(), resultTy); - Type i64Ty = matchContainerType(rewriter.getI64Type(), resultTy); Value value = op.getValue(); if (getElementTypeOrSelf(value.getType()).getIntOrFloatBitWidth() > 32) { @@ -144,20 +143,13 @@ public: Value two32 = getConstantValue(loc, i32Ty, 2, rewriter); Value thirty32 = getConstantValue(loc, i32Ty, 30, rewriter); Value thirtyTwo32 = getConstantValue(loc, i32Ty, 32, rewriter); - Value thirtyTwo64 = getConstantValue(loc, i64Ty, 32, rewriter); // Compute the multiplication in 64-bits then select the high / low parts. - Value value64 = rewriter.create<arith::ExtSIOp>(loc, i64Ty, value32); - Value multiplier64 = - rewriter.create<arith::ExtSIOp>(loc, i64Ty, multiplier32); - Value multiply64 = - rewriter.create<arith::MulIOp>(loc, value64, multiplier64); - // Grab out the high/low of the computation - Value high64 = - rewriter.create<arith::ShRUIOp>(loc, multiply64, thirtyTwo64); - Value high32 = rewriter.create<arith::TruncIOp>(loc, i32Ty, high64); - Value low32 = rewriter.create<arith::MulIOp>(loc, value32, multiplier32); + auto value64 = + rewriter.create<arith::MulSIExtendedOp>(loc, value32, multiplier32); + Value low32 = value64.getLow(); + Value high32 = value64.getHigh(); // Determine the direction and amount to shift the high bits. Value shiftOver32 = rewriter.create<arith::CmpIOp>( diff --git a/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir b/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir index 17a7ac3e76fd..7f99e38d7419 100644 --- a/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir +++ b/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir @@ -21,15 +21,9 @@ func.func @apply_scale_test_i32(%arg0 : i32, %arg1 : i32, %arg2 : i8) -> (i32) { // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : i32 // CHECK-DAG: %[[C30:.+]] = arith.constant 30 : i32 // CHECK-DAG: %[[C32:.+]] = arith.constant 32 : i32 - // CHECK-DAG: %[[C32L:.+]] = arith.constant 32 : i64 // Compute the high-low values of the matmul in 64-bits. - // CHECK-DAG: %[[V64:.+]] = arith.extsi %arg0 : i32 to i64 - // CHECK-DAG: %[[M64:.+]] = arith.extsi %arg1 : i32 to i64 - // CHECK-DAG: %[[MUL64:.+]] = arith.muli %[[V64]], %[[M64]] - // CHECK-DAG: %[[HI64:.+]] = arith.shrui %[[MUL64]], %[[C32L]] - // CHECK-DAG: %[[HI:.+]] = arith.trunci %[[HI64]] : i64 to i32 - // CHECK-DAG: %[[LOW:.+]] = arith.muli %arg0, %arg1 + // CHECK-DAG: %[[LOW:.+]], %[[HI:.+]] = arith.mulsi_extended %arg0, %arg1 // Determine whether the high bits need to shift left or right and by how much. // CHECK-DAG: %[[OVER31:.+]] = arith.cmpi sge, %[[S32]], %[[C32]] |