[Mlir-commits] [mlir] [mlir][transform] Handle multiple library preloading passes. (PR #69705)
Ingo Müller
llvmlistbot at llvm.org
Tue Oct 24 22:37:54 PDT 2023
https://github.com/ingomueller-net updated https://github.com/llvm/llvm-project/pull/69705
>From b6a98688503e443432f13c981c559b55caad62c3 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ingo=20M=C3=BCller?= <ingomueller at google.com>
Date: Fri, 20 Oct 2023 10:13:44 +0000
Subject: [PATCH 1/6] Move module merging utils to dialect target. (NFC)
This is a preparatory commit for fixing the handling of multiple
transform library modules. The plan is to merge every newly loaded
module into the global library module rather than keeping around a list.
Since that merging happens in the dialect and depends on the dialect,
the two depend on each other and thus have to live in the same CMake
target.
---
.../include/mlir/Dialect/Transform/IR/Utils.h | 40 +++
.../Transforms/TransformInterpreterUtils.h | 14 -
mlir/lib/Dialect/Transform/IR/CMakeLists.txt | 1 +
mlir/lib/Dialect/Transform/IR/Utils.cpp | 244 ++++++++++++++++++
.../TransformInterpreterPassBase.cpp | 1 +
.../Transforms/TransformInterpreterUtils.cpp | 227 +---------------
6 files changed, 287 insertions(+), 240 deletions(-)
create mode 100644 mlir/include/mlir/Dialect/Transform/IR/Utils.h
create mode 100644 mlir/lib/Dialect/Transform/IR/Utils.cpp
diff --git a/mlir/include/mlir/Dialect/Transform/IR/Utils.h b/mlir/include/mlir/Dialect/Transform/IR/Utils.h
new file mode 100644
index 000000000000000..e35c0ada07a4ea6
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Transform/IR/Utils.h
@@ -0,0 +1,40 @@
+//===- Utils.h - Utils related to the transform dialect ---------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TRANSFORM_IR_UTILS_H
+#define MLIR_DIALECT_TRANSFORM_IR_UTILS_H
+
+namespace mlir {
+class InFlightDiagnostic;
+class Operation;
+template <typename>
+class OwningOpRef;
+
+namespace transform {
+namespace detail {
+
+/// Merge all symbols from `other` into `target`. Both ops need to implement the
+/// `SymbolTable` trait. Operations are moved from `other`, i.e., `other` may be
+/// modified by this function and might not verify after the function returns.
+/// Upon merging, private symbols may be renamed in order to avoid collisions in
+/// the result. Public symbols may not collide, with the exception of
+/// instances of `SymbolOpInterface`, where collisions are allowed if at least
+/// one of the two is external, in which case the other op preserved (or any one
+/// of the two if both are external).
+// TODO: Reconsider cloning individual ops rather than forcing users of the
+// function to clone (or move) `other` in order to improve efficiency.
+// This might primarily make sense if we can also prune the symbols that
+// are merged to a subset (such as those that are actually used).
+InFlightDiagnostic mergeSymbolsInto(Operation *target,
+ OwningOpRef<Operation *> other);
+
+} // namespace detail
+} // namespace transform
+} // namespace mlir
+
+#endif // MLIR_DIALECT_TRANSFORM_IR_UTILS_H
\ No newline at end of file
diff --git a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h
index 0cf87c8b3c21cba..9881428f4f7bec5 100644
--- a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h
+++ b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h
@@ -69,20 +69,6 @@ TransformOpInterface findTransformEntryPoint(
Operation *root, ModuleOp module,
StringRef entryPoint = TransformDialect::kTransformEntryPointSymbolName);
-/// Merge all symbols from `other` into `target`. Both ops need to implement the
-/// `SymbolTable` trait. Operations are moved from `other`, i.e., `other` may be
-/// modified by this function and might not verify after the function returns.
-/// Upon merging, private symbols may be renamed in order to avoid collisions in
-/// the result. Public symbols may not collide, with the exception of
-/// instances of `SymbolOpInterface`, where collisions are allowed if at least
-/// one of the two is external, in which case the other op preserved (or any one
-/// of the two if both are external).
-// TODO: Reconsider cloning individual ops rather than forcing users of the
-// function to clone (or move) `other` in order to improve efficiency.
-// This might primarily make sense if we can also prune the symbols that
-// are merged to a subset (such as those that are actually used).
-InFlightDiagnostic mergeSymbolsInto(Operation *target,
- OwningOpRef<Operation *> other);
} // namespace detail
/// Standalone util to apply the named sequence `transformRoot` to `payload` IR.
diff --git a/mlir/lib/Dialect/Transform/IR/CMakeLists.txt b/mlir/lib/Dialect/Transform/IR/CMakeLists.txt
index 7e8802bc755eedf..34083b2fd7aab37 100644
--- a/mlir/lib/Dialect/Transform/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Transform/IR/CMakeLists.txt
@@ -5,6 +5,7 @@ add_mlir_dialect_library(MLIRTransformDialect
TransformInterfaces.cpp
TransformOps.cpp
TransformTypes.cpp
+ Utils.cpp
DEPENDS
MLIRMatchInterfacesIncGen
diff --git a/mlir/lib/Dialect/Transform/IR/Utils.cpp b/mlir/lib/Dialect/Transform/IR/Utils.cpp
new file mode 100644
index 000000000000000..d66639042295a4c
--- /dev/null
+++ b/mlir/lib/Dialect/Transform/IR/Utils.cpp
@@ -0,0 +1,244 @@
+//===- Utils.cpp - Utils related to the transform dialect -------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Transform/IR/Utils.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "llvm/Support/Debug.h"
+
+using namespace mlir;
+
+#define DEBUG_TYPE "transform-dialect-utils"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
+
+/// Return whether `func1` can be merged into `func2`. For that to work
+/// `func1` has to be a declaration (aka has to be external) and `func2`
+/// either has to be a declaration as well, or it has to be public (otherwise,
+/// it wouldn't be visible by `func1`).
+static bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) {
+ return func1.isExternal() && (func2.isPublic() || func2.isExternal());
+}
+
+/// Merge `func1` into `func2`. The two ops must be inside the same parent op
+/// and mergable according to `canMergeInto`. The function erases `func1` such
+/// that only `func2` exists when the function returns.
+static InFlightDiagnostic mergeInto(FunctionOpInterface func1,
+ FunctionOpInterface func2) {
+ assert(canMergeInto(func1, func2));
+ assert(func1->getParentOp() == func2->getParentOp() &&
+ "expected func1 and func2 to be in the same parent op");
+
+ // Check that function signatures match.
+ if (func1.getFunctionType() != func2.getFunctionType()) {
+ return func1.emitError()
+ << "external definition has a mismatching signature ("
+ << func2.getFunctionType() << ")";
+ }
+
+ // Check and merge argument attributes.
+ MLIRContext *context = func1->getContext();
+ auto *td = context->getLoadedDialect<transform::TransformDialect>();
+ StringAttr consumedName = td->getConsumedAttrName();
+ StringAttr readOnlyName = td->getReadOnlyAttrName();
+ for (unsigned i = 0, e = func1.getNumArguments(); i < e; ++i) {
+ bool isExternalConsumed = func2.getArgAttr(i, consumedName) != nullptr;
+ bool isExternalReadonly = func2.getArgAttr(i, readOnlyName) != nullptr;
+ bool isConsumed = func1.getArgAttr(i, consumedName) != nullptr;
+ bool isReadonly = func1.getArgAttr(i, readOnlyName) != nullptr;
+ if (!isExternalConsumed && !isExternalReadonly) {
+ if (isConsumed)
+ func2.setArgAttr(i, consumedName, UnitAttr::get(context));
+ else if (isReadonly)
+ func2.setArgAttr(i, readOnlyName, UnitAttr::get(context));
+ continue;
+ }
+
+ if ((isExternalConsumed && !isConsumed) ||
+ (isExternalReadonly && !isReadonly)) {
+ return func1.emitError()
+ << "external definition has mismatching consumption "
+ "annotations for argument #"
+ << i;
+ }
+ }
+
+ // `func1` is the external one, so we can remove it.
+ assert(func1.isExternal());
+ func1->erase();
+
+ return InFlightDiagnostic();
+}
+
+InFlightDiagnostic
+transform::detail::mergeSymbolsInto(Operation *target,
+ OwningOpRef<Operation *> other) {
+ assert(target->hasTrait<OpTrait::SymbolTable>() &&
+ "requires target to implement the 'SymbolTable' trait");
+ assert(other->hasTrait<OpTrait::SymbolTable>() &&
+ "requires target to implement the 'SymbolTable' trait");
+
+ SymbolTable targetSymbolTable(target);
+ SymbolTable otherSymbolTable(*other);
+
+ // Step 1:
+ //
+ // Rename private symbols in both ops in order to resolve conflicts that can
+ // be resolved that way.
+ LLVM_DEBUG(DBGS() << "renaming private symbols to resolve conflicts:\n");
+ // TODO: Do we *actually* need to test in both directions?
+ for (auto &&[symbolTable, otherSymbolTable] : llvm::zip(
+ SmallVector<SymbolTable *, 2>{&targetSymbolTable, &otherSymbolTable},
+ SmallVector<SymbolTable *, 2>{&otherSymbolTable,
+ &targetSymbolTable})) {
+ Operation *symbolTableOp = symbolTable->getOp();
+ for (Operation &op : symbolTableOp->getRegion(0).front()) {
+ auto symbolOp = dyn_cast<SymbolOpInterface>(op);
+ if (!symbolOp)
+ continue;
+ StringAttr name = symbolOp.getNameAttr();
+ LLVM_DEBUG(DBGS() << " found @" << name.getValue() << "\n");
+
+ // Check if there is a colliding op in the other module.
+ auto collidingOp =
+ cast_or_null<SymbolOpInterface>(otherSymbolTable->lookup(name));
+ if (!collidingOp)
+ continue;
+
+ LLVM_DEBUG(DBGS() << " collision found for @" << name.getValue());
+
+ // Collisions are fine if both opt are functions and can be merged.
+ if (auto funcOp = dyn_cast<FunctionOpInterface>(op),
+ collidingFuncOp =
+ dyn_cast<FunctionOpInterface>(collidingOp.getOperation());
+ funcOp && collidingFuncOp) {
+ if (canMergeInto(funcOp, collidingFuncOp) ||
+ canMergeInto(collidingFuncOp, funcOp)) {
+ LLVM_DEBUG(llvm::dbgs() << " but both ops are functions and "
+ "will be merged\n");
+ continue;
+ }
+
+ // If they can't be merged, proceed like any other collision.
+ LLVM_DEBUG(llvm::dbgs() << " and both ops are function definitions");
+ }
+
+ // Collision can be resolved by renaming if one of the ops is private.
+ auto renameToUnique =
+ [&](SymbolOpInterface op, SymbolOpInterface otherOp,
+ SymbolTable &symbolTable,
+ SymbolTable &otherSymbolTable) -> InFlightDiagnostic {
+ LLVM_DEBUG(llvm::dbgs() << ", renaming\n");
+ FailureOr<StringAttr> maybeNewName =
+ symbolTable.renameToUnique(op, {&otherSymbolTable});
+ if (failed(maybeNewName)) {
+ InFlightDiagnostic diag = op->emitError("failed to rename symbol");
+ diag.attachNote(otherOp->getLoc())
+ << "attempted renaming due to collision with this op";
+ return diag;
+ }
+ LLVM_DEBUG(DBGS() << " renamed to @" << maybeNewName->getValue()
+ << "\n");
+ return InFlightDiagnostic();
+ };
+
+ if (symbolOp.isPrivate()) {
+ InFlightDiagnostic diag = renameToUnique(
+ symbolOp, collidingOp, *symbolTable, *otherSymbolTable);
+ if (failed(diag))
+ return diag;
+ continue;
+ }
+ if (collidingOp.isPrivate()) {
+ InFlightDiagnostic diag = renameToUnique(
+ collidingOp, symbolOp, *otherSymbolTable, *symbolTable);
+ if (failed(diag))
+ return diag;
+ continue;
+ }
+ LLVM_DEBUG(llvm::dbgs() << ", emitting error\n");
+ InFlightDiagnostic diag = symbolOp.emitError()
+ << "doubly defined symbol @" << name.getValue();
+ diag.attachNote(collidingOp->getLoc()) << "previously defined here";
+ return diag;
+ }
+ }
+
+ // TODO: This duplicates pass infrastructure. We should split this pass into
+ // several and let the pass infrastructure do the verification.
+ for (auto *op : SmallVector<Operation *>{target, *other}) {
+ if (failed(mlir::verify(op)))
+ return op->emitError() << "failed to verify input op after renaming";
+ }
+
+ // Step 2:
+ //
+ // Move all ops from `other` into target and merge public symbols.
+ LLVM_DEBUG(DBGS() << "moving all symbols into target\n");
+ {
+ SmallVector<SymbolOpInterface> opsToMove;
+ for (Operation &op : other->getRegion(0).front()) {
+ if (auto symbol = dyn_cast<SymbolOpInterface>(op))
+ opsToMove.push_back(symbol);
+ }
+
+ for (SymbolOpInterface op : opsToMove) {
+ // Remember potentially colliding op in the target module.
+ auto collidingOp = cast_or_null<SymbolOpInterface>(
+ targetSymbolTable.lookup(op.getNameAttr()));
+
+ // Move op even if we get a collision.
+ LLVM_DEBUG(DBGS() << " moving @" << op.getName());
+ op->moveBefore(&target->getRegion(0).front(),
+ target->getRegion(0).front().end());
+
+ // If there is no collision, we are done.
+ if (!collidingOp) {
+ LLVM_DEBUG(llvm::dbgs() << " without collision\n");
+ continue;
+ }
+
+ // The two colliding ops must both be functions because we have already
+ // emitted errors otherwise earlier.
+ auto funcOp = cast<FunctionOpInterface>(op.getOperation());
+ auto collidingFuncOp =
+ cast<FunctionOpInterface>(collidingOp.getOperation());
+
+ // Both ops are in the target module now and can be treated
+ // symmetrically, so w.l.o.g. we can reduce to merging `funcOp` into
+ // `collidingFuncOp`.
+ if (!canMergeInto(funcOp, collidingFuncOp)) {
+ std::swap(funcOp, collidingFuncOp);
+ }
+ assert(canMergeInto(funcOp, collidingFuncOp));
+
+ LLVM_DEBUG(llvm::dbgs() << " with collision, trying to keep op at "
+ << collidingFuncOp.getLoc() << ":\n"
+ << collidingFuncOp << "\n");
+
+ // Update symbol table. This works with or without the previous `swap`.
+ targetSymbolTable.remove(funcOp);
+ targetSymbolTable.insert(collidingFuncOp);
+ assert(targetSymbolTable.lookup(funcOp.getName()) == collidingFuncOp);
+
+ // Do the actual merging.
+ {
+ InFlightDiagnostic diag = mergeInto(funcOp, collidingFuncOp);
+ if (failed(diag))
+ return diag;
+ }
+ }
+ }
+
+ if (failed(mlir::verify(target)))
+ return target->emitError()
+ << "failed to verify target op after merging symbols";
+
+ LLVM_DEBUG(DBGS() << "done merging ops\n");
+ return InFlightDiagnostic();
+}
diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
index e14461eaf124abc..a2f9e502e7235fa 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Transform/IR/TransformOps.h"
+#include "mlir/Dialect/Transform/IR/Utils.h"
#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Verifier.h"
diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
index b6b6256223ed23d..9f01cccc92737d4 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Transform/IR/TransformOps.h"
+#include "mlir/Dialect/Transform/IR/Utils.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Verifier.h"
#include "mlir/IR/Visitors.h"
@@ -186,232 +187,6 @@ LogicalResult transform::detail::assembleTransformLibraryFromPaths(
return success();
}
-/// Return whether `func1` can be merged into `func2`. For that to work
-/// `func1` has to be a declaration (aka has to be external) and `func2`
-/// either has to be a declaration as well, or it has to be public (otherwise,
-/// it wouldn't be visible by `func1`).
-static bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) {
- return func1.isExternal() && (func2.isPublic() || func2.isExternal());
-}
-
-/// Merge `func1` into `func2`. The two ops must be inside the same parent op
-/// and mergable according to `canMergeInto`. The function erases `func1` such
-/// that only `func2` exists when the function returns.
-static InFlightDiagnostic mergeInto(FunctionOpInterface func1,
- FunctionOpInterface func2) {
- assert(canMergeInto(func1, func2));
- assert(func1->getParentOp() == func2->getParentOp() &&
- "expected func1 and func2 to be in the same parent op");
-
- // Check that function signatures match.
- if (func1.getFunctionType() != func2.getFunctionType()) {
- return func1.emitError()
- << "external definition has a mismatching signature ("
- << func2.getFunctionType() << ")";
- }
-
- // Check and merge argument attributes.
- MLIRContext *context = func1->getContext();
- auto *td = context->getLoadedDialect<transform::TransformDialect>();
- StringAttr consumedName = td->getConsumedAttrName();
- StringAttr readOnlyName = td->getReadOnlyAttrName();
- for (unsigned i = 0, e = func1.getNumArguments(); i < e; ++i) {
- bool isExternalConsumed = func2.getArgAttr(i, consumedName) != nullptr;
- bool isExternalReadonly = func2.getArgAttr(i, readOnlyName) != nullptr;
- bool isConsumed = func1.getArgAttr(i, consumedName) != nullptr;
- bool isReadonly = func1.getArgAttr(i, readOnlyName) != nullptr;
- if (!isExternalConsumed && !isExternalReadonly) {
- if (isConsumed)
- func2.setArgAttr(i, consumedName, UnitAttr::get(context));
- else if (isReadonly)
- func2.setArgAttr(i, readOnlyName, UnitAttr::get(context));
- continue;
- }
-
- if ((isExternalConsumed && !isConsumed) ||
- (isExternalReadonly && !isReadonly)) {
- return func1.emitError()
- << "external definition has mismatching consumption "
- "annotations for argument #"
- << i;
- }
- }
-
- // `func1` is the external one, so we can remove it.
- assert(func1.isExternal());
- func1->erase();
-
- return InFlightDiagnostic();
-}
-
-InFlightDiagnostic
-transform::detail::mergeSymbolsInto(Operation *target,
- OwningOpRef<Operation *> other) {
- assert(target->hasTrait<OpTrait::SymbolTable>() &&
- "requires target to implement the 'SymbolTable' trait");
- assert(other->hasTrait<OpTrait::SymbolTable>() &&
- "requires target to implement the 'SymbolTable' trait");
-
- SymbolTable targetSymbolTable(target);
- SymbolTable otherSymbolTable(*other);
-
- // Step 1:
- //
- // Rename private symbols in both ops in order to resolve conflicts that can
- // be resolved that way.
- LLVM_DEBUG(DBGS() << "renaming private symbols to resolve conflicts:\n");
- // TODO: Do we *actually* need to test in both directions?
- for (auto &&[symbolTable, otherSymbolTable] : llvm::zip(
- SmallVector<SymbolTable *, 2>{&targetSymbolTable, &otherSymbolTable},
- SmallVector<SymbolTable *, 2>{&otherSymbolTable,
- &targetSymbolTable})) {
- Operation *symbolTableOp = symbolTable->getOp();
- for (Operation &op : symbolTableOp->getRegion(0).front()) {
- auto symbolOp = dyn_cast<SymbolOpInterface>(op);
- if (!symbolOp)
- continue;
- StringAttr name = symbolOp.getNameAttr();
- LLVM_DEBUG(DBGS() << " found @" << name.getValue() << "\n");
-
- // Check if there is a colliding op in the other module.
- auto collidingOp =
- cast_or_null<SymbolOpInterface>(otherSymbolTable->lookup(name));
- if (!collidingOp)
- continue;
-
- LLVM_DEBUG(DBGS() << " collision found for @" << name.getValue());
-
- // Collisions are fine if both opt are functions and can be merged.
- if (auto funcOp = dyn_cast<FunctionOpInterface>(op),
- collidingFuncOp =
- dyn_cast<FunctionOpInterface>(collidingOp.getOperation());
- funcOp && collidingFuncOp) {
- if (canMergeInto(funcOp, collidingFuncOp) ||
- canMergeInto(collidingFuncOp, funcOp)) {
- LLVM_DEBUG(llvm::dbgs() << " but both ops are functions and "
- "will be merged\n");
- continue;
- }
-
- // If they can't be merged, proceed like any other collision.
- LLVM_DEBUG(llvm::dbgs() << " and both ops are function definitions");
- }
-
- // Collision can be resolved by renaming if one of the ops is private.
- auto renameToUnique =
- [&](SymbolOpInterface op, SymbolOpInterface otherOp,
- SymbolTable &symbolTable,
- SymbolTable &otherSymbolTable) -> InFlightDiagnostic {
- LLVM_DEBUG(llvm::dbgs() << ", renaming\n");
- FailureOr<StringAttr> maybeNewName =
- symbolTable.renameToUnique(op, {&otherSymbolTable});
- if (failed(maybeNewName)) {
- InFlightDiagnostic diag = op->emitError("failed to rename symbol");
- diag.attachNote(otherOp->getLoc())
- << "attempted renaming due to collision with this op";
- return diag;
- }
- LLVM_DEBUG(DBGS() << " renamed to @" << maybeNewName->getValue()
- << "\n");
- return InFlightDiagnostic();
- };
-
- if (symbolOp.isPrivate()) {
- InFlightDiagnostic diag = renameToUnique(
- symbolOp, collidingOp, *symbolTable, *otherSymbolTable);
- if (failed(diag))
- return diag;
- continue;
- }
- if (collidingOp.isPrivate()) {
- InFlightDiagnostic diag = renameToUnique(
- collidingOp, symbolOp, *otherSymbolTable, *symbolTable);
- if (failed(diag))
- return diag;
- continue;
- }
- LLVM_DEBUG(llvm::dbgs() << ", emitting error\n");
- InFlightDiagnostic diag = symbolOp.emitError()
- << "doubly defined symbol @" << name.getValue();
- diag.attachNote(collidingOp->getLoc()) << "previously defined here";
- return diag;
- }
- }
-
- // TODO: This duplicates pass infrastructure. We should split this pass into
- // several and let the pass infrastructure do the verification.
- for (auto *op : SmallVector<Operation *>{target, *other}) {
- if (failed(mlir::verify(op)))
- return op->emitError() << "failed to verify input op after renaming";
- }
-
- // Step 2:
- //
- // Move all ops from `other` into target and merge public symbols.
- LLVM_DEBUG(DBGS() << "moving all symbols into target\n");
- {
- SmallVector<SymbolOpInterface> opsToMove;
- for (Operation &op : other->getRegion(0).front()) {
- if (auto symbol = dyn_cast<SymbolOpInterface>(op))
- opsToMove.push_back(symbol);
- }
-
- for (SymbolOpInterface op : opsToMove) {
- // Remember potentially colliding op in the target module.
- auto collidingOp = cast_or_null<SymbolOpInterface>(
- targetSymbolTable.lookup(op.getNameAttr()));
-
- // Move op even if we get a collision.
- LLVM_DEBUG(DBGS() << " moving @" << op.getName());
- op->moveBefore(&target->getRegion(0).front(),
- target->getRegion(0).front().end());
-
- // If there is no collision, we are done.
- if (!collidingOp) {
- LLVM_DEBUG(llvm::dbgs() << " without collision\n");
- continue;
- }
-
- // The two colliding ops must both be functions because we have already
- // emitted errors otherwise earlier.
- auto funcOp = cast<FunctionOpInterface>(op.getOperation());
- auto collidingFuncOp =
- cast<FunctionOpInterface>(collidingOp.getOperation());
-
- // Both ops are in the target module now and can be treated
- // symmetrically, so w.l.o.g. we can reduce to merging `funcOp` into
- // `collidingFuncOp`.
- if (!canMergeInto(funcOp, collidingFuncOp)) {
- std::swap(funcOp, collidingFuncOp);
- }
- assert(canMergeInto(funcOp, collidingFuncOp));
-
- LLVM_DEBUG(llvm::dbgs() << " with collision, trying to keep op at "
- << collidingFuncOp.getLoc() << ":\n"
- << collidingFuncOp << "\n");
-
- // Update symbol table. This works with or without the previous `swap`.
- targetSymbolTable.remove(funcOp);
- targetSymbolTable.insert(collidingFuncOp);
- assert(targetSymbolTable.lookup(funcOp.getName()) == collidingFuncOp);
-
- // Do the actual merging.
- {
- InFlightDiagnostic diag = mergeInto(funcOp, collidingFuncOp);
- if (failed(diag))
- return diag;
- }
- }
- }
-
- if (failed(mlir::verify(target)))
- return target->emitError()
- << "failed to verify target op after merging symbols";
-
- LLVM_DEBUG(DBGS() << "done merging ops\n");
- return InFlightDiagnostic();
-}
-
LogicalResult transform::applyTransformNamedSequence(
Operation *payload, Operation *transformRoot, ModuleOp transformModule,
const TransformOptions &options) {
>From d2fa23614cdd8dbbc0cec2a054d4a567dbe61b28 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ingo=20M=C3=BCller?= <ingomueller at google.com>
Date: Fri, 20 Oct 2023 10:25:31 +0000
Subject: [PATCH 2/6] Fix some unrelated issues found by @ftynse.
---
.../Transform/Transforms/TransformInterpreterUtils.h | 2 +-
.../Transform/Transforms/TransformInterpreterUtils.cpp | 9 +--------
2 files changed, 2 insertions(+), 9 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h
index 9881428f4f7bec5..1737d72838e9b33 100644
--- a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h
+++ b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h
@@ -32,7 +32,7 @@ namespace detail {
/// Each entry in `paths` may either be a regular file, in which case it ends up
/// in the result list, or a directory, in which case all (regular) `.mlir`
/// files in that directory are added. Any other file types lead to a failure.
-LogicalResult expandPathsToMLIRFiles(ArrayRef<std::string> &paths,
+LogicalResult expandPathsToMLIRFiles(ArrayRef<std::string> paths,
MLIRContext *context,
SmallVectorImpl<std::string> &fileNames);
diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
index 9f01cccc92737d4..b74aac3665930ce 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
@@ -39,7 +39,7 @@ using namespace mlir;
/// in the result list, or a directory, in which case all (regular) `.mlir`
/// files in that directory are added. Any other file types lead to a failure.
LogicalResult transform::detail::expandPathsToMLIRFiles(
- ArrayRef<std::string> &paths, MLIRContext *context,
+ ArrayRef<std::string> paths, MLIRContext *context,
SmallVectorImpl<std::string> &fileNames) {
for (const std::string &path : paths) {
auto loc = FileLineColLoc::get(context, path, 0, 0);
@@ -90,12 +90,6 @@ LogicalResult transform::detail::expandPathsToMLIRFiles(
LogicalResult transform::detail::parseTransformModuleFromFile(
MLIRContext *context, llvm::StringRef transformFileName,
OwningOpRef<ModuleOp> &transformModule) {
- if (transformFileName.empty()) {
- LLVM_DEBUG(
- DBGS() << "no transform file name specified, assuming the transform "
- "module is embedded in the IR next to the top-level\n");
- return success();
- }
// Parse transformFileName content into a ModuleOp.
std::string errorMessage;
auto memoryBuffer = mlir::openInputFile(transformFileName, &errorMessage);
@@ -173,7 +167,6 @@ LogicalResult transform::detail::assembleTransformLibraryFromPaths(
{
mergedParsedLibraries.get()->setAttr("transform.with_named_sequence",
UnitAttr::get(context));
- IRRewriter rewriter(context);
// TODO: extend `mergeSymbolsInto` to support multiple `other` modules.
for (OwningOpRef<ModuleOp> &parsedLibrary : parsedLibraries) {
if (failed(transform::detail::mergeSymbolsInto(
>From 822d127a0e530e87153f17830dd7605629a04eb7 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ingo=20M=C3=BCller?= <ingomueller at google.com>
Date: Mon, 16 Oct 2023 12:47:50 +0000
Subject: [PATCH 3/6] [mlir][transform] Handle multiple library preloading
passes.
The transform dialect stores a "library module" that the preload pass
can populate. Until now, each pass registered an additional module by
simply pushing it to a vector; however, the interpreter only used the
first of them. This commit turns the registration into "loading", i.e.,
each newly added module gets merged into the existing one. This allows
the loading to be split into several passes, and using the library in
the interpreter now takes all of them into account. While this design
avoids repeated merging every time the library is accessed, it requires
that the implementation of merging modules lives in the
`TransformDialect` target (since it at the dialect depend on each
other).
This resolves #69111.
---
.../Dialect/Transform/IR/TransformDialect.td | 34 +++++++++----------
.../Dialect/Transform/IR/TransformDialect.cpp | 16 +++++++++
.../Transforms/PreloadLibraryPass.cpp | 4 ++-
.../Transforms/TransformInterpreterUtils.cpp | 8 ++---
.../Dialect/Transform/preload-library.mlir | 12 +++++++
mlir/unittests/Dialect/Transform/Preload.cpp | 5 ++-
6 files changed, 53 insertions(+), 26 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
index ad6804673b770ca..211663258bfb133 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
@@ -78,20 +78,19 @@ def Transform_Dialect : Dialect {
using ExtensionTypePrintingHook =
std::function<void(::mlir::Type, ::mlir::AsmPrinter &)>;
- /// Appends the given module as a transform symbol library available to
- /// all dialect users.
- void registerLibraryModule(::mlir::OwningOpRef<::mlir::ModuleOp> &&
- library) {
- libraryModules.push_back(std::move(library));
- }
-
- /// Returns a range of registered library modules.
- auto getLibraryModules() const {
- return ::llvm::map_range(
- libraryModules,
- [](const ::mlir::OwningOpRef<::mlir::ModuleOp> &library) {
- return library.get();
- });
+ /// Loads the given module into the transform symbol library module.
+ void initializeLibraryModule();
+
+ /// Loads the given module into the transform symbol library module.
+ LogicalResult loadIntoLibraryModule(::mlir::OwningOpRef<::mlir::ModuleOp> &&
+ library);
+
+ /// Returns the transform symbol library module available to all dialect
+ /// users.
+ ModuleOp getLibraryModule() const {
+ if (libraryModule)
+ return libraryModule.get();
+ return ModuleOp();
}
private:
@@ -153,10 +152,9 @@ def Transform_Dialect : Dialect {
::llvm::DenseMap<::mlir::TypeID, ExtensionTypePrintingHook>
typePrintingHooks;
- /// Modules containing symbols, e.g. named sequences, that will be
- /// resolved by the interpreter when used.
- ::llvm::SmallVector<::mlir::OwningOpRef<::mlir::ModuleOp>, 2>
- libraryModules;
+ /// Module containing symbols, e.g. named sequences, that will be resolved
+ /// by the interpreter when used.
+ ::mlir::OwningOpRef<::mlir::ModuleOp> libraryModule;
}];
}
diff --git a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
index 32c56e903268f74..0104c5fe70c7183 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
@@ -11,6 +11,7 @@
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
+#include "mlir/Dialect/Transform/IR/Utils.h"
#include "mlir/IR/DialectImplementation.h"
#include "llvm/ADT/SCCIterator.h"
@@ -65,6 +66,7 @@ void transform::TransformDialect::initialize() {
#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
>();
initializeTypes();
+ initializeLibraryModule();
}
Type transform::TransformDialect::parseType(DialectAsmParser &parser) const {
@@ -89,6 +91,20 @@ void transform::TransformDialect::printType(Type type,
it->getSecond()(type, printer);
}
+void transform::TransformDialect::initializeLibraryModule() {
+ MLIRContext *context = getContext();
+ auto loc =
+ FileLineColLoc::get(context, "<transform-dialect-library-module>", 0, 0);
+ libraryModule = ModuleOp::create(loc, "__transform_library");
+ libraryModule.get()->setAttr(TransformDialect::kWithNamedSequenceAttrName,
+ UnitAttr::get(context));
+}
+
+LogicalResult transform::TransformDialect::loadIntoLibraryModule(
+ ::mlir::OwningOpRef<::mlir::ModuleOp> &&library) {
+ return detail::mergeSymbolsInto(getLibraryModule(), std::move(library));
+}
+
void transform::TransformDialect::reportDuplicateTypeRegistration(
StringRef mnemonic) {
std::string buffer;
diff --git a/mlir/lib/Dialect/Transform/Transforms/PreloadLibraryPass.cpp b/mlir/lib/Dialect/Transform/Transforms/PreloadLibraryPass.cpp
index d2e7108c0288623..4ade414dba833a8 100644
--- a/mlir/lib/Dialect/Transform/Transforms/PreloadLibraryPass.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/PreloadLibraryPass.cpp
@@ -33,7 +33,9 @@ class PreloadLibraryPass
// TODO: investigate using a resource blob if some ownership mode allows it.
auto *dialect =
getContext().getOrLoadDialect<transform::TransformDialect>();
- dialect->registerLibraryModule(std::move(mergedParsedLibraries));
+ if (failed(
+ dialect->loadIntoLibraryModule(std::move(mergedParsedLibraries))))
+ signalPassFailure();
}
};
} // namespace
diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
index b74aac3665930ce..00486ad68004f88 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
@@ -107,12 +107,8 @@ LogicalResult transform::detail::parseTransformModuleFromFile(
}
ModuleOp transform::detail::getPreloadedTransformModule(MLIRContext *context) {
- auto preloadedLibraryRange =
- context->getOrLoadDialect<transform::TransformDialect>()
- ->getLibraryModules();
- if (!preloadedLibraryRange.empty())
- return *preloadedLibraryRange.begin();
- return ModuleOp();
+ return context->getOrLoadDialect<transform::TransformDialect>()
+ ->getLibraryModule();
}
transform::TransformOpInterface
diff --git a/mlir/test/Dialect/Transform/preload-library.mlir b/mlir/test/Dialect/Transform/preload-library.mlir
index 9beefa44d673d9f..a0fe7a6c9b2d3c5 100644
--- a/mlir/test/Dialect/Transform/preload-library.mlir
+++ b/mlir/test/Dialect/Transform/preload-library.mlir
@@ -3,6 +3,18 @@
// RUN: -transform-interpreter=entry-point=private_helper \
// RUN: -split-input-file -verify-diagnostics
+// RUN: mlir-opt %s \
+// RUN: -transform-preload-library=transform-library-paths=%p%{fs-sep}test-interpreter-library/definitions-self-contained.mlir \
+// RUN: -transform-preload-library=transform-library-paths=%p%{fs-sep}test-interpreter-library/definitions-with-unresolved.mlir \
+// RUN: -transform-interpreter=entry-point=private_helper \
+// RUN: -split-input-file -verify-diagnostics
+
+// RUN: mlir-opt %s \
+// RUN: -transform-preload-library=transform-library-paths=%p%{fs-sep}test-interpreter-library/definitions-with-unresolved.mlir \
+// RUN: -transform-preload-library=transform-library-paths=%p%{fs-sep}test-interpreter-library/definitions-self-contained.mlir \
+// RUN: -transform-interpreter=entry-point=private_helper \
+// RUN: -split-input-file -verify-diagnostics
+
// expected-remark @below {{message}}
module {}
diff --git a/mlir/unittests/Dialect/Transform/Preload.cpp b/mlir/unittests/Dialect/Transform/Preload.cpp
index 7d66de0fe48ef80..5ab53ce9844b825 100644
--- a/mlir/unittests/Dialect/Transform/Preload.cpp
+++ b/mlir/unittests/Dialect/Transform/Preload.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/IR/Utils.h"
#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/DialectRegistry.h"
@@ -67,7 +68,9 @@ TEST(Preload, ContextPreloadConstructedLibrary) {
OwningOpRef<ModuleOp> transformLibrary =
parseSourceString<ModuleOp>(library, parserConfig, "<transform-library>");
EXPECT_TRUE(transformLibrary) << "failed to parse transform module";
- dialect->registerLibraryModule(std::move(transformLibrary));
+ LogicalResult diag =
+ dialect->loadIntoLibraryModule(std::move(transformLibrary));
+ EXPECT_TRUE(succeeded(diag));
ModuleOp retrievedTransformLibrary =
transform::detail::getPreloadedTransformModule(&context);
>From 5c7302ade698d8a4a0419064c7ba059247b4a600 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ingo=20M=C3=BCller?= <ingomueller at google.com>
Date: Tue, 24 Oct 2023 10:29:24 +0000
Subject: [PATCH 4/6] Fix issue found by @ftynse. Also make init function
private.
---
.../mlir/Dialect/Transform/IR/TransformDialect.td | 7 ++++---
mlir/lib/Dialect/Transform/IR/TransformDialect.cpp | 10 +++++-----
2 files changed, 9 insertions(+), 8 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
index 211663258bfb133..33c2e7ce0e6d62a 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
@@ -78,9 +78,6 @@ def Transform_Dialect : Dialect {
using ExtensionTypePrintingHook =
std::function<void(::mlir::Type, ::mlir::AsmPrinter &)>;
- /// Loads the given module into the transform symbol library module.
- void initializeLibraryModule();
-
/// Loads the given module into the transform symbol library module.
LogicalResult loadIntoLibraryModule(::mlir::OwningOpRef<::mlir::ModuleOp> &&
library);
@@ -94,6 +91,10 @@ def Transform_Dialect : Dialect {
}
private:
+ /// Initializes the transform symbol library module. Must be called from
+ /// `TransformDialect::initialize` for the library module to work.
+ void initializeLibraryModule();
+
/// Registers operations specified as template parameters with this
/// dialect. Checks that they implement the required interfaces.
template <typename... OpTys>
diff --git a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
index 0104c5fe70c7183..fb355bc97192682 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
@@ -91,6 +91,11 @@ void transform::TransformDialect::printType(Type type,
it->getSecond()(type, printer);
}
+LogicalResult transform::TransformDialect::loadIntoLibraryModule(
+ ::mlir::OwningOpRef<::mlir::ModuleOp> &&library) {
+ return detail::mergeSymbolsInto(getLibraryModule(), std::move(library));
+}
+
void transform::TransformDialect::initializeLibraryModule() {
MLIRContext *context = getContext();
auto loc =
@@ -100,11 +105,6 @@ void transform::TransformDialect::initializeLibraryModule() {
UnitAttr::get(context));
}
-LogicalResult transform::TransformDialect::loadIntoLibraryModule(
- ::mlir::OwningOpRef<::mlir::ModuleOp> &&library) {
- return detail::mergeSymbolsInto(getLibraryModule(), std::move(library));
-}
-
void transform::TransformDialect::reportDuplicateTypeRegistration(
StringRef mnemonic) {
std::string buffer;
>From 8fda307649e3f37222a824a0ecd1845c69cd6a4b Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ingo=20M=C3=BCller?= <ingomueller at google.com>
Date: Tue, 24 Oct 2023 10:38:45 +0000
Subject: [PATCH 5/6] Revert one of the "unrelated fixes".
That fix had removed the test in `parseTransformModuleFromFile` for an
empty file name. However, the function is still used in cases where the
file name may be empty, in which case we do not want a failure but want
to return from the function.
---
.../Transform/Transforms/TransformInterpreterUtils.cpp | 6 ++++++
1 file changed, 6 insertions(+)
diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
index 00486ad68004f88..c92472acf279481 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
@@ -90,6 +90,12 @@ LogicalResult transform::detail::expandPathsToMLIRFiles(
LogicalResult transform::detail::parseTransformModuleFromFile(
MLIRContext *context, llvm::StringRef transformFileName,
OwningOpRef<ModuleOp> &transformModule) {
+ if (transformFileName.empty()) {
+ LLVM_DEBUG(
+ DBGS() << "no transform file name specified, assuming the transform "
+ "module is embedded in the IR next to the top-level\n");
+ return success();
+ }
// Parse transformFileName content into a ModuleOp.
std::string errorMessage;
auto memoryBuffer = mlir::openInputFile(transformFileName, &errorMessage);
>From 911eb7c9dc62c5f9d3fa9d1f2fd994dfb1a9f0ae Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ingo=20M=C3=BCller?= <ingomueller at google.com>
Date: Wed, 25 Oct 2023 05:24:00 +0000
Subject: [PATCH 6/6] Fix test that broke after rebasing on top of #69329.
---
mlir/test/Dialect/Transform/preload-library.mlir | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/mlir/test/Dialect/Transform/preload-library.mlir b/mlir/test/Dialect/Transform/preload-library.mlir
index a0fe7a6c9b2d3c5..657bf50b35831fc 100644
--- a/mlir/test/Dialect/Transform/preload-library.mlir
+++ b/mlir/test/Dialect/Transform/preload-library.mlir
@@ -4,14 +4,14 @@
// RUN: -split-input-file -verify-diagnostics
// RUN: mlir-opt %s \
-// RUN: -transform-preload-library=transform-library-paths=%p%{fs-sep}test-interpreter-library/definitions-self-contained.mlir \
-// RUN: -transform-preload-library=transform-library-paths=%p%{fs-sep}test-interpreter-library/definitions-with-unresolved.mlir \
+// RUN: -transform-preload-library=transform-library-paths=%p%{fs-sep}include%{fs-sep}test-interpreter-library/definitions-self-contained.mlir \
+// RUN: -transform-preload-library=transform-library-paths=%p%{fs-sep}include%{fs-sep}test-interpreter-library/definitions-with-unresolved.mlir \
// RUN: -transform-interpreter=entry-point=private_helper \
// RUN: -split-input-file -verify-diagnostics
// RUN: mlir-opt %s \
-// RUN: -transform-preload-library=transform-library-paths=%p%{fs-sep}test-interpreter-library/definitions-with-unresolved.mlir \
-// RUN: -transform-preload-library=transform-library-paths=%p%{fs-sep}test-interpreter-library/definitions-self-contained.mlir \
+// RUN: -transform-preload-library=transform-library-paths=%p%{fs-sep}include%{fs-sep}test-interpreter-library/definitions-with-unresolved.mlir \
+// RUN: -transform-preload-library=transform-library-paths=%p%{fs-sep}include%{fs-sep}test-interpreter-library/definitions-self-contained.mlir \
// RUN: -transform-interpreter=entry-point=private_helper \
// RUN: -split-input-file -verify-diagnostics
More information about the Mlir-commits
mailing list