[Mlir-commits] [mlir] Add AsmParser::parseDecimalInteger. (PR #96255)
Jacques Pienaar
llvmlistbot at llvm.org
Tue Jul 23 17:51:38 PDT 2024
https://github.com/jpienaar updated https://github.com/llvm/llvm-project/pull/96255
>From 22303bb7acce74718a5e191e2ffd038288c28e8b Mon Sep 17 00:00:00 2001
From: Stephen Goglin <sdgoglin at google.com>
Date: Thu, 20 Jun 2024 14:03:54 -0700
Subject: [PATCH 1/3] Add AsmParser::parseDecimalInteger.
An attribute parser needs to parse lists of possibly negative integers
separated by x in a way which is foiled by parseInteger handling hex
format and parseIntegerInDimensionList does not allow negatives.
---
mlir/include/mlir/IR/OpImplementation.h | 35 ++++++++++++++++---
mlir/lib/AsmParser/AsmParserImpl.h | 5 +++
mlir/lib/AsmParser/Parser.cpp | 46 +++++++++++++++++++++++++
mlir/lib/AsmParser/Parser.h | 3 ++
4 files changed, 84 insertions(+), 5 deletions(-)
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index fa435cb3155ed..ae412c7227f8e 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -714,16 +714,27 @@ class AsmParser {
return *parseResult;
}
+ /// Parse a decimal integer value from the stream.
+ template <typename IntT>
+ ParseResult parseDecimalInteger(IntT &result) {
+ auto loc = getCurrentLocation();
+ OptionalParseResult parseResult = parseOptionalDecimalInteger(result);
+ if (!parseResult.has_value())
+ return emitError(loc, "expected decimal integer value");
+ return *parseResult;
+ }
+
/// Parse an optional integer value from the stream.
virtual OptionalParseResult parseOptionalInteger(APInt &result) = 0;
+ virtual OptionalParseResult parseOptionalDecimalInteger(APInt &result) = 0;
- template <typename IntT>
- OptionalParseResult parseOptionalInteger(IntT &result) {
+ private:
+ template <typename IntT, typename ParseFn>
+ OptionalParseResult parseOptionalIntegerAndCheck(IntT &result,
+ ParseFn &&parseFn) {
auto loc = getCurrentLocation();
-
- // Parse the unsigned variant.
APInt uintResult;
- OptionalParseResult parseResult = parseOptionalInteger(uintResult);
+ OptionalParseResult parseResult = parseFn(uintResult);
if (!parseResult.has_value() || failed(*parseResult))
return parseResult;
@@ -737,6 +748,20 @@ class AsmParser {
return success();
}
+ public:
+ template <typename IntT>
+ OptionalParseResult parseOptionalInteger(IntT &result) {
+ return parseOptionalIntegerAndCheck(
+ result, [&](APInt &result) { return parseOptionalInteger(result); });
+ }
+
+ template <typename IntT>
+ OptionalParseResult parseOptionalDecimalInteger(IntT &result) {
+ return parseOptionalIntegerAndCheck(result, [&](APInt &result) {
+ return parseOptionalDecimalInteger(result);
+ });
+ }
+
/// These are the supported delimiters around operand lists and region
/// argument lists, used by parseOperandList.
enum class Delimiter {
diff --git a/mlir/lib/AsmParser/AsmParserImpl.h b/mlir/lib/AsmParser/AsmParserImpl.h
index 8f22be80865bf..b12687833e3fd 100644
--- a/mlir/lib/AsmParser/AsmParserImpl.h
+++ b/mlir/lib/AsmParser/AsmParserImpl.h
@@ -322,6 +322,11 @@ class AsmParserImpl : public BaseT {
return parser.parseOptionalInteger(result);
}
+ /// Parse an optional integer value from the stream.
+ OptionalParseResult parseOptionalDecimalInteger(APInt &result) override {
+ return parser.parseOptionalDecimalInteger(result);
+ }
+
/// Parse a list of comma-separated items with an optional delimiter. If a
/// delimiter is provided, then an empty list is allowed. If not, then at
/// least one element will be parsed.
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 1b8b4bac1821e..48aabb0163464 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -42,6 +42,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/Sequence.h"
+#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/Alignment.h"
@@ -308,6 +309,51 @@ OptionalParseResult Parser::parseOptionalInteger(APInt &result) {
return success();
}
+namespace {
+bool isBinOrHexOrOctIndicator(char c) {
+ return (llvm::toLower(c) == 'x' || llvm::toLower(c) == 'b' ||
+ llvm::isDigit(c));
+}
+} // namespace
+
+/// Parse an optional integer value only in decimal format from the stream.
+OptionalParseResult Parser::parseOptionalDecimalInteger(APInt &result) {
+ Token curToken = getToken();
+ if (curToken.isNot(Token::integer, Token::minus)) {
+ return std::nullopt;
+ }
+
+ bool negative = consumeIf(Token::minus);
+ Token curTok = getToken();
+ if (parseToken(Token::integer, "expected integer value")) {
+ return failure();
+ }
+
+ StringRef spelling = curTok.getSpelling();
+ // If the integer is in bin, hex, or oct format, return only the 0 and reset
+ // the lex pointer.
+ if (spelling[0] == '0' && spelling.size() > 1 &&
+ isBinOrHexOrOctIndicator(spelling[1])) {
+ result = 0;
+ state.lex.resetPointer(spelling.data() + 1);
+ consumeToken();
+ return success();
+ }
+
+ if (spelling.getAsInteger(10, result))
+ return emitError(curTok.getLoc(), "integer value too large");
+
+ // Make sure we have a zero at the top so we return the right signedness.
+ if (result.isNegative())
+ result = result.zext(result.getBitWidth() + 1);
+
+ // Process the negative sign if present.
+ if (negative)
+ result.negate();
+
+ return success();
+}
+
/// Parse a floating point value from an integer literal token.
ParseResult Parser::parseFloatFromIntegerLiteral(
std::optional<APFloat> &result, const Token &tok, bool isNegative,
diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h
index b959e67b8e258..4caab499e1a0e 100644
--- a/mlir/lib/AsmParser/Parser.h
+++ b/mlir/lib/AsmParser/Parser.h
@@ -144,6 +144,9 @@ class Parser {
/// Parse an optional integer value from the stream.
OptionalParseResult parseOptionalInteger(APInt &result);
+ /// Parse an optional integer value only in decimal format from the stream.
+ OptionalParseResult parseOptionalDecimalInteger(APInt &result);
+
/// Parse a floating point value from an integer literal token.
ParseResult parseFloatFromIntegerLiteral(std::optional<APFloat> &result,
const Token &tok, bool isNegative,
>From fb5746521e211a0f6a40f8ef27d68b163270a751 Mon Sep 17 00:00:00 2001
From: Stephen Goglin <sdgoglin at google.com>
Date: Wed, 10 Jul 2024 15:52:57 -0700
Subject: [PATCH 2/3] Adding test attribute and test
---
mlir/lib/AsmParser/Parser.cpp | 14 ++++---------
mlir/test/lib/Dialect/Test/TestAttrDefs.td | 9 ++++++++
mlir/test/lib/Dialect/Test/TestAttributes.cpp | 21 +++++++++++++++++++
.../mlir-tblgen/testdialect-attrdefs.mlir | 15 ++++++++++++-
4 files changed, 48 insertions(+), 11 deletions(-)
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 48aabb0163464..657ffd88aa260 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -309,13 +309,6 @@ OptionalParseResult Parser::parseOptionalInteger(APInt &result) {
return success();
}
-namespace {
-bool isBinOrHexOrOctIndicator(char c) {
- return (llvm::toLower(c) == 'x' || llvm::toLower(c) == 'b' ||
- llvm::isDigit(c));
-}
-} // namespace
-
/// Parse an optional integer value only in decimal format from the stream.
OptionalParseResult Parser::parseOptionalDecimalInteger(APInt &result) {
Token curToken = getToken();
@@ -330,10 +323,11 @@ OptionalParseResult Parser::parseOptionalDecimalInteger(APInt &result) {
}
StringRef spelling = curTok.getSpelling();
- // If the integer is in bin, hex, or oct format, return only the 0 and reset
- // the lex pointer.
+ // If the integer is in hexadecimal return only the 0. The lexer has already
+ // moved past the entire hexidecimal encoded integer so we reset the lex
+ // pointer to just past the 0 we actualy want to consume.
if (spelling[0] == '0' && spelling.size() > 1 &&
- isBinOrHexOrOctIndicator(spelling[1])) {
+ llvm::toLower(spelling[1]) == 'x') {
result = 0;
state.lex.resetPointer(spelling.data() + 1);
consumeToken();
diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index 12635e107bd42..d3de0f673cfb8 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -80,6 +80,15 @@ def AttrWithTrait : Test_Attr<"AttrWithTrait", [TestAttrTrait]> {
let mnemonic = "attr_with_trait";
}
+// An attribute of decimal formatted integer.
+def TestDecimalIntegerAttr : Test_Attr<"TestDecimalInteger"> {
+ let mnemonic = "decimal_integer";
+
+ let parameters = (ins "int64_t":$value);
+
+ let hasCustomAssemblyFormat = 1;
+}
+
// Test support for ElementsAttrInterface.
def TestI64ElementsAttr : Test_Attr<"TestI64Elements", [ElementsAttrInterface]> {
let mnemonic = "i64_elements";
diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
index d7e40d35238d9..bcdafd0567c65 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
@@ -13,9 +13,11 @@
#include "TestAttributes.h"
#include "TestDialect.h"
+#include "TestTypes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/ExtensibleDialect.h"
+#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/Types.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/APFloat.h"
@@ -64,6 +66,25 @@ void CompoundAAttr::print(AsmPrinter &printer) const {
// CompoundAAttr
//===----------------------------------------------------------------------===//
+
+Attribute TestDecimalIntegerAttr::parse(AsmParser &parser, Type type) {
+ if (parser.parseLess()){
+ return Attribute();
+ }
+ uint64_t intVal;
+ if (failed(*parser.parseOptionalDecimalInteger(intVal))) {
+ return Attribute();
+ }
+ if (parser.parseGreater()) {
+ return Attribute();
+ }
+ return get(parser.getContext(), intVal);
+}
+
+void TestDecimalIntegerAttr::print(AsmPrinter &printer) const {
+ printer << "<" << getValue() << ">";
+}
+
Attribute TestI64ElementsAttr::parse(AsmParser &parser, Type type) {
SmallVector<uint64_t> elements;
if (parser.parseLess() || parser.parseLSquare())
diff --git a/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir b/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir
index ee92ea06a208c..d3179886b3390 100644
--- a/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir
+++ b/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s | mlir-opt -verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics | FileCheck %s
// CHECK-LABEL: func private @compoundA()
// CHECK-SAME: #test.cmpnd_a<1, !test.smpla, [5, 6]>
@@ -19,3 +19,16 @@ func.func private @qualifiedAttr() attributes {foo = #test.cmpnd_nested_outer_qu
func.func private @overriddenAttr() attributes {
foo = #test.override_builder<5>
}
+
+// CHECK-LABEL: @decimalInteger
+// CHECK-SAME: foo = #test.decimal_integer<5>
+func.func private @decimalInteger() attributes {
+ foo = #test.decimal_integer<5>
+}
+
+// -----
+
+func.func private @hexdecimalInteger() attributes {
+// expected-error @below {{expected '>'}}
+ foo = #test.decimal_integer<0x5>
+}
>From 918c6c630ca16048267237e3c01f90e9e8adb374 Mon Sep 17 00:00:00 2001
From: Stephen Goglin <sdgoglin at google.com>
Date: Wed, 17 Jul 2024 15:24:07 -0700
Subject: [PATCH 3/3] Updating test attribute and test to model list of
integers
---
mlir/test/lib/Dialect/Test/TestAttrDefs.td | 8 ++---
mlir/test/lib/Dialect/Test/TestAttributes.cpp | 36 +++++++++++++------
.../mlir-tblgen/testdialect-attrdefs.mlir | 24 +++++++++----
3 files changed, 48 insertions(+), 20 deletions(-)
diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index d3de0f673cfb8..cf3457546a78a 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -80,11 +80,11 @@ def AttrWithTrait : Test_Attr<"AttrWithTrait", [TestAttrTrait]> {
let mnemonic = "attr_with_trait";
}
-// An attribute of decimal formatted integer.
-def TestDecimalIntegerAttr : Test_Attr<"TestDecimalInteger"> {
- let mnemonic = "decimal_integer";
+// An attribute of a list of decimal formatted integers in similar format to shapes.
+def TestDecimalShapeAttr : Test_Attr<"TestDecimalShape"> {
+ let mnemonic = "decimal_shape";
- let parameters = (ins "int64_t":$value);
+ let parameters = (ins ArrayRefParameter<"int64_t">:$shape);
let hasCustomAssemblyFormat = 1;
}
diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
index bcdafd0567c65..268c2b25e54aa 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
@@ -14,6 +14,7 @@
#include "TestAttributes.h"
#include "TestDialect.h"
#include "TestTypes.h"
+#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/ExtensibleDialect.h"
@@ -67,22 +68,37 @@ void CompoundAAttr::print(AsmPrinter &printer) const {
//===----------------------------------------------------------------------===//
-Attribute TestDecimalIntegerAttr::parse(AsmParser &parser, Type type) {
+Attribute TestDecimalShapeAttr::parse(AsmParser &parser, Type type) {
if (parser.parseLess()){
return Attribute();
}
- uint64_t intVal;
- if (failed(*parser.parseOptionalDecimalInteger(intVal))) {
- return Attribute();
- }
- if (parser.parseGreater()) {
- return Attribute();
+ SmallVector<int64_t> shape;
+ if (parser.parseOptionalGreater()) {
+ auto parseDecimal = [&]() {
+ shape.emplace_back();
+ auto parseResult = parser.parseOptionalDecimalInteger(shape.back());
+ if (!parseResult.has_value() || failed(*parseResult)) {
+ parser.emitError(parser.getCurrentLocation()) << "expected an integer";
+ return failure();
+ }
+ return success();
+ };
+ if (failed(parseDecimal())) {
+ return Attribute();
+ }
+ while (failed(parser.parseOptionalGreater())) {
+ if (failed(parser.parseXInDimensionList()) || failed(parseDecimal())) {
+ return Attribute();
+ }
+ }
}
- return get(parser.getContext(), intVal);
+ return get(parser.getContext(), shape);
}
-void TestDecimalIntegerAttr::print(AsmPrinter &printer) const {
- printer << "<" << getValue() << ">";
+void TestDecimalShapeAttr::print(AsmPrinter &printer) const {
+ printer << "<";
+ llvm::interleave(getShape(), printer, "x");
+ printer << ">";
}
Attribute TestI64ElementsAttr::parse(AsmParser &parser, Type type) {
diff --git a/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir b/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir
index d3179886b3390..89ad3594eebd8 100644
--- a/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir
+++ b/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir
@@ -20,15 +20,27 @@ func.func private @overriddenAttr() attributes {
foo = #test.override_builder<5>
}
-// CHECK-LABEL: @decimalInteger
-// CHECK-SAME: foo = #test.decimal_integer<5>
-func.func private @decimalInteger() attributes {
- foo = #test.decimal_integer<5>
+// CHECK-LABEL: @decimalIntegerShapeEmpty
+// CHECK-SAME: foo = #test.decimal_shape<>
+func.func private @decimalIntegerShapeEmpty() attributes {
+ foo = #test.decimal_shape<>
+}
+
+// CHECK-LABEL: @decimalIntegerShape
+// CHECK-SAME: foo = #test.decimal_shape<5>
+func.func private @decimalIntegerShape() attributes {
+ foo = #test.decimal_shape<5>
+}
+
+// CHECK-LABEL: @decimalIntegerShapeMultiple
+// CHECK-SAME: foo = #test.decimal_shape<0x3x7>
+func.func private @decimalIntegerShapeMultiple() attributes {
+ foo = #test.decimal_shape<0x3x7>
}
// -----
func.func private @hexdecimalInteger() attributes {
-// expected-error @below {{expected '>'}}
- foo = #test.decimal_integer<0x5>
+// expected-error @below {{expected an integer}}
+ sdg = #test.decimal_shape<1x0xb>
}
More information about the Mlir-commits
mailing list