[Mlir-commits] [mlir] [mlir][Vector] add vector.insert canonicalization pattern for vectors created from ub.poison (PR #142944)

Yang Bai llvmlistbot at llvm.org
Thu Jun 26 08:34:38 PDT 2025


https://github.com/yangtetris updated https://github.com/llvm/llvm-project/pull/142944

>From c511a4ebea4895b707c6a6827ed7f4a54975fa02 Mon Sep 17 00:00:00 2001
From: Yang Bai <yangb at nvidia.com>
Date: Thu, 5 Jun 2025 03:43:23 -0700
Subject: [PATCH 1/3] add vector.insert canonicalization pattern for vectors
 created from ub.poison

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   | 174 +++++++++++++++++----
 mlir/test/Dialect/Vector/canonicalize.mlir |  32 ++++
 2 files changed, 177 insertions(+), 29 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index fcfb401fd9867..253d148072dc0 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3149,6 +3149,42 @@ LogicalResult InsertOp::verify() {
   return success();
 }
 
+// Calculate the linearized position for inserting elements and extract values
+// from the source attribute. Returns the starting position in the destination
+// vector where elements should be inserted.
+static int64_t calculateInsertPositionAndExtractValues(
+    VectorType destTy, const ArrayRef<int64_t> &positions, Attribute srcAttr,
+    SmallVector<Attribute> &valueToInsert) {
+  llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
+  copy(positions, completePositions.begin());
+  int64_t insertBeginPosition =
+      linearize(completePositions, computeStrides(destTy.getShape()));
+
+  Type destEltType = destTy.getElementType();
+
+  /// Converts the expected type to an IntegerAttr if there's
+  /// a mismatch.
+  auto convertIntegerAttr = [](Attribute attr, Type expectedType) -> Attribute {
+    if (auto intAttr = mlir::dyn_cast<IntegerAttr>(attr)) {
+      if (intAttr.getType() != expectedType)
+        return IntegerAttr::get(expectedType, intAttr.getInt());
+    }
+    return attr;
+  };
+
+  // The `convertIntegerAttr` method specifically handles the case
+  // for `llvm.mlir.constant` which can hold an attribute with a
+  // different type than the return type.
+  if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
+    for (auto value : denseSource.getValues<Attribute>())
+      valueToInsert.push_back(convertIntegerAttr(value, destEltType));
+  } else {
+    valueToInsert.push_back(convertIntegerAttr(srcAttr, destEltType));
+  }
+
+  return insertBeginPosition;
+}
+
 namespace {
 
 // If insertOp is only inserting unit dimensions it can be transformed to a
@@ -3191,6 +3227,109 @@ class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
   }
 };
 
+// Pattern to optimize a chain of constant insertions into a poison vector.
+//
+// This pattern identifies chains of vector.insert operations that:
+// 1. Start from an ub.poison operation.
+// 2. Insert only constant values at static positions.
+// 3. Completely initialize all elements in the resulting vector.
+//
+// When these conditions are met, the entire chain can be replaced with a
+// single arith.constant operation containing a dense elements attribute.
+//
+// Example transformation:
+//   %poison = ub.poison : vector<2xi32>
+//   %0 = vector.insert %c1, %poison[0] : i32 into vector<2xi32>
+//   %1 = vector.insert %c2, %0[1] : i32 into vector<2xi32>
+// ->
+//   %result = arith.constant dense<[1, 2]> : vector<2xi32>
+
+// TODO: Support the case where only some elements of the poison vector are set.
+//       Currently, MLIR doesn't support partial poison vectors.
+
+class InsertConstantToPoison final : public OpRewritePattern<InsertOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(InsertOp op,
+                                PatternRewriter &rewriter) const override {
+
+    VectorType destTy = op.getDestVectorType();
+    if (destTy.isScalable())
+      return failure();
+    // Check if the result is used as the dest operand of another vector.insert
+    // Only care about the last op in a chain of insertions.
+    for (Operation *user : op.getResult().getUsers())
+      if (auto insertOp = dyn_cast<InsertOp>(user))
+        if (insertOp.getDest() == op.getResult())
+          return failure();
+
+    InsertOp firstInsertOp;
+    InsertOp previousInsertOp = op;
+    SmallVector<InsertOp> chainInsertOps;
+    SmallVector<Attribute> srcAttrs;
+    while (previousInsertOp) {
+      // Dynamic position is not supported.
+      if (previousInsertOp.hasDynamicPosition())
+        return failure();
+
+      // The inserted content must be constant.
+      chainInsertOps.push_back(previousInsertOp);
+      srcAttrs.push_back(Attribute());
+      matchPattern(previousInsertOp.getValueToStore(),
+                   m_Constant(&srcAttrs.back()));
+      if (!srcAttrs.back())
+        return failure();
+
+      // An insertion at poison index makes the entire chain poisoned.
+      if (is_contained(previousInsertOp.getStaticPosition(),
+                       InsertOp::kPoisonIndex))
+        return failure();
+
+      firstInsertOp = previousInsertOp;
+      previousInsertOp = previousInsertOp.getDest().getDefiningOp<InsertOp>();
+    }
+
+    if (!firstInsertOp.getDest().getDefiningOp<ub::PoisonOp>())
+      return failure();
+
+    // Need to make sure all elements are initialized.
+    int64_t vectorSize = destTy.getNumElements();
+    int64_t initializedCount = 0;
+    SmallVector<bool> initialized(vectorSize, false);
+    SmallVector<Attribute> initValues(vectorSize);
+
+    for (auto [insertOp, srcAttr] : llvm::zip(chainInsertOps, srcAttrs)) {
+      // Calculate the linearized position for inserting elements, as well as
+      // convert the source attribute to the proper type.
+      SmallVector<Attribute> valueToInsert;
+      int64_t insertBeginPosition = calculateInsertPositionAndExtractValues(
+          destTy, insertOp.getStaticPosition(), srcAttr, valueToInsert);
+      for (auto index :
+           llvm::seq<int64_t>(insertBeginPosition,
+                              insertBeginPosition + valueToInsert.size())) {
+        if (initialized[index])
+          continue;
+
+        initialized[index] = true;
+        ++initializedCount;
+        initValues[index] = valueToInsert[index - insertBeginPosition];
+      }
+      // If all elements in the vector have been initialized, we can stop
+      // processing the remaining insert operations in the chain.
+      if (initializedCount == vectorSize)
+        break;
+    }
+
+    // some positions are not initialized.
+    if (initializedCount != vectorSize)
+      return failure();
+
+    auto newAttr = DenseElementsAttr::get(destTy, initValues);
+    rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, destTy, newAttr);
+    return success();
+  }
+};
+
 } // namespace
 
 static Attribute
@@ -3217,35 +3356,11 @@ foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr,
       !insertOp->hasOneUse())
     return {};
 
-  // Calculate the linearized position of the continuous chunk of elements to
-  // insert.
-  llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
-  copy(insertOp.getStaticPosition(), completePositions.begin());
-  int64_t insertBeginPosition =
-      linearize(completePositions, computeStrides(destTy.getShape()));
-
+  // Calculate the linearized position for inserting elements, as well as
+  // convert the source attribute to the proper type.
   SmallVector<Attribute> insertedValues;
-  Type destEltType = destTy.getElementType();
-
-  /// Converts the expected type to an IntegerAttr if there's
-  /// a mismatch.
-  auto convertIntegerAttr = [](Attribute attr, Type expectedType) -> Attribute {
-    if (auto intAttr = mlir::dyn_cast<IntegerAttr>(attr)) {
-      if (intAttr.getType() != expectedType)
-        return IntegerAttr::get(expectedType, intAttr.getInt());
-    }
-    return attr;
-  };
-
-  // The `convertIntegerAttr` method specifically handles the case
-  // for `llvm.mlir.constant` which can hold an attribute with a
-  // different type than the return type.
-  if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
-    for (auto value : denseSource.getValues<Attribute>())
-      insertedValues.push_back(convertIntegerAttr(value, destEltType));
-  } else {
-    insertedValues.push_back(convertIntegerAttr(srcAttr, destEltType));
-  }
+  int64_t insertBeginPosition = calculateInsertPositionAndExtractValues(
+      destTy, insertOp.getStaticPosition(), srcAttr, insertedValues);
 
   auto allValues = llvm::to_vector(denseDst.getValues<Attribute>());
   copy(insertedValues, allValues.begin() + insertBeginPosition);
@@ -3256,7 +3371,8 @@ foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr,
 
 void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                            MLIRContext *context) {
-  results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat>(context);
+  results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
+              InsertConstantToPoison>(context);
 }
 
 OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index a06a9f67d54dc..36f3d7196bb93 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2320,6 +2320,38 @@ func.func @insert_2d_constant() -> (vector<2x3xi32>, vector<2x3xi32>, vector<2x3
 
 // -----
 
+// CHECK-LABEL: func.func @fully_insert_scalar_constant_to_poison_vector
+//       CHECK: %[[VAL0:.+]] = arith.constant dense<[10, 20]> : vector<2xi64>
+//  CHECK-NEXT: return %[[VAL0]]
+func.func @fully_insert_scalar_constant_to_poison_vector() -> vector<2xi64> {
+  %poison = ub.poison : vector<2xi64>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %e0 = arith.constant 10 : i64
+  %e1 = arith.constant 20 : i64
+  %v1 = vector.insert %e0, %poison[%c0] : i64 into vector<2xi64>
+  %v2 = vector.insert %e1, %v1[%c1] : i64 into vector<2xi64>
+  return %v2 : vector<2xi64>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @fully_insert_vector_constant_to_poison_vector
+//       CHECK: %[[VAL0:.+]] = arith.constant dense<{{\[\[1, 2, 3\], \[4, 5, 6\]\]}}> : vector<2x3xi64>
+//  CHECK-NEXT: return %[[VAL0]]
+func.func @fully_insert_vector_constant_to_poison_vector() -> vector<2x3xi64> {
+  %poison = ub.poison : vector<2x3xi64>
+  %cv0 = arith.constant dense<[1, 2, 3]> : vector<3xi64>
+  %cv1 = arith.constant dense<[4, 5, 6]> : vector<3xi64>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %v1 = vector.insert %cv0, %poison[%c0] : vector<3xi64> into vector<2x3xi64>
+  %v2 = vector.insert %cv1, %v1[%c1] : vector<3xi64> into vector<2x3xi64>
+  return %v2 : vector<2x3xi64>
+}
+
+// -----
+
 // CHECK-LABEL: func.func @insert_2d_splat_constant
 //   CHECK-DAG: %[[ACST:.*]] = arith.constant dense<0> : vector<2x3xi32>
 //   CHECK-DAG: %[[BCST:.*]] = arith.constant dense<{{\[\[99, 0, 0\], \[0, 0, 0\]\]}}> : vector<2x3xi32>

>From bba3d6c97c5d23862a75ea60d40cfba467c5932e Mon Sep 17 00:00:00 2001
From: Yang Bai <yangb at nvidia.com>
Date: Wed, 25 Jun 2025 02:08:19 -0700
Subject: [PATCH 2/3] refine comments & add hasOneUse check

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 53 +++++++++++++-----------
 1 file changed, 29 insertions(+), 24 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 253d148072dc0..e744b877f64bf 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3153,8 +3153,8 @@ LogicalResult InsertOp::verify() {
 // from the source attribute. Returns the starting position in the destination
 // vector where elements should be inserted.
 static int64_t calculateInsertPositionAndExtractValues(
-    VectorType destTy, const ArrayRef<int64_t> &positions, Attribute srcAttr,
-    SmallVector<Attribute> &valueToInsert) {
+    VectorType destTy, ArrayRef<int64_t> positions, Attribute srcAttr,
+    SmallVectorImpl<Attribute> &valueToInsert) {
   llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
   copy(positions, completePositions.begin());
   int64_t insertBeginPosition =
@@ -3227,26 +3227,25 @@ class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
   }
 };
 
-// Pattern to optimize a chain of constant insertions into a poison vector.
-//
-// This pattern identifies chains of vector.insert operations that:
-// 1. Start from an ub.poison operation.
-// 2. Insert only constant values at static positions.
-// 3. Completely initialize all elements in the resulting vector.
-//
-// When these conditions are met, the entire chain can be replaced with a
-// single arith.constant operation containing a dense elements attribute.
-//
-// Example transformation:
-//   %poison = ub.poison : vector<2xi32>
-//   %0 = vector.insert %c1, %poison[0] : i32 into vector<2xi32>
-//   %1 = vector.insert %c2, %0[1] : i32 into vector<2xi32>
-// ->
-//   %result = arith.constant dense<[1, 2]> : vector<2xi32>
-
-// TODO: Support the case where only some elements of the poison vector are set.
-//       Currently, MLIR doesn't support partial poison vectors.
-
+/// Pattern to optimize a chain of constant insertions into a poison vector.
+///
+/// This pattern identifies chains of vector.insert operations that:
+/// 1. Start from an ub.poison operation.
+/// 2. Insert only constant values at static positions.
+/// 3. Completely initialize all elements in the resulting vector.
+/// 4. All intermediate insert operations have only one use.
+///
+/// When these conditions are met, the entire chain can be replaced with a
+/// single arith.constant operation containing a dense elements attribute.
+///
+/// Example transformation:
+///   %poison = ub.poison : vector<2xi32>
+///   %0 = vector.insert %c1, %poison[0] : i32 into vector<2xi32>
+///   %1 = vector.insert %c2, %0[1] : i32 into vector<2xi32>
+/// ->
+///   %result = arith.constant dense<[1, 2]> : vector<2xi32>
+/// TODO: Support the case where only some elements of the poison vector are
+/// set. Currently, MLIR doesn't support partial poison vectors.
 class InsertConstantToPoison final : public OpRewritePattern<InsertOp> {
 public:
   using OpRewritePattern::OpRewritePattern;
@@ -3287,12 +3286,18 @@ class InsertConstantToPoison final : public OpRewritePattern<InsertOp> {
 
       firstInsertOp = previousInsertOp;
       previousInsertOp = previousInsertOp.getDest().getDefiningOp<InsertOp>();
+
+      // Check that intermediate inserts have only one use to avoid an explosion
+      // of constants.
+      if (previousInsertOp && !previousInsertOp->hasOneUse())
+        return failure();
     }
 
     if (!firstInsertOp.getDest().getDefiningOp<ub::PoisonOp>())
       return failure();
 
-    // Need to make sure all elements are initialized.
+    // Currently, MLIR doesn't support partial poison vectors, so we can only
+    // optimize when the entire vector is completely initialized.
     int64_t vectorSize = destTy.getNumElements();
     int64_t initializedCount = 0;
     SmallVector<bool> initialized(vectorSize, false);
@@ -3320,7 +3325,7 @@ class InsertConstantToPoison final : public OpRewritePattern<InsertOp> {
         break;
     }
 
-    // some positions are not initialized.
+    // Some positions are not initialized.
     if (initializedCount != vectorSize)
       return failure();
 

>From 9701de857422398c531290262299212988b2166a Mon Sep 17 00:00:00 2001
From: Yang Bai <yangb at nvidia.com>
Date: Thu, 26 Jun 2025 08:10:39 -0700
Subject: [PATCH 3/3] use from_elements to replace arith.constant

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 126 +++++++++++++----------
 1 file changed, 69 insertions(+), 57 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index e744b877f64bf..cc0812f925036 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3149,39 +3149,16 @@ LogicalResult InsertOp::verify() {
   return success();
 }
 
-// Calculate the linearized position for inserting elements and extract values
-// from the source attribute. Returns the starting position in the destination
-// vector where elements should be inserted.
-static int64_t calculateInsertPositionAndExtractValues(
-    VectorType destTy, ArrayRef<int64_t> positions, Attribute srcAttr,
-    SmallVectorImpl<Attribute> &valueToInsert) {
+// Calculate the linearized position of the continuous chunk of elements to
+// insert, based on the shape of the value to insert and the positions to insert
+// at.
+static int64_t calculateInsertPosition(VectorType destTy,
+                                       ArrayRef<int64_t> positions) {
   llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
   copy(positions, completePositions.begin());
   int64_t insertBeginPosition =
       linearize(completePositions, computeStrides(destTy.getShape()));
 
-  Type destEltType = destTy.getElementType();
-
-  /// Converts the expected type to an IntegerAttr if there's
-  /// a mismatch.
-  auto convertIntegerAttr = [](Attribute attr, Type expectedType) -> Attribute {
-    if (auto intAttr = mlir::dyn_cast<IntegerAttr>(attr)) {
-      if (intAttr.getType() != expectedType)
-        return IntegerAttr::get(expectedType, intAttr.getInt());
-    }
-    return attr;
-  };
-
-  // The `convertIntegerAttr` method specifically handles the case
-  // for `llvm.mlir.constant` which can hold an attribute with a
-  // different type than the return type.
-  if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
-    for (auto value : denseSource.getValues<Attribute>())
-      valueToInsert.push_back(convertIntegerAttr(value, destEltType));
-  } else {
-    valueToInsert.push_back(convertIntegerAttr(srcAttr, destEltType));
-  }
-
   return insertBeginPosition;
 }
 
@@ -3231,7 +3208,7 @@ class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
 ///
 /// This pattern identifies chains of vector.insert operations that:
 /// 1. Start from an ub.poison operation.
-/// 2. Insert only constant values at static positions.
+/// 2. Only insert values at static positions.
 /// 3. Completely initialize all elements in the resulting vector.
 /// 4. All intermediate insert operations have only one use.
 ///
@@ -3265,7 +3242,6 @@ class InsertConstantToPoison final : public OpRewritePattern<InsertOp> {
     InsertOp firstInsertOp;
     InsertOp previousInsertOp = op;
     SmallVector<InsertOp> chainInsertOps;
-    SmallVector<Attribute> srcAttrs;
     while (previousInsertOp) {
       // Dynamic position is not supported.
       if (previousInsertOp.hasDynamicPosition())
@@ -3273,22 +3249,12 @@ class InsertConstantToPoison final : public OpRewritePattern<InsertOp> {
 
       // The inserted content must be constant.
       chainInsertOps.push_back(previousInsertOp);
-      srcAttrs.push_back(Attribute());
-      matchPattern(previousInsertOp.getValueToStore(),
-                   m_Constant(&srcAttrs.back()));
-      if (!srcAttrs.back())
-        return failure();
-
-      // An insertion at poison index makes the entire chain poisoned.
-      if (is_contained(previousInsertOp.getStaticPosition(),
-                       InsertOp::kPoisonIndex))
-        return failure();
 
       firstInsertOp = previousInsertOp;
       previousInsertOp = previousInsertOp.getDest().getDefiningOp<InsertOp>();
 
       // Check that intermediate inserts have only one use to avoid an explosion
-      // of constants.
+      // of vectors.
       if (previousInsertOp && !previousInsertOp->hasOneUse())
         return failure();
     }
@@ -3301,23 +3267,50 @@ class InsertConstantToPoison final : public OpRewritePattern<InsertOp> {
     int64_t vectorSize = destTy.getNumElements();
     int64_t initializedCount = 0;
     SmallVector<bool> initialized(vectorSize, false);
-    SmallVector<Attribute> initValues(vectorSize);
-
-    for (auto [insertOp, srcAttr] : llvm::zip(chainInsertOps, srcAttrs)) {
-      // Calculate the linearized position for inserting elements, as well as
-      // convert the source attribute to the proper type.
-      SmallVector<Attribute> valueToInsert;
-      int64_t insertBeginPosition = calculateInsertPositionAndExtractValues(
-          destTy, insertOp.getStaticPosition(), srcAttr, valueToInsert);
+    SmallVector<Value> elements(vectorSize);
+
+    for (auto insertOp : chainInsertOps) {
+      // The insert op folder will fold an insert at poison index into a
+      // ub.poison, which truncates the insert chain's backward traversal.
+      if (is_contained(previousInsertOp.getStaticPosition(),
+                       InsertOp::kPoisonIndex))
+        return failure();
+
+      // Calculate the linearized position for inserting elements.
+      int64_t insertBeginPosition =
+          calculateInsertPosition(destTy, insertOp.getStaticPosition());
+
+      // The valueToStore operand may be a vector or a scalar. Need to handle
+      // both cases.
+      SmallVector<Value> elementsToInsert;
+      int64_t elementsToInsertSize = 1;
+      if (auto srcVectorType =
+              llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType())) {
+
+        elementsToInsertSize = srcVectorType.getNumElements();
+        elementsToInsert.reserve(elementsToInsertSize);
+        SmallVector<int64_t> strides = computeStrides(srcVectorType.getShape());
+        // Get all elements from the vector in row-major order.
+        for (int64_t linearIdx = 0; linearIdx < elementsToInsertSize;
+             linearIdx++) {
+          SmallVector<int64_t> position = delinearize(linearIdx, strides);
+          Value extractedElement = rewriter.create<vector::ExtractOp>(
+              insertOp.getLoc(), insertOp.getValueToStore(), position);
+          elementsToInsert.push_back(extractedElement);
+        }
+      } else {
+        elementsToInsert.push_back(insertOp.getValueToStore());
+      }
+
       for (auto index :
            llvm::seq<int64_t>(insertBeginPosition,
-                              insertBeginPosition + valueToInsert.size())) {
+                              insertBeginPosition + elementsToInsertSize)) {
         if (initialized[index])
           continue;
 
         initialized[index] = true;
         ++initializedCount;
-        initValues[index] = valueToInsert[index - insertBeginPosition];
+        elements[index] = elementsToInsert[index - insertBeginPosition];
       }
       // If all elements in the vector have been initialized, we can stop
       // processing the remaining insert operations in the chain.
@@ -3329,8 +3322,7 @@ class InsertConstantToPoison final : public OpRewritePattern<InsertOp> {
     if (initializedCount != vectorSize)
       return failure();
 
-    auto newAttr = DenseElementsAttr::get(destTy, initValues);
-    rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, destTy, newAttr);
+    rewriter.replaceOpWithNewOp<vector::FromElementsOp>(op, destTy, elements);
     return success();
   }
 };
@@ -3361,11 +3353,31 @@ foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr,
       !insertOp->hasOneUse())
     return {};
 
-  // Calculate the linearized position for inserting elements, as well as
-  // convert the source attribute to the proper type.
+  // Calculate the linearized position for inserting elements.
+  int64_t insertBeginPosition =
+      calculateInsertPosition(destTy, insertOp.getStaticPosition());
   SmallVector<Attribute> insertedValues;
-  int64_t insertBeginPosition = calculateInsertPositionAndExtractValues(
-      destTy, insertOp.getStaticPosition(), srcAttr, insertedValues);
+  Type destEltType = destTy.getElementType();
+
+  /// Converts the expected type to an IntegerAttr if there's
+  /// a mismatch.
+  auto convertIntegerAttr = [](Attribute attr, Type expectedType) -> Attribute {
+    if (auto intAttr = mlir::dyn_cast<IntegerAttr>(attr)) {
+      if (intAttr.getType() != expectedType)
+        return IntegerAttr::get(expectedType, intAttr.getInt());
+    }
+    return attr;
+  };
+
+  // The `convertIntegerAttr` method specifically handles the case
+  // for `llvm.mlir.constant` which can hold an attribute with a
+  // different type than the return type.
+  if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
+    for (auto value : denseSource.getValues<Attribute>())
+      insertedValues.push_back(convertIntegerAttr(value, destEltType));
+  } else {
+    insertedValues.push_back(convertIntegerAttr(srcAttr, destEltType));
+  }
 
   auto allValues = llvm::to_vector(denseDst.getValues<Attribute>());
   copy(insertedValues, allValues.begin() + insertBeginPosition);



More information about the Mlir-commits mailing list