[Mlir-commits] [mlir] [mlir] Speed up resolving ExtractValueOp source (PR #176478)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jan 16 13:25:12 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-llvm
Author: None (neildhar)
<details>
<summary>Changes</summary>
The current `ExtractValueOp::fold` implementation traverses the entire chain of `InsertValueOp`s leading up to it. This can be extremely slow if there are a huge number of `ExtractValueOp`s using values from the same chain.
This PR improves this significantly in cases where a large number of the `ExtractValueOp`s are actually reading from the same `InsertValueOp`. That is, for patterns like:
```
%i0 = llvm.insertvalue %v0, %undef[0]
%i1 = llvm.insertvalue %v1, %0[1]
...
%i999 = llvm.insertvalue %v999, %998[999]
%e0 = llvm.extractvalue %i999[0]
%e1 = llvm.extractvalue %i999[1]
...
%e999 = llvm.extractvalue %i999[999]
```
In such cases, the resolution can be performed much faster using a canonicalisation pattern on `InsertValueOp` that is applied to `%i999`, because we can collect all of the `ExtractValueOp`s that use it, and then do a single traversal of the chain to resolve them.
With this change, most of the resolution happens as part of the `InsertValueOp` canonicalisation step, and there is much less work to do when `ExtractValueOp::fold` is run.
Note that for now, this leaves the implementation of `ExtractValueOp` as-is so the order in which patterns are applied affects whether we see the speedup. This requires patterns to be applied in top-down order, which is the default for the canonicaliser pass. I am separately working on simplifying `ExtractValueOp::fold` to do less traversal, but that requires some care to ensure existing cases are not pessimised.
---
Full diff: https://github.com/llvm/llvm-project/pull/176478.diff
2 Files Affected:
- (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td (+1)
- (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp (+75)
``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 971710fa3ee13..6789ca22c3d5f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -940,6 +940,7 @@ def LLVM_InsertValueOp : LLVM_Op<
}];
let hasVerifier = 1;
+ let hasCanonicalizer = 1;
string llvmInstName = "InsertValue";
string llvmBuilder = [{
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index f9162b35966c1..78c5989e2beb0 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2027,6 +2027,81 @@ void ExtractValueOp::build(OpBuilder &builder, OperationState &state,
// InsertValueOp
//===----------------------------------------------------------------------===//
+namespace {
+/// Update any ExtractValueOps using a given InsertValueOp to instead read from
+/// the closest InsertValueOp in the chain leading up to the current op that
+/// writes to the same member. This traversal could be done entirely in
+/// ExtractValueOp::fold, but doing it here significantly speeds things up
+/// because we can handle several ExtractValueOps with a single traversal.
+struct ResolveExtractValueSource : public OpRewritePattern<InsertValueOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(InsertValueOp insertOp,
+ PatternRewriter &rewriter) const override {
+ bool changed = false;
+ // Map each position in the top-level struct to the ExtractOps that read
+ // from it.
+ DenseMap<int64_t, SmallVector<ExtractValueOp, 4>> posToExtractOps;
+ auto insertBaseIdx = insertOp.getPosition()[0];
+ for (auto &use : insertOp->getUses()) {
+ if (auto extractOp = dyn_cast<ExtractValueOp>(use.getOwner())) {
+ auto baseIdx = extractOp.getPosition()[0];
+ // We can skip reads of the member that insertOp writes to since they
+ // will not be updated.
+ if (baseIdx == insertBaseIdx)
+ continue;
+ posToExtractOps[baseIdx].push_back(extractOp);
+ }
+ }
+ // Walk up the chain of insertions and try to resolve the remaining
+ // extractions that access the same member.
+ Value nextContainer = insertOp.getContainer();
+ while (!posToExtractOps.empty()) {
+ auto curInsert =
+ dyn_cast_or_null<InsertValueOp>(nextContainer.getDefiningOp());
+ if (!curInsert)
+ break;
+ nextContainer = curInsert.getContainer();
+
+ // Check if any extractions read the member written by this insertion.
+ auto curInsertBaseIdx = curInsert.getPosition()[0];
+ auto it = posToExtractOps.find(curInsertBaseIdx);
+ if (it == posToExtractOps.end())
+ continue;
+
+ // Update the ExtractOps to read from the current insertion.
+ for (auto &extractOp : it->second) {
+ rewriter.modifyOpInPlace(extractOp, [&] {
+ extractOp.getContainerMutable().assign(curInsert);
+ });
+ }
+ // The entry should never be empty if it exists, so if we are at this
+ // point, set changed to true.
+ assert(!it->second.empty());
+ changed |= true;
+ posToExtractOps.erase(it);
+ }
+ // There was no insertion along the chain that wrote the member accessed by
+ // these extracts. So we can update them to use the top of the chain.
+ for (auto &[baseIdx, extracts] : posToExtractOps) {
+ for (auto &extractOp : extracts) {
+ rewriter.modifyOpInPlace(extractOp, [&] {
+ extractOp.getContainerMutable().assign(nextContainer);
+ });
+ }
+ assert(!extracts.empty() && "Empty list in map");
+ changed = true;
+ }
+ return success(changed);
+ }
+};
+} // namespace
+
+void InsertValueOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<ResolveExtractValueSource>(context);
+}
+
/// Infer the value type from the container type and position.
static ParseResult
parseInsertExtractValueElementType(AsmParser &parser, Type &valueType,
``````````
</details>
https://github.com/llvm/llvm-project/pull/176478
More information about the Mlir-commits
mailing list