[Mlir-commits] [mlir] [MLIR] Make compatible with APInt ctor assertion (PR #110466)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Sep 30 01:36:37 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-spirv

Author: Nikita Popov (nikic)

<details>
<summary>Changes</summary>

This fixes all the places in MLIR that hit the new assertion added in #<!-- -->106524, in preparation for enabling it by default. That is, cases where the value passed to the APInt constructor is not an N-bit signed/unsigned integer, where N is the bit width and signedness is determined by the isSigned flag.

The fixes either set the correct value for isSigned, or set the implicitTrunc flag to retain the old behavior. I've left TODOs for the latter case in some places, where I think that it may be worthwhile to stop doing implicit truncation in the future.

Note that the assertion is currently still disabled by default, so this patch is mostly NFC.

This is just the MLIR changes split off from https://github.com/llvm/llvm-project/pull/80309.

---
Full diff: https://github.com/llvm/llvm-project/pull/110466.diff


9 Files Affected:

- (modified) mlir/include/mlir/IR/BuiltinAttributes.td (+3-1) 
- (modified) mlir/include/mlir/IR/OpImplementation.h (+2-1) 
- (modified) mlir/lib/Conversion/TosaToArith/TosaToArith.cpp (+1-1) 
- (modified) mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp (+1-1) 
- (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp (+1-1) 
- (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+1-1) 
- (modified) mlir/lib/IR/Builders.cpp (+12-4) 
- (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp (+4-2) 
- (modified) mlir/unittests/Dialect/SPIRV/SerializationTest.cpp (+1-1) 


``````````diff
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index f0d41754001400..530ba7d2f11e5c 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -701,8 +701,10 @@ def Builtin_IntegerAttr : Builtin_Attr<"Integer", "integer",
         return $_get(type.getContext(), type, apValue);
       }
 
+      // TODO: Avoid implicit trunc?
       IntegerType intTy = ::llvm::cast<IntegerType>(type);
-      APInt apValue(intTy.getWidth(), value, intTy.isSignedInteger());
+      APInt apValue(intTy.getWidth(), value, intTy.isSignedInteger(),
+                    /*implicitTrunc=*/true);
       return $_get(type.getContext(), type, apValue);
     }]>
   ];
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index e2472eea8a3714..606a56c7fd55b5 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -749,7 +749,8 @@ class AsmParser {
     // zero for non-negated integers.
     result =
         (IntT)uintResult.sextOrTrunc(sizeof(IntT) * CHAR_BIT).getLimitedValue();
-    if (APInt(uintResult.getBitWidth(), result) != uintResult)
+    if (APInt(uintResult.getBitWidth(), result, /*isSigned=*/true,
+              /*implicitTrunc=*/true) != uintResult)
       return emitError(loc, "integer value too large");
     return success();
   }
diff --git a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
index 50e57682a2dc8d..593dbaa6c6545a 100644
--- a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
+++ b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
@@ -43,7 +43,7 @@ Type matchContainerType(Type element, Type container) {
 TypedAttr getConstantAttr(Type type, int64_t value, PatternRewriter &rewriter) {
   if (auto shapedTy = dyn_cast<ShapedType>(type)) {
     Type eTy = shapedTy.getElementType();
-    APInt valueInt(eTy.getIntOrFloatBitWidth(), value);
+    APInt valueInt(eTy.getIntOrFloatBitWidth(), value, /*isSigned=*/true);
     return DenseIntElementsAttr::get(shapedTy, valueInt);
   }
 
diff --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
index 98b429de1fd85c..edd7f607f24f4d 100644
--- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
@@ -528,7 +528,7 @@ static ParseResult parseSwitchOpCases(
     int64_t value = 0;
     if (failed(parser.parseInteger(value)))
       return failure();
-    values.push_back(APInt(bitWidth, value));
+    values.push_back(APInt(bitWidth, value, /*isSigned=*/true));
 
     Block *destination;
     SmallVector<OpAsmParser::UnresolvedOperand> operands;
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 0561c364c7d591..00f81db1dd795c 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -598,7 +598,7 @@ static ParseResult parseSwitchOpCases(
     int64_t value = 0;
     if (failed(parser.parseInteger(value)))
       return failure();
-    values.push_back(APInt(bitWidth, value));
+    values.push_back(APInt(bitWidth, value, /*isSigned=*/true));
 
     Block *destination;
     SmallVector<OpAsmParser::UnresolvedOperand> operands;
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 90bf5df67b03ba..a5f987206db11c 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1073,7 +1073,7 @@ static ParseResult parseMembersIndex(OpAsmParser &parser,
     if (parser.parseInteger(value))
       return failure();
     shapeTmp++;
-    values.push_back(APInt(32, value));
+    values.push_back(APInt(32, value, /*isSigned=*/true));
     return success();
   };
 
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 7aed415343e551..7a7f360c53eb2f 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -234,7 +234,10 @@ DenseIntElementsAttr Builder::getIndexTensorAttr(ArrayRef<int64_t> values) {
 }
 
 IntegerAttr Builder::getI32IntegerAttr(int32_t value) {
-  return IntegerAttr::get(getIntegerType(32), APInt(32, value));
+  // The APInt always uses isSigned=true here because we accept the value
+  // as int32_t.
+  return IntegerAttr::get(getIntegerType(32),
+                          APInt(32, value, /*isSigned=*/true));
 }
 
 IntegerAttr Builder::getSI32IntegerAttr(int32_t value) {
@@ -252,14 +255,19 @@ IntegerAttr Builder::getI16IntegerAttr(int16_t value) {
 }
 
 IntegerAttr Builder::getI8IntegerAttr(int8_t value) {
-  return IntegerAttr::get(getIntegerType(8), APInt(8, value));
+  // The APInt always uses isSigned=true here because we accept the value
+  // as int8_t.
+  return IntegerAttr::get(getIntegerType(8),
+                          APInt(8, value, /*isSigned=*/true));
 }
 
 IntegerAttr Builder::getIntegerAttr(Type type, int64_t value) {
   if (type.isIndex())
     return IntegerAttr::get(type, APInt(64, value));
-  return IntegerAttr::get(
-      type, APInt(type.getIntOrFloatBitWidth(), value, type.isSignedInteger()));
+  // TODO: Avoid implicit trunc?
+  return IntegerAttr::get(type, APInt(type.getIntOrFloatBitWidth(), value,
+                                      type.isSignedInteger(),
+                                      /*implicitTrunc=*/true));
 }
 
 IntegerAttr Builder::getIntegerAttr(Type type, const APInt &value) {
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 38293f7106a05a..236e0ec100a0de 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -1284,9 +1284,11 @@ LogicalResult spirv::Deserializer::processConstant(ArrayRef<uint32_t> operands,
         uint32_t word1;
         uint32_t word2;
       } words = {operands[2], operands[3]};
-      value = APInt(64, llvm::bit_cast<uint64_t>(words), /*isSigned=*/true);
+      value = APInt(64, llvm::bit_cast<uint64_t>(words), /*isSigned=*/true,
+                    /*implicitTrunc=*/true);
     } else if (bitwidth <= 32) {
-      value = APInt(bitwidth, operands[2], /*isSigned=*/true);
+      value = APInt(bitwidth, operands[2], /*isSigned=*/true,
+                    /*implicitTrunc=*/true);
     }
 
     auto attr = opBuilder.getIntegerAttr(intType, value);
diff --git a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
index 9d2f690ed898af..ef89c1645d373f 100644
--- a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
+++ b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
@@ -176,7 +176,7 @@ TEST_F(SerializationTest, SignlessVsSignedIntegerConstantBitExtension) {
       IntegerType::get(&context, 16, IntegerType::Signless);
   auto signedInt16Type = IntegerType::get(&context, 16, IntegerType::Signed);
   // Check the bit extension of same value under different signedness semantics.
-  APInt signlessIntConstVal(signlessInt16Type.getWidth(), -1,
+  APInt signlessIntConstVal(signlessInt16Type.getWidth(), 0xffff,
                             signlessInt16Type.getSignedness());
   APInt signedIntConstVal(signedInt16Type.getWidth(), -1,
                           signedInt16Type.getSignedness());

``````````

</details>


https://github.com/llvm/llvm-project/pull/110466


More information about the Mlir-commits mailing list