[Mlir-commits] [mlir] [mlir] Speed up resolving ExtractValueOp source (PR #176478)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 16 13:25:12 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-llvm

Author: None (neildhar)

<details>
<summary>Changes</summary>

The current `ExtractValueOp::fold` implementation traverses the entire chain of `InsertValueOp`s leading up to it. This can be extremely slow if there are a huge number of `ExtractValueOp`s using values from the same chain.

This PR improves this significantly in cases where a large number of the `ExtractValueOp`s are actually reading from the same `InsertValueOp`. That is, for patterns like:

```
%i0 = llvm.insertvalue %v0, %undef[0]
%i1 = llvm.insertvalue %v1, %0[1]
...
%i999 = llvm.insertvalue %v999, %998[999]
%e0 = llvm.extractvalue %i999[0]
%e1 = llvm.extractvalue %i999[1]
...
%e999 = llvm.extractvalue %i999[999]
```

In such cases, the resolution can be performed much faster using a canonicalisation pattern on `InsertValueOp` that is applied to `%i999`, because we can collect all of the `ExtractValueOp`s that use it, and then do a single traversal of the chain to resolve them.

With this change, most of the resolution happens as part of the `InsertValueOp` canonicalisation step, and there is much less work to do when `ExtractValueOp::fold` is run. 

Note that for now, this leaves the implementation of `ExtractValueOp` as-is so the order in which patterns are applied affects whether we see the speedup. This requires patterns to be applied in top-down order, which is the default for the canonicaliser pass. I am separately working on simplifying `ExtractValueOp::fold` to do less traversal, but that requires some care to ensure existing cases are not pessimised.



---
Full diff: https://github.com/llvm/llvm-project/pull/176478.diff


2 Files Affected:

- (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td (+1) 
- (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp (+75) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 971710fa3ee13..6789ca22c3d5f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -940,6 +940,7 @@ def LLVM_InsertValueOp : LLVM_Op<
   }];
 
   let hasVerifier = 1;
+  let hasCanonicalizer = 1;
 
   string llvmInstName = "InsertValue";
   string llvmBuilder = [{
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index f9162b35966c1..78c5989e2beb0 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2027,6 +2027,81 @@ void ExtractValueOp::build(OpBuilder &builder, OperationState &state,
 // InsertValueOp
 //===----------------------------------------------------------------------===//
 
+namespace {
+/// Update any ExtractValueOps using a given InsertValueOp to instead read from
+/// the closest InsertValueOp in the chain leading up to the current op that
+/// writes to the same member. This traversal could be done entirely in
+/// ExtractValueOp::fold, but doing it here significantly speeds things up
+/// because we can handle several ExtractValueOps with a single traversal.
+struct ResolveExtractValueSource : public OpRewritePattern<InsertValueOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(InsertValueOp insertOp,
+                                PatternRewriter &rewriter) const override {
+    bool changed = false;
+    // Map each position in the top-level struct to the ExtractOps that read
+    // from it.
+    DenseMap<int64_t, SmallVector<ExtractValueOp, 4>> posToExtractOps;
+    auto insertBaseIdx = insertOp.getPosition()[0];
+    for (auto &use : insertOp->getUses()) {
+      if (auto extractOp = dyn_cast<ExtractValueOp>(use.getOwner())) {
+        auto baseIdx = extractOp.getPosition()[0];
+        // We can skip reads of the member that insertOp writes to since they
+        // will not be updated.
+        if (baseIdx == insertBaseIdx)
+          continue;
+        posToExtractOps[baseIdx].push_back(extractOp);
+      }
+    }
+    // Walk up the chain of insertions and try to resolve the remaining
+    // extractions that access the same member.
+    Value nextContainer = insertOp.getContainer();
+    while (!posToExtractOps.empty()) {
+      auto curInsert =
+          dyn_cast_or_null<InsertValueOp>(nextContainer.getDefiningOp());
+      if (!curInsert)
+        break;
+      nextContainer = curInsert.getContainer();
+
+      // Check if any extractions read the member written by this insertion.
+      auto curInsertBaseIdx = curInsert.getPosition()[0];
+      auto it = posToExtractOps.find(curInsertBaseIdx);
+      if (it == posToExtractOps.end())
+        continue;
+
+      // Update the ExtractOps to read from the current insertion.
+      for (auto &extractOp : it->second) {
+        rewriter.modifyOpInPlace(extractOp, [&] {
+          extractOp.getContainerMutable().assign(curInsert);
+        });
+      }
+      // The entry should never be empty if it exists, so if we are at this
+      // point, set changed to true.
+      assert(!it->second.empty());
+      changed |= true;
+      posToExtractOps.erase(it);
+    }
+    // There was no insertion along the chain that wrote the member accessed by
+    // these extracts. So we can update them to use the top of the chain.
+    for (auto &[baseIdx, extracts] : posToExtractOps) {
+      for (auto &extractOp : extracts) {
+        rewriter.modifyOpInPlace(extractOp, [&] {
+          extractOp.getContainerMutable().assign(nextContainer);
+        });
+      }
+      assert(!extracts.empty() && "Empty list in map");
+      changed = true;
+    }
+    return success(changed);
+  }
+};
+} // namespace
+
+void InsertValueOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                                MLIRContext *context) {
+  patterns.add<ResolveExtractValueSource>(context);
+}
+
 /// Infer the value type from the container type and position.
 static ParseResult
 parseInsertExtractValueElementType(AsmParser &parser, Type &valueType,

``````````

</details>


https://github.com/llvm/llvm-project/pull/176478


More information about the Mlir-commits mailing list