[Mlir-commits] [mlir] [mlir][spirv] Support `vector.step` in vector to spirv conversion (PR #100651)

Angel Zhang llvmlistbot at llvm.org
Thu Jul 25 13:53:23 PDT 2024


https://github.com/angelz913 created https://github.com/llvm/llvm-project/pull/100651

Added a conversion pattern and LIT tests for lowering `vector.step` to SPIR-V. Related issue: #100602 


>From 2b147c8c71e26fed679b978766ea34899a6ba0d4 Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Thu, 25 Jul 2024 20:49:20 +0000
Subject: [PATCH] [mlir][spirv] Support vector.step in vector to spirv
 conversion

---
 .../VectorToSPIRV/VectorToSPIRV.cpp           | 41 ++++++++++++++++++-
 .../VectorToSPIRV/vector-to-spirv.mlir        | 28 +++++++++++++
 2 files changed, 67 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 527fbe5cf628a..8b5789f9e8497 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -906,6 +906,42 @@ struct VectorReductionToFPDotProd final
   }
 };
 
+struct VectorStepOpConvert final : OpConversionPattern<vector::StepOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+    Type dstType = typeConverter.convertType(stepOp.getType());
+    if (!dstType)
+      return failure();
+
+    Location loc = stepOp.getLoc();
+    int64_t numElements = stepOp.getType().getNumElements();
+    auto intType =
+        rewriter.getIntegerType(typeConverter.getIndexTypeBitwidth());
+
+    // Input vectors of size 1 are converted to scalars by the type converter.
+    // We just create a constant in this case.
+    if (numElements == 1) {
+      Value zero = spirv::ConstantOp::getZero(intType, loc, rewriter);
+      rewriter.replaceOp(stepOp, zero);
+      return success();
+    }
+
+    SmallVector<Value> source;
+    for (int64_t i = 0; i < numElements; ++i) {
+      Attribute intAttr = rewriter.getIntegerAttr(intType, i);
+      Value constOp = rewriter.create<spirv::ConstantOp>(loc, intType, intAttr);
+      source.push_back(constOp);
+    }
+    rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(stepOp, dstType,
+                                                             source);
+    return success();
+  }
+};
+
 } // namespace
 #define CL_INT_MAX_MIN_OPS                                                     \
   spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
@@ -929,8 +965,9 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
       VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
       VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
       VectorInterleaveOpConvert, VectorDeinterleaveOpConvert,
-      VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
-      typeConverter, patterns.getContext(), PatternBenefit(1));
+      VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter,
+      VectorStepOpConvert>(typeConverter, patterns.getContext(),
+                           PatternBenefit(1));
 
   // Make sure that the more specialized dot product pattern has higher benefit
   // than the generic one that extracts all elements.
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index edad208749930..016c9e141a712 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -794,6 +794,34 @@ func.func @shape_cast_size1_vector(%arg0 : vector<f32>) -> vector<1xf32> {
 
 // -----
 
+// CHECK-LABEL: @step
+//  CHECK-SAME: ()
+//       CHECK:   %[[CST0:.*]] = spirv.Constant 0 : i32
+//       CHECK:   %[[CST1:.*]] = spirv.Constant 1 : i32
+//       CHECK:   %[[CST2:.*]] = spirv.Constant 2 : i32
+//       CHECK:   %[[CST3:.*]] = spirv.Constant 3 : i32
+//       CHECK:   %[[CONSTRUCT:.*]] = spirv.CompositeConstruct %[[CST0]], %[[CST1]], %[[CST2]], %[[CST3]] : (i32, i32, i32, i32) -> vector<4xi32>
+//       CHECK:   %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[CONSTRUCT]] : vector<4xi32> to vector<4xindex>
+//       CHECK:   return %[[CAST]] : vector<4xindex>
+func.func @step() -> vector<4xindex> {
+  %0 = vector.step : vector<4xindex>
+  return %0 : vector<4xindex>
+}
+
+// -----
+
+// CHECK-LABEL: @step_size1
+//  CHECK-SAME: ()
+//       CHECK:   %[[CST0:.*]] = spirv.Constant 0 : i32
+//       CHECK:   %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[CST0]] : i32 to vector<1xindex>
+//       CHECK:   return %[[CAST]] : vector<1xindex>
+func.func @step_size1() -> vector<1xindex> {
+  %0 = vector.step : vector<1xindex>
+  return %0 : vector<1xindex>
+}
+
+// -----
+
 module attributes {
   spirv.target_env = #spirv.target_env<
     #spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>



More information about the Mlir-commits mailing list