[Mlir-commits] [mlir] 1b2c787 - Add support for IndexType inside DenseIntElementsAttr.
Sean Silva
llvmlistbot at llvm.org
Thu Apr 23 17:42:51 PDT 2020
Author: Sean Silva
Date: 2020-04-23T17:42:33-07:00
New Revision: 1b2c7877a4dd3da2ba7462d65a48341f59c6cc61
URL: https://github.com/llvm/llvm-project/commit/1b2c7877a4dd3da2ba7462d65a48341f59c6cc61
DIFF: https://github.com/llvm/llvm-project/commit/1b2c7877a4dd3da2ba7462d65a48341f59c6cc61.diff
LOG: Add support for IndexType inside DenseIntElementsAttr.
This also fixes issues discovered in the parsing/printing path.
Added:
Modified:
mlir/include/mlir/IR/StandardTypes.h
mlir/include/mlir/IR/Types.h
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/IR/AttributeDetail.h
mlir/lib/IR/Attributes.cpp
mlir/lib/IR/StandardTypes.cpp
mlir/lib/Parser/Parser.cpp
mlir/test/IR/parser.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h
index 4c5bbba0aa6a..3241a70f2675 100644
--- a/mlir/include/mlir/IR/StandardTypes.h
+++ b/mlir/include/mlir/IR/StandardTypes.h
@@ -76,6 +76,9 @@ class IndexType : public Type::TypeBase<IndexType, Type> {
/// Support method to enable LLVM-style type casting.
static bool kindof(unsigned kind) { return kind == StandardTypes::Index; }
+
+ /// Storage bit width used for IndexType by internal compiler data structures.
+ static constexpr unsigned kInternalStorageBitWidth = 64;
};
/// Integer types can have arbitrary bitwidth up to a large fixed limit.
diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index e45fa9037470..32ae13f86dc9 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -169,6 +169,8 @@ class Type {
/// Return true of this is a signless integer or a float type.
bool isSignlessIntOrFloat();
+ /// Return true if this is an integer (of any signedness) or an index type.
+ bool isIntOrIndex();
/// Return true if this is an integer (of any signedness) or a float type.
bool isIntOrFloat();
/// Return true if this is an integer (of any signedness), index, or float
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 75ea0612bb3a..ac92b707fac5 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1462,7 +1462,7 @@ void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr,
bool isSigned = !type.getElementType().isUnsignedInteger();
// The function used to print elements of this attribute.
- auto printEltFn = type.getElementType().isa<IntegerType>()
+ auto printEltFn = type.getElementType().isIntOrIndex()
? printDenseIntElement
: printDenseFloatElement;
diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h
index 8908416b3efa..0119830e1ab9 100644
--- a/mlir/lib/IR/AttributeDetail.h
+++ b/mlir/lib/IR/AttributeDetail.h
@@ -372,6 +372,17 @@ struct TypeAttributeStorage : public AttributeStorage {
// Elements Attributes
//===----------------------------------------------------------------------===//
+/// Return the bit width which DenseElementsAttr should use for this type.
+inline size_t getDenseElementBitWidth(Type eltType) {
+ // FIXME(b/121118307): using 64 bits for BF16 because it is currently stored
+ // with double semantics.
+ if (eltType.isBF16())
+ return 64;
+ if (eltType.isIndex())
+ return IndexType::kInternalStorageBitWidth;
+ return eltType.getIntOrFloatBitWidth();
+}
+
/// An attribute representing a reference to a dense vector or tensor object.
struct DenseElementsAttributeStorage : public AttributeStorage {
struct KeyTy {
@@ -405,7 +416,7 @@ struct DenseElementsAttributeStorage : public AttributeStorage {
// same. Boolean values are packed at the bit level, and even though a splat
// is detected the rest of the bits in the first byte may
diff er from the
// splat value.
- if (key.type.getElementTypeBitWidth() == 1) {
+ if (key.type.getElementType().isInteger(1)) {
if (key.isSplat != isSplat)
return false;
if (isSplat)
@@ -437,15 +448,10 @@ struct DenseElementsAttributeStorage : public AttributeStorage {
assert(numElements != 1 && "splat of 1 element should already be detected");
// Handle boolean values directly as they are packed to 1-bit.
- size_t elementWidth = ty.getElementTypeBitWidth();
- if (elementWidth == 1)
+ if (ty.getElementType().isInteger(1) == 1)
return getKeyForBoolData(ty, data, numElements);
- // FIXME(b/121118307): using 64 bits for BF16 because it is currently stored
- // with double semantics.
- if (ty.getElementType().isBF16())
- elementWidth = 64;
-
+ size_t elementWidth = getDenseElementBitWidth(ty.getElementType());
// Non 1-bit dense elements are padded to 8-bits.
size_t storageSize = llvm::divideCeil(elementWidth, CHAR_BIT);
assert(((data.size() / storageSize) == numElements) &&
@@ -517,7 +523,7 @@ struct DenseElementsAttributeStorage : public AttributeStorage {
std::memcpy(rawData, data.data(), data.size());
// If this is a boolean splat, make sure only the first bit is used.
- if (key.isSplat && key.type.getElementTypeBitWidth() == 1)
+ if (key.isSplat && key.type.getElementType().isInteger(1))
rawData[0] &= 1;
copy = ArrayRef<char>(rawData, data.size());
}
diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp
index 657df5752ade..be31d95657ff 100644
--- a/mlir/lib/IR/Attributes.cpp
+++ b/mlir/lib/IR/Attributes.cpp
@@ -275,7 +275,7 @@ IntegerAttr IntegerAttr::get(Type type, const APInt &value) {
IntegerAttr IntegerAttr::get(Type type, int64_t value) {
// This uses 64 bit APInts by default for index type.
if (type.isIndex())
- return get(type, APInt(64, value));
+ return get(type, APInt(IndexType::kInternalStorageBitWidth, value));
auto intType = type.cast<IntegerType>();
return get(type, APInt(intType.getWidth(), value, intType.isSignedInteger()));
@@ -483,12 +483,6 @@ uint64_t ElementsAttr::getFlattenedIndex(ArrayRef<uint64_t> index) const {
// DenseElementAttr Utilities
//===----------------------------------------------------------------------===//
-static size_t getDenseElementBitwidth(Type eltType) {
- // FIXME(b/121118307): using 64 bits for BF16 because it is currently stored
- // with double semantics.
- return eltType.isBF16() ? 64 : eltType.getIntOrFloatBitWidth();
-}
-
/// Get the bitwidth of a dense element type within the buffer.
/// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8.
static size_t getDenseElementStorageWidth(size_t origWidth) {
@@ -592,7 +586,7 @@ DenseElementsAttr::IntElementIterator::IntElementIterator(
DenseElementsAttr attr, size_t dataIndex)
: DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>(
attr.getRawData().data(), attr.isSplat(), dataIndex),
- bitWidth(getDenseElementBitwidth(attr.getType().getElementType())) {}
+ bitWidth(getDenseElementBitWidth(attr.getType().getElementType())) {}
/// Accesses the raw APInt value at this iterator position.
APInt DenseElementsAttr::IntElementIterator::operator*() const {
@@ -613,12 +607,12 @@ DenseElementsAttr::FloatElementIterator::FloatElementIterator(
DenseElementsAttr DenseElementsAttr::get(ShapedType type,
ArrayRef<Attribute> values) {
- assert(type.getElementType().isIntOrFloat() &&
- "expected int or float element type");
+ assert(type.getElementType().isIntOrIndexOrFloat() &&
+ "expected int or index or float element type");
assert(hasSameElementsOrSplat(type, values));
auto eltType = type.getElementType();
- size_t bitWidth = getDenseElementBitwidth(eltType);
+ size_t bitWidth = getDenseElementBitWidth(eltType);
size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
// Compress the attribute values into a character buffer.
@@ -637,6 +631,7 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt();
break;
case StandardTypes::Integer:
+ case StandardTypes::Index:
intVal = values[i].isa<BoolAttr>()
? APInt(1, values[i].cast<BoolAttr>().getValue() ? 1 : 0)
: values[i].cast<IntegerAttr>().getValue();
@@ -667,7 +662,7 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
/// element type of 'type'.
DenseElementsAttr DenseElementsAttr::get(ShapedType type,
ArrayRef<APInt> values) {
- assert(type.getElementType().isa<IntegerType>());
+ assert(type.getElementType().isIntOrIndex());
return getRaw(type, values);
}
@@ -701,7 +696,7 @@ DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type,
ArrayRef<APInt> values) {
assert(hasSameElementsOrSplat(type, values));
- size_t bitWidth = getDenseElementBitwidth(type.getElementType());
+ size_t bitWidth = getDenseElementBitWidth(type.getElementType());
size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
std::vector<char> elementData(llvm::divideCeil(storageBitWidth, CHAR_BIT) *
values.size());
@@ -727,14 +722,17 @@ DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type,
static bool isValidIntOrFloat(ShapedType type, int64_t dataEltSize, bool isInt,
bool isSigned) {
// Make sure that the data element size is the same as the type element width.
- if (getDenseElementBitwidth(type.getElementType()) !=
+ if (getDenseElementBitWidth(type.getElementType()) !=
static_cast<size_t>(dataEltSize * CHAR_BIT))
return false;
- // Check that the element type is either float or integer.
+ // Check that the element type is either float or integer or index.
if (!isInt)
return type.getElementType().isa<FloatType>();
+ if (type.getElementType().isIndex())
+ return true;
+
auto intType = type.getElementType().dyn_cast<IntegerType>();
if (!intType)
return false;
@@ -798,18 +796,15 @@ auto DenseElementsAttr::getBoolValues() const
/// this attribute must be of integer type.
auto DenseElementsAttr::getIntValues() const
-> llvm::iterator_range<IntElementIterator> {
- assert(getType().getElementType().isa<IntegerType>() &&
- "expected integer type");
+ assert(getType().getElementType().isIntOrIndex() && "expected integral type");
return {raw_int_begin(), raw_int_end()};
}
auto DenseElementsAttr::int_value_begin() const -> IntElementIterator {
- assert(getType().getElementType().isa<IntegerType>() &&
- "expected integer type");
+ assert(getType().getElementType().isIntOrIndex() && "expected integral type");
return raw_int_begin();
}
auto DenseElementsAttr::int_value_end() const -> IntElementIterator {
- assert(getType().getElementType().isa<IntegerType>() &&
- "expected integer type");
+ assert(getType().getElementType().isIntOrIndex() && "expected integral type");
return raw_int_end();
}
@@ -870,7 +865,7 @@ template <typename Fn, typename Attr>
static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType,
Type newElementType,
llvm::SmallVectorImpl<char> &data) {
- size_t bitWidth = getDenseElementBitwidth(newElementType);
+ size_t bitWidth = getDenseElementBitWidth(newElementType);
size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
ShapedType newArrayType;
@@ -937,7 +932,7 @@ DenseElementsAttr DenseIntElementsAttr::mapValues(
/// Method for supporting type inquiry through isa, cast and dyn_cast.
bool DenseIntElementsAttr::classof(Attribute attr) {
return attr.isa<DenseElementsAttr>() &&
- attr.getType().cast<ShapedType>().getElementType().isa<IntegerType>();
+ attr.getType().cast<ShapedType>().getElementType().isIntOrIndex();
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp
index 94156c358eb0..c61e1d9e2ab9 100644
--- a/mlir/lib/IR/StandardTypes.cpp
+++ b/mlir/lib/IR/StandardTypes.cpp
@@ -83,6 +83,8 @@ bool Type::isSignlessIntOrFloat() {
return isSignlessInteger() || isa<FloatType>();
}
+bool Type::isIntOrIndex() { return isa<IntegerType>() || isIndex(); }
+
bool Type::isIntOrFloat() { return isa<IntegerType>() || isa<FloatType>(); }
bool Type::isIntOrIndexOrFloat() { return isIntOrFloat() || isIndex(); }
diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index e8825ff0fb27..8eabe3c46f4c 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -1797,7 +1797,8 @@ static Optional<APInt> buildAttributeAPInt(Type type, bool isNegative,
return llvm::None;
// Extend or truncate the bitwidth to the right size.
- unsigned width = type.isIndex() ? 64 : type.getIntOrFloatBitWidth();
+ unsigned width = type.isIndex() ? IndexType::kInternalStorageBitWidth
+ : type.getIntOrFloatBitWidth();
if (width > result.getBitWidth()) {
result = result.zext(width);
} else if (width < result.getBitWidth()) {
@@ -1968,8 +1969,7 @@ class TensorLiteralParser {
}
/// Build a Dense Integer attribute for the given type.
- DenseElementsAttr getIntAttr(llvm::SMLoc loc, ShapedType type,
- IntegerType eltTy);
+ DenseElementsAttr getIntAttr(llvm::SMLoc loc, ShapedType type, Type eltTy);
/// Build a Dense Float attribute for the given type.
DenseElementsAttr getFloatAttr(llvm::SMLoc loc, ShapedType type,
@@ -2044,14 +2044,17 @@ DenseElementsAttr TensorLiteralParser::getAttr(llvm::SMLoc loc,
// If the type is an integer, build a set of APInt values from the storage
// with the correct bitwidth.
- if (auto intTy = type.getElementType().dyn_cast<IntegerType>())
+ Type eltType = type.getElementType();
+ if (auto intTy = eltType.dyn_cast<IntegerType>())
return getIntAttr(loc, type, intTy);
+ if (auto indexTy = eltType.dyn_cast<IndexType>())
+ return getIntAttr(loc, type, indexTy);
// Otherwise, this must be a floating point type.
- auto floatTy = type.getElementType().dyn_cast<FloatType>();
+ auto floatTy = eltType.dyn_cast<FloatType>();
if (!floatTy) {
p.emitError(loc) << "expected floating-point or integer element type, got "
- << type.getElementType();
+ << eltType;
return nullptr;
}
return getFloatAttr(loc, type, floatTy);
@@ -2059,8 +2062,7 @@ DenseElementsAttr TensorLiteralParser::getAttr(llvm::SMLoc loc,
/// Build a Dense Integer attribute for the given type.
DenseElementsAttr TensorLiteralParser::getIntAttr(llvm::SMLoc loc,
- ShapedType type,
- IntegerType eltTy) {
+ ShapedType type, Type eltTy) {
std::vector<APInt> intElements;
intElements.reserve(storage.size());
auto isUintType = type.getElementType().isUnsignedInteger();
@@ -2085,11 +2087,12 @@ DenseElementsAttr TensorLiteralParser::getIntAttr(llvm::SMLoc loc,
assert(token.isAny(Token::integer, Token::kw_true, Token::kw_false) &&
"unexpected token type");
if (token.isAny(Token::kw_true, Token::kw_false)) {
- if (!eltTy.isInteger(1))
+ if (!eltTy.isInteger(1)) {
p.emitError(tokenLoc)
<< "expected i1 type for 'true' or 'false' values";
- APInt apInt(eltTy.getWidth(), token.is(Token::kw_true),
- /*isSigned=*/false);
+ return nullptr;
+ }
+ APInt apInt(1, token.is(Token::kw_true), /*isSigned=*/false);
intElements.push_back(apInt);
continue;
}
diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir
index 585f7de40aaf..7a02d776c172 100644
--- a/mlir/test/IR/parser.mlir
+++ b/mlir/test/IR/parser.mlir
@@ -697,6 +697,11 @@ func @densetensorattr() -> () {
"intscalar"(){bar = dense<1> : tensor<i32>} : () -> ()
// CHECK: "floatscalar"() {bar = dense<5.000000e+00> : tensor<f32>} : () -> ()
"floatscalar"(){bar = dense<5.0> : tensor<f32>} : () -> ()
+
+// CHECK: "index"() {bar = dense<1> : tensor<index>} : () -> ()
+ "index"(){bar = dense<1> : tensor<index>} : () -> ()
+// CHECK: "index"() {bar = dense<[1, 2]> : tensor<2xindex>} : () -> ()
+ "index"(){bar = dense<[1, 2]> : tensor<2xindex>} : () -> ()
return
}
More information about the Mlir-commits
mailing list