[Mlir-commits] [mlir] [mlir][spirv] Support `vector.step` in vector to spirv conversion (PR #100651)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jul 25 13:53:55 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Angel Zhang (angelz913)

<details>
<summary>Changes</summary>

Added a conversion pattern and LIT tests for lowering `vector.step` to SPIR-V. Related issue: #<!-- -->100602 


---
Full diff: https://github.com/llvm/llvm-project/pull/100651.diff


2 Files Affected:

- (modified) mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp (+39-2) 
- (modified) mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir (+28) 


``````````diff
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 527fbe5cf628a..8b5789f9e8497 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -906,6 +906,42 @@ struct VectorReductionToFPDotProd final
   }
 };
 
+struct VectorStepOpConvert final : OpConversionPattern<vector::StepOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+    Type dstType = typeConverter.convertType(stepOp.getType());
+    if (!dstType)
+      return failure();
+
+    Location loc = stepOp.getLoc();
+    int64_t numElements = stepOp.getType().getNumElements();
+    auto intType =
+        rewriter.getIntegerType(typeConverter.getIndexTypeBitwidth());
+
+    // Input vectors of size 1 are converted to scalars by the type converter.
+    // We just create a constant in this case.
+    if (numElements == 1) {
+      Value zero = spirv::ConstantOp::getZero(intType, loc, rewriter);
+      rewriter.replaceOp(stepOp, zero);
+      return success();
+    }
+
+    SmallVector<Value> source;
+    for (int64_t i = 0; i < numElements; ++i) {
+      Attribute intAttr = rewriter.getIntegerAttr(intType, i);
+      Value constOp = rewriter.create<spirv::ConstantOp>(loc, intType, intAttr);
+      source.push_back(constOp);
+    }
+    rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(stepOp, dstType,
+                                                             source);
+    return success();
+  }
+};
+
 } // namespace
 #define CL_INT_MAX_MIN_OPS                                                     \
   spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
@@ -929,8 +965,9 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
       VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
       VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
       VectorInterleaveOpConvert, VectorDeinterleaveOpConvert,
-      VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
-      typeConverter, patterns.getContext(), PatternBenefit(1));
+      VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter,
+      VectorStepOpConvert>(typeConverter, patterns.getContext(),
+                           PatternBenefit(1));
 
   // Make sure that the more specialized dot product pattern has higher benefit
   // than the generic one that extracts all elements.
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index edad208749930..016c9e141a712 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -794,6 +794,34 @@ func.func @shape_cast_size1_vector(%arg0 : vector<f32>) -> vector<1xf32> {
 
 // -----
 
+// CHECK-LABEL: @step
+//  CHECK-SAME: ()
+//       CHECK:   %[[CST0:.*]] = spirv.Constant 0 : i32
+//       CHECK:   %[[CST1:.*]] = spirv.Constant 1 : i32
+//       CHECK:   %[[CST2:.*]] = spirv.Constant 2 : i32
+//       CHECK:   %[[CST3:.*]] = spirv.Constant 3 : i32
+//       CHECK:   %[[CONSTRUCT:.*]] = spirv.CompositeConstruct %[[CST0]], %[[CST1]], %[[CST2]], %[[CST3]] : (i32, i32, i32, i32) -> vector<4xi32>
+//       CHECK:   %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[CONSTRUCT]] : vector<4xi32> to vector<4xindex>
+//       CHECK:   return %[[CAST]] : vector<4xindex>
+func.func @step() -> vector<4xindex> {
+  %0 = vector.step : vector<4xindex>
+  return %0 : vector<4xindex>
+}
+
+// -----
+
+// CHECK-LABEL: @step_size1
+//  CHECK-SAME: ()
+//       CHECK:   %[[CST0:.*]] = spirv.Constant 0 : i32
+//       CHECK:   %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[CST0]] : i32 to vector<1xindex>
+//       CHECK:   return %[[CAST]] : vector<1xindex>
+func.func @step_size1() -> vector<1xindex> {
+  %0 = vector.step : vector<1xindex>
+  return %0 : vector<1xindex>
+}
+
+// -----
+
 module attributes {
   spirv.target_env = #spirv.target_env<
     #spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>

``````````

</details>


https://github.com/llvm/llvm-project/pull/100651


More information about the Mlir-commits mailing list