[Mlir-commits] [mlir] [mlir][transform] Fix handling of transitive includes in interpreter. (PR #67241)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Sep 23 08:25:07 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

<details>
<summary>Changes</summary>

Until now, the interpreter would only load those symbols from the provided library files that were declared in the main transform module. However, sequences in the library may include other sequences on their own. Until now, if such sequences were not *also* declared in the main transform module, the interpreter would fail to resolve them. Forward declaring all of them is undesirable as it defeats the purpose of encapsulation into library modules.

This PR extends the loading missing as follows: in `defineDeclaredSymbols`, not only are the definitions inserted that are forward-declared in the main module, but any such inserted definition is scanned for further dependencies, and those are processed in the same way as the forward-declarations from the main module.

---
Full diff: https://github.com/llvm/llvm-project/pull/67241.diff


3 Files Affected:

- (modified) mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp (+61-8) 
- (added) mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-transitive.mlir (+25) 
- (modified) mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir (+5) 


``````````diff
diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
index d5c65b23e3a2134..aa2b1157c254b47 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
@@ -311,6 +311,9 @@ static LogicalResult defineDeclaredSymbols(Block &block, ModuleOp definitions) {
   auto readOnlyName =
       StringAttr::get(&ctx, transform::TransformDialect::kArgReadOnlyAttrName);
 
+  // Collect symbols missing in the block.
+  SmallVector<SymbolOpInterface> missingSymbols;
+  LLVM_DEBUG(DBGS() << "searching block for missing symbols:\n");
   for (Operation &op : llvm::make_early_inc_range(block)) {
     LLVM_DEBUG(DBGS() << op << "\n");
     auto symbol = dyn_cast<SymbolOpInterface>(op);
@@ -318,25 +321,33 @@ static LogicalResult defineDeclaredSymbols(Block &block, ModuleOp definitions) {
       continue;
     if (symbol->getNumRegions() == 1 && !symbol->getRegion(0).empty())
       continue;
+    LLVM_DEBUG(DBGS() << "  -> symbol missing\n");
+    missingSymbols.push_back(symbol);
+  }
 
-    LLVM_DEBUG(DBGS() << "looking for definition of symbol "
-                      << symbol.getNameAttr() << ":");
-    SymbolTable symbolTable(definitions);
-    Operation *externalSymbol = symbolTable.lookup(symbol.getNameAttr());
+  // Resolve missing symbols until they are all resolved.
+  while (!missingSymbols.empty()) {
+    SymbolOpInterface symbol = missingSymbols.pop_back_val();
+    LLVM_DEBUG(DBGS() << "looking for definition of symbol @"
+                      << symbol.getNameAttr().getValue() << ": ");
+    SymbolTable definitionsSymbolTable(definitions);
+    Operation *externalSymbol =
+        definitionsSymbolTable.lookup(symbol.getNameAttr());
     if (!externalSymbol || externalSymbol->getNumRegions() != 1 ||
         externalSymbol->getRegion(0).empty()) {
       LLVM_DEBUG(llvm::dbgs() << "not found\n");
       continue;
     }
 
-    auto symbolFunc = dyn_cast<FunctionOpInterface>(op);
+    auto symbolFunc = dyn_cast<FunctionOpInterface>(symbol.getOperation());
     auto externalSymbolFunc = dyn_cast<FunctionOpInterface>(externalSymbol);
     if (!symbolFunc || !externalSymbolFunc) {
       LLVM_DEBUG(llvm::dbgs() << "cannot compare types\n");
       continue;
     }
 
-    LLVM_DEBUG(llvm::dbgs() << "found @" << externalSymbol << "\n");
+    LLVM_DEBUG(llvm::dbgs() << "found " << externalSymbol << " from "
+                            << externalSymbol->getLoc() << "\n");
     if (symbolFunc.getFunctionType() != externalSymbolFunc.getFunctionType()) {
       return symbolFunc.emitError()
              << "external definition has a mismatching signature ("
@@ -367,10 +378,52 @@ static LogicalResult defineDeclaredSymbols(Block &block, ModuleOp definitions) {
       }
     }
 
-    OpBuilder builder(&op);
-    builder.setInsertionPoint(&op);
+    OpBuilder builder(symbol);
+    builder.setInsertionPoint(symbol);
     builder.clone(*externalSymbol);
     symbol->erase();
+
+    LLVM_DEBUG(DBGS() << "scanning definition of @"
+                      << externalSymbolFunc.getNameAttr().getValue()
+                      << " for symbol usages\n");
+    externalSymbolFunc.walk([&](CallOpInterface callOp) {
+      LLVM_DEBUG(DBGS() << "  found symbol usage in:\n" << callOp << "\n");
+      CallInterfaceCallable callable = callOp.getCallableForCallee();
+      if (!isa<SymbolRefAttr>(callable)) {
+        LLVM_DEBUG(DBGS() << "    not a 'SymbolRefAttr'\n");
+        return WalkResult::advance();
+      }
+
+      StringRef callableSymbol =
+          cast<SymbolRefAttr>(callable).getLeafReference();
+      LLVM_DEBUG(DBGS() << "    looking for @" << callableSymbol
+                        << " in definitions: ");
+
+      Operation *callableOp = definitionsSymbolTable.lookup(callableSymbol);
+      if (!isa<SymbolRefAttr>(callable)) {
+        LLVM_DEBUG(llvm::dbgs() << "not found\n");
+        return WalkResult::advance();
+      }
+      LLVM_DEBUG(llvm::dbgs() << "found " << callableOp << " from "
+                              << callableOp->getLoc() << "\n");
+
+      if (!block.getParent() || !block.getParent()->getParentOp()) {
+        LLVM_DEBUG(DBGS() << "could not get parent of provided block");
+        return WalkResult::advance();
+      }
+
+      SymbolTable targetSymbolTable(block.getParent()->getParentOp());
+      if (targetSymbolTable.lookup(callableSymbol)) {
+        LLVM_DEBUG(DBGS() << "    symbol @" << callableSymbol
+                          << " already present in target\n");
+        return WalkResult::advance();
+      }
+
+      LLVM_DEBUG(DBGS() << "    cloning op into target\n");
+      builder.clone(*callableOp);
+
+      return WalkResult::advance();
+    });
   }
 
   return success();
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-transitive.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-transitive.mlir
new file mode 100644
index 000000000000000..3a122ce2f77c3a8
--- /dev/null
+++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-transitive.mlir
@@ -0,0 +1,25 @@
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
+// RUN:             --verify-diagnostics --split-input-file | FileCheck %s
+
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter)" \
+// RUN:             --verify-diagnostics --split-input-file | FileCheck %s
+
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
+// RUN:             --verify-diagnostics --split-input-file | FileCheck %s
+
+// The definition of the @foo named sequence is provided in another file. It
+// will be included because of the pass option. Repeated application of the
+// same pass, with or without the library option, should not be a problem.
+// Note that the same diagnostic produced twice at the same location only
+// needs to be matched once.
+
+// expected-remark @below {{message}}
+// expected-remark @below {{unannotated}}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence private @bar(!transform.any_op {transform.readonly})
+
+  transform.sequence failures(propagate) {
+  ^bb0(%arg0: !transform.any_op):
+    include @bar failures(propagate) (%arg0) : (!transform.any_op) -> ()
+  }
+}
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir
index 1149bda98ab8527..9aa2d46d5abb995 100644
--- a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir
@@ -1,6 +1,11 @@
 // RUN: mlir-opt %s
 
 module attributes {transform.with_named_sequence} {
+  transform.named_sequence @bar(%arg0: !transform.any_op) {
+    transform.include @foo failures(propagate) (%arg0) : (!transform.any_op) -> ()
+    transform.yield
+  }
+
   transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly}) {
     transform.test_print_remark_at_operand %arg0, "message" : !transform.any_op
     transform.yield

``````````

</details>


https://github.com/llvm/llvm-project/pull/67241


More information about the Mlir-commits mailing list