[Mlir-commits] [mlir] [mlir][vector][spirv] Lower `vector.to_elements` to SPIR-V (PR #146618)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jul 1 18:53:48 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-spirv
Author: Eric Feng (efric)
<details>
<summary>Changes</summary>
Implement `vector.to_elements` lowering to SPIR-V.
Addresses [145929](https://github.com/llvm/llvm-project/issues/145929).
---
Full diff: https://github.com/llvm/llvm-project/pull/146618.diff
2 Files Affected:
- (modified) mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp (+46-1)
- (modified) mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir (+35)
``````````diff
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index de2af69eba9ec..475fd76c667e6 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -1022,6 +1022,51 @@ struct VectorStepOpConvert final : OpConversionPattern<vector::StepOp> {
}
};
+struct VectorToElementOpConvert final
+ : public OpConversionPattern<vector::ToElementsOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ Type srcType =
+ getTypeConverter()->convertType(toElementsOp.getSource().getType());
+ if (!srcType)
+ return failure();
+
+ SmallVector<Value> results(toElementsOp->getNumResults());
+ Location loc = toElementsOp.getLoc();
+
+ // Input vectors of size 1 are converted to scalars by the type converter.
+ // We cannot use `spirv::CompositeExtractOp` directly in this case.
+ // For a scalar source, the result is just the scalar itself.
+ if (isa<spirv::ScalarType>(adaptor.getSource().getType())) {
+ results[0] = adaptor.getSource();
+ rewriter.replaceOp(toElementsOp, results);
+ return success();
+ }
+
+ for (auto [idx, element] : llvm::enumerate(toElementsOp.getElements())) {
+ // Create an CompositeExtract operation only for results that are not
+ // dead.
+ if (element.use_empty())
+ continue;
+
+ auto elementType = getTypeConverter()->convertType(element.getType());
+ if (!elementType)
+ return failure();
+ Value result = rewriter.create<spirv::CompositeExtractOp>(
+ loc, elementType, adaptor.getSource(),
+ rewriter.getI32ArrayAttr({static_cast<int32_t>(idx)}));
+ results[idx] = result;
+ }
+
+ rewriter.replaceOp(toElementsOp, results);
+ return success();
+ }
+};
+
} // namespace
#define CL_INT_MAX_MIN_OPS \
spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
@@ -1038,7 +1083,7 @@ void mlir::populateVectorToSPIRVPatterns(
VectorBitcastConvert, VectorBroadcastConvert,
VectorExtractElementOpConvert, VectorExtractOpConvert,
VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
- VectorFmaOpConvert<spirv::CLFmaOp>, VectorFromElementsOpConvert,
+ VectorFmaOpConvert<spirv::CLFmaOp>, VectorFromElementsOpConvert, VectorToElementOpConvert,
VectorInsertElementOpConvert, VectorInsertOpConvert,
VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 4701ac5d96009..99ab0e1dc4eef 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -246,6 +246,41 @@ func.func @extract_dynamic_cst(%arg0 : vector<4xf32>) -> f32 {
// -----
+// CHECK-LABEL: func.func @to_elements_one_element
+// CHECK-SAME: %[[A:.*]]: vector<1xf32>)
+// CHECK: %[[ELEM0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<1xf32> to f32
+// CHECK: return %[[ELEM0]] : f32
+func.func @to_elements_one_element(%a: vector<1xf32>) -> (f32) {
+ %0:1 = vector.to_elements %a : vector<1xf32>
+ return %0#0 : f32
+}
+
+// CHECK-LABEL: func.func @to_elements_no_dead_elements
+// CHECK-SAME: %[[A:.*]]: vector<4xf32>)
+// CHECK: %[[ELEM0:.*]] = spirv.CompositeExtract %[[A]][0 : i32] : vector<4xf32>
+// CHECK: %[[ELEM1:.*]] = spirv.CompositeExtract %[[A]][1 : i32] : vector<4xf32>
+// CHECK: %[[ELEM2:.*]] = spirv.CompositeExtract %[[A]][2 : i32] : vector<4xf32>
+// CHECK: %[[ELEM3:.*]] = spirv.CompositeExtract %[[A]][3 : i32] : vector<4xf32>
+// CHECK: return %[[ELEM0]], %[[ELEM1]], %[[ELEM2]], %[[ELEM3]] : f32, f32, f32, f32
+func.func @to_elements_no_dead_elements(%a: vector<4xf32>) -> (f32, f32, f32, f32) {
+ %0:4 = vector.to_elements %a : vector<4xf32>
+ return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
+}
+
+// CHECK-LABEL: func.func @to_elements_dead_elements
+// CHECK-SAME: %[[A:.*]]: vector<4xf32>)
+// CHECK-NOT: spirv.CompositeExtract %[[A]][0 : i32]
+// CHECK: %[[ELEM1:.*]] = spirv.CompositeExtract %[[A]][1 : i32] : vector<4xf32>
+// CHECK-NOT: spirv.CompositeExtract %[[A]][2 : i32]
+// CHECK: %[[ELEM3:.*]] = spirv.CompositeExtract %[[A]][3 : i32] : vector<4xf32>
+// CHECK: return %[[ELEM1]], %[[ELEM3]] : f32, f32
+func.func @to_elements_dead_elements(%a: vector<4xf32>) -> (f32, f32) {
+ %0:4 = vector.to_elements %a : vector<4xf32>
+ return %0#1, %0#3 : f32, f32
+}
+
+// -----
+
// CHECK-LABEL: @from_elements_0d
// CHECK-SAME: %[[ARG0:.+]]: f32
// CHECK: %[[RETVAL:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
``````````
</details>
https://github.com/llvm/llvm-project/pull/146618
More information about the Mlir-commits
mailing list