[Mlir-commits] [mlir] [mlir][Vector] Add constant folding for vector.from_elements operation (PR #145849)

Yang Bai llvmlistbot at llvm.org
Thu Jun 26 03:52:00 PDT 2025


https://github.com/yangtetris updated https://github.com/llvm/llvm-project/pull/145849

>From 53f8845bf1d12e26da826ae9608b550875eadc98 Mon Sep 17 00:00:00 2001
From: Yang Bai <yangb at nvidia.com>
Date: Wed, 25 Jun 2025 23:52:55 -0700
Subject: [PATCH 1/3] [mlir] fold vector.from_elements to constant when all
 elements are constants

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   | 24 +++++++++++++++++++++-
 mlir/test/Dialect/Vector/canonicalize.mlir | 14 +++++++++++++
 2 files changed, 37 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 862ed7bae1fbb..9afb443cebc13 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2459,8 +2459,30 @@ static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp) {
   return {};
 }
 
+/// Fold vector.from_elements to a constant when all operands are constants.
+/// Example:
+///   %c1 = arith.constant 1 : i32
+///   %c2 = arith.constant 2 : i32
+///   %v = vector.from_elements %c1, %c2 : vector<2xi32>
+/// =>
+///   %v = arith.constant dense<[1, 2]> : vector<2xi32>
+///
+static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp,
+                                               ArrayRef<Attribute> elements) {
+  if (llvm::any_of(elements, [](Attribute attr) { return !attr; }))
+    return {};
+
+  auto destType = cast<VectorType>(fromElementsOp.getType());
+  return DenseElementsAttr::get(destType, elements);
+}
+
 OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
-  return foldFromElementsToElements(*this);
+  if (auto res = foldFromElementsToElements(*this))
+    return res;
+  if (auto res = foldFromElementsToConstant(*this, adaptor.getElements()))
+    return res;
+
+  return {};
 }
 
 /// Rewrite a vector.from_elements into a vector.splat if all elements are the
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 65b73375831da..d56c64552f9e7 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -3075,6 +3075,20 @@ func.func @from_elements_to_elements_shuffle(%a: vector<4x2xf32>) -> vector<4x2x
 
 // -----
 
+// CHECK-LABEL: func @from_elements_to_constant
+func.func @from_elements_to_constant() -> vector<2x2xi32> {
+  %c0_i32 = arith.constant 0 : i32
+  %c1_i32 = arith.constant 1 : i32
+  %c2_i32 = arith.constant 2 : i32
+  %c3_i32 = arith.constant 3 : i32
+  // CHECK: %[[RES:.*]] = arith.constant dense<{{\[\[0, 1\], \[2, 3\]\]}}> : vector<2x2xi32>
+  %res = vector.from_elements %c0_i32, %c1_i32, %c2_i32, %c3_i32 : vector<2x2xi32>
+  // CHECK: return %[[RES]]
+  return %res : vector<2x2xi32>
+}
+
+// -----
+
 // CHECK-LABEL: func @vector_insert_const_regression(
 //       CHECK:   llvm.mlir.undef
 //       CHECK:   vector.insert

>From 3e64f97033d4aee2133ab81ddd18740da4ed3009 Mon Sep 17 00:00:00 2001
From: Yang Bai <yangb at nvidia.com>
Date: Thu, 26 Jun 2025 00:54:46 -0700
Subject: [PATCH 2/3] remove cast

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 9afb443cebc13..5c5db250585cf 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2472,7 +2472,7 @@ static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp,
   if (llvm::any_of(elements, [](Attribute attr) { return !attr; }))
     return {};
 
-  auto destType = cast<VectorType>(fromElementsOp.getType());
+  auto destType = fromElementsOp.getDest().getType();
   return DenseElementsAttr::get(destType, elements);
 }
 

>From 6551987d5d59c0d71f822cdec13bf985ebc25268 Mon Sep 17 00:00:00 2001
From: Yang Bai <yangb at nvidia.com>
Date: Thu, 26 Jun 2025 03:51:43 -0700
Subject: [PATCH 3/3] apply convertIntegerAttr to avoid crash caused by
 mismatch between attribute type and value type

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 35 +++++++++++++++---------
 1 file changed, 22 insertions(+), 13 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 5c5db250585cf..f08afec432ca5 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -398,6 +398,19 @@ std::optional<int64_t> vector::getConstantVscaleMultiplier(Value value) {
   return {};
 }
 
+/// Converts an IntegerAttr to have the specified type if needed.
+/// This handles cases where constant attributes (e.g., from
+/// `llvm.mlir.constant`) have a different type than the target element type. If
+/// the input attribute is not an IntegerAttr or already has the correct type,
+/// returns it unchanged.
+static Attribute convertIntegerAttr(Attribute attr, Type expectedType) {
+  if (auto intAttr = mlir::dyn_cast<IntegerAttr>(attr)) {
+    if (intAttr.getType() != expectedType)
+      return IntegerAttr::get(expectedType, intAttr.getInt());
+  }
+  return attr;
+}
+
 //===----------------------------------------------------------------------===//
 // CombiningKindAttr
 //===----------------------------------------------------------------------===//
@@ -2472,8 +2485,15 @@ static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp,
   if (llvm::any_of(elements, [](Attribute attr) { return !attr; }))
     return {};
 
-  auto destType = fromElementsOp.getDest().getType();
-  return DenseElementsAttr::get(destType, elements);
+  auto destVecType = fromElementsOp.getDest().getType();
+  auto destEltType = destVecType.getElementType();
+  // Constants from llvm.mlir.constant can have a different type than the return
+  // type. Convert them before creating the dense elements attribute.
+  auto convertedElements = llvm::map_to_vector(elements, [&](Attribute attr) {
+    return convertIntegerAttr(attr, destEltType);
+  });
+
+  return DenseElementsAttr::get(destVecType, convertedElements);
 }
 
 OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
@@ -3344,17 +3364,6 @@ foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr,
 
   /// Converts the expected type to an IntegerAttr if there's
   /// a mismatch.
-  auto convertIntegerAttr = [](Attribute attr, Type expectedType) -> Attribute {
-    if (auto intAttr = mlir::dyn_cast<IntegerAttr>(attr)) {
-      if (intAttr.getType() != expectedType)
-        return IntegerAttr::get(expectedType, intAttr.getInt());
-    }
-    return attr;
-  };
-
-  // The `convertIntegerAttr` method specifically handles the case
-  // for `llvm.mlir.constant` which can hold an attribute with a
-  // different type than the return type.
   if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
     for (auto value : denseSource.getValues<Attribute>())
       insertedValues.push_back(convertIntegerAttr(value, destEltType));



More information about the Mlir-commits mailing list