[Mlir-commits] [mlir] de5a81b - [mlir] Update several usages of IntegerType to properly handled unsignedness.
River Riddle
llvmlistbot at llvm.org
Mon Mar 2 09:22:59 PST 2020
Author: River Riddle
Date: 2020-03-02T09:19:26-08:00
New Revision: de5a81b1023e95a06f0e40b8ef9cdfc2e38b6223
URL: https://github.com/llvm/llvm-project/commit/de5a81b1023e95a06f0e40b8ef9cdfc2e38b6223
DIFF: https://github.com/llvm/llvm-project/commit/de5a81b1023e95a06f0e40b8ef9cdfc2e38b6223.diff
LOG: [mlir] Update several usages of IntegerType to properly handled unsignedness.
Summary: For example, DenseElementsAttr currently does not properly round-trip unsigned integer values.
Differential Revision: https://reviews.llvm.org/D75374
Added:
Modified:
mlir/include/mlir/IR/Matchers.h
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/IR/StandardTypes.h
mlir/include/mlir/IR/Types.h
mlir/lib/Analysis/Utils.cpp
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/IR/Attributes.cpp
mlir/lib/IR/StandardTypes.cpp
mlir/lib/Parser/Parser.cpp
mlir/lib/Transforms/LoopFusion.cpp
mlir/test/IR/parser.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h
index 6321e88c9c10..d9979b8467ee 100644
--- a/mlir/include/mlir/IR/Matchers.h
+++ b/mlir/include/mlir/IR/Matchers.h
@@ -93,9 +93,8 @@ struct constant_int_op_binder {
return false;
auto type = op->getResult(0).getType();
- if (type.isSignlessIntOrIndex()) {
+ if (type.isa<IntegerType>() || type.isa<IndexType>())
return attr_value_binder<IntegerAttr>(bind_value).match(attr);
- }
if (type.isa<VectorType>() || type.isa<RankedTensorType>()) {
if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
return attr_value_binder<IntegerAttr>(bind_value)
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 25c0238946a9..d431d4ebabf4 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -339,6 +339,30 @@ def I16 : I<16>;
def I32 : I<32>;
def I64 : I<64>;
+// Unsigned integer types.
+// Any unsigned integer type irrespective of its width.
+def AnyUnsignedInteger : Type<
+ CPred<"$_self.isUnsignedInteger()">, "unsigned integer">;
+
+// Unsigned integer type of a specific width.
+class UI<int width>
+ : Type<CPred<"$_self.isUnsignedInteger(" # width # ")">,
+ width # "-bit unsigned integer">,
+ BuildableType<"$_builder.getIntegerType(" # width #
+ ", /*isSigned=*/false)"> {
+ int bitwidth = width;
+}
+
+class UnsignedIntOfWidths<list<int> widths> :
+ AnyTypeOf<!foreach(w, widths, UI<w>),
+ StrJoinInt<widths, "/">.result # "-bit unsigned integer">;
+
+def UI1 : UI<1>;
+def UI8 : UI<8>;
+def UI16 : UI<16>;
+def UI32 : UI<32>;
+def UI64 : UI<64>;
+
// Floating point types.
// Any float type irrespective of its width.
diff --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h
index 9bb9a8c06234..cd5ba07b689d 100644
--- a/mlir/include/mlir/IR/StandardTypes.h
+++ b/mlir/include/mlir/IR/StandardTypes.h
@@ -328,8 +328,9 @@ class TensorType : public ShapedType {
// Note: Non standard/builtin types are allowed to exist within tensor
// types. Dialects are expected to verify that tensor types have a valid
// element type within that dialect.
- return type.isSignlessIntOrFloat() || type.isa<ComplexType>() ||
- type.isa<VectorType>() || type.isa<OpaqueType>() ||
+ return type.isa<ComplexType>() || type.isa<FloatType>() ||
+ type.isa<IntegerType>() || type.isa<OpaqueType>() ||
+ type.isa<VectorType>() ||
(type.getKind() > Type::Kind::LAST_STANDARD_TYPE);
}
diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index 40f1d4818769..eccc90cdae0c 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -169,6 +169,9 @@ class Type {
/// Return true of this is a signless integer or a float type.
bool isSignlessIntOrFloat();
+ /// Return true of this is an integer(of any signedness) or a float type.
+ bool isIntOrFloat();
+
/// Print the current type.
void print(raw_ostream &os);
void dump();
diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp
index b76c0c0770a3..14635a144735 100644
--- a/mlir/lib/Analysis/Utils.cpp
+++ b/mlir/lib/Analysis/Utils.cpp
@@ -314,7 +314,7 @@ static unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
auto elementType = memRefType.getElementType();
unsigned sizeInBits;
- if (elementType.isSignlessIntOrFloat()) {
+ if (elementType.isIntOrFloat()) {
sizeInBits = elementType.getIntOrFloatBitWidth();
} else {
auto vectorType = elementType.cast<VectorType>();
@@ -358,7 +358,7 @@ Optional<uint64_t> mlir::getMemRefSizeInBytes(MemRefType memRefType) {
if (!memRefType.hasStaticShape())
return None;
auto elementType = memRefType.getElementType();
- if (!elementType.isSignlessIntOrFloat() && !elementType.isa<VectorType>())
+ if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>())
return None;
uint64_t sizeInBytes = getMemRefEltSizeInBytes(memRefType);
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 140f533e0b15..ac2648846b24 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1372,17 +1372,18 @@ void ModulePrinter::printAttribute(Attribute attr,
/// Print the integer element of the given DenseElementsAttr at 'index'.
static void printDenseIntElement(DenseElementsAttr attr, raw_ostream &os,
- unsigned index) {
+ unsigned index, bool isSigned) {
APInt value = *std::next(attr.int_value_begin(), index);
if (value.getBitWidth() == 1)
os << (value.getBoolValue() ? "true" : "false");
else
- value.print(os, /*isSigned=*/true);
+ value.print(os, isSigned);
}
/// Print the float element of the given DenseElementsAttr at 'index'.
static void printDenseFloatElement(DenseElementsAttr attr, raw_ostream &os,
- unsigned index) {
+ unsigned index, bool isSigned) {
+ assert(isSigned && "floating point values are always signed");
APFloat value = *std::next(attr.float_value_begin(), index);
printFloatValue(value, os);
}
@@ -1392,6 +1393,7 @@ void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr,
auto type = attr.getType();
auto shape = type.getShape();
auto rank = type.getRank();
+ bool isSigned = !type.getElementType().isUnsignedInteger();
// The function used to print elements of this attribute.
auto printEltFn = type.getElementType().isa<IntegerType>()
@@ -1400,7 +1402,7 @@ void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr,
// Special case for 0-d and splat tensors.
if (attr.isSplat()) {
- printEltFn(attr, os, 0);
+ printEltFn(attr, os, 0, isSigned);
return;
}
@@ -1452,7 +1454,7 @@ void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr,
while (openBrackets++ < rank)
os << '[';
openBrackets = rank;
- printEltFn(attr, os, idx);
+ printEltFn(attr, os, idx, isSigned);
bumpCounter();
}
while (openBrackets-- > 0)
diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp
index 5beb12a59940..4526d7dc10be 100644
--- a/mlir/lib/IR/Attributes.cpp
+++ b/mlir/lib/IR/Attributes.cpp
@@ -608,7 +608,7 @@ DenseElementsAttr::FloatElementIterator::FloatElementIterator(
DenseElementsAttr DenseElementsAttr::get(ShapedType type,
ArrayRef<Attribute> values) {
- assert(type.getElementType().isSignlessIntOrFloat() &&
+ assert(type.getElementType().isIntOrFloat() &&
"expected int or float element type");
assert(hasSameElementsOrSplat(type, values));
diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp
index 30d5bbcc7b3c..774f80a46de3 100644
--- a/mlir/lib/IR/StandardTypes.cpp
+++ b/mlir/lib/IR/StandardTypes.cpp
@@ -84,6 +84,8 @@ bool Type::isSignlessIntOrFloat() {
return isSignlessInteger() || isa<FloatType>();
}
+bool Type::isIntOrFloat() { return isa<IntegerType>() || isa<FloatType>(); }
+
//===----------------------------------------------------------------------===//
// Integer Type
//===----------------------------------------------------------------------===//
@@ -147,13 +149,10 @@ const llvm::fltSemantics &FloatType::getFloatSemantics() {
}
unsigned Type::getIntOrFloatBitWidth() {
- assert(isSignlessIntOrFloat() && "only ints and floats have a bitwidth");
- if (auto intType = dyn_cast<IntegerType>()) {
+ assert(isIntOrFloat() && "only integers and floats have a bitwidth");
+ if (auto intType = dyn_cast<IntegerType>())
return intType.getWidth();
- }
-
- auto floatType = cast<FloatType>();
- return floatType.getWidth();
+ return cast<FloatType>().getWidth();
}
//===----------------------------------------------------------------------===//
@@ -202,7 +201,7 @@ int64_t ShapedType::getSizeInBits() const {
"cannot get the bit size of an aggregate with a dynamic shape");
auto elementType = getElementType();
- if (elementType.isSignlessIntOrFloat())
+ if (elementType.isIntOrFloat())
return elementType.getIntOrFloatBitWidth() * getNumElements();
// Tensors can have vectors and other tensors as elements, other shaped types
@@ -373,7 +372,7 @@ MemRefType MemRefType::getImpl(ArrayRef<int64_t> shape, Type elementType,
auto *context = elementType.getContext();
// Check that memref is formed from allowed types.
- if (!elementType.isSignlessIntOrFloat() && !elementType.isa<VectorType>() &&
+ if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>() &&
!elementType.isa<ComplexType>())
return emitOptionalError(location, "invalid memref element type"),
MemRefType();
@@ -451,7 +450,7 @@ LogicalResult
UnrankedMemRefType::verifyConstructionInvariants(Location loc, Type elementType,
unsigned memorySpace) {
// Check that memref is formed from allowed types.
- if (!elementType.isSignlessIntOrFloat() && !elementType.isa<VectorType>() &&
+ if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>() &&
!elementType.isa<ComplexType>())
return emitError(loc, "invalid memref element type");
return success();
diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index 668fb694d8fd..661bddf8107a 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -1102,7 +1102,7 @@ Type Parser::parseMemRefType() {
return nullptr;
// Check that memref is formed from allowed types.
- if (!elementType.isSignlessIntOrFloat() && !elementType.isa<VectorType>() &&
+ if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>() &&
!elementType.isa<ComplexType>())
return emitError(typeLoc, "invalid memref element type"), nullptr;
diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index ef1af5d71aa8..bcb0c16ba77b 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -869,7 +869,7 @@ static unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
auto elementType = memRefType.getElementType();
unsigned sizeInBits;
- if (elementType.isSignlessIntOrFloat()) {
+ if (elementType.isIntOrFloat()) {
sizeInBits = elementType.getIntOrFloatBitWidth();
} else {
auto vectorType = elementType.cast<VectorType>();
diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir
index bec1fbd4aca6..3baf0642e8b0 100644
--- a/mlir/test/IR/parser.mlir
+++ b/mlir/test/IR/parser.mlir
@@ -616,6 +616,9 @@ func @splattensorattr() -> () {
// CHECK: "splatBoolTensor"() {bar = dense<false> : tensor<i1>} : () -> ()
"splatBoolTensor"(){bar = dense<false> : tensor<i1>} : () -> ()
+ // CHECK: "splatUIntTensor"() {bar = dense<222> : tensor<2x1x4xui8>} : () -> ()
+ "splatUIntTensor"(){bar = dense<222> : tensor<2x1x4xui8>} : () -> ()
+
// CHECK: "splatIntTensor"() {bar = dense<5> : tensor<2x1x4xi32>} : () -> ()
"splatIntTensor"(){bar = dense<5> : tensor<2x1x4xi32>} : () -> ()
More information about the Mlir-commits
mailing list