[Mlir-commits] [mlir] [mlir][Vector] add vector.insert canonicalization pattern for vectors created from ub.poison (PR #142944)
Yang Bai
llvmlistbot at llvm.org
Thu Jun 5 03:54:34 PDT 2025
https://github.com/yangtetris created https://github.com/llvm/llvm-project/pull/142944
## Description
This change introduces a new canonicalization pattern for the MLIR Vector dialect that optimizes chains of constant insertions into vectors initialized with `ub.poison`. The optimization identifies when a vector is **completely** initialized through a series of vector.insert operations and replaces the entire chain with a single `arith.constant `operation.
Please be aware that the new pattern **doesn't** work for poison vectors where only **some** elements are set, as MLIR doesn't support partial poison vectors for now.
**New Pattern: InsertConstantToPoison**
* Detects chains of vector.insert operations that start from an ub.poison operation.
* Validates that all insertions use constant values at static positions.
* Ensures the entire vector is **completely** initialized.
* Replaces the entire chain with a single arith.constant operation containing a DenseElementsAttr.
**Refactored Helper Function**
* Extracted `calculateInsertPositionAndExtractValues` from `foldDenseElementsAttrDestInsertOp` to avoid code duplication.
## Example
```
// Before:
%poison = ub.poison : vector<2xi64>
%v1 = vector.insert %c10, %poison[0] : i64 into vector<2xi64>
%v2 = vector.insert %c20, %v1[1] : i64 into vector<2xi64>
// After:
%result = arith.constant dense<[10, 20]> : vector<2xi64>
```
It also works for multidimensional vectors.
```
// Before:
%poison = ub.poison : vector<2x3xi64>
%cv0 = arith.constant dense<[1, 2, 3]> : vector<3xi64>
%cv1 = arith.constant dense<[4, 5, 6]> : vector<3xi64>
%v1 = vector.insert %cv0, %poison[0] : vector<3xi64> into vector<2x3xi64>
%v2 = vector.insert %cv1, %v1[1] : vector<3xi64> into vector<2x3xi64>
// After:
%result = arith.constant dense<[[1, 2, 3], [4, 5, 6]]> : vector<2x3xi64>
>From c511a4ebea4895b707c6a6827ed7f4a54975fa02 Mon Sep 17 00:00:00 2001
From: Yang Bai <yangb at nvidia.com>
Date: Thu, 5 Jun 2025 03:43:23 -0700
Subject: [PATCH] add vector.insert canonicalization pattern for vectors
created from ub.poison
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 174 +++++++++++++++++----
mlir/test/Dialect/Vector/canonicalize.mlir | 32 ++++
2 files changed, 177 insertions(+), 29 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index fcfb401fd9867..253d148072dc0 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3149,6 +3149,42 @@ LogicalResult InsertOp::verify() {
return success();
}
+// Calculate the linearized position for inserting elements and extract values
+// from the source attribute. Returns the starting position in the destination
+// vector where elements should be inserted.
+static int64_t calculateInsertPositionAndExtractValues(
+ VectorType destTy, const ArrayRef<int64_t> &positions, Attribute srcAttr,
+ SmallVector<Attribute> &valueToInsert) {
+ llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
+ copy(positions, completePositions.begin());
+ int64_t insertBeginPosition =
+ linearize(completePositions, computeStrides(destTy.getShape()));
+
+ Type destEltType = destTy.getElementType();
+
+ /// Converts the expected type to an IntegerAttr if there's
+ /// a mismatch.
+ auto convertIntegerAttr = [](Attribute attr, Type expectedType) -> Attribute {
+ if (auto intAttr = mlir::dyn_cast<IntegerAttr>(attr)) {
+ if (intAttr.getType() != expectedType)
+ return IntegerAttr::get(expectedType, intAttr.getInt());
+ }
+ return attr;
+ };
+
+ // The `convertIntegerAttr` method specifically handles the case
+ // for `llvm.mlir.constant` which can hold an attribute with a
+ // different type than the return type.
+ if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
+ for (auto value : denseSource.getValues<Attribute>())
+ valueToInsert.push_back(convertIntegerAttr(value, destEltType));
+ } else {
+ valueToInsert.push_back(convertIntegerAttr(srcAttr, destEltType));
+ }
+
+ return insertBeginPosition;
+}
+
namespace {
// If insertOp is only inserting unit dimensions it can be transformed to a
@@ -3191,6 +3227,109 @@ class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
}
};
+// Pattern to optimize a chain of constant insertions into a poison vector.
+//
+// This pattern identifies chains of vector.insert operations that:
+// 1. Start from an ub.poison operation.
+// 2. Insert only constant values at static positions.
+// 3. Completely initialize all elements in the resulting vector.
+//
+// When these conditions are met, the entire chain can be replaced with a
+// single arith.constant operation containing a dense elements attribute.
+//
+// Example transformation:
+// %poison = ub.poison : vector<2xi32>
+// %0 = vector.insert %c1, %poison[0] : i32 into vector<2xi32>
+// %1 = vector.insert %c2, %0[1] : i32 into vector<2xi32>
+// ->
+// %result = arith.constant dense<[1, 2]> : vector<2xi32>
+
+// TODO: Support the case where only some elements of the poison vector are set.
+// Currently, MLIR doesn't support partial poison vectors.
+
+class InsertConstantToPoison final : public OpRewritePattern<InsertOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(InsertOp op,
+ PatternRewriter &rewriter) const override {
+
+ VectorType destTy = op.getDestVectorType();
+ if (destTy.isScalable())
+ return failure();
+ // Check if the result is used as the dest operand of another vector.insert
+ // Only care about the last op in a chain of insertions.
+ for (Operation *user : op.getResult().getUsers())
+ if (auto insertOp = dyn_cast<InsertOp>(user))
+ if (insertOp.getDest() == op.getResult())
+ return failure();
+
+ InsertOp firstInsertOp;
+ InsertOp previousInsertOp = op;
+ SmallVector<InsertOp> chainInsertOps;
+ SmallVector<Attribute> srcAttrs;
+ while (previousInsertOp) {
+ // Dynamic position is not supported.
+ if (previousInsertOp.hasDynamicPosition())
+ return failure();
+
+ // The inserted content must be constant.
+ chainInsertOps.push_back(previousInsertOp);
+ srcAttrs.push_back(Attribute());
+ matchPattern(previousInsertOp.getValueToStore(),
+ m_Constant(&srcAttrs.back()));
+ if (!srcAttrs.back())
+ return failure();
+
+ // An insertion at poison index makes the entire chain poisoned.
+ if (is_contained(previousInsertOp.getStaticPosition(),
+ InsertOp::kPoisonIndex))
+ return failure();
+
+ firstInsertOp = previousInsertOp;
+ previousInsertOp = previousInsertOp.getDest().getDefiningOp<InsertOp>();
+ }
+
+ if (!firstInsertOp.getDest().getDefiningOp<ub::PoisonOp>())
+ return failure();
+
+ // Need to make sure all elements are initialized.
+ int64_t vectorSize = destTy.getNumElements();
+ int64_t initializedCount = 0;
+ SmallVector<bool> initialized(vectorSize, false);
+ SmallVector<Attribute> initValues(vectorSize);
+
+ for (auto [insertOp, srcAttr] : llvm::zip(chainInsertOps, srcAttrs)) {
+ // Calculate the linearized position for inserting elements, as well as
+ // convert the source attribute to the proper type.
+ SmallVector<Attribute> valueToInsert;
+ int64_t insertBeginPosition = calculateInsertPositionAndExtractValues(
+ destTy, insertOp.getStaticPosition(), srcAttr, valueToInsert);
+ for (auto index :
+ llvm::seq<int64_t>(insertBeginPosition,
+ insertBeginPosition + valueToInsert.size())) {
+ if (initialized[index])
+ continue;
+
+ initialized[index] = true;
+ ++initializedCount;
+ initValues[index] = valueToInsert[index - insertBeginPosition];
+ }
+ // If all elements in the vector have been initialized, we can stop
+ // processing the remaining insert operations in the chain.
+ if (initializedCount == vectorSize)
+ break;
+ }
+
+ // some positions are not initialized.
+ if (initializedCount != vectorSize)
+ return failure();
+
+ auto newAttr = DenseElementsAttr::get(destTy, initValues);
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, destTy, newAttr);
+ return success();
+ }
+};
+
} // namespace
static Attribute
@@ -3217,35 +3356,11 @@ foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr,
!insertOp->hasOneUse())
return {};
- // Calculate the linearized position of the continuous chunk of elements to
- // insert.
- llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
- copy(insertOp.getStaticPosition(), completePositions.begin());
- int64_t insertBeginPosition =
- linearize(completePositions, computeStrides(destTy.getShape()));
-
+ // Calculate the linearized position for inserting elements, as well as
+ // convert the source attribute to the proper type.
SmallVector<Attribute> insertedValues;
- Type destEltType = destTy.getElementType();
-
- /// Converts the expected type to an IntegerAttr if there's
- /// a mismatch.
- auto convertIntegerAttr = [](Attribute attr, Type expectedType) -> Attribute {
- if (auto intAttr = mlir::dyn_cast<IntegerAttr>(attr)) {
- if (intAttr.getType() != expectedType)
- return IntegerAttr::get(expectedType, intAttr.getInt());
- }
- return attr;
- };
-
- // The `convertIntegerAttr` method specifically handles the case
- // for `llvm.mlir.constant` which can hold an attribute with a
- // different type than the return type.
- if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
- for (auto value : denseSource.getValues<Attribute>())
- insertedValues.push_back(convertIntegerAttr(value, destEltType));
- } else {
- insertedValues.push_back(convertIntegerAttr(srcAttr, destEltType));
- }
+ int64_t insertBeginPosition = calculateInsertPositionAndExtractValues(
+ destTy, insertOp.getStaticPosition(), srcAttr, insertedValues);
auto allValues = llvm::to_vector(denseDst.getValues<Attribute>());
copy(insertedValues, allValues.begin() + insertBeginPosition);
@@ -3256,7 +3371,8 @@ foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr,
void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat>(context);
+ results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
+ InsertConstantToPoison>(context);
}
OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index a06a9f67d54dc..36f3d7196bb93 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2320,6 +2320,38 @@ func.func @insert_2d_constant() -> (vector<2x3xi32>, vector<2x3xi32>, vector<2x3
// -----
+// CHECK-LABEL: func.func @fully_insert_scalar_constant_to_poison_vector
+// CHECK: %[[VAL0:.+]] = arith.constant dense<[10, 20]> : vector<2xi64>
+// CHECK-NEXT: return %[[VAL0]]
+func.func @fully_insert_scalar_constant_to_poison_vector() -> vector<2xi64> {
+ %poison = ub.poison : vector<2xi64>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %e0 = arith.constant 10 : i64
+ %e1 = arith.constant 20 : i64
+ %v1 = vector.insert %e0, %poison[%c0] : i64 into vector<2xi64>
+ %v2 = vector.insert %e1, %v1[%c1] : i64 into vector<2xi64>
+ return %v2 : vector<2xi64>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @fully_insert_vector_constant_to_poison_vector
+// CHECK: %[[VAL0:.+]] = arith.constant dense<{{\[\[1, 2, 3\], \[4, 5, 6\]\]}}> : vector<2x3xi64>
+// CHECK-NEXT: return %[[VAL0]]
+func.func @fully_insert_vector_constant_to_poison_vector() -> vector<2x3xi64> {
+ %poison = ub.poison : vector<2x3xi64>
+ %cv0 = arith.constant dense<[1, 2, 3]> : vector<3xi64>
+ %cv1 = arith.constant dense<[4, 5, 6]> : vector<3xi64>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %v1 = vector.insert %cv0, %poison[%c0] : vector<3xi64> into vector<2x3xi64>
+ %v2 = vector.insert %cv1, %v1[%c1] : vector<3xi64> into vector<2x3xi64>
+ return %v2 : vector<2x3xi64>
+}
+
+// -----
+
// CHECK-LABEL: func.func @insert_2d_splat_constant
// CHECK-DAG: %[[ACST:.*]] = arith.constant dense<0> : vector<2x3xi32>
// CHECK-DAG: %[[BCST:.*]] = arith.constant dense<{{\[\[99, 0, 0\], \[0, 0, 0\]\]}}> : vector<2x3xi32>
More information about the Mlir-commits
mailing list