[Mlir-commits] [mlir] 035e12e - [MLIR] [ODS] Allowing attr-dict in custom directive
John Demme
llvmlistbot at llvm.org
Tue Oct 27 18:25:51 PDT 2020
Author: John Demme
Date: 2020-10-28T01:24:16Z
New Revision: 035e12e66449576051fb0ac91ff84786f330ad95
URL: https://github.com/llvm/llvm-project/commit/035e12e66449576051fb0ac91ff84786f330ad95
DIFF: https://github.com/llvm/llvm-project/commit/035e12e66449576051fb0ac91ff84786f330ad95.diff
LOG: [MLIR] [ODS] Allowing attr-dict in custom directive
Enhance tblgen's declarative assembly format to allow `attr-dict` in
custom directives.
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D89772
Added:
Modified:
mlir/docs/OpDefinitions.md
mlir/lib/Dialect/Async/IR/Async.cpp
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/lib/Dialect/Test/TestOps.td
mlir/tools/mlir-tblgen/OpFormatGen.cpp
Removed:
################################################################################
diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md
index f44765204dc0..d24f4711da5f 100644
--- a/mlir/docs/OpDefinitions.md
+++ b/mlir/docs/OpDefinitions.md
@@ -799,12 +799,12 @@ when generating the C++ code for the format. The `UserDirective` is an
identifier used as a suffix to these two calls, i.e., `custom<MyDirective>(...)`
would result in calls to `parseMyDirective` and `printMyDirective` wihtin the
parser and printer respectively. `Params` may be any combination of variables
-(i.e. Attribute, Operand, Successor, etc.) and type directives. The type
-directives must refer to a variable, but that variable need not also be a
-parameter to the custom directive.
+(i.e. Attribute, Operand, Successor, etc.), type directives, and `attr-dict`.
+The type directives must refer to a variable, but that variable need not also
+be a parameter to the custom directive.
-The arguments to the `parse<UserDirective>` method is firstly a reference to the
-`OpAsmParser`(`OpAsmParser &`), and secondly a set of output parameters
+The arguments to the `parse<UserDirective>` method are firstly a reference to
+the `OpAsmParser`(`OpAsmParser &`), and secondly a set of output parameters
corresponding to the parameters specified in the format. The mapping of
declarative parameter to `parse` method argument is detailed below:
@@ -829,12 +829,14 @@ declarative parameter to `parse` method argument is detailed below:
- Single: `Type`
- Optional: `Type`
- Variadic: `const SmallVectorImpl<Type> &`
+* `attr-dict` Directive: `NamedAttrList &`
When a variable is optional, the value should only be specified if the variable
is present. Otherwise, the value should remain `None` or null.
-The arguments to the `print<UserDirective>` method is firstly a reference to the
-`OpAsmPrinter`(`OpAsmPrinter &`), and secondly a set of output parameters
+The arguments to the `print<UserDirective>` method is firstly a reference to
+the `OpAsmPrinter`(`OpAsmPrinter &`), second the op (e.g. `FooOp op` which
+can be `Operation *op` alternatively), and finally a set of output parameters
corresponding to the parameters specified in the format. The mapping of
declarative parameter to `print` method argument is detailed below:
@@ -859,6 +861,7 @@ declarative parameter to `print` method argument is detailed below:
- Single: `Type`
- Optional: `Type`
- Variadic: `TypeRange`
+* `attr-dict` Directive: `const MutableDictionaryAttr&`
When a variable is optional, the provided value may be null.
diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp
index 5ef3931ae3fd..36ef1af8cfa7 100644
--- a/mlir/lib/Dialect/Async/IR/Async.cpp
+++ b/mlir/lib/Dialect/Async/IR/Async.cpp
@@ -271,8 +271,8 @@ static ParseResult parseAwaitResultType(OpAsmParser &parser, Type &operandType,
return success();
}
-static void printAwaitResultType(OpAsmPrinter &p, Type operandType,
- Type resultType) {
+static void printAwaitResultType(OpAsmPrinter &p, Operation *op,
+ Type operandType, Type resultType) {
p << operandType;
}
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 7abefd7a5499..9514d3e0a04b 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -493,7 +493,7 @@ parseLaunchFuncOperands(OpAsmParser &parser,
argTypes, argAttrs, isVariadic);
}
-static void printLaunchFuncOperands(OpAsmPrinter &printer,
+static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *,
OperandRange operands, TypeRange types) {
if (operands.empty())
return;
@@ -846,7 +846,8 @@ static ParseResult parseAsyncDependencies(
OpAsmParser::Delimiter::OptionalSquare);
}
-static void printAsyncDependencies(OpAsmPrinter &printer, Type asyncTokenType,
+static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op,
+ Type asyncTokenType,
OperandRange asyncDependencies) {
if (asyncTokenType)
printer << "async ";
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index d34e997644a5..d2013d8c6941 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -385,19 +385,24 @@ static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser,
return success();
}
+static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser,
+ NamedAttrList &attrs) {
+ return parser.parseOptionalAttrDict(attrs);
+}
+
//===----------------------------------------------------------------------===//
// Printing
-static void printCustomDirectiveOperands(OpAsmPrinter &printer, Value operand,
- Value optOperand,
+static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *,
+ Value operand, Value optOperand,
OperandRange varOperands) {
printer << operand;
if (optOperand)
printer << ", " << optOperand;
printer << " -> (" << varOperands << ")";
}
-static void printCustomDirectiveResults(OpAsmPrinter &printer, Type operandType,
- Type optOperandType,
+static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *,
+ Type operandType, Type optOperandType,
TypeRange varOperandTypes) {
printer << " : " << operandType;
if (optOperandType)
@@ -405,23 +410,23 @@ static void printCustomDirectiveResults(OpAsmPrinter &printer, Type operandType,
printer << " -> (" << varOperandTypes << ")";
}
static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer,
- Type operandType,
+ Operation *op, Type operandType,
Type optOperandType,
TypeRange varOperandTypes) {
printer << " type_refs_capture ";
- printCustomDirectiveResults(printer, operandType, optOperandType,
+ printCustomDirectiveResults(printer, op, operandType, optOperandType,
varOperandTypes);
}
-static void
-printCustomDirectiveOperandsAndTypes(OpAsmPrinter &printer, Value operand,
- Value optOperand, OperandRange varOperands,
- Type operandType, Type optOperandType,
- TypeRange varOperandTypes) {
- printCustomDirectiveOperands(printer, operand, optOperand, varOperands);
- printCustomDirectiveResults(printer, operandType, optOperandType,
+static void printCustomDirectiveOperandsAndTypes(
+ OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand,
+ OperandRange varOperands, Type operandType, Type optOperandType,
+ TypeRange varOperandTypes) {
+ printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands);
+ printCustomDirectiveResults(printer, op, operandType, optOperandType,
varOperandTypes);
}
-static void printCustomDirectiveRegions(OpAsmPrinter &printer, Region ®ion,
+static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *,
+ Region ®ion,
MutableArrayRef<Region> varRegions) {
printer.printRegion(region);
if (!varRegions.empty()) {
@@ -430,14 +435,14 @@ static void printCustomDirectiveRegions(OpAsmPrinter &printer, Region ®ion,
printer.printRegion(region);
}
}
-static void printCustomDirectiveSuccessors(OpAsmPrinter &printer,
+static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *,
Block *successor,
SuccessorRange varSuccessors) {
printer << successor;
if (!varSuccessors.empty())
printer << ", " << varSuccessors.front();
}
-static void printCustomDirectiveAttributes(OpAsmPrinter &printer,
+static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *,
Attribute attribute,
Attribute optAttribute) {
printer << attribute;
@@ -445,6 +450,10 @@ static void printCustomDirectiveAttributes(OpAsmPrinter &printer,
printer << ", " << optAttribute;
}
+static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op,
+ MutableDictionaryAttr attrs) {
+ printer.printOptionalAttrDict(attrs.getAttrs());
+}
//===----------------------------------------------------------------------===//
// Test IsolatedRegionOp - parse passthrough region arguments.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 48b37719b0e3..144f418c730f 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1638,6 +1638,14 @@ def FormatCustomDirectiveAttributes
}];
}
+def FormatCustomDirectiveAttrDict
+ : TEST_Op<"format_custom_directive_attrdict"> {
+ let arguments = (ins I64Attr:$attr, OptionalAttr<I64Attr>:$optAttr);
+ let assemblyFormat = [{
+ custom<CustomDirectiveAttrDict>( attr-dict )
+ }];
+}
+
//===----------------------------------------------------------------------===//
// AllTypesMatch type inference
diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index 01acd591b6fe..7d156ec4c933 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -863,7 +863,8 @@ static void genCustomParameterParser(Element ¶m, OpMethodBody &body) {
body << ", ";
if (auto *attr = dyn_cast<AttributeVariable>(¶m)) {
body << attr->getVar()->name << "Attr";
-
+ } else if (isa<AttrDictDirective>(¶m)) {
+ body << "result.attributes";
} else if (auto *operand = dyn_cast<OperandVariable>(¶m)) {
StringRef name = operand->getVar()->name;
ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
@@ -1502,12 +1503,18 @@ static void genSpacePrinter(OpMethodBody &body, bool &shouldEmitSpace,
/// Generate the printer for a custom directive.
static void genCustomDirectivePrinter(CustomDirective *customDir,
OpMethodBody &body) {
- body << " print" << customDir->getName() << "(p";
+ body << " print" << customDir->getName() << "(p, *this";
for (Element ¶m : customDir->getArguments()) {
body << ", ";
if (auto *attr = dyn_cast<AttributeVariable>(¶m)) {
body << attr->getVar()->name << "Attr()";
+ } else if (isa<AttrDictDirective>(¶m)) {
+ // Enforce the const-ness since getMutableAttrDict() returns a reference
+ // into the Operations `attr` member.
+ body << "(const "
+ "MutableDictionaryAttr&)getOperation()->getMutableAttrDict()";
+
} else if (auto *operand = dyn_cast<OperandVariable>(¶m)) {
body << operand->getVar()->name << "()";
@@ -2776,8 +2783,9 @@ LogicalResult FormatParser::parseCustomDirectiveParameter(
return ::mlir::failure();
// Verify that the element can be placed within a custom directive.
- if (!isa<TypeRefDirective, TypeDirective, AttributeVariable, OperandVariable,
- RegionVariable, SuccessorVariable>(parameters.back().get())) {
+ if (!isa<TypeRefDirective, TypeDirective, AttrDictDirective,
+ AttributeVariable, OperandVariable, RegionVariable,
+ SuccessorVariable>(parameters.back().get())) {
return emitError(childLoc, "only variables and types may be used as "
"parameters to a custom directive");
}
More information about the Mlir-commits
mailing list