[Mlir-commits] [mlir] [mlir][spirv] Support `vector.step` in vector to spirv conversion (PR #100651)
Angel Zhang
llvmlistbot at llvm.org
Thu Jul 25 13:53:23 PDT 2024
https://github.com/angelz913 created https://github.com/llvm/llvm-project/pull/100651
Added a conversion pattern and LIT tests for lowering `vector.step` to SPIR-V. Related issue: #100602
>From 2b147c8c71e26fed679b978766ea34899a6ba0d4 Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Thu, 25 Jul 2024 20:49:20 +0000
Subject: [PATCH] [mlir][spirv] Support vector.step in vector to spirv
conversion
---
.../VectorToSPIRV/VectorToSPIRV.cpp | 41 ++++++++++++++++++-
.../VectorToSPIRV/vector-to-spirv.mlir | 28 +++++++++++++
2 files changed, 67 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 527fbe5cf628a..8b5789f9e8497 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -906,6 +906,42 @@ struct VectorReductionToFPDotProd final
}
};
+struct VectorStepOpConvert final : OpConversionPattern<vector::StepOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+ Type dstType = typeConverter.convertType(stepOp.getType());
+ if (!dstType)
+ return failure();
+
+ Location loc = stepOp.getLoc();
+ int64_t numElements = stepOp.getType().getNumElements();
+ auto intType =
+ rewriter.getIntegerType(typeConverter.getIndexTypeBitwidth());
+
+ // Input vectors of size 1 are converted to scalars by the type converter.
+ // We just create a constant in this case.
+ if (numElements == 1) {
+ Value zero = spirv::ConstantOp::getZero(intType, loc, rewriter);
+ rewriter.replaceOp(stepOp, zero);
+ return success();
+ }
+
+ SmallVector<Value> source;
+ for (int64_t i = 0; i < numElements; ++i) {
+ Attribute intAttr = rewriter.getIntegerAttr(intType, i);
+ Value constOp = rewriter.create<spirv::ConstantOp>(loc, intType, intAttr);
+ source.push_back(constOp);
+ }
+ rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(stepOp, dstType,
+ source);
+ return success();
+ }
+};
+
} // namespace
#define CL_INT_MAX_MIN_OPS \
spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
@@ -929,8 +965,9 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
VectorInterleaveOpConvert, VectorDeinterleaveOpConvert,
- VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
- typeConverter, patterns.getContext(), PatternBenefit(1));
+ VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter,
+ VectorStepOpConvert>(typeConverter, patterns.getContext(),
+ PatternBenefit(1));
// Make sure that the more specialized dot product pattern has higher benefit
// than the generic one that extracts all elements.
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index edad208749930..016c9e141a712 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -794,6 +794,34 @@ func.func @shape_cast_size1_vector(%arg0 : vector<f32>) -> vector<1xf32> {
// -----
+// CHECK-LABEL: @step
+// CHECK-SAME: ()
+// CHECK: %[[CST0:.*]] = spirv.Constant 0 : i32
+// CHECK: %[[CST1:.*]] = spirv.Constant 1 : i32
+// CHECK: %[[CST2:.*]] = spirv.Constant 2 : i32
+// CHECK: %[[CST3:.*]] = spirv.Constant 3 : i32
+// CHECK: %[[CONSTRUCT:.*]] = spirv.CompositeConstruct %[[CST0]], %[[CST1]], %[[CST2]], %[[CST3]] : (i32, i32, i32, i32) -> vector<4xi32>
+// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[CONSTRUCT]] : vector<4xi32> to vector<4xindex>
+// CHECK: return %[[CAST]] : vector<4xindex>
+func.func @step() -> vector<4xindex> {
+ %0 = vector.step : vector<4xindex>
+ return %0 : vector<4xindex>
+}
+
+// -----
+
+// CHECK-LABEL: @step_size1
+// CHECK-SAME: ()
+// CHECK: %[[CST0:.*]] = spirv.Constant 0 : i32
+// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[CST0]] : i32 to vector<1xindex>
+// CHECK: return %[[CAST]] : vector<1xindex>
+func.func @step_size1() -> vector<1xindex> {
+ %0 = vector.step : vector<1xindex>
+ return %0 : vector<1xindex>
+}
+
+// -----
+
module attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
More information about the Mlir-commits
mailing list