[Mlir-commits] [mlir] [mlir][Vector] add vector.insert canonicalization pattern for vectors created from ub.poison (PR #142944)

Yang Bai llvmlistbot at llvm.org
Fri Jul 25 08:49:29 PDT 2025


================
@@ -3191,6 +3203,126 @@ class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
   }
 };
 
+/// Pattern to optimize a chain of insertions.
+///
+/// This pattern identifies chains of vector.insert operations that:
+/// 1. Only insert values at static positions.
+/// 2. Completely initialize all elements in the resulting vector.
+/// 3. All intermediate insert operations have only one use.
+///
+/// When these conditions are met, the entire chain can be replaced with a
+/// single vector.from_elements operation.
+///
+/// Example transformation:
+///   %poison = ub.poison : vector<2xi32>
+///   %0 = vector.insert %c1, %poison[0] : i32 into vector<2xi32>
+///   %1 = vector.insert %c2, %0[1] : i32 into vector<2xi32>
+/// ->
+///   %result = vector.from_elements %c1, %c2 : vector<2xi32>
+class InsertChainFullyInitialized final : public OpRewritePattern<InsertOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(InsertOp op,
+                                PatternRewriter &rewriter) const override {
+
+    VectorType destTy = op.getDestVectorType();
+    if (destTy.isScalable())
+      return failure();
+    // Check if the result is used as the dest operand of another vector.insert
+    // Only care about the last op in a chain of insertions.
+    for (Operation *user : op.getResult().getUsers())
+      if (auto insertOp = dyn_cast<InsertOp>(user))
+        if (insertOp.getDest() == op.getResult())
+          return failure();
+
+    InsertOp currentOp = op;
+    SmallVector<InsertOp> chainInsertOps;
+    while (currentOp) {
+      // Dynamic position is not supported.
+      if (currentOp.hasDynamicPosition())
+        return failure();
+
+      chainInsertOps.push_back(currentOp);
+      currentOp = currentOp.getDest().getDefiningOp<InsertOp>();
+      // Check that intermediate inserts have only one use to avoid an explosion
+      // of vectors.
+      if (currentOp && !currentOp->hasOneUse())
+        return failure();
+    }
+
+    int64_t vectorSize = destTy.getNumElements();
+    int64_t initializedCount = 0;
+    SmallVector<bool> initialized(vectorSize, false);
+    SmallVector<int64_t> pendingInsertPos;
+    SmallVector<int64_t> pendingInsertSize;
+    SmallVector<Value> pendingInsertValues;
+
+    for (auto insertOp : chainInsertOps) {
+      // The insert op folder will fold an insert at poison index into a
+      // ub.poison, which truncates the insert chain's backward traversal.
+      if (is_contained(insertOp.getStaticPosition(), InsertOp::kPoisonIndex))
+        return failure();
+
+      // Calculate the linearized position for inserting elements.
+      int64_t insertBeginPosition =
+          calculateInsertPosition(destTy, insertOp.getStaticPosition());
+
+      // The valueToStore operand may be a vector or a scalar. Need to handle
+      // both cases.
+      int64_t insertSize = 1;
+      if (auto srcVectorType =
+              llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType()))
+        insertSize = srcVectorType.getNumElements();
+
+      assert(insertBeginPosition + insertSize <= vectorSize &&
+             "insert would overflow the vector");
+
+      for (auto index : llvm::seq<int64_t>(insertBeginPosition,
+                                           insertBeginPosition + insertSize)) {
+        if (initialized[index])
+          continue;
+        initialized[index] = true;
+        ++initializedCount;
+      }
+
+      // Defer the creation of ops before we can make sure the pattern can
+      // succeed.
+      pendingInsertPos.push_back(insertBeginPosition);
+      pendingInsertSize.push_back(insertSize);
+      pendingInsertValues.push_back(insertOp.getValueToStore());
+
+      if (initializedCount == vectorSize)
+        break;
+    }
+
+    // Final check: all positions must be initialized
+    if (initializedCount != vectorSize)
+      return failure();
+
+    SmallVector<Value> elements(vectorSize);
+    for (auto [insertBeginPosition, insertSize, valueToStore] :
+         llvm::reverse(llvm::zip(pendingInsertPos, pendingInsertSize,
+                                 pendingInsertValues))) {
+      if (auto srcVectorType =
+              llvm::dyn_cast<VectorType>(valueToStore.getType())) {
+        SmallVector<int64_t> strides = computeStrides(srcVectorType.getShape());
+        // Get all elements from the vector in row-major order.
+        for (int64_t linearIdx = 0; linearIdx < insertSize; linearIdx++) {
+          SmallVector<int64_t> position = delinearize(linearIdx, strides);
+          Value extractedElement = rewriter.create<vector::ExtractOp>(
----------------
yangtetris wrote:

I think you mean `vector::ToElementsOp`. That's a good point.

https://github.com/llvm/llvm-project/pull/142944


More information about the Mlir-commits mailing list