[Mlir-commits] [mlir] [MLIR] Fix parsed integers exceeding the expected bitwidth (PR #119971)

Ivan R. Ivanov llvmlistbot at llvm.org
Sat Dec 14 06:34:53 PST 2024


https://github.com/ivanradanov created https://github.com/llvm/llvm-project/pull/119971

When the AsmParser used to parse integers for MLIR operations was used
with an APInt instead of a concrete C++ integer type, it silently
extended the bitwidth of the APInt when the parsed value was close to or
overflowed the pre-provided bitwidth. This resulted in MLIR attributes
with erroneous integer types in some cases.

This patch fixes this issue by requiring the required bitwdith to be
provided explictly when parsing into an APInt.


>From ab1ce208c3d77f6cdf90b9ca94f29c0865b5b5c1 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sat, 14 Dec 2024 23:26:03 +0900
Subject: [PATCH] [MLIR] Fix parsed integers exceeding the expected bitwidth

When the AsmParser used to parse integers for MLIR operations was used
with an APInt instead of a concrete C++ integer type, it silently
extended the bitwidth of the APInt when the parsed value was close to or
overflowed the pre-provided bitwidth. This resulted in MLIR attributes
with erroneous integer types in some cases.

This patch fixes this issue by requiring the required bitwdith to be
provided explictly when parsing into an APInt.
---
 mlir/include/mlir/IR/OpImplementation.h       | 27 ++++++++++++++++++-
 mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp      |  9 ++-----
 .../Polynomial/IR/PolynomialAttributes.cpp    |  2 +-
 mlir/test/Dialect/LLVMIR/nvvm.mlir            | 16 +++++++++++
 4 files changed, 45 insertions(+), 9 deletions(-)

diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 6c1ff4d0e5e6b9..7149635c830720 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -19,6 +19,7 @@
 #include "llvm/ADT/Twine.h"
 #include "llvm/Support/SMLoc.h"
 #include <optional>
+#include <type_traits>
 
 namespace mlir {
 class AsmParsedResourceEntry;
@@ -710,12 +711,36 @@ class AsmParser {
   virtual ParseResult parseFloat(const llvm::fltSemantics &semantics,
                                  APFloat &result) = 0;
 
+  /// Parse an integer value from the stream.
+  ParseResult parseInteger(APInt &result, unsigned bitWidth,
+                           bool isSignedOrSignless) {
+    auto loc = getCurrentLocation();
+    APInt apintResult;
+    OptionalParseResult parseResult = parseOptionalInteger(apintResult);
+    if (!parseResult.has_value() || failed(*parseResult))
+      return emitError(loc, "expected integer value");
+
+    // Unlike the parseOptionalInteger used below for integral types, the
+    // virtual APInt does not check for whether the parsed integer fits in the
+    // width we want or whether its signednes matches the requested one.. Check
+    // here.
+    if (!isSignedOrSignless && apintResult.isNegative())
+      return emitError(loc, "negative integer when unsigned expected");
+    APInt sextOrTrunc = apintResult.sextOrTrunc(bitWidth);
+    if (sextOrTrunc.sextOrTrunc(apintResult.getBitWidth()) != apintResult)
+      return emitError(loc, "integer value too large");
+
+    result = sextOrTrunc;
+    return success();
+  }
+
   /// Parse an integer value from the stream.
   template <typename IntT>
   ParseResult parseInteger(IntT &result) {
+    static_assert(std::is_integral<IntT>::value);
     auto loc = getCurrentLocation();
     OptionalParseResult parseResult = parseOptionalInteger(result);
-    if (!parseResult.has_value())
+    if (!parseResult.has_value() || failed(*parseResult))
       return emitError(loc, "expected integer value");
     return *parseResult;
   }
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
index ee4e344674a67e..ae6e4ad0732af2 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
@@ -249,14 +249,9 @@ Attribute ConstantRangeAttr::parse(AsmParser &parser, Type odsType) {
   unsigned bitWidth = widthType.getWidth();
   APInt lower(bitWidth, 0);
   APInt upper(bitWidth, 0);
-  if (parser.parseInteger(lower) || parser.parseComma() ||
-      parser.parseInteger(upper) || parser.parseGreater())
+  if (parser.parseInteger(lower, bitWidth, true) || parser.parseComma() ||
+      parser.parseInteger(upper, bitWidth, true) || parser.parseGreater())
     return Attribute{};
-  // For some reason, 0 is always parsed as 64-bits, fix that if needed.
-  if (lower.isZero())
-    lower = lower.sextOrTrunc(bitWidth);
-  if (upper.isZero())
-    upper = upper.sextOrTrunc(bitWidth);
   return parser.getChecked<ConstantRangeAttr>(loc, parser.getContext(), lower,
                                               upper);
 }
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
index cd7789a2e9531c..63ebbabfcefed3 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
@@ -83,7 +83,7 @@ parseMonomial(AsmParser &parser, Monomial &monomial, llvm::StringRef &variable,
 
     // If there's a **, then the integer exponent is required.
     APInt parsedExponent(apintBitWidth, 0);
-    if (failed(parser.parseInteger(parsedExponent))) {
+    if (failed(parser.parseInteger(parsedExponent, apintBitWidth, true))) {
       parser.emitError(parser.getCurrentLocation(),
                        "found invalid integer exponent");
       return failure();
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index a7bdceba01c1e8..0a1438f923a5a0 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -520,3 +520,19 @@ llvm.func @kernel_func(%arg0: !llvm.ptr {nvvm.grid_constant}) attributes {nvvm.k
 llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant = true}) attributes {nvvm.kernel} {
   llvm.return
 }
+
+// -----
+
+func.func @nvvm_special_regs() {
+  // CHECK: nvvm.read.ptx.sreg.ctaid.x range <i32, 0, 2147483647> : i32
+  %0 = nvvm.read.ptx.sreg.ctaid.x range <i32, 0, 2147483647> : i32
+  func.return
+}
+
+// -----
+
+func.func @nvvm_special_regs() {
+  // expected-error @below {{integer value too large}}
+  %0 = nvvm.read.ptx.sreg.ctaid.x range <i32, 0, 2147483648> : i32
+  func.return
+}



More information about the Mlir-commits mailing list