[Mlir-commits] [mlir] [mlir][Vector] Add constant folding for vector.from_elements operation (PR #145849)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jun 26 00:03:06 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector
Author: Yang Bai (yangtetris)
<details>
<summary>Changes</summary>
### Summary
This PR adds a new folding pattern for **vector.from_elements** that canonicalizes it to **arith.constant** when all input operands are constants.
### Implementation Details
**Leverages FoldAdaptor capabilities**: Uses adaptor.getElements() to access **pre-computed** constant attributes, avoiding redundant pattern matching on operands.
### Example Transformation
```
Before:
%c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32
%c2_i32 = arith.constant 2 : i32
%c3_i32 = arith.constant 3 : i32
%v = vector.from_elements %c0_i32, %c1_i32, %c2_i32, %c3_i32 : vector<2x2xi32>
After:
%v = arith.constant dense<[[0, 1], [2, 3]]> : vector<2x2xi32>
```
---
Full diff: https://github.com/llvm/llvm-project/pull/145849.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+23-1)
- (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+14)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 862ed7bae1fbb..9afb443cebc13 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2459,8 +2459,30 @@ static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp) {
return {};
}
+/// Fold vector.from_elements to a constant when all operands are constants.
+/// Example:
+/// %c1 = arith.constant 1 : i32
+/// %c2 = arith.constant 2 : i32
+/// %v = vector.from_elements %c1, %c2 : vector<2xi32>
+/// =>
+/// %v = arith.constant dense<[1, 2]> : vector<2xi32>
+///
+static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp,
+ ArrayRef<Attribute> elements) {
+ if (llvm::any_of(elements, [](Attribute attr) { return !attr; }))
+ return {};
+
+ auto destType = cast<VectorType>(fromElementsOp.getType());
+ return DenseElementsAttr::get(destType, elements);
+}
+
OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
- return foldFromElementsToElements(*this);
+ if (auto res = foldFromElementsToElements(*this))
+ return res;
+ if (auto res = foldFromElementsToConstant(*this, adaptor.getElements()))
+ return res;
+
+ return {};
}
/// Rewrite a vector.from_elements into a vector.splat if all elements are the
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 65b73375831da..d56c64552f9e7 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -3075,6 +3075,20 @@ func.func @from_elements_to_elements_shuffle(%a: vector<4x2xf32>) -> vector<4x2x
// -----
+// CHECK-LABEL: func @from_elements_to_constant
+func.func @from_elements_to_constant() -> vector<2x2xi32> {
+ %c0_i32 = arith.constant 0 : i32
+ %c1_i32 = arith.constant 1 : i32
+ %c2_i32 = arith.constant 2 : i32
+ %c3_i32 = arith.constant 3 : i32
+ // CHECK: %[[RES:.*]] = arith.constant dense<{{\[\[0, 1\], \[2, 3\]\]}}> : vector<2x2xi32>
+ %res = vector.from_elements %c0_i32, %c1_i32, %c2_i32, %c3_i32 : vector<2x2xi32>
+ // CHECK: return %[[RES]]
+ return %res : vector<2x2xi32>
+}
+
+// -----
+
// CHECK-LABEL: func @vector_insert_const_regression(
// CHECK: llvm.mlir.undef
// CHECK: vector.insert
``````````
</details>
https://github.com/llvm/llvm-project/pull/145849
More information about the Mlir-commits
mailing list