[Mlir-commits] [mlir] 274152c - [mlir][vector][spirv] Lower `vector.to_elements` to SPIR-V (#146618)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jul 2 12:56:22 PDT 2025
Author: Eric Feng
Date: 2025-07-02T15:56:18-04:00
New Revision: 274152c5fa9f642d5ce6317ca24c0f2f27a53576
URL: https://github.com/llvm/llvm-project/commit/274152c5fa9f642d5ce6317ca24c0f2f27a53576
DIFF: https://github.com/llvm/llvm-project/commit/274152c5fa9f642d5ce6317ca24c0f2f27a53576.diff
LOG: [mlir][vector][spirv] Lower `vector.to_elements` to SPIR-V (#146618)
Implement `vector.to_elements` lowering to SPIR-V.
Fixes: https://github.com/llvm/llvm-project/issues/145929
Added:
Modified:
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index de2af69eba9ec..21d8e1d9f1156 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -1022,6 +1022,51 @@ struct VectorStepOpConvert final : OpConversionPattern<vector::StepOp> {
}
};
+struct VectorToElementOpConvert final
+ : OpConversionPattern<vector::ToElementsOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ SmallVector<Value> results(toElementsOp->getNumResults());
+ Location loc = toElementsOp.getLoc();
+
+ // Input vectors of size 1 are converted to scalars by the type converter.
+ // We cannot use `spirv::CompositeExtractOp` directly in this case.
+ // For a scalar source, the result is just the scalar itself.
+ if (isa<spirv::ScalarType>(adaptor.getSource().getType())) {
+ results[0] = adaptor.getSource();
+ rewriter.replaceOp(toElementsOp, results);
+ return success();
+ }
+
+ Type srcElementType = toElementsOp.getElements().getType().front();
+ Type elementType = getTypeConverter()->convertType(srcElementType);
+ if (!elementType)
+ return rewriter.notifyMatchFailure(
+ toElementsOp,
+ llvm::formatv("failed to convert element type '{0}' to SPIR-V",
+ srcElementType));
+
+ for (auto [idx, element] : llvm::enumerate(toElementsOp.getElements())) {
+ // Create an CompositeExtract operation only for results that are not
+ // dead.
+ if (element.use_empty())
+ continue;
+
+ Value result = rewriter.create<spirv::CompositeExtractOp>(
+ loc, elementType, adaptor.getSource(),
+ rewriter.getI32ArrayAttr({static_cast<int32_t>(idx)}));
+ results[idx] = result;
+ }
+
+ rewriter.replaceOp(toElementsOp, results);
+ return success();
+ }
+};
+
} // namespace
#define CL_INT_MAX_MIN_OPS \
spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
@@ -1039,8 +1084,8 @@ void mlir::populateVectorToSPIRVPatterns(
VectorExtractElementOpConvert, VectorExtractOpConvert,
VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
VectorFmaOpConvert<spirv::CLFmaOp>, VectorFromElementsOpConvert,
- VectorInsertElementOpConvert, VectorInsertOpConvert,
- VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
+ VectorToElementOpConvert, VectorInsertElementOpConvert,
+ VectorInsertOpConvert, VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 4701ac5d96009..99ab0e1dc4eef 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -246,6 +246,41 @@ func.func @extract_dynamic_cst(%arg0 : vector<4xf32>) -> f32 {
// -----
+// CHECK-LABEL: func.func @to_elements_one_element
+// CHECK-SAME: %[[A:.*]]: vector<1xf32>)
+// CHECK: %[[ELEM0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<1xf32> to f32
+// CHECK: return %[[ELEM0]] : f32
+func.func @to_elements_one_element(%a: vector<1xf32>) -> (f32) {
+ %0:1 = vector.to_elements %a : vector<1xf32>
+ return %0#0 : f32
+}
+
+// CHECK-LABEL: func.func @to_elements_no_dead_elements
+// CHECK-SAME: %[[A:.*]]: vector<4xf32>)
+// CHECK: %[[ELEM0:.*]] = spirv.CompositeExtract %[[A]][0 : i32] : vector<4xf32>
+// CHECK: %[[ELEM1:.*]] = spirv.CompositeExtract %[[A]][1 : i32] : vector<4xf32>
+// CHECK: %[[ELEM2:.*]] = spirv.CompositeExtract %[[A]][2 : i32] : vector<4xf32>
+// CHECK: %[[ELEM3:.*]] = spirv.CompositeExtract %[[A]][3 : i32] : vector<4xf32>
+// CHECK: return %[[ELEM0]], %[[ELEM1]], %[[ELEM2]], %[[ELEM3]] : f32, f32, f32, f32
+func.func @to_elements_no_dead_elements(%a: vector<4xf32>) -> (f32, f32, f32, f32) {
+ %0:4 = vector.to_elements %a : vector<4xf32>
+ return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
+}
+
+// CHECK-LABEL: func.func @to_elements_dead_elements
+// CHECK-SAME: %[[A:.*]]: vector<4xf32>)
+// CHECK-NOT: spirv.CompositeExtract %[[A]][0 : i32]
+// CHECK: %[[ELEM1:.*]] = spirv.CompositeExtract %[[A]][1 : i32] : vector<4xf32>
+// CHECK-NOT: spirv.CompositeExtract %[[A]][2 : i32]
+// CHECK: %[[ELEM3:.*]] = spirv.CompositeExtract %[[A]][3 : i32] : vector<4xf32>
+// CHECK: return %[[ELEM1]], %[[ELEM3]] : f32, f32
+func.func @to_elements_dead_elements(%a: vector<4xf32>) -> (f32, f32) {
+ %0:4 = vector.to_elements %a : vector<4xf32>
+ return %0#1, %0#3 : f32, f32
+}
+
+// -----
+
// CHECK-LABEL: @from_elements_0d
// CHECK-SAME: %[[ARG0:.+]]: f32
// CHECK: %[[RETVAL:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
More information about the Mlir-commits
mailing list