[Mlir-commits] [mlir] ff81cc8 - [mlir][spirv] Improve vector extract/insert element conversion

Lei Zhang llvmlistbot at llvm.org
Wed Nov 30 16:41:06 PST 2022


Author: Lei Zhang
Date: 2022-12-01T00:35:41Z
New Revision: ff81cc824f2867e8b9f8504549fb2592faea4567

URL: https://github.com/llvm/llvm-project/commit/ff81cc824f2867e8b9f8504549fb2592faea4567
DIFF: https://github.com/llvm/llvm-project/commit/ff81cc824f2867e8b9f8504549fb2592faea4567.diff

LOG: [mlir][spirv] Improve vector extract/insert element conversion

* Fix type conversions around positions--we need to use the
  converted value from the adaptor.
* Convert constant position cases to composite extract/insert.

Reviewed By: kuhar

Differential Revision: https://reviews.llvm.org/D139057

Added: 
    

Modified: 
    mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
    mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 38696d062fbe5..11608db409f71 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -19,6 +19,7 @@
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Matchers.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
@@ -191,19 +192,23 @@ struct VectorExtractElementOpConvert final
   LogicalResult
   matchAndRewrite(vector::ExtractElementOp extractOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    Type vectorType =
-        getTypeConverter()->convertType(adaptor.getVector().getType());
-    if (!vectorType)
+    Type resultType = getTypeConverter()->convertType(extractOp.getType());
+    if (!resultType)
       return failure();
 
-    if (vectorType.isa<spirv::ScalarType>()) {
+    if (adaptor.getVector().getType().isa<spirv::ScalarType>()) {
       rewriter.replaceOp(extractOp, adaptor.getVector());
       return success();
     }
 
-    rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
-        extractOp, extractOp.getType(), adaptor.getVector(),
-        extractOp.getPosition());
+    APInt cstPos;
+    if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos)))
+      rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
+          extractOp, resultType, adaptor.getVector(),
+          rewriter.getI32ArrayAttr({static_cast<int>(cstPos.getSExtValue())}));
+    else
+      rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
+          extractOp, resultType, adaptor.getVector(), adaptor.getPosition());
     return success();
   }
 };
@@ -224,9 +229,15 @@ struct VectorInsertElementOpConvert final
       return success();
     }
 
-    rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
-        insertOp, vectorType, insertOp.getDest(), adaptor.getSource(),
-        insertOp.getPosition());
+    APInt cstPos;
+    if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos)))
+      rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
+          insertOp, adaptor.getSource(), adaptor.getDest(),
+          cstPos.getSExtValue());
+    else
+      rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
+          insertOp, vectorType, insertOp.getDest(), adaptor.getSource(),
+          adaptor.getPosition());
     return success();
   }
 };

diff  --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index a1f05e9a620c9..4bb1835e1d928 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -187,9 +187,20 @@ func.func @extract_element(%arg0 : vector<4xf32>, %id : i32) -> f32 {
 
 // -----
 
+// CHECK-LABEL: @extract_element_cst
+//  CHECK-SAME: %[[V:.*]]: vector<4xf32>
+//       CHECK:   spirv.CompositeExtract %[[V]][1 : i32] : vector<4xf32>
+func.func @extract_element_cst(%arg0 : vector<4xf32>) -> f32 {
+  %idx = arith.constant 1 : i32
+  %0 = vector.extractelement %arg0[%idx : i32] : vector<4xf32>
+  return %0: f32
+}
+
+// -----
+
 // CHECK-LABEL: @extract_element_index
 func.func @extract_element_index(%arg0 : vector<4xf32>, %id : index) -> f32 {
-  // CHECK: vector.extractelement
+  // CHECK: spirv.VectorExtractDynamic
   %0 = vector.extractelement %arg0[%id : index] : vector<4xf32>
   return %0: f32
 }
@@ -249,9 +260,20 @@ func.func @insert_element(%val: f32, %arg0 : vector<4xf32>, %id : i32) -> vector
 
 // -----
 
+// CHECK-LABEL: @insert_element_cst
+//  CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>
+//       CHECK:   spirv.CompositeInsert %[[VAL]], %[[V]][2 : i32] : f32 into vector<4xf32>
+func.func @insert_element_cst(%val: f32, %arg0 : vector<4xf32>) -> vector<4xf32> {
+  %idx = arith.constant 2 : i32
+  %0 = vector.insertelement %val, %arg0[%idx : i32] : vector<4xf32>
+  return %0: vector<4xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @insert_element_index
 func.func @insert_element_index(%val: f32, %arg0 : vector<4xf32>, %id : index) -> vector<4xf32> {
-  // CHECK: vector.insertelement
+  // CHECK: spirv.VectorInsertDynamic
   %0 = vector.insertelement %val, %arg0[%id : index] : vector<4xf32>
   return %0: vector<4xf32>
 }


        


More information about the Mlir-commits mailing list