[Mlir-commits] [mlir] [mlir][ods] Do not print default-valued properties when the value is equal to the default (PR #87970)

Beal Wang llvmlistbot at llvm.org
Mon Apr 8 01:01:47 PDT 2024


https://github.com/bealwang created https://github.com/llvm/llvm-project/pull/87970

This diff causes the `tblgen`-erated printProperties() function to skip printing a `DefaultValuedAttr` property when the value is equal to the default.

>From 4ef902b1a83afc4201d45dc18d8753dcf2ea71c2 Mon Sep 17 00:00:00 2001
From: Biao Wang <biaow at nvidia.com>
Date: Mon, 8 Apr 2024 15:50:58 +0800
Subject: [PATCH] [mlir][ods] Do not print default-valued properties when the
 value is equal to the default

This diff causes the `tblgen`-erated printProperties() function to skip
printing a `DefaultValuedAttr` property when the value is equal to the
default.
---
 .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td |  3 +-
 mlir/include/mlir/IR/OpDefinition.h           | 31 ++++++++++++-------
 mlir/lib/IR/Operation.cpp                     | 30 ++++++++++++++----
 mlir/test/IR/properties.mlir                  |  5 +++
 mlir/test/lib/Dialect/Test/TestOps.td         | 17 +++++++---
 mlir/tools/mlir-tblgen/OpFormatGen.cpp        | 23 +++++++++++++-
 6 files changed, 85 insertions(+), 24 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 93c56ad05b432c..b8ebd1a40c6073 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -27,7 +27,8 @@ class XeGPU_Op<string mnemonic, list<Trait> traits = []>:
 
   code extraBaseClassDeclaration = [{
     void printProperties(::mlir::MLIRContext *ctx,
-            ::mlir::OpAsmPrinter &p, const Properties &prop) {
+            ::mlir::OpAsmPrinter &p, const Properties &prop,
+            ::mlir::ArrayRef<::llvm::StringRef> elidedProps) {
       Attribute propAttr = getPropertiesAsAttr(ctx, prop);
       if (propAttr)
         p << "<" << propAttr << ">";
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index c177ae3594d11f..24a54c18da701e 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -226,8 +226,10 @@ class OpState {
   static ParseResult genericParseProperties(OpAsmParser &parser,
                                             Attribute &result);
 
-  /// Print the properties as a Attribute.
-  static void genericPrintProperties(OpAsmPrinter &p, Attribute properties);
+  /// Print the properties as a Attribute with names not included within
+  /// 'elidedProps'
+  static void genericPrintProperties(OpAsmPrinter &p, Attribute properties,
+                                     ArrayRef<StringRef> elidedProps = {});
 
   /// Print an operation name, eliding the dialect prefix if necessary.
   static void printOpName(Operation *op, OpAsmPrinter &p,
@@ -1805,10 +1807,13 @@ class Op : public OpState, public Traits<ConcreteType>... {
   template <typename T>
   using detect_has_print = llvm::is_detected<has_print, T>;
 
-  /// Trait to check if printProperties(OpAsmPrinter, T) exist
+  /// Trait to check if printProperties(OpAsmPrinter, T, ArrayRef<StringRef>)
+  /// exist
   template <typename T, typename... Args>
-  using has_print_properties = decltype(printProperties(
-      std::declval<OpAsmPrinter &>(), std::declval<T>()));
+  using has_print_properties =
+      decltype(printProperties(std::declval<OpAsmPrinter &>(),
+                               std::declval<T>(),
+                               std::declval<ArrayRef<StringRef>>()));
   template <typename T>
   using detect_has_print_properties =
       llvm::is_detected<has_print_properties, T>;
@@ -1974,16 +1979,18 @@ class Op : public OpState, public Traits<ConcreteType>... {
   static void populateDefaultProperties(OperationName opName,
                                         InferredProperties<T> &properties) {}
 
-  /// Print the operation properties. Unless overridden, this method will try to
-  /// dispatch to a `printProperties` free-function if it exists, and otherwise
-  /// by converting the properties to an Attribute.
+  /// Print the operation properties with names not included within
+  /// 'elidedProps'. Unless overridden, this method will try to dispatch to a
+  /// `printProperties` free-function if it exists, and otherwise by converting
+  /// the properties to an Attribute.
   template <typename T>
   static void printProperties(MLIRContext *ctx, OpAsmPrinter &p,
-                              const T &properties) {
+                              const T &properties,
+                              ArrayRef<StringRef> elidedProps = {}) {
     if constexpr (detect_has_print_properties<T>::value)
-      return printProperties(p, properties);
-    genericPrintProperties(p,
-                           ConcreteType::getPropertiesAsAttr(ctx, properties));
+      return printProperties(p, properties, elidedProps);
+    genericPrintProperties(
+        p, ConcreteType::getPropertiesAsAttr(ctx, properties), elidedProps);
   }
 
   /// Parser the properties. Unless overridden, this method will print by
diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index ca5ff9f72e3e29..db903d540761b7 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -790,15 +790,33 @@ void OpState::printOpName(Operation *op, OpAsmPrinter &p,
 /// Parse properties as a Attribute.
 ParseResult OpState::genericParseProperties(OpAsmParser &parser,
                                             Attribute &result) {
-  if (parser.parseLess() || parser.parseAttribute(result) ||
-      parser.parseGreater())
-    return failure();
+  if (succeeded(parser.parseOptionalLess())) { // The less is optional.
+    if (parser.parseAttribute(result) || parser.parseGreater())
+      return failure();
+  }
   return success();
 }
 
-/// Print the properties as a Attribute.
-void OpState::genericPrintProperties(OpAsmPrinter &p, Attribute properties) {
-  p << "<" << properties << ">";
+/// Print the properties as a Attribute with names not included within
+/// 'elidedProps'
+void OpState::genericPrintProperties(OpAsmPrinter &p, Attribute properties,
+                                     ArrayRef<StringRef> elidedProps) {
+  auto dictAttr = dyn_cast_or_null<::mlir::DictionaryAttr>(properties);
+  if (dictAttr && !elidedProps.empty()) {
+    ArrayRef<NamedAttribute> attrs = dictAttr.getValue();
+    llvm::SmallDenseSet<StringRef> elidedAttrsSet(elidedProps.begin(),
+                                                  elidedProps.end());
+    bool atLeastOneAttr = llvm::any_of(attrs, [&](NamedAttribute attr) {
+      return !elidedAttrsSet.contains(attr.getName().strref());
+    });
+    if (atLeastOneAttr) {
+      p << "<";
+      p.printOptionalAttrDict(dictAttr.getValue(), elidedProps);
+      p << ">";
+    }
+  } else {
+    p << "<" << properties << ">";
+  }
 }
 
 /// Emit an error about fatal conditions with this operation, reporting up to
diff --git a/mlir/test/IR/properties.mlir b/mlir/test/IR/properties.mlir
index 3c4bd57859ef91..c9d3956f6de3f4 100644
--- a/mlir/test/IR/properties.mlir
+++ b/mlir/test/IR/properties.mlir
@@ -33,3 +33,8 @@ test.using_property_in_custom [1, 4, 20]
 // GENERIC-SAME: second = 4
 // GENERIC-SAME: }>
 test.using_property_ref_in_custom 1 + 4 = 5
+
+// CHECK:   test.with_default_valued_properties {{$}}
+// GENERIC: "test.with_default_valued_properties"()
+// GENERIC-SAME:  <{a = 0 : i32}> : () -> ()
+test.with_default_valued_properties <{a = 0 : i32}>
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index e6c3601d08dad0..edca05fde5a524 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2909,7 +2909,8 @@ def TestOpWithNiceProperties : TEST_Op<"with_nice_properties"> {
   );
   let extraClassDeclaration = [{
     void printProperties(::mlir::MLIRContext *ctx, ::mlir::OpAsmPrinter &p,
-                         const Properties &prop);
+                         const Properties &prop,
+                         ::mlir::ArrayRef<::llvm::StringRef> elidedProps);
     static ::mlir::ParseResult  parseProperties(::mlir::OpAsmParser &parser,
                                      ::mlir::OperationState &result);
     static ::mlir::LogicalResult readFromMlirBytecode(
@@ -2938,7 +2939,8 @@ def TestOpWithNiceProperties : TEST_Op<"with_nice_properties"> {
       writer.writeVarInt(prop.value);
     }
     void TestOpWithNiceProperties::printProperties(::mlir::MLIRContext *ctx,
-            ::mlir::OpAsmPrinter &p, const Properties &prop) {
+            ::mlir::OpAsmPrinter &p, const Properties &prop,
+            ::mlir::ArrayRef<::llvm::StringRef> elidedProps) {
       customPrintProperties(p, prop.prop);
     }
     ::mlir::ParseResult TestOpWithNiceProperties::parseProperties(
@@ -2971,7 +2973,8 @@ def TestOpWithVersionedProperties : TEST_Op<"with_versioned_properties"> {
   );
   let extraClassDeclaration = [{
     void printProperties(::mlir::MLIRContext *ctx, ::mlir::OpAsmPrinter &p,
-                         const Properties &prop);
+                         const Properties &prop,
+                         ::mlir::ArrayRef<::llvm::StringRef> elidedProps);
     static ::mlir::ParseResult  parseProperties(::mlir::OpAsmParser &parser,
                                      ::mlir::OperationState &result);
     static ::mlir::LogicalResult readFromMlirBytecode(
@@ -2983,7 +2986,8 @@ def TestOpWithVersionedProperties : TEST_Op<"with_versioned_properties"> {
   }];
   let extraClassDefinition = [{
     void TestOpWithVersionedProperties::printProperties(::mlir::MLIRContext *ctx,
-            ::mlir::OpAsmPrinter &p, const Properties &prop) {
+            ::mlir::OpAsmPrinter &p, const Properties &prop,
+            ::mlir::ArrayRef<::llvm::StringRef> elidedProps) {
       customPrintProperties(p, prop.prop);
     }
     ::mlir::ParseResult TestOpWithVersionedProperties::parseProperties(
@@ -2997,6 +3001,11 @@ def TestOpWithVersionedProperties : TEST_Op<"with_versioned_properties"> {
   }];
 }
 
+def TestOpWithDefaultValuedProperties : TEST_Op<"with_default_valued_properties"> {
+  let assemblyFormat = "prop-dict attr-dict";
+  let arguments = (ins DefaultValuedAttr<I32Attr, "0">:$a);
+}
+
 //===----------------------------------------------------------------------===//
 // Test Dataflow
 //===----------------------------------------------------------------------===//
diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index c8e0476d45b9a3..5963b5e689da74 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -1775,9 +1775,30 @@ const char *enumAttrBeginPrinterCode = R"(
 /// Generate the printer for the 'prop-dict' directive.
 static void genPropDictPrinter(OperationFormat &fmt, Operator &op,
                                MethodBody &body) {
+  body << "  ::llvm::SmallVector<::llvm::StringRef, 2> elidedProps;\n";
+  // Add code to check attributes for equality with the default value
+  // for attributes with the elidePrintingDefaultValue bit set.
+  for (const NamedAttribute &namedAttr : op.getAttributes()) {
+    const Attribute &attr = namedAttr.attr;
+    if (!attr.isDerivedAttr() && attr.hasDefaultValue()) {
+      const StringRef &name = namedAttr.name;
+      FmtContext fctx;
+      fctx.withBuilder("odsBuilder");
+      std::string defaultValue = std::string(
+          tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
+      body << "  {\n";
+      body << "     ::mlir::Builder odsBuilder(getContext());\n";
+      body << "     ::mlir::Attribute attr = " << op.getGetterName(name)
+           << "Attr();\n";
+      body << "     if(attr && (attr == " << defaultValue << "))\n";
+      body << "       elidedProps.push_back(\"" << name << "\");\n";
+      body << "  }\n";
+    }
+  }
+
   body << "  _odsPrinter << \" \";\n"
        << "  printProperties(this->getContext(), _odsPrinter, "
-          "getProperties());\n";
+          "getProperties(), elidedProps);\n";
 }
 
 /// Generate the printer for the 'attr-dict' directive.



More information about the Mlir-commits mailing list