[Mlir-commits] [mlir] [mlir][IR] Improve attribute printing with serial comma (PR #87217)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Mar 31 21:25:47 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Jakub Kuderski (kuhar)
<details>
<summary>Changes</summary>
Make attributes more readable by using serial comma (a.k.a. Oxford comma). This applies to sequences of length 3 and longer and reduces ambiguity.
The behavior is opt-in and enabled by a new flag `--mlir-print-serial-comma` with a convenience alias `--mlir-print-oxford-comma`. Parsing support will be added in a follow-up PR.
For more context, see the RFC thread: TODO
---
Full diff: https://github.com/llvm/llvm-project/pull/87217.diff
3 Files Affected:
- (modified) mlir/include/mlir/IR/OperationSupport.h (+11)
- (modified) mlir/lib/IR/AsmPrinter.cpp (+94-38)
- (added) mlir/test/IR/print-use-serial-comma.mlir (+22)
``````````diff
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 90e63ff8fcb38f..486403a9d6d74c 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -1166,6 +1166,10 @@ class OpPrintingFlags {
/// Print users of values as comments.
OpPrintingFlags &printValueUsers();
+ /// Use serial comma (a.k.a. Oxford comma) when priting sequences of 3 or more
+ /// elements.
+ OpPrintingFlags &useSerialComma();
+
/// Return if the given ElementsAttr should be elided.
bool shouldElideElementsAttr(ElementsAttr attr) const;
@@ -1196,6 +1200,10 @@ class OpPrintingFlags {
/// Return if the printer should print users of values.
bool shouldPrintValueUsers() const;
+ /// Return if the printer should use serial comma when printing sequences of 3
+ /// or more elements.
+ bool shouldUseSerialComma() const;
+
private:
/// Elide large elements attributes if the number of elements is larger than
/// the upper limit.
@@ -1222,6 +1230,9 @@ class OpPrintingFlags {
/// Print users of values.
bool printValueUsersFlag : 1;
+
+ /// Print sequences of 3 or more elements using serial comma.
+ bool printSerialComma : 1;
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 456cf6a2c27783..045102d156ddb2 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -43,10 +43,12 @@
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/Endian.h"
+#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/Regex.h"
#include "llvm/Support/SaveAndRestore.h"
#include "llvm/Support/Threading.h"
#include "llvm/Support/raw_ostream.h"
+#include <cstdlib>
#include <type_traits>
#include <optional>
@@ -136,6 +138,15 @@ OpAsmDialectInterface::parseResource(AsmParsedResourceEntry &entry) const {
// OpPrintingFlags
//===----------------------------------------------------------------------===//
+static bool environmentRequiresSerialComma() {
+ StringRef currentLanguage = std::getenv("LANG");
+ if (currentLanguage.starts_with("en_US") ||
+ currentLanguage.starts_with("en_CA"))
+ return true;
+ llvm_unreachable(
+ "Unhandled corner case. This is very unlikely to happen in practice.");
+}
+
namespace {
/// This struct contains command line options that can be used to initialize
/// various bits of the AsmPrinter. This uses a struct wrapper to avoid the need
@@ -189,6 +200,12 @@ struct AsmPrinterOptions {
"mlir-print-value-users", llvm::cl::init(false),
llvm::cl::desc(
"Print users of operation results and block arguments as a comment")};
+
+ llvm::cl::opt<bool> printSerialComma{
+ "mlir-print-serial-comma",
+ llvm::cl::init(false && environmentRequiresSerialComma()),
+ llvm::cl::desc(
+ "Use serial comma when printing sequences of length 3 or more.")};
};
} // namespace
@@ -199,6 +216,11 @@ static llvm::ManagedStatic<AsmPrinterOptions> clOptions;
void mlir::registerAsmPrinterCLOptions() {
// Make sure that the options struct has been initialized.
*clOptions;
+
+ static llvm::cl::alias printOxfordComma(
+ "mlir-print-oxford-comma",
+ llvm::cl::desc("Alias for --mlir-print-serial-comma"),
+ llvm::cl::aliasopt(clOptions->printSerialComma));
}
/// Initialize the printing flags with default supplied by the cl::opts above.
@@ -206,7 +228,7 @@ OpPrintingFlags::OpPrintingFlags()
: printDebugInfoFlag(false), printDebugInfoPrettyFormFlag(false),
printGenericOpFormFlag(false), skipRegionsFlag(false),
assumeVerifiedFlag(false), printLocalScope(false),
- printValueUsersFlag(false) {
+ printValueUsersFlag(false), printSerialComma(false) {
// Initialize based upon command line options, if they are available.
if (!clOptions.isConstructed())
return;
@@ -221,6 +243,7 @@ OpPrintingFlags::OpPrintingFlags()
printLocalScope = clOptions->printLocalScopeOpt;
skipRegionsFlag = clOptions->skipRegionsOpt;
printValueUsersFlag = clOptions->printValueUsers;
+ printSerialComma = clOptions->printSerialComma;
}
/// Enable the elision of large elements attributes, by printing a '...'
@@ -280,6 +303,12 @@ OpPrintingFlags &OpPrintingFlags::printValueUsers() {
return *this;
}
+/// Use serial comma when printing sequences with 3 or more elements.
+OpPrintingFlags &OpPrintingFlags::useSerialComma() {
+ printSerialComma = true;
+ return *this;
+}
+
/// Return if the given ElementsAttr should be elided.
bool OpPrintingFlags::shouldElideElementsAttr(ElementsAttr attr) const {
return elementsAttrElementLimit &&
@@ -328,6 +357,9 @@ bool OpPrintingFlags::shouldPrintValueUsers() const {
return printValueUsersFlag;
}
+/// Return if the printer should use serial comma.
+bool OpPrintingFlags::shouldUseSerialComma() const { return printSerialComma; }
+
/// Returns true if an ElementsAttr with the given number of elements should be
/// printed with hex.
static bool shouldPrintElementsAttrWithHex(int64_t numElements) {
@@ -377,8 +409,15 @@ class AsmPrinter::Impl {
raw_ostream &getStream() { return os; }
template <typename Container, typename UnaryFunctor>
- inline void interleaveComma(const Container &c, UnaryFunctor eachFn) const {
- llvm::interleaveComma(c, os, eachFn);
+ inline void interleaveComma(const Container &c, UnaryFunctor eachFn,
+ bool useSerialComma = false) const {
+ size_t numElements = llvm::range_size(c);
+ if (!useSerialComma || numElements < 3)
+ return llvm::interleaveComma(c, os, eachFn);
+
+ llvm::interleaveComma(llvm::drop_end(c), os, eachFn);
+ os << ", and ";
+ llvm::interleaveComma(llvm::drop_begin(c, numElements - 1), os, eachFn);
}
/// This enum describes the different kinds of elision for the type of an
@@ -2228,8 +2267,10 @@ void AsmPrinter::Impl::printAttributeImpl(Attribute attr,
return;
} else if (auto dictAttr = llvm::dyn_cast<DictionaryAttr>(attr)) {
os << '{';
- interleaveComma(dictAttr.getValue(),
- [&](NamedAttribute attr) { printNamedAttribute(attr); });
+ interleaveComma(
+ dictAttr.getValue(),
+ [&](NamedAttribute attr) { printNamedAttribute(attr); },
+ printerFlags.shouldUseSerialComma());
os << '}';
} else if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr)) {
@@ -2264,9 +2305,10 @@ void AsmPrinter::Impl::printAttributeImpl(Attribute attr,
} else if (auto arrayAttr = llvm::dyn_cast<ArrayAttr>(attr)) {
os << '[';
- interleaveComma(arrayAttr.getValue(), [&](Attribute attr) {
- printAttribute(attr, AttrTypeElision::May);
- });
+ interleaveComma(
+ arrayAttr.getValue(),
+ [&](Attribute attr) { printAttribute(attr, AttrTypeElision::May); },
+ printerFlags.shouldUseSerialComma());
os << ']';
} else if (auto affineMapAttr = llvm::dyn_cast<AffineMapAttr>(attr)) {
@@ -2370,9 +2412,10 @@ static void printDenseIntElement(const APInt &value, raw_ostream &os,
value.print(os, !type.isUnsignedInteger());
}
-static void
-printDenseElementsAttrImpl(bool isSplat, ShapedType type, raw_ostream &os,
- function_ref<void(unsigned)> printEltFn) {
+static void printDenseElementsAttrImpl(bool isSplat, ShapedType type,
+ raw_ostream &os,
+ function_ref<void(unsigned)> printEltFn,
+ bool useSerialComma) {
// Special case for 0-d and splat tensors.
if (isSplat)
return printEltFn(0);
@@ -2407,9 +2450,11 @@ printDenseElementsAttrImpl(bool isSplat, ShapedType type, raw_ostream &os,
}
};
+ useSerialComma &= numElements >= 3;
for (unsigned idx = 0, e = numElements; idx != e; ++idx) {
- if (idx != 0)
- os << ", ";
+ if (idx != 0) {
+ os << (useSerialComma && (idx + 1 == e) ? ", and " : ", ");
+ }
while (openBrackets++ < rank)
os << '[';
openBrackets = rank;
@@ -2461,36 +2506,46 @@ void AsmPrinter::Impl::printDenseIntOrFPElementsAttr(
// and hence was replaced.
if (llvm::isa<IntegerType>(complexElementType)) {
auto valueIt = attr.value_begin<std::complex<APInt>>();
- printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
- auto complexValue = *(valueIt + index);
- os << "(";
- printDenseIntElement(complexValue.real(), os, complexElementType);
- os << ",";
- printDenseIntElement(complexValue.imag(), os, complexElementType);
- os << ")";
- });
+ printDenseElementsAttrImpl(
+ attr.isSplat(), type, os,
+ [&](unsigned index) {
+ auto complexValue = *(valueIt + index);
+ os << "(";
+ printDenseIntElement(complexValue.real(), os, complexElementType);
+ os << ",";
+ printDenseIntElement(complexValue.imag(), os, complexElementType);
+ os << ")";
+ },
+ printerFlags.shouldUseSerialComma());
} else {
auto valueIt = attr.value_begin<std::complex<APFloat>>();
- printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
- auto complexValue = *(valueIt + index);
- os << "(";
- printFloatValue(complexValue.real(), os);
- os << ",";
- printFloatValue(complexValue.imag(), os);
- os << ")";
- });
+ printDenseElementsAttrImpl(
+ attr.isSplat(), type, os,
+ [&](unsigned index) {
+ auto complexValue = *(valueIt + index);
+ os << "(";
+ printFloatValue(complexValue.real(), os);
+ os << ",";
+ printFloatValue(complexValue.imag(), os);
+ os << ")";
+ },
+ printerFlags.shouldUseSerialComma());
}
} else if (elementType.isIntOrIndex()) {
auto valueIt = attr.value_begin<APInt>();
- printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
- printDenseIntElement(*(valueIt + index), os, elementType);
- });
+ printDenseElementsAttrImpl(
+ attr.isSplat(), type, os,
+ [&](unsigned index) {
+ printDenseIntElement(*(valueIt + index), os, elementType);
+ },
+ printerFlags.shouldUseSerialComma());
} else {
assert(llvm::isa<FloatType>(elementType) && "unexpected element type");
auto valueIt = attr.value_begin<APFloat>();
- printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
- printFloatValue(*(valueIt + index), os);
- });
+ printDenseElementsAttrImpl(
+ attr.isSplat(), type, os,
+ [&](unsigned index) { printFloatValue(*(valueIt + index), os); },
+ printerFlags.shouldUseSerialComma());
}
}
@@ -2498,7 +2553,8 @@ void AsmPrinter::Impl::printDenseStringElementsAttr(
DenseStringElementsAttr attr) {
ArrayRef<StringRef> data = attr.getRawStringData();
auto printFn = [&](unsigned index) { printEscapedString(data[index]); };
- printDenseElementsAttrImpl(attr.isSplat(), attr.getType(), os, printFn);
+ printDenseElementsAttrImpl(attr.isSplat(), attr.getType(), os, printFn,
+ printerFlags.shouldUseSerialComma());
}
void AsmPrinter::Impl::printDenseArrayAttr(DenseArrayAttr attr) {
@@ -2522,8 +2578,8 @@ void AsmPrinter::Impl::printDenseArrayAttr(DenseArrayAttr attr) {
printFloatValue(fltVal, getStream());
}
};
- llvm::interleaveComma(llvm::seq<unsigned>(0, attr.size()), getStream(),
- printElementAt);
+ interleaveComma(llvm::seq<unsigned>(0, attr.size()), printElementAt,
+ printerFlags.shouldUseSerialComma());
}
void AsmPrinter::Impl::printType(Type type) {
diff --git a/mlir/test/IR/print-use-serial-comma.mlir b/mlir/test/IR/print-use-serial-comma.mlir
new file mode 100644
index 00000000000000..908e478705be4b
--- /dev/null
+++ b/mlir/test/IR/print-use-serial-comma.mlir
@@ -0,0 +1,22 @@
+// RUN: mlir-opt %s --mlir-print-serial-comma | FileCheck %s
+
+// Check that the alias flag is also recognized.
+// RUN: mlir-opt %s --mlir-print-oxford-comma | FileCheck %s
+
+// CHECK: foo.dense_attr = dense<[1, 2, and 3]> : tensor<3xi32>
+"test.dense_attr"() {foo.dense_attr = dense<[1, 2, 3]> : tensor<3xi32>} : () -> ()
+
+// Nested attributes not supported, we should use 1-d vectors from the LLVM dialect anyway.
+// CHECK{LITERAL}: foo.dense_attr = dense<[[1, 2, 3], [4, 5, 6], [7, 8, and 9]]> : tensor<3x3xi32>
+"test.nested_dense_attr"() {foo.dense_attr = dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32>} : () -> ()
+
+// Two elements, serial comma not necessary.
+// CHECK: dense<[1, 2]> : tensor<2xi32>
+"test.non_elided_dense_attr"() {foo.dense_attr = dense<[1, 2]> : tensor<2xi32>} : () -> ()
+
+// CHECK{LITERAL}: sparse<[[0, 0, and 5]], -2.000000e+00> : vector<1x1x10xf16>
+"test.sparse_attr"() {foo.sparse_attr = sparse<[[0, 0, 5]], -2.0> : vector<1x1x10xf16>} : () -> ()
+
+// One unique element, do not use serial comma.
+// CHECK: dense<1> : tensor<3xi32>
+"test.dense_splat"() {foo.dense_attr = dense<1> : tensor<3xi32>} : () -> ()
``````````
</details>
https://github.com/llvm/llvm-project/pull/87217
More information about the Mlir-commits
mailing list