[Mlir-commits] [mlir] [mlir][Vector] Add constant folding for vector.from_elements operation (PR #145849)
Yang Bai
llvmlistbot at llvm.org
Thu Jun 26 03:52:00 PDT 2025
https://github.com/yangtetris updated https://github.com/llvm/llvm-project/pull/145849
>From 53f8845bf1d12e26da826ae9608b550875eadc98 Mon Sep 17 00:00:00 2001
From: Yang Bai <yangb at nvidia.com>
Date: Wed, 25 Jun 2025 23:52:55 -0700
Subject: [PATCH 1/3] [mlir] fold vector.from_elements to constant when all
elements are constants
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 24 +++++++++++++++++++++-
mlir/test/Dialect/Vector/canonicalize.mlir | 14 +++++++++++++
2 files changed, 37 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 862ed7bae1fbb..9afb443cebc13 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2459,8 +2459,30 @@ static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp) {
return {};
}
+/// Fold vector.from_elements to a constant when all operands are constants.
+/// Example:
+/// %c1 = arith.constant 1 : i32
+/// %c2 = arith.constant 2 : i32
+/// %v = vector.from_elements %c1, %c2 : vector<2xi32>
+/// =>
+/// %v = arith.constant dense<[1, 2]> : vector<2xi32>
+///
+static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp,
+ ArrayRef<Attribute> elements) {
+ if (llvm::any_of(elements, [](Attribute attr) { return !attr; }))
+ return {};
+
+ auto destType = cast<VectorType>(fromElementsOp.getType());
+ return DenseElementsAttr::get(destType, elements);
+}
+
OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
- return foldFromElementsToElements(*this);
+ if (auto res = foldFromElementsToElements(*this))
+ return res;
+ if (auto res = foldFromElementsToConstant(*this, adaptor.getElements()))
+ return res;
+
+ return {};
}
/// Rewrite a vector.from_elements into a vector.splat if all elements are the
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 65b73375831da..d56c64552f9e7 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -3075,6 +3075,20 @@ func.func @from_elements_to_elements_shuffle(%a: vector<4x2xf32>) -> vector<4x2x
// -----
+// CHECK-LABEL: func @from_elements_to_constant
+func.func @from_elements_to_constant() -> vector<2x2xi32> {
+ %c0_i32 = arith.constant 0 : i32
+ %c1_i32 = arith.constant 1 : i32
+ %c2_i32 = arith.constant 2 : i32
+ %c3_i32 = arith.constant 3 : i32
+ // CHECK: %[[RES:.*]] = arith.constant dense<{{\[\[0, 1\], \[2, 3\]\]}}> : vector<2x2xi32>
+ %res = vector.from_elements %c0_i32, %c1_i32, %c2_i32, %c3_i32 : vector<2x2xi32>
+ // CHECK: return %[[RES]]
+ return %res : vector<2x2xi32>
+}
+
+// -----
+
// CHECK-LABEL: func @vector_insert_const_regression(
// CHECK: llvm.mlir.undef
// CHECK: vector.insert
>From 3e64f97033d4aee2133ab81ddd18740da4ed3009 Mon Sep 17 00:00:00 2001
From: Yang Bai <yangb at nvidia.com>
Date: Thu, 26 Jun 2025 00:54:46 -0700
Subject: [PATCH 2/3] remove cast
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 9afb443cebc13..5c5db250585cf 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2472,7 +2472,7 @@ static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp,
if (llvm::any_of(elements, [](Attribute attr) { return !attr; }))
return {};
- auto destType = cast<VectorType>(fromElementsOp.getType());
+ auto destType = fromElementsOp.getDest().getType();
return DenseElementsAttr::get(destType, elements);
}
>From 6551987d5d59c0d71f822cdec13bf985ebc25268 Mon Sep 17 00:00:00 2001
From: Yang Bai <yangb at nvidia.com>
Date: Thu, 26 Jun 2025 03:51:43 -0700
Subject: [PATCH 3/3] apply convertIntegerAttr to avoid crash caused by
mismatch between attribute type and value type
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 35 +++++++++++++++---------
1 file changed, 22 insertions(+), 13 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 5c5db250585cf..f08afec432ca5 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -398,6 +398,19 @@ std::optional<int64_t> vector::getConstantVscaleMultiplier(Value value) {
return {};
}
+/// Converts an IntegerAttr to have the specified type if needed.
+/// This handles cases where constant attributes (e.g., from
+/// `llvm.mlir.constant`) have a different type than the target element type. If
+/// the input attribute is not an IntegerAttr or already has the correct type,
+/// returns it unchanged.
+static Attribute convertIntegerAttr(Attribute attr, Type expectedType) {
+ if (auto intAttr = mlir::dyn_cast<IntegerAttr>(attr)) {
+ if (intAttr.getType() != expectedType)
+ return IntegerAttr::get(expectedType, intAttr.getInt());
+ }
+ return attr;
+}
+
//===----------------------------------------------------------------------===//
// CombiningKindAttr
//===----------------------------------------------------------------------===//
@@ -2472,8 +2485,15 @@ static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp,
if (llvm::any_of(elements, [](Attribute attr) { return !attr; }))
return {};
- auto destType = fromElementsOp.getDest().getType();
- return DenseElementsAttr::get(destType, elements);
+ auto destVecType = fromElementsOp.getDest().getType();
+ auto destEltType = destVecType.getElementType();
+ // Constants from llvm.mlir.constant can have a different type than the return
+ // type. Convert them before creating the dense elements attribute.
+ auto convertedElements = llvm::map_to_vector(elements, [&](Attribute attr) {
+ return convertIntegerAttr(attr, destEltType);
+ });
+
+ return DenseElementsAttr::get(destVecType, convertedElements);
}
OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
@@ -3344,17 +3364,6 @@ foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr,
/// Converts the expected type to an IntegerAttr if there's
/// a mismatch.
- auto convertIntegerAttr = [](Attribute attr, Type expectedType) -> Attribute {
- if (auto intAttr = mlir::dyn_cast<IntegerAttr>(attr)) {
- if (intAttr.getType() != expectedType)
- return IntegerAttr::get(expectedType, intAttr.getInt());
- }
- return attr;
- };
-
- // The `convertIntegerAttr` method specifically handles the case
- // for `llvm.mlir.constant` which can hold an attribute with a
- // different type than the return type.
if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
for (auto value : denseSource.getValues<Attribute>())
insertedValues.push_back(convertIntegerAttr(value, destEltType));
More information about the Mlir-commits
mailing list