summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorThomas Raoux <thomasraoux@google.com>2021-11-18 16:09:49 -0800
committerThomas Raoux <thomasraoux@google.com>2021-11-19 10:25:21 -0800
commit06dbb2856967a5497c6ddfad3d3fdfea20849f7e (patch)
treeb9dd59a8503a2e3bce8a599bb3bb631c3e8a135d
parentffdace4892bd1f43121d365c22eb9c3fe79aeb6c (diff)
downloadllvm-06dbb2856967a5497c6ddfad3d3fdfea20849f7e.tar.gz
[mlir][vector] Remove usage of shapecast to remove unit dim
Instead of using shape_cast op in the pattern removing leading unit dimensions we use extract/broadcast ops. This is part of the effort to restrict ShapeCastOp fuirther in the future and only allow them to convert to or from 1D vector. This also adds extra canonicalization to fill the gaps in simplifying broadcast/extract ops. Differential Revision: https://reviews.llvm.org/D114205
-rw-r--r--mlir/lib/Dialect/Vector/VectorOps.cpp83
-rw-r--r--mlir/lib/Dialect/Vector/VectorTransforms.cpp142
-rw-r--r--mlir/test/Dialect/Vector/canonicalize.mlir39
-rw-r--r--mlir/test/Dialect/Vector/vector-dim-one-shape-cast.mlir23
-rw-r--r--mlir/test/Dialect/Vector/vector-transforms.mlir66
5 files changed, 161 insertions, 192 deletions
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 3ca8fa0dcf0b..4b67b39b2fdb 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -1125,9 +1125,6 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
b.getI64ArrayAttr(extractPos));
return extractOp.getResult();
}
- // TODO: In case the rank of the broadcast source is greater than the rank of
- // the extract result this can be combined into a new broadcast op. This needs
- // to be added a canonicalization pattern if needed.
return Value();
}
@@ -1208,12 +1205,63 @@ OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) {
namespace {
+// Pattern to rewrite a ExtractOp(Broadcast) -> Broadcast.
+class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
+public:
+ using OpRewritePattern<ExtractOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ExtractOp extractOp,
+ PatternRewriter &rewriter) const override {
+ Operation *defOp = extractOp.vector().getDefiningOp();
+ if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
+ return failure();
+ Value source = defOp->getOperand(0);
+ if (extractOp.getType() == source.getType())
+ return failure();
+ auto getRank = [](Type type) {
+ return type.isa<VectorType>() ? type.cast<VectorType>().getRank() : 0;
+ };
+ unsigned broadcasrSrcRank = getRank(source.getType());
+ unsigned extractResultRank = getRank(extractOp.getType());
+ // We only consider the case where the rank of the source is smaller than
+ // the rank of the extract dst. The other cases are handled in the folding
+ // patterns.
+ if (extractResultRank <= broadcasrSrcRank)
+ return failure();
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+ extractOp, extractOp.getType(), source);
+ return success();
+ }
+};
+
+// Pattern to rewrite a ExtractOp(splat ConstantOp) -> ConstantOp.
+class ExtractOpConstantFolder final : public OpRewritePattern<ExtractOp> {
+public:
+ using OpRewritePattern<ExtractOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ExtractOp extractOp,
+ PatternRewriter &rewriter) const override {
+ // Return if 'extractStridedSliceOp' operand is not defined by a
+ // ConstantOp.
+ auto constantOp = extractOp.vector().getDefiningOp<arith::ConstantOp>();
+ if (!constantOp)
+ return failure();
+ auto dense = constantOp.getValue().dyn_cast<SplatElementsAttr>();
+ if (!dense)
+ return failure();
+ Attribute newAttr = dense.getSplatValue<Attribute>();
+ if (auto vecDstType = extractOp.getType().dyn_cast<VectorType>())
+ newAttr = DenseElementsAttr::get(vecDstType, newAttr);
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
+ return success();
+ }
+};
+
} // namespace
void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- // ExtractToShapeCast is not a default canonicalization, it is opt-in by
- // calling `populateCastAwayVectorLeadingOneDimPatterns`
+ results.add<ExtractOpConstantFolder, ExtractOpFromBroadcast>(context);
}
static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
@@ -1555,10 +1603,31 @@ static LogicalResult verify(InsertOp op) {
return success();
}
+namespace {
+
+// If insertOp is only inserting unit dimensions it can be transformed to a
+// broadcast.
+class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
+public:
+ using OpRewritePattern<InsertOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(InsertOp insertOp,
+ PatternRewriter &rewriter) const override {
+ auto srcVecType = insertOp.getSourceType().dyn_cast<VectorType>();
+ if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
+ srcVecType.getNumElements())
+ return failure();
+ rewriter.replaceOpWithNewOp<BroadcastOp>(
+ insertOp, insertOp.getDestVectorType(), insertOp.source());
+ return success();
+ }
+};
+
+} // namespace
+
void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- // InsertToShapeCast is not a default canonicalization, it is opt-in by
- // calling `populateCastAwayVectorLeadingOneDimPatterns`
+ results.add<InsertToBroadcast, BroadcastFolder>(context);
}
// Eliminates insert operations that produce values identical to their source
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index cf40e4f27260..8a77f171c61b 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -2943,6 +2943,11 @@ static VectorType trimLeadingOneDims(VectorType oldType) {
return VectorType::get(newShape, oldType.getElementType());
}
+/// Return a smallVector of size `rank` containing all zeros.
+static SmallVector<int64_t> splatZero(int64_t rank) {
+ return SmallVector<int64_t>(rank, 0);
+}
+
// Casts away leading one dimensions in vector.extract_strided_slice's vector
// input by inserting vector.shape_cast.
struct CastAwayExtractStridedSliceLeadingOneDim
@@ -2969,8 +2974,8 @@ struct CastAwayExtractStridedSliceLeadingOneDim
Location loc = extractOp.getLoc();
- Value newSrcVector = rewriter.create<vector::ShapeCastOp>(
- loc, newSrcType, extractOp.vector());
+ Value newSrcVector = rewriter.create<vector::ExtractOp>(
+ loc, extractOp.vector(), splatZero(dropCount));
// The offsets/sizes/strides attribute can have a less number of elements
// than the input vector's rank: it is meant for the leading dimensions.
@@ -2984,7 +2989,7 @@ struct CastAwayExtractStridedSliceLeadingOneDim
auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
loc, newDstType, newSrcVector, newOffsets, newSizes, newStrides);
- rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, oldDstType,
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(extractOp, oldDstType,
newExtractOp);
return success();
@@ -3004,17 +3009,18 @@ struct CastAwayInsertStridedSliceLeadingOneDim
VectorType oldDstType = insertOp.getDestVectorType();
VectorType newDstType = trimLeadingOneDims(oldDstType);
- if (newSrcType.getRank() == oldSrcType.getRank() &&
- newDstType.getRank() == oldDstType.getRank())
+ int64_t srcDropCount = oldSrcType.getRank() - newSrcType.getRank();
+ int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
+ if (srcDropCount == 0 && dstDropCount == 0)
return failure();
// Trim leading one dimensions from both operands.
Location loc = insertOp.getLoc();
- Value newSrcVector = rewriter.create<vector::ShapeCastOp>(
- loc, newSrcType, insertOp.source());
- Value newDstVector =
- rewriter.create<vector::ShapeCastOp>(loc, newDstType, insertOp.dest());
+ Value newSrcVector = rewriter.create<vector::ExtractOp>(
+ loc, insertOp.source(), splatZero(srcDropCount));
+ Value newDstVector = rewriter.create<vector::ExtractOp>(
+ loc, insertOp.dest(), splatZero(dstDropCount));
auto newOffsets = rewriter.getArrayAttr(
insertOp.offsets().getValue().take_back(newDstType.getRank()));
@@ -3024,7 +3030,7 @@ struct CastAwayInsertStridedSliceLeadingOneDim
auto newInsertOp = rewriter.create<vector::InsertStridedSliceOp>(
loc, newDstType, newSrcVector, newDstVector, newOffsets, newStrides);
- rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(insertOp, oldDstType,
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType,
newInsertOp);
return success();
@@ -3068,7 +3074,7 @@ struct CastAwayTransferReadLeadingOneDim
auto newRead = rewriter.create<vector::TransferReadOp>(
read.getLoc(), newType, read.source(), read.indices(), newMap,
read.padding(), inBounds);
- rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(read, oldType, newRead);
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(read, oldType, newRead);
return success();
}
@@ -3092,9 +3098,9 @@ struct CastAwayTransferWriteLeadingOneDim
VectorType oldType = write.getVectorType();
VectorType newType = trimLeadingOneDims(oldType);
-
if (newType == oldType)
return failure();
+ int64_t dropDim = oldType.getRank() - newType.getRank();
AffineMap oldMap = write.permutation_map();
ArrayRef<AffineExpr> newResults =
@@ -3108,8 +3114,8 @@ struct CastAwayTransferWriteLeadingOneDim
inBounds = rewriter.getArrayAttr(
write.in_boundsAttr().getValue().take_back(newType.getRank()));
- auto newVector = rewriter.create<vector::ShapeCastOp>(
- write.getLoc(), newType, write.vector());
+ auto newVector = rewriter.create<vector::ExtractOp>(
+ write.getLoc(), write.vector(), splatZero(dropDim));
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
write, newVector, write.source(), write.indices(), newMap, inBounds);
@@ -3117,35 +3123,6 @@ struct CastAwayTransferWriteLeadingOneDim
}
};
-template <typename BroadCastType>
-struct CastAwayBroadcastLeadingOneDim : public OpRewritePattern<BroadCastType> {
- using OpRewritePattern<BroadCastType>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(BroadCastType broadcastOp,
- PatternRewriter &rewriter) const override {
- VectorType dstType =
- broadcastOp.getResult().getType().template dyn_cast<VectorType>();
- if (!dstType)
- return failure();
- VectorType newDstType = trimLeadingOneDims(dstType);
- if (newDstType == dstType)
- return failure();
- Location loc = broadcastOp.getLoc();
- Value source = broadcastOp->getOperand(0);
- VectorType srcVecType = source.getType().template dyn_cast<VectorType>();
- if (srcVecType)
- srcVecType = trimLeadingOneDims(srcVecType);
- if (srcVecType && srcVecType != source.getType()) {
- source = rewriter.create<vector::ShapeCastOp>(loc, srcVecType, source);
- }
- Value newBroadcastOp =
- rewriter.create<BroadCastType>(loc, newDstType, source);
- rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(broadcastOp, dstType,
- newBroadcastOp);
- return success();
- }
-};
-
class CastAwayElementwiseLeadingOneDim : public RewritePattern {
public:
CastAwayElementwiseLeadingOneDim(MLIRContext *context)
@@ -3161,14 +3138,12 @@ public:
VectorType newVecType = trimLeadingOneDims(vecType);
if (newVecType == vecType)
return failure();
-
+ int64_t dropDim = vecType.getRank() - newVecType.getRank();
SmallVector<Value, 4> newOperands;
for (Value operand : op->getOperands()) {
if (auto opVecType = operand.getType().dyn_cast<VectorType>()) {
- auto newType =
- VectorType::get(newVecType.getShape(), opVecType.getElementType());
- newOperands.push_back(rewriter.create<vector::ShapeCastOp>(
- op->getLoc(), newType, operand));
+ newOperands.push_back(rewriter.create<vector::ExtractOp>(
+ op->getLoc(), operand, splatZero(dropDim)));
} else {
newOperands.push_back(operand);
}
@@ -3178,69 +3153,12 @@ public:
state.addOperands(newOperands);
state.addTypes(newVecType);
Operation *newOp = rewriter.createOperation(state);
- rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, vecType,
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecType,
newOp->getResult(0));
return success();
}
};
-// If extractOp is only removing unit dimensions it can be transformed to a
-// shapecast.
-class ExtractToShapeCast final : public OpRewritePattern<ExtractOp> {
-public:
- using OpRewritePattern<ExtractOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(ExtractOp extractOp,
- PatternRewriter &rewriter) const override {
- auto dstVecType = extractOp.getResult().getType().dyn_cast<VectorType>();
- if (!dstVecType || extractOp.getVectorType().getNumElements() !=
- dstVecType.getNumElements())
- return failure();
- rewriter.replaceOpWithNewOp<ShapeCastOp>(extractOp, dstVecType,
- extractOp.vector());
- return success();
- }
-};
-
-// If insertOp is only inserting unit dimensions it can be transformed to a
-// shapecast.
-class InsertToShapeCast final : public OpRewritePattern<InsertOp> {
-public:
- using OpRewritePattern<InsertOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(InsertOp insertOp,
- PatternRewriter &rewriter) const override {
- auto srcVecType = insertOp.getSourceType().dyn_cast<VectorType>();
- if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
- srcVecType.getNumElements())
- return failure();
- rewriter.replaceOpWithNewOp<ShapeCastOp>(
- insertOp, insertOp.getDestVectorType(), insertOp.source());
- return success();
- }
-};
-
-// BroadcastOp can only add dimensions or broadcast a dimension from 1 to N. In
-// the degenerated case where the broadcast only adds dimensions of size 1 it
-// can be replaced by a ShapeCastOp. This canonicalization checks if the total
-// number of elements is the same before and after the broadcast to detect if
-// the only change in the vector type are new dimensions of size 1.
-class BroadcastToShapeCast final : public OpRewritePattern<BroadcastOp> {
-public:
- using OpRewritePattern<BroadcastOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
- PatternRewriter &rewriter) const override {
- auto srcVecType = broadcastOp.getSourceType().dyn_cast<VectorType>();
- if (!srcVecType || broadcastOp.getVectorType().getNumElements() !=
- srcVecType.getNumElements())
- return failure();
- rewriter.replaceOpWithNewOp<ShapeCastOp>(
- broadcastOp, broadcastOp.getVectorType(), broadcastOp.source());
- return success();
- }
-};
-
// Returns the values in `arrayAttr` as an integer vector.
static SmallVector<int64_t, 4> getIntValueVector(ArrayAttr arrayAttr) {
return llvm::to_vector<4>(
@@ -3722,13 +3640,11 @@ void mlir::vector::populateShapeCastFoldingPatterns(
void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
RewritePatternSet &patterns) {
- patterns.add<
- BroadcastToShapeCast, CastAwayExtractStridedSliceLeadingOneDim,
- CastAwayInsertStridedSliceLeadingOneDim,
- CastAwayTransferReadLeadingOneDim, CastAwayTransferWriteLeadingOneDim,
- CastAwayBroadcastLeadingOneDim<vector::BroadcastOp>,
- CastAwayBroadcastLeadingOneDim<SplatOp>, CastAwayElementwiseLeadingOneDim,
- ExtractToShapeCast, InsertToShapeCast>(patterns.getContext());
+ patterns.add<CastAwayExtractStridedSliceLeadingOneDim,
+ CastAwayInsertStridedSliceLeadingOneDim,
+ CastAwayTransferReadLeadingOneDim,
+ CastAwayTransferWriteLeadingOneDim,
+ CastAwayElementwiseLeadingOneDim>(patterns.getContext());
populateShapeCastFoldingPatterns(patterns);
}
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 3d60745b8ccd..3557cebae0af 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -496,13 +496,10 @@ func @fold_extract_broadcast(%a : vector<4xf32>) -> f32 {
// -----
-// Negative test for extract_op folding when the type of broadcast source
-// doesn't match the type of vector.extract.
-// CHECK-LABEL: fold_extract_broadcast_negative
-// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<1x2x4xf32>
-// CHECK: %[[R:.*]] = vector.extract %[[B]][0, 1] : vector<1x2x4xf32>
-// CHECK: return %[[R]] : vector<4xf32>
-func @fold_extract_broadcast_negative(%a : f32) -> vector<4xf32> {
+// CHECK-LABEL: fold_extract_broadcast
+// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32>
+// CHECK: return %[[B]] : vector<4xf32>
+func @fold_extract_broadcast(%a : f32) -> vector<4xf32> {
%b = vector.broadcast %a : f32 to vector<1x2x4xf32>
%r = vector.extract %b[0, 1] : vector<1x2x4xf32>
return %r : vector<4xf32>
@@ -1058,3 +1055,31 @@ func @extract_strided_splat(%arg0: f16) -> vector<2x4xf16> {
vector<16x4xf16> to vector<2x4xf16>
return %1 : vector<2x4xf16>
}
+
+// -----
+
+// CHECK-LABEL: func @insert_extract_to_broadcast
+// CHECK-SAME: (%[[ARG0:.*]]: vector<1x1x4xf32>, %[[ARG1:.*]]: vector<4xf32>)
+// CHECK: %[[V0:.*]] = vector.extract %[[ARG0]][0, 0] : vector<1x1x4xf32>
+// CHECK: %[[V1:.*]] = vector.broadcast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32>
+// CHECK: return %[[V0]], %[[V1]] : vector<4xf32>, vector<1x1x4xf32>
+func @insert_extract_to_broadcast(%arg0 : vector<1x1x4xf32>,
+ %arg1 : vector<4xf32>) -> (vector<4xf32>, vector<1x1x4xf32>) {
+ %0 = vector.extract %arg0[0, 0] : vector<1x1x4xf32>
+ %1 = vector.insert %arg1, %arg0 [0, 0] : vector<4xf32> into vector<1x1x4xf32>
+ return %0, %1 : vector<4xf32>, vector<1x1x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: extract_constant
+// CHECK-DAG: %[[CST1:.*]] = arith.constant 1 : i32
+// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<2.000000e+00> : vector<7xf32>
+// CHECK: return %[[CST0]], %[[CST1]] : vector<7xf32>, i32
+func @extract_constant() -> (vector<7xf32>, i32) {
+ %cst = arith.constant dense<2.000000e+00> : vector<29x7xf32>
+ %cst_1 = arith.constant dense<1> : vector<4x37x9xi32>
+ %0 = vector.extract %cst[2] : vector<29x7xf32>
+ %1 = vector.extract %cst_1[1, 4, 5] : vector<4x37x9xi32>
+ return %0, %1 : vector<7xf32>, i32
+}
diff --git a/mlir/test/Dialect/Vector/vector-dim-one-shape-cast.mlir b/mlir/test/Dialect/Vector/vector-dim-one-shape-cast.mlir
deleted file mode 100644
index 5dd44d38ccb7..000000000000
--- a/mlir/test/Dialect/Vector/vector-dim-one-shape-cast.mlir
+++ /dev/null
@@ -1,23 +0,0 @@
-// RUN: mlir-opt %s -test-vector-to-vector-lowering | FileCheck %s
-
-// CHECK-LABEL: broadcast_to_shapecast
-// CHECK: %[[C:.*]] = vector.shape_cast %{{.*}} : vector<4x4xf16> to vector<1x4x4xf16>
-// CHECK-NEXT: return %[[C]] : vector<1x4x4xf16>
-func @broadcast_to_shapecast(%arg0: vector<4x4xf16>) -> vector<1x4x4xf16> {
- %0 = vector.broadcast %arg0 : vector<4x4xf16> to vector<1x4x4xf16>
- return %0 : vector<1x4x4xf16>
-}
-
-// -----
-
-// CHECK-LABEL: func @insert_extract_to_shapecast
-// CHECK-SAME: (%[[ARG0:.*]]: vector<1x1x4xf32>, %[[ARG1:.*]]: vector<4xf32>)
-// CHECK: %[[V0:.*]] = vector.shape_cast %[[ARG0]] : vector<1x1x4xf32> to vector<4xf32>
-// CHECK: %[[V1:.*]] = vector.shape_cast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32>
-// CHECK: return %[[V0]], %[[V1]] : vector<4xf32>, vector<1x1x4xf32>
-func @insert_extract_to_shapecast(%arg0 : vector<1x1x4xf32>,
- %arg1 : vector<4xf32>) -> (vector<4xf32>, vector<1x1x4xf32>) {
- %0 = vector.extract %arg0[0, 0] : vector<1x1x4xf32>
- %1 = vector.insert %arg1, %arg0 [0, 0] : vector<4xf32> into vector<1x1x4xf32>
- return %0, %1 : vector<4xf32>, vector<1x1x4xf32>
-}
diff --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir
index 50ed05a4b956..96c35210d8fd 100644
--- a/mlir/test/Dialect/Vector/vector-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-transforms.mlir
@@ -421,21 +421,21 @@ func @contraction4x4_ikj_xfer_read_tensor(%arg0 : tensor<4x2xf32>,
// CHECK-LABEL: func @cast_away_extract_strided_slice_leading_one_dims
func @cast_away_extract_strided_slice_leading_one_dims(%arg0: vector<1x8x8xf16>) -> vector<1x1x8xf16> {
- // CHECK: %[[SRC:.+]] = vector.shape_cast %{{.*}} : vector<1x8x8xf16> to vector<8x8xf16>
+ // CHECK: %[[SRC:.+]] = vector.extract %{{.*}}[0] : vector<1x8x8xf16>
// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[SRC]] {offsets = [4], sizes = [1], strides = [1]} : vector<8x8xf16> to vector<1x8xf16>
%0 = vector.extract_strided_slice %arg0 {offsets = [0, 4], sizes = [1, 1], strides = [1, 1]} : vector<1x8x8xf16> to vector<1x1x8xf16>
- // CHECK: %[[RET:.+]] = vector.shape_cast %[[EXTRACT]] : vector<1x8xf16> to vector<1x1x8xf16>
+ // CHECK: %[[RET:.+]] = vector.broadcast %[[EXTRACT]] : vector<1x8xf16> to vector<1x1x8xf16>
// CHECK: return %[[RET]]
return %0: vector<1x1x8xf16>
}
// CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims
func @cast_away_insert_strided_slice_leading_one_dims(%arg0: vector<1x8xf16>, %arg1: vector<1x8x8xf16>) -> vector<1x8x8xf16> {
- // CHECK: %[[SRC:.+]] = vector.shape_cast %{{.*}} : vector<1x8xf16> to vector<8xf16>
- // CHECK: %[[DST:.+]] = vector.shape_cast %{{.*}} : vector<1x8x8xf16> to vector<8x8xf16>
+ // CHECK: %[[SRC:.+]] = vector.extract %{{.*}}[0] : vector<1x8xf16>
+ // CHECK: %[[DST:.+]] = vector.extract %{{.*}}[0] : vector<1x8x8xf16>
// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[SRC]], %[[DST]] {offsets = [0, 0], strides = [1]} : vector<8xf16> into vector<8x8xf16>
%0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 0], strides = [1, 1]} : vector<1x8xf16> into vector<1x8x8xf16>
- // CHECK: %[[RET:.+]] = vector.shape_cast %[[INSERT]] : vector<8x8xf16> to vector<1x8x8xf16>
+ // CHECK: %[[RET:.+]] = vector.broadcast %[[INSERT]] : vector<8x8xf16> to vector<1x8x8xf16>
// CHECK: return %[[RET]]
return %0: vector<1x8x8xf16>
}
@@ -443,9 +443,10 @@ func @cast_away_insert_strided_slice_leading_one_dims(%arg0: vector<1x8xf16>, %a
// CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims_one_element
// CHECK-SAME: %[[ARG0:.+]]: vector<1x1xf16>, %{{.+}}: vector<1x1x1xf16>
func @cast_away_insert_strided_slice_leading_one_dims_one_element(%arg0: vector<1x1xf16>, %arg1: vector<1x1x1xf16>) -> vector<1x1x1xf16> {
- // CHECK: %[[CAST:.+]] = vector.shape_cast %[[ARG0]] : vector<1x1xf16> to vector<1x1x1xf16>
+ // CHECK: %[[EXT:.+]] = vector.extract %{{.*}}[0] : vector<1x1xf16>
+ // CHECK: %[[B:.+]] = vector.broadcast %[[EXT]] : vector<1xf16> to vector<1x1x1xf16>
%0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 0], strides = [1, 1]} : vector<1x1xf16> into vector<1x1x1xf16>
- // CHECK: return %[[CAST]]
+ // CHECK: return %[[B]]
return %0: vector<1x1x1xf16>
}
@@ -456,7 +457,7 @@ func @cast_away_transfer_read_leading_one_dims(%arg0: memref<1x4x8x16xf16>) -> v
// CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f16
%f0 = arith.constant 0. : f16
// CHECK: %[[READ:.+]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[F0]] {in_bounds = [true]} : memref<1x4x8x16xf16>, vector<4xf16>
- // CHECK: %[[CAST:.+]] = vector.shape_cast %[[READ]] : vector<4xf16> to vector<1x4xf16>
+ // CHECK: %[[CAST:.+]] = vector.broadcast %[[READ]] : vector<4xf16> to vector<1x4xf16>
%0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %f0 {in_bounds = [true, true]} : memref<1x4x8x16xf16>, vector<1x4xf16>
// CHECK: return %[[CAST]]
return %0: vector<1x4xf16>
@@ -466,7 +467,7 @@ func @cast_away_transfer_read_leading_one_dims(%arg0: memref<1x4x8x16xf16>) -> v
func @cast_away_transfer_read_leading_one_dims_one_element(%arg0: memref<1x1x1x1xf16>) -> vector<1x1xf16> {
%c0 = arith.constant 0 : index
%f0 = arith.constant 0. : f16
- // CHECK: vector.shape_cast %{{.+}} : vector<1xf16> to vector<1x1xf16>
+ // CHECK: vector.broadcast %{{.+}} : vector<1xf16> to vector<1x1xf16>
%0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %f0 {in_bounds = [true, true]} : memref<1x1x1x1xf16>, vector<1x1xf16>
return %0: vector<1x1xf16>
}
@@ -475,7 +476,7 @@ func @cast_away_transfer_read_leading_one_dims_one_element(%arg0: memref<1x1x1x1
func @cast_away_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %arg1: vector<1x4xf16>) {
// CHECK: %[[C0:.+]] = arith.constant 0 : index
%c0 = arith.constant 0 : index
- // CHECK: %[[CAST:.+]] = vector.shape_cast %{{.*}} : vector<1x4xf16> to vector<4xf16>
+ // CHECK: %[[CAST:.+]] = vector.extract %{{.*}}[0] : vector<1x4xf16>
// CHECK: vector.transfer_write %[[CAST]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true]} : vector<4xf16>, memref<1x4x8x16xf16>
vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0] {in_bounds = [true, true]} : vector<1x4xf16>, memref<1x4x8x16xf16>
@@ -485,54 +486,35 @@ func @cast_away_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %ar
// CHECK-LABEL: func @cast_away_transfer_write_leading_one_dims_one_element
func @cast_away_transfer_write_leading_one_dims_one_element(%arg0: memref<1x1x1x1xf16>, %arg1: vector<1x1xf16>) {
%c0 = arith.constant 0 : index
- // CHECK: vector.shape_cast %{{.+}} : vector<1x1xf16> to vector<1xf16>
+ // CHECK: vector.extract %{{.+}}[0] : vector<1x1xf16>
vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0] {in_bounds = [true, true]} : vector<1x1xf16>, memref<1x1x1x1xf16>
return
}
-// CHECK-LABEL: func @cast_away_broadcast_leading_one_dims
-func @cast_away_broadcast_leading_one_dims(
- %arg0: vector<8xf32>, %arg1: f32, %arg2: vector<1x4xf32>) ->
- (vector<1x1x8xf32>, vector<1x1x4xf32>, vector<1x3x4xf32>, vector<1x1x4xf32>) {
- // CHECK: vector.shape_cast %{{.*}} : vector<8xf32> to vector<1x1x8xf32>
- %0 = vector.broadcast %arg0 : vector<8xf32> to vector<1x1x8xf32>
- // CHECK: vector.broadcast %{{.*}} : f32 to vector<4xf32>
- // CHECK: vector.shape_cast %{{.*}} : vector<4xf32> to vector<1x1x4xf32>
- %1 = vector.broadcast %arg1 : f32 to vector<1x1x4xf32>
- // CHECK: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32>
- // CHECK: vector.broadcast %{{.*}} : vector<4xf32> to vector<3x4xf32>
- // CHECK: vector.shape_cast %{{.*}} : vector<3x4xf32> to vector<1x3x4xf32>
- %2 = vector.broadcast %arg2 : vector<1x4xf32> to vector<1x3x4xf32>
- // CHECK: splat %{{.*}} : vector<4xf32>
- // CHECK: vector.shape_cast %{{.*}} : vector<4xf32> to vector<1x1x4xf32>
- %3 = splat %arg1 : vector<1x1x4xf32>
- return %0, %1, %2, %3: vector<1x1x8xf32>, vector<1x1x4xf32>, vector<1x3x4xf32>, vector<1x1x4xf32>
-}
-
// CHECK-LABEL: func @cast_away_elementwise_leading_one_dims
func @cast_away_elementwise_leading_one_dims(
%arg0: vector<1x1x8xf32>, %arg1: f32, %arg2: vector<1x4xf32>,
%arg3: vector<1x4xf32>, %arg4: i1) ->
(vector<1x1x8xf32>, vector<1x4xi1>, vector<1x4xf32>, vector<1x4xf32>) {
- // CHECK: vector.shape_cast %{{.*}} : vector<1x1x8xf32> to vector<8xf32>
- // CHECK: vector.shape_cast %{{.*}} : vector<1x1x8xf32> to vector<8xf32>
+ // CHECK: vector.extract %{{.*}}[0, 0] : vector<1x1x8xf32>
+ // CHECK: vector.extract %{{.*}}[0, 0] : vector<1x1x8xf32>
// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<8xf32>
- // CHECK: vector.shape_cast %{{.*}} : vector<8xf32> to vector<1x1x8xf32>
+ // CHECK: vector.broadcast %{{.*}} : vector<8xf32> to vector<1x1x8xf32>
%0 = arith.addf %arg0, %arg0 : vector<1x1x8xf32>
- // CHECK: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32>
- // CHECK: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32>
+ // CHECK: vector.extract %{{.*}}[0] : vector<1x4xf32>
+ // CHECK: vector.extract %{{.*}}[0] : vector<1x4xf32>
// CHECK: arith.cmpf ogt, %{{.*}}, %{{.*}} : vector<4xf32>
- // CHECK: vector.shape_cast %{{.*}} : vector<4xi1> to vector<1x4xi1>
+ // CHECK: vector.broadcast %{{.*}} : vector<4xi1> to vector<1x4xi1>
%1 = arith.cmpf ogt, %arg2, %arg3 : vector<1x4xf32>
- // CHECK: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32>
- // CHECK: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32>
+ // CHECK: vector.extract %{{.*}}[0] : vector<1x4xf32>
+ // CHECK: vector.extract %{{.*}}[0] : vector<1x4xf32>
// CHECK: select %{{.*}}, %{{.*}}, %{{.*}} : vector<4xi1>, vector<4xf32>
- // CHECK: vector.shape_cast %{{.*}} : vector<4xf32> to vector<1x4xf32>
+ // CHECK: vector.broadcast %{{.*}} : vector<4xf32> to vector<1x4xf32>
%2 = select %1, %arg3, %arg2 : vector<1x4xi1>, vector<1x4xf32>
- // CHECK: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32>
- // CHECK: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32>
+ // CHECK: vector.extract %{{.*}}[0] : vector<1x4xf32>
+ // CHECK: vector.extract %{{.*}}[0] : vector<1x4xf32>
// CHECK: select %arg4, %12, %{{.*}} : vector<4xf32>
- // CHECK: vector.shape_cast %{{.*}} : vector<4xf32> to vector<1x4xf32>
+ // CHECK: vector.broadcast %{{.*}} : vector<4xf32> to vector<1x4xf32>
%3 = select %arg4, %arg3, %arg2 : vector<1x4xf32>
return %0, %1, %2, %3: vector<1x1x8xf32>, vector<1x4xi1>, vector<1x4xf32>, vector<1x4xf32>
}