[Mlir-commits] [mlir] [mlir][spirv] Support `vector.step` in vector to spirv conversion (PR #100651)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jul 25 13:53:55 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Angel Zhang (angelz913)
<details>
<summary>Changes</summary>
Added a conversion pattern and LIT tests for lowering `vector.step` to SPIR-V. Related issue: #<!-- -->100602
---
Full diff: https://github.com/llvm/llvm-project/pull/100651.diff
2 Files Affected:
- (modified) mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp (+39-2)
- (modified) mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir (+28)
``````````diff
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 527fbe5cf628a..8b5789f9e8497 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -906,6 +906,42 @@ struct VectorReductionToFPDotProd final
}
};
+struct VectorStepOpConvert final : OpConversionPattern<vector::StepOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+ Type dstType = typeConverter.convertType(stepOp.getType());
+ if (!dstType)
+ return failure();
+
+ Location loc = stepOp.getLoc();
+ int64_t numElements = stepOp.getType().getNumElements();
+ auto intType =
+ rewriter.getIntegerType(typeConverter.getIndexTypeBitwidth());
+
+ // Input vectors of size 1 are converted to scalars by the type converter.
+ // We just create a constant in this case.
+ if (numElements == 1) {
+ Value zero = spirv::ConstantOp::getZero(intType, loc, rewriter);
+ rewriter.replaceOp(stepOp, zero);
+ return success();
+ }
+
+ SmallVector<Value> source;
+ for (int64_t i = 0; i < numElements; ++i) {
+ Attribute intAttr = rewriter.getIntegerAttr(intType, i);
+ Value constOp = rewriter.create<spirv::ConstantOp>(loc, intType, intAttr);
+ source.push_back(constOp);
+ }
+ rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(stepOp, dstType,
+ source);
+ return success();
+ }
+};
+
} // namespace
#define CL_INT_MAX_MIN_OPS \
spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
@@ -929,8 +965,9 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
VectorInterleaveOpConvert, VectorDeinterleaveOpConvert,
- VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
- typeConverter, patterns.getContext(), PatternBenefit(1));
+ VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter,
+ VectorStepOpConvert>(typeConverter, patterns.getContext(),
+ PatternBenefit(1));
// Make sure that the more specialized dot product pattern has higher benefit
// than the generic one that extracts all elements.
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index edad208749930..016c9e141a712 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -794,6 +794,34 @@ func.func @shape_cast_size1_vector(%arg0 : vector<f32>) -> vector<1xf32> {
// -----
+// CHECK-LABEL: @step
+// CHECK-SAME: ()
+// CHECK: %[[CST0:.*]] = spirv.Constant 0 : i32
+// CHECK: %[[CST1:.*]] = spirv.Constant 1 : i32
+// CHECK: %[[CST2:.*]] = spirv.Constant 2 : i32
+// CHECK: %[[CST3:.*]] = spirv.Constant 3 : i32
+// CHECK: %[[CONSTRUCT:.*]] = spirv.CompositeConstruct %[[CST0]], %[[CST1]], %[[CST2]], %[[CST3]] : (i32, i32, i32, i32) -> vector<4xi32>
+// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[CONSTRUCT]] : vector<4xi32> to vector<4xindex>
+// CHECK: return %[[CAST]] : vector<4xindex>
+func.func @step() -> vector<4xindex> {
+ %0 = vector.step : vector<4xindex>
+ return %0 : vector<4xindex>
+}
+
+// -----
+
+// CHECK-LABEL: @step_size1
+// CHECK-SAME: ()
+// CHECK: %[[CST0:.*]] = spirv.Constant 0 : i32
+// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[CST0]] : i32 to vector<1xindex>
+// CHECK: return %[[CAST]] : vector<1xindex>
+func.func @step_size1() -> vector<1xindex> {
+ %0 = vector.step : vector<1xindex>
+ return %0 : vector<1xindex>
+}
+
+// -----
+
module attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
``````````
</details>
https://github.com/llvm/llvm-project/pull/100651
More information about the Mlir-commits
mailing list