[Mlir-commits] [mlir] [mlir][IR] Improve attribute printing with serial comma (PR #87217)

Jakub Kuderski llvmlistbot at llvm.org
Sun Mar 31 21:25:20 PDT 2024


https://github.com/kuhar created https://github.com/llvm/llvm-project/pull/87217

Make attributes more readable by using serial comma (a.k.a. Oxford comma). This applies to sequences of length 3 and longer and reduces ambiguity.

The behavior is opt-in and enabled by a new flag `--mlir-print-serial-comma` with a convenience alias `--mlir-print-oxford-comma`.  Parsing support will be added in a follow-up PR.

For more context, see the RFC thread: TODO

>From 8b6e1ef6a3d07ea992bc2405aad9df315e39b43b Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Sun, 31 Mar 2024 22:24:38 -0400
Subject: [PATCH 1/2] [mlir][IR] Improve attribute printing with serial comma

Make attributes more readable by using serial comma (a.k.a. Oxford comma).
This applies to sequenes of length 3 and longer.

For now, this behavior is opt-in with the
`--mlir-print-use-serial-comma` flag, but we plan to enable it by
default in the coming release (LLVM 19).
---
 mlir/include/mlir/IR/OperationSupport.h  |  11 ++
 mlir/lib/IR/AsmPrinter.cpp               | 127 ++++++++++++++++-------
 mlir/test/IR/print-use-serial-comma.mlir |  19 ++++
 3 files changed, 119 insertions(+), 38 deletions(-)
 create mode 100644 mlir/test/IR/print-use-serial-comma.mlir

diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 90e63ff8fcb38f..486403a9d6d74c 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -1166,6 +1166,10 @@ class OpPrintingFlags {
   /// Print users of values as comments.
   OpPrintingFlags &printValueUsers();
 
+  /// Use serial comma (a.k.a. Oxford comma) when priting sequences of 3 or more
+  /// elements.
+  OpPrintingFlags &useSerialComma();
+
   /// Return if the given ElementsAttr should be elided.
   bool shouldElideElementsAttr(ElementsAttr attr) const;
 
@@ -1196,6 +1200,10 @@ class OpPrintingFlags {
   /// Return if the printer should print users of values.
   bool shouldPrintValueUsers() const;
 
+  /// Return if the printer should use serial comma when printing sequences of 3
+  /// or more elements.
+  bool shouldUseSerialComma() const;
+
 private:
   /// Elide large elements attributes if the number of elements is larger than
   /// the upper limit.
@@ -1222,6 +1230,9 @@ class OpPrintingFlags {
 
   /// Print users of values.
   bool printValueUsersFlag : 1;
+
+  /// Print sequences of 3 or more elements using serial comma.
+  bool printSerialComma : 1;
 };
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 456cf6a2c27783..d260fb078fed71 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -43,10 +43,12 @@
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/Endian.h"
+#include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/Regex.h"
 #include "llvm/Support/SaveAndRestore.h"
 #include "llvm/Support/Threading.h"
 #include "llvm/Support/raw_ostream.h"
+#include <cstdlib>
 #include <type_traits>
 
 #include <optional>
@@ -136,6 +138,15 @@ OpAsmDialectInterface::parseResource(AsmParsedResourceEntry &entry) const {
 // OpPrintingFlags
 //===----------------------------------------------------------------------===//
 
+static bool environmentRequiresSerialComma() {
+  StringRef currentLanguage = std::getenv("LANG");
+  if (currentLanguage.starts_with("en_US") ||
+      currentLanguage.starts_with("en_CA"))
+    return true;
+  llvm_unreachable(
+      "Unhandle corner case. This is very unlikely to happen in practice.");
+}
+
 namespace {
 /// This struct contains command line options that can be used to initialize
 /// various bits of the AsmPrinter. This uses a struct wrapper to avoid the need
@@ -189,6 +200,12 @@ struct AsmPrinterOptions {
       "mlir-print-value-users", llvm::cl::init(false),
       llvm::cl::desc(
           "Print users of operation results and block arguments as a comment")};
+
+  llvm::cl::opt<bool> printSerialComma{
+      "mlir-print-serial-comma",
+      llvm::cl::init(false && environmentRequiresSerialComma()),
+      llvm::cl::desc(
+          "Use serial comma when printing sequences of length 3 or more.")};
 };
 } // namespace
 
@@ -206,7 +223,7 @@ OpPrintingFlags::OpPrintingFlags()
     : printDebugInfoFlag(false), printDebugInfoPrettyFormFlag(false),
       printGenericOpFormFlag(false), skipRegionsFlag(false),
       assumeVerifiedFlag(false), printLocalScope(false),
-      printValueUsersFlag(false) {
+      printValueUsersFlag(false), printSerialComma(false) {
   // Initialize based upon command line options, if they are available.
   if (!clOptions.isConstructed())
     return;
@@ -221,6 +238,7 @@ OpPrintingFlags::OpPrintingFlags()
   printLocalScope = clOptions->printLocalScopeOpt;
   skipRegionsFlag = clOptions->skipRegionsOpt;
   printValueUsersFlag = clOptions->printValueUsers;
+  printSerialComma = clOptions->printSerialComma;
 }
 
 /// Enable the elision of large elements attributes, by printing a '...'
@@ -280,6 +298,12 @@ OpPrintingFlags &OpPrintingFlags::printValueUsers() {
   return *this;
 }
 
+/// Use serial comma when printing sequences with 3 or more elements.
+OpPrintingFlags &OpPrintingFlags::useSerialComma() {
+  printSerialComma = true;
+  return *this;
+}
+
 /// Return if the given ElementsAttr should be elided.
 bool OpPrintingFlags::shouldElideElementsAttr(ElementsAttr attr) const {
   return elementsAttrElementLimit &&
@@ -328,6 +352,9 @@ bool OpPrintingFlags::shouldPrintValueUsers() const {
   return printValueUsersFlag;
 }
 
+/// Return if the printer should use serial comma.
+bool OpPrintingFlags::shouldUseSerialComma() const { return printSerialComma; }
+
 /// Returns true if an ElementsAttr with the given number of elements should be
 /// printed with hex.
 static bool shouldPrintElementsAttrWithHex(int64_t numElements) {
@@ -377,8 +404,15 @@ class AsmPrinter::Impl {
   raw_ostream &getStream() { return os; }
 
   template <typename Container, typename UnaryFunctor>
-  inline void interleaveComma(const Container &c, UnaryFunctor eachFn) const {
-    llvm::interleaveComma(c, os, eachFn);
+  inline void interleaveComma(const Container &c, UnaryFunctor eachFn,
+                              bool useSerialComma = false) const {
+    size_t numElements = llvm::range_size(c);
+    if (!useSerialComma || numElements < 3)
+      return llvm::interleaveComma(c, os, eachFn);
+
+    llvm::interleaveComma(llvm::drop_end(c), os, eachFn);
+    os << ", and ";
+    llvm::interleaveComma(llvm::drop_begin(c, numElements - 1), os, eachFn);
   }
 
   /// This enum describes the different kinds of elision for the type of an
@@ -2228,8 +2262,10 @@ void AsmPrinter::Impl::printAttributeImpl(Attribute attr,
     return;
   } else if (auto dictAttr = llvm::dyn_cast<DictionaryAttr>(attr)) {
     os << '{';
-    interleaveComma(dictAttr.getValue(),
-                    [&](NamedAttribute attr) { printNamedAttribute(attr); });
+    interleaveComma(
+        dictAttr.getValue(),
+        [&](NamedAttribute attr) { printNamedAttribute(attr); },
+        printerFlags.shouldUseSerialComma());
     os << '}';
 
   } else if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr)) {
@@ -2264,9 +2300,10 @@ void AsmPrinter::Impl::printAttributeImpl(Attribute attr,
 
   } else if (auto arrayAttr = llvm::dyn_cast<ArrayAttr>(attr)) {
     os << '[';
-    interleaveComma(arrayAttr.getValue(), [&](Attribute attr) {
-      printAttribute(attr, AttrTypeElision::May);
-    });
+    interleaveComma(
+        arrayAttr.getValue(),
+        [&](Attribute attr) { printAttribute(attr, AttrTypeElision::May); },
+        printerFlags.shouldUseSerialComma());
     os << ']';
 
   } else if (auto affineMapAttr = llvm::dyn_cast<AffineMapAttr>(attr)) {
@@ -2370,9 +2407,10 @@ static void printDenseIntElement(const APInt &value, raw_ostream &os,
     value.print(os, !type.isUnsignedInteger());
 }
 
-static void
-printDenseElementsAttrImpl(bool isSplat, ShapedType type, raw_ostream &os,
-                           function_ref<void(unsigned)> printEltFn) {
+static void printDenseElementsAttrImpl(bool isSplat, ShapedType type,
+                                       raw_ostream &os,
+                                       function_ref<void(unsigned)> printEltFn,
+                                       bool useSerialComma) {
   // Special case for 0-d and splat tensors.
   if (isSplat)
     return printEltFn(0);
@@ -2407,9 +2445,11 @@ printDenseElementsAttrImpl(bool isSplat, ShapedType type, raw_ostream &os,
       }
   };
 
+  useSerialComma &= numElements >= 3;
   for (unsigned idx = 0, e = numElements; idx != e; ++idx) {
-    if (idx != 0)
-      os << ", ";
+    if (idx != 0) {
+      os << (useSerialComma && (idx + 1 == e) ? ", and " : ", ");
+    }
     while (openBrackets++ < rank)
       os << '[';
     openBrackets = rank;
@@ -2461,36 +2501,46 @@ void AsmPrinter::Impl::printDenseIntOrFPElementsAttr(
     // and hence was replaced.
     if (llvm::isa<IntegerType>(complexElementType)) {
       auto valueIt = attr.value_begin<std::complex<APInt>>();
-      printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
-        auto complexValue = *(valueIt + index);
-        os << "(";
-        printDenseIntElement(complexValue.real(), os, complexElementType);
-        os << ",";
-        printDenseIntElement(complexValue.imag(), os, complexElementType);
-        os << ")";
-      });
+      printDenseElementsAttrImpl(
+          attr.isSplat(), type, os,
+          [&](unsigned index) {
+            auto complexValue = *(valueIt + index);
+            os << "(";
+            printDenseIntElement(complexValue.real(), os, complexElementType);
+            os << ",";
+            printDenseIntElement(complexValue.imag(), os, complexElementType);
+            os << ")";
+          },
+          printerFlags.shouldUseSerialComma());
     } else {
       auto valueIt = attr.value_begin<std::complex<APFloat>>();
-      printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
-        auto complexValue = *(valueIt + index);
-        os << "(";
-        printFloatValue(complexValue.real(), os);
-        os << ",";
-        printFloatValue(complexValue.imag(), os);
-        os << ")";
-      });
+      printDenseElementsAttrImpl(
+          attr.isSplat(), type, os,
+          [&](unsigned index) {
+            auto complexValue = *(valueIt + index);
+            os << "(";
+            printFloatValue(complexValue.real(), os);
+            os << ",";
+            printFloatValue(complexValue.imag(), os);
+            os << ")";
+          },
+          printerFlags.shouldUseSerialComma());
     }
   } else if (elementType.isIntOrIndex()) {
     auto valueIt = attr.value_begin<APInt>();
-    printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
-      printDenseIntElement(*(valueIt + index), os, elementType);
-    });
+    printDenseElementsAttrImpl(
+        attr.isSplat(), type, os,
+        [&](unsigned index) {
+          printDenseIntElement(*(valueIt + index), os, elementType);
+        },
+        printerFlags.shouldUseSerialComma());
   } else {
     assert(llvm::isa<FloatType>(elementType) && "unexpected element type");
     auto valueIt = attr.value_begin<APFloat>();
-    printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
-      printFloatValue(*(valueIt + index), os);
-    });
+    printDenseElementsAttrImpl(
+        attr.isSplat(), type, os,
+        [&](unsigned index) { printFloatValue(*(valueIt + index), os); },
+        printerFlags.shouldUseSerialComma());
   }
 }
 
@@ -2498,7 +2548,8 @@ void AsmPrinter::Impl::printDenseStringElementsAttr(
     DenseStringElementsAttr attr) {
   ArrayRef<StringRef> data = attr.getRawStringData();
   auto printFn = [&](unsigned index) { printEscapedString(data[index]); };
-  printDenseElementsAttrImpl(attr.isSplat(), attr.getType(), os, printFn);
+  printDenseElementsAttrImpl(attr.isSplat(), attr.getType(), os, printFn,
+                             printerFlags.shouldUseSerialComma());
 }
 
 void AsmPrinter::Impl::printDenseArrayAttr(DenseArrayAttr attr) {
@@ -2522,8 +2573,8 @@ void AsmPrinter::Impl::printDenseArrayAttr(DenseArrayAttr attr) {
       printFloatValue(fltVal, getStream());
     }
   };
-  llvm::interleaveComma(llvm::seq<unsigned>(0, attr.size()), getStream(),
-                        printElementAt);
+  interleaveComma(llvm::seq<unsigned>(0, attr.size()), printElementAt,
+                  printerFlags.shouldUseSerialComma());
 }
 
 void AsmPrinter::Impl::printType(Type type) {
diff --git a/mlir/test/IR/print-use-serial-comma.mlir b/mlir/test/IR/print-use-serial-comma.mlir
new file mode 100644
index 00000000000000..5601e0bfdf5c1b
--- /dev/null
+++ b/mlir/test/IR/print-use-serial-comma.mlir
@@ -0,0 +1,19 @@
+// RUN: mlir-opt %s --mlir-print-serial-comma | FileCheck %s
+
+// CHECK: foo.dense_attr = dense<[1, 2, and 3]> : tensor<3xi32>
+"test.dense_attr"() {foo.dense_attr = dense<[1, 2, 3]> : tensor<3xi32>} : () -> ()
+
+// Nested attributes not supported, we should use 1-d vectors from the LLVM dialect anyway.
+// CHECK{LITERAL}: foo.dense_attr = dense<[[1, 2, 3], [4, 5, 6], [7, 8, and 9]]> : tensor<3x3xi32>
+"test.nested_dense_attr"() {foo.dense_attr = dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32>} : () -> ()
+
+// Two elements, serial comma not necessary.
+// CHECK: dense<[1, 2]> : tensor<2xi32>
+"test.non_elided_dense_attr"() {foo.dense_attr = dense<[1, 2]> : tensor<2xi32>} : () -> ()
+
+// CHECK{LITERAL}: sparse<[[0, 0, and 5]], -2.000000e+00> : vector<1x1x10xf16>
+"test.sparse_attr"() {foo.sparse_attr = sparse<[[0, 0, 5]],  -2.0> : vector<1x1x10xf16>} : () -> ()
+
+// One unique element, do not use serial comma.
+// CHECK: dense<1> : tensor<3xi32>
+"test.dense_splat"() {foo.dense_attr = dense<1> : tensor<3xi32>} : () -> ()

>From 5983efb9ce5c8ecc5f942ed34702bdda4e40d6fc Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Sun, 31 Mar 2024 22:35:00 -0400
Subject: [PATCH 2/2] Add alias flag for oxford comma

---
 mlir/lib/IR/AsmPrinter.cpp               | 7 ++++++-
 mlir/test/IR/print-use-serial-comma.mlir | 3 +++
 2 files changed, 9 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index d260fb078fed71..045102d156ddb2 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -144,7 +144,7 @@ static bool environmentRequiresSerialComma() {
       currentLanguage.starts_with("en_CA"))
     return true;
   llvm_unreachable(
-      "Unhandle corner case. This is very unlikely to happen in practice.");
+      "Unhandled corner case. This is very unlikely to happen in practice.");
 }
 
 namespace {
@@ -216,6 +216,11 @@ static llvm::ManagedStatic<AsmPrinterOptions> clOptions;
 void mlir::registerAsmPrinterCLOptions() {
   // Make sure that the options struct has been initialized.
   *clOptions;
+
+  static llvm::cl::alias printOxfordComma(
+      "mlir-print-oxford-comma",
+      llvm::cl::desc("Alias for --mlir-print-serial-comma"),
+      llvm::cl::aliasopt(clOptions->printSerialComma));
 }
 
 /// Initialize the printing flags with default supplied by the cl::opts above.
diff --git a/mlir/test/IR/print-use-serial-comma.mlir b/mlir/test/IR/print-use-serial-comma.mlir
index 5601e0bfdf5c1b..908e478705be4b 100644
--- a/mlir/test/IR/print-use-serial-comma.mlir
+++ b/mlir/test/IR/print-use-serial-comma.mlir
@@ -1,5 +1,8 @@
 // RUN: mlir-opt %s --mlir-print-serial-comma | FileCheck %s
 
+// Check that the alias flag is also recognized.
+// RUN: mlir-opt %s --mlir-print-oxford-comma | FileCheck %s
+
 // CHECK: foo.dense_attr = dense<[1, 2, and 3]> : tensor<3xi32>
 "test.dense_attr"() {foo.dense_attr = dense<[1, 2, 3]> : tensor<3xi32>} : () -> ()
 



More information about the Mlir-commits mailing list