[Mlir-commits] [mlir] fc9feee - [mlir][LLVM] Speed up `extractvalue(insertvalue)` canonicalization (#176478)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jan 19 08:51:58 PST 2026


Author: neildhar
Date: 2026-01-19T17:51:54+01:00
New Revision: fc9feee1efa870cbcf17cfee9bc15e57ecbccfd7

URL: https://github.com/llvm/llvm-project/commit/fc9feee1efa870cbcf17cfee9bc15e57ecbccfd7
DIFF: https://github.com/llvm/llvm-project/commit/fc9feee1efa870cbcf17cfee9bc15e57ecbccfd7.diff

LOG: [mlir][LLVM] Speed up `extractvalue(insertvalue)` canonicalization (#176478)

The current `ExtractValueOp::fold` implementation traverses the entire
chain of `InsertValueOp`s leading up to it. This can be extremely slow
if there are a huge number of `ExtractValueOp`s using values from the
same chain.

This PR improves this significantly in cases where a large number of the
`ExtractValueOp`s are actually reading from the same `InsertValueOp`.
That is, for patterns like:

```
%i0 = llvm.insertvalue %v0, %undef[0]
%i1 = llvm.insertvalue %v1, %0[1]
...
%i999 = llvm.insertvalue %v999, %998[999]
%e0 = llvm.extractvalue %i999[0]
%e1 = llvm.extractvalue %i999[1]
...
%e999 = llvm.extractvalue %i999[999]
```

In such cases, the resolution can be performed much faster using a
canonicalisation pattern on `InsertValueOp` that is applied to `%i999`,
because we can collect all of the `ExtractValueOp`s that use it, and
then do a single traversal of the chain to resolve them.

With this change, most of the resolution happens as part of the
`InsertValueOp` canonicalisation step, and there is much less work to do
when `ExtractValueOp::fold` is run.

Note that for now, this leaves the implementation of `ExtractValueOp`
as-is so the order in which patterns are applied affects whether we see
the speedup. This requires patterns to be applied in top-down order,
which is the default for the canonicaliser pass. I am separately working
on simplifying `ExtractValueOp::fold` to do less traversal, but that
requires some care to ensure existing cases are not pessimised.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 971710fa3ee13..6789ca22c3d5f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -940,6 +940,7 @@ def LLVM_InsertValueOp : LLVM_Op<
   }];
 
   let hasVerifier = 1;
+  let hasCanonicalizer = 1;
 
   string llvmInstName = "InsertValue";
   string llvmBuilder = [{

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index f9162b35966c1..6281f0d6e0b09 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2027,6 +2027,104 @@ void ExtractValueOp::build(OpBuilder &builder, OperationState &state,
 // InsertValueOp
 //===----------------------------------------------------------------------===//
 
+namespace {
+/// Update any ExtractValueOps using a given InsertValueOp to instead read from
+/// the closest InsertValueOp in the chain leading up to the current op that
+/// writes to the same member. This traversal could be done entirely in
+/// ExtractValueOp::fold, but doing it here significantly speeds things up
+/// because we can handle several ExtractValueOps with a single traversal.
+/// For instance, in this example:
+///   %i0 = llvm.insertvalue %v0, %undef[0]
+///   %i1 = llvm.insertvalue %v1, %0[1]
+///   ...
+///   %i999 = llvm.insertvalue %v999, %998[999]
+///   %e0 = llvm.extractvalue %i999[0]
+///   %e1 = llvm.extractvalue %i999[1]
+///   ...
+///   %e999 = llvm.extractvalue %i999[999]
+/// Individually running the folder on each extractvalue would require
+/// traversing the insertvalue chain 1000 times, but running this pattern on the
+/// InsertValueOp would allow us to achieve the same result with a single
+/// traversal. The resulting IR after this pattern will then be:
+///   %i0 = llvm.insertvalue %v0, %undef[0]
+///   %i1 = llvm.insertvalue %v1, %0[1]
+///   ...
+///   %i999 = llvm.insertvalue %v999, %998[999]
+///   %e0 = llvm.extractvalue %i0[0]
+///   %e1 = llvm.extractvalue %i1[1]
+///   ...
+///   %e999 = llvm.extractvalue %i999[999]
+struct ResolveExtractValueSource : public OpRewritePattern<InsertValueOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(InsertValueOp insertOp,
+                                PatternRewriter &rewriter) const override {
+    bool changed = false;
+    // Map each position in the top-level struct to the ExtractOps that read
+    // from it. For the example in the doc-comment above this map will be empty
+    // when we visit ops %i0 - %i998. For %i999, it will contain:
+    //   0 -> { %e0 }, 1 -> { %e1 }, ... 999-> { %e999 }
+    DenseMap<int64_t, SmallVector<ExtractValueOp, 4>> posToExtractOps;
+    auto insertBaseIdx = insertOp.getPosition()[0];
+    for (auto &use : insertOp->getUses()) {
+      if (auto extractOp = dyn_cast<ExtractValueOp>(use.getOwner())) {
+        auto baseIdx = extractOp.getPosition()[0];
+        // We can skip reads of the member that insertOp writes to since they
+        // will not be updated.
+        if (baseIdx == insertBaseIdx)
+          continue;
+        posToExtractOps[baseIdx].push_back(extractOp);
+      }
+    }
+    // Walk up the chain of insertions and try to resolve the remaining
+    // extractions that access the same member.
+    Value nextContainer = insertOp.getContainer();
+    while (!posToExtractOps.empty()) {
+      auto curInsert =
+          dyn_cast_or_null<InsertValueOp>(nextContainer.getDefiningOp());
+      if (!curInsert)
+        break;
+      nextContainer = curInsert.getContainer();
+
+      // Check if any extractions read the member written by this insertion.
+      auto curInsertBaseIdx = curInsert.getPosition()[0];
+      auto it = posToExtractOps.find(curInsertBaseIdx);
+      if (it == posToExtractOps.end())
+        continue;
+
+      // Update the ExtractOps to read from the current insertion.
+      for (auto &extractOp : it->second) {
+        rewriter.modifyOpInPlace(extractOp, [&] {
+          extractOp.getContainerMutable().assign(curInsert);
+        });
+      }
+      // The entry should never be empty if it exists, so if we are at this
+      // point, set changed to true.
+      assert(!it->second.empty());
+      changed |= true;
+      posToExtractOps.erase(it);
+    }
+    // There was no insertion along the chain that wrote the member accessed by
+    // these extracts. So we can update them to use the top of the chain.
+    for (auto &[baseIdx, extracts] : posToExtractOps) {
+      for (auto &extractOp : extracts) {
+        rewriter.modifyOpInPlace(extractOp, [&] {
+          extractOp.getContainerMutable().assign(nextContainer);
+        });
+      }
+      assert(!extracts.empty() && "Empty list in map");
+      changed = true;
+    }
+    return success(changed);
+  }
+};
+} // namespace
+
+void InsertValueOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                                MLIRContext *context) {
+  patterns.add<ResolveExtractValueSource>(context);
+}
+
 /// Infer the value type from the container type and position.
 static ParseResult
 parseInsertExtractValueElementType(AsmParser &parser, Type &valueType,


        


More information about the Mlir-commits mailing list