[Mlir-commits] [mlir] 035e12e - [MLIR] [ODS] Allowing attr-dict in custom directive

John Demme llvmlistbot at llvm.org
Tue Oct 27 18:25:51 PDT 2020


Author: John Demme
Date: 2020-10-28T01:24:16Z
New Revision: 035e12e66449576051fb0ac91ff84786f330ad95

URL: https://github.com/llvm/llvm-project/commit/035e12e66449576051fb0ac91ff84786f330ad95
DIFF: https://github.com/llvm/llvm-project/commit/035e12e66449576051fb0ac91ff84786f330ad95.diff

LOG: [MLIR] [ODS] Allowing attr-dict in custom directive

Enhance tblgen's declarative assembly format to allow `attr-dict` in
custom directives.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D89772

Added: 
    

Modified: 
    mlir/docs/OpDefinitions.md
    mlir/lib/Dialect/Async/IR/Async.cpp
    mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
    mlir/test/lib/Dialect/Test/TestDialect.cpp
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/tools/mlir-tblgen/OpFormatGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md
index f44765204dc0..d24f4711da5f 100644
--- a/mlir/docs/OpDefinitions.md
+++ b/mlir/docs/OpDefinitions.md
@@ -799,12 +799,12 @@ when generating the C++ code for the format. The `UserDirective` is an
 identifier used as a suffix to these two calls, i.e., `custom<MyDirective>(...)`
 would result in calls to `parseMyDirective` and `printMyDirective` wihtin the
 parser and printer respectively. `Params` may be any combination of variables
-(i.e. Attribute, Operand, Successor, etc.) and type directives. The type
-directives must refer to a variable, but that variable need not also be a
-parameter to the custom directive.
+(i.e. Attribute, Operand, Successor, etc.), type directives, and `attr-dict`.
+The type directives must refer to a variable, but that variable need not also
+be a parameter to the custom directive.
 
-The arguments to the `parse<UserDirective>` method is firstly a reference to the
-`OpAsmParser`(`OpAsmParser &`), and secondly a set of output parameters
+The arguments to the `parse<UserDirective>` method are firstly a reference to
+the `OpAsmParser`(`OpAsmParser &`), and secondly a set of output parameters
 corresponding to the parameters specified in the format. The mapping of
 declarative parameter to `parse` method argument is detailed below:
 
@@ -829,12 +829,14 @@ declarative parameter to `parse` method argument is detailed below:
     -   Single: `Type`
     -   Optional: `Type`
     -   Variadic: `const SmallVectorImpl<Type> &`
+*   `attr-dict` Directive: `NamedAttrList &`
 
 When a variable is optional, the value should only be specified if the variable
 is present. Otherwise, the value should remain `None` or null.
 
-The arguments to the `print<UserDirective>` method is firstly a reference to the
-`OpAsmPrinter`(`OpAsmPrinter &`), and secondly a set of output parameters
+The arguments to the `print<UserDirective>` method is firstly a reference to
+the `OpAsmPrinter`(`OpAsmPrinter &`), second the op (e.g. `FooOp op` which
+can be `Operation *op` alternatively), and finally a set of output parameters
 corresponding to the parameters specified in the format. The mapping of
 declarative parameter to `print` method argument is detailed below:
 
@@ -859,6 +861,7 @@ declarative parameter to `print` method argument is detailed below:
     -   Single: `Type`
     -   Optional: `Type`
     -   Variadic: `TypeRange`
+*   `attr-dict` Directive: `const MutableDictionaryAttr&`
 
 When a variable is optional, the provided value may be null.
 

diff  --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp
index 5ef3931ae3fd..36ef1af8cfa7 100644
--- a/mlir/lib/Dialect/Async/IR/Async.cpp
+++ b/mlir/lib/Dialect/Async/IR/Async.cpp
@@ -271,8 +271,8 @@ static ParseResult parseAwaitResultType(OpAsmParser &parser, Type &operandType,
   return success();
 }
 
-static void printAwaitResultType(OpAsmPrinter &p, Type operandType,
-                                 Type resultType) {
+static void printAwaitResultType(OpAsmPrinter &p, Operation *op,
+                                 Type operandType, Type resultType) {
   p << operandType;
 }
 

diff  --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 7abefd7a5499..9514d3e0a04b 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -493,7 +493,7 @@ parseLaunchFuncOperands(OpAsmParser &parser,
                                          argTypes, argAttrs, isVariadic);
 }
 
-static void printLaunchFuncOperands(OpAsmPrinter &printer,
+static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *,
                                     OperandRange operands, TypeRange types) {
   if (operands.empty())
     return;
@@ -846,7 +846,8 @@ static ParseResult parseAsyncDependencies(
                                  OpAsmParser::Delimiter::OptionalSquare);
 }
 
-static void printAsyncDependencies(OpAsmPrinter &printer, Type asyncTokenType,
+static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op,
+                                   Type asyncTokenType,
                                    OperandRange asyncDependencies) {
   if (asyncTokenType)
     printer << "async ";

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index d34e997644a5..d2013d8c6941 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -385,19 +385,24 @@ static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser,
   return success();
 }
 
+static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser,
+                                                NamedAttrList &attrs) {
+  return parser.parseOptionalAttrDict(attrs);
+}
+
 //===----------------------------------------------------------------------===//
 // Printing
 
-static void printCustomDirectiveOperands(OpAsmPrinter &printer, Value operand,
-                                         Value optOperand,
+static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *,
+                                         Value operand, Value optOperand,
                                          OperandRange varOperands) {
   printer << operand;
   if (optOperand)
     printer << ", " << optOperand;
   printer << " -> (" << varOperands << ")";
 }
-static void printCustomDirectiveResults(OpAsmPrinter &printer, Type operandType,
-                                        Type optOperandType,
+static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *,
+                                        Type operandType, Type optOperandType,
                                         TypeRange varOperandTypes) {
   printer << " : " << operandType;
   if (optOperandType)
@@ -405,23 +410,23 @@ static void printCustomDirectiveResults(OpAsmPrinter &printer, Type operandType,
   printer << " -> (" << varOperandTypes << ")";
 }
 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer,
-                                             Type operandType,
+                                             Operation *op, Type operandType,
                                              Type optOperandType,
                                              TypeRange varOperandTypes) {
   printer << " type_refs_capture ";
-  printCustomDirectiveResults(printer, operandType, optOperandType,
+  printCustomDirectiveResults(printer, op, operandType, optOperandType,
                               varOperandTypes);
 }
-static void
-printCustomDirectiveOperandsAndTypes(OpAsmPrinter &printer, Value operand,
-                                     Value optOperand, OperandRange varOperands,
-                                     Type operandType, Type optOperandType,
-                                     TypeRange varOperandTypes) {
-  printCustomDirectiveOperands(printer, operand, optOperand, varOperands);
-  printCustomDirectiveResults(printer, operandType, optOperandType,
+static void printCustomDirectiveOperandsAndTypes(
+    OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand,
+    OperandRange varOperands, Type operandType, Type optOperandType,
+    TypeRange varOperandTypes) {
+  printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands);
+  printCustomDirectiveResults(printer, op, operandType, optOperandType,
                               varOperandTypes);
 }
-static void printCustomDirectiveRegions(OpAsmPrinter &printer, Region &region,
+static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *,
+                                        Region &region,
                                         MutableArrayRef<Region> varRegions) {
   printer.printRegion(region);
   if (!varRegions.empty()) {
@@ -430,14 +435,14 @@ static void printCustomDirectiveRegions(OpAsmPrinter &printer, Region &region,
       printer.printRegion(region);
   }
 }
-static void printCustomDirectiveSuccessors(OpAsmPrinter &printer,
+static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *,
                                            Block *successor,
                                            SuccessorRange varSuccessors) {
   printer << successor;
   if (!varSuccessors.empty())
     printer << ", " << varSuccessors.front();
 }
-static void printCustomDirectiveAttributes(OpAsmPrinter &printer,
+static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *,
                                            Attribute attribute,
                                            Attribute optAttribute) {
   printer << attribute;
@@ -445,6 +450,10 @@ static void printCustomDirectiveAttributes(OpAsmPrinter &printer,
     printer << ", " << optAttribute;
 }
 
+static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op,
+                                         MutableDictionaryAttr attrs) {
+  printer.printOptionalAttrDict(attrs.getAttrs());
+}
 //===----------------------------------------------------------------------===//
 // Test IsolatedRegionOp - parse passthrough region arguments.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 48b37719b0e3..144f418c730f 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1638,6 +1638,14 @@ def FormatCustomDirectiveAttributes
   }];
 }
 
+def FormatCustomDirectiveAttrDict
+    : TEST_Op<"format_custom_directive_attrdict"> {
+  let arguments = (ins I64Attr:$attr, OptionalAttr<I64Attr>:$optAttr);
+  let assemblyFormat = [{
+    custom<CustomDirectiveAttrDict>( attr-dict )
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // AllTypesMatch type inference
 

diff  --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index 01acd591b6fe..7d156ec4c933 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -863,7 +863,8 @@ static void genCustomParameterParser(Element &param, OpMethodBody &body) {
   body << ", ";
   if (auto *attr = dyn_cast<AttributeVariable>(&param)) {
     body << attr->getVar()->name << "Attr";
-
+  } else if (isa<AttrDictDirective>(&param)) {
+    body << "result.attributes";
   } else if (auto *operand = dyn_cast<OperandVariable>(&param)) {
     StringRef name = operand->getVar()->name;
     ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
@@ -1502,12 +1503,18 @@ static void genSpacePrinter(OpMethodBody &body, bool &shouldEmitSpace,
 /// Generate the printer for a custom directive.
 static void genCustomDirectivePrinter(CustomDirective *customDir,
                                       OpMethodBody &body) {
-  body << "  print" << customDir->getName() << "(p";
+  body << "  print" << customDir->getName() << "(p, *this";
   for (Element &param : customDir->getArguments()) {
     body << ", ";
     if (auto *attr = dyn_cast<AttributeVariable>(&param)) {
       body << attr->getVar()->name << "Attr()";
 
+    } else if (isa<AttrDictDirective>(&param)) {
+      // Enforce the const-ness since getMutableAttrDict() returns a reference
+      // into the Operations `attr` member.
+      body << "(const "
+              "MutableDictionaryAttr&)getOperation()->getMutableAttrDict()";
+
     } else if (auto *operand = dyn_cast<OperandVariable>(&param)) {
       body << operand->getVar()->name << "()";
 
@@ -2776,8 +2783,9 @@ LogicalResult FormatParser::parseCustomDirectiveParameter(
     return ::mlir::failure();
 
   // Verify that the element can be placed within a custom directive.
-  if (!isa<TypeRefDirective, TypeDirective, AttributeVariable, OperandVariable,
-           RegionVariable, SuccessorVariable>(parameters.back().get())) {
+  if (!isa<TypeRefDirective, TypeDirective, AttrDictDirective,
+           AttributeVariable, OperandVariable, RegionVariable,
+           SuccessorVariable>(parameters.back().get())) {
     return emitError(childLoc, "only variables and types may be used as "
                                "parameters to a custom directive");
   }


        


More information about the Mlir-commits mailing list