[Mlir-commits] [mlir] 599a91a - [mlir][spirv] Support `vector.step` in vector to spirv conversion (#100651)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jul 26 10:39:48 PDT 2024
Author: Angel Zhang
Date: 2024-07-26T13:39:44-04:00
New Revision: 599a91a7df6b75f93b91507e0caedd8dd1996641
URL: https://github.com/llvm/llvm-project/commit/599a91a7df6b75f93b91507e0caedd8dd1996641
DIFF: https://github.com/llvm/llvm-project/commit/599a91a7df6b75f93b91507e0caedd8dd1996641.diff
LOG: [mlir][spirv] Support `vector.step` in vector to spirv conversion (#100651)
Added a conversion pattern and LIT tests for lowering `vector.step` to
SPIR-V.
Fixes: #100602
---------
Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>
Added:
Modified:
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 527fbe5cf628a..890706bf1bb2e 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -906,6 +906,43 @@ struct VectorReductionToFPDotProd final
}
};
+struct VectorStepOpConvert final : OpConversionPattern<vector::StepOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+ Type dstType = typeConverter.convertType(stepOp.getType());
+ if (!dstType)
+ return failure();
+
+ Location loc = stepOp.getLoc();
+ int64_t numElements = stepOp.getType().getNumElements();
+ auto intType =
+ rewriter.getIntegerType(typeConverter.getIndexTypeBitwidth());
+
+ // Input vectors of size 1 are converted to scalars by the type converter.
+ // We just create a constant in this case.
+ if (numElements == 1) {
+ Value zero = spirv::ConstantOp::getZero(intType, loc, rewriter);
+ rewriter.replaceOp(stepOp, zero);
+ return success();
+ }
+
+ SmallVector<Value> source;
+ source.reserve(numElements);
+ for (int64_t i = 0; i < numElements; ++i) {
+ Attribute intAttr = rewriter.getIntegerAttr(intType, i);
+ Value constOp = rewriter.create<spirv::ConstantOp>(loc, intType, intAttr);
+ source.push_back(constOp);
+ }
+ rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(stepOp, dstType,
+ source);
+ return success();
+ }
+};
+
} // namespace
#define CL_INT_MAX_MIN_OPS \
spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
@@ -929,8 +966,9 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
VectorInterleaveOpConvert, VectorDeinterleaveOpConvert,
- VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
- typeConverter, patterns.getContext(), PatternBenefit(1));
+ VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter,
+ VectorStepOpConvert>(typeConverter, patterns.getContext(),
+ PatternBenefit(1));
// Make sure that the more specialized dot product pattern has higher benefit
// than the generic one that extracts all elements.
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index edad208749930..dd0ed77470a25 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -794,6 +794,32 @@ func.func @shape_cast_size1_vector(%arg0 : vector<f32>) -> vector<1xf32> {
// -----
+// CHECK-LABEL: @step()
+// CHECK: %[[CST0:.*]] = spirv.Constant 0 : i32
+// CHECK: %[[CST1:.*]] = spirv.Constant 1 : i32
+// CHECK: %[[CST2:.*]] = spirv.Constant 2 : i32
+// CHECK: %[[CST3:.*]] = spirv.Constant 3 : i32
+// CHECK: %[[CONSTRUCT:.*]] = spirv.CompositeConstruct %[[CST0]], %[[CST1]], %[[CST2]], %[[CST3]] : (i32, i32, i32, i32) -> vector<4xi32>
+// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[CONSTRUCT]] : vector<4xi32> to vector<4xindex>
+// CHECK: return %[[CAST]] : vector<4xindex>
+func.func @step() -> vector<4xindex> {
+ %0 = vector.step : vector<4xindex>
+ return %0 : vector<4xindex>
+}
+
+// -----
+
+// CHECK-LABEL: @step_size1()
+// CHECK: %[[CST0:.*]] = spirv.Constant 0 : i32
+// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[CST0]] : i32 to vector<1xindex>
+// CHECK: return %[[CAST]] : vector<1xindex>
+func.func @step_size1() -> vector<1xindex> {
+ %0 = vector.step : vector<1xindex>
+ return %0 : vector<1xindex>
+}
+
+// -----
+
module attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
More information about the Mlir-commits
mailing list