summaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
authorLei Zhang <antiagainst@google.com>2023-05-15 22:36:08 +0000
committerLei Zhang <antiagainst@google.com>2023-05-15 23:29:37 +0000
commit71703a097859a24883aa32c3ee258647412c311e (patch)
tree7599097e6ec387ed3f10e22153d08cccea29d3e7 /mlir
parentf649599ea93301bd0d0a2b8e450d1f77425ea92e (diff)
downloadllvm-71703a097859a24883aa32c3ee258647412c311e.tar.gz
[mlir][spirv] Check type legality using converter for vectors
This allows `index` vectors to be converted to SPIR-V. Reviewed By: ThomasRaoux Differential Revision: https://reviews.llvm.org/D150616
Diffstat (limited to 'mlir')
-rw-r--r--mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp14
-rw-r--r--mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir21
2 files changed, 30 insertions, 5 deletions
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 35171b3e077e..a4f20c610500 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -196,6 +196,12 @@ struct VectorInsertOpConvert final
LogicalResult
matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ if (isa<VectorType>(insertOp.getSourceType()))
+ return rewriter.notifyMatchFailure(insertOp, "unsupported vector source");
+ if (!getTypeConverter()->convertType(insertOp.getDestVectorType()))
+ return rewriter.notifyMatchFailure(insertOp,
+ "unsupported dest vector type");
+
// Special case for inserting scalar values into size-1 vectors.
if (insertOp.getSourceType().isIntOrFloat() &&
insertOp.getDestVectorType().getNumElements() == 1) {
@@ -203,9 +209,6 @@ struct VectorInsertOpConvert final
return success();
}
- if (isa<VectorType>(insertOp.getSourceType()) ||
- !spirv::CompositeType::isValid(insertOp.getDestVectorType()))
- return failure();
int32_t id = getFirstIntValue(insertOp.getPosition());
rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
insertOp, adaptor.getSource(), adaptor.getDest(), id);
@@ -413,9 +416,10 @@ struct VectorShuffleOpConvert final
matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto oldResultType = shuffleOp.getResultVectorType();
- if (!spirv::CompositeType::isValid(oldResultType))
- return failure();
Type newResultType = getTypeConverter()->convertType(oldResultType);
+ if (!newResultType)
+ return rewriter.notifyMatchFailure(shuffleOp,
+ "unsupported result vector type");
auto oldSourceType = shuffleOp.getV1VectorType();
if (oldSourceType.getNumElements() > 1) {
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 26a2ab6d6243..bedd3d11e6f9 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -183,6 +183,15 @@ func.func @insert(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> {
// -----
+// CHECK-LABEL: @insert_index_vector
+// CHECK: spirv.CompositeInsert %{{.+}}, %{{.+}}[2 : i32] : i32 into vector<4xi32>
+func.func @insert_index_vector(%arg0 : vector<4xindex>, %arg1: index) -> vector<4xindex> {
+ %1 = vector.insert %arg1, %arg0[2] : index into vector<4xindex>
+ return %1: vector<4xindex>
+}
+
+// -----
+
// CHECK-LABEL: @insert_size1_vector
// CHECK-SAME: %[[V:.*]]: vector<1xf32>, %[[S:.*]]: f32
// CHECK: %[[R:.+]] = builtin.unrealized_conversion_cast %[[S]]
@@ -402,6 +411,18 @@ func.func @shuffle(%v0 : vector<1xf32>, %v1: vector<1xf32>) -> vector<4xf32> {
// -----
+// CHECK-LABEL: func @shuffle_index_vector
+// CHECK-SAME: %[[ARG0:.+]]: vector<1xindex>, %[[ARG1:.+]]: vector<1xindex>
+// CHECK: %[[V0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
+// CHECK: %[[V1:.+]] = builtin.unrealized_conversion_cast %[[ARG1]]
+// CHECK: spirv.CompositeConstruct %[[V0]], %[[V1]], %[[V1]], %[[V0]] : (i32, i32, i32, i32) -> vector<4xi32>
+func.func @shuffle_index_vector(%v0 : vector<1xindex>, %v1: vector<1xindex>) -> vector<4xindex> {
+ %shuffle = vector.shuffle %v0, %v1 [0, 1, 1, 0] : vector<1xindex>, vector<1xindex>
+ return %shuffle : vector<4xindex>
+}
+
+// -----
+
// CHECK-LABEL: func @shuffle
// CHECK-SAME: %[[V0:.+]]: vector<3xf32>, %[[V1:.+]]: vector<3xf32>
// CHECK: spirv.VectorShuffle [3 : i32, 2 : i32, 5 : i32, 1 : i32] %[[V0]] : vector<3xf32>, %[[V1]] : vector<3xf32> -> vector<4xf32>