[Mlir-commits] [mlir] [mlir][IR] Improve attribute printing with serial comma (PR #87217)
Jakub Kuderski
llvmlistbot at llvm.org
Sun Mar 31 21:25:20 PDT 2024
https://github.com/kuhar created https://github.com/llvm/llvm-project/pull/87217
Make attributes more readable by using serial comma (a.k.a. Oxford comma). This applies to sequences of length 3 and longer and reduces ambiguity.
The behavior is opt-in and enabled by a new flag `--mlir-print-serial-comma` with a convenience alias `--mlir-print-oxford-comma`. Parsing support will be added in a follow-up PR.
For more context, see the RFC thread: TODO
>From 8b6e1ef6a3d07ea992bc2405aad9df315e39b43b Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Sun, 31 Mar 2024 22:24:38 -0400
Subject: [PATCH 1/2] [mlir][IR] Improve attribute printing with serial comma
Make attributes more readable by using serial comma (a.k.a. Oxford comma).
This applies to sequenes of length 3 and longer.
For now, this behavior is opt-in with the
`--mlir-print-use-serial-comma` flag, but we plan to enable it by
default in the coming release (LLVM 19).
---
mlir/include/mlir/IR/OperationSupport.h | 11 ++
mlir/lib/IR/AsmPrinter.cpp | 127 ++++++++++++++++-------
mlir/test/IR/print-use-serial-comma.mlir | 19 ++++
3 files changed, 119 insertions(+), 38 deletions(-)
create mode 100644 mlir/test/IR/print-use-serial-comma.mlir
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 90e63ff8fcb38f..486403a9d6d74c 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -1166,6 +1166,10 @@ class OpPrintingFlags {
/// Print users of values as comments.
OpPrintingFlags &printValueUsers();
+ /// Use serial comma (a.k.a. Oxford comma) when priting sequences of 3 or more
+ /// elements.
+ OpPrintingFlags &useSerialComma();
+
/// Return if the given ElementsAttr should be elided.
bool shouldElideElementsAttr(ElementsAttr attr) const;
@@ -1196,6 +1200,10 @@ class OpPrintingFlags {
/// Return if the printer should print users of values.
bool shouldPrintValueUsers() const;
+ /// Return if the printer should use serial comma when printing sequences of 3
+ /// or more elements.
+ bool shouldUseSerialComma() const;
+
private:
/// Elide large elements attributes if the number of elements is larger than
/// the upper limit.
@@ -1222,6 +1230,9 @@ class OpPrintingFlags {
/// Print users of values.
bool printValueUsersFlag : 1;
+
+ /// Print sequences of 3 or more elements using serial comma.
+ bool printSerialComma : 1;
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 456cf6a2c27783..d260fb078fed71 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -43,10 +43,12 @@
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/Endian.h"
+#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/Regex.h"
#include "llvm/Support/SaveAndRestore.h"
#include "llvm/Support/Threading.h"
#include "llvm/Support/raw_ostream.h"
+#include <cstdlib>
#include <type_traits>
#include <optional>
@@ -136,6 +138,15 @@ OpAsmDialectInterface::parseResource(AsmParsedResourceEntry &entry) const {
// OpPrintingFlags
//===----------------------------------------------------------------------===//
+static bool environmentRequiresSerialComma() {
+ StringRef currentLanguage = std::getenv("LANG");
+ if (currentLanguage.starts_with("en_US") ||
+ currentLanguage.starts_with("en_CA"))
+ return true;
+ llvm_unreachable(
+ "Unhandle corner case. This is very unlikely to happen in practice.");
+}
+
namespace {
/// This struct contains command line options that can be used to initialize
/// various bits of the AsmPrinter. This uses a struct wrapper to avoid the need
@@ -189,6 +200,12 @@ struct AsmPrinterOptions {
"mlir-print-value-users", llvm::cl::init(false),
llvm::cl::desc(
"Print users of operation results and block arguments as a comment")};
+
+ llvm::cl::opt<bool> printSerialComma{
+ "mlir-print-serial-comma",
+ llvm::cl::init(false && environmentRequiresSerialComma()),
+ llvm::cl::desc(
+ "Use serial comma when printing sequences of length 3 or more.")};
};
} // namespace
@@ -206,7 +223,7 @@ OpPrintingFlags::OpPrintingFlags()
: printDebugInfoFlag(false), printDebugInfoPrettyFormFlag(false),
printGenericOpFormFlag(false), skipRegionsFlag(false),
assumeVerifiedFlag(false), printLocalScope(false),
- printValueUsersFlag(false) {
+ printValueUsersFlag(false), printSerialComma(false) {
// Initialize based upon command line options, if they are available.
if (!clOptions.isConstructed())
return;
@@ -221,6 +238,7 @@ OpPrintingFlags::OpPrintingFlags()
printLocalScope = clOptions->printLocalScopeOpt;
skipRegionsFlag = clOptions->skipRegionsOpt;
printValueUsersFlag = clOptions->printValueUsers;
+ printSerialComma = clOptions->printSerialComma;
}
/// Enable the elision of large elements attributes, by printing a '...'
@@ -280,6 +298,12 @@ OpPrintingFlags &OpPrintingFlags::printValueUsers() {
return *this;
}
+/// Use serial comma when printing sequences with 3 or more elements.
+OpPrintingFlags &OpPrintingFlags::useSerialComma() {
+ printSerialComma = true;
+ return *this;
+}
+
/// Return if the given ElementsAttr should be elided.
bool OpPrintingFlags::shouldElideElementsAttr(ElementsAttr attr) const {
return elementsAttrElementLimit &&
@@ -328,6 +352,9 @@ bool OpPrintingFlags::shouldPrintValueUsers() const {
return printValueUsersFlag;
}
+/// Return if the printer should use serial comma.
+bool OpPrintingFlags::shouldUseSerialComma() const { return printSerialComma; }
+
/// Returns true if an ElementsAttr with the given number of elements should be
/// printed with hex.
static bool shouldPrintElementsAttrWithHex(int64_t numElements) {
@@ -377,8 +404,15 @@ class AsmPrinter::Impl {
raw_ostream &getStream() { return os; }
template <typename Container, typename UnaryFunctor>
- inline void interleaveComma(const Container &c, UnaryFunctor eachFn) const {
- llvm::interleaveComma(c, os, eachFn);
+ inline void interleaveComma(const Container &c, UnaryFunctor eachFn,
+ bool useSerialComma = false) const {
+ size_t numElements = llvm::range_size(c);
+ if (!useSerialComma || numElements < 3)
+ return llvm::interleaveComma(c, os, eachFn);
+
+ llvm::interleaveComma(llvm::drop_end(c), os, eachFn);
+ os << ", and ";
+ llvm::interleaveComma(llvm::drop_begin(c, numElements - 1), os, eachFn);
}
/// This enum describes the different kinds of elision for the type of an
@@ -2228,8 +2262,10 @@ void AsmPrinter::Impl::printAttributeImpl(Attribute attr,
return;
} else if (auto dictAttr = llvm::dyn_cast<DictionaryAttr>(attr)) {
os << '{';
- interleaveComma(dictAttr.getValue(),
- [&](NamedAttribute attr) { printNamedAttribute(attr); });
+ interleaveComma(
+ dictAttr.getValue(),
+ [&](NamedAttribute attr) { printNamedAttribute(attr); },
+ printerFlags.shouldUseSerialComma());
os << '}';
} else if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr)) {
@@ -2264,9 +2300,10 @@ void AsmPrinter::Impl::printAttributeImpl(Attribute attr,
} else if (auto arrayAttr = llvm::dyn_cast<ArrayAttr>(attr)) {
os << '[';
- interleaveComma(arrayAttr.getValue(), [&](Attribute attr) {
- printAttribute(attr, AttrTypeElision::May);
- });
+ interleaveComma(
+ arrayAttr.getValue(),
+ [&](Attribute attr) { printAttribute(attr, AttrTypeElision::May); },
+ printerFlags.shouldUseSerialComma());
os << ']';
} else if (auto affineMapAttr = llvm::dyn_cast<AffineMapAttr>(attr)) {
@@ -2370,9 +2407,10 @@ static void printDenseIntElement(const APInt &value, raw_ostream &os,
value.print(os, !type.isUnsignedInteger());
}
-static void
-printDenseElementsAttrImpl(bool isSplat, ShapedType type, raw_ostream &os,
- function_ref<void(unsigned)> printEltFn) {
+static void printDenseElementsAttrImpl(bool isSplat, ShapedType type,
+ raw_ostream &os,
+ function_ref<void(unsigned)> printEltFn,
+ bool useSerialComma) {
// Special case for 0-d and splat tensors.
if (isSplat)
return printEltFn(0);
@@ -2407,9 +2445,11 @@ printDenseElementsAttrImpl(bool isSplat, ShapedType type, raw_ostream &os,
}
};
+ useSerialComma &= numElements >= 3;
for (unsigned idx = 0, e = numElements; idx != e; ++idx) {
- if (idx != 0)
- os << ", ";
+ if (idx != 0) {
+ os << (useSerialComma && (idx + 1 == e) ? ", and " : ", ");
+ }
while (openBrackets++ < rank)
os << '[';
openBrackets = rank;
@@ -2461,36 +2501,46 @@ void AsmPrinter::Impl::printDenseIntOrFPElementsAttr(
// and hence was replaced.
if (llvm::isa<IntegerType>(complexElementType)) {
auto valueIt = attr.value_begin<std::complex<APInt>>();
- printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
- auto complexValue = *(valueIt + index);
- os << "(";
- printDenseIntElement(complexValue.real(), os, complexElementType);
- os << ",";
- printDenseIntElement(complexValue.imag(), os, complexElementType);
- os << ")";
- });
+ printDenseElementsAttrImpl(
+ attr.isSplat(), type, os,
+ [&](unsigned index) {
+ auto complexValue = *(valueIt + index);
+ os << "(";
+ printDenseIntElement(complexValue.real(), os, complexElementType);
+ os << ",";
+ printDenseIntElement(complexValue.imag(), os, complexElementType);
+ os << ")";
+ },
+ printerFlags.shouldUseSerialComma());
} else {
auto valueIt = attr.value_begin<std::complex<APFloat>>();
- printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
- auto complexValue = *(valueIt + index);
- os << "(";
- printFloatValue(complexValue.real(), os);
- os << ",";
- printFloatValue(complexValue.imag(), os);
- os << ")";
- });
+ printDenseElementsAttrImpl(
+ attr.isSplat(), type, os,
+ [&](unsigned index) {
+ auto complexValue = *(valueIt + index);
+ os << "(";
+ printFloatValue(complexValue.real(), os);
+ os << ",";
+ printFloatValue(complexValue.imag(), os);
+ os << ")";
+ },
+ printerFlags.shouldUseSerialComma());
}
} else if (elementType.isIntOrIndex()) {
auto valueIt = attr.value_begin<APInt>();
- printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
- printDenseIntElement(*(valueIt + index), os, elementType);
- });
+ printDenseElementsAttrImpl(
+ attr.isSplat(), type, os,
+ [&](unsigned index) {
+ printDenseIntElement(*(valueIt + index), os, elementType);
+ },
+ printerFlags.shouldUseSerialComma());
} else {
assert(llvm::isa<FloatType>(elementType) && "unexpected element type");
auto valueIt = attr.value_begin<APFloat>();
- printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
- printFloatValue(*(valueIt + index), os);
- });
+ printDenseElementsAttrImpl(
+ attr.isSplat(), type, os,
+ [&](unsigned index) { printFloatValue(*(valueIt + index), os); },
+ printerFlags.shouldUseSerialComma());
}
}
@@ -2498,7 +2548,8 @@ void AsmPrinter::Impl::printDenseStringElementsAttr(
DenseStringElementsAttr attr) {
ArrayRef<StringRef> data = attr.getRawStringData();
auto printFn = [&](unsigned index) { printEscapedString(data[index]); };
- printDenseElementsAttrImpl(attr.isSplat(), attr.getType(), os, printFn);
+ printDenseElementsAttrImpl(attr.isSplat(), attr.getType(), os, printFn,
+ printerFlags.shouldUseSerialComma());
}
void AsmPrinter::Impl::printDenseArrayAttr(DenseArrayAttr attr) {
@@ -2522,8 +2573,8 @@ void AsmPrinter::Impl::printDenseArrayAttr(DenseArrayAttr attr) {
printFloatValue(fltVal, getStream());
}
};
- llvm::interleaveComma(llvm::seq<unsigned>(0, attr.size()), getStream(),
- printElementAt);
+ interleaveComma(llvm::seq<unsigned>(0, attr.size()), printElementAt,
+ printerFlags.shouldUseSerialComma());
}
void AsmPrinter::Impl::printType(Type type) {
diff --git a/mlir/test/IR/print-use-serial-comma.mlir b/mlir/test/IR/print-use-serial-comma.mlir
new file mode 100644
index 00000000000000..5601e0bfdf5c1b
--- /dev/null
+++ b/mlir/test/IR/print-use-serial-comma.mlir
@@ -0,0 +1,19 @@
+// RUN: mlir-opt %s --mlir-print-serial-comma | FileCheck %s
+
+// CHECK: foo.dense_attr = dense<[1, 2, and 3]> : tensor<3xi32>
+"test.dense_attr"() {foo.dense_attr = dense<[1, 2, 3]> : tensor<3xi32>} : () -> ()
+
+// Nested attributes not supported, we should use 1-d vectors from the LLVM dialect anyway.
+// CHECK{LITERAL}: foo.dense_attr = dense<[[1, 2, 3], [4, 5, 6], [7, 8, and 9]]> : tensor<3x3xi32>
+"test.nested_dense_attr"() {foo.dense_attr = dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32>} : () -> ()
+
+// Two elements, serial comma not necessary.
+// CHECK: dense<[1, 2]> : tensor<2xi32>
+"test.non_elided_dense_attr"() {foo.dense_attr = dense<[1, 2]> : tensor<2xi32>} : () -> ()
+
+// CHECK{LITERAL}: sparse<[[0, 0, and 5]], -2.000000e+00> : vector<1x1x10xf16>
+"test.sparse_attr"() {foo.sparse_attr = sparse<[[0, 0, 5]], -2.0> : vector<1x1x10xf16>} : () -> ()
+
+// One unique element, do not use serial comma.
+// CHECK: dense<1> : tensor<3xi32>
+"test.dense_splat"() {foo.dense_attr = dense<1> : tensor<3xi32>} : () -> ()
>From 5983efb9ce5c8ecc5f942ed34702bdda4e40d6fc Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Sun, 31 Mar 2024 22:35:00 -0400
Subject: [PATCH 2/2] Add alias flag for oxford comma
---
mlir/lib/IR/AsmPrinter.cpp | 7 ++++++-
mlir/test/IR/print-use-serial-comma.mlir | 3 +++
2 files changed, 9 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index d260fb078fed71..045102d156ddb2 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -144,7 +144,7 @@ static bool environmentRequiresSerialComma() {
currentLanguage.starts_with("en_CA"))
return true;
llvm_unreachable(
- "Unhandle corner case. This is very unlikely to happen in practice.");
+ "Unhandled corner case. This is very unlikely to happen in practice.");
}
namespace {
@@ -216,6 +216,11 @@ static llvm::ManagedStatic<AsmPrinterOptions> clOptions;
void mlir::registerAsmPrinterCLOptions() {
// Make sure that the options struct has been initialized.
*clOptions;
+
+ static llvm::cl::alias printOxfordComma(
+ "mlir-print-oxford-comma",
+ llvm::cl::desc("Alias for --mlir-print-serial-comma"),
+ llvm::cl::aliasopt(clOptions->printSerialComma));
}
/// Initialize the printing flags with default supplied by the cl::opts above.
diff --git a/mlir/test/IR/print-use-serial-comma.mlir b/mlir/test/IR/print-use-serial-comma.mlir
index 5601e0bfdf5c1b..908e478705be4b 100644
--- a/mlir/test/IR/print-use-serial-comma.mlir
+++ b/mlir/test/IR/print-use-serial-comma.mlir
@@ -1,5 +1,8 @@
// RUN: mlir-opt %s --mlir-print-serial-comma | FileCheck %s
+// Check that the alias flag is also recognized.
+// RUN: mlir-opt %s --mlir-print-oxford-comma | FileCheck %s
+
// CHECK: foo.dense_attr = dense<[1, 2, and 3]> : tensor<3xi32>
"test.dense_attr"() {foo.dense_attr = dense<[1, 2, 3]> : tensor<3xi32>} : () -> ()
More information about the Mlir-commits
mailing list