[Mlir-commits] [mlir] [mlir][Vector] Lower `vector.to_elements` to LLVM (PR #145766)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jun 25 12:05:20 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Diego Caballero (dcaballe)
<details>
<summary>Changes</summary>
Only elements with at least one use are lowered to `llvm.extractelement` op.
---
Full diff: https://github.com/llvm/llvm-project/pull/145766.diff
2 Files Affected:
- (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (+33-1)
- (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir (+35-1)
``````````diff
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index d53d11f87efe8..f1543200fb56f 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1985,6 +1985,37 @@ struct VectorFromElementsLowering
}
};
+/// Conversion pattern for a `vector.to_elements`.
+struct VectorToElementsLowering
+ : public ConvertOpToLLVMPattern<vector::ToElementsOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = toElementsOp.getLoc();
+ auto idxType = typeConverter->convertType(rewriter.getIndexType());
+ Value source = adaptor.getSource();
+
+ SmallVector<Value> results(toElementsOp->getNumResults());
+ for (auto [idx, element] : llvm::enumerate(toElementsOp.getElements())) {
+ // Create an extractelement operation only for results that are not dead.
+ if (!element.use_empty()) {
+ auto constIdx = rewriter.create<LLVM::ConstantOp>(
+ loc, idxType, rewriter.getIntegerAttr(idxType, idx));
+ auto llvmType = typeConverter->convertType(element.getType());
+
+ Value result = rewriter.create<LLVM::ExtractElementOp>(
+ loc, llvmType, source, constIdx);
+ results[idx] = result;
+ }
+ }
+
+ rewriter.replaceOp(toElementsOp, results);
+ return success();
+ }
+};
+
/// Conversion pattern for vector.step.
struct VectorScalableStepOpLowering
: public ConvertOpToLLVMPattern<vector::StepOp> {
@@ -2035,7 +2066,8 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
MaskedReductionOpConversion, VectorInterleaveOpLowering,
VectorDeinterleaveOpLowering, VectorFromElementsLowering,
- VectorScalableStepOpLowering>(converter);
+ VectorToElementsLowering, VectorScalableStepOpLowering>(
+ converter);
}
void mlir::populateVectorToLLVMMatrixConversionPatterns(
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
index 3df14528bac39..8f73e79d7bfc2 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
@@ -1875,7 +1875,7 @@ func.func @store_0d(%memref : memref<200x100xf32>, %i : index, %j : index) {
// CHECK: %[[CAST_MEMREF:.*]] = builtin.unrealized_conversion_cast %{{.*}} : memref<200x100xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[CST:.*]] = arith.constant dense<1.100000e+01> : vector<f32>
// CHECK: %[[VAL:.*]] = builtin.unrealized_conversion_cast %[[CST]] : vector<f32> to vector<1xf32>
-// CHECK: %[[REF:.*]] = llvm.extractvalue %[[CAST_MEMREF]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[REF:.*]] = llvm.extractvalue %[[CAST_MEMREF]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[C100:.*]] = llvm.mlir.constant(100 : index) : i64
// CHECK: %[[MUL:.*]] = llvm.mul %[[I]], %[[C100]] : i64
// CHECK: %[[ADD:.*]] = llvm.add %[[MUL]], %[[J]] : i64
@@ -2421,6 +2421,40 @@ func.func @from_elements_0d(%arg0: f32) -> vector<f32> {
// -----
+// CHECK-LABEL: func.func @vector_to_elements_no_dead_elements
+ // CHECK-SAME: %[[A:.*]]: vector<4xf32>)
+ // CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64
+ // CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[A]][%[[C0]] : i64] : vector<4xf32>
+ // CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64
+ // CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[A]][%[[C1]] : i64] : vector<4xf32>
+ // CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : i64) : i64
+ // CHECK: %[[ELEM2:.*]] = llvm.extractelement %[[A]][%[[C2]] : i64] : vector<4xf32>
+ // CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : i64) : i64
+ // CHECK: %[[ELEM3:.*]] = llvm.extractelement %[[A]][%[[C3]] : i64] : vector<4xf32>
+ // CHECK: return %[[ELEM0]], %[[ELEM1]], %[[ELEM2]], %[[ELEM3]] : f32, f32, f32, f32
+func.func @vector_to_elements_no_dead_elements(%a: vector<4xf32>) -> (f32, f32, f32, f32) {
+ %0:4 = vector.to_elements %a : vector<4xf32>
+ return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @vector_to_elements_dead_elements
+ // CHECK-SAME: %[[A:.*]]: vector<4xf32>)
+ // CHECK-NOT: llvm.mlir.constant(0 : i64) : i64
+ // CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64
+ // CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[A]][%[[C1]] : i64] : vector<4xf32>
+ // CHECK-NOT: llvm.mlir.constant(2 : i64) : i64
+ // CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : i64) : i64
+ // CHECK: %[[ELEM3:.*]] = llvm.extractelement %[[A]][%[[C3]] : i64] : vector<4xf32>
+ // CHECK: return %[[ELEM1]], %[[ELEM3]] : f32, f32
+func.func @vector_to_elements_dead_elements(%a: vector<4xf32>) -> (f32, f32) {
+ %0:4 = vector.to_elements %a : vector<4xf32>
+ return %0#1, %0#3 : f32, f32
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// vector.step
//===----------------------------------------------------------------------===//
``````````
</details>
https://github.com/llvm/llvm-project/pull/145766
More information about the Mlir-commits
mailing list