[Mlir-commits] [mlir] add IntegerLikeTypeInterface to enable out-of-tree uses of int attribute parsers (PR #137061)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 23 14:01:30 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Jeremy Kun (j2kun)

<details>
<summary>Changes</summary>

This is a first draft and I'm looking for feedback.

This PR adds the IntegerLikeTypeInterface to BuiltinTypeInterfaces, and registers is on Index and Integer types.

The goal of this PR is to enable ops that use out-of-tree types to be able to reuse the upstream attribute parsing infrastructure for types whose constant values can be stored as integers/dense int elements. For example, in HEIR we have a `mod_arith` dialect which is effectively an integer storage type with an added modulus. This dialect has a `mod_arith.constant` op that we want to support scalar and dense attribute values, but because the types involved are `mod_arith` types, to use the dense int element (or even plain integer) attribute parsers from upstream I have to redo a lot of upstream work to swap out the attribute's underlying type.

Looking into it, it seems the parsers only really need to know the bit width of the underlying storage type so that they can construct the relevant `APInt`s from parsed integer literals. Otherwise the attribute type is essentially a pass through. So I figured if the attribute's type can advertise its storage bitwidth, that would suffice to make the parsers more generic.

This proof of concept PR works with HEIR's `mod_arith` dialect as shown in https://github.com/google/heir/pull/1758.

Notably, with this change I get the option to use splatted dense attributes out of tree for free (cf. `syntax.mlir` in the linked PR). It even does nice things like range checking for attributes that don't fit within the advertised bit width of the out-of-tree type.

---
Full diff: https://github.com/llvm/llvm-project/pull/137061.diff


8 Files Affected:

- (modified) mlir/include/mlir/IR/BuiltinAttributes.h (+3-1) 
- (modified) mlir/include/mlir/IR/BuiltinTypeInterfaces.td (+38) 
- (modified) mlir/include/mlir/IR/BuiltinTypes.td (+4-2) 
- (modified) mlir/lib/AsmParser/AttributeParser.cpp (+9-7) 
- (modified) mlir/lib/IR/AsmPrinter.cpp (+1-1) 
- (modified) mlir/lib/IR/AttributeDetail.h (+3-2) 
- (modified) mlir/lib/IR/BuiltinAttributes.cpp (+17-18) 
- (modified) mlir/lib/IR/BuiltinTypes.cpp (+10) 


``````````diff
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index 67fab7ebc13ba..5eb2c891047cd 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -548,7 +548,9 @@ class DenseElementsAttr : public Attribute {
       std::enable_if_t<std::is_same<T, APInt>::value>;
   template <typename T, typename = APIntValueTemplateCheckT<T>>
   FailureOr<iterator_range_impl<IntElementIterator>> tryGetValues() const {
-    if (!getElementType().isIntOrIndex())
+    auto intLikeType =
+        llvm::dyn_cast<IntegerLikeTypeInterface>(getElementType());
+    if (!intLikeType)
       return failure();
     return iterator_range_impl<IntElementIterator>(getType(), raw_int_begin(),
                                                    raw_int_end());
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index 4a4f818b46c57..3d459f006093a 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -257,4 +257,42 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
   }];
 }
 
+def IntegerLikeTypeInterface : TypeInterface<"IntegerLikeTypeInterface"> {
+  let cppNamespace = "::mlir";
+  let description = [{
+    This type interface is for types that behave like integers. It provides
+    the API that allows MLIR utilities to treat them the same was as MLIR
+    treats integer types in settings like parsing and printing.
+  }];
+
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/[{
+        Returns the storage bit width for this type.
+      }],
+      /*retTy=*/"unsigned",
+      /*methodName=*/"getStorageBitWidth",
+      /*args=*/(ins)
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Returns true if this type is signed.
+      }],
+      /*retTy=*/"bool",
+      /*methodName=*/"isSigned",
+      /*args=*/(ins),
+      /*defaultImplementation=*/"return true;"
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Returns true if this type is signless.
+      }],
+      /*retTy=*/"bool",
+      /*methodName=*/"isSignless",
+      /*args=*/(ins),
+      /*defaultImplementation=*/"return true;"
+    >,
+  ];
+}
+
 #endif // MLIR_IR_BUILTINTYPEINTERFACES_TD_
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 771de01fc8d5d..6eb2ec333351a 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -466,7 +466,8 @@ def Builtin_Function : Builtin_Type<"Function", "function"> {
 //===----------------------------------------------------------------------===//
 
 def Builtin_Index : Builtin_Type<"Index", "index",
-    [VectorElementTypeInterface]> {
+    [VectorElementTypeInterface,
+     DeclareTypeInterfaceMethods<IntegerLikeTypeInterface, ["getStorageBitWidth"]>]> {
   let summary = "Integer-like type with unknown platform-dependent bit width";
   let description = [{
     Syntax:
@@ -497,7 +498,8 @@ def Builtin_Index : Builtin_Type<"Index", "index",
 //===----------------------------------------------------------------------===//
 
 def Builtin_Integer : Builtin_Type<"Integer", "integer",
-    [VectorElementTypeInterface]> {
+    [VectorElementTypeInterface,
+     DeclareTypeInterfaceMethods<IntegerLikeTypeInterface, ["getStorageBitWidth"]>]> {
   let summary = "Integer type with arbitrary precision up to a fixed limit";
   let description = [{
     Syntax:
diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index 2474e88373e04..a3252cf1964ee 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -16,6 +16,7 @@
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/DialectResourceBlobManager.h"
 #include "mlir/IR/IntegerSet.h"
@@ -366,8 +367,12 @@ static std::optional<APInt> buildAttributeAPInt(Type type, bool isNegative,
     return std::nullopt;
 
   // Extend or truncate the bitwidth to the right size.
-  unsigned width = type.isIndex() ? IndexType::kInternalStorageBitWidth
-                                  : type.getIntOrFloatBitWidth();
+  unsigned width;
+  if (auto intLikeType = dyn_cast<IntegerLikeTypeInterface>(type)) {
+    width = intLikeType.getStorageBitWidth();
+  } else {
+    width = type.getIntOrFloatBitWidth();
+  }
 
   if (width > result.getBitWidth()) {
     result = result.zext(width);
@@ -425,10 +430,6 @@ Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
     return FloatAttr::get(floatType, *result);
   }
 
-  if (!isa<IntegerType, IndexType>(type))
-    return emitError(loc, "integer literal not valid for specified type"),
-           nullptr;
-
   if (isNegative && type.isUnsignedInteger()) {
     emitError(loc,
               "negative integer literal not valid for unsigned integer type");
@@ -584,7 +585,8 @@ DenseElementsAttr TensorLiteralParser::getAttr(SMLoc loc, ShapedType type) {
   }
 
   // Handle integer and index types.
-  if (eltType.isIntOrIndex()) {
+  auto integerLikeType = dyn_cast<IntegerLikeTypeInterface>(eltType);
+  if (integerLikeType || eltType.isIntOrIndex()) {
     std::vector<APInt> intValues;
     if (failed(getIntAttrElements(loc, eltType, intValues)))
       return nullptr;
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 5b5ec841917e7..d74bf5d975f0b 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2656,7 +2656,7 @@ void AsmPrinter::Impl::printDenseIntOrFPElementsAttr(
         os << ")";
       });
     }
-  } else if (elementType.isIntOrIndex()) {
+  } else if (isa<IntegerLikeTypeInterface>(elementType)) {
     auto valueIt = attr.value_begin<APInt>();
     printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
       printDenseIntElement(*(valueIt + index), os, elementType);
diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h
index 26d40ac3a38f6..96b269e3b1363 100644
--- a/mlir/lib/IR/AttributeDetail.h
+++ b/mlir/lib/IR/AttributeDetail.h
@@ -37,8 +37,9 @@ inline size_t getDenseElementBitWidth(Type eltType) {
   // Align the width for complex to 8 to make storage and interpretation easier.
   if (ComplexType comp = llvm::dyn_cast<ComplexType>(eltType))
     return llvm::alignTo<8>(getDenseElementBitWidth(comp.getElementType())) * 2;
-  if (eltType.isIndex())
-    return IndexType::kInternalStorageBitWidth;
+  if (auto intLikeType = dyn_cast<IntegerLikeTypeInterface>(eltType))
+    return intLikeType.getStorageBitWidth();
+
   return eltType.getIntOrFloatBitWidth();
 }
 
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index daf79dc5de981..b185223085c3d 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -10,6 +10,7 @@
 #include "AttributeDetail.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/DialectResourceBlobManager.h"
 #include "mlir/IR/IntegerSet.h"
@@ -379,22 +380,20 @@ APSInt IntegerAttr::getAPSInt() const {
 
 LogicalResult IntegerAttr::verify(function_ref<InFlightDiagnostic()> emitError,
                                   Type type, APInt value) {
-  if (IntegerType integerType = llvm::dyn_cast<IntegerType>(type)) {
-    if (integerType.getWidth() != value.getBitWidth())
-      return emitError() << "integer type bit width (" << integerType.getWidth()
-                         << ") doesn't match value bit width ("
-                         << value.getBitWidth() << ")";
-    return success();
+  unsigned width;
+  if (auto intLikeType = dyn_cast<IntegerLikeTypeInterface>(type)) {
+    width = intLikeType.getStorageBitWidth();
+  } else {
+    return emitError() << "expected integer-like type";
   }
-  if (llvm::isa<IndexType>(type)) {
-    if (value.getBitWidth() != IndexType::kInternalStorageBitWidth)
-      return emitError()
-             << "value bit width (" << value.getBitWidth()
-             << ") doesn't match index type internal storage bit width ("
-             << IndexType::kInternalStorageBitWidth << ")";
-    return success();
+
+  if (width != value.getBitWidth()) {
+    return emitError() << "integer-like type bit width (" << width
+                       << ") doesn't match value bit width ("
+                       << value.getBitWidth() << ")";
   }
-  return emitError() << "expected integer or index type";
+
+  return success();
 }
 
 BoolAttr IntegerAttr::getBoolAttrUnchecked(IntegerType type, bool value) {
@@ -1019,7 +1018,7 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
 /// element type of 'type'.
 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
                                          ArrayRef<APInt> values) {
-  assert(type.getElementType().isIntOrIndex());
+  assert(isa<IntegerLikeTypeInterface>(type.getElementType()));
   assert(hasSameElementsOrSplat(type, values));
   size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values);
@@ -1130,11 +1129,11 @@ static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt,
   if (type.isIndex())
     return true;
 
-  auto intType = llvm::dyn_cast<IntegerType>(type);
+  auto intType = llvm::dyn_cast<IntegerLikeTypeInterface>(type);
   if (!intType) {
     LLVM_DEBUG(llvm::dbgs()
-               << "expected integer type when isInt is true, but found " << type
-               << "\n");
+               << "expected integer-like type when isInt is true, but found "
+               << type << "\n");
     return false;
   }
 
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 3924d082f0628..c38e33dba7e5d 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -59,6 +59,14 @@ LogicalResult ComplexType::verify(function_ref<InFlightDiagnostic()> emitError,
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// Index Type
+//===----------------------------------------------------------------------===//
+
+unsigned IndexType::getStorageBitWidth() const {
+  return kInternalStorageBitWidth;
+}
+
 //===----------------------------------------------------------------------===//
 // Integer Type
 //===----------------------------------------------------------------------===//
@@ -86,6 +94,8 @@ IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
   return IntegerType::get(getContext(), scale * getWidth(), getSignedness());
 }
 
+unsigned IntegerType::getStorageBitWidth() const { return getWidth(); }
+
 //===----------------------------------------------------------------------===//
 // Float Types
 //===----------------------------------------------------------------------===//

``````````

</details>


https://github.com/llvm/llvm-project/pull/137061


More information about the Mlir-commits mailing list