[Mlir-commits] [mlir] [mlir][Ptr] Init the Ptr dialect with the `!ptr.ptr` type. (PR #86860)

Fabian Mora llvmlistbot at llvm.org
Sat Jun 8 06:23:18 PDT 2024


https://github.com/fabianmcg updated https://github.com/llvm/llvm-project/pull/86860

>From fb9267a5f7da74845b82343b7a076244d59b637c Mon Sep 17 00:00:00 2001
From: Fabian Mora <fmora.dev at gmail.com>
Date: Wed, 27 Mar 2024 19:26:30 +0000
Subject: [PATCH 1/2] [mlir][Ptr] Init the Ptr dialect with the `!ptr.ptr`
 type.

This patch initializes the `ptr` dialect directories and some base files. It
also add the `!ptr.ptr` type, together with the `DataLayoutTypeInterface`
interface. The implementation of the `DataLayoutTypeInterface` interface
clones the implementation from `LLVM::LLVMPointerType`.
---
 mlir/include/mlir/Dialect/CMakeLists.txt      |   1 +
 mlir/include/mlir/Dialect/Ptr/CMakeLists.txt  |   1 +
 .../mlir/Dialect/Ptr/IR/CMakeLists.txt        |   2 +
 mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.h |  20 ++
 .../include/mlir/Dialect/Ptr/IR/PtrDialect.td |  83 ++++++++
 mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h     |  24 +++
 mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td    |  15 ++
 mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h   |  37 ++++
 mlir/include/mlir/InitAllDialects.h           |   2 +
 mlir/lib/Dialect/CMakeLists.txt               |   1 +
 mlir/lib/Dialect/Ptr/CMakeLists.txt           |   1 +
 mlir/lib/Dialect/Ptr/IR/CMakeLists.txt        |  14 ++
 mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp        |  48 +++++
 mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp          | 184 ++++++++++++++++++
 mlir/test/Dialect/Ptr/types.mlir              |  17 ++
 15 files changed, 450 insertions(+)
 create mode 100644 mlir/include/mlir/Dialect/Ptr/CMakeLists.txt
 create mode 100644 mlir/include/mlir/Dialect/Ptr/IR/CMakeLists.txt
 create mode 100644 mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.h
 create mode 100644 mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
 create mode 100644 mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h
 create mode 100644 mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
 create mode 100644 mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h
 create mode 100644 mlir/lib/Dialect/Ptr/CMakeLists.txt
 create mode 100644 mlir/lib/Dialect/Ptr/IR/CMakeLists.txt
 create mode 100644 mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
 create mode 100644 mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp
 create mode 100644 mlir/test/Dialect/Ptr/types.mlir

diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt
index 4bd7f12fabf7b..f710235197334 100644
--- a/mlir/include/mlir/Dialect/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/CMakeLists.txt
@@ -29,6 +29,7 @@ add_subdirectory(OpenMP)
 add_subdirectory(PDL)
 add_subdirectory(PDLInterp)
 add_subdirectory(Polynomial)
+add_subdirectory(Ptr)
 add_subdirectory(Quant)
 add_subdirectory(SCF)
 add_subdirectory(Shape)
diff --git a/mlir/include/mlir/Dialect/Ptr/CMakeLists.txt b/mlir/include/mlir/Dialect/Ptr/CMakeLists.txt
new file mode 100644
index 0000000000000..f33061b2d87cf
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Ptr/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(IR)
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Ptr/IR/CMakeLists.txt
new file mode 100644
index 0000000000000..c6ffa892e4ecb
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Ptr/IR/CMakeLists.txt
@@ -0,0 +1,2 @@
+add_mlir_dialect(PtrOps ptr)
+add_mlir_doc(PtrOps PtrOps Dialects/ -gen-op-doc)
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.h b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.h
new file mode 100644
index 0000000000000..92f877c20dbf0
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.h
@@ -0,0 +1,20 @@
+//===- PtrDialect.h - Pointer dialect ---------------------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the Ptr dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_PTR_IR_PTRDIALECT_H
+#define MLIR_DIALECT_PTR_IR_PTRDIALECT_H
+
+#include "mlir/IR/Dialect.h"
+
+#include "mlir/Dialect/Ptr/IR/PtrOpsDialect.h.inc"
+
+#endif // MLIR_DIALECT_PTR_IR_PTRDIALECT_H
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
new file mode 100644
index 0000000000000..bffae6b1ad71b
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
@@ -0,0 +1,83 @@
+//===- PtrDialect.td - Pointer dialect ---------------------*- tablegen -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef PTR_DIALECT
+#define PTR_DIALECT
+
+include "mlir/Interfaces/DataLayoutInterfaces.td"
+include "mlir/IR/AttrTypeBase.td"
+include "mlir/IR/BuiltinTypeInterfaces.td"
+include "mlir/IR/OpBase.td"
+
+//===----------------------------------------------------------------------===//
+// Pointer dialect definition.
+//===----------------------------------------------------------------------===//
+
+def Ptr_Dialect : Dialect {
+  let name = "ptr";
+  let summary = "Pointer dialect";
+  let cppNamespace = "::mlir::ptr";
+  let useDefaultTypePrinterParser = 1;
+  let useDefaultAttributePrinterParser = 0;
+}
+
+//===----------------------------------------------------------------------===//
+// Pointer type definitions
+//===----------------------------------------------------------------------===//
+
+class Ptr_Type<string name, string typeMnemonic, list<Trait> traits = []>
+    : TypeDef<Ptr_Dialect, name, traits> {
+  let mnemonic = typeMnemonic;
+}
+
+def Ptr_PtrType : Ptr_Type<"Ptr", "ptr", [
+    MemRefElementTypeInterface,
+    DeclareTypeInterfaceMethods<DataLayoutTypeInterface, [
+      "areCompatible", "getIndexBitwidth", "verifyEntries"]>
+  ]> {
+  let summary = "pointer type";
+  let description = [{
+    The `ptr` type is an opaque pointer type. This type typically represents
+    a reference to an object in memory. Pointers are optionally parameterized
+    by a memory space.
+    Syntax:
+
+    ```mlir
+    pointer ::= `ptr` (`<` memory-space `>`)?
+    memory-space ::= attribute-value
+    ```
+  }];
+  let parameters = (ins OptionalParameter<"Attribute">:$memorySpace);
+  let assemblyFormat = "(`<` $memorySpace^ `>`)?";
+  let builders = [
+    TypeBuilder<(ins CArg<"Attribute", "nullptr">:$addressSpace), [{
+      return $_get($_ctxt, addressSpace);
+    }]>,
+    TypeBuilder<(ins CArg<"unsigned">:$addressSpace), [{
+      return $_get($_ctxt, IntegerAttr::get(IntegerType::get($_ctxt, 32),
+                                            addressSpace));
+    }]>
+  ];
+  let skipDefaultBuilders = 1;
+  let extraClassDeclaration = [{
+    /// Returns the default memory space.
+    Attribute getDefaultMemorySpace() const;
+
+    /// Returns the memory space as an unsigned number.
+    int64_t getAddressSpace() const;
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// Base address operation definition.
+//===----------------------------------------------------------------------===//
+
+class Pointer_Op<string mnemonic, list<Trait> traits = []> :
+        Op<Ptr_Dialect, mnemonic, traits>;
+
+#endif // PTR_DIALECT
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h
new file mode 100644
index 0000000000000..ad8a2bbcbdd8d
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h
@@ -0,0 +1,24 @@
+//===- PtrDialect.h - Pointer dialect ---------------------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the Ptr dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_PTR_IR_PTROPS_H
+#define MLIR_DIALECT_PTR_IR_PTROPS_H
+
+#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/Ptr/IR/PtrDialect.h"
+#include "mlir/Dialect/Ptr/IR/PtrTypes.h"
+#include "mlir/IR/OpDefinition.h"
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/Ptr/IR/PtrOps.h.inc"
+
+#endif // MLIR_DIALECT_PTR_IR_PTROPS_H
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
new file mode 100644
index 0000000000000..690941337bdfb
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
@@ -0,0 +1,15 @@
+//===- PtrOps.td - Pointer dialect ops ---------------------*- tablegen -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://ptr.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef PTR_OPS
+#define PTR_OPS
+
+include "mlir/Dialect/Ptr/IR/PtrDialect.td"
+include "mlir/IR/OpAsmInterface.td"
+
+#endif // PTR_OPS
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h b/mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h
new file mode 100644
index 0000000000000..9984aedcbf6ce
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h
@@ -0,0 +1,37 @@
+//===- PtrTypes.h - Pointer types -------------------------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the Pointer dialect types.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_PTR_IR_PTRTYPES_H
+#define MLIR_DIALECT_PTR_IR_PTRTYPES_H
+
+#include "mlir/IR/Types.h"
+#include "mlir/Interfaces/DataLayoutInterfaces.h"
+
+namespace mlir {
+namespace ptr {
+/// The positions of different values in the data layout entry for pointers.
+enum class PtrDLEntryPos { Size = 0, Abi = 1, Preferred = 2, Index = 3 };
+
+/// Returns the value that corresponds to named position `pos` from the
+/// data layout entry `attr` assuming it's a dense integer elements attribute.
+/// Returns `std::nullopt` if `pos` is not present in the entry.
+/// Currently only `PtrDLEntryPos::Index` is optional, and all other positions
+/// may be assumed to be present.
+std::optional<uint64_t> extractPointerSpecValue(Attribute attr,
+                                                PtrDLEntryPos pos);
+} // namespace ptr
+} // namespace mlir
+
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/Ptr/IR/PtrOpsTypes.h.inc"
+
+#endif // MLIR_DIALECT_PTR_IR_PTRTYPES_H
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index d9db21073e15c..549c26c72d8a1 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -63,6 +63,7 @@
 #include "mlir/Dialect/PDL/IR/PDL.h"
 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
 #include "mlir/Dialect/Polynomial/IR/PolynomialDialect.h"
+#include "mlir/Dialect/Ptr/IR/PtrDialect.h"
 #include "mlir/Dialect/Quant/QuantOps.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.h"
@@ -134,6 +135,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
                   pdl::PDLDialect,
                   pdl_interp::PDLInterpDialect,
                   polynomial::PolynomialDialect,
+                  ptr::PtrDialect,
                   quant::QuantizationDialect,
                   ROCDL::ROCDLDialect,
                   scf::SCFDialect,
diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt
index a324ce7f9b19f..80b0ef068d96d 100644
--- a/mlir/lib/Dialect/CMakeLists.txt
+++ b/mlir/lib/Dialect/CMakeLists.txt
@@ -29,6 +29,7 @@ add_subdirectory(OpenMP)
 add_subdirectory(PDL)
 add_subdirectory(PDLInterp)
 add_subdirectory(Polynomial)
+add_subdirectory(Ptr)
 add_subdirectory(Quant)
 add_subdirectory(SCF)
 add_subdirectory(Shape)
diff --git a/mlir/lib/Dialect/Ptr/CMakeLists.txt b/mlir/lib/Dialect/Ptr/CMakeLists.txt
new file mode 100644
index 0000000000000..f33061b2d87cf
--- /dev/null
+++ b/mlir/lib/Dialect/Ptr/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(IR)
diff --git a/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt b/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt
new file mode 100644
index 0000000000000..359b9f02a0626
--- /dev/null
+++ b/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt
@@ -0,0 +1,14 @@
+add_mlir_dialect_library(
+  MLIRPtrDialect
+  PtrTypes.cpp
+  PtrDialect.cpp
+  ADDITIONAL_HEADER_DIRS
+  ${PROJECT_SOURCE_DIR}/mlir/Dialect/Pointer
+  DEPENDS
+  MLIRPtrOpsIncGen
+  LINK_LIBS
+  PUBLIC
+  MLIRIR
+  MLIRDataLayoutInterfaces
+  MLIRMemorySlotInterfaces
+)
diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
new file mode 100644
index 0000000000000..59c97b22f332c
--- /dev/null
+++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
@@ -0,0 +1,48 @@
+//===- PtrDialect.cpp - Pointer dialect ---------------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the Pointer dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Ptr/IR/PtrOps.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/InliningUtils.h"
+#include "llvm/ADT/SmallString.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+using namespace mlir;
+using namespace mlir::ptr;
+
+//===----------------------------------------------------------------------===//
+// Pointer dialect
+//===----------------------------------------------------------------------===//
+
+void PtrDialect::initialize() {
+  addOperations<
+#define GET_OP_LIST
+#include "mlir/Dialect/Ptr/IR/PtrOps.cpp.inc"
+      >();
+  addTypes<
+#define GET_TYPEDEF_LIST
+#include "mlir/Dialect/Ptr/IR/PtrOpsTypes.cpp.inc"
+      >();
+}
+
+//===----------------------------------------------------------------------===//
+// Pointer API.
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Ptr/IR/PtrOpsDialect.cpp.inc"
+
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/Ptr/IR/PtrOpsTypes.cpp.inc"
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/Ptr/IR/PtrOps.cpp.inc"
diff --git a/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp b/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp
new file mode 100644
index 0000000000000..51d0a45051b85
--- /dev/null
+++ b/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp
@@ -0,0 +1,184 @@
+//===- PtrTypes.cpp - Pointer dialect types ---------------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the Ptr dialect types.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Ptr/IR/PtrTypes.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+using namespace mlir;
+using namespace mlir::ptr;
+
+//===----------------------------------------------------------------------===//
+// Pointer type
+//===----------------------------------------------------------------------===//
+
+constexpr const static unsigned kDefaultPointerSizeBits = 64;
+constexpr const static unsigned kBitsInByte = 8;
+constexpr const static unsigned kDefaultPointerAlignment = 8;
+
+/// Returns the part of the data layout entry that corresponds to `pos` for the
+/// given `type` by interpreting the list of entries `params`. For the pointer
+/// type in the default address space, returns the default value if the entries
+/// do not provide a custom one, for other address spaces returns std::nullopt.
+static std::optional<uint64_t>
+getPointerDataLayoutEntry(DataLayoutEntryListRef params, PtrType type,
+                          PtrDLEntryPos pos) {
+  // First, look for the entry for the pointer in the current address space.
+  Attribute currentEntry;
+  for (DataLayoutEntryInterface entry : params) {
+    if (!entry.isTypeEntry())
+      continue;
+    if (cast<PtrType>(entry.getKey().get<Type>()).getAddressSpace() ==
+        type.getAddressSpace()) {
+      currentEntry = entry.getValue();
+      break;
+    }
+  }
+  if (currentEntry) {
+    std::optional<uint64_t> value = extractPointerSpecValue(currentEntry, pos);
+    // If the optional `PtrDLEntryPos::Index` entry is not available, use the
+    // pointer size as the index bitwidth.
+    if (!value && pos == PtrDLEntryPos::Index)
+      value = extractPointerSpecValue(currentEntry, PtrDLEntryPos::Size);
+    bool isSizeOrIndex =
+        pos == PtrDLEntryPos::Size || pos == PtrDLEntryPos::Index;
+    return *value / (isSizeOrIndex ? 1 : kBitsInByte);
+  }
+
+  // If not found, and this is the pointer to the default memory space, assume
+  // 64-bit pointers.
+  if (type.getAddressSpace() == 0) {
+    bool isSizeOrIndex =
+        pos == PtrDLEntryPos::Size || pos == PtrDLEntryPos::Index;
+    return isSizeOrIndex ? kDefaultPointerSizeBits : kDefaultPointerAlignment;
+  }
+
+  return std::nullopt;
+}
+
+int64_t PtrType::getAddressSpace() const { return 0; }
+
+Attribute PtrType::getDefaultMemorySpace() const { return nullptr; }
+
+bool PtrType::areCompatible(DataLayoutEntryListRef oldLayout,
+                            DataLayoutEntryListRef newLayout) const {
+  for (DataLayoutEntryInterface newEntry : newLayout) {
+    if (!newEntry.isTypeEntry())
+      continue;
+    unsigned size = kDefaultPointerSizeBits;
+    unsigned abi = kDefaultPointerAlignment;
+    auto newType = llvm::cast<PtrType>(newEntry.getKey().get<Type>());
+    const auto *it =
+        llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
+          if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
+            return llvm::cast<PtrType>(type).getMemorySpace() ==
+                   newType.getMemorySpace();
+          }
+          return false;
+        });
+    if (it == oldLayout.end()) {
+      llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
+        if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
+          return llvm::cast<PtrType>(type).getAddressSpace() == 0;
+        }
+        return false;
+      });
+    }
+    if (it != oldLayout.end()) {
+      size = *extractPointerSpecValue(*it, PtrDLEntryPos::Size);
+      abi = *extractPointerSpecValue(*it, PtrDLEntryPos::Abi);
+    }
+
+    Attribute newSpec = llvm::cast<DenseIntElementsAttr>(newEntry.getValue());
+    unsigned newSize = *extractPointerSpecValue(newSpec, PtrDLEntryPos::Size);
+    unsigned newAbi = *extractPointerSpecValue(newSpec, PtrDLEntryPos::Abi);
+    if (size != newSize || abi < newAbi || abi % newAbi != 0)
+      return false;
+  }
+  return true;
+}
+
+uint64_t PtrType::getABIAlignment(const DataLayout &dataLayout,
+                                  DataLayoutEntryListRef params) const {
+  if (std::optional<uint64_t> alignment =
+          getPointerDataLayoutEntry(params, *this, PtrDLEntryPos::Abi))
+    return *alignment;
+
+  return dataLayout.getTypeABIAlignment(
+      get(getContext(), getDefaultMemorySpace()));
+}
+
+std::optional<uint64_t>
+PtrType::getIndexBitwidth(const DataLayout &dataLayout,
+                          DataLayoutEntryListRef params) const {
+  if (std::optional<uint64_t> indexBitwidth =
+          getPointerDataLayoutEntry(params, *this, PtrDLEntryPos::Index))
+    return *indexBitwidth;
+
+  return dataLayout.getTypeIndexBitwidth(
+      get(getContext(), getDefaultMemorySpace()));
+}
+
+llvm::TypeSize PtrType::getTypeSizeInBits(const DataLayout &dataLayout,
+                                          DataLayoutEntryListRef params) const {
+  if (std::optional<uint64_t> size =
+          getPointerDataLayoutEntry(params, *this, PtrDLEntryPos::Size))
+    return llvm::TypeSize::getFixed(*size);
+
+  // For other memory spaces, use the size of the pointer to the default memory
+  // space.
+  return dataLayout.getTypeSizeInBits(
+      get(getContext(), getDefaultMemorySpace()));
+}
+
+uint64_t PtrType::getPreferredAlignment(const DataLayout &dataLayout,
+                                        DataLayoutEntryListRef params) const {
+  if (std::optional<uint64_t> alignment =
+          getPointerDataLayoutEntry(params, *this, PtrDLEntryPos::Preferred))
+    return *alignment;
+
+  return dataLayout.getTypePreferredAlignment(
+      get(getContext(), getDefaultMemorySpace()));
+}
+
+std::optional<uint64_t> mlir::ptr::extractPointerSpecValue(Attribute attr,
+                                                           PtrDLEntryPos pos) {
+  auto spec = cast<DenseIntElementsAttr>(attr);
+  auto idx = static_cast<int64_t>(pos);
+  if (idx >= spec.size())
+    return std::nullopt;
+  return spec.getValues<uint64_t>()[idx];
+}
+
+LogicalResult PtrType::verifyEntries(DataLayoutEntryListRef entries,
+                                     Location loc) const {
+  for (DataLayoutEntryInterface entry : entries) {
+    if (!entry.isTypeEntry())
+      continue;
+    auto key = entry.getKey().get<Type>();
+    auto values = llvm::dyn_cast<DenseIntElementsAttr>(entry.getValue());
+    if (!values || (values.size() != 3 && values.size() != 4)) {
+      return emitError(loc)
+             << "expected layout attribute for " << key
+             << " to be a dense integer elements attribute with 3 or 4 "
+                "elements";
+    }
+    if (!values.getElementType().isInteger(64))
+      return emitError(loc) << "expected i64 parameters for " << key;
+
+    if (extractPointerSpecValue(values, PtrDLEntryPos::Abi) >
+        extractPointerSpecValue(values, PtrDLEntryPos::Preferred)) {
+      return emitError(loc) << "preferred alignment is expected to be at least "
+                               "as large as ABI alignment";
+    }
+  }
+  return success();
+}
diff --git a/mlir/test/Dialect/Ptr/types.mlir b/mlir/test/Dialect/Ptr/types.mlir
new file mode 100644
index 0000000000000..279213bd6fc3e
--- /dev/null
+++ b/mlir/test/Dialect/Ptr/types.mlir
@@ -0,0 +1,17 @@
+// RUN: mlir-opt %s -split-input-file | mlir-opt | FileCheck %s
+
+// CHECK-LABEL: func @ptr_test
+// CHECK: (%[[ARG0:.*]]: !ptr.ptr, %[[ARG1:.*]]: !ptr.ptr<1 : i32>)
+// CHECK: -> (!ptr.ptr<1 : i32>, !ptr.ptr)
+func.func @ptr_test(%arg0: !ptr.ptr, %arg1: !ptr.ptr<1 : i32>) -> (!ptr.ptr<1 : i32>, !ptr.ptr) {
+  // CHECK: return %[[ARG1]], %[[ARG0]] : !ptr.ptr<1 : i32>, !ptr.ptr
+  return %arg1, %arg0 : !ptr.ptr<1 : i32>, !ptr.ptr
+}
+
+// -----
+
+// CHECK-LABEL: func @ptr_test
+// CHECK: %[[ARG:.*]]: memref<!ptr.ptr>
+func.func @ptr_test(%arg0: memref<!ptr.ptr>) {
+  return
+}

>From c57fb25b48c2e13e27e1d14106084ff4fc345fdf Mon Sep 17 00:00:00 2001
From: Fabian Mora <fmora.dev at gmail.com>
Date: Tue, 2 Apr 2024 15:12:48 +0000
Subject: [PATCH 2/2] add layout test, address reviewer comments and fix a bug
 in layout methods

---
 .../include/mlir/Dialect/Ptr/IR/PtrDialect.td |  6 +-
 mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp          | 20 ++---
 mlir/test/Dialect/Ptr/layout.mlir             | 87 +++++++++++++++++++
 3 files changed, 99 insertions(+), 14 deletions(-)
 create mode 100644 mlir/test/Dialect/Ptr/layout.mlir

diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
index bffae6b1ad71b..315ecdfb6609e 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
@@ -55,8 +55,8 @@ def Ptr_PtrType : Ptr_Type<"Ptr", "ptr", [
   let parameters = (ins OptionalParameter<"Attribute">:$memorySpace);
   let assemblyFormat = "(`<` $memorySpace^ `>`)?";
   let builders = [
-    TypeBuilder<(ins CArg<"Attribute", "nullptr">:$addressSpace), [{
-      return $_get($_ctxt, addressSpace);
+    TypeBuilder<(ins CArg<"Attribute", "nullptr">:$memorySpace), [{
+      return $_get($_ctxt, memorySpace);
     }]>,
     TypeBuilder<(ins CArg<"unsigned">:$addressSpace), [{
       return $_get($_ctxt, IntegerAttr::get(IntegerType::get($_ctxt, 32),
@@ -69,7 +69,7 @@ def Ptr_PtrType : Ptr_Type<"Ptr", "ptr", [
     Attribute getDefaultMemorySpace() const;
 
     /// Returns the memory space as an unsigned number.
-    int64_t getAddressSpace() const;
+    uint64_t getAddressSpace() const;
   }];
 }
 
diff --git a/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp b/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp
index 51d0a45051b85..95c72f5743afc 100644
--- a/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp
+++ b/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp
@@ -36,35 +36,32 @@ getPointerDataLayoutEntry(DataLayoutEntryListRef params, PtrType type,
   for (DataLayoutEntryInterface entry : params) {
     if (!entry.isTypeEntry())
       continue;
-    if (cast<PtrType>(entry.getKey().get<Type>()).getAddressSpace() ==
-        type.getAddressSpace()) {
+    if (cast<PtrType>(entry.getKey().get<Type>()).getMemorySpace() ==
+        type.getMemorySpace()) {
       currentEntry = entry.getValue();
       break;
     }
   }
+  bool isSizeOrIndex =
+      pos == PtrDLEntryPos::Size || pos == PtrDLEntryPos::Index;
   if (currentEntry) {
     std::optional<uint64_t> value = extractPointerSpecValue(currentEntry, pos);
     // If the optional `PtrDLEntryPos::Index` entry is not available, use the
     // pointer size as the index bitwidth.
     if (!value && pos == PtrDLEntryPos::Index)
       value = extractPointerSpecValue(currentEntry, PtrDLEntryPos::Size);
-    bool isSizeOrIndex =
-        pos == PtrDLEntryPos::Size || pos == PtrDLEntryPos::Index;
     return *value / (isSizeOrIndex ? 1 : kBitsInByte);
   }
 
   // If not found, and this is the pointer to the default memory space, assume
   // 64-bit pointers.
-  if (type.getAddressSpace() == 0) {
-    bool isSizeOrIndex =
-        pos == PtrDLEntryPos::Size || pos == PtrDLEntryPos::Index;
+  if (type.getMemorySpace() == type.getDefaultMemorySpace())
     return isSizeOrIndex ? kDefaultPointerSizeBits : kDefaultPointerAlignment;
-  }
 
   return std::nullopt;
 }
 
-int64_t PtrType::getAddressSpace() const { return 0; }
+uint64_t PtrType::getAddressSpace() const { return 0; }
 
 Attribute PtrType::getDefaultMemorySpace() const { return nullptr; }
 
@@ -85,9 +82,10 @@ bool PtrType::areCompatible(DataLayoutEntryListRef oldLayout,
           return false;
         });
     if (it == oldLayout.end()) {
-      llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
+      it = llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
         if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
-          return llvm::cast<PtrType>(type).getAddressSpace() == 0;
+          auto ptrTy = llvm::cast<PtrType>(type);
+          return ptrTy.getMemorySpace() == ptrTy.getDefaultMemorySpace();
         }
         return false;
       });
diff --git a/mlir/test/Dialect/Ptr/layout.mlir b/mlir/test/Dialect/Ptr/layout.mlir
new file mode 100644
index 0000000000000..b345fbd6f6fbb
--- /dev/null
+++ b/mlir/test/Dialect/Ptr/layout.mlir
@@ -0,0 +1,87 @@
+// RUN: mlir-opt --test-data-layout-query --split-input-file --verify-diagnostics %s | FileCheck %s
+
+module attributes { dlti.dl_spec = #dlti.dl_spec<
+  #dlti.dl_entry<!ptr.ptr, dense<[32, 32, 64]> : vector<3xi64>>,
+  #dlti.dl_entry<!ptr.ptr<5>, dense<[64, 64, 64]> : vector<3xi64>>,
+  #dlti.dl_entry<!ptr.ptr<4>, dense<[32, 64, 64, 24]> : vector<4xi64>>,
+  #dlti.dl_entry<"dlti.alloca_memory_space", 5 : ui64>,
+  #dlti.dl_entry<"dlti.global_memory_space", 2 : ui64>,
+  #dlti.dl_entry<"dlti.program_memory_space", 3 : ui64>,
+  #dlti.dl_entry<"dlti.stack_alignment", 128 : i64>
+>} {
+  // CHECK: @spec
+  func.func @spec() {
+    // CHECK: alignment = 4
+    // CHECK: alloca_memory_space = 5
+    // CHECK: bitsize = 32
+    // CHECK: global_memory_space = 2
+    // CHECK: index = 32
+    // CHECK: preferred = 8
+    // CHECK: program_memory_space = 3
+    // CHECK: size = 4
+    // CHECK: stack_alignment = 128
+    "test.data_layout_query"() : () -> !ptr.ptr
+    // CHECK: alignment = 4
+    // CHECK: alloca_memory_space = 5
+    // CHECK: bitsize = 32
+    // CHECK: global_memory_space = 2
+    // CHECK: index = 32
+    // CHECK: preferred = 8
+    // CHECK: program_memory_space = 3
+    // CHECK: size = 4
+    // CHECK: stack_alignment = 128
+    "test.data_layout_query"() : () -> !ptr.ptr<3>
+    // CHECK: alignment = 8
+    // CHECK: alloca_memory_space = 5
+    // CHECK: bitsize = 64
+    // CHECK: global_memory_space = 2
+    // CHECK: index = 64
+    // CHECK: preferred = 8
+    // CHECK: program_memory_space = 3
+    // CHECK: size = 8
+    // CHECK: stack_alignment = 128
+    "test.data_layout_query"() : () -> !ptr.ptr<5>
+    // CHECK: alignment = 8
+    // CHECK: alloca_memory_space = 5
+    // CHECK: bitsize = 32
+    // CHECK: global_memory_space = 2
+    // CHECK: index = 24
+    // CHECK: preferred = 8
+    // CHECK: program_memory_space = 3
+    // CHECK: size = 4
+    // CHECK: stack_alignment = 128
+    "test.data_layout_query"() : () -> !ptr.ptr<4>
+    return
+  }
+}
+
+// -----
+
+// expected-error at below {{expected layout attribute for '!ptr.ptr' to be a dense integer elements attribute with 3 or 4 elements}}
+module attributes { dlti.dl_spec = #dlti.dl_spec<
+  #dlti.dl_entry<!ptr.ptr, dense<[64.0, 64.0, 64.0]> : vector<3xf32>>
+>} {
+  func.func @pointer() {
+    return
+  }
+}
+
+// -----
+
+// expected-error at below {{preferred alignment is expected to be at least as large as ABI alignment}}
+module attributes { dlti.dl_spec = #dlti.dl_spec<
+  #dlti.dl_entry<!ptr.ptr, dense<[64, 64, 32]> : vector<3xi64>>
+>} {
+  func.func @pointer() {
+    return
+  }
+}
+
+// -----
+
+// expected-error @below {{expected i64 parameters for '!ptr.ptr'}}
+module attributes { dlti.dl_spec = #dlti.dl_spec<
+  #dlti.dl_entry<!ptr.ptr, dense<[32, 32, 64]> : vector<3xi32>>
+>} {
+}
+



More information about the Mlir-commits mailing list