[Mlir-commits] [mlir] [mlir][transform] Fix handling of transitive include in interpreter. (PR #67560)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Sep 29 03:10:45 PDT 2023
Ingo =?utf-8?q?Müller?= <ingomueller at google.com>
Message-ID:
In-Reply-To: <llvm/llvm-project/pull/67560/mlir at github.com>
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
<details>
<summary>Changes</summary>
This is a new attempt at fixing transitive includes in the transform dialect interpreter (next to #<!-- -->67241) and a preparation for being able to load multiple transform library files (#<!-- -->67120).
Until now, the interpreter would only load those symbols from the provided library files that were declared in the main transform module. However, sequences in the library may include other sequences on their own. Until now, if such sequences were not *also* declared in the main transform module, the interpreter would fail to resolve them. Forward declaring all of them is undesirable as it defeats the purpose of encapsulation into library modules.
This PR implements a kind of linker for transform scripts to solve this problem. The linker merges all symbols of the library module into the main module before interpreting the latter. Symbols whose names collide are handled as follows: (1) if they are both functions (in the sense of `FunctionOpInterface`) with compatible signatures, one is external, and the other one is public, then they are merged; (2) of one of them is private, that one is renamed; and (3) an error is raised otherwise.
TODO:
- [x] Enable running interpreter several times (see discussion below).
- [x] Fix several smaller issues marked with `XXX`.
---
Patch is 27.39 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/67560.diff
6 Files Affected:
- (modified) mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp (+252-60)
- (modified) mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-and-schedule.mlir (+6)
- (modified) mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir (+19-4)
- (modified) mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir (+49-12)
- (added) mlir/test/Dialect/Transform/test-interpreter-external-symbol-def-invalid.mlir (+14)
- (modified) mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir (+42-1)
``````````diff
diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
index d5c65b23e3a2134..3dd172335160fb3 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
@@ -50,6 +50,8 @@ constexpr static llvm::StringLiteral kTransformDialectTagPayloadRootValue =
constexpr static llvm::StringLiteral
kTransformDialectTagTransformContainerValue = "transform_container";
+namespace {
+
/// Utility to parse the content of a `transformFileName` MLIR file containing
/// a transform dialect specification.
static LogicalResult
@@ -302,80 +304,270 @@ static void performOptionalDebugActions(
transform->removeAttr(kTransformDialectTagAttrName);
}
-/// Replaces external symbols in `block` with their (non-external) definitions
-/// from the given module.
-static LogicalResult defineDeclaredSymbols(Block &block, ModuleOp definitions) {
- MLIRContext &ctx = *definitions->getContext();
- auto consumedName =
- StringAttr::get(&ctx, transform::TransformDialect::kArgConsumedAttrName);
- auto readOnlyName =
- StringAttr::get(&ctx, transform::TransformDialect::kArgReadOnlyAttrName);
-
- for (Operation &op : llvm::make_early_inc_range(block)) {
- LLVM_DEBUG(DBGS() << op << "\n");
- auto symbol = dyn_cast<SymbolOpInterface>(op);
- if (!symbol)
- continue;
- if (symbol->getNumRegions() == 1 && !symbol->getRegion(0).empty())
- continue;
-
- LLVM_DEBUG(DBGS() << "looking for definition of symbol "
- << symbol.getNameAttr() << ":");
- SymbolTable symbolTable(definitions);
- Operation *externalSymbol = symbolTable.lookup(symbol.getNameAttr());
- if (!externalSymbol || externalSymbol->getNumRegions() != 1 ||
- externalSymbol->getRegion(0).empty()) {
- LLVM_DEBUG(llvm::dbgs() << "not found\n");
- continue;
+/// Rename `op` to avoid a collision with `otherOp`. `symbolTable` and
+/// `otherSymbolTable` are the symbol tables of the two ops, respectively.
+/// `uniqueId` is used to generate a unique name in the context of the caller.
+LogicalResult renameToUnique(SymbolOpInterface op, SymbolOpInterface otherOp,
+ SymbolTable &symbolTable,
+ SymbolTable &otherSymbolTable, int &uniqueId) {
+ assert(symbolTable.lookup(op.getNameAttr()) == op &&
+ "symbol table does not contain op");
+ assert(otherSymbolTable.lookup(otherOp.getNameAttr()) == otherOp &&
+ "other symbol table does not contain other op");
+
+ // Determine new name that is unique in both symbol tables.
+ StringAttr oldName = op.getNameAttr();
+ StringAttr newName;
+ {
+ MLIRContext *context = op->getContext();
+ SmallString<64> prefix = oldName.getValue();
+ prefix.push_back('_');
+ while (true) {
+ newName = StringAttr::get(context, prefix + Twine(uniqueId++));
+ if (!symbolTable.lookup(newName) && !otherSymbolTable.lookup(newName)) {
+ break;
+ }
}
+ }
+
+ // Apply renaming.
+ LLVM_DEBUG(llvm::dbgs() << ", renaming to @" << newName.getValue() << "\n");
+ Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(op);
+ if (failed(SymbolTable::replaceAllSymbolUses(op, newName, symbolTableOp))) {
+ InFlightDiagnostic diag =
+ emitError(op->getLoc(),
+ Twine("failed to rename symbol to @") + newName.getValue());
+ diag.attachNote(otherOp->getLoc())
+ << "attempted renaming due to collision with this op";
+ return diag;
+ }
+
+ // Change the symbol in the op itself and update the symbol table.
+ symbolTable.remove(op);
+ SymbolTable::setSymbolName(op, newName);
+ symbolTable.insert(op);
+
+ assert(symbolTable.lookup(newName) == op &&
+ "symbol table does not resolve to renamed op");
+ assert(symbolTable.lookup(oldName) == nullptr &&
+ "symbol table still resolves old name");
+
+ return success();
+}
- auto symbolFunc = dyn_cast<FunctionOpInterface>(op);
- auto externalSymbolFunc = dyn_cast<FunctionOpInterface>(externalSymbol);
- if (!symbolFunc || !externalSymbolFunc) {
- LLVM_DEBUG(llvm::dbgs() << "cannot compare types\n");
+/// Return whether `func1` can be merged into `func2`.
+bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) {
+ return func1.isExternal() && (func2.isPublic() || func2.isExternal());
+}
+
+/// Merge `func1` into `func2`. The two ops must be inside the same parent op
+/// and mergable according to `canMergeInto`. The function erases `func1` such
+/// that only `func2` exists when the function returns.
+LogicalResult mergeInto(FunctionOpInterface func1, FunctionOpInterface func2) {
+ assert(canMergeInto(func1, func2));
+ assert(func1->getParentOp() == func2->getParentOp() &&
+ "expected func1 and func2 to be in the same parent op");
+
+ MLIRContext *context = func1->getContext();
+ auto consumedName = StringAttr::get(
+ context, transform::TransformDialect::kArgConsumedAttrName);
+ auto readOnlyName = StringAttr::get(
+ context, transform::TransformDialect::kArgReadOnlyAttrName);
+
+ // Check that function signatures match.
+ if (func1.getFunctionType() != func2.getFunctionType()) {
+ return func1.emitError()
+ << "external definition has a mismatching signature ("
+ << func2.getFunctionType() << ")";
+ }
+
+ // Check and merge argument attributes.
+ for (unsigned i = 0, e = func1.getNumArguments(); i < e; ++i) {
+ bool isExternalConsumed = func2.getArgAttr(i, consumedName) != nullptr;
+ bool isExternalReadonly = func2.getArgAttr(i, readOnlyName) != nullptr;
+ bool isConsumed = func1.getArgAttr(i, consumedName) != nullptr;
+ bool isReadonly = func1.getArgAttr(i, readOnlyName) != nullptr;
+ if (!isExternalConsumed && !isExternalReadonly) {
+ if (isConsumed)
+ func2.setArgAttr(i, consumedName, UnitAttr::get(context));
+ else if (isReadonly)
+ func2.setArgAttr(i, readOnlyName, UnitAttr::get(context));
continue;
}
- LLVM_DEBUG(llvm::dbgs() << "found @" << externalSymbol << "\n");
- if (symbolFunc.getFunctionType() != externalSymbolFunc.getFunctionType()) {
- return symbolFunc.emitError()
- << "external definition has a mismatching signature ("
- << externalSymbolFunc.getFunctionType() << ")";
+ if ((isExternalConsumed && !isConsumed) ||
+ (isExternalReadonly && !isReadonly)) {
+ return func1.emitError()
+ << "external definition has mismatching consumption "
+ "annotations for argument #"
+ << i;
}
+ }
+
+ // `func1` is the external one, so we can remove it.
+ assert(func1.isExternal());
+ func1->erase();
+
+ return success();
+}
+
+/// Merge all symbols from `other` into `target`. Both ops need to implement the
+/// `SymbolTable` trait. Operations are moved from `other`, i.e., `other` may be
+/// modified by this function and might not verify after the function returns.
+/// Upon merging, private symbols may be renamed in order to avoid collisions in
+/// the result. Public symbols may not collide, with the exception of
+/// instances of `SymbolOpInterface`, where collisions are allowed if at least
+/// one of the two is external, in which case the other op preserved (or any one
+/// of the two if both are external).
+static LogicalResult mergeSymbolsInto(Operation *target,
+ OwningOpRef<Operation *> other) {
+ assert(target->hasTrait<OpTrait::SymbolTable>() &&
+ "requires target to implement the 'SymbolTable' trait");
+ assert(other->hasTrait<OpTrait::SymbolTable>() &&
+ "requires target to implement the 'SymbolTable' trait");
+
+ SymbolTable targetSymbolTable(target);
+ SymbolTable otherSymbolTable(*other);
+
+ int uniqueId = 0;
+
+ // Step 1:
+ //
+ // Rename private symbols in both ops in order to resolve conflicts that can
+ // be resolved that way.
+ LLVM_DEBUG(DBGS() << "renaming private symbols to resolve conflicts:\n");
+ for (auto [symbolTable, otherSymbolTable] : llvm::zip(
+ SmallVector<SymbolTable *>{&targetSymbolTable, &otherSymbolTable},
+ SmallVector<SymbolTable *>{&otherSymbolTable, &targetSymbolTable})) {
+ Operation *symbolTableOp = symbolTable->getOp();
+ for (Operation &op : symbolTableOp->getRegion(0).front()) {
+ auto symbolOp = dyn_cast<SymbolOpInterface>(op);
+ if (!symbolOp)
+ continue;
+ StringAttr name = symbolOp.getNameAttr();
+ LLVM_DEBUG(DBGS() << " found @" << name.getValue() << "\n");
- for (unsigned i = 0, e = symbolFunc.getNumArguments(); i < e; ++i) {
- bool isExternalConsumed =
- externalSymbolFunc.getArgAttr(i, consumedName) != nullptr;
- bool isExternalReadonly =
- externalSymbolFunc.getArgAttr(i, readOnlyName) != nullptr;
- bool isConsumed = symbolFunc.getArgAttr(i, consumedName) != nullptr;
- bool isReadonly = symbolFunc.getArgAttr(i, readOnlyName) != nullptr;
- if (!isExternalConsumed && !isExternalReadonly) {
- if (isConsumed)
- externalSymbolFunc.setArgAttr(i, consumedName, UnitAttr::get(&ctx));
- else if (isReadonly)
- externalSymbolFunc.setArgAttr(i, readOnlyName, UnitAttr::get(&ctx));
+ // Check if there is a colliding op in the other module.
+ auto collidingOp =
+ cast_or_null<SymbolOpInterface>(otherSymbolTable->lookup(name));
+ if (!collidingOp)
continue;
+
+ LLVM_DEBUG(DBGS() << " collision found for @" << name.getValue());
+
+ // Collisions are fine if both opt are functions and can be merged.
+ if (auto funcOp = dyn_cast<FunctionOpInterface>(op),
+ collidingFuncOp =
+ dyn_cast<FunctionOpInterface>(collidingOp.getOperation());
+ funcOp && collidingFuncOp) {
+ if (canMergeInto(funcOp, collidingFuncOp) ||
+ canMergeInto(collidingFuncOp, funcOp)) {
+ LLVM_DEBUG(llvm::dbgs() << " but both ops are functions and "
+ "will be merged\n");
+ continue;
+ }
+
+ // If they can't be merged, proceed like any other collision.
+ LLVM_DEBUG(llvm::dbgs() << " and both ops are function definitions");
}
- if ((isExternalConsumed && !isConsumed) ||
- (isExternalReadonly && !isReadonly)) {
- return symbolFunc.emitError()
- << "external definition has mismatching consumption annotations "
- "for argument #"
- << i;
+ // Collision can be resolved if one of the ops is private.
+ if (symbolOp.isPrivate()) {
+ if (failed(renameToUnique(symbolOp, collidingOp, *symbolTable,
+ *otherSymbolTable, uniqueId)))
+ return failure();
+ continue;
+ }
+ if (collidingOp.isPrivate()) {
+ if (failed(renameToUnique(collidingOp, symbolOp, *otherSymbolTable,
+ *symbolTable, uniqueId)))
+ return failure();
+ continue;
}
+
+ LLVM_DEBUG(llvm::dbgs() << ", emitting error\n");
+ InFlightDiagnostic diag =
+ emitError(symbolOp->getLoc(),
+ Twine("doubly defined symbol @") + name.getValue());
+ diag.attachNote(collidingOp->getLoc()) << "previously defined here";
+ return diag;
}
+ }
- OpBuilder builder(&op);
- builder.setInsertionPoint(&op);
- builder.clone(*externalSymbol);
- symbol->erase();
+ for (auto *op : SmallVector<Operation *>{target, *other}) {
+ if (failed(mlir::verify(op)))
+ return emitError(op->getLoc(),
+ "failed to verify input op after renaming");
}
+ // Step 2:
+ //
+ // Move all ops from `other` into target and merge public symbols.
+ LLVM_DEBUG(DBGS() << "moving all symbols into target\n");
+ {
+ SmallVector<SymbolOpInterface> opsToMove;
+ for (Operation &op : other->getRegion(0).front()) {
+ if (auto symbol = dyn_cast<SymbolOpInterface>(op))
+ opsToMove.push_back(symbol);
+ }
+
+ for (SymbolOpInterface op : opsToMove) {
+ // Remember potentially colliding op in the target module.
+ auto collidingOp = cast_or_null<SymbolOpInterface>(
+ targetSymbolTable.lookup(op.getNameAttr()));
+
+ // Move op even if we get a collision.
+ LLVM_DEBUG(DBGS() << " moving @" << op.getName());
+ op->moveAfter(&target->getRegion(0).front(),
+ target->getRegion(0).front().begin());
+
+ // If there is no collision, we are done.
+ if (!collidingOp) {
+ LLVM_DEBUG(llvm::dbgs() << " without collision\n");
+ continue;
+ }
+
+ // The two colliding ops must both be functions because we have already
+ // emitted errors otherwise earlier.
+ auto funcOp = cast<FunctionOpInterface>(op.getOperation());
+ auto collidingFuncOp =
+ cast<FunctionOpInterface>(collidingOp.getOperation());
+
+ // Both ops are in the target module now and can be treated symmetrically,
+ // so w.l.o.g. we can reduce to merging `funcOp` into `collidingFuncOp`.
+ if (!canMergeInto(funcOp, collidingFuncOp)) {
+ std::swap(funcOp, collidingFuncOp);
+ }
+ assert(canMergeInto(funcOp, collidingFuncOp));
+
+ LLVM_DEBUG(llvm::dbgs() << " with collision, trying to keep op at "
+ << collidingFuncOp.getLoc() << ":\n"
+ << collidingFuncOp << "\n");
+
+ // Update symbol table. This works with or without the previous `swap`.
+ targetSymbolTable.remove(funcOp);
+ targetSymbolTable.insert(collidingFuncOp);
+ assert(targetSymbolTable.lookup(funcOp.getName()) == collidingFuncOp);
+
+ // Do the actual merging.
+ if (failed(mergeInto(funcOp, collidingFuncOp))) {
+ return failure();
+ }
+
+ assert(succeeded(mlir::verify(target)));
+ }
+ }
+
+ if (failed(mlir::verify(target)))
+ return emitError(target->getLoc(),
+ "failed to verify target op after merging symbols");
+
+ LLVM_DEBUG(DBGS() << "done merging ops\n");
return success();
}
+} // namespace
+
LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
Operation *target, StringRef passName,
const std::shared_ptr<OwningOpRef<ModuleOp>> &sharedTransformModule,
@@ -438,8 +630,9 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
diag.attachNote(target->getLoc()) << "pass anchor op";
return diag;
}
- if (failed(defineDeclaredSymbols(*transformRoot->getBlock(),
- libraryModule->get())))
+ if (failed(
+ mergeSymbolsInto(SymbolTable::getNearestSymbolTable(transformRoot),
+ libraryModule->get()->clone())))
return failure();
}
@@ -499,8 +692,7 @@ LogicalResult transform::detail::interpreterBaseInitializeImpl(
return success();
if (module && *module) {
- if (failed(defineDeclaredSymbols(*module->get().getBody(),
- parsedLibrary.get())))
+ if (failed(mergeSymbolsInto(module->get(), std::move(parsedLibrary))))
return failure();
} else {
libraryModule =
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-and-schedule.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-and-schedule.mlir
index 3d4cb0776982934..dd8d141e994da0e 100644
--- a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-and-schedule.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-and-schedule.mlir
@@ -11,4 +11,10 @@
// expected-remark @below {{message}}
// expected-remark @below {{unannotated}}
+// expected-remark @below {{internal colliding (without suffix)}}
+// expected-remark @below {{internal colliding_0}}
+// expected-remark @below {{internal colliding_1}}
+// expected-remark @below {{internal colliding_3}}
+// expected-remark @below {{internal colliding_4}}
+// expected-remark @below {{internal colliding_5}}
module {}
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir
index b21abbbdfd6d045..7452deb39b6c18d 100644
--- a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir
@@ -1,16 +1,16 @@
-// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter)" \
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def-invalid.mlir}, test-transform-dialect-interpreter)" \
// RUN: --verify-diagnostics --split-input-file
-// The definition of the @foo named sequence is provided in another file. It
+// The definition of the @print_message named sequence is provided in another file. It
// will be included because of the pass option.
module attributes {transform.with_named_sequence} {
// expected-error @below {{external definition has a mismatching signature}}
- transform.named_sequence private @foo(!transform.op<"builtin.module"> {transform.readonly})
+ transform.named_sequence private @print_message(!transform.op<"builtin.module"> {transform.readonly})
transform.sequence failures(propagate) {
^bb0(%arg0: !transform.op<"builtin.module">):
- include @foo failures(propagate) (%arg0) : (!transform.op<"builtin.module">) -> ()
+ include @print_message failures(propagate) (%arg0) : (!transform.op<"builtin.module">) -> ()
}
}
@@ -37,3 +37,18 @@ module attributes {transform.with_named_sequence} {
include @consuming failures(suppress) (%arg0) : (!transform.any_op) -> ()
}
}
+
+// -----
+
+module attributes {transform.with_named_sequence} {
+ // expected-error @below {{doubly defined symbol @print_message}}
+ transform.named_sequence @print_message(%arg0: !transform.any_op {transform.readonly}) {
+ transform.test_print_remark_at_operand %arg0, "message" : !transform.any_op
+ transform.yield
+ }
+
+ transform.sequence failures(suppress) {
+ ^bb0(%arg0: !transform.any_op):
+ include @print_message failures(propagate) (%arg0) : (!transform.any_op) -> ()
+ }
+}
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir
index 04b6c5a02e0adf1..a9083fe3e70788a 100644
--- a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir
@@ -4,29 +4,66 @@
// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter)" \
// RUN: --verify-diagnostics --split-input-file | FileCheck %s
-// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
-// RUN: --verify-diagnostics --split-input-file | FileCheck %s
-
-// The definition of the @foo named sequence is provided in another file. It
-// will be included because of the pass option. Repeated application of the
-// same pass, with or without the library option, should not be a problem.
+// The definition of the @print_message named sequence is provided in another
+// file. It will be included because of the pass option. Subsequent application
+// of the same pass works but only without the library file (since the first
+// application loads external symbols and loading them again woul make them
+// clash).
// Note that the same diagnostic produced twice at the same location only
// needs to be matched once.
// expected-remark @below {{message}}
// expected-remark @below {{unannotated}}
+// ...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/67560
More information about the Mlir-commits
mailing list