[Mlir-commits] [mlir] Add AsmParser::parseDecimalInteger. (PR #96255)

Jacques Pienaar llvmlistbot at llvm.org
Tue Jul 23 20:12:36 PDT 2024


https://github.com/jpienaar updated https://github.com/llvm/llvm-project/pull/96255

>From ea4ad1e91fabdc16cd3b700e5915746c70af4ef9 Mon Sep 17 00:00:00 2001
From: Stephen Goglin <sdgoglin at google.com>
Date: Thu, 20 Jun 2024 14:03:54 -0700
Subject: [PATCH 1/4] Add AsmParser::parseDecimalInteger.

An attribute parser needs to parse lists of possibly negative integers
separated by x in a way which is foiled by parseInteger handling hex
format and parseIntegerInDimensionList does not allow negatives.
---
 mlir/include/mlir/IR/OpImplementation.h | 35 ++++++++++++++++---
 mlir/lib/AsmParser/AsmParserImpl.h      |  5 +++
 mlir/lib/AsmParser/Parser.cpp           | 46 +++++++++++++++++++++++++
 mlir/lib/AsmParser/Parser.h             |  3 ++
 4 files changed, 84 insertions(+), 5 deletions(-)

diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index fa435cb3155ed..ae412c7227f8e 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -714,16 +714,27 @@ class AsmParser {
     return *parseResult;
   }
 
+  /// Parse a decimal integer value from the stream.
+  template <typename IntT>
+  ParseResult parseDecimalInteger(IntT &result) {
+    auto loc = getCurrentLocation();
+    OptionalParseResult parseResult = parseOptionalDecimalInteger(result);
+    if (!parseResult.has_value())
+      return emitError(loc, "expected decimal integer value");
+    return *parseResult;
+  }
+
   /// Parse an optional integer value from the stream.
   virtual OptionalParseResult parseOptionalInteger(APInt &result) = 0;
+  virtual OptionalParseResult parseOptionalDecimalInteger(APInt &result) = 0;
 
-  template <typename IntT>
-  OptionalParseResult parseOptionalInteger(IntT &result) {
+ private:
+  template <typename IntT, typename ParseFn>
+  OptionalParseResult parseOptionalIntegerAndCheck(IntT &result,
+                                                   ParseFn &&parseFn) {
     auto loc = getCurrentLocation();
-
-    // Parse the unsigned variant.
     APInt uintResult;
-    OptionalParseResult parseResult = parseOptionalInteger(uintResult);
+    OptionalParseResult parseResult = parseFn(uintResult);
     if (!parseResult.has_value() || failed(*parseResult))
       return parseResult;
 
@@ -737,6 +748,20 @@ class AsmParser {
     return success();
   }
 
+ public:
+  template <typename IntT>
+  OptionalParseResult parseOptionalInteger(IntT &result) {
+    return parseOptionalIntegerAndCheck(
+        result, [&](APInt &result) { return parseOptionalInteger(result); });
+  }
+
+  template <typename IntT>
+  OptionalParseResult parseOptionalDecimalInteger(IntT &result) {
+    return parseOptionalIntegerAndCheck(result, [&](APInt &result) {
+      return parseOptionalDecimalInteger(result);
+    });
+  }
+
   /// These are the supported delimiters around operand lists and region
   /// argument lists, used by parseOperandList.
   enum class Delimiter {
diff --git a/mlir/lib/AsmParser/AsmParserImpl.h b/mlir/lib/AsmParser/AsmParserImpl.h
index 8f22be80865bf..b12687833e3fd 100644
--- a/mlir/lib/AsmParser/AsmParserImpl.h
+++ b/mlir/lib/AsmParser/AsmParserImpl.h
@@ -322,6 +322,11 @@ class AsmParserImpl : public BaseT {
     return parser.parseOptionalInteger(result);
   }
 
+  /// Parse an optional integer value from the stream.
+  OptionalParseResult parseOptionalDecimalInteger(APInt &result) override {
+    return parser.parseOptionalDecimalInteger(result);
+  }
+
   /// Parse a list of comma-separated items with an optional delimiter.  If a
   /// delimiter is provided, then an empty list is allowed.  If not, then at
   /// least one element will be parsed.
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 7181f13d3c8bb..e207d2af72fc9 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -41,6 +41,7 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/Sequence.h"
+#include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/StringMap.h"
 #include "llvm/ADT/StringSet.h"
 #include "llvm/Support/Alignment.h"
@@ -307,6 +308,51 @@ OptionalParseResult Parser::parseOptionalInteger(APInt &result) {
   return success();
 }
 
+namespace {
+bool isBinOrHexOrOctIndicator(char c) {
+  return (llvm::toLower(c) == 'x' || llvm::toLower(c) == 'b' ||
+          llvm::isDigit(c));
+}
+} // namespace
+
+/// Parse an optional integer value only in decimal format from the stream.
+OptionalParseResult Parser::parseOptionalDecimalInteger(APInt &result) {
+  Token curToken = getToken();
+  if (curToken.isNot(Token::integer, Token::minus)) {
+    return std::nullopt;
+  }
+
+  bool negative = consumeIf(Token::minus);
+  Token curTok = getToken();
+  if (parseToken(Token::integer, "expected integer value")) {
+    return failure();
+  }
+
+  StringRef spelling = curTok.getSpelling();
+  // If the integer is in bin, hex, or oct format, return only the 0 and reset
+  // the lex pointer.
+  if (spelling[0] == '0' && spelling.size() > 1 &&
+      isBinOrHexOrOctIndicator(spelling[1])) {
+    result = 0;
+    state.lex.resetPointer(spelling.data() + 1);
+    consumeToken();
+    return success();
+  }
+
+  if (spelling.getAsInteger(10, result))
+    return emitError(curTok.getLoc(), "integer value too large");
+
+  // Make sure we have a zero at the top so we return the right signedness.
+  if (result.isNegative())
+    result = result.zext(result.getBitWidth() + 1);
+
+  // Process the negative sign if present.
+  if (negative)
+    result.negate();
+
+  return success();
+}
+
 /// Parse a floating point value from an integer literal token.
 ParseResult Parser::parseFloatFromIntegerLiteral(
     std::optional<APFloat> &result, const Token &tok, bool isNegative,
diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h
index b959e67b8e258..4caab499e1a0e 100644
--- a/mlir/lib/AsmParser/Parser.h
+++ b/mlir/lib/AsmParser/Parser.h
@@ -144,6 +144,9 @@ class Parser {
   /// Parse an optional integer value from the stream.
   OptionalParseResult parseOptionalInteger(APInt &result);
 
+  /// Parse an optional integer value only in decimal format from the stream.
+  OptionalParseResult parseOptionalDecimalInteger(APInt &result);
+
   /// Parse a floating point value from an integer literal token.
   ParseResult parseFloatFromIntegerLiteral(std::optional<APFloat> &result,
                                            const Token &tok, bool isNegative,

>From b1d49a343529e1c8da7ec5d53e1dd1ac9a3e05f5 Mon Sep 17 00:00:00 2001
From: Stephen Goglin <sdgoglin at google.com>
Date: Wed, 10 Jul 2024 15:52:57 -0700
Subject: [PATCH 2/4] Adding test attribute and test

---
 mlir/lib/AsmParser/Parser.cpp                 | 14 ++++---------
 mlir/test/lib/Dialect/Test/TestAttrDefs.td    |  9 ++++++++
 mlir/test/lib/Dialect/Test/TestAttributes.cpp | 21 +++++++++++++++++++
 .../mlir-tblgen/testdialect-attrdefs.mlir     | 15 ++++++++++++-
 4 files changed, 48 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index e207d2af72fc9..2e4c4a36d46b9 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -308,13 +308,6 @@ OptionalParseResult Parser::parseOptionalInteger(APInt &result) {
   return success();
 }
 
-namespace {
-bool isBinOrHexOrOctIndicator(char c) {
-  return (llvm::toLower(c) == 'x' || llvm::toLower(c) == 'b' ||
-          llvm::isDigit(c));
-}
-} // namespace
-
 /// Parse an optional integer value only in decimal format from the stream.
 OptionalParseResult Parser::parseOptionalDecimalInteger(APInt &result) {
   Token curToken = getToken();
@@ -329,10 +322,11 @@ OptionalParseResult Parser::parseOptionalDecimalInteger(APInt &result) {
   }
 
   StringRef spelling = curTok.getSpelling();
-  // If the integer is in bin, hex, or oct format, return only the 0 and reset
-  // the lex pointer.
+  // If the integer is in hexadecimal return only the 0. The lexer has already
+  // moved past the entire hexidecimal encoded integer so we reset the lex
+  // pointer to just past the 0 we actualy want to consume.
   if (spelling[0] == '0' && spelling.size() > 1 &&
-      isBinOrHexOrOctIndicator(spelling[1])) {
+      llvm::toLower(spelling[1]) == 'x') {
     result = 0;
     state.lex.resetPointer(spelling.data() + 1);
     consumeToken();
diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index a0a1cd30ed8ae..98aae5773e2d9 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -81,6 +81,15 @@ def AttrWithTrait : Test_Attr<"AttrWithTrait", [TestAttrTrait]> {
   let mnemonic = "attr_with_trait";
 }
 
+// An attribute of decimal formatted integer.
+def TestDecimalIntegerAttr : Test_Attr<"TestDecimalInteger"> {
+  let mnemonic = "decimal_integer";
+
+  let parameters = (ins "int64_t":$value);
+
+  let hasCustomAssemblyFormat = 1;
+}
+
 // Test support for ElementsAttrInterface.
 def TestI64ElementsAttr : Test_Attr<"TestI64Elements", [ElementsAttrInterface]> {
   let mnemonic = "i64_elements";
diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
index b66dfbfcf0895..9127c4d3c3ddc 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
@@ -13,9 +13,11 @@
 
 #include "TestAttributes.h"
 #include "TestDialect.h"
+#include "TestTypes.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/ExtensibleDialect.h"
+#include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/Types.h"
 #include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/Hashing.h"
@@ -63,6 +65,25 @@ void CompoundAAttr::print(AsmPrinter &printer) const {
 // CompoundAAttr
 //===----------------------------------------------------------------------===//
 
+
+Attribute TestDecimalIntegerAttr::parse(AsmParser &parser, Type type) {
+  if (parser.parseLess()){
+    return Attribute();
+  }
+  uint64_t intVal;
+  if (failed(*parser.parseOptionalDecimalInteger(intVal))) {
+    return Attribute();
+  }
+  if (parser.parseGreater()) {
+    return Attribute();
+  }
+  return get(parser.getContext(), intVal);
+}
+
+void TestDecimalIntegerAttr::print(AsmPrinter &printer) const {
+  printer << "<" << getValue() << ">";
+}
+
 Attribute TestI64ElementsAttr::parse(AsmParser &parser, Type type) {
   SmallVector<uint64_t> elements;
   if (parser.parseLess() || parser.parseLSquare())
diff --git a/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir b/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir
index ee92ea06a208c..d3179886b3390 100644
--- a/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir
+++ b/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s | mlir-opt -verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics | FileCheck %s
 
 // CHECK-LABEL: func private @compoundA()
 // CHECK-SAME: #test.cmpnd_a<1, !test.smpla, [5, 6]>
@@ -19,3 +19,16 @@ func.func private @qualifiedAttr() attributes {foo = #test.cmpnd_nested_outer_qu
 func.func private @overriddenAttr() attributes {
   foo = #test.override_builder<5>
 }
+
+// CHECK-LABEL: @decimalInteger
+// CHECK-SAME: foo = #test.decimal_integer<5>
+func.func private @decimalInteger() attributes {
+  foo = #test.decimal_integer<5>
+}
+
+// -----
+
+func.func private @hexdecimalInteger() attributes {
+// expected-error @below {{expected '>'}}
+  foo = #test.decimal_integer<0x5>
+}

>From 318c01388367d54bbcebeede6fbf8ce99fbc80fc Mon Sep 17 00:00:00 2001
From: Stephen Goglin <sdgoglin at google.com>
Date: Wed, 17 Jul 2024 15:24:07 -0700
Subject: [PATCH 3/4] Updating test attribute and test to model list of
 integers

---
 mlir/test/lib/Dialect/Test/TestAttrDefs.td    |  8 ++---
 mlir/test/lib/Dialect/Test/TestAttributes.cpp | 36 +++++++++++++------
 .../mlir-tblgen/testdialect-attrdefs.mlir     | 24 +++++++++----
 3 files changed, 48 insertions(+), 20 deletions(-)

diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index 98aae5773e2d9..8f109f8ce5e6d 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -81,11 +81,11 @@ def AttrWithTrait : Test_Attr<"AttrWithTrait", [TestAttrTrait]> {
   let mnemonic = "attr_with_trait";
 }
 
-// An attribute of decimal formatted integer.
-def TestDecimalIntegerAttr : Test_Attr<"TestDecimalInteger"> {
-  let mnemonic = "decimal_integer";
+// An attribute of a list of decimal formatted integers in similar format to shapes.
+def TestDecimalShapeAttr : Test_Attr<"TestDecimalShape"> {
+  let mnemonic = "decimal_shape";
 
-  let parameters = (ins "int64_t":$value);
+  let parameters = (ins ArrayRefParameter<"int64_t">:$shape);
 
   let hasCustomAssemblyFormat = 1;
 }
diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
index 9127c4d3c3ddc..17ffa80ca1023 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
@@ -14,6 +14,7 @@
 #include "TestAttributes.h"
 #include "TestDialect.h"
 #include "TestTypes.h"
+#include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/ExtensibleDialect.h"
@@ -66,22 +67,37 @@ void CompoundAAttr::print(AsmPrinter &printer) const {
 //===----------------------------------------------------------------------===//
 
 
-Attribute TestDecimalIntegerAttr::parse(AsmParser &parser, Type type) {
+Attribute TestDecimalShapeAttr::parse(AsmParser &parser, Type type) {
   if (parser.parseLess()){
     return Attribute();
   }
-  uint64_t intVal;
-  if (failed(*parser.parseOptionalDecimalInteger(intVal))) {
-    return Attribute();
-  }
-  if (parser.parseGreater()) {
-    return Attribute();
+  SmallVector<int64_t> shape;
+  if (parser.parseOptionalGreater()) {
+    auto parseDecimal = [&]() {
+      shape.emplace_back();
+      auto parseResult = parser.parseOptionalDecimalInteger(shape.back());
+      if (!parseResult.has_value() || failed(*parseResult)) {
+        parser.emitError(parser.getCurrentLocation()) << "expected an integer";
+        return failure();
+      }
+      return success();
+    };
+    if (failed(parseDecimal())) {
+      return Attribute();
+    }
+    while (failed(parser.parseOptionalGreater())) {
+      if (failed(parser.parseXInDimensionList()) || failed(parseDecimal())) {
+        return Attribute();
+      }
+    }
   }
-  return get(parser.getContext(), intVal);
+  return get(parser.getContext(), shape);
 }
 
-void TestDecimalIntegerAttr::print(AsmPrinter &printer) const {
-  printer << "<" << getValue() << ">";
+void TestDecimalShapeAttr::print(AsmPrinter &printer) const {
+  printer << "<";
+  llvm::interleave(getShape(), printer, "x");
+  printer << ">";
 }
 
 Attribute TestI64ElementsAttr::parse(AsmParser &parser, Type type) {
diff --git a/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir b/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir
index d3179886b3390..89ad3594eebd8 100644
--- a/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir
+++ b/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir
@@ -20,15 +20,27 @@ func.func private @overriddenAttr() attributes {
   foo = #test.override_builder<5>
 }
 
-// CHECK-LABEL: @decimalInteger
-// CHECK-SAME: foo = #test.decimal_integer<5>
-func.func private @decimalInteger() attributes {
-  foo = #test.decimal_integer<5>
+// CHECK-LABEL: @decimalIntegerShapeEmpty
+// CHECK-SAME: foo = #test.decimal_shape<>
+func.func private @decimalIntegerShapeEmpty() attributes {
+  foo = #test.decimal_shape<>
+}
+
+// CHECK-LABEL: @decimalIntegerShape
+// CHECK-SAME: foo = #test.decimal_shape<5>
+func.func private @decimalIntegerShape() attributes {
+  foo = #test.decimal_shape<5>
+}
+
+// CHECK-LABEL: @decimalIntegerShapeMultiple
+// CHECK-SAME: foo = #test.decimal_shape<0x3x7>
+func.func private @decimalIntegerShapeMultiple() attributes {
+  foo = #test.decimal_shape<0x3x7>
 }
 
 // -----
 
 func.func private @hexdecimalInteger() attributes {
-// expected-error @below {{expected '>'}}
-  foo = #test.decimal_integer<0x5>
+// expected-error @below {{expected an integer}}
+  sdg = #test.decimal_shape<1x0xb>
 }

>From 9731617570dc47eadffeedc2e8f2a6cfa5c0d673 Mon Sep 17 00:00:00 2001
From: Jacques Pienaar <jpienaar at google.com>
Date: Wed, 24 Jul 2024 03:12:22 +0000
Subject: [PATCH 4/4] Formatting

---
 mlir/test/lib/Dialect/Test/TestAttributes.cpp | 1 -
 1 file changed, 1 deletion(-)

diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
index 17ffa80ca1023..e09ea10906164 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
@@ -66,7 +66,6 @@ void CompoundAAttr::print(AsmPrinter &printer) const {
 // CompoundAAttr
 //===----------------------------------------------------------------------===//
 
-
 Attribute TestDecimalShapeAttr::parse(AsmParser &parser, Type type) {
   if (parser.parseLess()){
     return Attribute();



More information about the Mlir-commits mailing list