[Mlir-commits] [mlir] 71703a0 - [mlir][spirv] Check type legality using converter for vectors

Lei Zhang llvmlistbot at llvm.org
Mon May 15 16:31:38 PDT 2023


Author: Lei Zhang
Date: 2023-05-15T23:29:37Z
New Revision: 71703a097859a24883aa32c3ee258647412c311e

URL: https://github.com/llvm/llvm-project/commit/71703a097859a24883aa32c3ee258647412c311e
DIFF: https://github.com/llvm/llvm-project/commit/71703a097859a24883aa32c3ee258647412c311e.diff

LOG: [mlir][spirv] Check type legality using converter for vectors

This allows `index` vectors to be converted to SPIR-V.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D150616

Added: 
    

Modified: 
    mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
    mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 35171b3e077ee..a4f20c610500c 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -196,6 +196,12 @@ struct VectorInsertOpConvert final
   LogicalResult
   matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    if (isa<VectorType>(insertOp.getSourceType()))
+      return rewriter.notifyMatchFailure(insertOp, "unsupported vector source");
+    if (!getTypeConverter()->convertType(insertOp.getDestVectorType()))
+      return rewriter.notifyMatchFailure(insertOp,
+                                         "unsupported dest vector type");
+
     // Special case for inserting scalar values into size-1 vectors.
     if (insertOp.getSourceType().isIntOrFloat() &&
         insertOp.getDestVectorType().getNumElements() == 1) {
@@ -203,9 +209,6 @@ struct VectorInsertOpConvert final
       return success();
     }
 
-    if (isa<VectorType>(insertOp.getSourceType()) ||
-        !spirv::CompositeType::isValid(insertOp.getDestVectorType()))
-      return failure();
     int32_t id = getFirstIntValue(insertOp.getPosition());
     rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
         insertOp, adaptor.getSource(), adaptor.getDest(), id);
@@ -413,9 +416,10 @@ struct VectorShuffleOpConvert final
   matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto oldResultType = shuffleOp.getResultVectorType();
-    if (!spirv::CompositeType::isValid(oldResultType))
-      return failure();
     Type newResultType = getTypeConverter()->convertType(oldResultType);
+    if (!newResultType)
+      return rewriter.notifyMatchFailure(shuffleOp,
+                                         "unsupported result vector type");
 
     auto oldSourceType = shuffleOp.getV1VectorType();
     if (oldSourceType.getNumElements() > 1) {

diff  --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 26a2ab6d62436..bedd3d11e6f93 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -183,6 +183,15 @@ func.func @insert(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> {
 
 // -----
 
+// CHECK-LABEL: @insert_index_vector
+//       CHECK:   spirv.CompositeInsert %{{.+}}, %{{.+}}[2 : i32] : i32 into vector<4xi32>
+func.func @insert_index_vector(%arg0 : vector<4xindex>, %arg1: index) -> vector<4xindex> {
+  %1 = vector.insert %arg1, %arg0[2] : index into vector<4xindex>
+  return %1: vector<4xindex>
+}
+
+// -----
+
 // CHECK-LABEL: @insert_size1_vector
 //  CHECK-SAME: %[[V:.*]]: vector<1xf32>, %[[S:.*]]: f32
 //       CHECK:   %[[R:.+]] = builtin.unrealized_conversion_cast %[[S]]
@@ -402,6 +411,18 @@ func.func @shuffle(%v0 : vector<1xf32>, %v1: vector<1xf32>) -> vector<4xf32> {
 
 // -----
 
+// CHECK-LABEL:  func @shuffle_index_vector
+//  CHECK-SAME:  %[[ARG0:.+]]: vector<1xindex>, %[[ARG1:.+]]: vector<1xindex>
+//       CHECK:    %[[V0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
+//       CHECK:    %[[V1:.+]] = builtin.unrealized_conversion_cast %[[ARG1]]
+//       CHECK:    spirv.CompositeConstruct %[[V0]], %[[V1]], %[[V1]], %[[V0]] : (i32, i32, i32, i32) -> vector<4xi32>
+func.func @shuffle_index_vector(%v0 : vector<1xindex>, %v1: vector<1xindex>) -> vector<4xindex> {
+  %shuffle = vector.shuffle %v0, %v1 [0, 1, 1, 0] : vector<1xindex>, vector<1xindex>
+  return %shuffle : vector<4xindex>
+}
+
+// -----
+
 // CHECK-LABEL:  func @shuffle
 //  CHECK-SAME:  %[[V0:.+]]: vector<3xf32>, %[[V1:.+]]: vector<3xf32>
 //       CHECK:    spirv.VectorShuffle [3 : i32, 2 : i32, 5 : i32, 1 : i32] %[[V0]] : vector<3xf32>, %[[V1]] : vector<3xf32> -> vector<4xf32>


        


More information about the Mlir-commits mailing list