[Mlir-commits] [mlir] [mlir][Vector] Add constant folding for vector.from_elements operation (PR #145849)
Yang Bai
llvmlistbot at llvm.org
Thu Jun 26 00:02:33 PDT 2025
https://github.com/yangtetris created https://github.com/llvm/llvm-project/pull/145849
### Summary
This PR adds a new folding pattern for **vector.from_elements** that canonicalizes it to **arith.constant** when all input operands are constants.
### Implementation Details
**Leverages FoldAdaptor capabilities**: Uses adaptor.getElements() to access **pre-computed** constant attributes, avoiding redundant pattern matching on operands.
### Example Transformation
```
Before:
%c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32
%c2_i32 = arith.constant 2 : i32
%c3_i32 = arith.constant 3 : i32
%v = vector.from_elements %c0_i32, %c1_i32, %c2_i32, %c3_i32 : vector<2x2xi32>
After:
%v = arith.constant dense<[[0, 1], [2, 3]]> : vector<2x2xi32>
```
>From 53f8845bf1d12e26da826ae9608b550875eadc98 Mon Sep 17 00:00:00 2001
From: Yang Bai <yangb at nvidia.com>
Date: Wed, 25 Jun 2025 23:52:55 -0700
Subject: [PATCH] [mlir] fold vector.from_elements to constant when all
elements are constants
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 24 +++++++++++++++++++++-
mlir/test/Dialect/Vector/canonicalize.mlir | 14 +++++++++++++
2 files changed, 37 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 862ed7bae1fbb..9afb443cebc13 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2459,8 +2459,30 @@ static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp) {
return {};
}
+/// Fold vector.from_elements to a constant when all operands are constants.
+/// Example:
+/// %c1 = arith.constant 1 : i32
+/// %c2 = arith.constant 2 : i32
+/// %v = vector.from_elements %c1, %c2 : vector<2xi32>
+/// =>
+/// %v = arith.constant dense<[1, 2]> : vector<2xi32>
+///
+static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp,
+ ArrayRef<Attribute> elements) {
+ if (llvm::any_of(elements, [](Attribute attr) { return !attr; }))
+ return {};
+
+ auto destType = cast<VectorType>(fromElementsOp.getType());
+ return DenseElementsAttr::get(destType, elements);
+}
+
OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
- return foldFromElementsToElements(*this);
+ if (auto res = foldFromElementsToElements(*this))
+ return res;
+ if (auto res = foldFromElementsToConstant(*this, adaptor.getElements()))
+ return res;
+
+ return {};
}
/// Rewrite a vector.from_elements into a vector.splat if all elements are the
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 65b73375831da..d56c64552f9e7 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -3075,6 +3075,20 @@ func.func @from_elements_to_elements_shuffle(%a: vector<4x2xf32>) -> vector<4x2x
// -----
+// CHECK-LABEL: func @from_elements_to_constant
+func.func @from_elements_to_constant() -> vector<2x2xi32> {
+ %c0_i32 = arith.constant 0 : i32
+ %c1_i32 = arith.constant 1 : i32
+ %c2_i32 = arith.constant 2 : i32
+ %c3_i32 = arith.constant 3 : i32
+ // CHECK: %[[RES:.*]] = arith.constant dense<{{\[\[0, 1\], \[2, 3\]\]}}> : vector<2x2xi32>
+ %res = vector.from_elements %c0_i32, %c1_i32, %c2_i32, %c3_i32 : vector<2x2xi32>
+ // CHECK: return %[[RES]]
+ return %res : vector<2x2xi32>
+}
+
+// -----
+
// CHECK-LABEL: func @vector_insert_const_regression(
// CHECK: llvm.mlir.undef
// CHECK: vector.insert
More information about the Mlir-commits
mailing list