[Mlir-commits] [mlir] add IntegerLikeTypeInterface to enable out-of-tree uses of int attribute parsers (PR #137061)
Jeremy Kun
llvmlistbot at llvm.org
Wed Apr 23 14:02:06 PDT 2025
https://github.com/j2kun updated https://github.com/llvm/llvm-project/pull/137061
>From 5e5816982024ae26eabc2f72e58d6e8b9fb5ce9c Mon Sep 17 00:00:00 2001
From: Jeremy Kun <jkun at google.com>
Date: Wed, 23 Apr 2025 13:45:58 -0700
Subject: [PATCH] add IntegerLikeTypeInterface to enable out-of-tree uses of
int attribute parsers
---
mlir/include/mlir/IR/BuiltinAttributes.h | 4 +-
mlir/include/mlir/IR/BuiltinTypeInterfaces.td | 38 +++++++++++++++++++
mlir/include/mlir/IR/BuiltinTypes.td | 6 ++-
mlir/lib/AsmParser/AttributeParser.cpp | 16 ++++----
mlir/lib/IR/AsmPrinter.cpp | 2 +-
mlir/lib/IR/AttributeDetail.h | 5 ++-
mlir/lib/IR/BuiltinAttributes.cpp | 35 +++++++++--------
mlir/lib/IR/BuiltinTypes.cpp | 10 +++++
8 files changed, 85 insertions(+), 31 deletions(-)
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index c07ade606a775..005316a737dff 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -548,7 +548,9 @@ class DenseElementsAttr : public Attribute {
std::enable_if_t<std::is_same<T, APInt>::value>;
template <typename T, typename = APIntValueTemplateCheckT<T>>
FailureOr<iterator_range_impl<IntElementIterator>> tryGetValues() const {
- if (!getElementType().isIntOrIndex())
+ auto intLikeType =
+ llvm::dyn_cast<IntegerLikeTypeInterface>(getElementType());
+ if (!intLikeType)
return failure();
return iterator_range_impl<IntElementIterator>(getType(), raw_int_begin(),
raw_int_end());
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index 4a4f818b46c57..3d459f006093a 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -257,4 +257,42 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
}];
}
+def IntegerLikeTypeInterface : TypeInterface<"IntegerLikeTypeInterface"> {
+ let cppNamespace = "::mlir";
+ let description = [{
+ This type interface is for types that behave like integers. It provides
+ the API that allows MLIR utilities to treat them the same was as MLIR
+ treats integer types in settings like parsing and printing.
+ }];
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Returns the storage bit width for this type.
+ }],
+ /*retTy=*/"unsigned",
+ /*methodName=*/"getStorageBitWidth",
+ /*args=*/(ins)
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Returns true if this type is signed.
+ }],
+ /*retTy=*/"bool",
+ /*methodName=*/"isSigned",
+ /*args=*/(ins),
+ /*defaultImplementation=*/"return true;"
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Returns true if this type is signless.
+ }],
+ /*retTy=*/"bool",
+ /*methodName=*/"isSignless",
+ /*args=*/(ins),
+ /*defaultImplementation=*/"return true;"
+ >,
+ ];
+}
+
#endif // MLIR_IR_BUILTINTYPEINTERFACES_TD_
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 771de01fc8d5d..6eb2ec333351a 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -466,7 +466,8 @@ def Builtin_Function : Builtin_Type<"Function", "function"> {
//===----------------------------------------------------------------------===//
def Builtin_Index : Builtin_Type<"Index", "index",
- [VectorElementTypeInterface]> {
+ [VectorElementTypeInterface,
+ DeclareTypeInterfaceMethods<IntegerLikeTypeInterface, ["getStorageBitWidth"]>]> {
let summary = "Integer-like type with unknown platform-dependent bit width";
let description = [{
Syntax:
@@ -497,7 +498,8 @@ def Builtin_Index : Builtin_Type<"Index", "index",
//===----------------------------------------------------------------------===//
def Builtin_Integer : Builtin_Type<"Integer", "integer",
- [VectorElementTypeInterface]> {
+ [VectorElementTypeInterface,
+ DeclareTypeInterfaceMethods<IntegerLikeTypeInterface, ["getStorageBitWidth"]>]> {
let summary = "Integer type with arbitrary precision up to a fixed limit";
let description = [{
Syntax:
diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index 2474e88373e04..a3252cf1964ee 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -16,6 +16,7 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectResourceBlobManager.h"
#include "mlir/IR/IntegerSet.h"
@@ -366,8 +367,12 @@ static std::optional<APInt> buildAttributeAPInt(Type type, bool isNegative,
return std::nullopt;
// Extend or truncate the bitwidth to the right size.
- unsigned width = type.isIndex() ? IndexType::kInternalStorageBitWidth
- : type.getIntOrFloatBitWidth();
+ unsigned width;
+ if (auto intLikeType = dyn_cast<IntegerLikeTypeInterface>(type)) {
+ width = intLikeType.getStorageBitWidth();
+ } else {
+ width = type.getIntOrFloatBitWidth();
+ }
if (width > result.getBitWidth()) {
result = result.zext(width);
@@ -425,10 +430,6 @@ Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
return FloatAttr::get(floatType, *result);
}
- if (!isa<IntegerType, IndexType>(type))
- return emitError(loc, "integer literal not valid for specified type"),
- nullptr;
-
if (isNegative && type.isUnsignedInteger()) {
emitError(loc,
"negative integer literal not valid for unsigned integer type");
@@ -584,7 +585,8 @@ DenseElementsAttr TensorLiteralParser::getAttr(SMLoc loc, ShapedType type) {
}
// Handle integer and index types.
- if (eltType.isIntOrIndex()) {
+ auto integerLikeType = dyn_cast<IntegerLikeTypeInterface>(eltType);
+ if (integerLikeType || eltType.isIntOrIndex()) {
std::vector<APInt> intValues;
if (failed(getIntAttrElements(loc, eltType, intValues)))
return nullptr;
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 5b5ec841917e7..d74bf5d975f0b 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2656,7 +2656,7 @@ void AsmPrinter::Impl::printDenseIntOrFPElementsAttr(
os << ")";
});
}
- } else if (elementType.isIntOrIndex()) {
+ } else if (isa<IntegerLikeTypeInterface>(elementType)) {
auto valueIt = attr.value_begin<APInt>();
printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
printDenseIntElement(*(valueIt + index), os, elementType);
diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h
index 26d40ac3a38f6..96b269e3b1363 100644
--- a/mlir/lib/IR/AttributeDetail.h
+++ b/mlir/lib/IR/AttributeDetail.h
@@ -37,8 +37,9 @@ inline size_t getDenseElementBitWidth(Type eltType) {
// Align the width for complex to 8 to make storage and interpretation easier.
if (ComplexType comp = llvm::dyn_cast<ComplexType>(eltType))
return llvm::alignTo<8>(getDenseElementBitWidth(comp.getElementType())) * 2;
- if (eltType.isIndex())
- return IndexType::kInternalStorageBitWidth;
+ if (auto intLikeType = dyn_cast<IntegerLikeTypeInterface>(eltType))
+ return intLikeType.getStorageBitWidth();
+
return eltType.getIntOrFloatBitWidth();
}
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index e9af1f77a379e..7bc3d9a59a37e 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -10,6 +10,7 @@
#include "AttributeDetail.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectResourceBlobManager.h"
#include "mlir/IR/IntegerSet.h"
@@ -379,22 +380,20 @@ APSInt IntegerAttr::getAPSInt() const {
LogicalResult IntegerAttr::verify(function_ref<InFlightDiagnostic()> emitError,
Type type, APInt value) {
- if (IntegerType integerType = llvm::dyn_cast<IntegerType>(type)) {
- if (integerType.getWidth() != value.getBitWidth())
- return emitError() << "integer type bit width (" << integerType.getWidth()
- << ") doesn't match value bit width ("
- << value.getBitWidth() << ")";
- return success();
+ unsigned width;
+ if (auto intLikeType = dyn_cast<IntegerLikeTypeInterface>(type)) {
+ width = intLikeType.getStorageBitWidth();
+ } else {
+ return emitError() << "expected integer-like type";
}
- if (llvm::isa<IndexType>(type)) {
- if (value.getBitWidth() != IndexType::kInternalStorageBitWidth)
- return emitError()
- << "value bit width (" << value.getBitWidth()
- << ") doesn't match index type internal storage bit width ("
- << IndexType::kInternalStorageBitWidth << ")";
- return success();
+
+ if (width != value.getBitWidth()) {
+ return emitError() << "integer-like type bit width (" << width
+ << ") doesn't match value bit width ("
+ << value.getBitWidth() << ")";
}
- return emitError() << "expected integer or index type";
+
+ return success();
}
BoolAttr IntegerAttr::getBoolAttrUnchecked(IntegerType type, bool value) {
@@ -1019,7 +1018,7 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
/// element type of 'type'.
DenseElementsAttr DenseElementsAttr::get(ShapedType type,
ArrayRef<APInt> values) {
- assert(type.getElementType().isIntOrIndex());
+ assert(isa<IntegerLikeTypeInterface>(type.getElementType()));
assert(hasSameNumElementsOrSplat(type, values));
size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values);
@@ -1130,11 +1129,11 @@ static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt,
if (type.isIndex())
return true;
- auto intType = llvm::dyn_cast<IntegerType>(type);
+ auto intType = llvm::dyn_cast<IntegerLikeTypeInterface>(type);
if (!intType) {
LLVM_DEBUG(llvm::dbgs()
- << "expected integer type when isInt is true, but found " << type
- << "\n");
+ << "expected integer-like type when isInt is true, but found "
+ << type << "\n");
return false;
}
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 3924d082f0628..c38e33dba7e5d 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -59,6 +59,14 @@ LogicalResult ComplexType::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
}
+//===----------------------------------------------------------------------===//
+// Index Type
+//===----------------------------------------------------------------------===//
+
+unsigned IndexType::getStorageBitWidth() const {
+ return kInternalStorageBitWidth;
+}
+
//===----------------------------------------------------------------------===//
// Integer Type
//===----------------------------------------------------------------------===//
@@ -86,6 +94,8 @@ IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
return IntegerType::get(getContext(), scale * getWidth(), getSignedness());
}
+unsigned IntegerType::getStorageBitWidth() const { return getWidth(); }
+
//===----------------------------------------------------------------------===//
// Float Types
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list