[Mlir-commits] [mlir] [mlir][Vector] Lower `vector.to_elements` to LLVM (PR #145766)

Diego Caballero llvmlistbot at llvm.org
Thu Jun 26 10:25:51 PDT 2025


https://github.com/dcaballe updated https://github.com/llvm/llvm-project/pull/145766

>From 8ee9261d10ea50c1b0606d6ba7ba0232cc525dc3 Mon Sep 17 00:00:00 2001
From: Diego Caballero <dcaballero at nvidia.com>
Date: Wed, 25 Jun 2025 18:57:15 +0000
Subject: [PATCH 1/2] [mlir][Vector] Add `vector.to_elements` lowering to LLVM

Only elements with at least one use are lowered to `llvm.extractelement` op.
---
 .../VectorToLLVM/ConvertVectorToLLVM.cpp      | 34 +++++++++++++++++-
 .../vector-to-llvm-interface.mlir             | 36 ++++++++++++++++++-
 2 files changed, 68 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index d53d11f87efe8..f1543200fb56f 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1985,6 +1985,37 @@ struct VectorFromElementsLowering
   }
 };
 
+/// Conversion pattern for a `vector.to_elements`.
+struct VectorToElementsLowering
+    : public ConvertOpToLLVMPattern<vector::ToElementsOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = toElementsOp.getLoc();
+    auto idxType = typeConverter->convertType(rewriter.getIndexType());
+    Value source = adaptor.getSource();
+
+    SmallVector<Value> results(toElementsOp->getNumResults());
+    for (auto [idx, element] : llvm::enumerate(toElementsOp.getElements())) {
+      // Create an extractelement operation only for results that are not dead.
+      if (!element.use_empty()) {
+        auto constIdx = rewriter.create<LLVM::ConstantOp>(
+            loc, idxType, rewriter.getIntegerAttr(idxType, idx));
+        auto llvmType = typeConverter->convertType(element.getType());
+
+        Value result = rewriter.create<LLVM::ExtractElementOp>(
+            loc, llvmType, source, constIdx);
+        results[idx] = result;
+      }
+    }
+
+    rewriter.replaceOp(toElementsOp, results);
+    return success();
+  }
+};
+
 /// Conversion pattern for vector.step.
 struct VectorScalableStepOpLowering
     : public ConvertOpToLLVMPattern<vector::StepOp> {
@@ -2035,7 +2066,8 @@ void mlir::populateVectorToLLVMConversionPatterns(
                VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
                MaskedReductionOpConversion, VectorInterleaveOpLowering,
                VectorDeinterleaveOpLowering, VectorFromElementsLowering,
-               VectorScalableStepOpLowering>(converter);
+               VectorToElementsLowering, VectorScalableStepOpLowering>(
+      converter);
 }
 
 void mlir::populateVectorToLLVMMatrixConversionPatterns(
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
index 3df14528bac39..8f73e79d7bfc2 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
@@ -1875,7 +1875,7 @@ func.func @store_0d(%memref : memref<200x100xf32>, %i : index, %j : index) {
 // CHECK: %[[CAST_MEMREF:.*]] = builtin.unrealized_conversion_cast %{{.*}} : memref<200x100xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK: %[[CST:.*]] = arith.constant dense<1.100000e+01> : vector<f32>
 // CHECK: %[[VAL:.*]] = builtin.unrealized_conversion_cast %[[CST]] : vector<f32> to vector<1xf32>
-// CHECK: %[[REF:.*]] = llvm.extractvalue %[[CAST_MEMREF]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
+// CHECK: %[[REF:.*]] = llvm.extractvalue %[[CAST_MEMREF]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK: %[[C100:.*]] = llvm.mlir.constant(100 : index) : i64
 // CHECK: %[[MUL:.*]] = llvm.mul %[[I]], %[[C100]] : i64
 // CHECK: %[[ADD:.*]] = llvm.add %[[MUL]], %[[J]] : i64
@@ -2421,6 +2421,40 @@ func.func @from_elements_0d(%arg0: f32) -> vector<f32> {
 
 // -----
 
+// CHECK-LABEL: func.func @vector_to_elements_no_dead_elements
+ // CHECK-SAME:     %[[A:.*]]: vector<4xf32>)
+ //      CHECK:   %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64
+ //      CHECK:   %[[ELEM0:.*]] = llvm.extractelement %[[A]][%[[C0]] : i64] : vector<4xf32>
+ //      CHECK:   %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64
+ //      CHECK:   %[[ELEM1:.*]] = llvm.extractelement %[[A]][%[[C1]] : i64] : vector<4xf32>
+ //      CHECK:   %[[C2:.*]] = llvm.mlir.constant(2 : i64) : i64
+ //      CHECK:   %[[ELEM2:.*]] = llvm.extractelement %[[A]][%[[C2]] : i64] : vector<4xf32>
+ //      CHECK:   %[[C3:.*]] = llvm.mlir.constant(3 : i64) : i64
+ //      CHECK:   %[[ELEM3:.*]] = llvm.extractelement %[[A]][%[[C3]] : i64] : vector<4xf32>
+ //      CHECK:   return %[[ELEM0]], %[[ELEM1]], %[[ELEM2]], %[[ELEM3]] : f32, f32, f32, f32
+func.func @vector_to_elements_no_dead_elements(%a: vector<4xf32>) -> (f32, f32, f32, f32) {
+  %0:4 = vector.to_elements %a : vector<4xf32>
+  return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @vector_to_elements_dead_elements
+ // CHECK-SAME:     %[[A:.*]]: vector<4xf32>)
+ //  CHECK-NOT:   llvm.mlir.constant(0 : i64) : i64
+ //      CHECK:   %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64
+ //      CHECK:   %[[ELEM1:.*]] = llvm.extractelement %[[A]][%[[C1]] : i64] : vector<4xf32>
+ //  CHECK-NOT:   llvm.mlir.constant(2 : i64) : i64
+ //      CHECK:   %[[C3:.*]] = llvm.mlir.constant(3 : i64) : i64
+ //      CHECK:   %[[ELEM3:.*]] = llvm.extractelement %[[A]][%[[C3]] : i64] : vector<4xf32>
+ //      CHECK:   return %[[ELEM1]], %[[ELEM3]] : f32, f32
+func.func @vector_to_elements_dead_elements(%a: vector<4xf32>) -> (f32, f32) {
+  %0:4 = vector.to_elements %a : vector<4xf32>
+  return %0#1, %0#3 : f32, f32
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // vector.step
 //===----------------------------------------------------------------------===//

>From cc50307de44b2e700436471857c11c132e7c89d4 Mon Sep 17 00:00:00 2001
From: Diego Caballero <dcaballero at nvidia.com>
Date: Thu, 26 Jun 2025 17:25:32 +0000
Subject: [PATCH 2/2] Feedback

---
 .../VectorToLLVM/ConvertVectorToLLVM.cpp      | 19 ++++++++++---------
 .../vector-to-llvm-interface.mlir             | 12 ++++++++----
 2 files changed, 18 insertions(+), 13 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index f1543200fb56f..501d98862672d 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -2000,15 +2000,16 @@ struct VectorToElementsLowering
     SmallVector<Value> results(toElementsOp->getNumResults());
     for (auto [idx, element] : llvm::enumerate(toElementsOp.getElements())) {
       // Create an extractelement operation only for results that are not dead.
-      if (!element.use_empty()) {
-        auto constIdx = rewriter.create<LLVM::ConstantOp>(
-            loc, idxType, rewriter.getIntegerAttr(idxType, idx));
-        auto llvmType = typeConverter->convertType(element.getType());
-
-        Value result = rewriter.create<LLVM::ExtractElementOp>(
-            loc, llvmType, source, constIdx);
-        results[idx] = result;
-      }
+      if (element.use_empty())
+        continue;
+
+      auto constIdx = rewriter.create<LLVM::ConstantOp>(
+          loc, idxType, rewriter.getIntegerAttr(idxType, idx));
+      auto llvmType = typeConverter->convertType(element.getType());
+
+      Value result = rewriter.create<LLVM::ExtractElementOp>(loc, llvmType,
+                                                             source, constIdx);
+      results[idx] = result;
     }
 
     rewriter.replaceOp(toElementsOp, results);
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
index 8f73e79d7bfc2..c03d67fdc33fa 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
@@ -2421,7 +2421,11 @@ func.func @from_elements_0d(%arg0: f32) -> vector<f32> {
 
 // -----
 
-// CHECK-LABEL: func.func @vector_to_elements_no_dead_elements
+//===----------------------------------------------------------------------===//
+// vector.to_elements
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func.func @to_elements_no_dead_elements
  // CHECK-SAME:     %[[A:.*]]: vector<4xf32>)
  //      CHECK:   %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64
  //      CHECK:   %[[ELEM0:.*]] = llvm.extractelement %[[A]][%[[C0]] : i64] : vector<4xf32>
@@ -2432,14 +2436,14 @@ func.func @from_elements_0d(%arg0: f32) -> vector<f32> {
  //      CHECK:   %[[C3:.*]] = llvm.mlir.constant(3 : i64) : i64
  //      CHECK:   %[[ELEM3:.*]] = llvm.extractelement %[[A]][%[[C3]] : i64] : vector<4xf32>
  //      CHECK:   return %[[ELEM0]], %[[ELEM1]], %[[ELEM2]], %[[ELEM3]] : f32, f32, f32, f32
-func.func @vector_to_elements_no_dead_elements(%a: vector<4xf32>) -> (f32, f32, f32, f32) {
+func.func @to_elements_no_dead_elements(%a: vector<4xf32>) -> (f32, f32, f32, f32) {
   %0:4 = vector.to_elements %a : vector<4xf32>
   return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
 }
 
 // -----
 
-// CHECK-LABEL: func.func @vector_to_elements_dead_elements
+// CHECK-LABEL: func.func @to_elements_dead_elements
  // CHECK-SAME:     %[[A:.*]]: vector<4xf32>)
  //  CHECK-NOT:   llvm.mlir.constant(0 : i64) : i64
  //      CHECK:   %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64
@@ -2448,7 +2452,7 @@ func.func @vector_to_elements_no_dead_elements(%a: vector<4xf32>) -> (f32, f32,
  //      CHECK:   %[[C3:.*]] = llvm.mlir.constant(3 : i64) : i64
  //      CHECK:   %[[ELEM3:.*]] = llvm.extractelement %[[A]][%[[C3]] : i64] : vector<4xf32>
  //      CHECK:   return %[[ELEM1]], %[[ELEM3]] : f32, f32
-func.func @vector_to_elements_dead_elements(%a: vector<4xf32>) -> (f32, f32) {
+func.func @to_elements_dead_elements(%a: vector<4xf32>) -> (f32, f32) {
   %0:4 = vector.to_elements %a : vector<4xf32>
   return %0#1, %0#3 : f32, f32
 }



More information about the Mlir-commits mailing list