[Mlir-commits] [mlir] [mlir][transform] Fix handling of transitive include in interpreter. (PR #67560)

Oleksandr Alex Zinenko llvmlistbot at llvm.org
Mon Oct 2 10:46:45 PDT 2023


Ingo =?utf-8?q?Müller?= <ingomueller at google.com>,
Ingo =?utf-8?q?Müller?= <ingomueller at google.com>
Message-ID:
In-Reply-To: <llvm/llvm-project/pull/67560/mlir at github.com>


================
@@ -302,80 +304,268 @@ static void performOptionalDebugActions(
     transform->removeAttr(kTransformDialectTagAttrName);
 }
 
-/// Replaces external symbols in `block` with their (non-external) definitions
-/// from the given module.
-static LogicalResult defineDeclaredSymbols(Block &block, ModuleOp definitions) {
-  MLIRContext &ctx = *definitions->getContext();
-  auto consumedName =
-      StringAttr::get(&ctx, transform::TransformDialect::kArgConsumedAttrName);
-  auto readOnlyName =
-      StringAttr::get(&ctx, transform::TransformDialect::kArgReadOnlyAttrName);
-
-  for (Operation &op : llvm::make_early_inc_range(block)) {
-    LLVM_DEBUG(DBGS() << op << "\n");
-    auto symbol = dyn_cast<SymbolOpInterface>(op);
-    if (!symbol)
-      continue;
-    if (symbol->getNumRegions() == 1 && !symbol->getRegion(0).empty())
-      continue;
-
-    LLVM_DEBUG(DBGS() << "looking for definition of symbol "
-                      << symbol.getNameAttr() << ":");
-    SymbolTable symbolTable(definitions);
-    Operation *externalSymbol = symbolTable.lookup(symbol.getNameAttr());
-    if (!externalSymbol || externalSymbol->getNumRegions() != 1 ||
-        externalSymbol->getRegion(0).empty()) {
-      LLVM_DEBUG(llvm::dbgs() << "not found\n");
-      continue;
+/// Rename `op` to avoid a collision with `otherOp`. `symbolTable` and
+/// `otherSymbolTable` are the symbol tables of the two ops, respectively.
+/// `uniqueId` is used to generate a unique name in the context of the caller.
+LogicalResult renameToUnique(SymbolOpInterface op, SymbolOpInterface otherOp,
+                             SymbolTable &symbolTable,
+                             SymbolTable &otherSymbolTable, int &uniqueId) {
+  assert(symbolTable.lookup(op.getNameAttr()) == op &&
+         "symbol table does not contain op");
+  assert(otherSymbolTable.lookup(otherOp.getNameAttr()) == otherOp &&
+         "other symbol table does not contain other op");
+
+  // Determine new name that is unique in both symbol tables.
+  StringAttr oldName = op.getNameAttr();
+  StringAttr newName;
+  {
+    MLIRContext *context = op->getContext();
+    SmallString<64> prefix = oldName.getValue();
+    prefix.push_back('_');
+    while (true) {
+      newName = StringAttr::get(context, prefix + Twine(uniqueId++));
+      if (!symbolTable.lookup(newName) && !otherSymbolTable.lookup(newName)) {
+        break;
+      }
     }
+  }
+
+  // Apply renaming.
+  LLVM_DEBUG(llvm::dbgs() << ", renaming to @" << newName.getValue() << "\n");
+  Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(op);
+  if (failed(SymbolTable::replaceAllSymbolUses(op, newName, symbolTableOp))) {
+    InFlightDiagnostic diag =
+        emitError(op->getLoc(),
+                  Twine("failed to rename symbol to @") + newName.getValue());
+    diag.attachNote(otherOp->getLoc())
+        << "attempted renaming due to collision with this op";
+    return diag;
+  }
+
+  // Change the symbol in the op itself and update the symbol table.
+  symbolTable.remove(op);
+  SymbolTable::setSymbolName(op, newName);
+  symbolTable.insert(op);
+
+  assert(symbolTable.lookup(newName) == op &&
+         "symbol table does not resolve to renamed op");
+  assert(symbolTable.lookup(oldName) == nullptr &&
+         "symbol table still resolves old name");
+
+  return success();
+}
 
-    auto symbolFunc = dyn_cast<FunctionOpInterface>(op);
-    auto externalSymbolFunc = dyn_cast<FunctionOpInterface>(externalSymbol);
-    if (!symbolFunc || !externalSymbolFunc) {
-      LLVM_DEBUG(llvm::dbgs() << "cannot compare types\n");
+/// Return whether `func1` can be merged into `func2`.
+bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) {
+  return func1.isExternal() && (func2.isPublic() || func2.isExternal());
+}
+
+/// Merge `func1` into `func2`. The two ops must be inside the same parent op
+/// and mergable according to `canMergeInto`. The function erases `func1` such
+/// that only `func2` exists when the function returns.
+LogicalResult mergeInto(FunctionOpInterface func1, FunctionOpInterface func2) {
+  assert(canMergeInto(func1, func2));
+  assert(func1->getParentOp() == func2->getParentOp() &&
+         "expected func1 and func2 to be in the same parent op");
+
+  MLIRContext *context = func1->getContext();
+  auto consumedName = StringAttr::get(
+      context, transform::TransformDialect::kArgConsumedAttrName);
+  auto readOnlyName = StringAttr::get(
+      context, transform::TransformDialect::kArgReadOnlyAttrName);
+
+  // Check that function signatures match.
+  if (func1.getFunctionType() != func2.getFunctionType()) {
+    return func1.emitError()
+           << "external definition has a mismatching signature ("
+           << func2.getFunctionType() << ")";
+  }
+
+  // Check and merge argument attributes.
+  for (unsigned i = 0, e = func1.getNumArguments(); i < e; ++i) {
+    bool isExternalConsumed = func2.getArgAttr(i, consumedName) != nullptr;
+    bool isExternalReadonly = func2.getArgAttr(i, readOnlyName) != nullptr;
+    bool isConsumed = func1.getArgAttr(i, consumedName) != nullptr;
+    bool isReadonly = func1.getArgAttr(i, readOnlyName) != nullptr;
+    if (!isExternalConsumed && !isExternalReadonly) {
+      if (isConsumed)
+        func2.setArgAttr(i, consumedName, UnitAttr::get(context));
+      else if (isReadonly)
+        func2.setArgAttr(i, readOnlyName, UnitAttr::get(context));
       continue;
     }
 
-    LLVM_DEBUG(llvm::dbgs() << "found @" << externalSymbol << "\n");
-    if (symbolFunc.getFunctionType() != externalSymbolFunc.getFunctionType()) {
-      return symbolFunc.emitError()
-             << "external definition has a mismatching signature ("
-             << externalSymbolFunc.getFunctionType() << ")";
+    if ((isExternalConsumed && !isConsumed) ||
+        (isExternalReadonly && !isReadonly)) {
+      return func1.emitError()
+             << "external definition has mismatching consumption "
+                "annotations for argument #"
+             << i;
     }
+  }
+
+  // `func1` is the external one, so we can remove it.
+  assert(func1.isExternal());
+  func1->erase();
 
-    for (unsigned i = 0, e = symbolFunc.getNumArguments(); i < e; ++i) {
-      bool isExternalConsumed =
-          externalSymbolFunc.getArgAttr(i, consumedName) != nullptr;
-      bool isExternalReadonly =
-          externalSymbolFunc.getArgAttr(i, readOnlyName) != nullptr;
-      bool isConsumed = symbolFunc.getArgAttr(i, consumedName) != nullptr;
-      bool isReadonly = symbolFunc.getArgAttr(i, readOnlyName) != nullptr;
-      if (!isExternalConsumed && !isExternalReadonly) {
-        if (isConsumed)
-          externalSymbolFunc.setArgAttr(i, consumedName, UnitAttr::get(&ctx));
-        else if (isReadonly)
-          externalSymbolFunc.setArgAttr(i, readOnlyName, UnitAttr::get(&ctx));
+  return success();
+}
+
+/// Merge all symbols from `other` into `target`. Both ops need to implement the
+/// `SymbolTable` trait. Operations are moved from `other`, i.e., `other` may be
+/// modified by this function and might not verify after the function returns.
+/// Upon merging, private symbols may be renamed in order to avoid collisions in
+/// the result. Public symbols may not collide, with the exception of
+/// instances of `SymbolOpInterface`, where collisions are allowed if at least
+/// one of the two is external, in which case the other op preserved (or any one
+/// of the two if both are external).
+static LogicalResult mergeSymbolsInto(Operation *target,
+                                      OwningOpRef<Operation *> other) {
+  assert(target->hasTrait<OpTrait::SymbolTable>() &&
+         "requires target to implement the 'SymbolTable' trait");
+  assert(other->hasTrait<OpTrait::SymbolTable>() &&
+         "requires target to implement the 'SymbolTable' trait");
+
+  SymbolTable targetSymbolTable(target);
+  SymbolTable otherSymbolTable(*other);
+
+  int uniqueId = 0;
+
+  // Step 1:
+  //
+  // Rename private symbols in both ops in order to resolve conflicts that can
+  // be resolved that way.
+  LLVM_DEBUG(DBGS() << "renaming private symbols to resolve conflicts:\n");
+  for (auto [symbolTable, otherSymbolTable] : llvm::zip(
+           SmallVector<SymbolTable *>{&targetSymbolTable, &otherSymbolTable},
+           SmallVector<SymbolTable *>{&otherSymbolTable, &targetSymbolTable})) {
----------------
ftynse wrote:

This place is the rare place to explicitly specify the number of stack elements in the vector.

https://github.com/llvm/llvm-project/pull/67560


More information about the Mlir-commits mailing list