[Mlir-commits] [mlir] [MLIR] Add index bitwidth to the DataLayout. (PR #85927)

Tobias Gysi llvmlistbot at llvm.org
Wed Mar 20 06:05:28 PDT 2024


https://github.com/gysit created https://github.com/llvm/llvm-project/pull/85927

When importing from LLVM IR the data layout of all pointer types contains an index bitwidth that should be used for index computations. This revision adds a getter to the DataLayout that provides access to the already stored bitwidth. The function returns an optional since only pointer-like types have an index bitwidth. Querying the bitwidth of a non-pointer type returns std::nullopt.

The new function works for the built-in Index type and, using a type interface, for the LLVMPointerType.

>From 47ba0caf7d07e94accb9b39ff5762147e26945fa Mon Sep 17 00:00:00 2001
From: Tobias Gysi <tobias.gysi at nextsilicon.com>
Date: Wed, 20 Mar 2024 12:14:10 +0000
Subject: [PATCH] [MLIR] Add index bitwidth to the DataLayout.

When importing from LLVM IR the data layout of all pointer types
contains an index bitwidth that should be used for index computations.
This revision adds a getter to the DataLayout that provides access to
the already stored bitwidth. The function returns an optional since only
pointer-like types have an index bitwidth. Querying the bitwidth of a
non-pointer type returns std::nullopt.

The new function works for the built-in Index type and, using a type
interface, for the LLVMPointerType.
---
 mlir/docs/DataLayout.md                       |  4 +-
 mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td |  2 +-
 .../mlir/Interfaces/DataLayoutInterfaces.h    | 13 +++++
 .../mlir/Interfaces/DataLayoutInterfaces.td   | 28 ++++++++++
 mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp      | 25 +++++++--
 mlir/lib/Interfaces/DataLayoutInterfaces.cpp  | 30 ++++++++++-
 mlir/lib/Target/LLVMIR/ModuleTranslation.cpp  |  4 +-
 mlir/test/Dialect/LLVMIR/layout.mlir          | 30 +++++++----
 .../DataLayoutInterfaces/module.mlir          |  4 +-
 .../DataLayoutInterfaces/query.mlir           | 51 ++++++++++++++++---
 .../DataLayoutInterfaces/types.mlir           |  1 +
 .../lib/Dialect/DLTI/TestDataLayoutQuery.cpp  | 12 +++--
 mlir/test/lib/Dialect/Test/TestTypeDefs.td    |  3 +-
 mlir/test/lib/Dialect/Test/TestTypes.cpp      |  8 ++-
 .../Interfaces/DataLayoutInterfacesTest.cpp   |  7 +++
 15 files changed, 189 insertions(+), 33 deletions(-)

diff --git a/mlir/docs/DataLayout.md b/mlir/docs/DataLayout.md
index b9dde30519d6ed..86ad51a517ae7d 100644
--- a/mlir/docs/DataLayout.md
+++ b/mlir/docs/DataLayout.md
@@ -77,6 +77,7 @@ public:
   llvm::TypeSize getTypeSizeInBits(Type type) const;
   uint64_t getTypeABIAlignment(Type type) const;
   uint64_t getTypePreferredAlignment(Type type) const;
+  std::optional<uint64_t> getTypeIndexBitwidth(Type type) const;
 };
 ```
 
@@ -267,7 +268,8 @@ module attributes { dlti.dl_spec = #dlti.dl_spec<
 >} {}
 ```
 
-specifies that `index` has 32 bits. All other layout properties of `index` match
+specifies that `index` has 32 bits and index computations should be performed
+using 32-bit precision as well. All other layout properties of `index` match
 those of the integer type with the same bitwidth defined above.
 
 In absence of the corresponding entry, `index` is assumed to be a 64-bit
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
index 96cdbf01b4bd91..b7176aa93ff1f7 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
@@ -123,7 +123,7 @@ def LLVMFunctionType : LLVMType<"LLVMFunction", "func"> {
 
 def LLVMPointerType : LLVMType<"LLVMPointer", "ptr", [
     DeclareTypeInterfaceMethods<DataLayoutTypeInterface, [
-      "areCompatible", "verifyEntries"]>]> {
+      "getIndexBitwidth", "areCompatible", "verifyEntries"]>]> {
   let summary = "LLVM pointer type";
   let description = [{
     The `!llvm.ptr` type is an LLVM pointer type. This type typically represents
diff --git a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h
index 4a21f76dfc5d1c..046354677e6a00 100644
--- a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h
+++ b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h
@@ -57,6 +57,13 @@ uint64_t
 getDefaultPreferredAlignment(Type type, const DataLayout &dataLayout,
                              ArrayRef<DataLayoutEntryInterface> params);
 
+/// Default handler for the index bitwidth request. Computes the result for
+/// the built-in index type and dispatches to the DataLayoutTypeInterface for
+/// other types.
+std::optional<uint64_t>
+getDefaultIndexBitwidth(Type type, const DataLayout &dataLayout,
+                        ArrayRef<DataLayoutEntryInterface> params);
+
 /// Default handler for alloca memory space request. Dispatches to the
 /// DataLayoutInterface if specified, otherwise returns the default.
 Attribute getDefaultAllocaMemorySpace(DataLayoutEntryInterface entry);
@@ -180,6 +187,11 @@ class DataLayout {
   /// Returns the preferred of the given type in the current scope.
   uint64_t getTypePreferredAlignment(Type t) const;
 
+  /// Returns the bitwidth that should be used when performing index
+  /// computations for the given pointer-like type in the current scope. If the
+  /// type is not a pointer-like type, it returns std::nullopt.
+  std::optional<uint64_t> getTypeIndexBitwidth(Type t) const;
+
   /// Returns the memory space used for AllocaOps.
   Attribute getAllocaMemorySpace() const;
 
@@ -216,6 +228,7 @@ class DataLayout {
   mutable DenseMap<Type, llvm::TypeSize> bitsizes;
   mutable DenseMap<Type, uint64_t> abiAlignments;
   mutable DenseMap<Type, uint64_t> preferredAlignments;
+  mutable DenseMap<Type, std::optional<uint64_t>> indexBitwidths;
 
   /// Cache for alloca, global, and program memory spaces.
   mutable std::optional<Attribute> allocaMemorySpace;
diff --git a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td
index a8def967fffcfa..0ee7a116d11421 100644
--- a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td
+++ b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td
@@ -280,6 +280,22 @@ def DataLayoutOpInterface : OpInterface<"DataLayoutOpInterface"> {
                                                             params);
       }]
     >,
+    StaticInterfaceMethod<
+      /*description=*/"Returns the bitwidth that should be used when "
+                      "performing index computations for the type computed "
+                      "using the relevant entries. The data layout object can "
+                      "be used for recursive queries.",
+      /*retTy=*/"std::optional<uint64_t>",
+      /*methodName=*/"getIndexBitwidth",
+      /*args=*/(ins "::mlir::Type":$type,
+                    "const ::mlir::DataLayout &":$dataLayout,
+                    "::mlir::DataLayoutEntryListRef":$params),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return ::mlir::detail::getDefaultIndexBitwidth(type, dataLayout,
+                                                       params);
+      }]
+    >,
     StaticInterfaceMethod<
       /*description=*/"Returns the memory space used by the ABI computed "
                       "using the relevant entries. The data layout object "
@@ -400,6 +416,18 @@ def DataLayoutTypeInterface : TypeInterface<"DataLayoutTypeInterface"> {
       /*args=*/(ins "const ::mlir::DataLayout &":$dataLayout,
                     "::mlir::DataLayoutEntryListRef":$params)
     >,
+    InterfaceMethod<
+      /*description=*/"Returns the bitwidth that should be used when "
+                      "performing index computations for the given "
+                      "pointer-like type. If the type is not a pointer-like "
+                      "type, returns std::nullopt.",
+      /*retTy=*/"std::optional<uint64_t>",
+      /*methodName=*/"getIndexBitwidth",
+      /*args=*/(ins "const ::mlir::DataLayout &":$dataLayout,
+                    "::mlir::DataLayoutEntryListRef":$params),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{ return std::nullopt; }]
+    >,
     InterfaceMethod<
       /*desc=*/"Returns true if the two lists of entries are compatible, that "
                "is, that `newLayout` spec entries can be nested in an op with "
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index 443e245887ea8e..630187f220a4ba 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -287,15 +287,22 @@ getPointerDataLayoutEntry(DataLayoutEntryListRef params, LLVMPointerType type,
     }
   }
   if (currentEntry) {
-    return *extractPointerSpecValue(currentEntry, pos) /
-           (pos == PtrDLEntryPos::Size ? 1 : kBitsInByte);
+    std::optional<uint64_t> value = extractPointerSpecValue(currentEntry, pos);
+    // If the optional `PtrDLEntryPos::Index` entry is not available, use the
+    // pointer size as the index bitwidth.
+    if (!value && pos == PtrDLEntryPos::Index)
+      value = extractPointerSpecValue(currentEntry, PtrDLEntryPos::Size);
+    bool isSizeOrIndex =
+        pos == PtrDLEntryPos::Size || pos == PtrDLEntryPos::Index;
+    return *value / (isSizeOrIndex ? 1 : kBitsInByte);
   }
 
   // If not found, and this is the pointer to the default memory space, assume
   // 64-bit pointers.
   if (type.getAddressSpace() == 0) {
-    return pos == PtrDLEntryPos::Size ? kDefaultPointerSizeBits
-                                      : kDefaultPointerAlignment;
+    bool isSizeOrIndex =
+        pos == PtrDLEntryPos::Size || pos == PtrDLEntryPos::Index;
+    return isSizeOrIndex ? kDefaultPointerSizeBits : kDefaultPointerAlignment;
   }
 
   return std::nullopt;
@@ -332,6 +339,16 @@ LLVMPointerType::getPreferredAlignment(const DataLayout &dataLayout,
   return dataLayout.getTypePreferredAlignment(get(getContext()));
 }
 
+std::optional<uint64_t>
+LLVMPointerType::getIndexBitwidth(const DataLayout &dataLayout,
+                                  DataLayoutEntryListRef params) const {
+  if (std::optional<uint64_t> indexBitwidth =
+          getPointerDataLayoutEntry(params, *this, PtrDLEntryPos::Index))
+    return *indexBitwidth;
+
+  return dataLayout.getTypeIndexBitwidth(get(getContext()));
+}
+
 bool LLVMPointerType::areCompatible(DataLayoutEntryListRef oldLayout,
                                     DataLayoutEntryListRef newLayout) const {
   for (DataLayoutEntryInterface newEntry : newLayout) {
diff --git a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
index 65c41f44192a90..b5b7d78cfeff76 100644
--- a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
+++ b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
@@ -218,7 +218,23 @@ uint64_t mlir::detail::getDefaultPreferredAlignment(
   reportMissingDataLayout(type);
 }
 
-// Returns the memory space used for allocal operations if specified in the
+std::optional<uint64_t> mlir::detail::getDefaultIndexBitwidth(
+    Type type, const DataLayout &dataLayout,
+    ArrayRef<DataLayoutEntryInterface> params) {
+  if (isa<IndexType>(type))
+    return getIndexBitwidth(params);
+
+  if (auto typeInterface = dyn_cast<DataLayoutTypeInterface>(type))
+    if (std::optional<uint64_t> indexBitwidth =
+            typeInterface.getIndexBitwidth(dataLayout, params))
+      return *indexBitwidth;
+
+  // Return std::nullopt for all other types, which are assumed to be non
+  // pointer-like types.
+  return std::nullopt;
+}
+
+// Returns the memory space used for alloca operations if specified in the
 // given entry. If the entry is empty the default memory space represented by
 // an empty attribute is returned.
 Attribute
@@ -520,6 +536,18 @@ uint64_t mlir::DataLayout::getTypePreferredAlignment(Type t) const {
   });
 }
 
+std::optional<uint64_t> mlir::DataLayout::getTypeIndexBitwidth(Type t) const {
+  checkValid();
+  return cachedLookup<std::optional<uint64_t>>(t, indexBitwidths, [&](Type ty) {
+    DataLayoutEntryList list;
+    if (originalLayout)
+      list = originalLayout.getSpecForType(ty.getTypeID());
+    if (auto iface = dyn_cast_or_null<DataLayoutOpInterface>(scope))
+      return iface.getIndexBitwidth(ty, *this, list);
+    return detail::getDefaultIndexBitwidth(ty, *this, list);
+  });
+}
+
 mlir::Attribute mlir::DataLayout::getAllocaMemorySpace() const {
   checkValid();
   if (allocaMemorySpace)
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 995544238e4a3c..81242efc04b7e8 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -281,8 +281,8 @@ translateDataLayout(DataLayoutSpecInterface attribute,
               uint64_t preferred =
                   dataLayout.getTypePreferredAlignment(type) * 8u;
               layoutStream << size << ":" << abi << ":" << preferred;
-              if (std::optional<uint64_t> index = extractPointerSpecValue(
-                      entry.getValue(), PtrDLEntryPos::Index))
+              if (std::optional<uint64_t> index =
+                      dataLayout.getTypeIndexBitwidth(type))
                 layoutStream << ":" << *index;
               return success();
             })
diff --git a/mlir/test/Dialect/LLVMIR/layout.mlir b/mlir/test/Dialect/LLVMIR/layout.mlir
index 2868e1740f861c..a78fb771242e00 100644
--- a/mlir/test/Dialect/LLVMIR/layout.mlir
+++ b/mlir/test/Dialect/LLVMIR/layout.mlir
@@ -7,6 +7,7 @@ module {
     // CHECK: alloca_memory_space = 0
     // CHECK: bitsize = 64
     // CHECK: global_memory_space = 0
+    // CHECK: index = 64
     // CHECK: preferred = 8
     // CHECK: program_memory_space = 0
     // CHECK: size = 8
@@ -16,6 +17,7 @@ module {
     // CHECK: alloca_memory_space = 0
     // CHECK: bitsize = 64
     // CHECK: global_memory_space = 0
+    // CHECK: index = 64
     // CHECK: preferred = 8
     // CHECK: program_memory_space = 0
     // CHECK: size = 8
@@ -25,6 +27,7 @@ module {
     // CHECK: alloca_memory_space = 0
     // CHECK: bitsize = 64
     // CHECK: global_memory_space = 0
+    // CHECK: index = 64
     // CHECK: preferred = 8
     // CHECK: program_memory_space = 0
     // CHECK: size = 8
@@ -39,7 +42,7 @@ module {
 module attributes { dlti.dl_spec = #dlti.dl_spec<
   #dlti.dl_entry<!llvm.ptr, dense<[32, 32, 64]> : vector<3xi64>>,
   #dlti.dl_entry<!llvm.ptr<5>, dense<[64, 64, 64]> : vector<3xi64>>,
-  #dlti.dl_entry<!llvm.ptr<4>, dense<[32, 64, 64]> : vector<3xi64>>,
+  #dlti.dl_entry<!llvm.ptr<4>, dense<[32, 64, 64, 24]> : vector<4xi64>>,
   #dlti.dl_entry<"dlti.alloca_memory_space", 5 : ui64>,
   #dlti.dl_entry<"dlti.global_memory_space", 2 : ui64>,
   #dlti.dl_entry<"dlti.program_memory_space", 3 : ui64>,
@@ -51,6 +54,7 @@ module attributes { dlti.dl_spec = #dlti.dl_spec<
     // CHECK: alloca_memory_space = 5
     // CHECK: bitsize = 32
     // CHECK: global_memory_space = 2
+    // CHECK: index = 32
     // CHECK: preferred = 8
     // CHECK: program_memory_space = 3
     // CHECK: size = 4
@@ -60,6 +64,7 @@ module attributes { dlti.dl_spec = #dlti.dl_spec<
     // CHECK: alloca_memory_space = 5
     // CHECK: bitsize = 32
     // CHECK: global_memory_space = 2
+    // CHECK: index = 32
     // CHECK: preferred = 8
     // CHECK: program_memory_space = 3
     // CHECK: size = 4
@@ -69,24 +74,17 @@ module attributes { dlti.dl_spec = #dlti.dl_spec<
     // CHECK: alloca_memory_space = 5
     // CHECK: bitsize = 64
     // CHECK: global_memory_space = 2
+    // CHECK: index = 64
     // CHECK: preferred = 8
     // CHECK: program_memory_space = 3
     // CHECK: size = 8
     // CHECK: stack_alignment = 128
     "test.data_layout_query"() : () -> !llvm.ptr<5>
-    // CHECK: alignment = 4
-    // CHECK: alloca_memory_space = 5
-    // CHECK: bitsize = 32
-    // CHECK: global_memory_space = 2
-    // CHECK: preferred = 8
-    // CHECK: program_memory_space = 3
-    // CHECK: size = 4
-    // CHECK: stack_alignment = 128
-    "test.data_layout_query"() : () -> !llvm.ptr<3>
     // CHECK: alignment = 8
     // CHECK: alloca_memory_space = 5
     // CHECK: bitsize = 32
     // CHECK: global_memory_space = 2
+    // CHECK: index = 24
     // CHECK: preferred = 8
     // CHECK: program_memory_space = 3
     // CHECK: size = 4
@@ -134,6 +132,7 @@ module {
         // simple case
         // CHECK: alignment = 4
         // CHECK: bitsize = 32
+        // CHECK: index = 0
         // CHECK: preferred = 4
         // CHECK: size = 4
         "test.data_layout_query"() : () -> !llvm.struct<(i32)>
@@ -141,6 +140,7 @@ module {
         // padding inbetween
         // CHECK: alignment = 8
         // CHECK: bitsize = 128
+        // CHECK: index = 0
         // CHECK: preferred = 8
         // CHECK: size = 16
         "test.data_layout_query"() : () -> !llvm.struct<(i32, f64)>
@@ -148,6 +148,7 @@ module {
         // padding at end of struct
         // CHECK: alignment = 8
         // CHECK: bitsize = 128
+        // CHECK: index = 0
         // CHECK: preferred = 8
         // CHECK: size = 16
         "test.data_layout_query"() : () -> !llvm.struct<(f64, i32)>
@@ -155,6 +156,7 @@ module {
          // packed
          // CHECK: alignment = 1
          // CHECK: bitsize = 96
+         // CHECK: index = 0
          // CHECK: preferred = 8
          // CHECK: size = 12
          "test.data_layout_query"() : () -> !llvm.struct<packed (f64, i32)>
@@ -162,6 +164,7 @@ module {
          // empty
          // CHECK: alignment = 1
          // CHECK: bitsize = 0
+         // CHECK: index = 0
          // CHECK: preferred = 1
          // CHECK: size = 0
          "test.data_layout_query"() : () -> !llvm.struct<()>
@@ -179,6 +182,7 @@ module attributes { dlti.dl_spec = #dlti.dl_spec<
         // Strict alignment is applied
         // CHECK: alignment = 4
         // CHECK: bitsize = 16
+        // CHECK: index = 0
         // CHECK: preferred = 4
         // CHECK: size = 2
         "test.data_layout_query"() : () -> !llvm.struct<(i16)>
@@ -186,6 +190,7 @@ module attributes { dlti.dl_spec = #dlti.dl_spec<
         // No impact on structs that have stricter requirements
         // CHECK: alignment = 8
         // CHECK: bitsize = 128
+        // CHECK: index = 0
         // CHECK: preferred = 8
         // CHECK: size = 16
         "test.data_layout_query"() : () -> !llvm.struct<(i32, f64)>
@@ -193,6 +198,7 @@ module attributes { dlti.dl_spec = #dlti.dl_spec<
          // Only the preferred alignment of structs is affected
          // CHECK: alignment = 1
          // CHECK: bitsize = 32
+         // CHECK: index = 0
          // CHECK: preferred = 4
          // CHECK: size = 4
          "test.data_layout_query"() : () -> !llvm.struct<packed (i16, i16)>
@@ -200,6 +206,7 @@ module attributes { dlti.dl_spec = #dlti.dl_spec<
          // empty
          // CHECK: alignment = 4
          // CHECK: bitsize = 0
+         // CHECK: index = 0
          // CHECK: preferred = 4
          // CHECK: size = 0
          "test.data_layout_query"() : () -> !llvm.struct<()>
@@ -265,6 +272,7 @@ module {
         // simple case
         // CHECK: alignment = 4
         // CHECK: bitsize = 64
+        // CHECK: index = 0
         // CHECK: preferred = 4
         // CHECK: size = 8
         "test.data_layout_query"() : () -> !llvm.array<2 x i32>
@@ -272,6 +280,7 @@ module {
         // size 0
         // CHECK: alignment = 8
         // CHECK: bitsize = 0
+        // CHECK: index = 0
         // CHECK: preferred = 8
         // CHECK: size = 0
         "test.data_layout_query"() : () -> !llvm.array<0 x f64>
@@ -279,6 +288,7 @@ module {
         // alignment info matches element type
         // CHECK: alignment = 4
         // CHECK: bitsize = 64
+        // CHECK: index = 0
         // CHECK: preferred = 8
         // CHECK: size = 8
         "test.data_layout_query"() : () -> !llvm.array<1 x i64>
diff --git a/mlir/test/Interfaces/DataLayoutInterfaces/module.mlir b/mlir/test/Interfaces/DataLayoutInterfaces/module.mlir
index 096e7ceb3cbcef..97286ce758069c 100644
--- a/mlir/test/Interfaces/DataLayoutInterfaces/module.mlir
+++ b/mlir/test/Interfaces/DataLayoutInterfaces/module.mlir
@@ -2,11 +2,13 @@
 
 module attributes { dlti.dl_spec = #dlti.dl_spec<
       #dlti.dl_entry<!test.test_type_with_layout<10>, ["size", 12]>,
-      #dlti.dl_entry<!test.test_type_with_layout<20>, ["alignment", 32]>>} {
+      #dlti.dl_entry<!test.test_type_with_layout<20>, ["alignment", 32]>,
+      #dlti.dl_entry<!test.test_type_with_layout<30>, ["index", 7]>>} {
   // CHECK-LABEL: @module_level_layout
   func.func @module_level_layout() {
      // CHECK: alignment = 32
      // CHECK: bitsize = 12
+     // CHECK: index = 7
      // CHECK: preferred = 1
      // CHECK: size = 2
     "test.data_layout_query"() : () -> !test.test_type_with_layout<10>
diff --git a/mlir/test/Interfaces/DataLayoutInterfaces/query.mlir b/mlir/test/Interfaces/DataLayoutInterfaces/query.mlir
index 9f9240ac6f8cea..d3bc91339d164b 100644
--- a/mlir/test/Interfaces/DataLayoutInterfaces/query.mlir
+++ b/mlir/test/Interfaces/DataLayoutInterfaces/query.mlir
@@ -4,24 +4,34 @@
 func.func @no_layout_builtin() {
   // CHECK: alignment = 4
   // CHECK: bitsize = 32
+  // CHECK: index = 0
   // CHECK: preferred = 4
   // CHECK: size = 4
   "test.data_layout_query"() : () -> i32
   // CHECK: alignment = 8
   // CHECK: bitsize = 64
+  // CHECK: index = 0
   // CHECK: preferred = 8
   // CHECK: size = 8
   "test.data_layout_query"() : () -> f64
   // CHECK: alignment = 4
   // CHECK: bitsize = 64
+  // CHECK: index = 0
   // CHECK: preferred = 4
   // CHECK: size = 8
   "test.data_layout_query"() : () -> complex<f32>
   // CHECK: alignment = 1
   // CHECK: bitsize = 14
+  // CHECK: index = 0
   // CHECK: preferred = 1
   // CHECK: size = 2
   "test.data_layout_query"() : () -> complex<i6>
+  // CHECK: alignment = 4
+  // CHECK: bitsize = 64
+  // CHECK: index = 64
+  // CHECK: preferred = 8
+  // CHECK: size = 8
+  "test.data_layout_query"() : () -> index
   return
 
 }
@@ -30,6 +40,7 @@ func.func @no_layout_builtin() {
 func.func @no_layout_custom() {
   // CHECK: alignment = 1
   // CHECK: bitsize = 1
+  // CHECK: index = 1
   // CHECK: preferred = 1
   // CHECK: size = 1
   "test.data_layout_query"() : () -> !test.test_type_with_layout<10>
@@ -41,6 +52,7 @@ func.func @layout_op_no_layout() {
   "test.op_with_data_layout"() ({
     // CHECK: alignment = 1
     // CHECK: bitsize = 1
+    // CHECK: index = 1
     // CHECK: preferred = 1
     // CHECK: size = 1
     "test.data_layout_query"() : () -> !test.test_type_with_layout<1000>
@@ -54,13 +66,15 @@ func.func @layout_op() {
   "test.op_with_data_layout"() ({
     // CHECK: alignment = 20
     // CHECK: bitsize = 10
+    // CHECK: index = 30
     // CHECK: preferred = 1
     // CHECK: size = 2
     "test.data_layout_query"() : () -> !test.test_type_with_layout<10>
     "test.maybe_terminator"() : () -> ()
   }) { dlti.dl_spec = #dlti.dl_spec<
       #dlti.dl_entry<!test.test_type_with_layout<10>, ["size", 10]>,
-      #dlti.dl_entry<!test.test_type_with_layout<20>, ["alignment", 20]>
+      #dlti.dl_entry<!test.test_type_with_layout<20>, ["alignment", 20]>,
+      #dlti.dl_entry<!test.test_type_with_layout<30>, ["index", 30]>
   >} : () -> ()
   return
 }
@@ -72,13 +86,15 @@ func.func @nested_inner_only() {
     "test.op_with_data_layout"() ({
       // CHECK: alignment = 20
       // CHECK: bitsize = 10
+      // CHECK: index = 30
       // CHECK: preferred = 1
       // CHECK: size = 2
       "test.data_layout_query"() : () -> !test.test_type_with_layout<10>
       "test.maybe_terminator"() : () -> ()
     }) { dlti.dl_spec = #dlti.dl_spec<
         #dlti.dl_entry<!test.test_type_with_layout<10>, ["size", 10]>,
-        #dlti.dl_entry<!test.test_type_with_layout<20>, ["alignment", 20]>
+        #dlti.dl_entry<!test.test_type_with_layout<20>, ["alignment", 20]>,
+        #dlti.dl_entry<!test.test_type_with_layout<30>, ["index", 30]>
     >} : () -> ()
     "test.maybe_terminator"() : () -> ()
   }) : () -> ()
@@ -92,6 +108,7 @@ func.func @nested_outer_only() {
     "test.op_with_data_layout"() ({
       // CHECK: alignment = 20
       // CHECK: bitsize = 10
+      // CHECK: index = 30
       // CHECK: preferred = 1
       // CHECK: size = 2
       "test.data_layout_query"() : () -> !test.test_type_with_layout<10>
@@ -100,7 +117,8 @@ func.func @nested_outer_only() {
     "test.maybe_terminator"() : () -> ()
   }) { dlti.dl_spec = #dlti.dl_spec<
       #dlti.dl_entry<!test.test_type_with_layout<10>, ["size", 10]>,
-      #dlti.dl_entry<!test.test_type_with_layout<20>, ["alignment", 20]>
+      #dlti.dl_entry<!test.test_type_with_layout<20>, ["alignment", 20]>,
+      #dlti.dl_entry<!test.test_type_with_layout<30>, ["index", 30]>
     >} : () -> ()
   return
 }
@@ -112,6 +130,7 @@ func.func @nested_middle_only() {
       "test.op_with_data_layout"() ({
         // CHECK: alignment = 20
         // CHECK: bitsize = 10
+        // CHECK: index = 30
         // CHECK: preferred = 1
         // CHECK: size = 2
         "test.data_layout_query"() : () -> !test.test_type_with_layout<10>
@@ -120,7 +139,8 @@ func.func @nested_middle_only() {
     "test.maybe_terminator"() : () -> ()
     }) { dlti.dl_spec = #dlti.dl_spec<
         #dlti.dl_entry<!test.test_type_with_layout<10>, ["size", 10]>,
-        #dlti.dl_entry<!test.test_type_with_layout<20>, ["alignment", 20]>
+        #dlti.dl_entry<!test.test_type_with_layout<20>, ["alignment", 20]>,
+        #dlti.dl_entry<!test.test_type_with_layout<30>, ["index", 30]>
       >} : () -> ()
     "test.maybe_terminator"() : () -> ()
   }) : () -> ()
@@ -134,6 +154,7 @@ func.func @nested_combine_with_missing() {
       "test.op_with_data_layout"() ({
         // CHECK: alignment = 20
         // CHECK: bitsize = 10
+        // CHECK: index = 21
         // CHECK: preferred = 30
         // CHECK: size = 2
         "test.data_layout_query"() : () -> !test.test_type_with_layout<10>
@@ -146,13 +167,15 @@ func.func @nested_combine_with_missing() {
       >} : () -> ()
     // CHECK: alignment = 1
     // CHECK: bitsize = 42
+    // CHECK: index = 21
     // CHECK: preferred = 30
     // CHECK: size = 6
     "test.data_layout_query"() : () -> !test.test_type_with_layout<10>
     "test.maybe_terminator"() : () -> ()
   }) { dlti.dl_spec = #dlti.dl_spec<
       #dlti.dl_entry<!test.test_type_with_layout<10>, ["size", 42]>,
-      #dlti.dl_entry<!test.test_type_with_layout<30>, ["preferred", 30]>
+      #dlti.dl_entry<!test.test_type_with_layout<30>, ["preferred", 30]>,
+      #dlti.dl_entry<!test.test_type_with_layout<40>, ["index", 21]>
   >}: () -> ()
   return
 }
@@ -164,6 +187,7 @@ func.func @nested_combine_all() {
       "test.op_with_data_layout"() ({
         // CHECK: alignment = 20
         // CHECK: bitsize = 3
+        // CHECK: index = 40
         // CHECK: preferred = 30
         // CHECK: size = 1
         "test.data_layout_query"() : () -> !test.test_type_with_layout<10>
@@ -174,16 +198,19 @@ func.func @nested_combine_all() {
         >} : () -> ()
       // CHECK: alignment = 20
       // CHECK: bitsize = 10
+      // CHECK: index = 40
       // CHECK: preferred = 30
       // CHECK: size = 2
       "test.data_layout_query"() : () -> !test.test_type_with_layout<10>
       "test.maybe_terminator"() : () -> ()
     }) { dlti.dl_spec = #dlti.dl_spec<
         #dlti.dl_entry<!test.test_type_with_layout<10>, ["size", 10]>,
-        #dlti.dl_entry<!test.test_type_with_layout<20>, ["alignment", 20]>
+        #dlti.dl_entry<!test.test_type_with_layout<20>, ["alignment", 20]>,
+        #dlti.dl_entry<!test.test_type_with_layout<40>, ["index", 40]>
       >} : () -> ()
     // CHECK: alignment = 1
     // CHECK: bitsize = 42
+    // CHECK: index = 1
     // CHECK: preferred = 30
     // CHECK: size = 6
     "test.data_layout_query"() : () -> !test.test_type_with_layout<10>
@@ -200,18 +227,22 @@ func.func @integers() {
   "test.op_with_data_layout"() ({
     // CHECK: alignment = 8
     // CHECK: bitsize = 32
+    // CHECK: index = 0
     // CHECK: preferred = 8
     "test.data_layout_query"() : () -> i32
     // CHECK: alignment = 16
     // CHECK: bitsize = 56
+    // CHECK: index = 0
     // CHECK: preferred = 16
     "test.data_layout_query"() : () -> i56
     // CHECK: alignment = 16
     // CHECK: bitsize = 64
+    // CHECK: index = 0
     // CHECK: preferred = 16
     "test.data_layout_query"() : () -> i64
     // CHECK: alignment = 16
     // CHECK: bitsize = 128
+    // CHECK: index = 0
     // CHECK: preferred = 16
     "test.data_layout_query"() : () -> i128
     "test.maybe_terminator"() : () -> ()
@@ -222,18 +253,22 @@ func.func @integers() {
   "test.op_with_data_layout"() ({
     // CHECK: alignment = 8
     // CHECK: bitsize = 32
+    // CHECK: index = 0
     // CHECK: preferred = 16
     "test.data_layout_query"() : () -> i32
     // CHECK: alignment = 16
     // CHECK: bitsize = 56
+    // CHECK: index = 0
     // CHECK: preferred = 32
     "test.data_layout_query"() : () -> i56
     // CHECK: alignment = 16
     // CHECK: bitsize = 64
+    // CHECK: index = 0
     // CHECK: preferred = 32
     "test.data_layout_query"() : () -> i64
     // CHECK: alignment = 16
     // CHECK: bitsize = 128
+    // CHECK: index = 0
     // CHECK: preferred = 32
     "test.data_layout_query"() : () -> i128
     "test.maybe_terminator"() : () -> ()
@@ -248,10 +283,12 @@ func.func @floats() {
   "test.op_with_data_layout"() ({
     // CHECK: alignment = 8
     // CHECK: bitsize = 32
+    // CHECK: index = 0
     // CHECK: preferred = 8
     "test.data_layout_query"() : () -> f32
     // CHECK: alignment = 16
     // CHECK: bitsize = 80
+    // CHECK: index = 0
     // CHECK: preferred = 16
     "test.data_layout_query"() : () -> f80
     "test.maybe_terminator"() : () -> ()
@@ -262,10 +299,12 @@ func.func @floats() {
   "test.op_with_data_layout"() ({
     // CHECK: alignment = 8
     // CHECK: bitsize = 32
+    // CHECK: index = 0
     // CHECK: preferred = 16
     "test.data_layout_query"() : () -> f32
     // CHECK: alignment = 16
     // CHECK: bitsize = 80
+    // CHECK: index = 0
     // CHECK: preferred = 32
     "test.data_layout_query"() : () -> f80
     "test.maybe_terminator"() : () -> ()
diff --git a/mlir/test/Interfaces/DataLayoutInterfaces/types.mlir b/mlir/test/Interfaces/DataLayoutInterfaces/types.mlir
index 55bb1d2eac911c..82ae02cf92adff 100644
--- a/mlir/test/Interfaces/DataLayoutInterfaces/types.mlir
+++ b/mlir/test/Interfaces/DataLayoutInterfaces/types.mlir
@@ -40,6 +40,7 @@ module @index attributes { dlti.dl_spec = #dlti.dl_spec<
   #dlti.dl_entry<index, 32>>} {
   func.func @query() {
     // CHECK: bitsize = 32
+    // CHECK: index = 32
     "test.data_layout_query"() : () -> index
     return
   }
diff --git a/mlir/test/lib/Dialect/DLTI/TestDataLayoutQuery.cpp b/mlir/test/lib/Dialect/DLTI/TestDataLayoutQuery.cpp
index 740562e7783024..3da48ffa403ed6 100644
--- a/mlir/test/lib/Dialect/DLTI/TestDataLayoutQuery.cpp
+++ b/mlir/test/lib/Dialect/DLTI/TestDataLayoutQuery.cpp
@@ -36,19 +36,21 @@ struct TestDataLayoutQuery
         return;
 
       const DataLayout &layout = layouts.getAbove(op);
-      unsigned size = layout.getTypeSize(op.getType());
-      unsigned bitsize = layout.getTypeSizeInBits(op.getType());
-      unsigned alignment = layout.getTypeABIAlignment(op.getType());
-      unsigned preferred = layout.getTypePreferredAlignment(op.getType());
+      llvm::TypeSize size = layout.getTypeSize(op.getType());
+      llvm::TypeSize bitsize = layout.getTypeSizeInBits(op.getType());
+      uint64_t alignment = layout.getTypeABIAlignment(op.getType());
+      uint64_t preferred = layout.getTypePreferredAlignment(op.getType());
+      uint64_t index = layout.getTypeIndexBitwidth(op.getType()).value_or(0);
       Attribute allocaMemorySpace = layout.getAllocaMemorySpace();
       Attribute programMemorySpace = layout.getProgramMemorySpace();
       Attribute globalMemorySpace = layout.getGlobalMemorySpace();
-      unsigned stackAlignment = layout.getStackAlignment();
+      uint64_t stackAlignment = layout.getStackAlignment();
       op->setAttrs(
           {builder.getNamedAttr("size", builder.getIndexAttr(size)),
            builder.getNamedAttr("bitsize", builder.getIndexAttr(bitsize)),
            builder.getNamedAttr("alignment", builder.getIndexAttr(alignment)),
            builder.getNamedAttr("preferred", builder.getIndexAttr(preferred)),
+           builder.getNamedAttr("index", builder.getIndexAttr(index)),
            builder.getNamedAttr("alloca_memory_space",
                                 allocaMemorySpace == Attribute()
                                     ? builder.getUI32IntegerAttr(0)
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index 1957845c842f20..492642b711e09e 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -148,7 +148,8 @@ def TestType : Test_Type<"Test", [
 }
 
 def TestTypeWithLayoutType : Test_Type<"TestTypeWithLayout", [
-  DeclareTypeInterfaceMethods<DataLayoutTypeInterface, ["areCompatible"]>
+  DeclareTypeInterfaceMethods<DataLayoutTypeInterface, ["getIndexBitwidth",
+                                                        "areCompatible"]>
 ]> {
   let mnemonic = "test_type_with_layout";
   let parameters = (ins "unsigned":$key);
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp
index 2f4c9b689069b8..7a195eb25a3ba1 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp
@@ -276,6 +276,12 @@ uint64_t TestTypeWithLayoutType::getPreferredAlignment(
   return extractKind(params, "preferred");
 }
 
+std::optional<uint64_t>
+TestTypeWithLayoutType::getIndexBitwidth(const DataLayout &dataLayout,
+                                         DataLayoutEntryListRef params) const {
+  return extractKind(params, "index");
+}
+
 bool TestTypeWithLayoutType::areCompatible(
     DataLayoutEntryListRef oldLayout, DataLayoutEntryListRef newLayout) const {
   unsigned old = extractKind(oldLayout, "alignment");
@@ -297,7 +303,7 @@ TestTypeWithLayoutType::verifyEntries(DataLayoutEntryListRef params,
     (void)kind;
     assert(kind &&
            (kind.getValue() == "size" || kind.getValue() == "alignment" ||
-            kind.getValue() == "preferred") &&
+            kind.getValue() == "preferred" || kind.getValue() == "index") &&
            "unexpected kind");
     assert(llvm::isa<IntegerAttr>(array.getValue().back()));
   }
diff --git a/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp b/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp
index 794e19710fadca..65818bd642d877 100644
--- a/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp
+++ b/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp
@@ -345,6 +345,8 @@ TEST(DataLayout, NullSpec) {
   EXPECT_EQ(layout.getTypeABIAlignment(Float16Type::get(&ctx)), 16u);
   EXPECT_EQ(layout.getTypePreferredAlignment(IntegerType::get(&ctx, 42)), 128u);
   EXPECT_EQ(layout.getTypePreferredAlignment(Float16Type::get(&ctx)), 32u);
+  EXPECT_EQ(layout.getTypeIndexBitwidth(Float16Type::get(&ctx)), std::nullopt);
+  EXPECT_EQ(*layout.getTypeIndexBitwidth(IndexType::get(&ctx)), 64u);
 
   EXPECT_EQ(layout.getAllocaMemorySpace(), Attribute());
   EXPECT_EQ(layout.getProgramMemorySpace(), Attribute());
@@ -373,6 +375,8 @@ TEST(DataLayout, EmptySpec) {
   EXPECT_EQ(layout.getTypeABIAlignment(Float16Type::get(&ctx)), 16u);
   EXPECT_EQ(layout.getTypePreferredAlignment(IntegerType::get(&ctx, 42)), 128u);
   EXPECT_EQ(layout.getTypePreferredAlignment(Float16Type::get(&ctx)), 32u);
+  EXPECT_EQ(layout.getTypeIndexBitwidth(Float16Type::get(&ctx)), std::nullopt);
+  EXPECT_EQ(layout.getTypeIndexBitwidth(IndexType::get(&ctx)), 64u);
 
   EXPECT_EQ(layout.getAllocaMemorySpace(), Attribute());
   EXPECT_EQ(layout.getProgramMemorySpace(), Attribute());
@@ -385,6 +389,7 @@ TEST(DataLayout, SpecWithEntries) {
 "dltest.op_with_layout"() { dltest.layout = #dltest.spec<
   #dlti.dl_entry<i42, 5>,
   #dlti.dl_entry<i16, 6>,
+  #dlti.dl_entry<index, 42>,
   #dlti.dl_entry<"dltest.alloca_memory_space", 5 : i32>,
   #dlti.dl_entry<"dltest.program_memory_space", 3 : i32>,
   #dlti.dl_entry<"dltest.global_memory_space", 2 : i32>,
@@ -408,6 +413,8 @@ TEST(DataLayout, SpecWithEntries) {
   EXPECT_EQ(layout.getTypeABIAlignment(Float16Type::get(&ctx)), 8u);
   EXPECT_EQ(layout.getTypePreferredAlignment(IntegerType::get(&ctx, 42)), 16u);
   EXPECT_EQ(layout.getTypePreferredAlignment(Float16Type::get(&ctx)), 16u);
+  EXPECT_EQ(layout.getTypeIndexBitwidth(Float16Type::get(&ctx)), std::nullopt);
+  EXPECT_EQ(layout.getTypeIndexBitwidth(IndexType::get(&ctx)), 42u);
 
   EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 32)), 32u);
   EXPECT_EQ(layout.getTypeSize(Float32Type::get(&ctx)), 32u);



More information about the Mlir-commits mailing list