[Mlir-commits] [mlir] [mlir][Vector] add vector.insert canonicalization pattern for vectors created from ub.poison (PR #142944)

Yang Bai llvmlistbot at llvm.org
Wed Jun 25 02:30:52 PDT 2025


================
@@ -3191,6 +3227,114 @@ class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
   }
 };
 
+/// Pattern to optimize a chain of constant insertions into a poison vector.
+///
+/// This pattern identifies chains of vector.insert operations that:
+/// 1. Start from an ub.poison operation.
+/// 2. Insert only constant values at static positions.
+/// 3. Completely initialize all elements in the resulting vector.
+/// 4. All intermediate insert operations have only one use.
+///
+/// When these conditions are met, the entire chain can be replaced with a
+/// single arith.constant operation containing a dense elements attribute.
+///
+/// Example transformation:
+///   %poison = ub.poison : vector<2xi32>
+///   %0 = vector.insert %c1, %poison[0] : i32 into vector<2xi32>
+///   %1 = vector.insert %c2, %0[1] : i32 into vector<2xi32>
+/// ->
+///   %result = arith.constant dense<[1, 2]> : vector<2xi32>
+/// TODO: Support the case where only some elements of the poison vector are
+/// set. Currently, MLIR doesn't support partial poison vectors.
+class InsertConstantToPoison final : public OpRewritePattern<InsertOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(InsertOp op,
+                                PatternRewriter &rewriter) const override {
+
+    VectorType destTy = op.getDestVectorType();
+    if (destTy.isScalable())
+      return failure();
+    // Check if the result is used as the dest operand of another vector.insert
+    // Only care about the last op in a chain of insertions.
+    for (Operation *user : op.getResult().getUsers())
+      if (auto insertOp = dyn_cast<InsertOp>(user))
+        if (insertOp.getDest() == op.getResult())
+          return failure();
+
+    InsertOp firstInsertOp;
+    InsertOp previousInsertOp = op;
+    SmallVector<InsertOp> chainInsertOps;
+    SmallVector<Attribute> srcAttrs;
+    while (previousInsertOp) {
+      // Dynamic position is not supported.
+      if (previousInsertOp.hasDynamicPosition())
+        return failure();
+
+      // The inserted content must be constant.
+      chainInsertOps.push_back(previousInsertOp);
+      srcAttrs.push_back(Attribute());
+      matchPattern(previousInsertOp.getValueToStore(),
+                   m_Constant(&srcAttrs.back()));
+      if (!srcAttrs.back())
+        return failure();
+
+      // An insertion at poison index makes the entire chain poisoned.
+      if (is_contained(previousInsertOp.getStaticPosition(),
+                       InsertOp::kPoisonIndex))
+        return failure();
+
+      firstInsertOp = previousInsertOp;
+      previousInsertOp = previousInsertOp.getDest().getDefiningOp<InsertOp>();
+
+      // Check that intermediate inserts have only one use to avoid an explosion
+      // of constants.
+      if (previousInsertOp && !previousInsertOp->hasOneUse())
+        return failure();
+    }
+
+    if (!firstInsertOp.getDest().getDefiningOp<ub::PoisonOp>())
+      return failure();
+
+    // Currently, MLIR doesn't support partial poison vectors, so we can only
+    // optimize when the entire vector is completely initialized.
+    int64_t vectorSize = destTy.getNumElements();
+    int64_t initializedCount = 0;
+    SmallVector<bool> initialized(vectorSize, false);
+    SmallVector<Attribute> initValues(vectorSize);
+
+    for (auto [insertOp, srcAttr] : llvm::zip(chainInsertOps, srcAttrs)) {
+      // Calculate the linearized position for inserting elements, as well as
+      // convert the source attribute to the proper type.
+      SmallVector<Attribute> valueToInsert;
+      int64_t insertBeginPosition = calculateInsertPositionAndExtractValues(
+          destTy, insertOp.getStaticPosition(), srcAttr, valueToInsert);
+      for (auto index :
+           llvm::seq<int64_t>(insertBeginPosition,
+                              insertBeginPosition + valueToInsert.size())) {
+        if (initialized[index])
+          continue;
+
+        initialized[index] = true;
+        ++initializedCount;
+        initValues[index] = valueToInsert[index - insertBeginPosition];
+      }
+      // If all elements in the vector have been initialized, we can stop
+      // processing the remaining insert operations in the chain.
+      if (initializedCount == vectorSize)
+        break;
+    }
+
+    // Some positions are not initialized.
+    if (initializedCount != vectorSize)
+      return failure();
+
+    auto newAttr = DenseElementsAttr::get(destTy, initValues);
+    rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, destTy, newAttr);
----------------
yangtetris wrote:

Great idea! That would indeed remove the constant-only restriction. Let me try implementing it. Thanks!

https://github.com/llvm/llvm-project/pull/142944


More information about the Mlir-commits mailing list