[Mlir-commits] [mlir] 7f9e9c7 - Move getAsmBlockArgumentNames from OpAsmDialectInterface to OpAsmOpInterface
Mehdi Amini
llvmlistbot at llvm.org
Sun Dec 19 23:19:15 PST 2021
Author: Mehdi Amini
Date: 2021-12-20T07:18:01Z
New Revision: 7f9e9c7fc3416d243d59712864bdb5c5725aed7f
URL: https://github.com/llvm/llvm-project/commit/7f9e9c7fc3416d243d59712864bdb5c5725aed7f
DIFF: https://github.com/llvm/llvm-project/commit/7f9e9c7fc3416d243d59712864bdb5c5725aed7f.diff
LOG: Move getAsmBlockArgumentNames from OpAsmDialectInterface to OpAsmOpInterface
This method is more suitable as an opinterface: it seems intrinsic to
individual instances of the operation instead of the dialect.
Also remove the restriction on the interface being applicable to the entry block only.
Differential Revision: https://reviews.llvm.org/D116018
Added:
Modified:
mlir/include/mlir/IR/OpAsmInterface.td
mlir/include/mlir/IR/OpImplementation.h
mlir/lib/IR/AsmPrinter.cpp
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/lib/Dialect/Test/TestOps.td
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/OpAsmInterface.td b/mlir/include/mlir/IR/OpAsmInterface.td
index 50b98b04dbee5..b49e12ea9a85e 100644
--- a/mlir/include/mlir/IR/OpAsmInterface.td
+++ b/mlir/include/mlir/IR/OpAsmInterface.td
@@ -49,7 +49,18 @@ def OpAsmOpInterface : OpInterface<"OpAsmOpInterface"> {
}],
"void", "getAsmResultNames",
(ins "::mlir::OpAsmSetValueNameFn":$setNameFn),
- "", ";"
+ "", "return;"
+ >,
+ InterfaceMethod<[{
+ Get a special name to use when printing the block arguments for a region
+ immediately nested under this operation.
+ }],
+ "void", "getAsmBlockArgumentNames",
+ (ins
+ "::mlir::Region&":$region,
+ "::mlir::OpAsmSetValueNameFn":$setNameFn
+ ),
+ "", "return;"
>,
StaticInterfaceMethod<[{
Return the default dialect used when printing/parsing operations in
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 4b763f9efe36b..2cd09abb1dc24 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -1348,11 +1348,6 @@ class OpAsmDialectInterface
/// OpAsmInterface.td#getAsmResultNames for usage details and documentation.
virtual void getAsmResultNames(Operation *op,
OpAsmSetValueNameFn setNameFn) const {}
-
- /// Get a special name to use when printing the entry block arguments of the
- /// region contained by an operation in this dialect.
- virtual void getAsmBlockArgumentNames(Block *block,
- OpAsmSetValueNameFn setNameFn) const {}
};
} // namespace mlir
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index fe214f2aca2a4..fe6893d2147fe 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1006,6 +1006,20 @@ void SSANameState::shadowRegionArgs(Region ®ion, ValueRange namesToUse) {
}
void SSANameState::numberValuesInRegion(Region ®ion) {
+ auto setBlockArgNameFn = [&](Value arg, StringRef name) {
+ assert(!valueIDs.count(arg) && "arg numbered multiple times");
+ assert(arg.cast<BlockArgument>().getOwner()->getParent() == ®ion &&
+ "arg not defined in current region");
+ setValueName(arg, name);
+ };
+
+ if (!printerFlags.shouldPrintGenericOpForm()) {
+ if (Operation *op = region.getParentOp()) {
+ if (auto asmInterface = dyn_cast<OpAsmOpInterface>(op))
+ asmInterface.getAsmBlockArgumentNames(region, setBlockArgNameFn);
+ }
+ }
+
// Number the values within this region in a breadth-first order.
unsigned nextBlockID = 0;
for (auto &block : region) {
@@ -1017,23 +1031,9 @@ void SSANameState::numberValuesInRegion(Region ®ion) {
}
void SSANameState::numberValuesInBlock(Block &block) {
- auto setArgNameFn = [&](Value arg, StringRef name) {
- assert(!valueIDs.count(arg) && "arg numbered multiple times");
- assert(arg.cast<BlockArgument>().getOwner() == &block &&
- "arg not defined in 'block'");
- setValueName(arg, name);
- };
-
- bool isEntryBlock = block.isEntryBlock();
- if (isEntryBlock && !printerFlags.shouldPrintGenericOpForm()) {
- if (auto *op = block.getParentOp()) {
- if (auto asmInterface = interfaces.getInterfaceFor(op->getDialect()))
- asmInterface->getAsmBlockArgumentNames(&block, setArgNameFn);
- }
- }
-
// Number the block arguments. We give entry block arguments a special name
// 'arg'.
+ bool isEntryBlock = block.isEntryBlock();
SmallString<32> specialNameBuffer(isEntryBlock ? "arg" : "");
llvm::raw_svector_ostream specialName(specialNameBuffer);
for (auto arg : block.getArguments()) {
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index aae98baf1c26c..b46aaf1069979 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -105,20 +105,6 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op))
setNameFn(asmOp, "result");
}
-
- void getAsmBlockArgumentNames(Block *block,
- OpAsmSetValueNameFn setNameFn) const final {
- auto op = block->getParentOp();
- auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names");
- if (!arrayAttr)
- return;
- auto args = block->getArguments();
- auto e = std::min(arrayAttr.size(), args.size());
- for (unsigned i = 0; i < e; ++i) {
- if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
- setNameFn(args[i], strAttr.getValue());
- }
- }
};
struct TestDialectFoldInterface : public DialectFoldInterface {
@@ -848,6 +834,19 @@ static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) {
return parser.parseRegion(*body, ivsInfo, argTypes);
}
+void PolyForOp::getAsmBlockArgumentNames(Region ®ion,
+ OpAsmSetValueNameFn setNameFn) {
+ auto arrayAttr = getOperation()->getAttrOfType<ArrayAttr>("arg_names");
+ if (!arrayAttr)
+ return;
+ auto args = getRegion().front().getArguments();
+ auto e = std::min(arrayAttr.size(), args.size());
+ for (unsigned i = 0; i < e; ++i) {
+ if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
+ setNameFn(args[i], strAttr.getValue());
+ }
+}
+
//===----------------------------------------------------------------------===//
// Test removing op with inner ops.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 7dca8165b0d23..80b568c743b01 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1667,13 +1667,16 @@ def PrettyPrintedRegionOp : TEST_Op<"pretty_printed_region",
let printer = [{ return ::print(p, *this); }];
}
-def PolyForOp : TEST_Op<"polyfor">
+def PolyForOp : TEST_Op<"polyfor", [OpAsmOpInterface]>
{
let summary = "polyfor operation";
let description = [{
Test op with multiple region arguments, each argument of index type.
}];
-
+ let extraClassDeclaration = [{
+ void getAsmBlockArgumentNames(mlir::Region ®ion,
+ mlir::OpAsmSetValueNameFn setNameFn);
+ }];
let regions = (region SizedRegion<1>:$region);
let parser = [{ return ::parse$cppClass(parser, result); }];
}
More information about the Mlir-commits
mailing list