[Mlir-commits] [mlir] [mlir][Vector] add vector.insert canonicalization pattern for vectors created from ub.poison (PR #142944)
Diego Caballero
llvmlistbot at llvm.org
Tue Jun 24 11:59:31 PDT 2025
================
@@ -3191,6 +3227,109 @@ class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
}
};
+// Pattern to optimize a chain of constant insertions into a poison vector.
+//
+// This pattern identifies chains of vector.insert operations that:
+// 1. Start from an ub.poison operation.
+// 2. Insert only constant values at static positions.
+// 3. Completely initialize all elements in the resulting vector.
+//
+// When these conditions are met, the entire chain can be replaced with a
+// single arith.constant operation containing a dense elements attribute.
+//
+// Example transformation:
+// %poison = ub.poison : vector<2xi32>
+// %0 = vector.insert %c1, %poison[0] : i32 into vector<2xi32>
+// %1 = vector.insert %c2, %0[1] : i32 into vector<2xi32>
+// ->
+// %result = arith.constant dense<[1, 2]> : vector<2xi32>
+
+// TODO: Support the case where only some elements of the poison vector are set.
+// Currently, MLIR doesn't support partial poison vectors.
+
+class InsertConstantToPoison final : public OpRewritePattern<InsertOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(InsertOp op,
+ PatternRewriter &rewriter) const override {
+
+ VectorType destTy = op.getDestVectorType();
+ if (destTy.isScalable())
+ return failure();
+ // Check if the result is used as the dest operand of another vector.insert
+ // Only care about the last op in a chain of insertions.
+ for (Operation *user : op.getResult().getUsers())
+ if (auto insertOp = dyn_cast<InsertOp>(user))
+ if (insertOp.getDest() == op.getResult())
+ return failure();
+
+ InsertOp firstInsertOp;
+ InsertOp previousInsertOp = op;
+ SmallVector<InsertOp> chainInsertOps;
+ SmallVector<Attribute> srcAttrs;
+ while (previousInsertOp) {
+ // Dynamic position is not supported.
+ if (previousInsertOp.hasDynamicPosition())
+ return failure();
+
+ // The inserted content must be constant.
+ chainInsertOps.push_back(previousInsertOp);
+ srcAttrs.push_back(Attribute());
+ matchPattern(previousInsertOp.getValueToStore(),
+ m_Constant(&srcAttrs.back()));
+ if (!srcAttrs.back())
+ return failure();
+
+ // An insertion at poison index makes the entire chain poisoned.
+ if (is_contained(previousInsertOp.getStaticPosition(),
+ InsertOp::kPoisonIndex))
+ return failure();
+
+ firstInsertOp = previousInsertOp;
+ previousInsertOp = previousInsertOp.getDest().getDefiningOp<InsertOp>();
+ }
+
+ if (!firstInsertOp.getDest().getDefiningOp<ub::PoisonOp>())
+ return failure();
+
+ // Need to make sure all elements are initialized.
+ int64_t vectorSize = destTy.getNumElements();
+ int64_t initializedCount = 0;
+ SmallVector<bool> initialized(vectorSize, false);
+ SmallVector<Attribute> initValues(vectorSize);
+
+ for (auto [insertOp, srcAttr] : llvm::zip(chainInsertOps, srcAttrs)) {
+ // Calculate the linearized position for inserting elements, as well as
+ // convert the source attribute to the proper type.
+ SmallVector<Attribute> valueToInsert;
+ int64_t insertBeginPosition = calculateInsertPositionAndExtractValues(
+ destTy, insertOp.getStaticPosition(), srcAttr, valueToInsert);
+ for (auto index :
+ llvm::seq<int64_t>(insertBeginPosition,
+ insertBeginPosition + valueToInsert.size())) {
+ if (initialized[index])
+ continue;
+
+ initialized[index] = true;
+ ++initializedCount;
+ initValues[index] = valueToInsert[index - insertBeginPosition];
+ }
+ // If all elements in the vector have been initialized, we can stop
+ // processing the remaining insert operations in the chain.
+ if (initializedCount == vectorSize)
+ break;
+ }
+
+ // some positions are not initialized.
----------------
dcaballe wrote:
nit: `some` -> `Some`
https://github.com/llvm/llvm-project/pull/142944
More information about the Mlir-commits
mailing list