[Mlir-commits] [mlir] [mlir][Vector] add vector.insert canonicalization pattern for vectors created from ub.poison (PR #142944)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Fri Aug 1 13:15:51 PDT 2025
================
@@ -3250,6 +3262,128 @@ class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
return success();
}
};
+
+/// Pattern to optimize a chain of insertions.
+///
+/// This pattern identifies chains of vector.insert operations that:
+/// 1. Only insert values at static positions.
+/// 2. Completely initialize all elements in the resulting vector.
+/// 3. All intermediate insert operations have only one use.
+///
+/// When these conditions are met, the entire chain can be replaced with a
+/// single vector.from_elements operation.
+///
+/// Example transformation:
+/// %poison = ub.poison : vector<2xi32>
+/// %0 = vector.insert %c1, %poison[0] : i32 into vector<2xi32>
+/// %1 = vector.insert %c2, %0[1] : i32 into vector<2xi32>
+/// ->
+/// %result = vector.from_elements %c1, %c2 : vector<2xi32>
+class InsertChainFullyInitialized final : public OpRewritePattern<InsertOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(InsertOp op,
+ PatternRewriter &rewriter) const override {
+
+ VectorType destTy = op.getDestVectorType();
+ if (destTy.isScalable())
+ return failure();
+ // Check if the result is used as the dest operand of another vector.insert
+ // Only care about the last op in a chain of insertions.
+ for (Operation *user : op.getResult().getUsers())
+ if (auto insertOp = dyn_cast<InsertOp>(user))
+ if (insertOp.getDest() == op.getResult())
+ return failure();
+
+ InsertOp currentOp = op;
+ SmallVector<InsertOp> chainInsertOps;
+ while (currentOp) {
+ // Dynamic position is not supported.
+ if (currentOp.hasDynamicPosition())
+ return failure();
+
+ chainInsertOps.push_back(currentOp);
+ currentOp = currentOp.getDest().getDefiningOp<InsertOp>();
+ // Check that intermediate inserts have only one use to avoid an explosion
+ // of vectors.
+ if (currentOp && !currentOp->hasOneUse())
+ return failure();
+ }
+
+ int64_t vectorSize = destTy.getNumElements();
+ int64_t initializedCount = 0;
+ SmallVector<bool> initialized(vectorSize, false);
+ SmallVector<int64_t> pendingInsertPos;
+ SmallVector<int64_t> pendingInsertSize;
+ SmallVector<Value> pendingInsertValues;
+
+ for (auto insertOp : chainInsertOps) {
+ // The insert op folder will fold an insert at poison index into a
+ // ub.poison, which truncates the insert chain's backward traversal.
+ if (is_contained(insertOp.getStaticPosition(), InsertOp::kPoisonIndex))
+ return failure();
+
+ // Calculate the linearized position for inserting elements.
+ int64_t insertBeginPosition =
+ calculateInsertPosition(destTy, insertOp.getStaticPosition());
+
+ // The valueToStore operand may be a vector or a scalar. Need to handle
+ // both cases.
+ int64_t insertSize = 1;
+ if (auto srcVectorType =
+ llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType()))
+ insertSize = srcVectorType.getNumElements();
+
+ assert(insertBeginPosition + insertSize <= vectorSize &&
+ "insert would overflow the vector");
+
+ for (auto index : llvm::seq<int64_t>(insertBeginPosition,
+ insertBeginPosition + insertSize)) {
+ if (initialized[index])
+ continue;
+ initialized[index] = true;
+ ++initializedCount;
+ }
+
+ // Defer the creation of ops before we can make sure the pattern can
+ // succeed.
+ pendingInsertPos.push_back(insertBeginPosition);
+ pendingInsertSize.push_back(insertSize);
+ pendingInsertValues.push_back(insertOp.getValueToStore());
+
+ if (initializedCount == vectorSize)
+ break;
+ }
+
+ // Final check: all positions must be initialized
+ if (initializedCount != vectorSize)
+ return failure();
+
+ SmallVector<Value> elements(vectorSize);
+ for (auto [insertBeginPosition, insertSize, valueToStore] :
+ llvm::reverse(llvm::zip(pendingInsertPos, pendingInsertSize,
+ pendingInsertValues))) {
+ if (auto srcVectorType =
+ llvm::dyn_cast<VectorType>(valueToStore.getType())) {
+ SmallVector<Type> elementToInsertTypes(insertSize,
+ srcVectorType.getElementType());
+ auto elementsToInsert = rewriter.create<vector::ToElementsOp>(
+ op.getLoc(), elementToInsertTypes, valueToStore);
+ // Get all elements from the vector in row-major order.
+ for (int64_t linearIdx = 0; linearIdx < insertSize; linearIdx++) {
+ elements[insertBeginPosition + linearIdx] =
+ elementsToInsert.getResult(linearIdx);
+ }
+ } else {
+ elements[insertBeginPosition] = valueToStore;
+ }
----------------
banach-space wrote:
```suggestion
auto srcVectorType =
llvm::dyn_cast<VectorType>(valueToStore.getType()));
if (!srcVectorType) {
elements[insertBeginPosition] = valueToStore;
continue;
}
SmallVector<Type> elementToInsertTypes(insertSize,
srcVectorType.getElementType());
auto elementsToInsert = rewriter.create<vector::ToElementsOp>(
op.getLoc(), elementToInsertTypes, valueToStore);
// Get all elements from the vector in row-major order.
for (int64_t linearIdx = 0; linearIdx < insertSize; linearIdx++) {
elements[insertBeginPosition + linearIdx] =
elementsToInsert.getResult(linearIdx);
}
```
Prefer early exits: https://llvm.org/docs/CodingStandards.html#use-early-exits-and-continue-to-simplify-code
https://github.com/llvm/llvm-project/pull/142944
More information about the Mlir-commits
mailing list