diff options
author | Nicolas Vasilache <ntv@google.com> | 2020-07-10 09:31:02 -0400 |
---|---|---|
committer | Nicolas Vasilache <ntv@google.com> | 2020-07-10 09:32:02 -0400 |
commit | 22c8a08fd8a1487159564f74f24561964f6a6c97 (patch) | |
tree | 30e29ae1543dbc0f87e33dfffdec9106dbb08408 | |
parent | 0555db0a5df4d669ce4c2125668ec7a8a42fcd9d (diff) | |
download | llvm-22c8a08fd8a1487159564f74f24561964f6a6c97.tar.gz |
[mlir][Vector] Fold chains of ExtractOp
This revision adds folding to ExtractOp by simply concatenating the position attributes.
-rw-r--r-- | mlir/lib/Dialect/Vector/VectorOps.cpp | 43 | ||||
-rw-r--r-- | mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir | 40 | ||||
-rw-r--r-- | mlir/test/Dialect/Vector/canonicalize.mlir | 15 | ||||
-rw-r--r-- | mlir/test/Dialect/Vector/vector-contract-transforms.mlir | 155 |
4 files changed, 137 insertions, 116 deletions
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp index 019b73a0e00b..0aae97f24d20 100644 --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -571,16 +571,43 @@ static LogicalResult verify(vector::ExtractOp op) { return success(); } -static SmallVector<unsigned, 4> extractUnsignedVector(ArrayAttr arrayAttr) { +template <typename IntType> +static SmallVector<IntType, 4> extractVector(ArrayAttr arrayAttr) { return llvm::to_vector<4>(llvm::map_range( arrayAttr.getAsRange<IntegerAttr>(), - [](IntegerAttr attr) { return static_cast<unsigned>(attr.getInt()); })); + [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); })); } -static Value foldExtractOp(ExtractOp extractOp) { +/// Fold the result of chains of ExtractOp in place by simply concatenating the +/// positions. +static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) { + if (!extractOp.vector().getDefiningOp<ExtractOp>()) + return failure(); + + SmallVector<int64_t, 4> globalPosition; + ExtractOp currentOp = extractOp; + auto extractedPos = extractVector<int64_t>(currentOp.position()); + globalPosition.append(extractedPos.rbegin(), extractedPos.rend()); + while (ExtractOp nextOp = currentOp.vector().getDefiningOp<ExtractOp>()) { + currentOp = nextOp; + auto extractedPos = extractVector<int64_t>(currentOp.position()); + globalPosition.append(extractedPos.rbegin(), extractedPos.rend()); + } + extractOp.setOperand(currentOp.vector()); + // OpBuilder is only used as a helper to build an I64ArrayAttr. + OpBuilder b(extractOp.getContext()); + std::reverse(globalPosition.begin(), globalPosition.end()); + extractOp.setAttr(ExtractOp::getPositionAttrName(), + b.getI64ArrayAttr(globalPosition)); + return success(); +} + +/// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps. The +/// result is always the input to some InsertOp. +static Value foldExtractOpFromInsertChainAndTranspose(ExtractOp extractOp) { MLIRContext *context = extractOp.getContext(); AffineMap permutationMap; - auto extractedPos = extractUnsignedVector(extractOp.position()); + auto extractedPos = extractVector<unsigned>(extractOp.position()); // Walk back a chain of InsertOp/TransposeOp until we hit a match. // Compose TransposeOp permutations as we walk back. auto insertOp = extractOp.vector().getDefiningOp<vector::InsertOp>(); @@ -588,7 +615,7 @@ static Value foldExtractOp(ExtractOp extractOp) { while (insertOp || transposeOp) { if (transposeOp) { // If it is transposed, compose the map and iterate. - auto permutation = extractUnsignedVector(transposeOp.transp()); + auto permutation = extractVector<unsigned>(transposeOp.transp()); AffineMap newMap = AffineMap::getPermutationMap(permutation, context); if (!permutationMap) permutationMap = newMap; @@ -610,7 +637,7 @@ static Value foldExtractOp(ExtractOp extractOp) { // InsertOp/TransposeOp. This is because `vector.insert %scalar, %vector` // produces a new vector with 1 modified value/slice in exactly the static // position we need to match. - auto insertedPos = extractUnsignedVector(insertOp.position()); + auto insertedPos = extractVector<unsigned>(insertOp.position()); // Trivial permutations are solved with position equality checks. if (!permutationMap || permutationMap.isIdentity()) { if (extractedPos == insertedPos) @@ -660,7 +687,9 @@ static Value foldExtractOp(ExtractOp extractOp) { } OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) { - if (auto val = foldExtractOp(*this)) + if (succeeded(foldExtractOpFromExtractChain(*this))) + return getResult(); + if (auto val = foldExtractOpFromInsertChainAndTranspose(*this)) return val; return OpFoldResult(); } diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index 2b2adf0cca64..09162aa0236b 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -172,29 +172,25 @@ func @broadcast_stretch_in_middle(%arg0: vector<4x1x2xf32>) -> vector<4x3x2xf32> // CHECK-SAME: %[[A:.*]]: !llvm<"[4 x [1 x <2 x float>]]">) // CHECK: %[[T0:.*]] = llvm.mlir.constant(dense<0.000000e+00> : vector<4x3x2xf32>) : !llvm<"[4 x [3 x <2 x float>]]"> // CHECK: %[[T1:.*]] = llvm.mlir.constant(dense<0.000000e+00> : vector<3x2xf32>) : !llvm<"[3 x <2 x float>]"> -// CHECK: %[[T2:.*]] = llvm.extractvalue %[[A]][0] : !llvm<"[4 x [1 x <2 x float>]]"> -// CHECK: %[[T3:.*]] = llvm.extractvalue %[[T2]][0] : !llvm<"[1 x <2 x float>]"> -// CHECK: %[[T4:.*]] = llvm.insertvalue %[[T3]], %[[T1]][0] : !llvm<"[3 x <2 x float>]"> -// CHECK: %[[T5:.*]] = llvm.insertvalue %[[T3]], %[[T4]][1] : !llvm<"[3 x <2 x float>]"> -// CHECK: %[[T6:.*]] = llvm.insertvalue %[[T3]], %[[T5]][2] : !llvm<"[3 x <2 x float>]"> +// CHECK: %[[T2:.*]] = llvm.extractvalue %[[A]][0, 0] : !llvm<"[4 x [1 x <2 x float>]]"> +// CHECK: %[[T4:.*]] = llvm.insertvalue %[[T2]], %[[T1]][0] : !llvm<"[3 x <2 x float>]"> +// CHECK: %[[T5:.*]] = llvm.insertvalue %[[T2]], %[[T4]][1] : !llvm<"[3 x <2 x float>]"> +// CHECK: %[[T6:.*]] = llvm.insertvalue %[[T2]], %[[T5]][2] : !llvm<"[3 x <2 x float>]"> // CHECK: %[[T7:.*]] = llvm.insertvalue %[[T6]], %[[T0]][0] : !llvm<"[4 x [3 x <2 x float>]]"> -// CHECK: %[[T8:.*]] = llvm.extractvalue %[[A]][1] : !llvm<"[4 x [1 x <2 x float>]]"> -// CHECK: %[[T9:.*]] = llvm.extractvalue %[[T8]][0] : !llvm<"[1 x <2 x float>]"> -// CHECK: %[[T10:.*]] = llvm.insertvalue %[[T9]], %[[T1]][0] : !llvm<"[3 x <2 x float>]"> -// CHECK: %[[T11:.*]] = llvm.insertvalue %[[T9]], %[[T10]][1] : !llvm<"[3 x <2 x float>]"> -// CHECK: %[[T12:.*]] = llvm.insertvalue %[[T9]], %[[T11]][2] : !llvm<"[3 x <2 x float>]"> +// CHECK: %[[T8:.*]] = llvm.extractvalue %[[A]][1, 0] : !llvm<"[4 x [1 x <2 x float>]]"> +// CHECK: %[[T10:.*]] = llvm.insertvalue %[[T8]], %[[T1]][0] : !llvm<"[3 x <2 x float>]"> +// CHECK: %[[T11:.*]] = llvm.insertvalue %[[T8]], %[[T10]][1] : !llvm<"[3 x <2 x float>]"> +// CHECK: %[[T12:.*]] = llvm.insertvalue %[[T8]], %[[T11]][2] : !llvm<"[3 x <2 x float>]"> // CHECK: %[[T13:.*]] = llvm.insertvalue %[[T12]], %[[T7]][1] : !llvm<"[4 x [3 x <2 x float>]]"> -// CHECK: %[[T14:.*]] = llvm.extractvalue %[[A]][2] : !llvm<"[4 x [1 x <2 x float>]]"> -// CHECK: %[[T15:.*]] = llvm.extractvalue %[[T14]][0] : !llvm<"[1 x <2 x float>]"> -// CHECK: %[[T16:.*]] = llvm.insertvalue %[[T15]], %[[T1]][0] : !llvm<"[3 x <2 x float>]"> -// CHECK: %[[T17:.*]] = llvm.insertvalue %[[T15]], %[[T16]][1] : !llvm<"[3 x <2 x float>]"> -// CHECK: %[[T18:.*]] = llvm.insertvalue %[[T15]], %[[T17]][2] : !llvm<"[3 x <2 x float>]"> +// CHECK: %[[T14:.*]] = llvm.extractvalue %[[A]][2, 0] : !llvm<"[4 x [1 x <2 x float>]]"> +// CHECK: %[[T16:.*]] = llvm.insertvalue %[[T14]], %[[T1]][0] : !llvm<"[3 x <2 x float>]"> +// CHECK: %[[T17:.*]] = llvm.insertvalue %[[T14]], %[[T16]][1] : !llvm<"[3 x <2 x float>]"> +// CHECK: %[[T18:.*]] = llvm.insertvalue %[[T14]], %[[T17]][2] : !llvm<"[3 x <2 x float>]"> // CHECK: %[[T19:.*]] = llvm.insertvalue %[[T18]], %[[T13]][2] : !llvm<"[4 x [3 x <2 x float>]]"> -// CHECK: %[[T20:.*]] = llvm.extractvalue %[[A]][3] : !llvm<"[4 x [1 x <2 x float>]]"> -// CHECK: %[[T21:.*]] = llvm.extractvalue %[[T20]][0] : !llvm<"[1 x <2 x float>]"> -// CHECK: %[[T22:.*]] = llvm.insertvalue %[[T21]], %[[T1]][0] : !llvm<"[3 x <2 x float>]"> -// CHECK: %[[T23:.*]] = llvm.insertvalue %[[T21]], %[[T22]][1] : !llvm<"[3 x <2 x float>]"> -// CHECK: %[[T24:.*]] = llvm.insertvalue %[[T21]], %[[T23]][2] : !llvm<"[3 x <2 x float>]"> +// CHECK: %[[T20:.*]] = llvm.extractvalue %[[A]][3, 0] : !llvm<"[4 x [1 x <2 x float>]]"> +// CHECK: %[[T22:.*]] = llvm.insertvalue %[[T20]], %[[T1]][0] : !llvm<"[3 x <2 x float>]"> +// CHECK: %[[T23:.*]] = llvm.insertvalue %[[T20]], %[[T22]][1] : !llvm<"[3 x <2 x float>]"> +// CHECK: %[[T24:.*]] = llvm.insertvalue %[[T20]], %[[T23]][2] : !llvm<"[3 x <2 x float>]"> // CHECK: %[[T25:.*]] = llvm.insertvalue %[[T24]], %[[T19]][3] : !llvm<"[4 x [3 x <2 x float>]]"> // CHECK: llvm.return %[[T25]] : !llvm<"[4 x [3 x <2 x float>]]"> @@ -630,7 +626,7 @@ func @insert_strided_slice3(%arg0: vector<2x4xf32>, %arg1: vector<16x4x8xf32>) - // CHECK-SAME: %[[B:.*]]: !llvm<"[16 x [4 x <8 x float>]]">) // CHECK: %[[s0:.*]] = llvm.extractvalue %[[B]][0] : !llvm<"[16 x [4 x <8 x float>]]"> // CHECK: %[[s1:.*]] = llvm.extractvalue %[[A]][0] : !llvm<"[2 x <4 x float>]"> -// CHECK: %[[s2:.*]] = llvm.extractvalue %[[s0]][0] : !llvm<"[4 x <8 x float>]"> +// CHECK: %[[s2:.*]] = llvm.extractvalue %[[B]][0, 0] : !llvm<"[16 x [4 x <8 x float>]]"> // CHECK: %[[s3:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // CHECK: %[[s4:.*]] = llvm.extractelement %[[s1]][%[[s3]] : !llvm.i64] : !llvm<"<4 x float>"> // CHECK: %[[s5:.*]] = llvm.mlir.constant(2 : index) : !llvm.i64 @@ -649,7 +645,7 @@ func @insert_strided_slice3(%arg0: vector<2x4xf32>, %arg1: vector<16x4x8xf32>) - // CHECK: %[[s18:.*]] = llvm.insertelement %[[s16]], %[[s14]][%[[s17]] : !llvm.i64] : !llvm<"<8 x float>"> // CHECK: %[[s19:.*]] = llvm.insertvalue %[[s18]], %[[s0]][0] : !llvm<"[4 x <8 x float>]"> // CHECK: %[[s20:.*]] = llvm.extractvalue %[[A]][1] : !llvm<"[2 x <4 x float>]"> -// CHECK: %[[s21:.*]] = llvm.extractvalue %[[s0]][1] : !llvm<"[4 x <8 x float>]"> +// CHECK: %[[s21:.*]] = llvm.extractvalue %[[B]][0, 1] : !llvm<"[16 x [4 x <8 x float>]]"> // CHECK: %[[s22:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // CHECK: %[[s23:.*]] = llvm.extractelement %[[s20]][%[[s22]] : !llvm.i64] : !llvm<"<4 x float>"> // CHECK: %[[s24:.*]] = llvm.mlir.constant(2 : index) : !llvm.i64 diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 7ba79528aee6..94f3f627e777 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -295,3 +295,18 @@ func @insert_extract_transpose_3d_2d( // CHECK-SAME: vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32> return %r1, %r2, %r3, %r4 : vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32> } + +// ----- + +// CHECK-LABEL: fold_extracts +// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<3x4x5x6xf32> +// CHECK-NEXT: vector.extract %[[A]][0, 1, 2, 3] : vector<3x4x5x6xf32> +// CHECK-NEXT: vector.extract %[[A]][0] : vector<3x4x5x6xf32> +// CHECK-NEXT: return +func @fold_extracts(%a : vector<3x4x5x6xf32>) -> (f32, vector<4x5x6xf32>) { + %b = vector.extract %a[0] : vector<3x4x5x6xf32> + %c = vector.extract %b[1, 2] : vector<4x5x6xf32> + %d = vector.extract %c[3] : vector<6xf32> + %e = vector.extract %a[0] : vector<3x4x5x6xf32> + return %d, %e : f32, vector<4x5x6xf32> +} diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir index 6a933d5e24b5..f6f215a50616 100644 --- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir @@ -113,52 +113,46 @@ func @extract_contract3(%arg0: vector<3xf32>, // CHECK: %[[R:.*]] = constant dense<0.000000e+00> : vector<2x2xf32> // CHECK: %[[Z:.*]] = constant dense<0.000000e+00> : vector<2xf32> // CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x2xf32> -// CHECK: %[[T1:.*]] = vector.extract %[[C]][0] : vector<2x2xf32> -// CHECK: %[[T2:.*]] = vector.extract %[[B]][0] : vector<2x2xf32> -// CHECK: %[[T3:.*]] = vector.extract %[[T2]][0] : vector<2xf32> -// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[Z]] [0] : f32 into vector<2xf32> -// CHECK: %[[T5:.*]] = vector.extract %[[B]][1] : vector<2x2xf32> -// CHECK: %[[T6:.*]] = vector.extract %[[T5]][0] : vector<2xf32> -// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T4]] [1] : f32 into vector<2xf32> -// CHECK: %[[T8:.*]] = vector.extract %[[T1]][0] : vector<2xf32> +// CHECK: %[[T2:.*]] = vector.extract %[[B]][0, 0] : vector<2x2xf32> +// CHECK: %[[T4:.*]] = vector.insert %[[T2]], %[[Z]] [0] : f32 into vector<2xf32> +// CHECK: %[[T5:.*]] = vector.extract %[[B]][1, 0] : vector<2x2xf32> +// CHECK: %[[T7:.*]] = vector.insert %[[T5]], %[[T4]] [1] : f32 into vector<2xf32> +// CHECK: %[[T8:.*]] = vector.extract %[[C]][0, 0] : vector<2x2xf32> // CHECK: %[[T9:.*]] = mulf %[[T0]], %[[T7]] : vector<2xf32> // CHECK: %[[T10:.*]] = vector.reduction "add", %[[T9]], %[[T8]] : vector<2xf32> into f32 // CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[Z]] [0] : f32 into vector<2xf32> -// CHECK: %[[T12:.*]] = vector.extract %[[B]][0] : vector<2x2xf32> -// CHECK: %[[T13:.*]] = vector.extract %[[T12]][1] : vector<2xf32> -// CHECK: %[[T14:.*]] = vector.insert %[[T13]], %[[Z]] [0] : f32 into vector<2xf32> -// CHECK: %[[T15:.*]] = vector.extract %[[B]][1] : vector<2x2xf32> -// CHECK: %[[T16:.*]] = vector.extract %[[T15]][1] : vector<2xf32> -// CHECK: %[[T17:.*]] = vector.insert %[[T16]], %[[T14]] [1] : f32 into vector<2xf32> -// CHECK: %[[T18:.*]] = vector.extract %[[T1]][1] : vector<2xf32> +// +// CHECK: %[[T12:.*]] = vector.extract %[[B]][0, 1] : vector<2x2xf32> +// CHECK: %[[T14:.*]] = vector.insert %[[T12]], %[[Z]] [0] : f32 into vector<2xf32> +// CHECK: %[[T15:.*]] = vector.extract %[[B]][1, 1] : vector<2x2xf32> +// CHECK: %[[T17:.*]] = vector.insert %[[T15]], %[[T14]] [1] : f32 into vector<2xf32> +// CHECK: %[[T18:.*]] = vector.extract %[[C]][0, 1] : vector<2x2xf32> // CHECK: %[[T19:.*]] = mulf %[[T0]], %[[T17]] : vector<2xf32> // CHECK: %[[T20:.*]] = vector.reduction "add", %[[T19]], %[[T18]] : vector<2xf32> into f32 // CHECK: %[[T21:.*]] = vector.insert %[[T20]], %[[T11]] [1] : f32 into vector<2xf32> // CHECK: %[[T22:.*]] = vector.insert %[[T21]], %[[R]] [0] : vector<2xf32> into vector<2x2xf32> +// // CHECK: %[[T23:.*]] = vector.extract %[[A]][1] : vector<2x2xf32> -// CHECK: %[[T24:.*]] = vector.extract %[[C]][1] : vector<2x2xf32> -// CHECK: %[[T25:.*]] = vector.extract %[[B]][0] : vector<2x2xf32> -// CHECK: %[[T26:.*]] = vector.extract %[[T25]][0] : vector<2xf32> -// CHECK: %[[T27:.*]] = vector.insert %[[T26]], %[[Z]] [0] : f32 into vector<2xf32> -// CHECK: %[[T28:.*]] = vector.extract %[[B]][1] : vector<2x2xf32> -// CHECK: %[[T29:.*]] = vector.extract %[[T28]][0] : vector<2xf32> -// CHECK: %[[T30:.*]] = vector.insert %[[T29]], %[[T27]] [1] : f32 into vector<2xf32> -// CHECK: %[[T31:.*]] = vector.extract %[[T24]][0] : vector<2xf32> -// CHECK: %[[T32:.*]] = mulf %[[T23]], %[[T30]] : vector<2xf32> -// CHECK: %[[T33:.*]] = vector.reduction "add", %[[T32]], %[[T31]] : vector<2xf32> into f32 +// CHECK: %[[T22b:.*]] = vector.extract %[[B]][0, 0] : vector<2x2xf32> +// CHECK: %[[T24:.*]] = vector.insert %[[T22b]], %[[Z]] [0] : f32 into vector<2xf32> +// CHECK: %[[T25:.*]] = vector.extract %[[B]][1, 0] : vector<2x2xf32> +// CHECK: %[[T27:.*]] = vector.insert %[[T25]], %[[T24]] [1] : f32 into vector<2xf32> +// CHECK: %[[T28:.*]] = vector.extract %[[C]][1, 0] : vector<2x2xf32> +// CHECK: %[[T32:.*]] = mulf %[[T23]], %[[T27]] : vector<2xf32> +// CHECK: %[[T33:.*]] = vector.reduction "add", %[[T32]], %[[T28]] : vector<2xf32> into f32 // CHECK: %[[T34:.*]] = vector.insert %[[T33]], %[[Z]] [0] : f32 into vector<2xf32> -// CHECK: %[[T35:.*]] = vector.extract %[[B]][0] : vector<2x2xf32> -// CHECK: %[[T36:.*]] = vector.extract %[[T35]][1] : vector<2xf32> -// CHECK: %[[T37:.*]] = vector.insert %[[T36]], %[[Z]] [0] : f32 into vector<2xf32> -// CHECK: %[[T38:.*]] = vector.extract %[[B]][1] : vector<2x2xf32> -// CHECK: %[[T39:.*]] = vector.extract %[[T38]][1] : vector<2xf32> -// CHECK: %[[T40:.*]] = vector.insert %[[T39]], %[[T37]] [1] : f32 into vector<2xf32> -// CHECK: %[[T41:.*]] = vector.extract %[[T24]][1] : vector<2xf32> -// CHECK: %[[T42:.*]] = mulf %[[T23]], %[[T40]] : vector<2xf32> -// CHECK: %[[T43:.*]] = vector.reduction "add", %[[T42]], %[[T41]] : vector<2xf32> into f32 -// CHECK: %[[T44:.*]] = vector.insert %[[T43]], %[[T34]] [1] : f32 into vector<2xf32> -// CHECK: %[[T45:.*]] = vector.insert %[[T44]], %[[T22]] [1] : vector<2xf32> into vector<2x2xf32> -// CHECK: return %[[T45]] : vector<2x2xf32> +// +// CHECK: %[[T42:.*]] = vector.extract %[[B]][0, 1] : vector<2x2xf32> +// CHECK: %[[T44:.*]] = vector.insert %[[T42]], %[[Z]] [0] : f32 into vector<2xf32> +// CHECK: %[[T45:.*]] = vector.extract %[[B]][1, 1] : vector<2x2xf32> +// CHECK: %[[T47:.*]] = vector.insert %[[T45]], %[[T44]] [1] : f32 into vector<2xf32> +// CHECK: %[[T48:.*]] = vector.extract %[[C]][1, 1] : vector<2x2xf32> +// CHECK: %[[T49:.*]] = mulf %[[T23]], %[[T47]] : vector<2xf32> +// CHECK: %[[T50:.*]] = vector.reduction "add", %[[T49]], %[[T48]] : vector<2xf32> into f32 +// +// CHECK: %[[T51:.*]] = vector.insert %[[T50]], %[[T34]] [1] : f32 into vector<2xf32> +// CHECK: %[[T52:.*]] = vector.insert %[[T51]], %[[T22]] [1] : vector<2xf32> into vector<2x2xf32> +// CHECK: return %[[T52]] : vector<2x2xf32> func @extract_contract4(%arg0: vector<2x2xf32>, %arg1: vector<2x2xf32>, @@ -216,27 +210,22 @@ func @full_contract1(%arg0: vector<2x3xf32>, // CHECK-SAME: %[[C:.*2]]: f32 // CHECK: %[[Z:.*]] = constant dense<0.000000e+00> : vector<3xf32> // CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32> -// CHECK: %[[T1:.*]] = vector.extract %[[B]][0] : vector<3x2xf32> -// CHECK: %[[T2:.*]] = vector.extract %[[T1]][0] : vector<2xf32> -// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[Z]] [0] : f32 into vector<3xf32> -// CHECK: %[[T4:.*]] = vector.extract %[[B]][1] : vector<3x2xf32> -// CHECK: %[[T5:.*]] = vector.extract %[[T4]][0] : vector<2xf32> -// CHECK: %[[T6:.*]] = vector.insert %[[T5]], %[[T3]] [1] : f32 into vector<3xf32> -// CHECK: %[[T7:.*]] = vector.extract %[[B]][2] : vector<3x2xf32> -// CHECK: %[[T8:.*]] = vector.extract %[[T7]][0] : vector<2xf32> -// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T6]] [2] : f32 into vector<3xf32> +// CHECK: %[[T1:.*]] = vector.extract %[[B]][0, 0] : vector<3x2xf32> +// CHECK: %[[T3:.*]] = vector.insert %[[T1]], %[[Z]] [0] : f32 into vector<3xf32> +// CHECK: %[[T4:.*]] = vector.extract %[[B]][1, 0] : vector<3x2xf32> +// CHECK: %[[T6:.*]] = vector.insert %[[T4]], %[[T3]] [1] : f32 into vector<3xf32> +// CHECK: %[[T7:.*]] = vector.extract %[[B]][2, 0] : vector<3x2xf32> +// CHECK: %[[T9:.*]] = vector.insert %[[T7]], %[[T6]] [2] : f32 into vector<3xf32> // CHECK: %[[T10:.*]] = mulf %[[T0]], %[[T9]] : vector<3xf32> // CHECK: %[[T11:.*]] = vector.reduction "add", %[[T10]], %[[C]] : vector<3xf32> into f32 +// // CHECK: %[[T12:.*]] = vector.extract %[[A]][1] : vector<2x3xf32> -// CHECK: %[[T13:.*]] = vector.extract %[[B]][0] : vector<3x2xf32> -// CHECK: %[[T14:.*]] = vector.extract %[[T13]][1] : vector<2xf32> -// CHECK: %[[T15:.*]] = vector.insert %[[T14]], %[[Z]] [0] : f32 into vector<3xf32> -// CHECK: %[[T16:.*]] = vector.extract %[[B]][1] : vector<3x2xf32> -// CHECK: %[[T17:.*]] = vector.extract %[[T16]][1] : vector<2xf32> -// CHECK: %[[T18:.*]] = vector.insert %[[T17]], %[[T15]] [1] : f32 into vector<3xf32> -// CHECK: %[[T19:.*]] = vector.extract %[[B]][2] : vector<3x2xf32> -// CHECK: %[[T20:.*]] = vector.extract %[[T19]][1] : vector<2xf32> -// CHECK: %[[T21:.*]] = vector.insert %[[T20]], %[[T18]] [2] : f32 into vector<3xf32> +// CHECK: %[[T13:.*]] = vector.extract %[[B]][0, 1] : vector<3x2xf +// CHECK: %[[T15:.*]] = vector.insert %[[T13]], %[[Z]] [0] : f32 into vector<3xf32> +// CHECK: %[[T16:.*]] = vector.extract %[[B]][1, 1] : vector<3x2xf32> +// CHECK: %[[T18:.*]] = vector.insert %[[T16]], %[[T15]] [1] : f32 into vector<3xf32> +// CHECK: %[[T19:.*]] = vector.extract %[[B]][2, 1] : vector<3x2xf32> +// CHECK: %[[T21:.*]] = vector.insert %[[T19]], %[[T18]] [2] : f32 into vector<3xf32> // CHECK: %[[T22:.*]] = mulf %[[T12]], %[[T21]] : vector<3xf32> // CHECK: %[[T23:.*]] = vector.reduction "add", %[[T22]], %[[T11]] : vector<3xf32> into f32 // CHECK: return %[[T23]] : f32 @@ -657,21 +646,17 @@ func @broadcast_stretch_at_start(%arg0: vector<1x4xf32>) -> vector<3x4xf32> { // CHECK-LABEL: func @broadcast_stretch_at_end // CHECK-SAME: %[[A:.*0]]: vector<4x1xf32> // CHECK: %[[C0:.*]] = constant dense<0.000000e+00> : vector<4x3xf32> -// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<4x1xf32> -// CHECK: %[[T1:.*]] = vector.extract %[[T0]][0] : vector<1xf32> -// CHECK: %[[T2:.*]] = splat %[[T1]] : vector<3xf32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<4x1xf32> +// CHECK: %[[T2:.*]] = splat %[[T0]] : vector<3xf32> // CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xf32> into vector<4x3xf32> -// CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : vector<4x1xf32> -// CHECK: %[[T5:.*]] = vector.extract %[[T4]][0] : vector<1xf32> -// CHECK: %[[T6:.*]] = splat %[[T5]] : vector<3xf32> +// CHECK: %[[T4:.*]] = vector.extract %[[A]][1, 0] : vector<4x1xf32> +// CHECK: %[[T6:.*]] = splat %[[T4]] : vector<3xf32> // CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<4x3xf32> -// CHECK: %[[T8:.*]] = vector.extract %[[A]][2] : vector<4x1xf32> -// CHECK: %[[T9:.*]] = vector.extract %[[T8]][0] : vector<1xf32> -// CHECK: %[[T10:.*]] = splat %[[T9]] : vector<3xf32> +// CHECK: %[[T8:.*]] = vector.extract %[[A]][2, 0] : vector<4x1xf32> +// CHECK: %[[T10:.*]] = splat %[[T8]] : vector<3xf32> // CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T7]] [2] : vector<3xf32> into vector<4x3xf32> -// CHECK: %[[T12:.*]] = vector.extract %[[A]][3] : vector<4x1xf32> -// CHECK: %[[T13:.*]] = vector.extract %[[T12]][0] : vector<1xf32> -// CHECK: %[[T14:.*]] = splat %[[T13]] : vector<3xf32> +// CHECK: %[[T12:.*]] = vector.extract %[[A]][3, 0] : vector<4x1xf32> +// CHECK: %[[T14:.*]] = splat %[[T12]] : vector<3xf32> // CHECK: %[[T15:.*]] = vector.insert %[[T14]], %[[T11]] [3] : vector<3xf32> into vector<4x3xf32> // CHECK: return %[[T15]] : vector<4x3xf32> @@ -684,29 +669,25 @@ func @broadcast_stretch_at_end(%arg0: vector<4x1xf32>) -> vector<4x3xf32> { // CHECK-SAME: %[[A:.*0]]: vector<4x1x2xf32> // CHECK: %[[C0:.*]] = constant dense<0.000000e+00> : vector<4x3x2xf32> // CHECK: %[[C1:.*]] = constant dense<0.000000e+00> : vector<3x2xf32> -// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<4x1x2xf32> -// CHECK: %[[T1:.*]] = vector.extract %[[T0]][0] : vector<1x2xf32> -// CHECK: %[[T2:.*]] = vector.insert %[[T1]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32> -// CHECK: %[[T3:.*]] = vector.insert %[[T1]], %[[T2]] [1] : vector<2xf32> into vector<3x2xf32> -// CHECK: %[[T4:.*]] = vector.insert %[[T1]], %[[T3]] [2] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<4x1x2xf32> +// CHECK: %[[T2:.*]] = vector.insert %[[T0]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T3:.*]] = vector.insert %[[T0]], %[[T2]] [1] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T4:.*]] = vector.insert %[[T0]], %[[T3]] [2] : vector<2xf32> into vector<3x2xf32> // CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[C0]] [0] : vector<3x2xf32> into vector<4x3x2xf32> -// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<4x1x2xf32> -// CHECK: %[[T7:.*]] = vector.extract %[[T6]][0] : vector<1x2xf32> -// CHECK: %[[T8:.*]] = vector.insert %[[T7]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32> -// CHECK: %[[T9:.*]] = vector.insert %[[T7]], %[[T8]] [1] : vector<2xf32> into vector<3x2xf32> -// CHECK: %[[T10:.*]] = vector.insert %[[T7]], %[[T9]] [2] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T6:.*]] = vector.extract %[[A]][1, 0] : vector<4x1x2xf32> +// CHECK: %[[T8:.*]] = vector.insert %[[T6]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T9:.*]] = vector.insert %[[T6]], %[[T8]] [1] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T10:.*]] = vector.insert %[[T6]], %[[T9]] [2] : vector<2xf32> into vector<3x2xf32> // CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T5]] [1] : vector<3x2xf32> into vector<4x3x2xf32> -// CHECK: %[[T12:.*]] = vector.extract %[[A]][2] : vector<4x1x2xf32> -// CHECK: %[[T13:.*]] = vector.extract %[[T12]][0] : vector<1x2xf32> -// CHECK: %[[T14:.*]] = vector.insert %[[T13]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32> -// CHECK: %[[T15:.*]] = vector.insert %[[T13]], %[[T14]] [1] : vector<2xf32> into vector<3x2xf32> -// CHECK: %[[T16:.*]] = vector.insert %[[T13]], %[[T15]] [2] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T12:.*]] = vector.extract %[[A]][2, 0] : vector<4x1x2xf32> +// CHECK: %[[T14:.*]] = vector.insert %[[T12]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T15:.*]] = vector.insert %[[T12]], %[[T14]] [1] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T16:.*]] = vector.insert %[[T12]], %[[T15]] [2] : vector<2xf32> into vector<3x2xf32> // CHECK: %[[T17:.*]] = vector.insert %[[T16]], %[[T11]] [2] : vector<3x2xf32> into vector<4x3x2xf32> -// CHECK: %[[T18:.*]] = vector.extract %[[A]][3] : vector<4x1x2xf32> -// CHECK: %[[T19:.*]] = vector.extract %[[T18]][0] : vector<1x2xf32> -// CHECK: %[[T20:.*]] = vector.insert %[[T19]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32> -// CHECK: %[[T21:.*]] = vector.insert %[[T19]], %[[T20]] [1] : vector<2xf32> into vector<3x2xf32> -// CHECK: %[[T22:.*]] = vector.insert %[[T19]], %[[T21]] [2] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T18:.*]] = vector.extract %[[A]][3, 0] : vector<4x1x2xf32> +// CHECK: %[[T20:.*]] = vector.insert %[[T18]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T21:.*]] = vector.insert %[[T18]], %[[T20]] [1] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T22:.*]] = vector.insert %[[T18]], %[[T21]] [2] : vector<2xf32> into vector<3x2xf32> // CHECK: %[[T23:.*]] = vector.insert %[[T22]], %[[T17]] [3] : vector<3x2xf32> into vector<4x3x2xf32> // CHECK: return %[[T23]] : vector<4x3x2xf32> |