[Mlir-commits] [mlir] b055e6d - Add a new interface method `getAsmBlockName()` on OpAsmOpInterface to control block names

Mehdi Amini llvmlistbot at llvm.org
Fri Feb 11 00:46:17 PST 2022


Author: Mehdi Amini
Date: 2022-02-11T08:46:08Z
New Revision: b055e6d313651a766c97a1c82c09bcadb7c450c7

URL: https://github.com/llvm/llvm-project/commit/b055e6d313651a766c97a1c82c09bcadb7c450c7
DIFF: https://github.com/llvm/llvm-project/commit/b055e6d313651a766c97a1c82c09bcadb7c450c7.diff

LOG: Add a new interface method `getAsmBlockName()` on OpAsmOpInterface to control block names

This allows operations to control the block ids used by the printer in nested regions.

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D115849

Added: 
    

Modified: 
    mlir/include/mlir/IR/OpAsmInterface.td
    mlir/include/mlir/IR/OpImplementation.h
    mlir/lib/IR/AsmPrinter.cpp
    mlir/test/IR/pretty_printed_region_op.mlir
    mlir/test/lib/Dialect/Test/TestOps.td

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/OpAsmInterface.td b/mlir/include/mlir/IR/OpAsmInterface.td
index c13e59fb1b466..94e4bc8389915 100644
--- a/mlir/include/mlir/IR/OpAsmInterface.td
+++ b/mlir/include/mlir/IR/OpAsmInterface.td
@@ -62,6 +62,36 @@ def OpAsmOpInterface : OpInterface<"OpAsmOpInterface"> {
       ),
       "", "return;"
     >,
+    InterfaceMethod<[{
+        Get the name to use for a given block inside a region attached to this
+        operation.
+
+        For example if this operation has multiple blocks:
+
+        ```mlir
+          some.op() ({
+            ^bb0:
+              ...
+            ^bb1:
+              ...
+          })
+        ```
+
+        the method will be invoked on each of the blocks allowing the op to
+        print:
+
+        ```mlir
+          some.op() ({
+            ^custom_foo_name:
+              ...
+            ^custom_bar_name:
+              ...
+          })
+        ```
+      }],
+      "void", "getAsmBlockNames",
+      (ins "::mlir::OpAsmSetBlockNameFn":$setNameFn), "", ";"
+    >,
     StaticInterfaceMethod<[{
       Return the default dialect used when printing/parsing operations in
       regions nested under this operation. This allows for eliding the dialect

diff  --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 20310201fd8e4..0218e6be15bff 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -1322,6 +1322,10 @@ class OpAsmParser : public AsmParser {
 /// operation. See 'getAsmResultNames' below for more details.
 using OpAsmSetValueNameFn = function_ref<void(Value, StringRef)>;
 
+/// A functor used to set the name of blocks in regions directly nested under
+/// an operation.
+using OpAsmSetBlockNameFn = function_ref<void(Block *, StringRef)>;
+
 class OpAsmDialectInterface
     : public DialectInterface::Base<OpAsmDialectInterface> {
 public:

diff  --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index ba5628272a31d..3346702423c1a 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -791,6 +791,13 @@ void AliasState::printAliases(raw_ostream &os, NewLineCounter &newLine,
 //===----------------------------------------------------------------------===//
 
 namespace {
+/// Info about block printing: a number which is its position in the visitation
+/// order, and a name that is used to print reference to it, e.g. ^bb42.
+struct BlockInfo {
+  int ordering;
+  StringRef name;
+};
+
 /// This class manages the state of SSA value names.
 class SSANameState {
 public:
@@ -808,8 +815,8 @@ class SSANameState {
   /// operation, or empty if none exist.
   ArrayRef<int> getOpResultGroups(Operation *op);
 
-  /// Get the ID for the given block.
-  unsigned getBlockID(Block *block);
+  /// Get the info for the given block.
+  BlockInfo getBlockInfo(Block *block);
 
   /// Renumber the arguments for the specified region to the same names as the
   /// SSA values in namesToUse. See OperationPrinter::shadowRegionArgs for
@@ -846,8 +853,9 @@ class SSANameState {
   /// value of this map are the result numbers that start a result group.
   DenseMap<Operation *, SmallVector<int, 1>> opResultGroups;
 
-  /// This is the block ID for each block in the current.
-  DenseMap<Block *, unsigned> blockIDs;
+  /// This maps blocks to there visitation number in the current region as well
+  /// as the string representing their name.
+  DenseMap<Block *, BlockInfo> blockNames;
 
   /// This keeps track of all of the non-numeric names that are in flight,
   /// allowing us to check for duplicates.
@@ -967,9 +975,10 @@ ArrayRef<int> SSANameState::getOpResultGroups(Operation *op) {
   return it == opResultGroups.end() ? ArrayRef<int>() : it->second;
 }
 
-unsigned SSANameState::getBlockID(Block *block) {
-  auto it = blockIDs.find(block);
-  return it != blockIDs.end() ? it->second : NameSentinel;
+BlockInfo SSANameState::getBlockInfo(Block *block) {
+  auto it = blockNames.find(block);
+  BlockInfo invalidBlock{-1, "INVALIDBLOCK"};
+  return it != blockNames.end() ? it->second : invalidBlock;
 }
 
 void SSANameState::shadowRegionArgs(Region &region, ValueRange namesToUse) {
@@ -1021,7 +1030,16 @@ void SSANameState::numberValuesInRegion(Region &region) {
   for (auto &block : region) {
     // Each block gets a unique ID, and all of the operations within it get
     // numbered as well.
-    blockIDs[&block] = nextBlockID++;
+    auto blockInfoIt = blockNames.insert({&block, {-1, ""}});
+    if (blockInfoIt.second) {
+      // This block hasn't been named through `getAsmBlockArgumentNames`, use
+      // default `^bbNNN` format.
+      std::string name;
+      llvm::raw_string_ostream(name) << "^bb" << nextBlockID;
+      blockInfoIt.first->second.name = StringRef(name).copy(usedNameAllocator);
+    }
+    blockInfoIt.first->second.ordering = nextBlockID++;
+
     numberValuesInBlock(block);
   }
 }
@@ -1048,11 +1066,6 @@ void SSANameState::numberValuesInBlock(Block &block) {
 }
 
 void SSANameState::numberValuesInOp(Operation &op) {
-  unsigned numResults = op.getNumResults();
-  if (numResults == 0)
-    return;
-  Value resultBegin = op.getResult(0);
-
   // Function used to set the special result names for the operation.
   SmallVector<int, 2> resultGroups(/*Size=*/1, /*Value=*/0);
   auto setResultNameFn = [&](Value result, StringRef name) {
@@ -1064,11 +1077,34 @@ void SSANameState::numberValuesInOp(Operation &op) {
     if (int resultNo = result.cast<OpResult>().getResultNumber())
       resultGroups.push_back(resultNo);
   };
+  // Operations can customize the printing of block names in OpAsmOpInterface.
+  auto setBlockNameFn = [&](Block *block, StringRef name) {
+    assert(block->getParentOp() == &op &&
+           "getAsmBlockArgumentNames callback invoked on a block not directly "
+           "nested under the current operation");
+    assert(!blockNames.count(block) && "block numbered multiple times");
+    SmallString<16> tmpBuffer{"^"};
+    name = sanitizeIdentifier(name, tmpBuffer);
+    if (name.data() != tmpBuffer.data()) {
+      tmpBuffer.append(name);
+      name = tmpBuffer.str();
+    }
+    name = name.copy(usedNameAllocator);
+    blockNames[block] = {-1, name};
+  };
+
   if (!printerFlags.shouldPrintGenericOpForm()) {
-    if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(&op))
+    if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(&op)) {
+      asmInterface.getAsmBlockNames(setBlockNameFn);
       asmInterface.getAsmResultNames(setResultNameFn);
+    }
   }
 
+  unsigned numResults = op.getNumResults();
+  if (numResults == 0)
+    return;
+  Value resultBegin = op.getResult(0);
+
   // If the first result wasn't numbered, give it a default number.
   if (valueIDs.try_emplace(resultBegin, nextValueID).second)
     ++nextValueID;
@@ -2609,11 +2645,7 @@ void OperationPrinter::printGenericOp(Operation *op, bool printOpName) {
 }
 
 void OperationPrinter::printBlockName(Block *block) {
-  auto id = state->getSSANameState().getBlockID(block);
-  if (id != SSANameState::NameSentinel)
-    os << "^bb" << id;
-  else
-    os << "^INVALIDBLOCK";
+  os << state->getSSANameState().getBlockInfo(block).name;
 }
 
 void OperationPrinter::print(Block *block, bool printBlockArgs,
@@ -2647,18 +2679,18 @@ void OperationPrinter::print(Block *block, bool printBlockArgs,
       os << "  // pred: ";
       printBlockName(pred);
     } else {
-      // We want to print the predecessors in increasing numeric order, not in
+      // We want to print the predecessors in a stable order, not in
       // whatever order the use-list is in, so gather and sort them.
-      SmallVector<std::pair<unsigned, Block *>, 4> predIDs;
+      SmallVector<BlockInfo, 4> predIDs;
       for (auto *pred : block->getPredecessors())
-        predIDs.push_back({state->getSSANameState().getBlockID(pred), pred});
-      llvm::array_pod_sort(predIDs.begin(), predIDs.end());
+        predIDs.push_back(state->getSSANameState().getBlockInfo(pred));
+      llvm::sort(predIDs, [](BlockInfo lhs, BlockInfo rhs) {
+        return lhs.ordering < rhs.ordering;
+      });
 
       os << "  // " << predIDs.size() << " preds: ";
 
-      interleaveComma(predIDs, [&](std::pair<unsigned, Block *> pred) {
-        printBlockName(pred.second);
-      });
+      interleaveComma(predIDs, [&](BlockInfo pred) { os << pred.name; });
     }
     os << newLine;
   }

diff  --git a/mlir/test/IR/pretty_printed_region_op.mlir b/mlir/test/IR/pretty_printed_region_op.mlir
index 7bdc06f581260..220733f528c57 100644
--- a/mlir/test/IR/pretty_printed_region_op.mlir
+++ b/mlir/test/IR/pretty_printed_region_op.mlir
@@ -36,7 +36,6 @@ func @pretty_printed_region_op(%arg0 : f32, %arg1 : f32) -> (f32) {
 
 // -----
 
-
 func @pretty_printed_region_op_deferred_loc(%arg0 : f32, %arg1 : f32) -> (f32) {
 // CHECK-LOCATION: "test.pretty_printed_region"(%arg1, %arg0)
 // CHECK-LOCATION:   ^bb0(%arg[[x:[0-9]+]]: f32 loc("foo"), %arg[[y:[0-9]+]]: f32 loc("foo")
@@ -47,3 +46,29 @@ func @pretty_printed_region_op_deferred_loc(%arg0 : f32, %arg1 : f32) -> (f32) {
   %res = test.pretty_printed_region %arg1, %arg0 start special.op end : (f32, f32) -> (f32) loc("foo")
   return %res : f32
 }
+
+// -----
+
+// This tests the behavior of custom block names:
+// operations like `test.block_names` can define custom names for blocks in
+// nested regions.
+// CHECK-CUSTOM-LABEL: func @block_names
+func @block_names(%bool : i1) {
+  // CHECK: test.block_names
+  test.block_names {
+    // CHECK-CUSTOM: br ^foo1
+    // CHECK-GENERIC: cf.br{{.*}}^bb1
+    cf.br ^foo1
+  // CHECK-CUSTOM: ^foo1:
+  // CHECK-GENERIC: ^bb1:
+  ^foo1:
+    // CHECK-CUSTOM: br ^foo2
+    // CHECK-GENERIC: cf.br{{.*}}^bb2
+    cf.br ^foo2
+  // CHECK-CUSTOM: ^foo2:
+  // CHECK-GENERIC: ^bb2:
+  ^foo2:
+     "test.return"() : () -> ()
+  }
+  return
+}

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 3a14be7373721..5b7d973429269 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -660,6 +660,24 @@ def DefaultDialectOp : TEST_Op<"default_dialect", [OpAsmOpInterface]> {
   let assemblyFormat = "regions attr-dict-with-keyword";
 }
 
+// This is used to test the OpAsmOpInterface::getAsmBlockName() feature:
+// blocks nested in a region under this op will have a name defined by the
+// interface.
+def AsmBlockNameOp : TEST_Op<"block_names", [OpAsmOpInterface]> {
+ let regions = (region AnyRegion:$body);
+  let extraClassDeclaration = [{
+    void getAsmBlockNames(mlir::OpAsmSetBlockNameFn setNameFn) {
+      std::string name;
+      int count = 0;
+      for (::mlir::Block &block : getRegion().getBlocks()) {
+        name = "foo" + std::to_string(count++);
+        setNameFn(&block, name);
+      }
+    }
+  }];
+  let assemblyFormat = "regions attr-dict-with-keyword";
+}
+
 // This operation requires its return type to have the trait 'TestTypeTrait'.
 def ResultTypeWithTraitOp : TEST_Op<"result_type_with_trait", []> {
   let results = (outs AnyType);


        


More information about the Mlir-commits mailing list