[Mlir-commits] [mlir] [mlir][transform] Fix handling of transitive include in interpreter. (PR #67560)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Oct 11 22:52:34 PDT 2023


Ingo =?utf-8?q?Müller?= <ingomueller at google.com>,
Ingo =?utf-8?q?Müller?= <ingomueller at google.com>,
Ingo =?utf-8?q?Müller?= <ingomueller at google.com>,
Ingo =?utf-8?q?Müller?= <ingomueller at google.com>,
Ingo =?utf-8?q?Müller?= <ingomueller at google.com>,
Ingo =?utf-8?q?Müller?= <ingomueller at google.com>,
Ingo =?utf-8?q?Müller?= <ingomueller at google.com>,
Ingo =?utf-8?q?Müller?= <ingomueller at google.com>,
Ingo =?utf-8?q?Müller?= <ingomueller at google.com>,
Ingo =?utf-8?q?Müller?= <ingomueller at google.com>,
Ingo =?utf-8?q?Müller?= <ingomueller at google.com>,
Ingo =?utf-8?q?Müller?= <ingomueller at google.com>
Message-ID:
In-Reply-To: <llvm/llvm-project/pull/67560/mlir at github.com>


================
@@ -302,80 +304,268 @@ static void performOptionalDebugActions(
     transform->removeAttr(kTransformDialectTagAttrName);
 }
 
-/// Replaces external symbols in `block` with their (non-external) definitions
-/// from the given module.
-static LogicalResult defineDeclaredSymbols(Block &block, ModuleOp definitions) {
-  MLIRContext &ctx = *definitions->getContext();
-  auto consumedName =
-      StringAttr::get(&ctx, transform::TransformDialect::kArgConsumedAttrName);
-  auto readOnlyName =
-      StringAttr::get(&ctx, transform::TransformDialect::kArgReadOnlyAttrName);
-
-  for (Operation &op : llvm::make_early_inc_range(block)) {
-    LLVM_DEBUG(DBGS() << op << "\n");
-    auto symbol = dyn_cast<SymbolOpInterface>(op);
-    if (!symbol)
-      continue;
-    if (symbol->getNumRegions() == 1 && !symbol->getRegion(0).empty())
-      continue;
-
-    LLVM_DEBUG(DBGS() << "looking for definition of symbol "
-                      << symbol.getNameAttr() << ":");
-    SymbolTable symbolTable(definitions);
-    Operation *externalSymbol = symbolTable.lookup(symbol.getNameAttr());
-    if (!externalSymbol || externalSymbol->getNumRegions() != 1 ||
-        externalSymbol->getRegion(0).empty()) {
-      LLVM_DEBUG(llvm::dbgs() << "not found\n");
-      continue;
+/// Rename `op` to avoid a collision with `otherOp`. `symbolTable` and
+/// `otherSymbolTable` are the symbol tables of the two ops, respectively.
+/// `uniqueId` is used to generate a unique name in the context of the caller.
+LogicalResult renameToUnique(SymbolOpInterface op, SymbolOpInterface otherOp,
+                             SymbolTable &symbolTable,
+                             SymbolTable &otherSymbolTable, int &uniqueId) {
+  assert(symbolTable.lookup(op.getNameAttr()) == op &&
+         "symbol table does not contain op");
+  assert(otherSymbolTable.lookup(otherOp.getNameAttr()) == otherOp &&
+         "other symbol table does not contain other op");
+
+  // Determine new name that is unique in both symbol tables.
+  StringAttr oldName = op.getNameAttr();
+  StringAttr newName;
+  {
+    MLIRContext *context = op->getContext();
+    SmallString<64> prefix = oldName.getValue();
+    prefix.push_back('_');
+    while (true) {
+      newName = StringAttr::get(context, prefix + Twine(uniqueId++));
+      if (!symbolTable.lookup(newName) && !otherSymbolTable.lookup(newName)) {
+        break;
+      }
     }
+  }
+
+  // Apply renaming.
+  LLVM_DEBUG(llvm::dbgs() << ", renaming to @" << newName.getValue() << "\n");
+  Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(op);
+  if (failed(SymbolTable::replaceAllSymbolUses(op, newName, symbolTableOp))) {
+    InFlightDiagnostic diag =
+        emitError(op->getLoc(),
+                  Twine("failed to rename symbol to @") + newName.getValue());
+    diag.attachNote(otherOp->getLoc())
+        << "attempted renaming due to collision with this op";
+    return diag;
+  }
+
+  // Change the symbol in the op itself and update the symbol table.
+  symbolTable.remove(op);
+  SymbolTable::setSymbolName(op, newName);
+  symbolTable.insert(op);
+
+  assert(symbolTable.lookup(newName) == op &&
+         "symbol table does not resolve to renamed op");
+  assert(symbolTable.lookup(oldName) == nullptr &&
+         "symbol table still resolves old name");
+
+  return success();
+}
 
-    auto symbolFunc = dyn_cast<FunctionOpInterface>(op);
-    auto externalSymbolFunc = dyn_cast<FunctionOpInterface>(externalSymbol);
-    if (!symbolFunc || !externalSymbolFunc) {
-      LLVM_DEBUG(llvm::dbgs() << "cannot compare types\n");
+/// Return whether `func1` can be merged into `func2`.
+bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) {
+  return func1.isExternal() && (func2.isPublic() || func2.isExternal());
----------------
qcolombet wrote:

Kind of late to the party, but I agree that the meaning is difficult to grasp currently.
More specifically I believe the comment doesn't describe what mergeable means.

For me the confusing part is that merging implies that the result is bigger. So before reading the code I was wondering if we were doing some sort of function concatenation :).
\<digression>LLVM has an optimization pass called merge globals that puts global variables in the same global structure to allow faster symbol resolution and that's the kind of thing this function name brings to mind for me.\</digression>

Maybe `canCoalesceSymbols`?

https://github.com/llvm/llvm-project/pull/67560


More information about the Mlir-commits mailing list