[Mlir-commits] [mlir] [mlir][transform] Fix handling of transitive includes in interpreter. (PR #67241)

Ingo Müller llvmlistbot at llvm.org
Mon Sep 25 01:03:54 PDT 2023


https://github.com/ingomueller-net updated https://github.com/llvm/llvm-project/pull/67241

>From 2b7d6b4d4eaf5fe7159f93196e17423771d21fb7 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ingo=20M=C3=BCller?= <ingomueller at google.com>
Date: Fri, 22 Sep 2023 15:10:30 +0000
Subject: [PATCH] [mlir][transform] Fix handling of transitive include in
 interpreter.

Until now, the interpreter would only load those symbols from the
provided library files that were declared in the main transform module.
However, sequences in the library may include other sequences on their
own. Until now, if such sequences were not *also* declared in the main
transform module, the interpreter would fail to resolve them. Forward
declaring all of them is undesirable as it defeats the purpose of
encapsulation into library modules.

This PR extends the loading missing as follows: in
`defineDeclaredSymbols`, not only are the definitions inserted that are
forward-declared in the main module, but any such inserted definition is
scanned for further dependencies, and those are processed in the same
way as the forward-declarations from the main module.
---
 .../TransformInterpreterPassBase.cpp          | 72 ++++++++++++++++---
 ...reter-external-symbol-decl-transitive.mlir | 27 +++++++
 .../test-interpreter-external-symbol-def.mlir |  5 ++
 3 files changed, 95 insertions(+), 9 deletions(-)
 create mode 100644 mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-transitive.mlir

diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
index d5c65b23e3a2134..3c993e417b67efe 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
@@ -311,6 +311,9 @@ static LogicalResult defineDeclaredSymbols(Block &block, ModuleOp definitions) {
   auto readOnlyName =
       StringAttr::get(&ctx, transform::TransformDialect::kArgReadOnlyAttrName);
 
+  // Collect symbols missing in the block.
+  SmallVector<SymbolOpInterface> missingSymbols;
+  LLVM_DEBUG(DBGS() << "searching block for missing symbols:\n");
   for (Operation &op : llvm::make_early_inc_range(block)) {
     LLVM_DEBUG(DBGS() << op << "\n");
     auto symbol = dyn_cast<SymbolOpInterface>(op);
@@ -318,25 +321,33 @@ static LogicalResult defineDeclaredSymbols(Block &block, ModuleOp definitions) {
       continue;
     if (symbol->getNumRegions() == 1 && !symbol->getRegion(0).empty())
       continue;
+    LLVM_DEBUG(DBGS() << "  -> symbol missing\n");
+    missingSymbols.push_back(symbol);
+  }
 
-    LLVM_DEBUG(DBGS() << "looking for definition of symbol "
-                      << symbol.getNameAttr() << ":");
-    SymbolTable symbolTable(definitions);
-    Operation *externalSymbol = symbolTable.lookup(symbol.getNameAttr());
+  // Resolve missing symbols until they are all resolved.
+  while (!missingSymbols.empty()) {
+    SymbolOpInterface symbol = missingSymbols.pop_back_val();
+    LLVM_DEBUG(DBGS() << "looking for definition of symbol @"
+                      << symbol.getNameAttr().getValue() << ": ");
+    SymbolTable definitionsSymbolTable(definitions);
+    Operation *externalSymbol =
+        definitionsSymbolTable.lookup(symbol.getNameAttr());
     if (!externalSymbol || externalSymbol->getNumRegions() != 1 ||
         externalSymbol->getRegion(0).empty()) {
       LLVM_DEBUG(llvm::dbgs() << "not found\n");
       continue;
     }
 
-    auto symbolFunc = dyn_cast<FunctionOpInterface>(op);
+    auto symbolFunc = dyn_cast<FunctionOpInterface>(symbol.getOperation());
     auto externalSymbolFunc = dyn_cast<FunctionOpInterface>(externalSymbol);
     if (!symbolFunc || !externalSymbolFunc) {
       LLVM_DEBUG(llvm::dbgs() << "cannot compare types\n");
       continue;
     }
 
-    LLVM_DEBUG(llvm::dbgs() << "found @" << externalSymbol << "\n");
+    LLVM_DEBUG(llvm::dbgs() << "found " << externalSymbol << " from "
+                            << externalSymbol->getLoc() << "\n");
     if (symbolFunc.getFunctionType() != externalSymbolFunc.getFunctionType()) {
       return symbolFunc.emitError()
              << "external definition has a mismatching signature ("
@@ -367,10 +378,53 @@ static LogicalResult defineDeclaredSymbols(Block &block, ModuleOp definitions) {
       }
     }
 
-    OpBuilder builder(&op);
-    builder.setInsertionPoint(&op);
-    builder.clone(*externalSymbol);
+    OpBuilder builder(symbol);
+    builder.setInsertionPoint(symbol);
+    Operation *newSymbol = builder.clone(*externalSymbol);
+    builder.setInsertionPoint(newSymbol);
     symbol->erase();
+
+    LLVM_DEBUG(DBGS() << "scanning definition of @"
+                      << externalSymbolFunc.getNameAttr().getValue()
+                      << " for symbol usages\n");
+    externalSymbolFunc.walk([&](CallOpInterface callOp) {
+      LLVM_DEBUG(DBGS() << "  found symbol usage in:\n" << callOp << "\n");
+      CallInterfaceCallable callable = callOp.getCallableForCallee();
+      if (!isa<SymbolRefAttr>(callable)) {
+        LLVM_DEBUG(DBGS() << "    not a 'SymbolRefAttr'\n");
+        return WalkResult::advance();
+      }
+
+      StringRef callableSymbol =
+          cast<SymbolRefAttr>(callable).getLeafReference();
+      LLVM_DEBUG(DBGS() << "    looking for @" << callableSymbol
+                        << " in definitions: ");
+
+      Operation *callableOp = definitionsSymbolTable.lookup(callableSymbol);
+      if (!isa<SymbolRefAttr>(callable)) {
+        LLVM_DEBUG(llvm::dbgs() << "not found\n");
+        return WalkResult::advance();
+      }
+      LLVM_DEBUG(llvm::dbgs() << "found " << callableOp << " from "
+                              << callableOp->getLoc() << "\n");
+
+      if (!block.getParent() || !block.getParent()->getParentOp()) {
+        LLVM_DEBUG(DBGS() << "could not get parent of provided block");
+        return WalkResult::advance();
+      }
+
+      SymbolTable targetSymbolTable(block.getParent()->getParentOp());
+      if (targetSymbolTable.lookup(callableSymbol)) {
+        LLVM_DEBUG(DBGS() << "    symbol @" << callableSymbol
+                          << " already present in target\n");
+        return WalkResult::advance();
+      }
+
+      LLVM_DEBUG(DBGS() << "    cloning op into target\n");
+      builder.clone(*callableOp);
+
+      return WalkResult::advance();
+    });
   }
 
   return success();
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-transitive.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-transitive.mlir
new file mode 100644
index 000000000000000..0e9fa7c59bc4155
--- /dev/null
+++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-transitive.mlir
@@ -0,0 +1,27 @@
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
+// RUN:             --verify-diagnostics --split-input-file | FileCheck %s
+
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter)" \
+// RUN:             --verify-diagnostics --split-input-file | FileCheck %s
+
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
+// RUN:             --verify-diagnostics --split-input-file | FileCheck %s
+
+// The definition of the @bar named sequence is provided in another file. It
+// will be included because of the pass option. That sequence uses another named
+// sequence @foo, which should be made available here. Repeated application of
+// the same pass, with or without the library option, should not be a problem.
+// Note that the same diagnostic produced twice at the same location only
+// needs to be matched once.
+
+// expected-remark @below {{message}}
+module attributes {transform.with_named_sequence} {
+  // CHECK-DAG: transform.named_sequence @foo
+  // CHECK-DAG: transform.named_sequence @bar
+  transform.named_sequence private @bar(!transform.any_op {transform.readonly})
+
+  transform.sequence failures(propagate) {
+  ^bb0(%arg0: !transform.any_op):
+    include @bar failures(propagate) (%arg0) : (!transform.any_op) -> ()
+  }
+}
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir
index 1149bda98ab8527..9aa2d46d5abb995 100644
--- a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir
@@ -1,6 +1,11 @@
 // RUN: mlir-opt %s
 
 module attributes {transform.with_named_sequence} {
+  transform.named_sequence @bar(%arg0: !transform.any_op) {
+    transform.include @foo failures(propagate) (%arg0) : (!transform.any_op) -> ()
+    transform.yield
+  }
+
   transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly}) {
     transform.test_print_remark_at_operand %arg0, "message" : !transform.any_op
     transform.yield



More information about the Mlir-commits mailing list