[Mlir-commits] [mlir] [mlir][transform] Fix handling of transitive include in interpreter. (PR #67560)
Oleksandr Alex Zinenko
llvmlistbot at llvm.org
Mon Oct 2 10:46:47 PDT 2023
Ingo =?utf-8?q?Müller?= <ingomueller at google.com>,
Ingo =?utf-8?q?Müller?= <ingomueller at google.com>
Message-ID:
In-Reply-To: <llvm/llvm-project/pull/67560/mlir at github.com>
================
@@ -302,80 +304,268 @@ static void performOptionalDebugActions(
transform->removeAttr(kTransformDialectTagAttrName);
}
-/// Replaces external symbols in `block` with their (non-external) definitions
-/// from the given module.
-static LogicalResult defineDeclaredSymbols(Block &block, ModuleOp definitions) {
- MLIRContext &ctx = *definitions->getContext();
- auto consumedName =
- StringAttr::get(&ctx, transform::TransformDialect::kArgConsumedAttrName);
- auto readOnlyName =
- StringAttr::get(&ctx, transform::TransformDialect::kArgReadOnlyAttrName);
-
- for (Operation &op : llvm::make_early_inc_range(block)) {
- LLVM_DEBUG(DBGS() << op << "\n");
- auto symbol = dyn_cast<SymbolOpInterface>(op);
- if (!symbol)
- continue;
- if (symbol->getNumRegions() == 1 && !symbol->getRegion(0).empty())
- continue;
-
- LLVM_DEBUG(DBGS() << "looking for definition of symbol "
- << symbol.getNameAttr() << ":");
- SymbolTable symbolTable(definitions);
- Operation *externalSymbol = symbolTable.lookup(symbol.getNameAttr());
- if (!externalSymbol || externalSymbol->getNumRegions() != 1 ||
- externalSymbol->getRegion(0).empty()) {
- LLVM_DEBUG(llvm::dbgs() << "not found\n");
- continue;
+/// Rename `op` to avoid a collision with `otherOp`. `symbolTable` and
+/// `otherSymbolTable` are the symbol tables of the two ops, respectively.
+/// `uniqueId` is used to generate a unique name in the context of the caller.
+LogicalResult renameToUnique(SymbolOpInterface op, SymbolOpInterface otherOp,
+ SymbolTable &symbolTable,
+ SymbolTable &otherSymbolTable, int &uniqueId) {
+ assert(symbolTable.lookup(op.getNameAttr()) == op &&
+ "symbol table does not contain op");
+ assert(otherSymbolTable.lookup(otherOp.getNameAttr()) == otherOp &&
+ "other symbol table does not contain other op");
+
+ // Determine new name that is unique in both symbol tables.
+ StringAttr oldName = op.getNameAttr();
+ StringAttr newName;
+ {
+ MLIRContext *context = op->getContext();
+ SmallString<64> prefix = oldName.getValue();
+ prefix.push_back('_');
+ while (true) {
+ newName = StringAttr::get(context, prefix + Twine(uniqueId++));
+ if (!symbolTable.lookup(newName) && !otherSymbolTable.lookup(newName)) {
+ break;
+ }
}
+ }
+
+ // Apply renaming.
+ LLVM_DEBUG(llvm::dbgs() << ", renaming to @" << newName.getValue() << "\n");
+ Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(op);
+ if (failed(SymbolTable::replaceAllSymbolUses(op, newName, symbolTableOp))) {
+ InFlightDiagnostic diag =
+ emitError(op->getLoc(),
+ Twine("failed to rename symbol to @") + newName.getValue());
+ diag.attachNote(otherOp->getLoc())
+ << "attempted renaming due to collision with this op";
+ return diag;
+ }
+
+ // Change the symbol in the op itself and update the symbol table.
+ symbolTable.remove(op);
+ SymbolTable::setSymbolName(op, newName);
+ symbolTable.insert(op);
+
+ assert(symbolTable.lookup(newName) == op &&
+ "symbol table does not resolve to renamed op");
+ assert(symbolTable.lookup(oldName) == nullptr &&
+ "symbol table still resolves old name");
+
+ return success();
+}
- auto symbolFunc = dyn_cast<FunctionOpInterface>(op);
- auto externalSymbolFunc = dyn_cast<FunctionOpInterface>(externalSymbol);
- if (!symbolFunc || !externalSymbolFunc) {
- LLVM_DEBUG(llvm::dbgs() << "cannot compare types\n");
+/// Return whether `func1` can be merged into `func2`.
+bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) {
+ return func1.isExternal() && (func2.isPublic() || func2.isExternal());
+}
+
+/// Merge `func1` into `func2`. The two ops must be inside the same parent op
+/// and mergable according to `canMergeInto`. The function erases `func1` such
+/// that only `func2` exists when the function returns.
+LogicalResult mergeInto(FunctionOpInterface func1, FunctionOpInterface func2) {
+ assert(canMergeInto(func1, func2));
+ assert(func1->getParentOp() == func2->getParentOp() &&
+ "expected func1 and func2 to be in the same parent op");
+
+ MLIRContext *context = func1->getContext();
+ auto consumedName = StringAttr::get(
+ context, transform::TransformDialect::kArgConsumedAttrName);
+ auto readOnlyName = StringAttr::get(
+ context, transform::TransformDialect::kArgReadOnlyAttrName);
+
+ // Check that function signatures match.
+ if (func1.getFunctionType() != func2.getFunctionType()) {
+ return func1.emitError()
+ << "external definition has a mismatching signature ("
+ << func2.getFunctionType() << ")";
+ }
+
+ // Check and merge argument attributes.
+ for (unsigned i = 0, e = func1.getNumArguments(); i < e; ++i) {
+ bool isExternalConsumed = func2.getArgAttr(i, consumedName) != nullptr;
+ bool isExternalReadonly = func2.getArgAttr(i, readOnlyName) != nullptr;
+ bool isConsumed = func1.getArgAttr(i, consumedName) != nullptr;
+ bool isReadonly = func1.getArgAttr(i, readOnlyName) != nullptr;
+ if (!isExternalConsumed && !isExternalReadonly) {
+ if (isConsumed)
+ func2.setArgAttr(i, consumedName, UnitAttr::get(context));
+ else if (isReadonly)
+ func2.setArgAttr(i, readOnlyName, UnitAttr::get(context));
continue;
}
- LLVM_DEBUG(llvm::dbgs() << "found @" << externalSymbol << "\n");
- if (symbolFunc.getFunctionType() != externalSymbolFunc.getFunctionType()) {
- return symbolFunc.emitError()
- << "external definition has a mismatching signature ("
- << externalSymbolFunc.getFunctionType() << ")";
+ if ((isExternalConsumed && !isConsumed) ||
+ (isExternalReadonly && !isReadonly)) {
+ return func1.emitError()
+ << "external definition has mismatching consumption "
+ "annotations for argument #"
+ << i;
}
+ }
+
+ // `func1` is the external one, so we can remove it.
+ assert(func1.isExternal());
+ func1->erase();
- for (unsigned i = 0, e = symbolFunc.getNumArguments(); i < e; ++i) {
- bool isExternalConsumed =
- externalSymbolFunc.getArgAttr(i, consumedName) != nullptr;
- bool isExternalReadonly =
- externalSymbolFunc.getArgAttr(i, readOnlyName) != nullptr;
- bool isConsumed = symbolFunc.getArgAttr(i, consumedName) != nullptr;
- bool isReadonly = symbolFunc.getArgAttr(i, readOnlyName) != nullptr;
- if (!isExternalConsumed && !isExternalReadonly) {
- if (isConsumed)
- externalSymbolFunc.setArgAttr(i, consumedName, UnitAttr::get(&ctx));
- else if (isReadonly)
- externalSymbolFunc.setArgAttr(i, readOnlyName, UnitAttr::get(&ctx));
+ return success();
+}
+
+/// Merge all symbols from `other` into `target`. Both ops need to implement the
+/// `SymbolTable` trait. Operations are moved from `other`, i.e., `other` may be
+/// modified by this function and might not verify after the function returns.
+/// Upon merging, private symbols may be renamed in order to avoid collisions in
+/// the result. Public symbols may not collide, with the exception of
+/// instances of `SymbolOpInterface`, where collisions are allowed if at least
+/// one of the two is external, in which case the other op preserved (or any one
+/// of the two if both are external).
+static LogicalResult mergeSymbolsInto(Operation *target,
+ OwningOpRef<Operation *> other) {
+ assert(target->hasTrait<OpTrait::SymbolTable>() &&
+ "requires target to implement the 'SymbolTable' trait");
+ assert(other->hasTrait<OpTrait::SymbolTable>() &&
+ "requires target to implement the 'SymbolTable' trait");
+
+ SymbolTable targetSymbolTable(target);
+ SymbolTable otherSymbolTable(*other);
+
+ int uniqueId = 0;
+
+ // Step 1:
+ //
+ // Rename private symbols in both ops in order to resolve conflicts that can
+ // be resolved that way.
+ LLVM_DEBUG(DBGS() << "renaming private symbols to resolve conflicts:\n");
+ for (auto [symbolTable, otherSymbolTable] : llvm::zip(
+ SmallVector<SymbolTable *>{&targetSymbolTable, &otherSymbolTable},
+ SmallVector<SymbolTable *>{&otherSymbolTable, &targetSymbolTable})) {
+ Operation *symbolTableOp = symbolTable->getOp();
+ for (Operation &op : symbolTableOp->getRegion(0).front()) {
+ auto symbolOp = dyn_cast<SymbolOpInterface>(op);
+ if (!symbolOp)
+ continue;
+ StringAttr name = symbolOp.getNameAttr();
+ LLVM_DEBUG(DBGS() << " found @" << name.getValue() << "\n");
+
+ // Check if there is a colliding op in the other module.
+ auto collidingOp =
+ cast_or_null<SymbolOpInterface>(otherSymbolTable->lookup(name));
+ if (!collidingOp)
continue;
+
+ LLVM_DEBUG(DBGS() << " collision found for @" << name.getValue());
+
+ // Collisions are fine if both opt are functions and can be merged.
+ if (auto funcOp = dyn_cast<FunctionOpInterface>(op),
+ collidingFuncOp =
+ dyn_cast<FunctionOpInterface>(collidingOp.getOperation());
+ funcOp && collidingFuncOp) {
+ if (canMergeInto(funcOp, collidingFuncOp) ||
+ canMergeInto(collidingFuncOp, funcOp)) {
+ LLVM_DEBUG(llvm::dbgs() << " but both ops are functions and "
+ "will be merged\n");
+ continue;
+ }
+
+ // If they can't be merged, proceed like any other collision.
+ LLVM_DEBUG(llvm::dbgs() << " and both ops are function definitions");
}
- if ((isExternalConsumed && !isConsumed) ||
- (isExternalReadonly && !isReadonly)) {
- return symbolFunc.emitError()
- << "external definition has mismatching consumption annotations "
- "for argument #"
- << i;
+ // Collision can be resolved if one of the ops is private.
+ if (symbolOp.isPrivate()) {
+ if (failed(renameToUnique(symbolOp, collidingOp, *symbolTable,
+ *otherSymbolTable, uniqueId)))
+ return failure();
+ continue;
+ }
+ if (collidingOp.isPrivate()) {
+ if (failed(renameToUnique(collidingOp, symbolOp, *otherSymbolTable,
+ *symbolTable, uniqueId)))
+ return failure();
+ continue;
}
+
+ LLVM_DEBUG(llvm::dbgs() << ", emitting error\n");
+ InFlightDiagnostic diag =
+ emitError(symbolOp->getLoc(),
+ Twine("doubly defined symbol @") + name.getValue());
+ diag.attachNote(collidingOp->getLoc()) << "previously defined here";
+ return diag;
}
+ }
- OpBuilder builder(&op);
- builder.setInsertionPoint(&op);
- builder.clone(*externalSymbol);
- symbol->erase();
+ for (auto *op : SmallVector<Operation *>{target, *other}) {
+ if (failed(mlir::verify(op)))
+ return emitError(op->getLoc(),
+ "failed to verify input op after renaming");
----------------
ftynse wrote:
I don't think we should be running local verification here. Has it caught any errors?
Maybe we should consider having "merge" as a separate pass that runs before the interpreter so there is an opportunity for the verifier to run in between (and to be disabled for compiler performance-sensitive cases).
https://github.com/llvm/llvm-project/pull/67560
More information about the Mlir-commits
mailing list