[Mlir-commits] [mlir] [[mlir][Vector] Add simple folders for `vector.from_element`/`vector.to_elements` (PR #144444)

Diego Caballero llvmlistbot at llvm.org
Wed Jun 18 13:53:11 PDT 2025


https://github.com/dcaballe updated https://github.com/llvm/llvm-project/pull/144444

>From 4f5ccf9560855438a68341b461917834ab60ecc0 Mon Sep 17 00:00:00 2001
From: Diego Caballero <dcaballero at nvidia.com>
Date: Wed, 11 Jun 2025 19:42:47 +0000
Subject: [PATCH 1/2] [mlir][Vector] Add simple folders for
 `vector.from_element`/`vector.to_elements`

This PR adds simple folders to remove no-op sequences of
`vector.from_elements` and `vector.to_elements`.
---
 .../mlir/Dialect/Vector/IR/VectorOps.td       |  2 +
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 90 +++++++++++++++++++
 mlir/test/Dialect/Vector/canonicalize.mlir    | 52 +++++++++++
 3 files changed, 144 insertions(+)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 125cd4645ccc2..7c44cfbde0367 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -836,6 +836,7 @@ def Vector_ToElementsOp : Vector_Op<"to_elements", [
   let arguments = (ins AnyVectorOfAnyRank:$source);
   let results = (outs Variadic<AnyType>:$elements);
   let assemblyFormat = "$source attr-dict `:` type($source)";
+  let hasFolder = 1;
 }
 
 def Vector_FromElementsOp : Vector_Op<"from_elements", [
@@ -873,6 +874,7 @@ def Vector_FromElementsOp : Vector_Op<"from_elements", [
   let arguments = (ins Variadic<AnyType>:$elements);
   let results = (outs AnyFixedVectorOfAnyRank:$dest);
   let assemblyFormat = "$elements attr-dict `:` type($dest)";
+  let hasFolder = 1;
   let hasCanonicalizer = 1;
 }
 
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index e576eeac23656..7482b6a22c400 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2373,10 +2373,100 @@ std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
   return llvm::to_vector<4>(getVectorType().getShape());
 }
 
+//===----------------------------------------------------------------------===//
+// ToElementsOp
+//===----------------------------------------------------------------------===//
+
+/// Returns true if all the `operands` are defined by `defOp`.
+/// Otherwise, returns false.
+static bool haveSameDefiningOp(OperandRange operands, Operation *defOp) {
+  if (operands.empty())
+    return false;
+
+  return llvm::all_of(operands, [&](Value operand) {
+    Operation *currentDef = operand.getDefiningOp();
+    return currentDef == defOp;
+  });
+}
+
+/// Folds vector.to_elements(vector.from_elements(%e0, %e1, ...)) into
+/// (%e0, %e1, ...). For example:
+///
+///   %0 = vector.from_elements %a, %b, %c : vector<3xf32>
+///   %1:3 = vector.to_elements %0 : vector<3xf32>
+///   user_op %1#0, %1#1, %1#2
+///
+/// becomes:
+///
+///   user_op %a, %b, %c
+///
+static LogicalResult
+foldToElementsFromElements(ToElementsOp toElementsOp,
+                           SmallVectorImpl<OpFoldResult> &results) {
+  auto fromElementsOp = toElementsOp.getSource().getDefiningOp<FromElementsOp>();
+  if (!fromElementsOp)
+    return failure();
+
+  results.append(fromElementsOp.getElements().begin(),
+                 fromElementsOp.getElements().end());
+  return success();
+}
+
+LogicalResult ToElementsOp::fold(FoldAdaptor adaptor,
+                                 SmallVectorImpl<OpFoldResult> &results) {
+  if (succeeded(foldToElementsFromElements(*this, results)))
+    return success();
+  return failure();
+}
+
 //===----------------------------------------------------------------------===//
 // FromElementsOp
 //===----------------------------------------------------------------------===//
 
+/// Folds vector.from_elements(vector.to_elements(%vector)) into %vector.
+///
+/// Case #1: Input and output vectors are the same.
+///
+///   %0:3 = vector.to_elements %a : vector<3xf32>
+///   %1 = vector.from_elements %0#0, %0#1, %0#2 : vector<3xf32>
+///   user_op %1
+///
+/// becomes:
+///
+///   user_op %a
+///
+static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp) {
+  auto fromElemsOperands = fromElementsOp.getElements();
+
+  if (fromElemsOperands.empty())
+    return {};
+
+  auto toElementsOp = fromElemsOperands[0].getDefiningOp<ToElementsOp>();
+  if (!toElementsOp)
+    return {};
+
+  if (!haveSameDefiningOp(fromElemsOperands, toElementsOp))
+    return {};
+
+  // Case #1: Input and output vectors are the same. Forward the input vector.
+  Value toElementsInput = toElementsOp.getSource();
+  if (fromElementsOp.getType() == toElementsInput.getType() &&
+      llvm::equal(fromElemsOperands, toElementsOp.getResults())) {
+    return toElementsInput;
+  }
+
+  // TODO: Support cases with different input and output shapes and different
+  // number of elements.
+
+  return {};
+}
+
+OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
+  if (auto result = foldFromElementsToElements(*this))
+    return result;
+  return {};
+}
+
 /// Rewrite a vector.from_elements into a vector.splat if all elements are the
 /// same SSA value. E.g.:
 ///
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 6691cb52acdc0..65b73375831da 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -3023,6 +3023,58 @@ func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector<f32>,
 
 // -----
 
+// CHECK-LABEL: func @to_elements_from_elements_no_op(
+// CHECK-SAME:     %[[A:.*]]: f32, %[[B:.*]]: f32
+func.func @to_elements_from_elements_no_op(%a: f32, %b: f32) -> (f32, f32) {
+  // CHECK-NOT: vector.from_elements
+  // CHECK-NOT: vector.to_elements
+  %0 = vector.from_elements %b, %a : vector<2xf32>
+  %1:2 = vector.to_elements %0 : vector<2xf32>
+  // CHECK: return %[[B]], %[[A]]
+  return %1#0, %1#1 : f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: func @from_elements_to_elements_no_op(
+// CHECK-SAME:     %[[A:.*]]: vector<4x2xf32>
+func.func @from_elements_to_elements_no_op(%a: vector<4x2xf32>) -> vector<4x2xf32> {
+  // CHECK-NOT: vector.from_elements
+  // CHECK-NOT: vector.to_elements
+  %0:8 = vector.to_elements %a : vector<4x2xf32>
+  %1 = vector.from_elements %0#0, %0#1, %0#2, %0#3, %0#4, %0#5, %0#6, %0#7 : vector<4x2xf32>
+  // CHECK: return %[[A]]
+  return %1 : vector<4x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @from_elements_to_elements_dup_elems(
+// CHECK-SAME:     %[[A:.*]]: vector<4xf32>
+func.func @from_elements_to_elements_dup_elems(%a: vector<4xf32>) -> vector<4x2xf32> {
+  // CHECK: %[[TO_EL:.*]]:4 = vector.to_elements %[[A]]
+  // CHECK: %[[FROM_EL:.*]] = vector.from_elements %[[TO_EL]]#0, %[[TO_EL]]#1, %[[TO_EL]]#2
+  %0:4 = vector.to_elements %a : vector<4xf32> // 4 elements
+  %1 = vector.from_elements %0#0, %0#1, %0#2, %0#3, %0#0, %0#1, %0#2, %0#3 : vector<4x2xf32>
+  // CHECK: return %[[FROM_EL]]
+  return %1 : vector<4x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @from_elements_to_elements_shuffle(
+// CHECK-SAME:     %[[A:.*]]: vector<4x2xf32>
+func.func @from_elements_to_elements_shuffle(%a: vector<4x2xf32>) -> vector<4x2xf32> {
+  // CHECK: %[[TO_EL:.*]]:8 = vector.to_elements %[[A]]
+  // CHECK: %[[FROM_EL:.*]] = vector.from_elements %[[TO_EL]]#7, %[[TO_EL]]#0, %[[TO_EL]]#6
+  %0:8 = vector.to_elements %a : vector<4x2xf32>
+  %1 = vector.from_elements %0#7, %0#0, %0#6, %0#1, %0#5, %0#2, %0#4, %0#3 : vector<4x2xf32>
+  // CHECK: return %[[FROM_EL]]
+  return %1 : vector<4x2xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @vector_insert_const_regression(
 //       CHECK:   llvm.mlir.undef
 //       CHECK:   vector.insert

>From 7fc28b25b3a73b331004f55a451bfbf7d0775965 Mon Sep 17 00:00:00 2001
From: Diego Caballero <dcaballero at nvidia.com>
Date: Wed, 18 Jun 2025 20:45:54 +0000
Subject: [PATCH 2/2] Feedback

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 11 +++--------
 1 file changed, 3 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 7482b6a22c400..fab49e27562ac 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2407,16 +2407,13 @@ foldToElementsFromElements(ToElementsOp toElementsOp,
   if (!fromElementsOp)
     return failure();
 
-  results.append(fromElementsOp.getElements().begin(),
-                 fromElementsOp.getElements().end());
+  llvm::append_range(results, fromElementsOp.getElements());
   return success();
 }
 
 LogicalResult ToElementsOp::fold(FoldAdaptor adaptor,
                                  SmallVectorImpl<OpFoldResult> &results) {
-  if (succeeded(foldToElementsFromElements(*this, results)))
-    return success();
-  return failure();
+  return foldToElementsFromElements(*this, results);
 }
 
 //===----------------------------------------------------------------------===//
@@ -2462,9 +2459,7 @@ static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp) {
 }
 
 OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
-  if (auto result = foldFromElementsToElements(*this))
-    return result;
-  return {};
+  return foldFromElementsToElements(*this);
 }
 
 /// Rewrite a vector.from_elements into a vector.splat if all elements are the



More information about the Mlir-commits mailing list