[Mlir-commits] [mlir] 71703a0 - [mlir][spirv] Check type legality using converter for vectors
Lei Zhang
llvmlistbot at llvm.org
Mon May 15 16:31:38 PDT 2023
Author: Lei Zhang
Date: 2023-05-15T23:29:37Z
New Revision: 71703a097859a24883aa32c3ee258647412c311e
URL: https://github.com/llvm/llvm-project/commit/71703a097859a24883aa32c3ee258647412c311e
DIFF: https://github.com/llvm/llvm-project/commit/71703a097859a24883aa32c3ee258647412c311e.diff
LOG: [mlir][spirv] Check type legality using converter for vectors
This allows `index` vectors to be converted to SPIR-V.
Reviewed By: ThomasRaoux
Differential Revision: https://reviews.llvm.org/D150616
Added:
Modified:
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 35171b3e077ee..a4f20c610500c 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -196,6 +196,12 @@ struct VectorInsertOpConvert final
LogicalResult
matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ if (isa<VectorType>(insertOp.getSourceType()))
+ return rewriter.notifyMatchFailure(insertOp, "unsupported vector source");
+ if (!getTypeConverter()->convertType(insertOp.getDestVectorType()))
+ return rewriter.notifyMatchFailure(insertOp,
+ "unsupported dest vector type");
+
// Special case for inserting scalar values into size-1 vectors.
if (insertOp.getSourceType().isIntOrFloat() &&
insertOp.getDestVectorType().getNumElements() == 1) {
@@ -203,9 +209,6 @@ struct VectorInsertOpConvert final
return success();
}
- if (isa<VectorType>(insertOp.getSourceType()) ||
- !spirv::CompositeType::isValid(insertOp.getDestVectorType()))
- return failure();
int32_t id = getFirstIntValue(insertOp.getPosition());
rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
insertOp, adaptor.getSource(), adaptor.getDest(), id);
@@ -413,9 +416,10 @@ struct VectorShuffleOpConvert final
matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto oldResultType = shuffleOp.getResultVectorType();
- if (!spirv::CompositeType::isValid(oldResultType))
- return failure();
Type newResultType = getTypeConverter()->convertType(oldResultType);
+ if (!newResultType)
+ return rewriter.notifyMatchFailure(shuffleOp,
+ "unsupported result vector type");
auto oldSourceType = shuffleOp.getV1VectorType();
if (oldSourceType.getNumElements() > 1) {
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 26a2ab6d62436..bedd3d11e6f93 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -183,6 +183,15 @@ func.func @insert(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> {
// -----
+// CHECK-LABEL: @insert_index_vector
+// CHECK: spirv.CompositeInsert %{{.+}}, %{{.+}}[2 : i32] : i32 into vector<4xi32>
+func.func @insert_index_vector(%arg0 : vector<4xindex>, %arg1: index) -> vector<4xindex> {
+ %1 = vector.insert %arg1, %arg0[2] : index into vector<4xindex>
+ return %1: vector<4xindex>
+}
+
+// -----
+
// CHECK-LABEL: @insert_size1_vector
// CHECK-SAME: %[[V:.*]]: vector<1xf32>, %[[S:.*]]: f32
// CHECK: %[[R:.+]] = builtin.unrealized_conversion_cast %[[S]]
@@ -402,6 +411,18 @@ func.func @shuffle(%v0 : vector<1xf32>, %v1: vector<1xf32>) -> vector<4xf32> {
// -----
+// CHECK-LABEL: func @shuffle_index_vector
+// CHECK-SAME: %[[ARG0:.+]]: vector<1xindex>, %[[ARG1:.+]]: vector<1xindex>
+// CHECK: %[[V0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
+// CHECK: %[[V1:.+]] = builtin.unrealized_conversion_cast %[[ARG1]]
+// CHECK: spirv.CompositeConstruct %[[V0]], %[[V1]], %[[V1]], %[[V0]] : (i32, i32, i32, i32) -> vector<4xi32>
+func.func @shuffle_index_vector(%v0 : vector<1xindex>, %v1: vector<1xindex>) -> vector<4xindex> {
+ %shuffle = vector.shuffle %v0, %v1 [0, 1, 1, 0] : vector<1xindex>, vector<1xindex>
+ return %shuffle : vector<4xindex>
+}
+
+// -----
+
// CHECK-LABEL: func @shuffle
// CHECK-SAME: %[[V0:.+]]: vector<3xf32>, %[[V1:.+]]: vector<3xf32>
// CHECK: spirv.VectorShuffle [3 : i32, 2 : i32, 5 : i32, 1 : i32] %[[V0]] : vector<3xf32>, %[[V1]] : vector<3xf32> -> vector<4xf32>
More information about the Mlir-commits
mailing list