[Mlir-commits] [mlir] 7f9e9c7 - Move getAsmBlockArgumentNames from OpAsmDialectInterface to OpAsmOpInterface

Mehdi Amini llvmlistbot at llvm.org
Sun Dec 19 23:19:15 PST 2021


Author: Mehdi Amini
Date: 2021-12-20T07:18:01Z
New Revision: 7f9e9c7fc3416d243d59712864bdb5c5725aed7f

URL: https://github.com/llvm/llvm-project/commit/7f9e9c7fc3416d243d59712864bdb5c5725aed7f
DIFF: https://github.com/llvm/llvm-project/commit/7f9e9c7fc3416d243d59712864bdb5c5725aed7f.diff

LOG: Move getAsmBlockArgumentNames from OpAsmDialectInterface to OpAsmOpInterface

This method is more suitable as an opinterface: it seems intrinsic to
individual instances of the operation instead of the dialect.
Also remove the restriction on the interface being applicable to the entry block only.

Differential Revision: https://reviews.llvm.org/D116018

Added: 
    

Modified: 
    mlir/include/mlir/IR/OpAsmInterface.td
    mlir/include/mlir/IR/OpImplementation.h
    mlir/lib/IR/AsmPrinter.cpp
    mlir/test/lib/Dialect/Test/TestDialect.cpp
    mlir/test/lib/Dialect/Test/TestOps.td

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/OpAsmInterface.td b/mlir/include/mlir/IR/OpAsmInterface.td
index 50b98b04dbee5..b49e12ea9a85e 100644
--- a/mlir/include/mlir/IR/OpAsmInterface.td
+++ b/mlir/include/mlir/IR/OpAsmInterface.td
@@ -49,7 +49,18 @@ def OpAsmOpInterface : OpInterface<"OpAsmOpInterface"> {
       }],
       "void", "getAsmResultNames",
       (ins "::mlir::OpAsmSetValueNameFn":$setNameFn),
-      "", ";"
+      "", "return;"
+    >,
+    InterfaceMethod<[{
+        Get a special name to use when printing the block arguments for a region
+        immediately nested under this operation.
+      }],
+      "void", "getAsmBlockArgumentNames",
+      (ins
+        "::mlir::Region&":$region,
+        "::mlir::OpAsmSetValueNameFn":$setNameFn
+      ),
+      "", "return;"
     >,
     StaticInterfaceMethod<[{
       Return the default dialect used when printing/parsing operations in

diff  --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 4b763f9efe36b..2cd09abb1dc24 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -1348,11 +1348,6 @@ class OpAsmDialectInterface
   /// OpAsmInterface.td#getAsmResultNames for usage details and documentation.
   virtual void getAsmResultNames(Operation *op,
                                  OpAsmSetValueNameFn setNameFn) const {}
-
-  /// Get a special name to use when printing the entry block arguments of the
-  /// region contained by an operation in this dialect.
-  virtual void getAsmBlockArgumentNames(Block *block,
-                                        OpAsmSetValueNameFn setNameFn) const {}
 };
 } // namespace mlir
 

diff  --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index fe214f2aca2a4..fe6893d2147fe 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1006,6 +1006,20 @@ void SSANameState::shadowRegionArgs(Region &region, ValueRange namesToUse) {
 }
 
 void SSANameState::numberValuesInRegion(Region &region) {
+  auto setBlockArgNameFn = [&](Value arg, StringRef name) {
+    assert(!valueIDs.count(arg) && "arg numbered multiple times");
+    assert(arg.cast<BlockArgument>().getOwner()->getParent() == &region &&
+           "arg not defined in current region");
+    setValueName(arg, name);
+  };
+
+  if (!printerFlags.shouldPrintGenericOpForm()) {
+    if (Operation *op = region.getParentOp()) {
+      if (auto asmInterface = dyn_cast<OpAsmOpInterface>(op))
+        asmInterface.getAsmBlockArgumentNames(region, setBlockArgNameFn);
+    }
+  }
+
   // Number the values within this region in a breadth-first order.
   unsigned nextBlockID = 0;
   for (auto &block : region) {
@@ -1017,23 +1031,9 @@ void SSANameState::numberValuesInRegion(Region &region) {
 }
 
 void SSANameState::numberValuesInBlock(Block &block) {
-  auto setArgNameFn = [&](Value arg, StringRef name) {
-    assert(!valueIDs.count(arg) && "arg numbered multiple times");
-    assert(arg.cast<BlockArgument>().getOwner() == &block &&
-           "arg not defined in 'block'");
-    setValueName(arg, name);
-  };
-
-  bool isEntryBlock = block.isEntryBlock();
-  if (isEntryBlock && !printerFlags.shouldPrintGenericOpForm()) {
-    if (auto *op = block.getParentOp()) {
-      if (auto asmInterface = interfaces.getInterfaceFor(op->getDialect()))
-        asmInterface->getAsmBlockArgumentNames(&block, setArgNameFn);
-    }
-  }
-
   // Number the block arguments. We give entry block arguments a special name
   // 'arg'.
+  bool isEntryBlock = block.isEntryBlock();
   SmallString<32> specialNameBuffer(isEntryBlock ? "arg" : "");
   llvm::raw_svector_ostream specialName(specialNameBuffer);
   for (auto arg : block.getArguments()) {

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index aae98baf1c26c..b46aaf1069979 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -105,20 +105,6 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
     if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op))
       setNameFn(asmOp, "result");
   }
-
-  void getAsmBlockArgumentNames(Block *block,
-                                OpAsmSetValueNameFn setNameFn) const final {
-    auto op = block->getParentOp();
-    auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names");
-    if (!arrayAttr)
-      return;
-    auto args = block->getArguments();
-    auto e = std::min(arrayAttr.size(), args.size());
-    for (unsigned i = 0; i < e; ++i) {
-      if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
-        setNameFn(args[i], strAttr.getValue());
-    }
-  }
 };
 
 struct TestDialectFoldInterface : public DialectFoldInterface {
@@ -848,6 +834,19 @@ static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) {
   return parser.parseRegion(*body, ivsInfo, argTypes);
 }
 
+void PolyForOp::getAsmBlockArgumentNames(Region &region,
+                                         OpAsmSetValueNameFn setNameFn) {
+  auto arrayAttr = getOperation()->getAttrOfType<ArrayAttr>("arg_names");
+  if (!arrayAttr)
+    return;
+  auto args = getRegion().front().getArguments();
+  auto e = std::min(arrayAttr.size(), args.size());
+  for (unsigned i = 0; i < e; ++i) {
+    if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
+      setNameFn(args[i], strAttr.getValue());
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // Test removing op with inner ops.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 7dca8165b0d23..80b568c743b01 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1667,13 +1667,16 @@ def PrettyPrintedRegionOp : TEST_Op<"pretty_printed_region",
   let printer = [{ return ::print(p, *this); }];
 }
 
-def PolyForOp : TEST_Op<"polyfor">
+def PolyForOp : TEST_Op<"polyfor", [OpAsmOpInterface]>
 {
   let summary =  "polyfor operation";
   let description = [{
     Test op with multiple region arguments, each argument of index type.
   }];
-
+  let extraClassDeclaration = [{
+    void getAsmBlockArgumentNames(mlir::Region &region,
+                                  mlir::OpAsmSetValueNameFn setNameFn);
+  }];
   let regions = (region SizedRegion<1>:$region);
   let parser = [{ return ::parse$cppClass(parser, result); }];
 }


        


More information about the Mlir-commits mailing list