[Mlir-commits] [mlir] 761bc83 - [mlir][ods] Default-valued parameters in attribute or type defs

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Feb 15 11:02:15 PST 2022


Author: Mogball
Date: 2022-02-15T19:02:11Z
New Revision: 761bc83af4ee70c6c5156f73ed88947a5d9f013f

URL: https://github.com/llvm/llvm-project/commit/761bc83af4ee70c6c5156f73ed88947a5d9f013f
DIFF: https://github.com/llvm/llvm-project/commit/761bc83af4ee70c6c5156f73ed88947a5d9f013f.diff

LOG: [mlir][ods] Default-valued parameters in attribute or type defs

Optional parameters with `defaultValue` set will be populated with that value if they aren't encountered during parsing. Moreover, parameters equal to their default values are elided when printing.

Depends on D118210

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D118544

Added: 
    

Modified: 
    mlir/docs/Tutorials/DefiningAttributesAndTypes.md
    mlir/include/mlir/IR/OpBase.td
    mlir/include/mlir/TableGen/AttrOrTypeDef.h
    mlir/lib/TableGen/AttrOrTypeDef.cpp
    mlir/test/lib/Dialect/Test/TestTypeDefs.td
    mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir
    mlir/test/mlir-tblgen/attr-or-type-format.td
    mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
    mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Tutorials/DefiningAttributesAndTypes.md b/mlir/docs/Tutorials/DefiningAttributesAndTypes.md
index 1501b54af5279..3260ac10896f4 100644
--- a/mlir/docs/Tutorials/DefiningAttributesAndTypes.md
+++ b/mlir/docs/Tutorials/DefiningAttributesAndTypes.md
@@ -518,6 +518,37 @@ printed as `(5 : i32)`. If it is not present, it will be `x`. Directives that
 are used inside optional groups are allowed only if all captured parameters are
 also optional.
 
+#### Default-Valued Parameters
+
+Optional parameters can be given default values by setting `defaultValue`, a
+string of the C++ default value, or by using `DefaultValuedParameter`. If a
+value for the parameter was not encountered during parsing, it is set to this
+default value. If a parameter is equal to its default value, it is not printed.
+The `comparator` field of the parameter is used, but if one is not specified,
+the equality operator is used.
+
+For example:
+
+```
+let parameters = (ins DefaultValuedParameter<"Optional<int>", "5">:$a)
+let mnemonic = "default_valued";
+let assemblyFormat = "(`<` $a^ `>`)?";
+```
+
+Which will look like:
+
+```
+!test.default_valued     // a = 5
+!test.default_valued<10> // a = 10
+```
+
+For optional `Attribute` or `Type` parameters, the current MLIR context is
+available through `$_ctx`. E.g.
+
+```
+DefaultValuedParameter<"IntegerType", "IntegerType::get($_ctx, 32)">
+```
+
 ### Assembly Format Directives
 
 Attribute and type assembly formats have the following directives:

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index bb108f714a789..5ad8bd45b339e 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -3117,7 +3117,8 @@ class TypeDef<Dialect dialect, string name, list<Trait> traits = [],
 class AttrOrTypeParameter<string type, string desc, string accessorType = ""> {
   // Custom memory allocation code for storage constructor.
   code allocator = ?;
-  // Custom comparator used to compare two instances for equality.
+  // Comparator used to compare two instances for equality. By default, it uses
+  // the C++ equality operator.
   code comparator = ?;
   // The C++ type of this parameter.
   string cppType = type;
@@ -3143,6 +3144,14 @@ class AttrOrTypeParameter<string type, string desc, string accessorType = ""> {
   // must be default constructible and be contextually convertible to `bool`.
   // Any `Optional<T>` and any attribute type satisfies these requirements.
   bit isOptional = 0;
+  // Provide a default value for the parameter. Parameters with default values
+  // are considered optional. If a value was not parsed for the parameter, it
+  // will be set to the default value. Parameters equal to their default values
+  // are elided when printing. Equality is checked using the `comparator` field,
+  // which by default is the C++ equality operator. The current MLIR context is
+  // made available through `$_ctx`, e.g., for constructing default values for
+  // attributes and types.
+  string defaultValue = ?;
 }
 class AttrParameter<string type, string desc, string accessorType = "">
  : AttrOrTypeParameter<type, desc, accessorType>;
@@ -3193,6 +3202,13 @@ class OptionalParameter<string type, string desc = ""> :
   let isOptional = 1;
 }
 
+// A parameter with a default value.
+class DefaultValuedParameter<string type, string value, string desc = ""> :
+    AttrOrTypeParameter<type, desc> {
+  let isOptional = 1;
+  let defaultValue = value;
+}
+
 // This is a special parameter used for AttrDefs that represents a `mlir::Type`
 // that is also used as the value `Type` of the attribute. Only one parameter
 // of the attribute may be of this type.

diff  --git a/mlir/include/mlir/TableGen/AttrOrTypeDef.h b/mlir/include/mlir/TableGen/AttrOrTypeDef.h
index dcfdc8ab28a63..557f60b2686eb 100644
--- a/mlir/include/mlir/TableGen/AttrOrTypeDef.h
+++ b/mlir/include/mlir/TableGen/AttrOrTypeDef.h
@@ -52,6 +52,9 @@ class AttrOrTypeParameter {
   explicit AttrOrTypeParameter(const llvm::DagInit *def, unsigned index)
       : def(def), index(index) {}
 
+  /// Returns true if the parameter is anonymous (has no name).
+  bool isAnonymous() const;
+
   /// Get the parameter name.
   StringRef getName() const;
 
@@ -59,7 +62,7 @@ class AttrOrTypeParameter {
   Optional<StringRef> getAllocator() const;
 
   /// If specified, get the custom comparator code for this parameter.
-  Optional<StringRef> getComparator() const;
+  StringRef getComparator() const;
 
   /// Get the C++ type of this parameter.
   StringRef getCppType() const;
@@ -85,6 +88,9 @@ class AttrOrTypeParameter {
   /// Returns true if the parameter is optional.
   bool isOptional() const;
 
+  /// Get the default value of the parameter if it has one.
+  Optional<StringRef> getDefaultValue() const;
+
   /// Return the underlying def of this parameter.
   llvm::Init *getDef() const;
 

diff  --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp
index 2d92a5837c79f..3c0cfa029c700 100644
--- a/mlir/lib/TableGen/AttrOrTypeDef.cpp
+++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp
@@ -187,6 +187,10 @@ auto AttrOrTypeParameter::getDefValue(StringRef name) const {
   return result;
 }
 
+bool AttrOrTypeParameter::isAnonymous() const {
+  return !def->getArgName(index);
+}
+
 StringRef AttrOrTypeParameter::getName() const {
   return def->getArgName(index)->getValue();
 }
@@ -195,8 +199,9 @@ Optional<StringRef> AttrOrTypeParameter::getAllocator() const {
   return getDefValue<llvm::StringInit>("allocator");
 }
 
-Optional<StringRef> AttrOrTypeParameter::getComparator() const {
-  return getDefValue<llvm::StringInit>("comparator");
+StringRef AttrOrTypeParameter::getComparator() const {
+  return getDefValue<llvm::StringInit>("comparator")
+      .getValueOr("$_lhs == $_rhs");
 }
 
 StringRef AttrOrTypeParameter::getCppType() const {
@@ -239,7 +244,13 @@ StringRef AttrOrTypeParameter::getSyntax() const {
 }
 
 bool AttrOrTypeParameter::isOptional() const {
-  return getDefValue<llvm::BitInit>("isOptional").getValueOr(false);
+  // Parameters with default values are automatically optional.
+  return getDefValue<llvm::BitInit>("isOptional").getValueOr(false) ||
+         getDefaultValue().hasValue();
+}
+
+Optional<StringRef> AttrOrTypeParameter::getDefaultValue() const {
+  return getDefValue<llvm::StringInit>("defaultValue");
 }
 
 llvm::Init *AttrOrTypeParameter::getDef() const { return def->getArg(index); }

diff  --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index b8dce7ec50f93..de4969353ad01 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -311,7 +311,7 @@ def TestTypeOptionalGroup : Test_Type<"TestTypeOptionalGroup"> {
 }
 
 def TestTypeOptionalGroupParams : Test_Type<"TestTypeOptionalGroupParams"> {
-  let parameters = (ins OptionalParameter<"mlir::Optional<int>">:$a,
+  let parameters = (ins DefaultValuedParameter<"mlir::Optional<int>", "10">:$a,
                         OptionalParameter<"mlir::Optional<int>">:$b);
   let mnemonic = "optional_group_params";
   let assemblyFormat = "`<` (`(` params^ `)`) : (`x`)? `>`";
@@ -330,4 +330,37 @@ def TestTypeSpaces : Test_Type<"TestTypeSpaceS"> {
   let assemblyFormat = "`<` ` ` $a `\\n` `(` `)` `` `(` `)` $b `>`";
 }
 
+class DefaultValuedAPFloat<string value>
+    : DefaultValuedParameter<"llvm::Optional<llvm::APFloat>",
+                             "llvm::Optional<llvm::APFloat>(" # value # ")"> {
+  let comparator = "$_lhs->bitwiseIsEqual(*$_rhs)";
+  let parser = [{ [&]() -> mlir::FailureOr<llvm::Optional<llvm::APFloat>> {
+    mlir::FloatAttr attr;
+    auto result = $_parser.parseOptionalAttribute(attr);
+    if (result.hasValue() && mlir::succeeded(*result))
+      return {attr.getValue()};
+    if (!result.hasValue())
+      return llvm::Optional<llvm::APFloat>();
+    return mlir::failure();
+  }() }];
+  let printer = "$_printer << *$_self";
+}
+
+def TestTypeAPFloat : Test_Type<"TestTypeAPFloat"> {
+  let parameters = (ins
+    DefaultValuedAPFloat<"APFloat::getZero(APFloat::IEEEdouble())">:$a
+  );
+  let mnemonic = "ap_float";
+  let assemblyFormat = "`<` $a `>`";
+}
+
+def TestTypeDefaultValuedType : Test_Type<"TestTypeDefaultValuedType"> {
+  let parameters = (ins
+    DefaultValuedParameter<"mlir::IntegerType",
+                           "mlir::IntegerType::get($_ctx, 32)">:$type
+  );
+  let mnemonic = "default_valued_type";
+  let assemblyFormat = "`<` (`(` $type^ `)`)? `>`";
+}
+
 #endif // TEST_TYPEDEFS

diff  --git a/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir b/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir
index 47f25cb20406b..b626ca44f36a1 100644
--- a/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir
+++ b/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir
@@ -44,6 +44,11 @@ attributes {
 // CHECK: !test.optional_group_struct<(a = 10, b = 5)>
 // CHECK: !test.spaces< 5
 // CHECK-NEXT: ()() 6>
+// CHECK: !test.ap_float<5.000000e+00>
+// CHECK: !test.ap_float<>
+// CHECK: !test.default_valued_type<(i64)>
+// CHECK: !test.default_valued_type<>
+
 func private @test_roundtrip_default_parsers_struct(
   !test.no_parser<255, [1, 2, 3, 4, 5], "foobar", 4>
 ) -> (
@@ -70,5 +75,9 @@ func private @test_roundtrip_default_parsers_struct(
   !test.optional_group_struct<x>,
   !test.optional_group_struct<(b = 5)>,
   !test.optional_group_struct<(b = 5, a = 10)>,
-  !test.spaces<5 ()() 6>
+  !test.spaces<5 ()() 6>,
+  !test.ap_float<5.0>,
+  !test.ap_float<>,
+  !test.default_valued_type<(i64)>,
+  !test.default_valued_type<>
 )

diff  --git a/mlir/test/mlir-tblgen/attr-or-type-format.td b/mlir/test/mlir-tblgen/attr-or-type-format.td
index 4ed281d488db5..f0c1aac4af446 100644
--- a/mlir/test/mlir-tblgen/attr-or-type-format.td
+++ b/mlir/test/mlir-tblgen/attr-or-type-format.td
@@ -389,10 +389,13 @@ def TypeC : TestType<"TestE"> {
   let assemblyFormat = "`{` struct($v0, $v2) `}` `{` struct($v1, $v3) `}`";
 }
 
+// TYPE: ::mlir::Type TestFType::parse(::mlir::AsmParser &odsParser) {
+// TYPE:   _result_a.getValueOr(int())
+
 // TYPE: void TestFType::print(::mlir::AsmPrinter &odsPrinter) const {
 // TYPE if (getA()) {
-// TYPE   printer << ' ';
-// TYPE   printer.printStrippedAttrOrType(getA());
+// TYPE   odsPrinter << ' ';
+// TYPE   odsPrinter.printStrippedAttrOrType(getA());
 def TypeD : TestType<"TestF"> {
   let parameters = (ins OptionalParameter<"int">:$a);
   let mnemonic = "type_f";
@@ -406,7 +409,7 @@ def TypeD : TestType<"TestF"> {
 // TYPE:     if (odsParser.parseComma())
 // TYPE:       return {};
 
-// TYPE: if (getA())
+// TYPE: if ((getA()))
 // TYPE:   odsPrinter.printStrippedAttrOrType(getA());
 // TYPE: odsPrinter << ", ";
 // TYPE: odsPrinter.printStrippedAttrOrType(getB());
@@ -426,7 +429,7 @@ def TypeE : TestType<"TestG"> {
 // TYPE:     return {};
 
 // TYPE: void TestHType::print(::mlir::AsmPrinter &odsPrinter) const {
-// TYPE:   if (getA()) {
+// TYPE:   if ((getA())) {
 // TYPE:     odsPrinter << "a = ";
 // TYPE:     odsPrinter.printStrippedAttrOrType(getA());
 // TYPE:     odsPrinter << ", ";
@@ -469,9 +472,9 @@ def TypeG : TestType<"TestI"> {
 // TYPE:     return {};
 
 // TYPE: void TestJType::print(::mlir::AsmPrinter &odsPrinter) const {
-// TYPE:   if (getB()) {
+// TYPE:   if ((getB())) {
 // TYPE:     odsPrinter << "(";
-// TYPE:     if (getB())
+// TYPE:     if ((getB()))
 // TYPE:       odsPrinter.printStrippedAttrOrType(getB());
 // TYPE:     odsPrinter << ")";
 // TYPE:   } else {
@@ -484,3 +487,15 @@ def TypeH : TestType<"TestJ"> {
   let mnemonic = "type_j";
   let assemblyFormat = "(`(` $b^ `)`) : (`x`)? $a";
 }
+
+// TYPE: ::mlir::Type TestKType::parse(::mlir::AsmParser &odsParser) {
+// TYPE:   _result_a.getValueOr(10)
+
+// TYPE: void TestKType::print(::mlir::AsmPrinter &odsPrinter) const {
+// TYPE:   if ((getA() && !(getA() == 10)))
+
+def TypeI : TestType<"TestK"> {
+  let parameters = (ins DefaultValuedParameter<"int", "10">:$a);
+  let mnemonic = "type_k";
+  let assemblyFormat = "$a";
+}

diff  --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
index 7e2b147478e17..c6db1c0418846 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
@@ -179,6 +179,11 @@ DefGen::DefGen(const AttrOrTypeDef &def)
     : def(def), params(def.getParameters()), defCls(def.getCppClassName()),
       valueType(isa<AttrDef>(def) ? "Attribute" : "Type"),
       defType(isa<AttrDef>(def) ? "Attr" : "Type") {
+  // Check that all parameters have names.
+  for (const AttrOrTypeParameter &param : def.getParameters())
+    if (param.isAnonymous())
+      llvm::PrintFatalError("all parameters must have a name");
+
   // If a storage class is needed, create one.
   if (def.getNumParameters() > 0)
     storageCls.emplace(def.getStorageClassName(), /*isStruct=*/true);
@@ -535,8 +540,7 @@ void DefGen::emitEquals() {
                                  ? "getType()"
                                  : it.value().getName()},
                     {"_rhs", strfmt("std::get<{0}>(tblgenKey)", it.index())}});
-    Optional<StringRef> comparator = it.value().getComparator();
-    body << tgfmt(comparator ? *comparator : "$_lhs == $_rhs", &ctx);
+    body << tgfmt(it.value().getComparator(), &ctx);
   };
   llvm::interleave(llvm::enumerate(params), body, eachFn, ") && (");
 }

diff  --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
index 77087b6b4f698..35ab19942ba13 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
@@ -56,6 +56,23 @@ class ParameterElement
   /// Returns the name of the parameter.
   StringRef getName() const { return param.getName(); }
 
+  /// Generate the code to check whether the parameter should be printed.
+  auto genPrintGuard(FmtContext &ctx) const {
+    return [&](raw_ostream &os) -> raw_ostream & {
+      std::string self = getParameterAccessorName(getName()) + "()";
+      ctx.withSelf(self);
+      os << tgfmt("($_self", &ctx);
+      if (llvm::Optional<StringRef> defaultValue =
+              getParam().getDefaultValue()) {
+        // Use the `comparator` field if it exists, else the equality operator.
+        std::string valueStr = tgfmt(*defaultValue, &ctx).str();
+        ctx.addSubst("_lhs", self).addSubst("_rhs", valueStr);
+        os << " && !(" << tgfmt(getParam().getComparator(), &ctx) << ")";
+      }
+      return os << ")";
+    };
+  }
+
 private:
   bool shouldBeQualifiedFlag = false;
   AttrOrTypeParameter param;
@@ -65,6 +82,12 @@ class ParameterElement
 static bool paramIsOptional(ParameterElement *el) { return el->isOptional(); }
 static bool paramNotOptional(ParameterElement *el) { return !el->isOptional(); }
 
+/// raw_ostream doesn't have an overload for stream functors. Declare one here.
+template <typename StreamFunctor>
+static raw_ostream &operator<<(raw_ostream &os, StreamFunctor &&fcn) {
+  return fcn(os);
+}
+
 /// Base class for a directive that contains references to multiple variables.
 template <DirectiveElement::Kind DirectiveKind>
 class ParamsDirectiveBase : public DirectiveElementBase<DirectiveKind> {
@@ -231,6 +254,7 @@ class DefFormat {
 void DefFormat::genParser(MethodBody &os) {
   FmtContext ctx;
   ctx.addSubst("_parser", "odsParser");
+  ctx.addSubst("_ctx", "odsParser.getContext()");
   if (isa<AttrDef>(def))
     ctx.addSubst("_type", "odsType");
   os.indent();
@@ -274,11 +298,16 @@ void DefFormat::genParser(MethodBody &os) {
                 def.getCppClassName());
   }
   for (const AttrOrTypeParameter &param : params) {
-    if (param.isOptional())
-      os << formatv(",\n    _result_{0}.getValueOr({1}())", param.getName(),
-                    param.getCppStorageType());
-    else
+    if (param.isOptional()) {
+      os << formatv(",\n    _result_{0}.getValueOr(", param.getName());
+      if (Optional<StringRef> defaultValue = param.getDefaultValue())
+        os << tgfmt(*defaultValue, &ctx);
+      else
+        os << param.getCppStorageType() << "()";
+      os << ")";
+    } else {
       os << formatv(",\n    *_result_{0}", param.getName());
+    }
   }
   os << ");";
 }
@@ -596,6 +625,7 @@ void DefFormat::genOptionalGroupParser(OptionalElement *el, FmtContext &ctx,
 void DefFormat::genPrinter(MethodBody &os) {
   FmtContext ctx;
   ctx.addSubst("_printer", "odsPrinter");
+  ctx.addSubst("_ctx", "getContext()");
   os.indent();
 
   /// Generate printers.
@@ -642,9 +672,10 @@ void DefFormat::genVariablePrinter(ParameterElement *el, FmtContext &ctx,
   const AttrOrTypeParameter &param = el->getParam();
   ctx.withSelf(getParameterAccessorName(param.getName()) + "()");
 
-  // Guard the printer on the presence of optional parameters.
+  // Guard the printer on the presence of optional parameters and that they
+  // aren't equal to their default values (if they have one).
   if (el->isOptional() && !skipGuard) {
-    os << tgfmt("if ($_self) {\n", &ctx);
+    os << "if (" << el->genPrintGuard(ctx) << ") {\n";
     os.indent();
   }
 
@@ -665,23 +696,27 @@ void DefFormat::genVariablePrinter(ParameterElement *el, FmtContext &ctx,
     os.unindent() << "}\n";
 }
 
+/// Generate code to guard printing on the presence of any optional parameters.
+template <typename ParameterRange>
+static void guardOnAny(FmtContext &ctx, MethodBody &os,
+                       ParameterRange &&params) {
+  os << "if (";
+  llvm::interleave(
+      params, os,
+      [&](ParameterElement *param) { os << param->genPrintGuard(ctx); },
+      " || ");
+  os << ") {\n";
+  os.indent();
+}
+
 void DefFormat::genCommaSeparatedPrinter(
     ArrayRef<ParameterElement *> params, FmtContext &ctx, MethodBody &os,
     function_ref<void(ParameterElement *)> extra) {
   // Emit a space if necessary, but only if the struct is present.
   if (shouldEmitSpace || !lastWasPunctuation) {
     bool allOptional = llvm::all_of(params, paramIsOptional);
-    if (allOptional) {
-      os << "if (";
-      llvm::interleave(
-          params, os,
-          [&](ParameterElement *param) {
-            os << getParameterAccessorName(param->getName()) << "()";
-          },
-          " || ");
-      os << ") {\n";
-      os.indent();
-    }
+    if (allOptional)
+      guardOnAny(ctx, os, params);
     os << tgfmt("$_printer << ' ';\n", &ctx);
     if (allOptional)
       os.unindent() << "}\n";
@@ -692,8 +727,7 @@ void DefFormat::genCommaSeparatedPrinter(
   os.indent() << "bool _firstPrinted = true;\n";
   for (ParameterElement *param : params) {
     if (param->isOptional()) {
-      os << tgfmt("if ($_self()) {\n",
-                  &ctx.withSelf(getParameterAccessorName(param->getName())));
+      os << "if (" << param->genPrintGuard(ctx) << ") {\n";
       os.indent();
     }
     os << tgfmt("if (!_firstPrinted) $_printer << \", \";\n", &ctx);
@@ -724,26 +758,14 @@ void DefFormat::genStructPrinter(StructDirective *el, FmtContext &ctx,
 
 void DefFormat::genOptionalGroupPrinter(OptionalElement *el, FmtContext &ctx,
                                         MethodBody &os) {
-  // Emit the check on whether the group should be printed.
-  const auto guardOn = [&](auto params) {
-    os << "if (";
-    llvm::interleave(
-        params, os,
-        [&](ParameterElement *el) {
-          os << getParameterAccessorName(el->getName()) << "()";
-        },
-        " || ");
-    os << ") {\n";
-    os.indent();
-  };
   FormatElement *anchor = el->getAnchor();
   if (auto *param = dyn_cast<ParameterElement>(anchor)) {
-    guardOn(llvm::makeArrayRef(param));
+    guardOnAny(ctx, os, llvm::makeArrayRef(param));
   } else if (auto *params = dyn_cast<ParamsDirective>(anchor)) {
-    guardOn(params->getParams());
+    guardOnAny(ctx, os, params->getParams());
   } else {
-    auto *strct = dyn_cast<StructDirective>(anchor);
-    guardOn(strct->getParams());
+    auto *strct = cast<StructDirective>(anchor);
+    guardOnAny(ctx, os, strct->getParams());
   }
   // Generate the printer for the contained elements.
   {


        


More information about the Mlir-commits mailing list