[Mlir-commits] [mlir] [mlir][spirv][vector] Support converting vector.from_elements to SPIR-V (PR #118540)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Dec 3 12:31:08 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-spirv
Author: Andrea Faulds (andfau-amd)
<details>
<summary>Changes</summary>
---
Full diff: https://github.com/llvm/llvm-project/pull/118540.diff
2 Files Affected:
- (modified) mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp (+31-2)
- (modified) mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir (+19)
``````````diff
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 656b1cb3e99a1d..a2dbbab34c1db7 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -220,6 +220,34 @@ struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> {
}
};
+struct VectorFromElementsOpConvert final
+ : public OpConversionPattern<vector::FromElementsOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::FromElementsOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Type resultType = getTypeConverter()->convertType(op.getType());
+ auto elements = op.getElements();
+ if (!resultType)
+ return failure();
+ if (isa<spirv::ScalarType>(resultType)) {
+ // In the case with a single scalar operand / single-element result,
+ // pass through the scalar.
+ rewriter.replaceOp(op, elements[0]);
+ return success();
+ } else if (cast<VectorType>(resultType).getRank() == 1) {
+ // SPIRVTypeConverter rejects vectors with rank > 1, so the
+ // multi-dimensional vector.from_elements cases do not need to be handled,
+ // only a simple flat vector.
+ rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, resultType,
+ elements);
+ return success();
+ }
+ return failure();
+ }
+};
+
struct VectorInsertOpConvert final
: public OpConversionPattern<vector::InsertOp> {
using OpConversionPattern::OpConversionPattern;
@@ -952,8 +980,9 @@ void mlir::populateVectorToSPIRVPatterns(
VectorBitcastConvert, VectorBroadcastConvert,
VectorExtractElementOpConvert, VectorExtractOpConvert,
VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
- VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert,
- VectorInsertOpConvert, VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
+ VectorFmaOpConvert<spirv::CLFmaOp>, VectorFromElementsOpConvert,
+ VectorInsertElementOpConvert, VectorInsertOpConvert,
+ VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 8796f153c4911b..f9dbe527af2c56 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -217,6 +217,25 @@ func.func @extract_dynamic_cst(%arg0 : vector<4xf32>) -> f32 {
// -----
+// CHECK-LABEL: @from_elements_0d
+// CHECK-SAME: %[[ARG0:.+]]: f32
+// CHECK: %[[RETVAL:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
+// CHECK: return %[[RETVAL]]
+func.func @from_elements_0d(%arg0 : f32) -> vector<f32> {
+ %0 = vector.from_elements %arg0 : vector<f32>
+ return %0: vector<f32>
+}
+
+// CHECK-LABEL: @from_elements_1d
+// CHECK-SAME: %[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32
+// CHECK: spirv.CompositeConstruct %[[ARG0]], %[[ARG1]], %[[ARG2]] : (f32, f32, f32) -> vector<3xf32>
+func.func @from_elements_1d(%arg0 : f32, %arg1 : f32, %arg2 : f32) -> vector<3xf32> {
+ %0 = vector.from_elements %arg0, %arg1, %arg2 : vector<3xf32>
+ return %0: vector<3xf32>
+}
+
+// -----
+
// CHECK-LABEL: @insert
// CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[S:.*]]: f32
// CHECK: spirv.CompositeInsert %[[S]], %[[V]][2 : i32] : f32 into vector<4xf32>
``````````
</details>
https://github.com/llvm/llvm-project/pull/118540
More information about the Mlir-commits
mailing list