[Mlir-commits] [mlir] 35b6852 - [mlir] Add a signedness semantics bit to IntegerType
Lei Zhang
llvmlistbot at llvm.org
Fri Feb 21 06:17:23 PST 2020
Author: Lei Zhang
Date: 2020-02-21T09:16:54-05:00
New Revision: 35b685270b410f6a1351c2a527021f22330c25b9
URL: https://github.com/llvm/llvm-project/commit/35b685270b410f6a1351c2a527021f22330c25b9
DIFF: https://github.com/llvm/llvm-project/commit/35b685270b410f6a1351c2a527021f22330c25b9.diff
LOG: [mlir] Add a signedness semantics bit to IntegerType
Thus far IntegerType has been signless: a value of IntegerType does
not have a sign intrinsically and it's up to the specific operation
to decide how to interpret those bits. For example, std.addi does
two's complement arithmetic, and std.divis/std.diviu treats the first
bit as a sign.
This design choice was made some time ago when we did't have lots
of dialects and dialects were more rigid. Today we have much more
extensible infrastructure and different dialect may want different
modelling over integer signedness. So while we can say we want
signless integers in the standard dialect, we cannot dictate for
others. Requiring each dialect to model the signedness semantics
with another set of custom types is duplicating the functionality
everywhere, considering the fundamental role integer types play.
This CL extends the IntegerType with a signedness semantics bit.
This gives each dialect an option to opt in signedness semantics
if that's what they want and helps code sharing. The parser is
modified to recognize `si[1-9][0-9]*` and `ui[1-9][0-9]*` as
signed and unsigned integer types, respectively, leaving the
original `i[1-9][0-9]*` to continue to mean no indication over
signedness semantics. All existing dialects are not affected (yet)
as this is a feature to opt in.
More discussions can be found at:
https://groups.google.com/a/tensorflow.org/d/msg/mlir/XmkV8HOPWpo/7O4X0Nb_AQAJ
Differential Revision: https://reviews.llvm.org/D72533
Added:
Modified:
mlir/docs/LangRef.md
mlir/docs/Rationale.md
mlir/include/mlir/Dialect/FxpMathOps/FxpMathOps.td
mlir/include/mlir/Dialect/GPU/GPUOps.td
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
mlir/include/mlir/Dialect/QuantOps/QuantPredicates.td
mlir/include/mlir/Dialect/QuantOps/UniformSupport.h
mlir/include/mlir/Dialect/StandardOps/Ops.td
mlir/include/mlir/Dialect/VectorOps/VectorOps.td
mlir/include/mlir/IR/Attributes.h
mlir/include/mlir/IR/Builders.h
mlir/include/mlir/IR/Matchers.h
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/IR/OpDefinition.h
mlir/include/mlir/IR/Operation.h
mlir/include/mlir/IR/StandardTypes.h
mlir/include/mlir/IR/Types.h
mlir/lib/Analysis/Utils.cpp
mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Dialect/AffineOps/EDSC/Builders.cpp
mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
mlir/lib/Dialect/Linalg/Utils/Utils.cpp
mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
mlir/lib/Dialect/SPIRV/TargetAndABI.cpp
mlir/lib/Dialect/StandardOps/Ops.cpp
mlir/lib/Dialect/VectorOps/VectorOps.cpp
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/IR/Attributes.cpp
mlir/lib/IR/Builders.cpp
mlir/lib/IR/MLIRContext.cpp
mlir/lib/IR/Operation.cpp
mlir/lib/IR/StandardTypes.cpp
mlir/lib/IR/TypeDetail.h
mlir/lib/Parser/Lexer.cpp
mlir/lib/Parser/Parser.cpp
mlir/lib/Parser/Token.cpp
mlir/lib/Parser/Token.h
mlir/lib/Parser/TokenKinds.def
mlir/lib/Transforms/LoopFusion.cpp
mlir/test/IR/invalid.mlir
mlir/test/IR/parser.mlir
mlir/test/lib/TestDialect/TestDialect.cpp
mlir/test/lib/TestDialect/TestPatterns.cpp
mlir/test/mlir-tblgen/op-attribute.td
mlir/test/mlir-tblgen/predicate.td
Removed:
################################################################################
diff --git a/mlir/docs/LangRef.md b/mlir/docs/LangRef.md
index 1a9ef99492da..61a37b6d447d 100644
--- a/mlir/docs/LangRef.md
+++ b/mlir/docs/LangRef.md
@@ -745,11 +745,16 @@ Syntax:
```
// Sized integers like i1, i4, i8, i16, i32.
-integer-type ::= `i` [1-9][0-9]*
+signed-integer-type ::= `si` [1-9][0-9]*
+unsigned-integer-type ::= `ui` [1-9][0-9]*
+signless-integer-type ::= `i` [1-9][0-9]*
+integer-type ::= signed-integer-type |
+ unsigned-integer-type |
+ signless-integer-type
```
-MLIR supports arbitrary precision integer types. Integer types are signless, but
-have a designated width.
+MLIR supports arbitrary precision integer types. Integer types have a designated
+width and may have signedness semantics.
**Rationale:** low precision integers (like `i2`, `i4` etc) are useful for
low-precision inference chips, and arbitrary precision integers are useful for
diff --git a/mlir/docs/Rationale.md b/mlir/docs/Rationale.md
index 763442dce063..ff475fd849e2 100644
--- a/mlir/docs/Rationale.md
+++ b/mlir/docs/Rationale.md
@@ -244,13 +244,22 @@ introduced.
The bit width is not defined for dialect-specific types at MLIR level. Dialects
are free to define their own quantities for type sizes.
-### Signless types
+### Integer signedness semantics
Integers in the builtin MLIR type system have a bitwidth (note that the `index`
-type has a symbolic width equal to the machine word size), but they do not have
-an intrinsic sign. This means that the "standard ops" operation set has things
-like `addi` and `muli` which do two's complement arithmetic, but some other
-operations get a sign, e.g. `divis` vs `diviu`.
+type has a symbolic width equal to the machine word size), and they *may*
+additionally have signedness semantics. The purpose is to satisfy the needs of
+
diff erent dialects, which can model
diff erent levels of abstractions. Certain
+abstraction, especially closer to source language, might want to
diff erentiate
+signedness with integer types; while others, especially closer to machine
+instruction, might want signless integers. Instead of forcing each abstraction
+to adopt the same integer modelling or develop its own one in house, Integer
+types provides this as an option to help code reuse and consistency.
+
+For the standard dialect, the choice is to have signless integer types. An
+integer value does not have an intrinsic sign, and it's up to the specific op
+for interpretation. For example, ops like `addi` and `muli` do two's complement
+arithmetic, but some other operations get a sign, e.g. `divis` vs `diviu`.
LLVM uses the [same design](http://llvm.org/docs/LangRef.html#integer-type),
which was introduced in a revamp rolled out
diff --git a/mlir/include/mlir/Dialect/FxpMathOps/FxpMathOps.td b/mlir/include/mlir/Dialect/FxpMathOps/FxpMathOps.td
index b562559920b2..95c2b4c3934f 100644
--- a/mlir/include/mlir/Dialect/FxpMathOps/FxpMathOps.td
+++ b/mlir/include/mlir/Dialect/FxpMathOps/FxpMathOps.td
@@ -91,10 +91,10 @@ def fxpmath_ClampISOp : fxpmath_Op<"clampis", [NoSideEffect, SameOperandsAndResu
Element-wise equivalent to:
r = std::min(clamp_max, std::max(e, clamp_min))
}];
- let arguments = (ins IntegerLike:$operand,
+ let arguments = (ins SignlessIntegerLike:$operand,
APIntAttr:$clamp_min,
APIntAttr:$clamp_max);
- let results = (outs IntegerLike);
+ let results = (outs SignlessIntegerLike);
}
def fxpmath_ConvertISOp :
@@ -106,8 +106,8 @@ def fxpmath_ConvertISOp :
Similar to an element-wise static_cast in C++, from a one signed integer
element type to another.
}];
- let arguments = (ins IntegerLike:$operand);
- let results = (outs IntegerLike);
+ let arguments = (ins SignlessIntegerLike:$operand);
+ let results = (outs SignlessIntegerLike);
}
def fxpmath_ConvertISToFOp :
@@ -120,7 +120,7 @@ def fxpmath_ConvertISToFOp :
element type to a floating point element type, rounding to the nearest
floating point value.
}];
- let arguments = (ins IntegerLike:$operand);
+ let arguments = (ins SignlessIntegerLike:$operand);
let results = (outs FloatLike);
}
@@ -134,8 +134,8 @@ def fxpmath_VecScalarSaturatingRoundingDoublingHighMulISOp :
See gemmlowp::SaturatingRoundingDoublingHighMul for a reference
implementation.
}];
- let arguments = (ins IntegerLike:$a, APIntAttr:$b);
- let results = (outs IntegerLike);
+ let arguments = (ins SignlessIntegerLike:$a, APIntAttr:$b);
+ let results = (outs SignlessIntegerLike);
}
def fxpmath_RoundingDivideByPotISOp :
@@ -148,8 +148,8 @@ def fxpmath_RoundingDivideByPotISOp :
Also known as a rounding arithmetic right shift. See
gemmlowp::RoundingDivideByPOT for a reference implementation.
}];
- let arguments = (ins IntegerLike:$operand, APIntAttr:$exponent);
- let results = (outs IntegerLike:$res);
+ let arguments = (ins SignlessIntegerLike:$operand, APIntAttr:$exponent);
+ let results = (outs SignlessIntegerLike:$res);
let verifier = [{
auto verifyExponent = exponent().getSExtValue();
if (verifyExponent < 0 || verifyExponent > 31) {
diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td
index 5923e5258dd9..a979ac912daa 100644
--- a/mlir/include/mlir/Dialect/GPU/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td
@@ -19,7 +19,7 @@ include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
// Type constraint accepting standard integers, indices and wrapped LLVM integer
// types.
def IntLikeOrLLVMInt : TypeConstraint<
- Or<[AnyInteger.predicate, Index.predicate, LLVMInt.predicate]>,
+ Or<[AnySignlessInteger.predicate, Index.predicate, LLVMInt.predicate]>,
"integer, index or LLVM dialect equivalent">;
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 09a85a1026fb..1de56ef61440 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -313,7 +313,7 @@ def CopyOp : LinalgStructured_Op<"copy", [NInputs<1>, NOutputs<1>]> {
def FillOp : LinalgStructured_Op<"fill", [NInputs<0>, NOutputs<1>]> {
let arguments = (ins AnyStridedMemRef:$output,
- AnyTypeOf<[AnyFloat, AnyInteger, AnyVector]>:$value);
+ AnyTypeOf<[AnyFloat, AnySignlessInteger, AnyVector]>:$value);
let extraClassDeclaration = libraryCallName # [{
// Defined in C++ for now.
// TODO(ntv): auto-generate.
diff --git a/mlir/include/mlir/Dialect/QuantOps/QuantPredicates.td b/mlir/include/mlir/Dialect/QuantOps/QuantPredicates.td
index 7f2bfb9a76ef..cd2e85fd985d 100644
--- a/mlir/include/mlir/Dialect/QuantOps/QuantPredicates.td
+++ b/mlir/include/mlir/Dialect/QuantOps/QuantPredicates.td
@@ -36,7 +36,7 @@ def quant_RealPrimitiveType :
// A primitive type that can represent a storage value. This is either an
// integer or quantized type.
def quant_StoragePrimitiveType :
- Type<Or<[AnyInteger.predicate, quant_QuantizedType.predicate]>,
+ Type<Or<[AnySignlessInteger.predicate, quant_QuantizedType.predicate]>,
"quantized storage primitive (integer or quantized type)">;
// A primitive or container of RealPrimitiveType.
diff --git a/mlir/include/mlir/Dialect/QuantOps/UniformSupport.h b/mlir/include/mlir/Dialect/QuantOps/UniformSupport.h
index b299d69a72d3..23d900b7367e 100644
--- a/mlir/include/mlir/Dialect/QuantOps/UniformSupport.h
+++ b/mlir/include/mlir/Dialect/QuantOps/UniformSupport.h
@@ -66,7 +66,7 @@ class UniformQuantizedValueConverter {
static_cast<double>(uniformType.getStorageTypeMax()),
uniformType.getStorageTypeIntegralWidth(), uniformType.isSigned()) {
assert(uniformType.getExpressedType().isa<FloatType>());
- assert(uniformType.getStorageType().isa<IntegerType>());
+ assert(uniformType.getStorageType().isSignlessInteger());
}
UniformQuantizedValueConverter(double scale, double zeroPoint,
@@ -182,7 +182,7 @@ class UniformQuantizedPerAxisValueConverter {
isSigned(uniformType.isSigned()),
quantizationDim(uniformType.getQuantizedDimension()) {
assert(uniformType.getExpressedType().isa<FloatType>());
- assert(uniformType.getStorageType().isa<IntegerType>());
+ assert(uniformType.getStorageType().isSignlessInteger());
assert(scales.size() == zeroPoints.size());
}
diff --git a/mlir/include/mlir/Dialect/StandardOps/Ops.td b/mlir/include/mlir/Dialect/StandardOps/Ops.td
index b6186bd4ec76..a0b739ea6ec7 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/Ops.td
@@ -107,7 +107,7 @@ class ArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
// <op>i %0, %1 : i32
class IntArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
ArithmeticOp<mnemonic, traits>,
- Arguments<(ins IntegerLike:$lhs, IntegerLike:$rhs)>;
+ Arguments<(ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs)>;
// Base class for standard arithmetic binary operations on floats, vectors and
// tensors thereof. This operation has two operands and returns one result,
@@ -466,8 +466,8 @@ def CmpIOp : Std_Op<"cmpi",
let arguments = (ins
CmpIPredicateAttr:$predicate,
- IntegerLike:$lhs,
- IntegerLike:$rhs
+ SignlessIntegerLike:$lhs,
+ SignlessIntegerLike:$rhs
);
let results = (outs BoolLike:$result);
@@ -1070,9 +1070,10 @@ def SelectOp : Std_Op<"select", [NoSideEffect, SameOperandsAndResultShape,
%3 = select %2, %0, %1 : i32
}];
- let arguments = (ins BoolLike:$condition, IntegerOrFloatLike:$true_value,
- IntegerOrFloatLike:$false_value);
- let results = (outs IntegerOrFloatLike:$result);
+ let arguments = (ins BoolLike:$condition,
+ SignlessIntegerOrFloatLike:$true_value,
+ SignlessIntegerOrFloatLike:$false_value);
+ let results = (outs SignlessIntegerOrFloatLike:$result);
let verifier = ?;
let builders = [OpBuilder<
@@ -1109,8 +1110,8 @@ def SignExtendIOp : Std_Op<"sexti",
%5 = sexti %0 : vector<2 x i32> to vector<2 x i64>
}];
- let arguments = (ins IntegerLike:$value);
- let results = (outs IntegerLike);
+ let arguments = (ins SignlessIntegerLike:$value);
+ let results = (outs SignlessIntegerLike);
let builders = [OpBuilder<
"Builder *builder, OperationState &result, Value value, Type destType", [{
@@ -1211,7 +1212,7 @@ def SplatOp : Std_Op<"splat", [NoSideEffect,
}];
- let arguments = (ins AnyTypeOf<[AnyInteger, AnyFloat],
+ let arguments = (ins AnyTypeOf<[AnySignlessInteger, AnyFloat],
"integer or float type">:$input);
let results = (outs AnyTypeOf<[AnyVector, AnyStaticShapeTensor]>:$aggregate);
@@ -1561,8 +1562,8 @@ def TruncateIOp : Std_Op<"trunci", [NoSideEffect, SameOperandsAndResultShape]> {
%5 = trunci %0 : vector<2 x i32> to vector<2 x i16>
}];
- let arguments = (ins IntegerLike:$value);
- let results = (outs IntegerLike);
+ let arguments = (ins SignlessIntegerLike:$value);
+ let results = (outs SignlessIntegerLike);
let builders = [OpBuilder<
"Builder *builder, OperationState &result, Value value, Type destType", [{
@@ -1661,8 +1662,8 @@ def ZeroExtendIOp : Std_Op<"zexti", [NoSideEffect, SameOperandsAndResultShape]>
%5 = zexti %0 : vector<2 x i32> to vector<2 x i64>
}];
- let arguments = (ins IntegerLike:$value);
- let results = (outs IntegerLike);
+ let arguments = (ins SignlessIntegerLike:$value);
+ let results = (outs SignlessIntegerLike);
let builders = [OpBuilder<
"Builder *builder, OperationState &result, Value value, Type destType", [{
diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
index 10837b94bfc2..ce6029a5d497 100644
--- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
+++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
@@ -342,7 +342,7 @@ def Vector_ExtractElementOp :
TypesMatchWith<"result type matches element type of vector operand",
"vector", "result",
"$_self.cast<ShapedType>().getElementType()">]>,
- Arguments<(ins AnyVector:$vector, AnyInteger:$position)>,
+ Arguments<(ins AnyVector:$vector, AnySignlessInteger:$position)>,
Results<(outs AnyType:$result)> {
let summary = "extractelement operation";
let description = [{
@@ -487,7 +487,8 @@ def Vector_InsertElementOp :
"result", "source",
"$_self.cast<ShapedType>().getElementType()">,
AllTypesMatch<["dest", "result"]>]>,
- Arguments<(ins AnyType:$source, AnyVector:$dest, AnyInteger:$position)>,
+ Arguments<(ins AnyType:$source, AnyVector:$dest,
+ AnySignlessInteger:$position)>,
Results<(outs AnyVector:$result)> {
let summary = "insertelement operation";
let description = [{
diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h
index d5c063f326ce..5af6aa79e081 100644
--- a/mlir/include/mlir/IR/Attributes.h
+++ b/mlir/include/mlir/IR/Attributes.h
@@ -351,8 +351,16 @@ class IntegerAttr
static IntegerAttr get(Type type, const APInt &value);
APInt getValue() const;
+ /// Return the integer value as a 64-bit int. The attribute must be a signless
+ /// integer.
// TODO(jpienaar): Change callers to use getValue instead.
int64_t getInt() const;
+ /// Return the integer value as a signed 64-bit int. The attribute must be
+ /// a signed integer.
+ int64_t getSInt() const;
+ /// Return the integer value as a unsigned 64-bit int. The attribute must be
+ /// an unsigned integer.
+ uint64_t getUInt() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool kindof(unsigned kind) {
@@ -688,7 +696,7 @@ class DenseElementsAttr
const char *data = reinterpret_cast<const char *>(values.data());
return getRawIntOrFloat(
type, ArrayRef<char>(data, values.size() * sizeof(T)), sizeof(T),
- /*isInt=*/std::numeric_limits<T>::is_integer);
+ std::numeric_limits<T>::is_integer, std::numeric_limits<T>::is_signed);
}
/// Constructs a dense integer elements attribute from a single element.
@@ -863,7 +871,8 @@ class DenseElementsAttr
std::numeric_limits<T>::is_integer) ||
llvm::is_one_of<T, float, double>::value>::type>
llvm::iterator_range<ElementIterator<T>> getValues() const {
- assert(isValidIntOrFloat(sizeof(T), std::numeric_limits<T>::is_integer));
+ assert(isValidIntOrFloat(sizeof(T), std::numeric_limits<T>::is_integer,
+ std::numeric_limits<T>::is_signed));
auto rawData = getRawData().data();
bool splat = isSplat();
return {ElementIterator<T>(rawData, splat, 0),
@@ -976,12 +985,13 @@ class DenseElementsAttr
/// invariants that the templatized 'get' method cannot.
static DenseElementsAttr getRawIntOrFloat(ShapedType type,
ArrayRef<char> data,
- int64_t dataEltSize, bool isInt);
+ int64_t dataEltSize, bool isInt,
+ bool isSigned);
- /// Check the information for a c++ data type, check if this type is valid for
+ /// Check the information for a C++ data type, check if this type is valid for
/// the current attribute. This method is used to verify specific type
/// invariants that the templatized 'getValues' method cannot.
- bool isValidIntOrFloat(int64_t dataEltSize, bool isInt) const;
+ bool isValidIntOrFloat(int64_t dataEltSize, bool isInt, bool isSigned) const;
};
/// An attribute that represents a reference to a dense float vector or tensor
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 11e1d0f14296..ec4bcf75547b 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -70,6 +70,7 @@ class Builder {
IntegerType getI1Type();
IntegerType getIntegerType(unsigned width);
+ IntegerType getIntegerType(unsigned width, bool isSigned);
FunctionType getFunctionType(ArrayRef<Type> inputs, ArrayRef<Type> results);
TupleType getTupleType(ArrayRef<Type> elementTypes);
NoneType getNoneType();
@@ -111,6 +112,10 @@ class Builder {
IntegerAttr getI32IntegerAttr(int32_t value);
IntegerAttr getI64IntegerAttr(int64_t value);
+ /// Signed and unsigned integer attribute getters.
+ IntegerAttr getSI32IntegerAttr(int32_t value);
+ IntegerAttr getUI32IntegerAttr(uint32_t value);
+
DenseIntElementsAttr getI32VectorAttr(ArrayRef<int32_t> values);
ArrayAttr getAffineMapArrayAttr(ArrayRef<AffineMap> values);
diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h
index ab8086b60199..6321e88c9c10 100644
--- a/mlir/include/mlir/IR/Matchers.h
+++ b/mlir/include/mlir/IR/Matchers.h
@@ -93,7 +93,7 @@ struct constant_int_op_binder {
return false;
auto type = op->getResult(0).getType();
- if (type.isIntOrIndex()) {
+ if (type.isSignlessIntOrIndex()) {
return attr_value_binder<IntegerAttr>(bind_value).match(attr);
}
if (type.isa<VectorType>() || type.isa<RankedTensorType>()) {
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 2f02a885242b..3dba6a09c5a7 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -310,7 +310,8 @@ class AnyTypeOf<list<Type> allowedTypes, string description = ""> : Type<
// Integer types.
// Any integer type irrespective of its width.
-def AnyInteger : Type<CPred<"$_self.isa<IntegerType>()">, "integer">;
+def AnySignlessInteger : Type<
+ CPred<"$_self.isSignlessInteger()">, "integer">;
// Index type.
def Index : Type<CPred<"$_self.isa<IndexType>()">, "index">,
@@ -318,7 +319,7 @@ def Index : Type<CPred<"$_self.isa<IndexType>()">, "index">,
// Integer type of a specific width.
class I<int width>
- : Type<CPred<"$_self.isInteger(" # width # ")">,
+ : Type<CPred<"$_self.isSignlessInteger(" # width # ")">,
width # "-bit integer">,
BuildableType<"$_builder.getIntegerType(" # width # ")"> {
int bitwidth = width;
@@ -586,8 +587,10 @@ def BoolLike : TypeConstraint<Or<[I1.predicate, VectorOf<[I1]>.predicate,
// Type constraint for integer-like types: integers, indices, vectors of
// integers, tensors of integers.
-def IntegerLike : TypeConstraint<Or<[AnyInteger.predicate, Index.predicate,
- VectorOf<[AnyInteger]>.predicate, TensorOf<[AnyInteger]>.predicate]>,
+def SignlessIntegerLike : TypeConstraint<Or<[
+ AnySignlessInteger.predicate, Index.predicate,
+ VectorOf<[AnySignlessInteger]>.predicate,
+ TensorOf<[AnySignlessInteger]>.predicate]>,
"integer-like">;
// Type constraint for float-like types: floats, vectors or tensors thereof.
@@ -596,8 +599,8 @@ def FloatLike : TypeConstraint<Or<[AnyFloat.predicate,
"floating-point-like">;
// Type constraint for integer-like or float-like types.
-def IntegerOrFloatLike : TypeConstraint<Or<[IntegerLike.predicate,
- FloatLike.predicate]>,
+def SignlessIntegerOrFloatLike : TypeConstraint<Or<[
+ SignlessIntegerLike.predicate, FloatLike.predicate]>,
"integer-like or floating-point-like">;
@@ -725,7 +728,7 @@ class IntegerAttrBase<I attrValType, string descr> :
attrValType, "IntegerAttr",
And<[CPred<"$_self.isa<IntegerAttr>()">,
CPred<"$_self.cast<IntegerAttr>().getType()."
- "isInteger(" # attrValType.bitwidth # ")">]>,
+ "isSignlessInteger(" # attrValType.bitwidth # ")">]>,
descr> {
let returnType = [{ APInt }];
}
@@ -1031,7 +1034,7 @@ def ElementsAttr : ElementsAttrBase<CPred<"$_self.isa<ElementsAttr>()">,
class IntElementsAttr<int width> : ElementsAttrBase<
CPred<"$_self.isa<DenseIntElementsAttr>() &&"
"$_self.cast<DenseIntElementsAttr>().getType()."
- "getElementType().isInteger(" # width # ")">,
+ "getElementType().isSignlessInteger(" # width # ")">,
width # "-bit integer elements attribute"> {
let storageType = [{ DenseIntElementsAttr }];
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 2c2523e11cde..87ffb8110427 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -365,7 +365,7 @@ LogicalResult verifyOneOperand(Operation *op);
LogicalResult verifyNOperands(Operation *op, unsigned numOperands);
LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands);
LogicalResult verifyOperandsAreFloatLike(Operation *op);
-LogicalResult verifyOperandsAreIntegerLike(Operation *op);
+LogicalResult verifyOperandsAreSignlessIntegerLike(Operation *op);
LogicalResult verifySameTypeOperands(Operation *op);
LogicalResult verifyZeroResult(Operation *op);
LogicalResult verifyOneResult(Operation *op);
@@ -378,7 +378,7 @@ LogicalResult verifySameOperandsAndResultElementType(Operation *op);
LogicalResult verifySameOperandsAndResultType(Operation *op);
LogicalResult verifyResultsAreBoolLike(Operation *op);
LogicalResult verifyResultsAreFloatLike(Operation *op);
-LogicalResult verifyResultsAreIntegerLike(Operation *op);
+LogicalResult verifyResultsAreSignlessIntegerLike(Operation *op);
LogicalResult verifyIsTerminator(Operation *op);
LogicalResult verifyOperandSizeAttr(Operation *op, StringRef sizeAttrName);
LogicalResult verifyResultSizeAttr(Operation *op, StringRef sizeAttrName);
@@ -725,14 +725,14 @@ class ResultsAreFloatLike
}
};
-/// This class verifies that any results of the specified op have an integer or
-/// index type, a vector thereof, or a tensor thereof.
+/// This class verifies that any results of the specified op have a signless
+/// integer or index type, a vector thereof, or a tensor thereof.
template <typename ConcreteType>
-class ResultsAreIntegerLike
- : public TraitBase<ConcreteType, ResultsAreIntegerLike> {
+class ResultsAreSignlessIntegerLike
+ : public TraitBase<ConcreteType, ResultsAreSignlessIntegerLike> {
public:
static LogicalResult verifyTrait(Operation *op) {
- return impl::verifyResultsAreIntegerLike(op);
+ return impl::verifyResultsAreSignlessIntegerLike(op);
}
};
@@ -767,14 +767,14 @@ class OperandsAreFloatLike
}
};
-/// This class verifies that all operands of the specified op have an integer or
-/// index type, a vector thereof, or a tensor thereof.
+/// This class verifies that all operands of the specified op have a signless
+/// integer or index type, a vector thereof, or a tensor thereof.
template <typename ConcreteType>
-class OperandsAreIntegerLike
- : public TraitBase<ConcreteType, OperandsAreIntegerLike> {
+class OperandsAreSignlessIntegerLike
+ : public TraitBase<ConcreteType, OperandsAreSignlessIntegerLike> {
public:
static LogicalResult verifyTrait(Operation *op) {
- return impl::verifyOperandsAreIntegerLike(op);
+ return impl::verifyOperandsAreSignlessIntegerLike(op);
}
};
diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index 8339f5e26943..b2f8a7a109fc 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -503,7 +503,7 @@ class Operation final
SmallVectorImpl<OpFoldResult> &results);
/// Returns if the operation was registered with a particular trait, e.g.
- /// hasTrait<OperandsAreIntegerLike>().
+ /// hasTrait<OperandsAreSignlessIntegerLike>().
template <template <typename T> class Trait> bool hasTrait() {
auto *absOp = getAbstractOperation();
return absOp ? absOp->hasTrait<Trait>() : false;
diff --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h
index 7356c7be75b1..9bb9a8c06234 100644
--- a/mlir/include/mlir/IR/StandardTypes.h
+++ b/mlir/include/mlir/IR/StandardTypes.h
@@ -84,24 +84,57 @@ class IntegerType
public:
using Base::Base;
+ /// Signedness semantics.
+ enum SignednessSemantics {
+ Signless, /// No signedness semantics
+ Signed, /// Signed integer
+ Unsigned, /// Unsigned integer
+ };
+
/// Get or create a new IntegerType of the given width within the context.
- /// Assume the width is within the allowed range and assert on failures.
- /// Use getChecked to handle failures gracefully.
+ /// The created IntegerType is signless (i.e., no signedness semantics).
+ /// Assume the width is within the allowed range and assert on failures. Use
+ /// getChecked to handle failures gracefully.
static IntegerType get(unsigned width, MLIRContext *context);
+ /// Get or create a new IntegerType of the given width within the context.
+ /// The created IntegerType has signedness semantics as indicated via
+ /// `signedness`. Assume the width is within the allowed range and assert on
+ /// failures. Use getChecked to handle failures gracefully.
+ static IntegerType get(unsigned width, SignednessSemantics signedness,
+ MLIRContext *context);
+
/// Get or create a new IntegerType of the given width within the context,
- /// defined at the given, potentially unknown, location. If the width is
+ /// defined at the given, potentially unknown, location. The created
+ /// IntegerType is signless (i.e., no signedness semantics). If the width is
/// outside the allowed range, emit errors and return a null type.
- static IntegerType getChecked(unsigned width, MLIRContext *context,
+ static IntegerType getChecked(unsigned width, Location location);
+
+ /// Get or create a new IntegerType of the given width within the context,
+ /// defined at the given, potentially unknown, location. The created
+ /// IntegerType has signedness semantics as indicated via `signedness`. If the
+ /// width is outside the allowed range, emit errors and return a null type.
+ static IntegerType getChecked(unsigned width, SignednessSemantics signedness,
Location location);
/// Verify the construction of an integer type.
- static LogicalResult verifyConstructionInvariants(Location loc,
- unsigned width);
+ static LogicalResult
+ verifyConstructionInvariants(Location loc, unsigned width,
+ SignednessSemantics signedness);
/// Return the bitwidth of this integer type.
unsigned getWidth() const;
+ /// Return the signedness semantics of this integer type.
+ SignednessSemantics getSignedness() const;
+
+ /// Return true if this is a singless integer type.
+ bool isSignless() const { return getSignedness() == Signless; }
+ /// Return true if this is a signed integer type.
+ bool isSigned() const { return getSignedness() == Signed; }
+ /// Return true if this is an unsigned integer type.
+ bool isUnsigned() const { return getSignedness() == Unsigned; }
+
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool kindof(unsigned kind) { return kind == StandardTypes::Integer; }
@@ -274,7 +307,9 @@ class VectorType
/// Returns true of the given type can be used as an element of a vector type.
/// In particular, vectors can consist of integer or float primitives.
- static bool isValidElementType(Type t) { return t.isIntOrFloat(); }
+ static bool isValidElementType(Type t) {
+ return t.isa<IntegerType>() || t.isa<FloatType>();
+ }
ArrayRef<int64_t> getShape() const;
@@ -293,7 +328,7 @@ class TensorType : public ShapedType {
// Note: Non standard/builtin types are allowed to exist within tensor
// types. Dialects are expected to verify that tensor types have a valid
// element type within that dialect.
- return type.isIntOrFloat() || type.isa<ComplexType>() ||
+ return type.isSignlessIntOrFloat() || type.isa<ComplexType>() ||
type.isa<VectorType>() || type.isa<OpaqueType>() ||
(type.getKind() > Type::Kind::LAST_STANDARD_TYPE);
}
diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index abbc282b35d1..ef4b0511cb07 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -146,17 +146,27 @@ class Type {
/// Return true if this is an integer type with the specified width.
bool isInteger(unsigned width);
+ /// Return true if this is a signless integer type (with the specified width).
+ bool isSignlessInteger();
+ bool isSignlessInteger(unsigned width);
+ /// Return true if this is a signed integer type (with the specified width).
+ bool isSignedInteger();
+ bool isSignedInteger(unsigned width);
+ /// Return true if this is an unsigned integer type (with the specified
+ /// width).
+ bool isUnsignedInteger();
+ bool isUnsignedInteger(unsigned width);
/// Return the bit width of an integer or a float type, assert failure on
/// other types.
unsigned getIntOrFloatBitWidth();
- /// Return true if this is an integer or index type.
- bool isIntOrIndex();
- /// Return true if this is an integer, index, or float type.
- bool isIntOrIndexOrFloat();
- /// Return true of this is an integer or a float type.
- bool isIntOrFloat();
+ /// Return true if this is a signless integer or index type.
+ bool isSignlessIntOrIndex();
+ /// Return true if this is a signless integer, index, or float type.
+ bool isSignlessIntOrIndexOrFloat();
+ /// Return true of this is a signless integer or a float type.
+ bool isSignlessIntOrFloat();
/// Print the current type.
void print(raw_ostream &os);
diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp
index 5bcc02c16c83..e62e4eb360cc 100644
--- a/mlir/lib/Analysis/Utils.cpp
+++ b/mlir/lib/Analysis/Utils.cpp
@@ -314,7 +314,7 @@ static unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
auto elementType = memRefType.getElementType();
unsigned sizeInBits;
- if (elementType.isIntOrFloat()) {
+ if (elementType.isSignlessIntOrFloat()) {
sizeInBits = elementType.getIntOrFloatBitWidth();
} else {
auto vectorType = elementType.cast<VectorType>();
@@ -358,7 +358,7 @@ Optional<uint64_t> mlir::getMemRefSizeInBytes(MemRefType memRefType) {
if (!memRefType.hasStaticShape())
return None;
auto elementType = memRefType.getElementType();
- if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>())
+ if (!elementType.isSignlessIntOrFloat() && !elementType.isa<VectorType>())
return None;
uint64_t sizeInBytes = getMemRefEltSizeInBytes(memRefType);
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index df0a0535e7ff..bb7e18762a6d 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -45,7 +45,7 @@ struct GPUAllReduceOpLowering : public ConvertToLLVMPattern {
Value operand = operands.front();
// TODO(csigg): Generalize to other types of accumulation.
- assert(op->getOperand(0).getType().isIntOrFloat());
+ assert(op->getOperand(0).getType().isSignlessIntOrFloat());
// Create the reduction using an accumulator factory.
AccumulatorFactory factory =
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index f2048f1ba530..4d061e825163 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -289,7 +289,7 @@ class VectorReductionOpConversion : public ConvertToLLVMPattern {
auto kind = reductionOp.kind();
Type eltType = reductionOp.dest().getType();
Type llvmType = typeConverter.convertType(eltType);
- if (eltType.isInteger(32) || eltType.isInteger(64)) {
+ if (eltType.isSignlessInteger(32) || eltType.isSignlessInteger(64)) {
// Integer reductions: add/mul/min/max/and/or/xor.
if (kind == "add")
rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_add>(
@@ -989,9 +989,9 @@ class VectorPrintOpConversion : public ConvertToLLVMPattern {
Type eltType = vectorType ? vectorType.getElementType() : printType;
int64_t rank = vectorType ? vectorType.getRank() : 0;
Operation *printer;
- if (eltType.isInteger(32))
+ if (eltType.isSignlessInteger(32))
printer = getPrintI32(op);
- else if (eltType.isInteger(64))
+ else if (eltType.isSignlessInteger(64))
printer = getPrintI64(op);
else if (eltType.isF32())
printer = getPrintFloat(op);
@@ -1111,7 +1111,7 @@ class VectorStridedSliceOpConversion : public OpRewritePattern<StridedSliceOp> {
auto loc = op.getLoc();
auto elemType = dstType.getElementType();
- assert(elemType.isIntOrIndexOrFloat());
+ assert(elemType.isSignlessIntOrIndexOrFloat());
Value zero = rewriter.create<ConstantOp>(loc, elemType,
rewriter.getZeroAttr(elemType));
Value res = rewriter.create<SplatOp>(loc, dstType, zero);
diff --git a/mlir/lib/Dialect/AffineOps/EDSC/Builders.cpp b/mlir/lib/Dialect/AffineOps/EDSC/Builders.cpp
index 88f363302f84..301d3fded858 100644
--- a/mlir/lib/Dialect/AffineOps/EDSC/Builders.cpp
+++ b/mlir/lib/Dialect/AffineOps/EDSC/Builders.cpp
@@ -150,13 +150,13 @@ static ValueHandle createBinaryHandle(
(void)thatType;
if (thisType.isIndex()) {
return createBinaryIndexHandle(lhs, rhs, affCombiner);
- } else if (thisType.isa<IntegerType>()) {
+ } else if (thisType.isSignlessInteger()) {
return createBinaryHandle<IOp>(lhs, rhs);
} else if (thisType.isa<FloatType>()) {
return createBinaryHandle<FOp>(lhs, rhs);
} else if (thisType.isa<VectorType>() || thisType.isa<TensorType>()) {
auto aggregateType = thisType.cast<ShapedType>();
- if (aggregateType.getElementType().isa<IntegerType>())
+ if (aggregateType.getElementType().isSignedInteger())
return createBinaryHandle<IOp>(lhs, rhs);
else if (aggregateType.getElementType().isa<FloatType>())
return createBinaryHandle<FOp>(lhs, rhs);
@@ -223,7 +223,7 @@ static ValueHandle createIComparisonExpr(CmpIPredicate predicate,
(void)lhsType;
(void)rhsType;
assert(lhsType == rhsType && "cannot mix types in operators");
- assert((lhsType.isa<IndexType>() || lhsType.isa<IntegerType>()) &&
+ assert((lhsType.isa<IndexType>() || lhsType.isSignedInteger()) &&
"only integer comparisons are supported");
auto op = ScopedContext::getBuilder().create<CmpIOp>(
diff --git a/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h b/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h
index 0fa359024fe8..1f52156533f2 100644
--- a/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h
+++ b/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h
@@ -172,7 +172,7 @@ inline Type castElementType(Type t, Type newElementType) {
.setElementType(newElementType);
}
}
- assert(t.isIntOrFloat());
+ assert(t.isSignlessIntOrFloat());
return newElementType;
}
@@ -180,13 +180,13 @@ inline Type castElementType(Type t, Type newElementType) {
/// be a scalar primitive or a shaped type).
inline Attribute broadcastScalarConstIntValue(Type t, int64_t value) {
if (auto st = t.dyn_cast<ShapedType>()) {
- assert(st.getElementType().isa<IntegerType>());
+ assert(st.getElementType().isSignlessInteger());
return DenseElementsAttr::get(st,
IntegerAttr::get(st.getElementType(), value));
}
auto integerType = t.cast<IntegerType>();
- assert(t.isa<IntegerType>() && "integer broadcast must be of integer type");
+ assert(t.isSignlessInteger() && "integer broadcast must be of integer type");
return IntegerAttr::get(integerType, value);
}
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 7bc45cbe7fef..dc83cbc70471 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -158,7 +158,7 @@ static LogicalResult verifyShuffleOp(gpu::ShuffleOp shuffleOp) {
return shuffleOp.emitOpError()
<< "requires the same type for value operand and result";
}
- if (!type.isIntOrFloat() || type.getIntOrFloatBitWidth() != 32) {
+ if (!type.isSignlessIntOrFloat() || type.getIntOrFloatBitWidth() != 32) {
return shuffleOp.emitOpError()
<< "requires value operand type to be f32 or i32";
}
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index a2fe01847edb..6720837cae7f 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1022,7 +1022,7 @@ static void appendMangledType(llvm::raw_string_ostream &ss, Type t) {
interleave(
vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; });
appendMangledType(ss, vec.getElementType());
- } else if (t.isIntOrIndexOrFloat()) {
+ } else if (t.isSignlessIntOrIndexOrFloat()) {
ss << t;
} else {
llvm_unreachable("Invalid type for linalg library name mangling");
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
index f226496e7fdd..e7a462d1a5df 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
@@ -218,7 +218,7 @@ static void promoteSubViews(FuncOp f, bool dynamicBuffers) {
OpBuilder b(op);
for (auto it : op.getInputsAndOutputBuffers())
if (auto sv = dyn_cast_or_null<SubViewOp>(it.getDefiningOp()))
- if (sv.getType().getElementType().isIntOrFloat())
+ if (sv.getType().getElementType().isSignlessIntOrFloat())
subViews.insert(sv);
if (!subViews.empty()) {
promoteSubViewOperands(b, op, subViews, dynamicBuffers, &folder);
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 29a4a679fb2f..d445a09a2f93 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -36,8 +36,8 @@ RegionMatcher::matchAsScalarBinaryOp(GenericOp op) {
Block &block = region.front();
if (block.getNumArguments() != 2 ||
- !block.getArgument(0).getType().isIntOrFloat() ||
- !block.getArgument(1).getType().isIntOrFloat())
+ !block.getArgument(0).getType().isSignlessIntOrFloat() ||
+ !block.getArgument(1).getType().isSignlessIntOrFloat())
return llvm::None;
auto &ops = block.getOperations();
@@ -97,7 +97,7 @@ mlir::linalg::getAssumedNonViewOperands(LinalgOp linalgOp) {
res.push_back(op->getOperand(numViews + i));
auto t = res.back().getType();
(void)t;
- assert((t.isIntOrIndexOrFloat() || t.isa<VectorType>()) &&
+ assert((t.isSignlessIntOrIndexOrFloat() || t.isa<VectorType>()) &&
"expected scalar or vector type");
}
return res;
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index 6e97c3f58a66..89c243bc09ca 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -392,11 +392,11 @@ static unsigned getBitWidth(Type type) {
// TODO: Make sure not caller relies on the actual pointer width value.
return 64;
}
- if (type.isIntOrFloat()) {
+ if (type.isSignlessIntOrFloat()) {
return type.getIntOrFloatBitWidth();
}
if (auto vectorType = type.dyn_cast<VectorType>()) {
- assert(vectorType.getElementType().isIntOrFloat());
+ assert(vectorType.getElementType().isSignlessIntOrFloat());
return vectorType.getNumElements() *
vectorType.getElementType().getIntOrFloatBitWidth();
}
@@ -537,7 +537,7 @@ static void printAtomicUpdateOp(Operation *op, OpAsmPrinter &printer) {
static LogicalResult verifyAtomicUpdateOp(Operation *op) {
auto ptrType = op->getOperand(0).getType().cast<spirv::PointerType>();
auto elementType = ptrType.getPointeeType();
- if (!elementType.isa<IntegerType>())
+ if (!elementType.isSignlessInteger())
return op->emitOpError(
"pointer operand must point to an integer value, found ")
<< elementType;
@@ -1382,7 +1382,7 @@ static LogicalResult verify(spirv::ConstantOp constOp) {
numElements *= t.getNumElements();
opElemType = t.getElementType();
}
- if (!opElemType.isIntOrFloat()) {
+ if (!opElemType.isSignlessIntOrFloat()) {
return constOp.emitOpError("only support nested array result type");
}
diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
index 23fef757ed1e..21d4e7fc03ea 100644
--- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
+++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
@@ -1451,7 +1451,7 @@ LogicalResult Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
}
auto resultID = operands[1];
- if (resultType.isa<IntegerType>() || resultType.isa<FloatType>() ||
+ if (resultType.isSignlessInteger() || resultType.isa<FloatType>() ||
resultType.isa<VectorType>()) {
auto attr = opBuilder.getZeroAttr(resultType);
// For normal constants, we just record the attribute (and its type) for
@@ -2202,7 +2202,7 @@ LogicalResult Deserializer::processBitcast(ArrayRef<uint32_t> words) {
<< wordIndex << " of " << words.size() << " processed";
}
if (resultTypes[0] == operands[0].getType() &&
- resultTypes[0].isa<IntegerType>()) {
+ resultTypes[0].isSignlessInteger()) {
// TODO(b/130356985): This check is added to ignore error in Op verification
// due to both signed and unsigned integers mapping to the same
// type. Without this check this method is same as what is auto-generated.
diff --git a/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp b/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp
index 5c44d60b548b..ee5dcb1f7f85 100644
--- a/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp
+++ b/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp
@@ -105,7 +105,7 @@ DictionaryAttr spirv::TargetEnvAttr::getResourceLimits() {
LogicalResult spirv::TargetEnvAttr::verifyConstructionInvariants(
Location loc, IntegerAttr version, ArrayAttr extensions,
ArrayAttr capabilities, DictionaryAttr limits) {
- if (!version.getType().isInteger(32))
+ if (!version.getType().isSignlessInteger(32))
return emitError(loc, "expected 32-bit integer for version");
if (!llvm::all_of(extensions.getValue(), [](Attribute attr) {
diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp
index 7318861938a4..6cea308bc584 100644
--- a/mlir/lib/Dialect/StandardOps/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/Ops.cpp
@@ -540,7 +540,7 @@ void CallIndirectOp::getCanonicalizationPatterns(
// Return the type of the same shape (scalar, vector or tensor) containing i1.
static Type getCheckedI1SameShape(Type type) {
auto i1Type = IntegerType::get(1, type.getContext());
- if (type.isIntOrIndexOrFloat())
+ if (type.isSignlessIntOrIndexOrFloat())
return i1Type;
if (auto tensorType = type.dyn_cast<RankedTensorType>())
return RankedTensorType::get(tensorType.getShape(), i1Type);
@@ -1077,7 +1077,7 @@ bool ConstantFloatOp::classof(Operation *op) {
/// ConstantIntOp only matches values whose result type is an IntegerType.
bool ConstantIntOp::classof(Operation *op) {
return ConstantOp::classof(op) &&
- op->getResult(0).getType().isa<IntegerType>();
+ op->getResult(0).getType().isSignlessInteger();
}
void ConstantIntOp::build(Builder *builder, OperationState &result,
@@ -1091,7 +1091,8 @@ void ConstantIntOp::build(Builder *builder, OperationState &result,
/// which must be an integer type.
void ConstantIntOp::build(Builder *builder, OperationState &result,
int64_t value, Type type) {
- assert(type.isa<IntegerType>() && "ConstantIntOp can only have integer type");
+ assert(type.isSignlessInteger() &&
+ "ConstantIntOp can only have signless integer type");
ConstantOp::build(builder, result, type,
builder->getIntegerAttr(type, value));
}
@@ -1553,8 +1554,8 @@ OpFoldResult ExtractElementOp::fold(ArrayRef<Attribute> operands) {
// Index cast is applicable from index to integer and backwards.
bool IndexCastOp::areCastCompatible(Type a, Type b) {
- return (a.isIndex() && b.isa<IntegerType>()) ||
- (a.isa<IntegerType>() && b.isIndex());
+ return (a.isIndex() && b.isSignlessInteger()) ||
+ (a.isSignlessInteger() && b.isIndex());
}
OpFoldResult IndexCastOp::fold(ArrayRef<Attribute> cstOperands) {
@@ -1894,7 +1895,7 @@ static LogicalResult verify(ReturnOp op) {
// sitofp is applicable from integer types to float types.
bool SIToFPOp::areCastCompatible(Type a, Type b) {
- return a.isa<IntegerType>() && b.isa<FloatType>();
+ return a.isSignlessInteger() && b.isa<FloatType>();
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp
index 27623f113f1e..454c96733ea7 100644
--- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp
+++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp
@@ -74,13 +74,13 @@ static LogicalResult verify(ReductionOp op) {
auto kind = op.kind();
Type eltType = op.dest().getType();
if (kind == "add" || kind == "mul" || kind == "min" || kind == "max") {
- if (eltType.isF32() || eltType.isF64() || eltType.isInteger(32) ||
- eltType.isInteger(64))
+ if (eltType.isF32() || eltType.isF64() || eltType.isSignlessInteger(32) ||
+ eltType.isSignlessInteger(64))
return success();
return op.emitOpError("unsupported reduction type");
}
if (kind == "and" || kind == "or" || kind == "xor") {
- if (eltType.isInteger(32) || eltType.isInteger(64))
+ if (eltType.isSignlessInteger(32) || eltType.isSignlessInteger(64))
return success();
return op.emitOpError("unsupported reduction type");
}
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 05e3645151ea..140f533e0b15 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1258,12 +1258,14 @@ void ModulePrinter::printAttribute(Attribute attr,
break;
case StandardAttributes::Integer: {
auto intAttr = attr.cast<IntegerAttr>();
- // Print all integer attributes as signed unless i1.
- bool isSigned = attrType.isIndex() || attrType.getIntOrFloatBitWidth() != 1;
+ // Print all signed/signless integer attributes as signed unless i1.
+ bool isSigned =
+ attrType.isIndex() || (!attrType.isUnsignedInteger() &&
+ attrType.getIntOrFloatBitWidth() != 1);
intAttr.getValue().print(os, isSigned);
// IntegerAttr elides the type if I64.
- if (typeElision == AttrTypeElision::May && attrType.isInteger(64))
+ if (typeElision == AttrTypeElision::May && attrType.isSignlessInteger(64))
return;
break;
}
@@ -1495,6 +1497,10 @@ void ModulePrinter::printType(Type type) {
case StandardTypes::Integer: {
auto integer = type.cast<IntegerType>();
+ if (integer.isSigned())
+ os << 's';
+ else if (integer.isUnsigned())
+ os << 'u';
os << 'i' << integer.getWidth();
return;
}
diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp
index 5691184fefda..5beb12a59940 100644
--- a/mlir/lib/IR/Attributes.cpp
+++ b/mlir/lib/IR/Attributes.cpp
@@ -273,12 +273,28 @@ IntegerAttr IntegerAttr::get(Type type, int64_t value) {
return get(type, APInt(64, value));
auto intType = type.cast<IntegerType>();
- return get(type, APInt(intType.getWidth(), value));
+ return get(type, APInt(intType.getWidth(), value, intType.isSignedInteger()));
}
APInt IntegerAttr::getValue() const { return getImpl()->getValue(); }
-int64_t IntegerAttr::getInt() const { return getValue().getSExtValue(); }
+int64_t IntegerAttr::getInt() const {
+ assert((getImpl()->getType().isIndex() ||
+ getImpl()->getType().isSignlessInteger()) &&
+ "must be signless integer");
+ return getValue().getSExtValue();
+}
+
+int64_t IntegerAttr::getSInt() const {
+ assert(getImpl()->getType().isSignedInteger() && "must be signed integer");
+ return getValue().getSExtValue();
+}
+
+uint64_t IntegerAttr::getUInt() const {
+ assert(getImpl()->getType().isUnsignedInteger() &&
+ "must be unsigned integer");
+ return getValue().getZExtValue();
+}
static LogicalResult verifyIntegerTypeInvariants(Location loc, Type type) {
if (type.isa<IntegerType>() || type.isa<IndexType>())
@@ -592,7 +608,7 @@ DenseElementsAttr::FloatElementIterator::FloatElementIterator(
DenseElementsAttr DenseElementsAttr::get(ShapedType type,
ArrayRef<Attribute> values) {
- assert(type.getElementType().isIntOrFloat() &&
+ assert(type.getElementType().isSignlessIntOrFloat() &&
"expected int or float element type");
assert(hasSameElementsOrSplat(type, values));
@@ -700,19 +716,28 @@ DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type,
data, isSplat);
}
-/// Check the information for a c++ data type, check if this type is valid for
+/// Check the information for a C++ data type, check if this type is valid for
/// the current attribute. This method is used to verify specific type
/// invariants that the templatized 'getValues' method cannot.
-static bool isValidIntOrFloat(ShapedType type, int64_t dataEltSize,
- bool isInt) {
+static bool isValidIntOrFloat(ShapedType type, int64_t dataEltSize, bool isInt,
+ bool isSigned) {
// Make sure that the data element size is the same as the type element width.
if (getDenseElementBitwidth(type.getElementType()) !=
static_cast<size_t>(dataEltSize * CHAR_BIT))
return false;
- // Check that the element type is valid.
- return isInt ? type.getElementType().isa<IntegerType>()
- : type.getElementType().isa<FloatType>();
+ // Check that the element type is either float or integer.
+ if (!isInt)
+ return type.getElementType().isa<FloatType>();
+
+ auto intType = type.getElementType().dyn_cast<IntegerType>();
+ if (!intType)
+ return false;
+
+ // Make sure signedness semantics is consistent.
+ if (intType.isSignless())
+ return true;
+ return intType.isSigned() ? isSigned : !isSigned;
}
/// Overload of the 'getRaw' method that asserts that the given type is of
@@ -721,8 +746,9 @@ static bool isValidIntOrFloat(ShapedType type, int64_t dataEltSize,
DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type,
ArrayRef<char> data,
int64_t dataEltSize,
- bool isInt) {
- assert(::isValidIntOrFloat(type, dataEltSize, isInt));
+ bool isInt,
+ bool isSigned) {
+ assert(::isValidIntOrFloat(type, dataEltSize, isInt, isSigned));
int64_t numElements = data.size() / dataEltSize;
assert(numElements == 1 || numElements == type.getNumElements());
@@ -731,9 +757,9 @@ DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type,
/// A method used to verify specific type invariants that the templatized 'get'
/// method cannot.
-bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize,
- bool isInt) const {
- return ::isValidIntOrFloat(getType(), dataEltSize, isInt);
+bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt,
+ bool isSigned) const {
+ return ::isValidIntOrFloat(getType(), dataEltSize, isInt, isSigned);
}
/// Returns if this attribute corresponds to a splat, i.e. if all element
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index e406d79da1a2..b208710ac247 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -59,6 +59,11 @@ IntegerType Builder::getIntegerType(unsigned width) {
return IntegerType::get(width, context);
}
+IntegerType Builder::getIntegerType(unsigned width, bool isSigned) {
+ return IntegerType::get(
+ width, isSigned ? IntegerType::Signed : IntegerType::Unsigned, context);
+}
+
FunctionType Builder::getFunctionType(ArrayRef<Type> inputs,
ArrayRef<Type> results) {
return FunctionType::get(inputs, results, context);
@@ -104,6 +109,16 @@ IntegerAttr Builder::getI32IntegerAttr(int32_t value) {
return IntegerAttr::get(getIntegerType(32), APInt(32, value));
}
+IntegerAttr Builder::getSI32IntegerAttr(int32_t value) {
+ return IntegerAttr::get(getIntegerType(32, /*isSigned=*/true),
+ APInt(32, value, /*isSigned=*/true));
+}
+
+IntegerAttr Builder::getUI32IntegerAttr(uint32_t value) {
+ return IntegerAttr::get(getIntegerType(32, /*isSigned=*/false),
+ APInt(32, value, /*isSigned=*/false));
+}
+
IntegerAttr Builder::getI16IntegerAttr(int16_t value) {
return IntegerAttr::get(getIntegerType(16), APInt(16, value));
}
@@ -115,7 +130,8 @@ IntegerAttr Builder::getI8IntegerAttr(int8_t value) {
IntegerAttr Builder::getIntegerAttr(Type type, int64_t value) {
if (type.isIndex())
return IntegerAttr::get(type, APInt(64, value));
- return IntegerAttr::get(type, APInt(type.getIntOrFloatBitWidth(), value));
+ return IntegerAttr::get(
+ type, APInt(type.getIntOrFloatBitWidth(), value, type.isSignedInteger()));
}
IntegerAttr Builder::getIntegerAttr(Type type, const APInt &value) {
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 311fcc35ebdd..d38fdb00cd17 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -245,16 +245,18 @@ MLIRContext::MLIRContext() : impl(new MLIRContextImpl()) {
/// Index Type.
impl->indexTy = TypeUniquer::get<IndexType>(this, StandardTypes::Index);
/// Integer Types.
- impl->int1Ty = TypeUniquer::get<IntegerType>(this, StandardTypes::Integer, 1);
- impl->int8Ty = TypeUniquer::get<IntegerType>(this, StandardTypes::Integer, 8);
- impl->int16Ty =
- TypeUniquer::get<IntegerType>(this, StandardTypes::Integer, 16);
- impl->int32Ty =
- TypeUniquer::get<IntegerType>(this, StandardTypes::Integer, 32);
- impl->int64Ty =
- TypeUniquer::get<IntegerType>(this, StandardTypes::Integer, 64);
- impl->int128Ty =
- TypeUniquer::get<IntegerType>(this, StandardTypes::Integer, 128);
+ impl->int1Ty = TypeUniquer::get<IntegerType>(this, StandardTypes::Integer, 1,
+ IntegerType::Signless);
+ impl->int8Ty = TypeUniquer::get<IntegerType>(this, StandardTypes::Integer, 8,
+ IntegerType::Signless);
+ impl->int16Ty = TypeUniquer::get<IntegerType>(this, StandardTypes::Integer,
+ 16, IntegerType::Signless);
+ impl->int32Ty = TypeUniquer::get<IntegerType>(this, StandardTypes::Integer,
+ 32, IntegerType::Signless);
+ impl->int64Ty = TypeUniquer::get<IntegerType>(this, StandardTypes::Integer,
+ 64, IntegerType::Signless);
+ impl->int128Ty = TypeUniquer::get<IntegerType>(this, StandardTypes::Integer,
+ 128, IntegerType::Signless);
/// None Type.
impl->noneType = TypeUniquer::get<NoneType>(this, StandardTypes::None);
@@ -489,7 +491,13 @@ IndexType IndexType::get(MLIRContext *context) {
/// Return an existing integer type instance if one is cached within the
/// context.
-static IntegerType getCachedIntegerType(unsigned width, MLIRContext *context) {
+static IntegerType
+getCachedIntegerType(unsigned width,
+ IntegerType::SignednessSemantics signedness,
+ MLIRContext *context) {
+ if (signedness != IntegerType::Signless)
+ return IntegerType();
+
switch (width) {
case 1:
return context->getImpl().int1Ty;
@@ -509,16 +517,28 @@ static IntegerType getCachedIntegerType(unsigned width, MLIRContext *context) {
}
IntegerType IntegerType::get(unsigned width, MLIRContext *context) {
- if (auto cached = getCachedIntegerType(width, context))
+ return get(width, IntegerType::Signless, context);
+}
+
+IntegerType IntegerType::get(unsigned width,
+ IntegerType::SignednessSemantics signedness,
+ MLIRContext *context) {
+ if (auto cached = getCachedIntegerType(width, signedness, context))
return cached;
- return Base::get(context, StandardTypes::Integer, width);
+ return Base::get(context, StandardTypes::Integer, width, signedness);
+}
+
+IntegerType IntegerType::getChecked(unsigned width, Location location) {
+ return getChecked(width, IntegerType::Signless, location);
}
-IntegerType IntegerType::getChecked(unsigned width, MLIRContext *context,
+IntegerType IntegerType::getChecked(unsigned width,
+ SignednessSemantics signedness,
Location location) {
- if (auto cached = getCachedIntegerType(width, context))
+ if (auto cached =
+ getCachedIntegerType(width, signedness, location->getContext()))
return cached;
- return Base::getChecked(location, StandardTypes::Integer, width);
+ return Base::getChecked(location, StandardTypes::Integer, width, signedness);
}
/// Get an instance of the NoneType.
diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index 6b19c2603f5b..49185eb159dd 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -792,10 +792,11 @@ static Type getTensorOrVectorElementType(Type type) {
return type;
}
-LogicalResult OpTrait::impl::verifyOperandsAreIntegerLike(Operation *op) {
+LogicalResult
+OpTrait::impl::verifyOperandsAreSignlessIntegerLike(Operation *op) {
for (auto opType : op->getOperandTypes()) {
auto type = getTensorOrVectorElementType(opType);
- if (!type.isIntOrIndex())
+ if (!type.isSignlessIntOrIndex())
return op->emitOpError() << "requires an integer or index type";
}
return success();
@@ -1006,9 +1007,10 @@ LogicalResult OpTrait::impl::verifyResultsAreFloatLike(Operation *op) {
return success();
}
-LogicalResult OpTrait::impl::verifyResultsAreIntegerLike(Operation *op) {
+LogicalResult
+OpTrait::impl::verifyResultsAreSignlessIntegerLike(Operation *op) {
for (auto resultType : op->getResultTypes())
- if (!getTensorOrVectorElementType(resultType).isIntOrIndex())
+ if (!getTensorOrVectorElementType(resultType).isSignlessIntOrIndex())
return op->emitOpError() << "requires an integer or index type";
return success();
}
diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp
index cb9386febde8..30d5bbcc7b3c 100644
--- a/mlir/lib/IR/StandardTypes.cpp
+++ b/mlir/lib/IR/StandardTypes.cpp
@@ -36,13 +36,53 @@ bool Type::isInteger(unsigned width) {
return false;
}
-bool Type::isIntOrIndex() { return isa<IndexType>() || isa<IntegerType>(); }
+bool Type::isSignlessInteger() {
+ if (auto intTy = dyn_cast<IntegerType>())
+ return intTy.isSignless();
+ return false;
+}
-bool Type::isIntOrIndexOrFloat() {
- return isa<IndexType>() || isa<IntegerType>() || isa<FloatType>();
+bool Type::isSignlessInteger(unsigned width) {
+ if (auto intTy = dyn_cast<IntegerType>())
+ return intTy.isSignless() && intTy.getWidth() == width;
+ return false;
}
-bool Type::isIntOrFloat() { return isa<IntegerType>() || isa<FloatType>(); }
+bool Type::isSignedInteger() {
+ if (auto intTy = dyn_cast<IntegerType>())
+ return intTy.isSigned();
+ return false;
+}
+
+bool Type::isSignedInteger(unsigned width) {
+ if (auto intTy = dyn_cast<IntegerType>())
+ return intTy.isSigned() && intTy.getWidth() == width;
+ return false;
+}
+
+bool Type::isUnsignedInteger() {
+ if (auto intTy = dyn_cast<IntegerType>())
+ return intTy.isUnsigned();
+ return false;
+}
+
+bool Type::isUnsignedInteger(unsigned width) {
+ if (auto intTy = dyn_cast<IntegerType>())
+ return intTy.isUnsigned() && intTy.getWidth() == width;
+ return false;
+}
+
+bool Type::isSignlessIntOrIndex() {
+ return isa<IndexType>() || isSignlessInteger();
+}
+
+bool Type::isSignlessIntOrIndexOrFloat() {
+ return isa<IndexType>() || isSignlessInteger() || isa<FloatType>();
+}
+
+bool Type::isSignlessIntOrFloat() {
+ return isSignlessInteger() || isa<FloatType>();
+}
//===----------------------------------------------------------------------===//
// Integer Type
@@ -52,16 +92,23 @@ bool Type::isIntOrFloat() { return isa<IntegerType>() || isa<FloatType>(); }
constexpr unsigned IntegerType::kMaxWidth;
/// Verify the construction of an integer type.
-LogicalResult IntegerType::verifyConstructionInvariants(Location loc,
- unsigned width) {
+LogicalResult
+IntegerType::verifyConstructionInvariants(Location loc, unsigned width,
+ SignednessSemantics signedness) {
if (width > IntegerType::kMaxWidth) {
return emitError(loc) << "integer bitwidth is limited to "
<< IntegerType::kMaxWidth << " bits";
}
+ if (width == 1 && signedness != IntegerType::Signless)
+ return emitOptionalError(loc, "cannot have signedness semantics for i1");
return success();
}
-unsigned IntegerType::getWidth() const { return getImpl()->width; }
+unsigned IntegerType::getWidth() const { return getImpl()->getWidth(); }
+
+IntegerType::SignednessSemantics IntegerType::getSignedness() const {
+ return getImpl()->getSignedness();
+}
//===----------------------------------------------------------------------===//
// Float Type
@@ -100,7 +147,7 @@ const llvm::fltSemantics &FloatType::getFloatSemantics() {
}
unsigned Type::getIntOrFloatBitWidth() {
- assert(isIntOrFloat() && "only ints and floats have a bitwidth");
+ assert(isSignlessIntOrFloat() && "only ints and floats have a bitwidth");
if (auto intType = dyn_cast<IntegerType>()) {
return intType.getWidth();
}
@@ -155,7 +202,7 @@ int64_t ShapedType::getSizeInBits() const {
"cannot get the bit size of an aggregate with a dynamic shape");
auto elementType = getElementType();
- if (elementType.isIntOrFloat())
+ if (elementType.isSignlessIntOrFloat())
return elementType.getIntOrFloatBitWidth() * getNumElements();
// Tensors can have vectors and other tensors as elements, other shaped types
@@ -326,7 +373,7 @@ MemRefType MemRefType::getImpl(ArrayRef<int64_t> shape, Type elementType,
auto *context = elementType.getContext();
// Check that memref is formed from allowed types.
- if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>() &&
+ if (!elementType.isSignlessIntOrFloat() && !elementType.isa<VectorType>() &&
!elementType.isa<ComplexType>())
return emitOptionalError(location, "invalid memref element type"),
MemRefType();
@@ -404,7 +451,7 @@ LogicalResult
UnrankedMemRefType::verifyConstructionInvariants(Location loc, Type elementType,
unsigned memorySpace) {
// Check that memref is formed from allowed types.
- if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>() &&
+ if (!elementType.isSignlessIntOrFloat() && !elementType.isa<VectorType>() &&
!elementType.isa<ComplexType>())
return emitError(loc, "invalid memref element type");
return success();
@@ -619,7 +666,7 @@ ComplexType ComplexType::getChecked(Type elementType, Location location) {
/// Verify the construction of an integer type.
LogicalResult ComplexType::verifyConstructionInvariants(Location loc,
Type elementType) {
- if (!elementType.isa<FloatType>() && !elementType.isa<IntegerType>())
+ if (!elementType.isa<FloatType>() && !elementType.isSignlessInteger())
return emitError(loc, "invalid element type for complex");
return success();
}
diff --git a/mlir/lib/IR/TypeDetail.h b/mlir/lib/IR/TypeDetail.h
index 84bd705aa372..72f1585be2d0 100644
--- a/mlir/lib/IR/TypeDetail.h
+++ b/mlir/lib/IR/TypeDetail.h
@@ -15,8 +15,8 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Identifier.h"
#include "mlir/IR/MLIRContext.h"
-#include "mlir/IR/TypeSupport.h"
-#include "mlir/IR/Types.h"
+#include "mlir/IR/StandardTypes.h"
+#include "llvm/ADT/bit.h"
#include "llvm/Support/TrailingObjects.h"
namespace mlir {
@@ -52,19 +52,49 @@ struct OpaqueTypeStorage : public TypeStorage {
/// Integer Type Storage and Uniquing.
struct IntegerTypeStorage : public TypeStorage {
- IntegerTypeStorage(unsigned width) : width(width) {}
+ IntegerTypeStorage(unsigned width,
+ IntegerType::SignednessSemantics signedness)
+ : TypeStorage(packKeyBits(width, signedness)) {}
/// The hash key used for uniquing.
- using KeyTy = unsigned;
- bool operator==(const KeyTy &key) const { return key == width; }
+ using KeyTy = std::pair<unsigned, IntegerType::SignednessSemantics>;
+
+ static llvm::hash_code hashKey(const KeyTy &key) {
+ return llvm::hash_value(packKeyBits(key.first, key.second));
+ }
+
+ bool operator==(const KeyTy &key) const {
+ return getSubclassData() == packKeyBits(key.first, key.second);
+ }
static IntegerTypeStorage *construct(TypeStorageAllocator &allocator,
- KeyTy bitwidth) {
+ KeyTy key) {
return new (allocator.allocate<IntegerTypeStorage>())
- IntegerTypeStorage(bitwidth);
+ IntegerTypeStorage(key.first, key.second);
}
- unsigned width;
+ struct KeyBits {
+ unsigned width : 30;
+ unsigned signedness : 2;
+ };
+
+ /// Pack the given `width` and `signedness` as a key.
+ static unsigned packKeyBits(unsigned width,
+ IntegerType::SignednessSemantics signedness) {
+ KeyBits bits{width, static_cast<unsigned>(signedness)};
+ return llvm::bit_cast<unsigned>(bits);
+ }
+
+ static KeyBits unpackKeyBits(unsigned bits) {
+ return llvm::bit_cast<KeyBits>(bits);
+ }
+
+ unsigned getWidth() { return unpackKeyBits(getSubclassData()).width; }
+
+ IntegerType::SignednessSemantics getSignedness() {
+ return static_cast<IntegerType::SignednessSemantics>(
+ unpackKeyBits(getSubclassData()).signedness);
+ }
};
/// Function Type Storage and Uniquing.
@@ -321,4 +351,5 @@ struct TupleTypeStorage final
} // namespace detail
} // namespace mlir
+
#endif // TYPEDETAIL_H_
diff --git a/mlir/lib/Parser/Lexer.cpp b/mlir/lib/Parser/Lexer.cpp
index 2ac74ed9f667..ffc8a0ec4846 100644
--- a/mlir/lib/Parser/Lexer.cpp
+++ b/mlir/lib/Parser/Lexer.cpp
@@ -187,7 +187,7 @@ Token Lexer::lexAtIdentifier(const char *tokStart) {
/// Lex a bare identifier or keyword that starts with a letter.
///
/// bare-id ::= (letter|[_]) (letter|digit|[_$.])*
-/// integer-type ::= `i[1-9][0-9]*`
+/// integer-type ::= `[su]?i[1-9][0-9]*`
///
Token Lexer::lexBareIdentifierOrKeyword(const char *tokStart) {
// Match the rest of the identifier regex: [0-9a-zA-Z_.$]*
@@ -198,14 +198,17 @@ Token Lexer::lexBareIdentifierOrKeyword(const char *tokStart) {
// Check to see if this identifier is a keyword.
StringRef spelling(tokStart, curPtr - tokStart);
- // Check for i123.
- if (tokStart[0] == 'i') {
- bool allDigits = true;
- for (auto c : spelling.drop_front())
- allDigits &= isdigit(c) != 0;
- if (allDigits && spelling.size() != 1)
- return Token(Token::inttype, spelling);
- }
+ auto isAllDigit = [](StringRef str) {
+ return llvm::all_of(str, [](char c) { return llvm::isDigit(c); });
+ };
+
+ // Check for i123, si456, ui789.
+ if ((spelling.size() > 1 && tokStart[0] == 'i' &&
+ isAllDigit(spelling.drop_front())) ||
+ ((spelling.size() > 2 && tokStart[1] == 'i' &&
+ (tokStart[0] == 's' || tokStart[0] == 'u')) &&
+ isAllDigit(spelling.drop_front(2))))
+ return Token(Token::inttype, spelling);
Token::Kind kind = llvm::StringSwitch<Token::Kind>(spelling)
#define TOK_KEYWORD(SPELLING) .Case(#SPELLING, Token::kw_##SPELLING)
diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index 9962fbf8c055..bfee3339a1db 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -1101,7 +1101,7 @@ Type Parser::parseMemRefType() {
return nullptr;
// Check that memref is formed from allowed types.
- if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>() &&
+ if (!elementType.isSignlessIntOrFloat() && !elementType.isa<VectorType>() &&
!elementType.isa<ComplexType>())
return emitError(typeLoc, "invalid memref element type"), nullptr;
@@ -1217,8 +1217,13 @@ Type Parser::parseNonFunctionType() {
return nullptr;
}
+ IntegerType::SignednessSemantics signSemantics = IntegerType::Signless;
+ if (Optional<bool> signedness = getToken().getIntTypeSignedness())
+ signSemantics = *signedness ? IntegerType::Signed : IntegerType::Unsigned;
+
+ auto loc = getEncodedSourceLocation(getToken().getLoc());
consumeToken(Token::inttype);
- return IntegerType::get(width.getValue(), builder.getContext());
+ return IntegerType::getChecked(width.getValue(), signSemantics, loc);
}
// float-type
@@ -1789,10 +1794,16 @@ Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
return apVal ? FloatAttr::get(floatType, *apVal) : Attribute();
}
- if (!type.isIntOrIndex())
+ if (!type.isa<IntegerType>() && !type.isa<IndexType>())
return emitError(loc, "integer literal not valid for specified type"),
nullptr;
+ if (isNegative && type.isUnsignedInteger()) {
+ emitError(loc,
+ "negative integer literal not valid for unsigned integer type");
+ return nullptr;
+ }
+
// Parse the integer literal.
int width = type.isIndex() ? 64 : type.getIntOrFloatBitWidth();
APInt apInt(width, *val, isNegative);
@@ -1989,13 +2000,22 @@ DenseElementsAttr TensorLiteralParser::getIntAttr(llvm::SMLoc loc,
IntegerType eltTy) {
std::vector<APInt> intElements;
intElements.reserve(storage.size());
+ auto isUintType = type.getElementType().isUnsignedInteger();
for (const auto &signAndToken : storage) {
bool isNegative = signAndToken.first;
const Token &token = signAndToken.second;
+ auto tokenLoc = token.getLoc();
+
+ if (isNegative && isUintType) {
+ p.emitError(tokenLoc)
+ << "expected unsigned integer elements, but parsed negative value";
+ return nullptr;
+ }
// Check to see if floating point values were parsed.
if (token.is(Token::floatliteral)) {
- p.emitError() << "expected integer elements, but parsed floating-point";
+ p.emitError(tokenLoc)
+ << "expected integer elements, but parsed floating-point";
return nullptr;
}
@@ -2003,7 +2023,8 @@ DenseElementsAttr TensorLiteralParser::getIntAttr(llvm::SMLoc loc,
"unexpected token type");
if (token.isAny(Token::kw_true, Token::kw_false)) {
if (!eltTy.isInteger(1))
- p.emitError() << "expected i1 type for 'true' or 'false' values";
+ p.emitError(tokenLoc)
+ << "expected i1 type for 'true' or 'false' values";
APInt apInt(eltTy.getWidth(), token.is(Token::kw_true),
/*isSigned=*/false);
intElements.push_back(apInt);
@@ -2014,13 +2035,13 @@ DenseElementsAttr TensorLiteralParser::getIntAttr(llvm::SMLoc loc,
auto val = token.getUInt64IntegerValue();
if (!val.hasValue() || (isNegative ? (int64_t)-val.getValue() >= 0
: (int64_t)val.getValue() < 0)) {
- p.emitError(token.getLoc(),
- "integer constant out of range for attribute");
+ p.emitError(tokenLoc, "integer constant out of range for attribute");
return nullptr;
}
APInt apInt(eltTy.getWidth(), val.getValue(), isNegative);
if (apInt != val.getValue())
- return (p.emitError("integer constant out of range for type"), nullptr);
+ return (p.emitError(tokenLoc, "integer constant out of range for type"),
+ nullptr);
intElements.push_back(isNegative ? -apInt : apInt);
}
@@ -2085,7 +2106,7 @@ DenseElementsAttr TensorLiteralParser::getFloatAttr(llvm::SMLoc loc,
DenseElementsAttr TensorLiteralParser::getHexAttr(llvm::SMLoc loc,
ShapedType type) {
Type elementType = type.getElementType();
- if (!elementType.isIntOrFloat()) {
+ if (!elementType.isa<FloatType>() && !elementType.isa<IntegerType>()) {
p.emitError(loc) << "expected floating-point or integer element type, got "
<< elementType;
return nullptr;
diff --git a/mlir/lib/Parser/Token.cpp b/mlir/lib/Parser/Token.cpp
index 0dbb5270dca4..8fe16b05fde6 100644
--- a/mlir/lib/Parser/Token.cpp
+++ b/mlir/lib/Parser/Token.cpp
@@ -57,13 +57,26 @@ Optional<double> Token::getFloatingPointValue() const {
/// For an inttype token, return its bitwidth.
Optional<unsigned> Token::getIntTypeBitwidth() const {
+ assert(getKind() == inttype);
+ unsigned bitwidthStart = (spelling[0] == 'i' ? 1 : 2);
unsigned result = 0;
- if (spelling[1] == '0' || spelling.drop_front().getAsInteger(10, result) ||
+ if (spelling[bitwidthStart] == '0' ||
+ spelling.drop_front(bitwidthStart).getAsInteger(10, result) ||
result == 0)
return None;
return result;
}
+Optional<bool> Token::getIntTypeSignedness() const {
+ assert(getKind() == inttype);
+ if (spelling[0] == 'i')
+ return llvm::None;
+ if (spelling[0] == 's')
+ return true;
+ assert(spelling[0] == 'u');
+ return false;
+}
+
/// Given a token containing a string literal, return its value, including
/// removing the quote characters and unescaping the contents of the string. The
/// lexer has already verified that this token is valid.
diff --git a/mlir/lib/Parser/Token.h b/mlir/lib/Parser/Token.h
index 95db87c9cd42..e6fa6c70853f 100644
--- a/mlir/lib/Parser/Token.h
+++ b/mlir/lib/Parser/Token.h
@@ -74,6 +74,11 @@ class Token {
/// For an inttype token, return its bitwidth.
Optional<unsigned> getIntTypeBitwidth() const;
+ /// For an inttype token, return its signedness semantics: llvm::None means no
+ /// signedness semantics; true means signed integer type; false means unsigned
+ /// integer type.
+ Optional<bool> getIntTypeSignedness() const;
+
/// Given a hash_identifier token like #123, try to parse the number out of
/// the identifier, returning None if it is a named identifier like #x or
/// if the integer doesn't fit.
diff --git a/mlir/lib/Parser/TokenKinds.def b/mlir/lib/Parser/TokenKinds.def
index 95b05b5424ef..47c43f6522fb 100644
--- a/mlir/lib/Parser/TokenKinds.def
+++ b/mlir/lib/Parser/TokenKinds.def
@@ -52,7 +52,7 @@ TOK_IDENTIFIER(exclamation_identifier) // !foo
TOK_LITERAL(floatliteral) // 2.0
TOK_LITERAL(integer) // 42
TOK_LITERAL(string) // "foo"
-TOK_LITERAL(inttype) // i421
+TOK_LITERAL(inttype) // i4, si8, ui16
// Punctuation.
TOK_PUNCTUATION(arrow, "->")
diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index 4a100bc57e10..2062b379283f 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -869,7 +869,7 @@ static unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
auto elementType = memRefType.getElementType();
unsigned sizeInBits;
- if (elementType.isIntOrFloat()) {
+ if (elementType.isSignlessIntOrFloat()) {
sizeInBits = elementType.getIntOrFloatBitWidth();
} else {
auto vectorType = elementType.cast<VectorType>();
diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir
index 25756735df06..a8a72fd87f12 100644
--- a/mlir/test/IR/invalid.mlir
+++ b/mlir/test/IR/invalid.mlir
@@ -204,6 +204,14 @@ func @illegaltype(i0) // expected-error {{invalid integer width}}
// -----
+func @illegaltype(ui1) // expected-error {{cannot have signedness semantics for i1}}
+
+// -----
+
+func @illegaltype(si1) // expected-error {{cannot have signedness semantics for i1}}
+
+// -----
+
func @malformed_for_percent() {
affine.for i = 1 to 10 { // expected-error {{expected SSA operand}}
@@ -1206,5 +1214,21 @@ func @bool_literal_in_non_bool_tensor() {
"foo"() {bar = dense<true> : tensor<2xi16>} : () -> ()
}
+// -----
+
// expected-error @+1 {{unbalanced ')' character in pretty dialect name}}
func @bad_arrow(%arg : !unreg.ptr<(i32)->)
+
+// -----
+
+func @negative_value_in_unsigned_int_attr() {
+ // expected-error @+1 {{negative integer literal not valid for unsigned integer type}}
+ "foo"() {bar = -5 : ui32} : () -> ()
+}
+
+// -----
+
+func @negative_value_in_unsigned_vector_attr() {
+ // expected-error @+1 {{expected unsigned integer elements, but parsed negative value}}
+ "foo"() {bar = dense<[5, -5]> : vector<2xui32>} : () -> ()
+}
diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir
index 71e82b0fe090..bec1fbd4aca6 100644
--- a/mlir/test/IR/parser.mlir
+++ b/mlir/test/IR/parser.mlir
@@ -61,6 +61,12 @@ func @missingReturn()
// CHECK: func @int_types(i1, i2, i4, i7, i87) -> (i1, index, i19)
func @int_types(i1, i2, i4, i7, i87) -> (i1, index, i19)
+// CHECK: func @sint_types(si2, si4) -> (si7, si1023)
+func @sint_types(si2, si4) -> (si7, si1023)
+
+// CHECK: func @uint_types(ui2, ui4) -> (ui7, ui1023)
+func @uint_types(ui2, ui4) -> (ui7, ui1023)
+
// CHECK: func @vectors(vector<1xf32>, vector<2x4xf32>)
func @vectors(vector<1 x f32>, vector<2x4xf32>)
diff --git a/mlir/test/lib/TestDialect/TestDialect.cpp b/mlir/test/lib/TestDialect/TestDialect.cpp
index d1f53a2bcdb5..330b8041afdc 100644
--- a/mlir/test/lib/TestDialect/TestDialect.cpp
+++ b/mlir/test/lib/TestDialect/TestDialect.cpp
@@ -112,8 +112,10 @@ struct TestInlinerInterface : public DialectInlinerInterface {
Type resultType,
Location conversionLoc) const final {
// Only allow conversion for i16/i32 types.
- if (!(resultType.isInteger(16) || resultType.isInteger(32)) ||
- !(input.getType().isInteger(16) || input.getType().isInteger(32)))
+ if (!(resultType.isSignlessInteger(16) ||
+ resultType.isSignlessInteger(32)) ||
+ !(input.getType().isSignlessInteger(16) ||
+ input.getType().isSignlessInteger(32)))
return nullptr;
return builder.create<TestCastOp>(conversionLoc, resultType, input);
}
diff --git a/mlir/test/lib/TestDialect/TestPatterns.cpp b/mlir/test/lib/TestDialect/TestPatterns.cpp
index 534e86f85a63..f89987610c99 100644
--- a/mlir/test/lib/TestDialect/TestPatterns.cpp
+++ b/mlir/test/lib/TestDialect/TestPatterns.cpp
@@ -235,7 +235,7 @@ struct TestChangeProducerTypeI32ToF32 : public ConversionPattern {
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
// If the type is I32, change the type to F32.
- if (!Type(*op->result_type_begin()).isInteger(32))
+ if (!Type(*op->result_type_begin()).isSignlessInteger(32))
return matchFailure();
rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type());
return matchSuccess();
@@ -309,11 +309,11 @@ struct TestTypeConverter : public TypeConverter {
static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
// Drop I16 types.
- if (t.isInteger(16))
+ if (t.isSignlessInteger(16))
return success();
// Convert I64 to F64.
- if (t.isInteger(64)) {
+ if (t.isSignlessInteger(64)) {
results.push_back(FloatType::getF64(t.getContext()));
return success();
}
diff --git a/mlir/test/mlir-tblgen/op-attribute.td b/mlir/test/mlir-tblgen/op-attribute.td
index 31ccf59bd12d..44c9a72340f3 100644
--- a/mlir/test/mlir-tblgen/op-attribute.td
+++ b/mlir/test/mlir-tblgen/op-attribute.td
@@ -137,8 +137,8 @@ def BOp : NS_Op<"b_op", []> {
// DEF-LABEL: BOp::verify
// DEF: if (!((true)))
// DEF: if (!((tblgen_bool_attr.isa<BoolAttr>())))
-// DEF: if (!(((tblgen_i32_attr.isa<IntegerAttr>())) && ((tblgen_i32_attr.cast<IntegerAttr>().getType().isInteger(32)))))
-// DEF: if (!(((tblgen_i64_attr.isa<IntegerAttr>())) && ((tblgen_i64_attr.cast<IntegerAttr>().getType().isInteger(64)))))
+// DEF: if (!(((tblgen_i32_attr.isa<IntegerAttr>())) && ((tblgen_i32_attr.cast<IntegerAttr>().getType().isSignlessInteger(32)))))
+// DEF: if (!(((tblgen_i64_attr.isa<IntegerAttr>())) && ((tblgen_i64_attr.cast<IntegerAttr>().getType().isSignlessInteger(64)))))
// DEF: if (!(((tblgen_f32_attr.isa<FloatAttr>())) && ((tblgen_f32_attr.cast<FloatAttr>().getType().isF32()))))
// DEF: if (!(((tblgen_f64_attr.isa<FloatAttr>())) && ((tblgen_f64_attr.cast<FloatAttr>().getType().isF64()))))
// DEF: if (!((tblgen_str_attr.isa<StringAttr>())))
diff --git a/mlir/test/mlir-tblgen/predicate.td b/mlir/test/mlir-tblgen/predicate.td
index 0de9edce7209..8ce4913e91d3 100644
--- a/mlir/test/mlir-tblgen/predicate.td
+++ b/mlir/test/mlir-tblgen/predicate.td
@@ -91,4 +91,4 @@ def OpK : NS_Op<"op_for_AnyTensorOf", []> {
// CHECK-LABEL: OpK::verify
// CHECK: for (Value v : getODSOperands(0)) {
-// CHECK: if (!(((v.getType().isa<TensorType>())) && (((v.getType().cast<ShapedType>().getElementType().isF32())) || ((v.getType().cast<ShapedType>().getElementType().isInteger(32))))))
+// CHECK: if (!(((v.getType().isa<TensorType>())) && (((v.getType().cast<ShapedType>().getElementType().isF32())) || ((v.getType().cast<ShapedType>().getElementType().isSignlessInteger(32))))))
More information about the Mlir-commits
mailing list