From 71703a097859a24883aa32c3ee258647412c311e Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Mon, 15 May 2023 22:36:08 +0000 Subject: [mlir][spirv] Check type legality using converter for vectors This allows `index` vectors to be converted to SPIR-V. Reviewed By: ThomasRaoux Differential Revision: https://reviews.llvm.org/D150616 --- mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 14 +++++++++----- .../Conversion/VectorToSPIRV/vector-to-spirv.mlir | 21 +++++++++++++++++++++ 2 files changed, 30 insertions(+), 5 deletions(-) (limited to 'mlir') diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index 35171b3e077e..a4f20c610500 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -196,6 +196,12 @@ struct VectorInsertOpConvert final LogicalResult matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + if (isa(insertOp.getSourceType())) + return rewriter.notifyMatchFailure(insertOp, "unsupported vector source"); + if (!getTypeConverter()->convertType(insertOp.getDestVectorType())) + return rewriter.notifyMatchFailure(insertOp, + "unsupported dest vector type"); + // Special case for inserting scalar values into size-1 vectors. if (insertOp.getSourceType().isIntOrFloat() && insertOp.getDestVectorType().getNumElements() == 1) { @@ -203,9 +209,6 @@ struct VectorInsertOpConvert final return success(); } - if (isa(insertOp.getSourceType()) || - !spirv::CompositeType::isValid(insertOp.getDestVectorType())) - return failure(); int32_t id = getFirstIntValue(insertOp.getPosition()); rewriter.replaceOpWithNewOp( insertOp, adaptor.getSource(), adaptor.getDest(), id); @@ -413,9 +416,10 @@ struct VectorShuffleOpConvert final matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto oldResultType = shuffleOp.getResultVectorType(); - if (!spirv::CompositeType::isValid(oldResultType)) - return failure(); Type newResultType = getTypeConverter()->convertType(oldResultType); + if (!newResultType) + return rewriter.notifyMatchFailure(shuffleOp, + "unsupported result vector type"); auto oldSourceType = shuffleOp.getV1VectorType(); if (oldSourceType.getNumElements() > 1) { diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir index 26a2ab6d6243..bedd3d11e6f9 100644 --- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir @@ -183,6 +183,15 @@ func.func @insert(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> { // ----- +// CHECK-LABEL: @insert_index_vector +// CHECK: spirv.CompositeInsert %{{.+}}, %{{.+}}[2 : i32] : i32 into vector<4xi32> +func.func @insert_index_vector(%arg0 : vector<4xindex>, %arg1: index) -> vector<4xindex> { + %1 = vector.insert %arg1, %arg0[2] : index into vector<4xindex> + return %1: vector<4xindex> +} + +// ----- + // CHECK-LABEL: @insert_size1_vector // CHECK-SAME: %[[V:.*]]: vector<1xf32>, %[[S:.*]]: f32 // CHECK: %[[R:.+]] = builtin.unrealized_conversion_cast %[[S]] @@ -402,6 +411,18 @@ func.func @shuffle(%v0 : vector<1xf32>, %v1: vector<1xf32>) -> vector<4xf32> { // ----- +// CHECK-LABEL: func @shuffle_index_vector +// CHECK-SAME: %[[ARG0:.+]]: vector<1xindex>, %[[ARG1:.+]]: vector<1xindex> +// CHECK: %[[V0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] +// CHECK: %[[V1:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] +// CHECK: spirv.CompositeConstruct %[[V0]], %[[V1]], %[[V1]], %[[V0]] : (i32, i32, i32, i32) -> vector<4xi32> +func.func @shuffle_index_vector(%v0 : vector<1xindex>, %v1: vector<1xindex>) -> vector<4xindex> { + %shuffle = vector.shuffle %v0, %v1 [0, 1, 1, 0] : vector<1xindex>, vector<1xindex> + return %shuffle : vector<4xindex> +} + +// ----- + // CHECK-LABEL: func @shuffle // CHECK-SAME: %[[V0:.+]]: vector<3xf32>, %[[V1:.+]]: vector<3xf32> // CHECK: spirv.VectorShuffle [3 : i32, 2 : i32, 5 : i32, 1 : i32] %[[V0]] : vector<3xf32>, %[[V1]] : vector<3xf32> -> vector<4xf32> -- cgit v1.2.1