[Mlir-commits] [mlir] [MLIR] Fix parsed integers exceeding the expected bitwidth (PR #119971)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Dec 14 06:35:30 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-llvm
Author: Ivan R. Ivanov (ivanradanov)
<details>
<summary>Changes</summary>
When the AsmParser used to parse integers for MLIR operations was used
with an APInt instead of a concrete C++ integer type, it silently
extended the bitwidth of the APInt when the parsed value was close to or
overflowed the pre-provided bitwidth. This resulted in MLIR attributes
with erroneous integer types in some cases.
This patch fixes this issue by requiring the required bitwdith to be
provided explictly when parsing into an APInt.
---
Full diff: https://github.com/llvm/llvm-project/pull/119971.diff
4 Files Affected:
- (modified) mlir/include/mlir/IR/OpImplementation.h (+26-1)
- (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp (+2-7)
- (modified) mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp (+1-1)
- (modified) mlir/test/Dialect/LLVMIR/nvvm.mlir (+16)
``````````diff
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 6c1ff4d0e5e6b9..7149635c830720 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -19,6 +19,7 @@
#include "llvm/ADT/Twine.h"
#include "llvm/Support/SMLoc.h"
#include <optional>
+#include <type_traits>
namespace mlir {
class AsmParsedResourceEntry;
@@ -710,12 +711,36 @@ class AsmParser {
virtual ParseResult parseFloat(const llvm::fltSemantics &semantics,
APFloat &result) = 0;
+ /// Parse an integer value from the stream.
+ ParseResult parseInteger(APInt &result, unsigned bitWidth,
+ bool isSignedOrSignless) {
+ auto loc = getCurrentLocation();
+ APInt apintResult;
+ OptionalParseResult parseResult = parseOptionalInteger(apintResult);
+ if (!parseResult.has_value() || failed(*parseResult))
+ return emitError(loc, "expected integer value");
+
+ // Unlike the parseOptionalInteger used below for integral types, the
+ // virtual APInt does not check for whether the parsed integer fits in the
+ // width we want or whether its signednes matches the requested one.. Check
+ // here.
+ if (!isSignedOrSignless && apintResult.isNegative())
+ return emitError(loc, "negative integer when unsigned expected");
+ APInt sextOrTrunc = apintResult.sextOrTrunc(bitWidth);
+ if (sextOrTrunc.sextOrTrunc(apintResult.getBitWidth()) != apintResult)
+ return emitError(loc, "integer value too large");
+
+ result = sextOrTrunc;
+ return success();
+ }
+
/// Parse an integer value from the stream.
template <typename IntT>
ParseResult parseInteger(IntT &result) {
+ static_assert(std::is_integral<IntT>::value);
auto loc = getCurrentLocation();
OptionalParseResult parseResult = parseOptionalInteger(result);
- if (!parseResult.has_value())
+ if (!parseResult.has_value() || failed(*parseResult))
return emitError(loc, "expected integer value");
return *parseResult;
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
index ee4e344674a67e..ae6e4ad0732af2 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
@@ -249,14 +249,9 @@ Attribute ConstantRangeAttr::parse(AsmParser &parser, Type odsType) {
unsigned bitWidth = widthType.getWidth();
APInt lower(bitWidth, 0);
APInt upper(bitWidth, 0);
- if (parser.parseInteger(lower) || parser.parseComma() ||
- parser.parseInteger(upper) || parser.parseGreater())
+ if (parser.parseInteger(lower, bitWidth, true) || parser.parseComma() ||
+ parser.parseInteger(upper, bitWidth, true) || parser.parseGreater())
return Attribute{};
- // For some reason, 0 is always parsed as 64-bits, fix that if needed.
- if (lower.isZero())
- lower = lower.sextOrTrunc(bitWidth);
- if (upper.isZero())
- upper = upper.sextOrTrunc(bitWidth);
return parser.getChecked<ConstantRangeAttr>(loc, parser.getContext(), lower,
upper);
}
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
index cd7789a2e9531c..63ebbabfcefed3 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
@@ -83,7 +83,7 @@ parseMonomial(AsmParser &parser, Monomial &monomial, llvm::StringRef &variable,
// If there's a **, then the integer exponent is required.
APInt parsedExponent(apintBitWidth, 0);
- if (failed(parser.parseInteger(parsedExponent))) {
+ if (failed(parser.parseInteger(parsedExponent, apintBitWidth, true))) {
parser.emitError(parser.getCurrentLocation(),
"found invalid integer exponent");
return failure();
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index a7bdceba01c1e8..0a1438f923a5a0 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -520,3 +520,19 @@ llvm.func @kernel_func(%arg0: !llvm.ptr {nvvm.grid_constant}) attributes {nvvm.k
llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant = true}) attributes {nvvm.kernel} {
llvm.return
}
+
+// -----
+
+func.func @nvvm_special_regs() {
+ // CHECK: nvvm.read.ptx.sreg.ctaid.x range <i32, 0, 2147483647> : i32
+ %0 = nvvm.read.ptx.sreg.ctaid.x range <i32, 0, 2147483647> : i32
+ func.return
+}
+
+// -----
+
+func.func @nvvm_special_regs() {
+ // expected-error @below {{integer value too large}}
+ %0 = nvvm.read.ptx.sreg.ctaid.x range <i32, 0, 2147483648> : i32
+ func.return
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/119971
More information about the Mlir-commits
mailing list