[Mlir-commits] [mlir] [MLIR] Fix parsed integers exceeding the expected bitwidth (PR #119971)

Ivan R. Ivanov llvmlistbot at llvm.org
Sat Dec 14 09:00:51 PST 2024


https://github.com/ivanradanov updated https://github.com/llvm/llvm-project/pull/119971

>From ab1ce208c3d77f6cdf90b9ca94f29c0865b5b5c1 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sat, 14 Dec 2024 23:26:03 +0900
Subject: [PATCH 1/2] [MLIR] Fix parsed integers exceeding the expected
 bitwidth

When the AsmParser used to parse integers for MLIR operations was used
with an APInt instead of a concrete C++ integer type, it silently
extended the bitwidth of the APInt when the parsed value was close to or
overflowed the pre-provided bitwidth. This resulted in MLIR attributes
with erroneous integer types in some cases.

This patch fixes this issue by requiring the required bitwdith to be
provided explictly when parsing into an APInt.
---
 mlir/include/mlir/IR/OpImplementation.h       | 27 ++++++++++++++++++-
 mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp      |  9 ++-----
 .../Polynomial/IR/PolynomialAttributes.cpp    |  2 +-
 mlir/test/Dialect/LLVMIR/nvvm.mlir            | 16 +++++++++++
 4 files changed, 45 insertions(+), 9 deletions(-)

diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 6c1ff4d0e5e6b9..7149635c830720 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -19,6 +19,7 @@
 #include "llvm/ADT/Twine.h"
 #include "llvm/Support/SMLoc.h"
 #include <optional>
+#include <type_traits>
 
 namespace mlir {
 class AsmParsedResourceEntry;
@@ -710,12 +711,36 @@ class AsmParser {
   virtual ParseResult parseFloat(const llvm::fltSemantics &semantics,
                                  APFloat &result) = 0;
 
+  /// Parse an integer value from the stream.
+  ParseResult parseInteger(APInt &result, unsigned bitWidth,
+                           bool isSignedOrSignless) {
+    auto loc = getCurrentLocation();
+    APInt apintResult;
+    OptionalParseResult parseResult = parseOptionalInteger(apintResult);
+    if (!parseResult.has_value() || failed(*parseResult))
+      return emitError(loc, "expected integer value");
+
+    // Unlike the parseOptionalInteger used below for integral types, the
+    // virtual APInt does not check for whether the parsed integer fits in the
+    // width we want or whether its signednes matches the requested one.. Check
+    // here.
+    if (!isSignedOrSignless && apintResult.isNegative())
+      return emitError(loc, "negative integer when unsigned expected");
+    APInt sextOrTrunc = apintResult.sextOrTrunc(bitWidth);
+    if (sextOrTrunc.sextOrTrunc(apintResult.getBitWidth()) != apintResult)
+      return emitError(loc, "integer value too large");
+
+    result = sextOrTrunc;
+    return success();
+  }
+
   /// Parse an integer value from the stream.
   template <typename IntT>
   ParseResult parseInteger(IntT &result) {
+    static_assert(std::is_integral<IntT>::value);
     auto loc = getCurrentLocation();
     OptionalParseResult parseResult = parseOptionalInteger(result);
-    if (!parseResult.has_value())
+    if (!parseResult.has_value() || failed(*parseResult))
       return emitError(loc, "expected integer value");
     return *parseResult;
   }
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
index ee4e344674a67e..ae6e4ad0732af2 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
@@ -249,14 +249,9 @@ Attribute ConstantRangeAttr::parse(AsmParser &parser, Type odsType) {
   unsigned bitWidth = widthType.getWidth();
   APInt lower(bitWidth, 0);
   APInt upper(bitWidth, 0);
-  if (parser.parseInteger(lower) || parser.parseComma() ||
-      parser.parseInteger(upper) || parser.parseGreater())
+  if (parser.parseInteger(lower, bitWidth, true) || parser.parseComma() ||
+      parser.parseInteger(upper, bitWidth, true) || parser.parseGreater())
     return Attribute{};
-  // For some reason, 0 is always parsed as 64-bits, fix that if needed.
-  if (lower.isZero())
-    lower = lower.sextOrTrunc(bitWidth);
-  if (upper.isZero())
-    upper = upper.sextOrTrunc(bitWidth);
   return parser.getChecked<ConstantRangeAttr>(loc, parser.getContext(), lower,
                                               upper);
 }
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
index cd7789a2e9531c..63ebbabfcefed3 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
@@ -83,7 +83,7 @@ parseMonomial(AsmParser &parser, Monomial &monomial, llvm::StringRef &variable,
 
     // If there's a **, then the integer exponent is required.
     APInt parsedExponent(apintBitWidth, 0);
-    if (failed(parser.parseInteger(parsedExponent))) {
+    if (failed(parser.parseInteger(parsedExponent, apintBitWidth, true))) {
       parser.emitError(parser.getCurrentLocation(),
                        "found invalid integer exponent");
       return failure();
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index a7bdceba01c1e8..0a1438f923a5a0 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -520,3 +520,19 @@ llvm.func @kernel_func(%arg0: !llvm.ptr {nvvm.grid_constant}) attributes {nvvm.k
 llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant = true}) attributes {nvvm.kernel} {
   llvm.return
 }
+
+// -----
+
+func.func @nvvm_special_regs() {
+  // CHECK: nvvm.read.ptx.sreg.ctaid.x range <i32, 0, 2147483647> : i32
+  %0 = nvvm.read.ptx.sreg.ctaid.x range <i32, 0, 2147483647> : i32
+  func.return
+}
+
+// -----
+
+func.func @nvvm_special_regs() {
+  // expected-error @below {{integer value too large}}
+  %0 = nvvm.read.ptx.sreg.ctaid.x range <i32, 0, 2147483648> : i32
+  func.return
+}

>From 37005c86fa5a927d9917723e0617161f1e04c610 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 15 Dec 2024 02:00:38 +0900
Subject: [PATCH 2/2] Address comments

---
 mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h    | 3 +++
 mlir/include/mlir/IR/OpImplementation.h                 | 4 ++--
 mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp                | 5 +++--
 mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp | 3 ++-
 4 files changed, 10 insertions(+), 5 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h
index 3f206cd1e545a2..eba772fc15fe07 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h
@@ -9,6 +9,7 @@
 #ifndef MLIR_DIALECT_POLYNOMIAL_IR_POLYNOMIAL_H_
 #define MLIR_DIALECT_POLYNOMIAL_IR_POLYNOMIAL_H_
 
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/APInt.h"
@@ -28,6 +29,8 @@ namespace polynomial {
 /// coefficients. This may be relaxed in the future, but it seems unlikely one
 /// would want to specify 128-bit polynomials statically in the source code.
 constexpr unsigned apintBitWidth = 64;
+constexpr IntegerType::SignednessSemantics parsingSignedness =
+    IntegerType::Signless;
 
 template <class Derived, typename CoefficientType>
 class MonomialBase {
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 7149635c830720..1288765da0bd15 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -713,7 +713,7 @@ class AsmParser {
 
   /// Parse an integer value from the stream.
   ParseResult parseInteger(APInt &result, unsigned bitWidth,
-                           bool isSignedOrSignless) {
+                           IntegerType::SignednessSemantics signedness) {
     auto loc = getCurrentLocation();
     APInt apintResult;
     OptionalParseResult parseResult = parseOptionalInteger(apintResult);
@@ -724,7 +724,7 @@ class AsmParser {
     // virtual APInt does not check for whether the parsed integer fits in the
     // width we want or whether its signednes matches the requested one.. Check
     // here.
-    if (!isSignedOrSignless && apintResult.isNegative())
+    if (signedness == IntegerType::Unsigned && apintResult.isNegative())
       return emitError(loc, "negative integer when unsigned expected");
     APInt sextOrTrunc = apintResult.sextOrTrunc(bitWidth);
     if (sextOrTrunc.sextOrTrunc(apintResult.getBitWidth()) != apintResult)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
index ae6e4ad0732af2..931b4a7b67ca5d 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
@@ -247,10 +247,11 @@ Attribute ConstantRangeAttr::parse(AsmParser &parser, Type odsType) {
     return Attribute{};
   }
   unsigned bitWidth = widthType.getWidth();
+  IntegerType::SignednessSemantics signedness = widthType.getSignedness();
   APInt lower(bitWidth, 0);
   APInt upper(bitWidth, 0);
-  if (parser.parseInteger(lower, bitWidth, true) || parser.parseComma() ||
-      parser.parseInteger(upper, bitWidth, true) || parser.parseGreater())
+  if (parser.parseInteger(lower, bitWidth, signedness) || parser.parseComma() ||
+      parser.parseInteger(upper, bitWidth, signedness) || parser.parseGreater())
     return Attribute{};
   return parser.getChecked<ConstantRangeAttr>(loc, parser.getContext(), lower,
                                               upper);
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
index 63ebbabfcefed3..ecde9bc363794a 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
@@ -83,7 +83,8 @@ parseMonomial(AsmParser &parser, Monomial &monomial, llvm::StringRef &variable,
 
     // If there's a **, then the integer exponent is required.
     APInt parsedExponent(apintBitWidth, 0);
-    if (failed(parser.parseInteger(parsedExponent, apintBitWidth, true))) {
+    if (failed(parser.parseInteger(parsedExponent, apintBitWidth,
+                                   parsingSignedness))) {
       parser.emitError(parser.getCurrentLocation(),
                        "found invalid integer exponent");
       return failure();



More information about the Mlir-commits mailing list