[Mlir-commits] [mlir] add IntegerLikeTypeInterface to enable out-of-tree uses of int attribute parsers (PR #137061)

Jeremy Kun llvmlistbot at llvm.org
Wed Apr 23 14:02:06 PDT 2025


https://github.com/j2kun updated https://github.com/llvm/llvm-project/pull/137061

>From 5e5816982024ae26eabc2f72e58d6e8b9fb5ce9c Mon Sep 17 00:00:00 2001
From: Jeremy Kun <jkun at google.com>
Date: Wed, 23 Apr 2025 13:45:58 -0700
Subject: [PATCH] add IntegerLikeTypeInterface to enable out-of-tree uses of
 int attribute parsers

---
 mlir/include/mlir/IR/BuiltinAttributes.h      |  4 +-
 mlir/include/mlir/IR/BuiltinTypeInterfaces.td | 38 +++++++++++++++++++
 mlir/include/mlir/IR/BuiltinTypes.td          |  6 ++-
 mlir/lib/AsmParser/AttributeParser.cpp        | 16 ++++----
 mlir/lib/IR/AsmPrinter.cpp                    |  2 +-
 mlir/lib/IR/AttributeDetail.h                 |  5 ++-
 mlir/lib/IR/BuiltinAttributes.cpp             | 35 +++++++++--------
 mlir/lib/IR/BuiltinTypes.cpp                  | 10 +++++
 8 files changed, 85 insertions(+), 31 deletions(-)

diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index c07ade606a775..005316a737dff 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -548,7 +548,9 @@ class DenseElementsAttr : public Attribute {
       std::enable_if_t<std::is_same<T, APInt>::value>;
   template <typename T, typename = APIntValueTemplateCheckT<T>>
   FailureOr<iterator_range_impl<IntElementIterator>> tryGetValues() const {
-    if (!getElementType().isIntOrIndex())
+    auto intLikeType =
+        llvm::dyn_cast<IntegerLikeTypeInterface>(getElementType());
+    if (!intLikeType)
       return failure();
     return iterator_range_impl<IntElementIterator>(getType(), raw_int_begin(),
                                                    raw_int_end());
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index 4a4f818b46c57..3d459f006093a 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -257,4 +257,42 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
   }];
 }
 
+def IntegerLikeTypeInterface : TypeInterface<"IntegerLikeTypeInterface"> {
+  let cppNamespace = "::mlir";
+  let description = [{
+    This type interface is for types that behave like integers. It provides
+    the API that allows MLIR utilities to treat them the same was as MLIR
+    treats integer types in settings like parsing and printing.
+  }];
+
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/[{
+        Returns the storage bit width for this type.
+      }],
+      /*retTy=*/"unsigned",
+      /*methodName=*/"getStorageBitWidth",
+      /*args=*/(ins)
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Returns true if this type is signed.
+      }],
+      /*retTy=*/"bool",
+      /*methodName=*/"isSigned",
+      /*args=*/(ins),
+      /*defaultImplementation=*/"return true;"
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Returns true if this type is signless.
+      }],
+      /*retTy=*/"bool",
+      /*methodName=*/"isSignless",
+      /*args=*/(ins),
+      /*defaultImplementation=*/"return true;"
+    >,
+  ];
+}
+
 #endif // MLIR_IR_BUILTINTYPEINTERFACES_TD_
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 771de01fc8d5d..6eb2ec333351a 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -466,7 +466,8 @@ def Builtin_Function : Builtin_Type<"Function", "function"> {
 //===----------------------------------------------------------------------===//
 
 def Builtin_Index : Builtin_Type<"Index", "index",
-    [VectorElementTypeInterface]> {
+    [VectorElementTypeInterface,
+     DeclareTypeInterfaceMethods<IntegerLikeTypeInterface, ["getStorageBitWidth"]>]> {
   let summary = "Integer-like type with unknown platform-dependent bit width";
   let description = [{
     Syntax:
@@ -497,7 +498,8 @@ def Builtin_Index : Builtin_Type<"Index", "index",
 //===----------------------------------------------------------------------===//
 
 def Builtin_Integer : Builtin_Type<"Integer", "integer",
-    [VectorElementTypeInterface]> {
+    [VectorElementTypeInterface,
+     DeclareTypeInterfaceMethods<IntegerLikeTypeInterface, ["getStorageBitWidth"]>]> {
   let summary = "Integer type with arbitrary precision up to a fixed limit";
   let description = [{
     Syntax:
diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index 2474e88373e04..a3252cf1964ee 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -16,6 +16,7 @@
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/DialectResourceBlobManager.h"
 #include "mlir/IR/IntegerSet.h"
@@ -366,8 +367,12 @@ static std::optional<APInt> buildAttributeAPInt(Type type, bool isNegative,
     return std::nullopt;
 
   // Extend or truncate the bitwidth to the right size.
-  unsigned width = type.isIndex() ? IndexType::kInternalStorageBitWidth
-                                  : type.getIntOrFloatBitWidth();
+  unsigned width;
+  if (auto intLikeType = dyn_cast<IntegerLikeTypeInterface>(type)) {
+    width = intLikeType.getStorageBitWidth();
+  } else {
+    width = type.getIntOrFloatBitWidth();
+  }
 
   if (width > result.getBitWidth()) {
     result = result.zext(width);
@@ -425,10 +430,6 @@ Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
     return FloatAttr::get(floatType, *result);
   }
 
-  if (!isa<IntegerType, IndexType>(type))
-    return emitError(loc, "integer literal not valid for specified type"),
-           nullptr;
-
   if (isNegative && type.isUnsignedInteger()) {
     emitError(loc,
               "negative integer literal not valid for unsigned integer type");
@@ -584,7 +585,8 @@ DenseElementsAttr TensorLiteralParser::getAttr(SMLoc loc, ShapedType type) {
   }
 
   // Handle integer and index types.
-  if (eltType.isIntOrIndex()) {
+  auto integerLikeType = dyn_cast<IntegerLikeTypeInterface>(eltType);
+  if (integerLikeType || eltType.isIntOrIndex()) {
     std::vector<APInt> intValues;
     if (failed(getIntAttrElements(loc, eltType, intValues)))
       return nullptr;
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 5b5ec841917e7..d74bf5d975f0b 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2656,7 +2656,7 @@ void AsmPrinter::Impl::printDenseIntOrFPElementsAttr(
         os << ")";
       });
     }
-  } else if (elementType.isIntOrIndex()) {
+  } else if (isa<IntegerLikeTypeInterface>(elementType)) {
     auto valueIt = attr.value_begin<APInt>();
     printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
       printDenseIntElement(*(valueIt + index), os, elementType);
diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h
index 26d40ac3a38f6..96b269e3b1363 100644
--- a/mlir/lib/IR/AttributeDetail.h
+++ b/mlir/lib/IR/AttributeDetail.h
@@ -37,8 +37,9 @@ inline size_t getDenseElementBitWidth(Type eltType) {
   // Align the width for complex to 8 to make storage and interpretation easier.
   if (ComplexType comp = llvm::dyn_cast<ComplexType>(eltType))
     return llvm::alignTo<8>(getDenseElementBitWidth(comp.getElementType())) * 2;
-  if (eltType.isIndex())
-    return IndexType::kInternalStorageBitWidth;
+  if (auto intLikeType = dyn_cast<IntegerLikeTypeInterface>(eltType))
+    return intLikeType.getStorageBitWidth();
+
   return eltType.getIntOrFloatBitWidth();
 }
 
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index e9af1f77a379e..7bc3d9a59a37e 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -10,6 +10,7 @@
 #include "AttributeDetail.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/DialectResourceBlobManager.h"
 #include "mlir/IR/IntegerSet.h"
@@ -379,22 +380,20 @@ APSInt IntegerAttr::getAPSInt() const {
 
 LogicalResult IntegerAttr::verify(function_ref<InFlightDiagnostic()> emitError,
                                   Type type, APInt value) {
-  if (IntegerType integerType = llvm::dyn_cast<IntegerType>(type)) {
-    if (integerType.getWidth() != value.getBitWidth())
-      return emitError() << "integer type bit width (" << integerType.getWidth()
-                         << ") doesn't match value bit width ("
-                         << value.getBitWidth() << ")";
-    return success();
+  unsigned width;
+  if (auto intLikeType = dyn_cast<IntegerLikeTypeInterface>(type)) {
+    width = intLikeType.getStorageBitWidth();
+  } else {
+    return emitError() << "expected integer-like type";
   }
-  if (llvm::isa<IndexType>(type)) {
-    if (value.getBitWidth() != IndexType::kInternalStorageBitWidth)
-      return emitError()
-             << "value bit width (" << value.getBitWidth()
-             << ") doesn't match index type internal storage bit width ("
-             << IndexType::kInternalStorageBitWidth << ")";
-    return success();
+
+  if (width != value.getBitWidth()) {
+    return emitError() << "integer-like type bit width (" << width
+                       << ") doesn't match value bit width ("
+                       << value.getBitWidth() << ")";
   }
-  return emitError() << "expected integer or index type";
+
+  return success();
 }
 
 BoolAttr IntegerAttr::getBoolAttrUnchecked(IntegerType type, bool value) {
@@ -1019,7 +1018,7 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
 /// element type of 'type'.
 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
                                          ArrayRef<APInt> values) {
-  assert(type.getElementType().isIntOrIndex());
+  assert(isa<IntegerLikeTypeInterface>(type.getElementType()));
   assert(hasSameNumElementsOrSplat(type, values));
   size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values);
@@ -1130,11 +1129,11 @@ static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt,
   if (type.isIndex())
     return true;
 
-  auto intType = llvm::dyn_cast<IntegerType>(type);
+  auto intType = llvm::dyn_cast<IntegerLikeTypeInterface>(type);
   if (!intType) {
     LLVM_DEBUG(llvm::dbgs()
-               << "expected integer type when isInt is true, but found " << type
-               << "\n");
+               << "expected integer-like type when isInt is true, but found "
+               << type << "\n");
     return false;
   }
 
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 3924d082f0628..c38e33dba7e5d 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -59,6 +59,14 @@ LogicalResult ComplexType::verify(function_ref<InFlightDiagnostic()> emitError,
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// Index Type
+//===----------------------------------------------------------------------===//
+
+unsigned IndexType::getStorageBitWidth() const {
+  return kInternalStorageBitWidth;
+}
+
 //===----------------------------------------------------------------------===//
 // Integer Type
 //===----------------------------------------------------------------------===//
@@ -86,6 +94,8 @@ IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
   return IntegerType::get(getContext(), scale * getWidth(), getSignedness());
 }
 
+unsigned IntegerType::getStorageBitWidth() const { return getWidth(); }
+
 //===----------------------------------------------------------------------===//
 // Float Types
 //===----------------------------------------------------------------------===//



More information about the Mlir-commits mailing list