[Mlir-commits] [mlir] 405d6ca - [mlir] Add pattern to handle trivial shape_cast in SPIR-V

Jerry Wu llvmlistbot at llvm.org
Mon Jun 26 16:10:51 PDT 2023


Author: Jerry Wu
Date: 2023-06-26T23:10:43Z
New Revision: 405d6cada4286bf58c2a2ee6858108a34e623b97

URL: https://github.com/llvm/llvm-project/commit/405d6cada4286bf58c2a2ee6858108a34e623b97
DIFF: https://github.com/llvm/llvm-project/commit/405d6cada4286bf58c2a2ee6858108a34e623b97.diff

LOG: [mlir] Add pattern to handle trivial shape_cast in SPIR-V

Handle the trivial case of size-1 vector.shape_cast.

Differential Revision: https://reviews.llvm.org/D153719

Added: 
    

Modified: 
    mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
    mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 2bc1e36855f1f..903441943f200 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -50,6 +50,29 @@ static int getNumBits(Type type) {
 
 namespace {
 
+struct VectorShapeCast final : public OpConversionPattern<vector::ShapeCastOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::ShapeCastOp shapeCastOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Type dstType = getTypeConverter()->convertType(shapeCastOp.getType());
+    if (!dstType)
+      return failure();
+
+    // If dstType is same as the source type or the vector size is 1, it can be
+    // directly replaced by the source.
+    if (dstType == adaptor.getSource().getType() ||
+        shapeCastOp.getResultVectorType().getNumElements() == 1) {
+      rewriter.replaceOp(shapeCastOp, adaptor.getSource());
+      return success();
+    }
+
+    // Lowering for size-n vectors when n > 1 hasn't been implemented.
+    return failure();
+  }
+};
+
 struct VectorBitcastConvert final
     : public OpConversionPattern<vector::BitCastOp> {
   using OpConversionPattern::OpConversionPattern;
@@ -551,15 +574,15 @@ struct VectorReductionToDotProd final : OpRewritePattern<vector::ReductionOp> {
 
 void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                          RewritePatternSet &patterns) {
-  patterns.add<
-      VectorBitcastConvert, VectorBroadcastConvert,
-      VectorExtractElementOpConvert, VectorExtractOpConvert,
-      VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
-      VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert,
-      VectorInsertOpConvert, VectorReductionPattern<GL_MAX_MIN_OPS>,
-      VectorReductionPattern<CL_MAX_MIN_OPS>, VectorInsertStridedSliceOpConvert,
-      VectorShuffleOpConvert, VectorSplatPattern>(typeConverter,
-                                                  patterns.getContext());
+  patterns.add<VectorBitcastConvert, VectorBroadcastConvert,
+               VectorExtractElementOpConvert, VectorExtractOpConvert,
+               VectorExtractStridedSliceOpConvert,
+               VectorFmaOpConvert<spirv::GLFmaOp>,
+               VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert,
+               VectorInsertOpConvert, VectorReductionPattern<GL_MAX_MIN_OPS>,
+               VectorReductionPattern<CL_MAX_MIN_OPS>, VectorShapeCast,
+               VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
+               VectorSplatPattern>(typeConverter, patterns.getContext());
 }
 
 void mlir::populateVectorReductionToSPIRVDotProductPatterns(

diff  --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index bedd3d11e6f93..535da1328d34c 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -568,3 +568,25 @@ func.func @reduction_minui(%v : vector<3xi32>, %s: i32) -> i32 {
   %reduce = vector.reduction <minui>, %v, %s : vector<3xi32> into i32
   return %reduce : i32
 }
+
+// -----
+
+// CHECK-LABEL: @shape_cast_same_type
+//  CHECK-SAME: (%[[ARG0:.*]]: vector<2xf32>)
+//       CHECK:   return %[[ARG0]]
+func.func @shape_cast_same_type(%arg0 : vector<2xf32>) -> vector<2xf32> {
+  %1 = vector.shape_cast %arg0 : vector<2xf32> to vector<2xf32>
+  return %arg0 : vector<2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @shape_cast_size1_vector
+//  CHECK-SAME: (%[[ARG0:.*]]: vector<f32>)
+//       CHECK:   %[[R0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : vector<f32> to f32
+//       CHECK:   %[[R1:.+]] = builtin.unrealized_conversion_cast %[[R0]] : f32 to vector<1xf32>
+//       CHECK:   return %[[R1]]
+func.func @shape_cast_size1_vector(%arg0 : vector<f32>) -> vector<1xf32> {
+  %1 = vector.shape_cast %arg0 : vector<f32> to vector<1xf32>
+  return %1 : vector<1xf32>
+}


        


More information about the Mlir-commits mailing list