[Mlir-commits] [mlir] b055e6d - Add a new interface method `getAsmBlockName()` on OpAsmOpInterface to control block names
Mehdi Amini
llvmlistbot at llvm.org
Fri Feb 11 00:46:17 PST 2022
Author: Mehdi Amini
Date: 2022-02-11T08:46:08Z
New Revision: b055e6d313651a766c97a1c82c09bcadb7c450c7
URL: https://github.com/llvm/llvm-project/commit/b055e6d313651a766c97a1c82c09bcadb7c450c7
DIFF: https://github.com/llvm/llvm-project/commit/b055e6d313651a766c97a1c82c09bcadb7c450c7.diff
LOG: Add a new interface method `getAsmBlockName()` on OpAsmOpInterface to control block names
This allows operations to control the block ids used by the printer in nested regions.
Reviewed By: Mogball
Differential Revision: https://reviews.llvm.org/D115849
Added:
Modified:
mlir/include/mlir/IR/OpAsmInterface.td
mlir/include/mlir/IR/OpImplementation.h
mlir/lib/IR/AsmPrinter.cpp
mlir/test/IR/pretty_printed_region_op.mlir
mlir/test/lib/Dialect/Test/TestOps.td
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/OpAsmInterface.td b/mlir/include/mlir/IR/OpAsmInterface.td
index c13e59fb1b466..94e4bc8389915 100644
--- a/mlir/include/mlir/IR/OpAsmInterface.td
+++ b/mlir/include/mlir/IR/OpAsmInterface.td
@@ -62,6 +62,36 @@ def OpAsmOpInterface : OpInterface<"OpAsmOpInterface"> {
),
"", "return;"
>,
+ InterfaceMethod<[{
+ Get the name to use for a given block inside a region attached to this
+ operation.
+
+ For example if this operation has multiple blocks:
+
+ ```mlir
+ some.op() ({
+ ^bb0:
+ ...
+ ^bb1:
+ ...
+ })
+ ```
+
+ the method will be invoked on each of the blocks allowing the op to
+ print:
+
+ ```mlir
+ some.op() ({
+ ^custom_foo_name:
+ ...
+ ^custom_bar_name:
+ ...
+ })
+ ```
+ }],
+ "void", "getAsmBlockNames",
+ (ins "::mlir::OpAsmSetBlockNameFn":$setNameFn), "", ";"
+ >,
StaticInterfaceMethod<[{
Return the default dialect used when printing/parsing operations in
regions nested under this operation. This allows for eliding the dialect
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 20310201fd8e4..0218e6be15bff 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -1322,6 +1322,10 @@ class OpAsmParser : public AsmParser {
/// operation. See 'getAsmResultNames' below for more details.
using OpAsmSetValueNameFn = function_ref<void(Value, StringRef)>;
+/// A functor used to set the name of blocks in regions directly nested under
+/// an operation.
+using OpAsmSetBlockNameFn = function_ref<void(Block *, StringRef)>;
+
class OpAsmDialectInterface
: public DialectInterface::Base<OpAsmDialectInterface> {
public:
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index ba5628272a31d..3346702423c1a 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -791,6 +791,13 @@ void AliasState::printAliases(raw_ostream &os, NewLineCounter &newLine,
//===----------------------------------------------------------------------===//
namespace {
+/// Info about block printing: a number which is its position in the visitation
+/// order, and a name that is used to print reference to it, e.g. ^bb42.
+struct BlockInfo {
+ int ordering;
+ StringRef name;
+};
+
/// This class manages the state of SSA value names.
class SSANameState {
public:
@@ -808,8 +815,8 @@ class SSANameState {
/// operation, or empty if none exist.
ArrayRef<int> getOpResultGroups(Operation *op);
- /// Get the ID for the given block.
- unsigned getBlockID(Block *block);
+ /// Get the info for the given block.
+ BlockInfo getBlockInfo(Block *block);
/// Renumber the arguments for the specified region to the same names as the
/// SSA values in namesToUse. See OperationPrinter::shadowRegionArgs for
@@ -846,8 +853,9 @@ class SSANameState {
/// value of this map are the result numbers that start a result group.
DenseMap<Operation *, SmallVector<int, 1>> opResultGroups;
- /// This is the block ID for each block in the current.
- DenseMap<Block *, unsigned> blockIDs;
+ /// This maps blocks to there visitation number in the current region as well
+ /// as the string representing their name.
+ DenseMap<Block *, BlockInfo> blockNames;
/// This keeps track of all of the non-numeric names that are in flight,
/// allowing us to check for duplicates.
@@ -967,9 +975,10 @@ ArrayRef<int> SSANameState::getOpResultGroups(Operation *op) {
return it == opResultGroups.end() ? ArrayRef<int>() : it->second;
}
-unsigned SSANameState::getBlockID(Block *block) {
- auto it = blockIDs.find(block);
- return it != blockIDs.end() ? it->second : NameSentinel;
+BlockInfo SSANameState::getBlockInfo(Block *block) {
+ auto it = blockNames.find(block);
+ BlockInfo invalidBlock{-1, "INVALIDBLOCK"};
+ return it != blockNames.end() ? it->second : invalidBlock;
}
void SSANameState::shadowRegionArgs(Region ®ion, ValueRange namesToUse) {
@@ -1021,7 +1030,16 @@ void SSANameState::numberValuesInRegion(Region ®ion) {
for (auto &block : region) {
// Each block gets a unique ID, and all of the operations within it get
// numbered as well.
- blockIDs[&block] = nextBlockID++;
+ auto blockInfoIt = blockNames.insert({&block, {-1, ""}});
+ if (blockInfoIt.second) {
+ // This block hasn't been named through `getAsmBlockArgumentNames`, use
+ // default `^bbNNN` format.
+ std::string name;
+ llvm::raw_string_ostream(name) << "^bb" << nextBlockID;
+ blockInfoIt.first->second.name = StringRef(name).copy(usedNameAllocator);
+ }
+ blockInfoIt.first->second.ordering = nextBlockID++;
+
numberValuesInBlock(block);
}
}
@@ -1048,11 +1066,6 @@ void SSANameState::numberValuesInBlock(Block &block) {
}
void SSANameState::numberValuesInOp(Operation &op) {
- unsigned numResults = op.getNumResults();
- if (numResults == 0)
- return;
- Value resultBegin = op.getResult(0);
-
// Function used to set the special result names for the operation.
SmallVector<int, 2> resultGroups(/*Size=*/1, /*Value=*/0);
auto setResultNameFn = [&](Value result, StringRef name) {
@@ -1064,11 +1077,34 @@ void SSANameState::numberValuesInOp(Operation &op) {
if (int resultNo = result.cast<OpResult>().getResultNumber())
resultGroups.push_back(resultNo);
};
+ // Operations can customize the printing of block names in OpAsmOpInterface.
+ auto setBlockNameFn = [&](Block *block, StringRef name) {
+ assert(block->getParentOp() == &op &&
+ "getAsmBlockArgumentNames callback invoked on a block not directly "
+ "nested under the current operation");
+ assert(!blockNames.count(block) && "block numbered multiple times");
+ SmallString<16> tmpBuffer{"^"};
+ name = sanitizeIdentifier(name, tmpBuffer);
+ if (name.data() != tmpBuffer.data()) {
+ tmpBuffer.append(name);
+ name = tmpBuffer.str();
+ }
+ name = name.copy(usedNameAllocator);
+ blockNames[block] = {-1, name};
+ };
+
if (!printerFlags.shouldPrintGenericOpForm()) {
- if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(&op))
+ if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(&op)) {
+ asmInterface.getAsmBlockNames(setBlockNameFn);
asmInterface.getAsmResultNames(setResultNameFn);
+ }
}
+ unsigned numResults = op.getNumResults();
+ if (numResults == 0)
+ return;
+ Value resultBegin = op.getResult(0);
+
// If the first result wasn't numbered, give it a default number.
if (valueIDs.try_emplace(resultBegin, nextValueID).second)
++nextValueID;
@@ -2609,11 +2645,7 @@ void OperationPrinter::printGenericOp(Operation *op, bool printOpName) {
}
void OperationPrinter::printBlockName(Block *block) {
- auto id = state->getSSANameState().getBlockID(block);
- if (id != SSANameState::NameSentinel)
- os << "^bb" << id;
- else
- os << "^INVALIDBLOCK";
+ os << state->getSSANameState().getBlockInfo(block).name;
}
void OperationPrinter::print(Block *block, bool printBlockArgs,
@@ -2647,18 +2679,18 @@ void OperationPrinter::print(Block *block, bool printBlockArgs,
os << " // pred: ";
printBlockName(pred);
} else {
- // We want to print the predecessors in increasing numeric order, not in
+ // We want to print the predecessors in a stable order, not in
// whatever order the use-list is in, so gather and sort them.
- SmallVector<std::pair<unsigned, Block *>, 4> predIDs;
+ SmallVector<BlockInfo, 4> predIDs;
for (auto *pred : block->getPredecessors())
- predIDs.push_back({state->getSSANameState().getBlockID(pred), pred});
- llvm::array_pod_sort(predIDs.begin(), predIDs.end());
+ predIDs.push_back(state->getSSANameState().getBlockInfo(pred));
+ llvm::sort(predIDs, [](BlockInfo lhs, BlockInfo rhs) {
+ return lhs.ordering < rhs.ordering;
+ });
os << " // " << predIDs.size() << " preds: ";
- interleaveComma(predIDs, [&](std::pair<unsigned, Block *> pred) {
- printBlockName(pred.second);
- });
+ interleaveComma(predIDs, [&](BlockInfo pred) { os << pred.name; });
}
os << newLine;
}
diff --git a/mlir/test/IR/pretty_printed_region_op.mlir b/mlir/test/IR/pretty_printed_region_op.mlir
index 7bdc06f581260..220733f528c57 100644
--- a/mlir/test/IR/pretty_printed_region_op.mlir
+++ b/mlir/test/IR/pretty_printed_region_op.mlir
@@ -36,7 +36,6 @@ func @pretty_printed_region_op(%arg0 : f32, %arg1 : f32) -> (f32) {
// -----
-
func @pretty_printed_region_op_deferred_loc(%arg0 : f32, %arg1 : f32) -> (f32) {
// CHECK-LOCATION: "test.pretty_printed_region"(%arg1, %arg0)
// CHECK-LOCATION: ^bb0(%arg[[x:[0-9]+]]: f32 loc("foo"), %arg[[y:[0-9]+]]: f32 loc("foo")
@@ -47,3 +46,29 @@ func @pretty_printed_region_op_deferred_loc(%arg0 : f32, %arg1 : f32) -> (f32) {
%res = test.pretty_printed_region %arg1, %arg0 start special.op end : (f32, f32) -> (f32) loc("foo")
return %res : f32
}
+
+// -----
+
+// This tests the behavior of custom block names:
+// operations like `test.block_names` can define custom names for blocks in
+// nested regions.
+// CHECK-CUSTOM-LABEL: func @block_names
+func @block_names(%bool : i1) {
+ // CHECK: test.block_names
+ test.block_names {
+ // CHECK-CUSTOM: br ^foo1
+ // CHECK-GENERIC: cf.br{{.*}}^bb1
+ cf.br ^foo1
+ // CHECK-CUSTOM: ^foo1:
+ // CHECK-GENERIC: ^bb1:
+ ^foo1:
+ // CHECK-CUSTOM: br ^foo2
+ // CHECK-GENERIC: cf.br{{.*}}^bb2
+ cf.br ^foo2
+ // CHECK-CUSTOM: ^foo2:
+ // CHECK-GENERIC: ^bb2:
+ ^foo2:
+ "test.return"() : () -> ()
+ }
+ return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 3a14be7373721..5b7d973429269 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -660,6 +660,24 @@ def DefaultDialectOp : TEST_Op<"default_dialect", [OpAsmOpInterface]> {
let assemblyFormat = "regions attr-dict-with-keyword";
}
+// This is used to test the OpAsmOpInterface::getAsmBlockName() feature:
+// blocks nested in a region under this op will have a name defined by the
+// interface.
+def AsmBlockNameOp : TEST_Op<"block_names", [OpAsmOpInterface]> {
+ let regions = (region AnyRegion:$body);
+ let extraClassDeclaration = [{
+ void getAsmBlockNames(mlir::OpAsmSetBlockNameFn setNameFn) {
+ std::string name;
+ int count = 0;
+ for (::mlir::Block &block : getRegion().getBlocks()) {
+ name = "foo" + std::to_string(count++);
+ setNameFn(&block, name);
+ }
+ }
+ }];
+ let assemblyFormat = "regions attr-dict-with-keyword";
+}
+
// This operation requires its return type to have the trait 'TestTypeTrait'.
def ResultTypeWithTraitOp : TEST_Op<"result_type_with_trait", []> {
let results = (outs AnyType);
More information about the Mlir-commits
mailing list