[Mlir-commits] [mlir] [mlir] Speed up resolving ExtractValueOp source (PR #176478)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jan 19 08:22:16 PST 2026


https://github.com/neildhar updated https://github.com/llvm/llvm-project/pull/176478

>From 7594019392a7418de79c67a538edcde165b3d528 Mon Sep 17 00:00:00 2001
From: Neil Dhar <neildhar at meta.com>
Date: Fri, 16 Jan 2026 12:46:05 -0800
Subject: [PATCH] [mlir] Speed up resolving ExtractValueOp source

---
 mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td |  1 +
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp  | 98 +++++++++++++++++++++
 2 files changed, 99 insertions(+)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 971710fa3ee13..6789ca22c3d5f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -940,6 +940,7 @@ def LLVM_InsertValueOp : LLVM_Op<
   }];
 
   let hasVerifier = 1;
+  let hasCanonicalizer = 1;
 
   string llvmInstName = "InsertValue";
   string llvmBuilder = [{
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index f9162b35966c1..6281f0d6e0b09 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2027,6 +2027,104 @@ void ExtractValueOp::build(OpBuilder &builder, OperationState &state,
 // InsertValueOp
 //===----------------------------------------------------------------------===//
 
+namespace {
+/// Update any ExtractValueOps using a given InsertValueOp to instead read from
+/// the closest InsertValueOp in the chain leading up to the current op that
+/// writes to the same member. This traversal could be done entirely in
+/// ExtractValueOp::fold, but doing it here significantly speeds things up
+/// because we can handle several ExtractValueOps with a single traversal.
+/// For instance, in this example:
+///   %i0 = llvm.insertvalue %v0, %undef[0]
+///   %i1 = llvm.insertvalue %v1, %0[1]
+///   ...
+///   %i999 = llvm.insertvalue %v999, %998[999]
+///   %e0 = llvm.extractvalue %i999[0]
+///   %e1 = llvm.extractvalue %i999[1]
+///   ...
+///   %e999 = llvm.extractvalue %i999[999]
+/// Individually running the folder on each extractvalue would require
+/// traversing the insertvalue chain 1000 times, but running this pattern on the
+/// InsertValueOp would allow us to achieve the same result with a single
+/// traversal. The resulting IR after this pattern will then be:
+///   %i0 = llvm.insertvalue %v0, %undef[0]
+///   %i1 = llvm.insertvalue %v1, %0[1]
+///   ...
+///   %i999 = llvm.insertvalue %v999, %998[999]
+///   %e0 = llvm.extractvalue %i0[0]
+///   %e1 = llvm.extractvalue %i1[1]
+///   ...
+///   %e999 = llvm.extractvalue %i999[999]
+struct ResolveExtractValueSource : public OpRewritePattern<InsertValueOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(InsertValueOp insertOp,
+                                PatternRewriter &rewriter) const override {
+    bool changed = false;
+    // Map each position in the top-level struct to the ExtractOps that read
+    // from it. For the example in the doc-comment above this map will be empty
+    // when we visit ops %i0 - %i998. For %i999, it will contain:
+    //   0 -> { %e0 }, 1 -> { %e1 }, ... 999-> { %e999 }
+    DenseMap<int64_t, SmallVector<ExtractValueOp, 4>> posToExtractOps;
+    auto insertBaseIdx = insertOp.getPosition()[0];
+    for (auto &use : insertOp->getUses()) {
+      if (auto extractOp = dyn_cast<ExtractValueOp>(use.getOwner())) {
+        auto baseIdx = extractOp.getPosition()[0];
+        // We can skip reads of the member that insertOp writes to since they
+        // will not be updated.
+        if (baseIdx == insertBaseIdx)
+          continue;
+        posToExtractOps[baseIdx].push_back(extractOp);
+      }
+    }
+    // Walk up the chain of insertions and try to resolve the remaining
+    // extractions that access the same member.
+    Value nextContainer = insertOp.getContainer();
+    while (!posToExtractOps.empty()) {
+      auto curInsert =
+          dyn_cast_or_null<InsertValueOp>(nextContainer.getDefiningOp());
+      if (!curInsert)
+        break;
+      nextContainer = curInsert.getContainer();
+
+      // Check if any extractions read the member written by this insertion.
+      auto curInsertBaseIdx = curInsert.getPosition()[0];
+      auto it = posToExtractOps.find(curInsertBaseIdx);
+      if (it == posToExtractOps.end())
+        continue;
+
+      // Update the ExtractOps to read from the current insertion.
+      for (auto &extractOp : it->second) {
+        rewriter.modifyOpInPlace(extractOp, [&] {
+          extractOp.getContainerMutable().assign(curInsert);
+        });
+      }
+      // The entry should never be empty if it exists, so if we are at this
+      // point, set changed to true.
+      assert(!it->second.empty());
+      changed |= true;
+      posToExtractOps.erase(it);
+    }
+    // There was no insertion along the chain that wrote the member accessed by
+    // these extracts. So we can update them to use the top of the chain.
+    for (auto &[baseIdx, extracts] : posToExtractOps) {
+      for (auto &extractOp : extracts) {
+        rewriter.modifyOpInPlace(extractOp, [&] {
+          extractOp.getContainerMutable().assign(nextContainer);
+        });
+      }
+      assert(!extracts.empty() && "Empty list in map");
+      changed = true;
+    }
+    return success(changed);
+  }
+};
+} // namespace
+
+void InsertValueOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                                MLIRContext *context) {
+  patterns.add<ResolveExtractValueSource>(context);
+}
+
 /// Infer the value type from the container type and position.
 static ParseResult
 parseInsertExtractValueElementType(AsmParser &parser, Type &valueType,



More information about the Mlir-commits mailing list