[Mlir-commits] [mlir] ddc90da - [mlir] Printing oilist element

Shraiysh Vaishay llvmlistbot at llvm.org
Mon Mar 21 22:18:13 PDT 2022


Author: Shraiysh Vaishay
Date: 2022-03-22T10:48:03+05:30
New Revision: ddc90da478483437f26b4e27f8561cf37436a129

URL: https://github.com/llvm/llvm-project/commit/ddc90da478483437f26b4e27f8561cf37436a129
DIFF: https://github.com/llvm/llvm-project/commit/ddc90da478483437f26b4e27f8561cf37436a129.diff

LOG: [mlir] Printing oilist element

This patch attempts to deduce when the oilist element must be printed
based on the optional arguments to it. This especially helps creating
an operation accurately because with the current implementation, the
inferred unit attributes must be manually added to print the clauses
appropriately.

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D121579

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
    mlir/test/IR/traits.mlir
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/tools/mlir-tblgen/OpFormatGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 726152f8a7b3f..cf8d3370da1b9 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -199,7 +199,7 @@ def SectionsOp : OpenMP_Op<"sections", [AttrSizedOperandSegments]> {
                 $allocate_vars, type($allocate_vars),
                 $allocators_vars, type($allocators_vars)
               ) `)`
-          | `nowait`
+          | `nowait` $nowait
     ) $region attr-dict
   }];
 
@@ -438,7 +438,7 @@ def TargetOp : OpenMP_Op<"target",[AttrSizedOperandSegments]> {
     oilist( `if` `(` $if_expr `)`
           | `device` `(` $device `:` type($device) `)`
           | `thread_limit` `(` $thread_limit `:` type($thread_limit) `)`
-          | `nowait`
+          | `nowait` $nowait
     ) $region attr-dict
   }];
 }

diff  --git a/mlir/test/IR/traits.mlir b/mlir/test/IR/traits.mlir
index 2bd0ca6f3f020..bcd76797413d4 100644
--- a/mlir/test/IR/traits.mlir
+++ b/mlir/test/IR/traits.mlir
@@ -498,6 +498,10 @@ func @succeededOilistTrivial() {
   test.oilist_with_keywords_only keyword otherKeyword
   // CHECK: test.oilist_with_keywords_only keyword otherKeyword
   test.oilist_with_keywords_only otherKeyword keyword
+  // CHECK: test.oilist_with_keywords_only thirdKeyword
+  test.oilist_with_keywords_only thirdKeyword
+  // CHECK: test.oilist_with_keywords_only keyword thirdKeyword
+  test.oilist_with_keywords_only keyword thirdKeyword
   return
 }
 
@@ -550,7 +554,7 @@ func @succeededOilistCustom(%arg0: i32, %arg1: i32, %arg2: i32) {
   test.oilist_custom private (%arg0, %arg1 : i32, i32)
   // CHECK: test.oilist_custom private(%[[ARG0]], %[[ARG1]] : i32, i32) nowait
   test.oilist_custom private (%arg0, %arg1 : i32, i32) nowait
-  // CHECK: test.oilist_custom private(%arg0, %arg1 : i32, i32) nowait reduction (%arg1)
+  // CHECK: test.oilist_custom private(%arg0, %arg1 : i32, i32) reduction (%arg1) nowait
   test.oilist_custom nowait reduction (%arg1) private (%arg0, %arg1 : i32, i32)
   return
 }

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index b675a5515c994..5bb397353af28 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -656,9 +656,12 @@ def CustomFormatFallbackOp : TEST_Op<"dialect_custom_format_fallback">;
 
 // Ops related to OIList primitive
 def OIListTrivial : TEST_Op<"oilist_with_keywords_only"> {
+  let arguments = (ins UnitAttr:$keyword, UnitAttr:$otherKeyword,
+                       UnitAttr:$
diff NameUnitAttrKeyword);
   let assemblyFormat = [{
-    oilist( `keyword`
-          | `otherKeyword`) attr-dict
+    oilist( `keyword` $keyword
+          | `otherKeyword` $otherKeyword
+          | `thirdKeyword` $
diff NameUnitAttrKeyword) attr-dict
   }];
 }
 
@@ -690,8 +693,8 @@ def OIListCustom : TEST_Op<"oilist_custom", [AttrSizedOperandSegments]> {
                        UnitAttr:$nowait);
   let assemblyFormat = [{
     oilist( `private` `(` $arg0 `:` type($arg0) `)`
-          | `nowait`
           | `reduction` custom<CustomOptionalOperand>($optOperand)
+          | `nowait` $nowait
     ) attr-dict
   }];
 }

diff  --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index 3bf9ad99a34b8..fb54dcb3d85f8 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -207,6 +207,18 @@ class OIListElement : public DirectiveElementBase<DirectiveElement::OIList> {
     return llvm::zip(getLiteralElements(), getParsingElements());
   }
 
+  /// If the parsing element is a single UnitAttr element, then it returns the
+  /// attribute variable. Otherwise, returns nullptr.
+  AttributeVariable *
+  getUnitAttrParsingElement(ArrayRef<FormatElement *> pelement) {
+    if (pelement.size() == 1) {
+      auto attrElem = dyn_cast<AttributeVariable>(pelement[0]);
+      if (attrElem && attrElem->isUnitAttr())
+        return attrElem;
+    }
+    return nullptr;
+  }
+
 private:
   /// A vector of `LiteralElement` objects. Each element stores the keyword
   /// for one case of oilist element. For example, an oilist element along with
@@ -684,7 +696,6 @@ const char *oilistParserCode = R"(
              "oilist directive";
   }
   {0}Clause = true;
-  result.addAttribute("{0}", UnitAttr::get(parser.getContext()));
 )";
 
 namespace {
@@ -778,9 +789,11 @@ static void genElementParserStorage(FormatElement *element, const Operator &op,
       genElementParserStorage(childElement, op, body);
 
   } else if (auto *oilist = dyn_cast<OIListElement>(element)) {
-    for (ArrayRef<FormatElement *> pelement : oilist->getParsingElements())
-      for (FormatElement *element : pelement)
-        genElementParserStorage(element, op, body);
+    for (ArrayRef<FormatElement *> pelement : oilist->getParsingElements()) {
+      if (!oilist->getUnitAttrParsingElement(pelement))
+        for (FormatElement *element : pelement)
+          genElementParserStorage(element, op, body);
+    }
 
   } else if (auto *custom = dyn_cast<CustomDirective>(element)) {
     for (FormatElement *paramElement : custom->getArguments())
@@ -1180,11 +1193,16 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
       body << "if (succeeded(parser.parseOptional";
       genLiteralParser(lelement->getSpelling(), body);
       body << ")) {\n";
-      StringRef attrName = lelement->getSpelling();
-      body << formatv(oilistParserCode, attrName);
-      inferredAttributes.insert(attrName);
-      for (FormatElement *el : pelement)
-        genElementParser(el, body, attrTypeCtx);
+      StringRef lelementName = lelement->getSpelling();
+      body << formatv(oilistParserCode, lelementName);
+      if (AttributeVariable *unitAttrElem =
+              oilist->getUnitAttrParsingElement(pelement)) {
+        body << "  result.addAttribute(\"" << unitAttrElem->getVar()->name
+             << "\", UnitAttr::get(parser.getContext()));\n";
+      } else {
+        for (FormatElement *el : pelement)
+          genElementParser(el, body, attrTypeCtx);
+      }
       body << "    } else ";
     }
     body << " {\n";
@@ -1873,6 +1891,31 @@ static void genOptionalGroupPrinterAnchor(FormatElement *anchor,
       });
 }
 
+void collect(FormatElement *element,
+             SmallVectorImpl<VariableElement *> &variables) {
+  TypeSwitch<FormatElement *>(element)
+      .Case([&](VariableElement *var) { variables.emplace_back(var); })
+      .Case([&](CustomDirective *ele) {
+        for (FormatElement *arg : ele->getArguments())
+          collect(arg, variables);
+      })
+      .Case([&](OptionalElement *ele) {
+        for (FormatElement *arg : ele->getThenElements())
+          collect(arg, variables);
+        for (FormatElement *arg : ele->getElseElements())
+          collect(arg, variables);
+      })
+      .Case([&](FunctionalTypeDirective *funcType) {
+        collect(funcType->getInputs(), variables);
+        collect(funcType->getResults(), variables);
+      })
+      .Case([&](OIListElement *oilist) {
+        for (ArrayRef<FormatElement *> arg : oilist->getParsingElements())
+          for (FormatElement *arg_ : arg)
+            collect(arg_, variables);
+      });
+}
+
 void OperationFormat::genElementPrinter(FormatElement *element,
                                         MethodBody &body, Operator &op,
                                         bool &shouldEmitSpace,
@@ -1939,13 +1982,44 @@ void OperationFormat::genElementPrinter(FormatElement *element,
       LiteralElement *lelement = std::get<0>(clause);
       ArrayRef<FormatElement *> pelement = std::get<1>(clause);
 
-      body << "  if ((*this)->hasAttrOfType<UnitAttr>(\""
-           << lelement->getSpelling() << "\")) {\n";
+      SmallVector<VariableElement *> vars;
+      for (FormatElement *el : pelement)
+        collect(el, vars);
+      body << "  if (false";
+      for (VariableElement *var : vars) {
+        TypeSwitch<FormatElement *>(var)
+            .Case([&](AttributeVariable *attrEle) {
+              body << " || " << op.getGetterName(attrEle->getVar()->name)
+                   << "Attr()";
+            })
+            .Case([&](OperandVariable *ele) {
+              if (ele->getVar()->isVariadic()) {
+                body << " || " << op.getGetterName(ele->getVar()->name)
+                     << "().size()";
+              } else {
+                body << " || " << op.getGetterName(ele->getVar()->name) << "()";
+              }
+            })
+            .Case([&](ResultVariable *ele) {
+              if (ele->getVar()->isVariadic()) {
+                body << " || " << op.getGetterName(ele->getVar()->name)
+                     << "().size()";
+              } else {
+                body << " || " << op.getGetterName(ele->getVar()->name) << "()";
+              }
+            })
+            .Case([&](RegionVariable *reg) {
+              body << " || " << op.getGetterName(reg->getVar()->name) << "()";
+            });
+      }
+
+      body << ") {\n";
       genLiteralPrinter(lelement->getSpelling(), body, shouldEmitSpace,
                         lastWasPunctuation);
-      for (FormatElement *element : pelement) {
-        genElementPrinter(element, body, op, shouldEmitSpace,
-                          lastWasPunctuation);
+      if (oilist->getUnitAttrParsingElement(pelement) == nullptr) {
+        for (FormatElement *element : pelement)
+          genElementPrinter(element, body, op, shouldEmitSpace,
+                            lastWasPunctuation);
       }
       body << "  }\n";
     }
@@ -2866,51 +2940,45 @@ OpFormatParser::parseOIListDirective(SMLoc loc, Context context) {
 
 LogicalResult OpFormatParser::verifyOIListParsingElement(FormatElement *element,
                                                          SMLoc loc) {
-  return TypeSwitch<FormatElement *, LogicalResult>(element)
-      // Only optional attributes can be within an oilist parsing group.
-      .Case([&](AttributeVariable *attrEle) {
-        if (!attrEle->getVar()->attr.isOptional())
-          return emitError(loc, "only optional attributes can be used to "
-                                "in an oilist parsing group");
-        return success();
-      })
-      // Only optional-like(i.e. variadic) operands can be within an oilist
-      // parsing group.
-      .Case([&](OperandVariable *ele) {
-        if (!ele->getVar()->isVariableLength())
-          return emitError(loc, "only variable length operands can be "
-                                "used within an oilist parsing group");
-        return success();
-      })
-      // Only optional-like(i.e. variadic) results can be within an oilist
-      // parsing group.
-      .Case([&](ResultVariable *ele) {
-        if (!ele->getVar()->isVariableLength())
-          return emitError(loc, "only variable length results can be "
-                                "used within an oilist parsing group");
-        return success();
-      })
-      .Case([&](RegionVariable *) {
-        // TODO: When ODS has proper support for marking "optional" regions, add
-        // a check here.
-        return success();
-      })
-      .Case([&](TypeDirective *ele) {
-        return verifyOIListParsingElement(ele->getArg(), loc);
-      })
-      .Case([&](FunctionalTypeDirective *ele) {
-        if (failed(verifyOIListParsingElement(ele->getInputs(), loc)))
-          return failure();
-        return verifyOIListParsingElement(ele->getResults(), loc);
-      })
-      // Literals, whitespace, and custom directives may be used.
-      .Case<LiteralElement, WhitespaceElement, CustomDirective,
-            FunctionalTypeDirective, OptionalElement>(
-          [&](FormatElement *) { return success(); })
-      .Default([&](FormatElement *) {
-        return emitError(loc, "only literals, types, and variables can be "
-                              "used within an oilist group");
-      });
+  SmallVector<VariableElement *> vars;
+  collect(element, vars);
+  for (VariableElement *elem : vars) {
+    LogicalResult res =
+        TypeSwitch<FormatElement *, LogicalResult>(elem)
+            // Only optional attributes can be within an oilist parsing group.
+            .Case([&](AttributeVariable *attrEle) {
+              if (!attrEle->getVar()->attr.isOptional() &&
+                  !attrEle->getVar()->attr.hasDefaultValue())
+                return emitError(loc, "only optional attributes can be used in "
+                                      "an oilist parsing group");
+              return success();
+            })
+            // Only optional-like(i.e. variadic) operands can be within an
+            // oilist parsing group.
+            .Case([&](OperandVariable *ele) {
+              if (!ele->getVar()->isVariableLength())
+                return emitError(loc, "only variable length operands can be "
+                                      "used within an oilist parsing group");
+              return success();
+            })
+            // Only optional-like(i.e. variadic) results can be within an oilist
+            // parsing group.
+            .Case([&](ResultVariable *ele) {
+              if (!ele->getVar()->isVariableLength())
+                return emitError(loc, "only variable length results can be "
+                                      "used within an oilist parsing group");
+              return success();
+            })
+            .Case([&](RegionVariable *) { return success(); })
+            .Default([&](FormatElement *) {
+              return emitError(loc,
+                               "only literals, types, and variables can be "
+                               "used within an oilist group");
+            });
+    if (failed(res))
+      return failure();
+  }
+  return success();
 }
 
 FailureOr<FormatElement *> OpFormatParser::parseTypeDirective(SMLoc loc,


        


More information about the Mlir-commits mailing list