summaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
authorAndrzej Warzynski <andrzej.warzynski@arm.com>2023-05-12 17:35:53 +0100
committerAndrzej Warzynski <andrzej.warzynski@gmail.com>2023-05-16 08:45:38 +0100
commite4fd46b511b4a2729e5e1f792e741a88b4c5b935 (patch)
tree58c1a79348ad84856f671ff57ee90b7481296c60 /mlir
parent9d73a8bdc66496b673c11e991fd9cf0cba0a1bff (diff)
downloadllvm-e4fd46b511b4a2729e5e1f792e741a88b4c5b935.tar.gz
[mlir][linalg] Add a test for linalg.matmul --> vector.outerproduct
Representing matmuls as a sum of outer products is central to various matrix extensions (e.g. Arm's SME). This test demonstrates how to use Linalg's vectoriser and Vector's lowerings to represent `linalg.matmul` as a chain of `vector.outerproduct` Ops. Differential Revision: https://reviews.llvm.org/D150457
Diffstat (limited to 'mlir')
-rw-r--r--mlir/test/Dialect/Linalg/transform-op-matmul-to-outerproduct.mlir36
1 files changed, 36 insertions, 0 deletions
diff --git a/mlir/test/Dialect/Linalg/transform-op-matmul-to-outerproduct.mlir b/mlir/test/Dialect/Linalg/transform-op-matmul-to-outerproduct.mlir
new file mode 100644
index 000000000000..910f019f1a58
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-matmul-to-outerproduct.mlir
@@ -0,0 +1,36 @@
+// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file | FileCheck %s
+
+func.func @outerproduct_matmul(%A: memref<3x3xf32>, %B: memref<3x3xf32>, %C: memref<3x3xf32>) {
+ linalg.matmul ins(%A, %B: memref<3x3xf32>, memref<3x3xf32>)
+ outs(%C: memref<3x3xf32>)
+ return
+}
+
+// CHECK-LABEL: func.func @outerproduct_matmul(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<3x3xf32>, %[[VAL_1:.*]]: memref<3x3xf32>, %[[VAL_2:.*]]: memref<3x3xf32>) {
+// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_4:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[VAL_5:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<3x3xf32>, vector<3x3xf32>
+// CHECK: %[[VAL_6:.*]] = vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<3x3xf32>, vector<3x3xf32>
+// CHECK: %[[VAL_7:.*]] = vector.transfer_read %[[VAL_2]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<3x3xf32>, vector<3x3xf32>
+// CHECK: %[[VAL_8:.*]] = vector.transpose %[[VAL_5]], [1, 0] : vector<3x3xf32> to vector<3x3xf32>
+// CHECK: %[[VAL_9:.*]] = vector.extract %[[VAL_8]][0] : vector<3x3xf32>
+// CHECK: %[[VAL_10:.*]] = vector.extract %[[VAL_6]][0] : vector<3x3xf32>
+// CHECK: %[[VAL_11:.*]] = vector.outerproduct %[[VAL_9]], %[[VAL_10]], %[[VAL_7]] {kind = #vector.kind<add>} : vector<3xf32>, vector<3xf32>
+// CHECK: %[[VAL_12:.*]] = vector.extract %[[VAL_8]][1] : vector<3x3xf32>
+// CHECK: %[[VAL_13:.*]] = vector.extract %[[VAL_6]][1] : vector<3x3xf32>
+// CHECK: %[[VAL_14:.*]] = vector.outerproduct %[[VAL_12]], %[[VAL_13]], %[[VAL_11]] {kind = #vector.kind<add>} : vector<3xf32>, vector<3xf32>
+// CHECK: %[[VAL_15:.*]] = vector.extract %[[VAL_8]][2] : vector<3x3xf32>
+// CHECK: %[[VAL_16:.*]] = vector.extract %[[VAL_6]][2] : vector<3x3xf32>
+// CHECK: %[[VAL_17:.*]] = vector.outerproduct %[[VAL_15]], %[[VAL_16]], %[[VAL_14]] {kind = #vector.kind<add>} : vector<3xf32>, vector<3xf32>
+// CHECK: vector.transfer_write %[[VAL_17]], %[[VAL_2]]{{\[}}%[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, true]} : vector<3x3xf32>, memref<3x3xf32>
+// CHECK: return
+// CHECK: }
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+ %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation
+ %2 = transform.structured.vectorize %1
+ transform.vector.lower_contraction %2 lowering_strategy = "outerproduct" : (!pdl.operation) -> !pdl.operation
+}