[Mlir-commits] [mlir] [mlir][Vector] Add constant folding for vector.from_elements operation (PR #145849)

Yang Bai llvmlistbot at llvm.org
Thu Jun 26 00:02:33 PDT 2025


https://github.com/yangtetris created https://github.com/llvm/llvm-project/pull/145849

### Summary

This PR adds a new folding pattern for **vector.from_elements** that canonicalizes it to **arith.constant** when all input operands are constants.

### Implementation Details

**Leverages FoldAdaptor capabilities**: Uses adaptor.getElements() to access **pre-computed** constant attributes, avoiding redundant pattern matching on operands.

### Example Transformation
```
Before:
%c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32
%c2_i32 = arith.constant 2 : i32
%c3_i32 = arith.constant 3 : i32
%v = vector.from_elements %c0_i32, %c1_i32, %c2_i32, %c3_i32 : vector<2x2xi32>

After:
%v = arith.constant dense<[[0, 1], [2, 3]]> : vector<2x2xi32>
```


>From 53f8845bf1d12e26da826ae9608b550875eadc98 Mon Sep 17 00:00:00 2001
From: Yang Bai <yangb at nvidia.com>
Date: Wed, 25 Jun 2025 23:52:55 -0700
Subject: [PATCH] [mlir] fold vector.from_elements to constant when all
 elements are constants

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   | 24 +++++++++++++++++++++-
 mlir/test/Dialect/Vector/canonicalize.mlir | 14 +++++++++++++
 2 files changed, 37 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 862ed7bae1fbb..9afb443cebc13 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2459,8 +2459,30 @@ static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp) {
   return {};
 }
 
+/// Fold vector.from_elements to a constant when all operands are constants.
+/// Example:
+///   %c1 = arith.constant 1 : i32
+///   %c2 = arith.constant 2 : i32
+///   %v = vector.from_elements %c1, %c2 : vector<2xi32>
+/// =>
+///   %v = arith.constant dense<[1, 2]> : vector<2xi32>
+///
+static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp,
+                                               ArrayRef<Attribute> elements) {
+  if (llvm::any_of(elements, [](Attribute attr) { return !attr; }))
+    return {};
+
+  auto destType = cast<VectorType>(fromElementsOp.getType());
+  return DenseElementsAttr::get(destType, elements);
+}
+
 OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
-  return foldFromElementsToElements(*this);
+  if (auto res = foldFromElementsToElements(*this))
+    return res;
+  if (auto res = foldFromElementsToConstant(*this, adaptor.getElements()))
+    return res;
+
+  return {};
 }
 
 /// Rewrite a vector.from_elements into a vector.splat if all elements are the
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 65b73375831da..d56c64552f9e7 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -3075,6 +3075,20 @@ func.func @from_elements_to_elements_shuffle(%a: vector<4x2xf32>) -> vector<4x2x
 
 // -----
 
+// CHECK-LABEL: func @from_elements_to_constant
+func.func @from_elements_to_constant() -> vector<2x2xi32> {
+  %c0_i32 = arith.constant 0 : i32
+  %c1_i32 = arith.constant 1 : i32
+  %c2_i32 = arith.constant 2 : i32
+  %c3_i32 = arith.constant 3 : i32
+  // CHECK: %[[RES:.*]] = arith.constant dense<{{\[\[0, 1\], \[2, 3\]\]}}> : vector<2x2xi32>
+  %res = vector.from_elements %c0_i32, %c1_i32, %c2_i32, %c3_i32 : vector<2x2xi32>
+  // CHECK: return %[[RES]]
+  return %res : vector<2x2xi32>
+}
+
+// -----
+
 // CHECK-LABEL: func @vector_insert_const_regression(
 //       CHECK:   llvm.mlir.undef
 //       CHECK:   vector.insert



More information about the Mlir-commits mailing list