[Mlir-commits] [mlir] [mlir][Ptr] Add the `MemorySpaceAttrInterface` interface and dependencies. (PR #86870)

Fabian Mora llvmlistbot at llvm.org
Tue Apr 2 08:41:04 PDT 2024


https://github.com/fabianmcg updated https://github.com/llvm/llvm-project/pull/86870

>From af1ff4d4d302d0722e06ab0068eae917a30bb09d Mon Sep 17 00:00:00 2001
From: Fabian Mora <fmora.dev at gmail.com>
Date: Wed, 27 Mar 2024 19:26:30 +0000
Subject: [PATCH 1/3] [mlir][Ptr] Init the Ptr dialect with the `!ptr.ptr`
 type.

This patch initializes the `ptr` dialect directories and some base files. It
also add the `!ptr.ptr` type, together with the `DataLayoutTypeInterface`
interface. The implementation of the `DataLayoutTypeInterface` interface
clones the implementation from `LLVM::LLVMPointerType`.
---
 mlir/include/mlir/Dialect/CMakeLists.txt      |   1 +
 mlir/include/mlir/Dialect/Ptr/CMakeLists.txt  |   1 +
 .../mlir/Dialect/Ptr/IR/CMakeLists.txt        |   2 +
 mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.h |  20 ++
 .../include/mlir/Dialect/Ptr/IR/PtrDialect.td |  83 ++++++++
 mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h     |  24 +++
 mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td    |  15 ++
 mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h   |  37 ++++
 mlir/include/mlir/InitAllDialects.h           |   2 +
 mlir/lib/Dialect/CMakeLists.txt               |   1 +
 mlir/lib/Dialect/Ptr/CMakeLists.txt           |   1 +
 mlir/lib/Dialect/Ptr/IR/CMakeLists.txt        |  14 ++
 mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp        |  48 +++++
 mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp          | 184 ++++++++++++++++++
 mlir/test/Dialect/Ptr/types.mlir              |  17 ++
 15 files changed, 450 insertions(+)
 create mode 100644 mlir/include/mlir/Dialect/Ptr/CMakeLists.txt
 create mode 100644 mlir/include/mlir/Dialect/Ptr/IR/CMakeLists.txt
 create mode 100644 mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.h
 create mode 100644 mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
 create mode 100644 mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h
 create mode 100644 mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
 create mode 100644 mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h
 create mode 100644 mlir/lib/Dialect/Ptr/CMakeLists.txt
 create mode 100644 mlir/lib/Dialect/Ptr/IR/CMakeLists.txt
 create mode 100644 mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
 create mode 100644 mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp
 create mode 100644 mlir/test/Dialect/Ptr/types.mlir

diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt
index 2da79011fa26a3..5f0e9806926145 100644
--- a/mlir/include/mlir/Dialect/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/CMakeLists.txt
@@ -28,6 +28,7 @@ add_subdirectory(OpenACCMPCommon)
 add_subdirectory(OpenMP)
 add_subdirectory(PDL)
 add_subdirectory(PDLInterp)
+add_subdirectory(Ptr)
 add_subdirectory(Quant)
 add_subdirectory(SCF)
 add_subdirectory(Shape)
diff --git a/mlir/include/mlir/Dialect/Ptr/CMakeLists.txt b/mlir/include/mlir/Dialect/Ptr/CMakeLists.txt
new file mode 100644
index 00000000000000..f33061b2d87cff
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Ptr/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(IR)
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Ptr/IR/CMakeLists.txt
new file mode 100644
index 00000000000000..c6ffa892e4ecba
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Ptr/IR/CMakeLists.txt
@@ -0,0 +1,2 @@
+add_mlir_dialect(PtrOps ptr)
+add_mlir_doc(PtrOps PtrOps Dialects/ -gen-op-doc)
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.h b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.h
new file mode 100644
index 00000000000000..92f877c20dbf07
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.h
@@ -0,0 +1,20 @@
+//===- PtrDialect.h - Pointer dialect ---------------------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the Ptr dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_PTR_IR_PTRDIALECT_H
+#define MLIR_DIALECT_PTR_IR_PTRDIALECT_H
+
+#include "mlir/IR/Dialect.h"
+
+#include "mlir/Dialect/Ptr/IR/PtrOpsDialect.h.inc"
+
+#endif // MLIR_DIALECT_PTR_IR_PTRDIALECT_H
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
new file mode 100644
index 00000000000000..bffae6b1ad71bb
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
@@ -0,0 +1,83 @@
+//===- PtrDialect.td - Pointer dialect ---------------------*- tablegen -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef PTR_DIALECT
+#define PTR_DIALECT
+
+include "mlir/Interfaces/DataLayoutInterfaces.td"
+include "mlir/IR/AttrTypeBase.td"
+include "mlir/IR/BuiltinTypeInterfaces.td"
+include "mlir/IR/OpBase.td"
+
+//===----------------------------------------------------------------------===//
+// Pointer dialect definition.
+//===----------------------------------------------------------------------===//
+
+def Ptr_Dialect : Dialect {
+  let name = "ptr";
+  let summary = "Pointer dialect";
+  let cppNamespace = "::mlir::ptr";
+  let useDefaultTypePrinterParser = 1;
+  let useDefaultAttributePrinterParser = 0;
+}
+
+//===----------------------------------------------------------------------===//
+// Pointer type definitions
+//===----------------------------------------------------------------------===//
+
+class Ptr_Type<string name, string typeMnemonic, list<Trait> traits = []>
+    : TypeDef<Ptr_Dialect, name, traits> {
+  let mnemonic = typeMnemonic;
+}
+
+def Ptr_PtrType : Ptr_Type<"Ptr", "ptr", [
+    MemRefElementTypeInterface,
+    DeclareTypeInterfaceMethods<DataLayoutTypeInterface, [
+      "areCompatible", "getIndexBitwidth", "verifyEntries"]>
+  ]> {
+  let summary = "pointer type";
+  let description = [{
+    The `ptr` type is an opaque pointer type. This type typically represents
+    a reference to an object in memory. Pointers are optionally parameterized
+    by a memory space.
+    Syntax:
+
+    ```mlir
+    pointer ::= `ptr` (`<` memory-space `>`)?
+    memory-space ::= attribute-value
+    ```
+  }];
+  let parameters = (ins OptionalParameter<"Attribute">:$memorySpace);
+  let assemblyFormat = "(`<` $memorySpace^ `>`)?";
+  let builders = [
+    TypeBuilder<(ins CArg<"Attribute", "nullptr">:$addressSpace), [{
+      return $_get($_ctxt, addressSpace);
+    }]>,
+    TypeBuilder<(ins CArg<"unsigned">:$addressSpace), [{
+      return $_get($_ctxt, IntegerAttr::get(IntegerType::get($_ctxt, 32),
+                                            addressSpace));
+    }]>
+  ];
+  let skipDefaultBuilders = 1;
+  let extraClassDeclaration = [{
+    /// Returns the default memory space.
+    Attribute getDefaultMemorySpace() const;
+
+    /// Returns the memory space as an unsigned number.
+    int64_t getAddressSpace() const;
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// Base address operation definition.
+//===----------------------------------------------------------------------===//
+
+class Pointer_Op<string mnemonic, list<Trait> traits = []> :
+        Op<Ptr_Dialect, mnemonic, traits>;
+
+#endif // PTR_DIALECT
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h
new file mode 100644
index 00000000000000..ad8a2bbcbdd8d2
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h
@@ -0,0 +1,24 @@
+//===- PtrDialect.h - Pointer dialect ---------------------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the Ptr dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_PTR_IR_PTROPS_H
+#define MLIR_DIALECT_PTR_IR_PTROPS_H
+
+#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/Ptr/IR/PtrDialect.h"
+#include "mlir/Dialect/Ptr/IR/PtrTypes.h"
+#include "mlir/IR/OpDefinition.h"
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/Ptr/IR/PtrOps.h.inc"
+
+#endif // MLIR_DIALECT_PTR_IR_PTROPS_H
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
new file mode 100644
index 00000000000000..690941337bdfb5
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
@@ -0,0 +1,15 @@
+//===- PtrOps.td - Pointer dialect ops ---------------------*- tablegen -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://ptr.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef PTR_OPS
+#define PTR_OPS
+
+include "mlir/Dialect/Ptr/IR/PtrDialect.td"
+include "mlir/IR/OpAsmInterface.td"
+
+#endif // PTR_OPS
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h b/mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h
new file mode 100644
index 00000000000000..9984aedcbf6ce8
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h
@@ -0,0 +1,37 @@
+//===- PtrTypes.h - Pointer types -------------------------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the Pointer dialect types.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_PTR_IR_PTRTYPES_H
+#define MLIR_DIALECT_PTR_IR_PTRTYPES_H
+
+#include "mlir/IR/Types.h"
+#include "mlir/Interfaces/DataLayoutInterfaces.h"
+
+namespace mlir {
+namespace ptr {
+/// The positions of different values in the data layout entry for pointers.
+enum class PtrDLEntryPos { Size = 0, Abi = 1, Preferred = 2, Index = 3 };
+
+/// Returns the value that corresponds to named position `pos` from the
+/// data layout entry `attr` assuming it's a dense integer elements attribute.
+/// Returns `std::nullopt` if `pos` is not present in the entry.
+/// Currently only `PtrDLEntryPos::Index` is optional, and all other positions
+/// may be assumed to be present.
+std::optional<uint64_t> extractPointerSpecValue(Attribute attr,
+                                                PtrDLEntryPos pos);
+} // namespace ptr
+} // namespace mlir
+
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/Ptr/IR/PtrOpsTypes.h.inc"
+
+#endif // MLIR_DIALECT_PTR_IR_PTRTYPES_H
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index c558dc53cc7fac..b0cbb720519edd 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -61,6 +61,7 @@
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
 #include "mlir/Dialect/PDL/IR/PDL.h"
 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
+#include "mlir/Dialect/Ptr/IR/PtrDialect.h"
 #include "mlir/Dialect/Quant/QuantOps.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.h"
@@ -131,6 +132,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
                   omp::OpenMPDialect,
                   pdl::PDLDialect,
                   pdl_interp::PDLInterpDialect,
+                  ptr::PtrDialect,
                   quant::QuantizationDialect,
                   ROCDL::ROCDLDialect,
                   scf::SCFDialect,
diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt
index b1ba5a3bc8817d..6fa8a610e196c4 100644
--- a/mlir/lib/Dialect/CMakeLists.txt
+++ b/mlir/lib/Dialect/CMakeLists.txt
@@ -28,6 +28,7 @@ add_subdirectory(OpenACCMPCommon)
 add_subdirectory(OpenMP)
 add_subdirectory(PDL)
 add_subdirectory(PDLInterp)
+add_subdirectory(Ptr)
 add_subdirectory(Quant)
 add_subdirectory(SCF)
 add_subdirectory(Shape)
diff --git a/mlir/lib/Dialect/Ptr/CMakeLists.txt b/mlir/lib/Dialect/Ptr/CMakeLists.txt
new file mode 100644
index 00000000000000..f33061b2d87cff
--- /dev/null
+++ b/mlir/lib/Dialect/Ptr/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(IR)
diff --git a/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt b/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt
new file mode 100644
index 00000000000000..359b9f02a06266
--- /dev/null
+++ b/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt
@@ -0,0 +1,14 @@
+add_mlir_dialect_library(
+  MLIRPtrDialect
+  PtrTypes.cpp
+  PtrDialect.cpp
+  ADDITIONAL_HEADER_DIRS
+  ${PROJECT_SOURCE_DIR}/mlir/Dialect/Pointer
+  DEPENDS
+  MLIRPtrOpsIncGen
+  LINK_LIBS
+  PUBLIC
+  MLIRIR
+  MLIRDataLayoutInterfaces
+  MLIRMemorySlotInterfaces
+)
diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
new file mode 100644
index 00000000000000..59c97b22f332c4
--- /dev/null
+++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
@@ -0,0 +1,48 @@
+//===- PtrDialect.cpp - Pointer dialect ---------------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the Pointer dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Ptr/IR/PtrOps.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/InliningUtils.h"
+#include "llvm/ADT/SmallString.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+using namespace mlir;
+using namespace mlir::ptr;
+
+//===----------------------------------------------------------------------===//
+// Pointer dialect
+//===----------------------------------------------------------------------===//
+
+void PtrDialect::initialize() {
+  addOperations<
+#define GET_OP_LIST
+#include "mlir/Dialect/Ptr/IR/PtrOps.cpp.inc"
+      >();
+  addTypes<
+#define GET_TYPEDEF_LIST
+#include "mlir/Dialect/Ptr/IR/PtrOpsTypes.cpp.inc"
+      >();
+}
+
+//===----------------------------------------------------------------------===//
+// Pointer API.
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Ptr/IR/PtrOpsDialect.cpp.inc"
+
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/Ptr/IR/PtrOpsTypes.cpp.inc"
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/Ptr/IR/PtrOps.cpp.inc"
diff --git a/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp b/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp
new file mode 100644
index 00000000000000..51d0a45051b85e
--- /dev/null
+++ b/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp
@@ -0,0 +1,184 @@
+//===- PtrTypes.cpp - Pointer dialect types ---------------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the Ptr dialect types.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Ptr/IR/PtrTypes.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+using namespace mlir;
+using namespace mlir::ptr;
+
+//===----------------------------------------------------------------------===//
+// Pointer type
+//===----------------------------------------------------------------------===//
+
+constexpr const static unsigned kDefaultPointerSizeBits = 64;
+constexpr const static unsigned kBitsInByte = 8;
+constexpr const static unsigned kDefaultPointerAlignment = 8;
+
+/// Returns the part of the data layout entry that corresponds to `pos` for the
+/// given `type` by interpreting the list of entries `params`. For the pointer
+/// type in the default address space, returns the default value if the entries
+/// do not provide a custom one, for other address spaces returns std::nullopt.
+static std::optional<uint64_t>
+getPointerDataLayoutEntry(DataLayoutEntryListRef params, PtrType type,
+                          PtrDLEntryPos pos) {
+  // First, look for the entry for the pointer in the current address space.
+  Attribute currentEntry;
+  for (DataLayoutEntryInterface entry : params) {
+    if (!entry.isTypeEntry())
+      continue;
+    if (cast<PtrType>(entry.getKey().get<Type>()).getAddressSpace() ==
+        type.getAddressSpace()) {
+      currentEntry = entry.getValue();
+      break;
+    }
+  }
+  if (currentEntry) {
+    std::optional<uint64_t> value = extractPointerSpecValue(currentEntry, pos);
+    // If the optional `PtrDLEntryPos::Index` entry is not available, use the
+    // pointer size as the index bitwidth.
+    if (!value && pos == PtrDLEntryPos::Index)
+      value = extractPointerSpecValue(currentEntry, PtrDLEntryPos::Size);
+    bool isSizeOrIndex =
+        pos == PtrDLEntryPos::Size || pos == PtrDLEntryPos::Index;
+    return *value / (isSizeOrIndex ? 1 : kBitsInByte);
+  }
+
+  // If not found, and this is the pointer to the default memory space, assume
+  // 64-bit pointers.
+  if (type.getAddressSpace() == 0) {
+    bool isSizeOrIndex =
+        pos == PtrDLEntryPos::Size || pos == PtrDLEntryPos::Index;
+    return isSizeOrIndex ? kDefaultPointerSizeBits : kDefaultPointerAlignment;
+  }
+
+  return std::nullopt;
+}
+
+int64_t PtrType::getAddressSpace() const { return 0; }
+
+Attribute PtrType::getDefaultMemorySpace() const { return nullptr; }
+
+bool PtrType::areCompatible(DataLayoutEntryListRef oldLayout,
+                            DataLayoutEntryListRef newLayout) const {
+  for (DataLayoutEntryInterface newEntry : newLayout) {
+    if (!newEntry.isTypeEntry())
+      continue;
+    unsigned size = kDefaultPointerSizeBits;
+    unsigned abi = kDefaultPointerAlignment;
+    auto newType = llvm::cast<PtrType>(newEntry.getKey().get<Type>());
+    const auto *it =
+        llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
+          if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
+            return llvm::cast<PtrType>(type).getMemorySpace() ==
+                   newType.getMemorySpace();
+          }
+          return false;
+        });
+    if (it == oldLayout.end()) {
+      llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
+        if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
+          return llvm::cast<PtrType>(type).getAddressSpace() == 0;
+        }
+        return false;
+      });
+    }
+    if (it != oldLayout.end()) {
+      size = *extractPointerSpecValue(*it, PtrDLEntryPos::Size);
+      abi = *extractPointerSpecValue(*it, PtrDLEntryPos::Abi);
+    }
+
+    Attribute newSpec = llvm::cast<DenseIntElementsAttr>(newEntry.getValue());
+    unsigned newSize = *extractPointerSpecValue(newSpec, PtrDLEntryPos::Size);
+    unsigned newAbi = *extractPointerSpecValue(newSpec, PtrDLEntryPos::Abi);
+    if (size != newSize || abi < newAbi || abi % newAbi != 0)
+      return false;
+  }
+  return true;
+}
+
+uint64_t PtrType::getABIAlignment(const DataLayout &dataLayout,
+                                  DataLayoutEntryListRef params) const {
+  if (std::optional<uint64_t> alignment =
+          getPointerDataLayoutEntry(params, *this, PtrDLEntryPos::Abi))
+    return *alignment;
+
+  return dataLayout.getTypeABIAlignment(
+      get(getContext(), getDefaultMemorySpace()));
+}
+
+std::optional<uint64_t>
+PtrType::getIndexBitwidth(const DataLayout &dataLayout,
+                          DataLayoutEntryListRef params) const {
+  if (std::optional<uint64_t> indexBitwidth =
+          getPointerDataLayoutEntry(params, *this, PtrDLEntryPos::Index))
+    return *indexBitwidth;
+
+  return dataLayout.getTypeIndexBitwidth(
+      get(getContext(), getDefaultMemorySpace()));
+}
+
+llvm::TypeSize PtrType::getTypeSizeInBits(const DataLayout &dataLayout,
+                                          DataLayoutEntryListRef params) const {
+  if (std::optional<uint64_t> size =
+          getPointerDataLayoutEntry(params, *this, PtrDLEntryPos::Size))
+    return llvm::TypeSize::getFixed(*size);
+
+  // For other memory spaces, use the size of the pointer to the default memory
+  // space.
+  return dataLayout.getTypeSizeInBits(
+      get(getContext(), getDefaultMemorySpace()));
+}
+
+uint64_t PtrType::getPreferredAlignment(const DataLayout &dataLayout,
+                                        DataLayoutEntryListRef params) const {
+  if (std::optional<uint64_t> alignment =
+          getPointerDataLayoutEntry(params, *this, PtrDLEntryPos::Preferred))
+    return *alignment;
+
+  return dataLayout.getTypePreferredAlignment(
+      get(getContext(), getDefaultMemorySpace()));
+}
+
+std::optional<uint64_t> mlir::ptr::extractPointerSpecValue(Attribute attr,
+                                                           PtrDLEntryPos pos) {
+  auto spec = cast<DenseIntElementsAttr>(attr);
+  auto idx = static_cast<int64_t>(pos);
+  if (idx >= spec.size())
+    return std::nullopt;
+  return spec.getValues<uint64_t>()[idx];
+}
+
+LogicalResult PtrType::verifyEntries(DataLayoutEntryListRef entries,
+                                     Location loc) const {
+  for (DataLayoutEntryInterface entry : entries) {
+    if (!entry.isTypeEntry())
+      continue;
+    auto key = entry.getKey().get<Type>();
+    auto values = llvm::dyn_cast<DenseIntElementsAttr>(entry.getValue());
+    if (!values || (values.size() != 3 && values.size() != 4)) {
+      return emitError(loc)
+             << "expected layout attribute for " << key
+             << " to be a dense integer elements attribute with 3 or 4 "
+                "elements";
+    }
+    if (!values.getElementType().isInteger(64))
+      return emitError(loc) << "expected i64 parameters for " << key;
+
+    if (extractPointerSpecValue(values, PtrDLEntryPos::Abi) >
+        extractPointerSpecValue(values, PtrDLEntryPos::Preferred)) {
+      return emitError(loc) << "preferred alignment is expected to be at least "
+                               "as large as ABI alignment";
+    }
+  }
+  return success();
+}
diff --git a/mlir/test/Dialect/Ptr/types.mlir b/mlir/test/Dialect/Ptr/types.mlir
new file mode 100644
index 00000000000000..279213bd6fc3e5
--- /dev/null
+++ b/mlir/test/Dialect/Ptr/types.mlir
@@ -0,0 +1,17 @@
+// RUN: mlir-opt %s -split-input-file | mlir-opt | FileCheck %s
+
+// CHECK-LABEL: func @ptr_test
+// CHECK: (%[[ARG0:.*]]: !ptr.ptr, %[[ARG1:.*]]: !ptr.ptr<1 : i32>)
+// CHECK: -> (!ptr.ptr<1 : i32>, !ptr.ptr)
+func.func @ptr_test(%arg0: !ptr.ptr, %arg1: !ptr.ptr<1 : i32>) -> (!ptr.ptr<1 : i32>, !ptr.ptr) {
+  // CHECK: return %[[ARG1]], %[[ARG0]] : !ptr.ptr<1 : i32>, !ptr.ptr
+  return %arg1, %arg0 : !ptr.ptr<1 : i32>, !ptr.ptr
+}
+
+// -----
+
+// CHECK-LABEL: func @ptr_test
+// CHECK: %[[ARG:.*]]: memref<!ptr.ptr>
+func.func @ptr_test(%arg0: memref<!ptr.ptr>) {
+  return
+}

>From 827ba48faec31ded77b64b488533624345cba831 Mon Sep 17 00:00:00 2001
From: Fabian Mora <fmora.dev at gmail.com>
Date: Wed, 27 Mar 2024 20:31:55 +0000
Subject: [PATCH 2/3] [mlir][Ptr] Add the `MemorySpaceAttrInterface` interface
 and dependencies.

This patch introduces the `MemorySpaceAttrInterface` interface. This
interface is responsible for handling the semantics of `ptr` operations.
For example, this interface can be used to create read-only memory spaces,
making any other operation other than a load a verification error, see
`TestConstMemorySpaceAttr` for a possible implementation of this concept.

This patch also introduces Enum depedencies `AtomicOrdering`, and `AtomicBinOp`,
both enumerations are clones of the Enums with same name in the LLVM Dialect.
---
 .../mlir/Dialect/Ptr/IR/CMakeLists.txt        |  12 ++
 .../include/mlir/Dialect/Ptr/IR/MemorySpace.h | 161 ++++++++++++++++
 .../Dialect/Ptr/IR/MemorySpaceInterfaces.td   | 182 ++++++++++++++++++
 mlir/include/mlir/Dialect/Ptr/IR/PtrAttrs.h   |  20 ++
 mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.td  |  69 +++++++
 mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h     |   2 +
 mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td    |   1 +
 mlir/lib/Dialect/Ptr/IR/CMakeLists.txt        |   2 +
 mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp        |  69 +++++++
 mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp          |   9 +-
 mlir/test/Dialect/Ptr/types.mlir              |   7 +
 mlir/test/lib/Dialect/Test/CMakeLists.txt     |   1 +
 mlir/test/lib/Dialect/Test/TestAttrDefs.td    |  10 +
 mlir/test/lib/Dialect/Test/TestAttributes.cpp |  48 +++++
 mlir/test/lib/Dialect/Test/TestAttributes.h   |   1 +
 15 files changed, 592 insertions(+), 2 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/Ptr/IR/MemorySpace.h
 create mode 100644 mlir/include/mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td
 create mode 100644 mlir/include/mlir/Dialect/Ptr/IR/PtrAttrs.h
 create mode 100644 mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.td

diff --git a/mlir/include/mlir/Dialect/Ptr/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Ptr/IR/CMakeLists.txt
index c6ffa892e4ecba..80dffdfd402cf2 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Ptr/IR/CMakeLists.txt
@@ -1,2 +1,14 @@
 add_mlir_dialect(PtrOps ptr)
 add_mlir_doc(PtrOps PtrOps Dialects/ -gen-op-doc)
+
+set(LLVM_TARGET_DEFINITIONS MemorySpaceInterfaces.td)
+mlir_tablegen(MemorySpaceInterfaces.h.inc -gen-op-interface-decls)
+mlir_tablegen(MemorySpaceInterfaces.cpp.inc -gen-op-interface-defs)
+mlir_tablegen(MemorySpaceAttrInterfaces.h.inc -gen-attr-interface-decls)
+mlir_tablegen(MemorySpaceAttrInterfaces.cpp.inc -gen-attr-interface-defs)
+add_public_tablegen_target(MLIRPtrMemorySpaceInterfacesIncGen)
+
+set(LLVM_TARGET_DEFINITIONS PtrOps.td)
+mlir_tablegen(PtrOpsEnums.h.inc -gen-enum-decls)
+mlir_tablegen(PtrOpsEnums.cpp.inc -gen-enum-defs)
+add_public_tablegen_target(MLIRPtrOpsEnumsGen)
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/MemorySpace.h b/mlir/include/mlir/Dialect/Ptr/IR/MemorySpace.h
new file mode 100644
index 00000000000000..e467d121f2c886
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Ptr/IR/MemorySpace.h
@@ -0,0 +1,161 @@
+//===-- MemorySpace.h - ptr dialect memory space  ---------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the ptr's dialect memory space class and related
+// interfaces.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_PTR_IR_MEMORYSPACE_H
+#define MLIR_DIALECT_PTR_IR_MEMORYSPACE_H
+
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/OpDefinition.h"
+
+namespace mlir {
+class Operation;
+namespace ptr {
+/// This method checks if it's valid to perform an `addrspacecast` op in the
+/// memory space.
+/// Compatible types are:
+/// Vectors of rank 1, or scalars of `ptr` type.
+LogicalResult isValidAddrSpaceCastImpl(Type tgt, Type src,
+                                       Operation *diagnosticOp);
+
+/// This method checks if it's valid to perform a `ptrtoint` or `inttoptr` op in
+/// the memory space.
+/// Compatible types are:
+/// IntLikeTy: Vectors of rank 1, or scalars of integer types or `index` type.
+/// PtrLikeTy: Vectors of rank 1, or scalars of `ptr` type.
+LogicalResult isValidPtrIntCastImpl(Type intLikeTy, Type ptrLikeTy,
+                                    Operation *diagnosticOp);
+
+enum class AtomicBinOp : uint64_t;
+enum class AtomicOrdering : uint64_t;
+} // namespace ptr
+} // namespace mlir
+
+#include "mlir/Dialect/Ptr/IR/MemorySpaceAttrInterfaces.h.inc"
+
+namespace mlir {
+namespace ptr {
+/// This class wraps the `MemorySpaceAttrInterface` interface, providing a safe
+/// mechanism to specify the default behavior assumed by the ptr dialect.
+class MemorySpace {
+public:
+  MemorySpace() = default;
+  MemorySpace(std::nullptr_t) {}
+  MemorySpace(MemorySpaceAttrInterface memorySpace)
+      : memorySpaceAttr(memorySpace), memorySpace(memorySpace) {}
+  MemorySpace(Attribute memorySpace)
+      : memorySpaceAttr(memorySpace),
+        memorySpace(dyn_cast_or_null<MemorySpaceAttrInterface>(memorySpace)) {}
+
+  operator Attribute() const { return memorySpaceAttr; }
+  operator MemorySpaceAttrInterface() const { return memorySpace; }
+  bool operator==(const MemorySpace &memSpace) const {
+    return memSpace.memorySpaceAttr == memorySpaceAttr;
+  }
+
+  /// Returns the underlying memory space.
+  Attribute getUnderlyingSpace() const { return memorySpaceAttr; }
+
+  /// Returns true if the underlying memory space is null.
+  bool isDefaultModel() const { return memorySpace == nullptr; }
+
+  /// Returns the memory space as an integer, or 0 if using the default space.
+  unsigned getAddressSpace() const {
+    if (memorySpace)
+      return memorySpace.getAddressSpace();
+    if (auto intAttr = llvm::dyn_cast_or_null<IntegerAttr>(memorySpaceAttr))
+      return intAttr.getInt();
+    return 0;
+  }
+
+  /// Returns the default memory space as an attribute, or nullptr if using the
+  /// default model.
+  Attribute getDefaultMemorySpace() const {
+    return memorySpace ? memorySpace.getDefaultMemorySpace() : nullptr;
+  }
+
+  /// This method checks if it's valid to load a value from the memory space
+  /// with a specific type, alignment, and atomic ordering. The default model
+  /// assumes all values are loadable.
+  LogicalResult isValidLoad(Type type, AtomicOrdering ordering,
+                            IntegerAttr alignment,
+                            Operation *diagnosticOp = nullptr) const {
+    return memorySpace ? memorySpace.isValidLoad(type, ordering, alignment,
+                                                 diagnosticOp)
+                       : success();
+  }
+
+  /// This method checks if it's valid to store a value in the memory space with
+  /// a specific type, alignment, and atomic ordering. The default model assumes
+  /// all values are loadable.
+  LogicalResult isValidStore(Type type, AtomicOrdering ordering,
+                             IntegerAttr alignment,
+                             Operation *diagnosticOp = nullptr) const {
+    return memorySpace ? memorySpace.isValidStore(type, ordering, alignment,
+                                                  diagnosticOp)
+                       : success();
+  }
+
+  /// This method checks if it's valid to perform an atomic operation in the
+  /// memory space with a specific type, alignment, and atomic ordering.
+  LogicalResult isValidAtomicOp(AtomicBinOp op, Type type,
+                                AtomicOrdering ordering, IntegerAttr alignment,
+                                Operation *diagnosticOp = nullptr) const {
+    return memorySpace ? memorySpace.isValidAtomicOp(op, type, ordering,
+                                                     alignment, diagnosticOp)
+                       : success();
+  }
+
+  /// This method checks if it's valid to perform an atomic operation in the
+  /// memory space with a specific type, alignment, and atomic ordering.
+  LogicalResult isValidAtomicXchg(Type type, AtomicOrdering successOrdering,
+                                  AtomicOrdering failureOrdering,
+                                  IntegerAttr alignment,
+                                  Operation *diagnosticOp = nullptr) const {
+    return memorySpace ? memorySpace.isValidAtomicXchg(type, successOrdering,
+                                                       failureOrdering,
+                                                       alignment, diagnosticOp)
+                       : success();
+  }
+
+  /// This method checks if it's valid to perform an `addrspacecast` op in the
+  /// memory space.
+  LogicalResult isValidAddrSpaceCast(Type tgt, Type src,
+                                     Operation *diagnosticOp = nullptr) const {
+    return memorySpace
+               ? memorySpace.isValidAddrSpaceCast(tgt, src, diagnosticOp)
+               : isValidAddrSpaceCastImpl(tgt, src, diagnosticOp);
+  }
+
+  /// This method checks if it's valid to perform a `ptrtoint` or `inttoptr` op
+  /// in the memory space.
+  LogicalResult isValidPtrIntCast(Type intLikeTy, Type ptrLikeTy,
+                                  Operation *diagnosticOp = nullptr) const {
+    return memorySpace
+               ? memorySpace.isValidPtrIntCast(intLikeTy, ptrLikeTy,
+                                               diagnosticOp)
+               : isValidPtrIntCastImpl(intLikeTy, ptrLikeTy, diagnosticOp);
+  }
+
+protected:
+  /// Underlying memory space.
+  Attribute memorySpaceAttr{};
+  /// Memory space.
+  MemorySpaceAttrInterface memorySpace{};
+};
+} // namespace ptr
+} // namespace mlir
+
+#include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h.inc"
+
+#endif // MLIR_DIALECT_PTR_IR_MEMORYSPACE_H
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td b/mlir/include/mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td
new file mode 100644
index 00000000000000..b7bed95434839e
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td
@@ -0,0 +1,182 @@
+//===-- MemorySpaceInterfaces.td - Memory space interfaces ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines memory space attribute interfaces.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef PTR_MEMORYSPACEINTERFACES
+#define PTR_MEMORYSPACEINTERFACES
+
+include "mlir/IR/AttrTypeBase.td"
+include "mlir/IR/OpBase.td"
+
+//===----------------------------------------------------------------------===//
+// Memory space attribute interface.
+//===----------------------------------------------------------------------===//
+
+def MemorySpaceAttrInterface : AttrInterface<"MemorySpaceAttrInterface"> {
+  let description = [{
+    This interface defines a common API for interacting with the memory model of
+    a memory space and the operations in the pointer dialect, giving proper
+    semantical meaning to the ops.
+
+    Furthermore, this interface allows concepts such as read-only memory to be
+    adequately modeled and enforced.
+  }];
+  let cppNamespace = "::mlir::ptr";
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/        [{
+        Returns the dialect implementing the memory space.
+      }],
+      /*returnType=*/  "::mlir::Dialect*",
+      /*methodName=*/  "getMemorySpaceDialect",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{ return nullptr; }]
+    >,
+    InterfaceMethod<
+      /*desc=*/        [{
+        Returns the default memory space as an attribute.
+      }],
+      /*returnType=*/  "::mlir::Attribute",
+      /*methodName=*/  "getDefaultMemorySpace",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{}]
+    >,
+    InterfaceMethod<
+      /*desc=*/        [{
+        Returns the memory space as an integer, or 0 if using the default model.
+      }],
+      /*returnType=*/  "unsigned",
+      /*methodName=*/  "getAddressSpace",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{ return 0; }]
+    >,
+    InterfaceMethod<
+      /*desc=*/        [{
+        This method checks if it's valid to load a value from the memory space
+        with a specific type, alignment, and atomic ordering.
+        If `diagnosticOp` is non-null then the method might emit diagnostics.
+      }],
+      /*returnType=*/  "::mlir::LogicalResult",
+      /*methodName=*/  "isValidLoad",
+      /*args=*/        (ins "::mlir::Type":$type,
+                            "::mlir::ptr::AtomicOrdering":$ordering,
+                            "::mlir::IntegerAttr":$alignment,
+                            "::mlir::Operation*":$diagnosticOp),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{}]
+    >,
+    InterfaceMethod<
+      /*desc=*/        [{
+        This method checks if it's valid to store a value in the memory space
+        with a specific type, alignment, and atomic ordering.
+        If `diagnosticOp` is non-null then the method might emit diagnostics.
+      }],
+      /*returnType=*/  "::mlir::LogicalResult",
+      /*methodName=*/  "isValidStore",
+      /*args=*/        (ins "::mlir::Type":$type,
+                            "::mlir::ptr::AtomicOrdering":$ordering,
+                            "::mlir::IntegerAttr":$alignment,
+                            "::mlir::Operation*":$diagnosticOp),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{}]
+    >,
+    InterfaceMethod<
+      /*desc=*/        [{
+        This method checks if it's valid to perform an atomic operation in the
+        memory space with a specific type, alignment, and atomic ordering.
+        If `diagnosticOp` is non-null then the method might emit diagnostics.
+      }],
+      /*returnType=*/  "::mlir::LogicalResult",
+      /*methodName=*/  "isValidAtomicOp",
+      /*args=*/        (ins "::mlir::ptr::AtomicBinOp":$op,
+                            "::mlir::Type":$type,
+                            "::mlir::ptr::AtomicOrdering":$ordering,
+                            "::mlir::IntegerAttr":$alignment,
+                            "::mlir::Operation*":$diagnosticOp),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{}]
+    >,
+    InterfaceMethod<
+      /*desc=*/        [{
+        This method checks if it's valid to perform an atomic exchange operation
+        in the memory space with a specific type, alignment, and atomic
+        orderings.
+        If `diagnosticOp` is non-null then the method might emit diagnostics.
+      }],
+      /*returnType=*/  "::mlir::LogicalResult",
+      /*methodName=*/  "isValidAtomicXchg",
+      /*args=*/        (ins "::mlir::Type":$type,
+                            "::mlir::ptr::AtomicOrdering":$successOrdering,
+                            "::mlir::ptr::AtomicOrdering":$failureOrdering,
+                            "::mlir::IntegerAttr":$alignment,
+                            "::mlir::Operation*":$diagnosticOp),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{}]
+    >,
+    InterfaceMethod<
+      /*desc=*/        [{
+        This method checks if it's valid to perform an `addrspacecast` op
+        in the memory space.
+        Both types are expected to be vectors of rank 1, or scalars of `ptr`
+        type.
+        If `diagnosticOp` is non-null then the method might emit diagnostics.
+      }],
+      /*returnType=*/  "::mlir::LogicalResult",
+      /*methodName=*/  "isValidAddrSpaceCast",
+      /*args=*/        (ins "::mlir::Type":$tgt,
+                            "::mlir::Type":$src,
+                            "::mlir::Operation*":$diagnosticOp),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{}]
+    >,
+    InterfaceMethod<
+      /*desc=*/        [{
+        This method checks if it's valid to perform a `ptrtoint` or `inttoptr`
+        op in the memory space. `CastValidity::InvalidSourceType` always refers
+        to the 'ptr-like' type and `CastValidity::InvalidTargetType` always
+        refers to the `int-like` type.
+        The first type is expected to be integer-like, while the second must be a
+        ptr-like type.
+        If `diagnosticOp` is non-null then the method might emit diagnostics.
+      }],
+      /*returnType=*/  "::mlir::LogicalResult",
+      /*methodName=*/  "isValidPtrIntCast",
+      /*args=*/        (ins "::mlir::Type":$intLikeTy,
+                            "::mlir::Type":$ptrLikeTy,
+                            "::mlir::Operation*":$diagnosticOp),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{}]
+    >,
+  ];
+}
+
+def MemorySpaceOpInterface : OpInterface<"MemorySpaceOpInterface"> {
+  let description = [{
+    An interface for operations with a memory space.
+  }];
+
+  let cppNamespace = "::mlir::ptr";
+
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/        "Returns the memory space of the op.",
+      /*returnType=*/  "::mlir::ptr::MemorySpace",
+      /*methodName=*/  "getMemorySpace",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{}]
+      >,
+  ];
+}
+#endif // PTR_MEMORYSPACEINTERFACES
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrs.h b/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrs.h
new file mode 100644
index 00000000000000..e6aa7635919f6d
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrs.h
@@ -0,0 +1,20 @@
+//===- PtrAttrs.h - Pointer dialect attributes ------------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the Ptr dialect attributes.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_PTR_IR_PTRATTRS_H
+#define MLIR_DIALECT_PTR_IR_PTRATTRS_H
+
+#include "mlir/IR/OpImplementation.h"
+
+#include "mlir/Dialect/Ptr/IR/PtrOpsEnums.h.inc"
+
+#endif // MLIR_DIALECT_PTR_IR_PTRATTRS_H
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.td
new file mode 100644
index 00000000000000..3a921e4de08d8d
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.td
@@ -0,0 +1,69 @@
+//===-- PtrEnums.td - Ptr dialect enumerations -------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef PTR_ENUMS
+#define PTR_ENUMS
+
+include "mlir/IR/EnumAttr.td"
+
+//===----------------------------------------------------------------------===//
+// Atomic binary op enum attribute
+//===----------------------------------------------------------------------===//
+
+def AtomicBinOpXchg : I64EnumAttrCase<"xchg", 0, "xchg">;
+def AtomicBinOpAdd  : I64EnumAttrCase<"add", 1, "add">;
+def AtomicBinOpSub  : I64EnumAttrCase<"sub", 2, "sub">;
+def AtomicBinOpAnd  : I64EnumAttrCase<"_and", 3, "_and">;
+def AtomicBinOpNand : I64EnumAttrCase<"nand", 4, "nand">;
+def AtomicBinOpOr   : I64EnumAttrCase<"_or", 5, "_or">;
+def AtomicBinOpXor  : I64EnumAttrCase<"_xor", 6, "_xor">;
+def AtomicBinOpMax  : I64EnumAttrCase<"max", 7, "max">;
+def AtomicBinOpMin  : I64EnumAttrCase<"min", 8, "min">;
+def AtomicBinOpUMax : I64EnumAttrCase<"umax", 9, "umax">;
+def AtomicBinOpUMin : I64EnumAttrCase<"umin", 10, "umin">;
+def AtomicBinOpFAdd : I64EnumAttrCase<"fadd", 11, "fadd">;
+def AtomicBinOpFSub : I64EnumAttrCase<"fsub", 12, "fsub">;
+def AtomicBinOpFMax : I64EnumAttrCase<"fmax", 13, "fmax">;
+def AtomicBinOpFMin : I64EnumAttrCase<"fmin", 14, "fmin">;
+def AtomicBinOpUIncWrap : I64EnumAttrCase<"uinc_wrap", 15, "uinc_wrap">;
+def AtomicBinOpUDecWrap : I64EnumAttrCase<"udec_wrap", 16, "udec_wrap">;
+
+def AtomicBinOp : I64EnumAttr<
+    "AtomicBinOp",
+    "ptr.atomicrmw binary operations",
+    [AtomicBinOpXchg, AtomicBinOpAdd, AtomicBinOpSub, AtomicBinOpAnd,
+     AtomicBinOpNand, AtomicBinOpOr, AtomicBinOpXor, AtomicBinOpMax,
+     AtomicBinOpMin, AtomicBinOpUMax, AtomicBinOpUMin, AtomicBinOpFAdd,
+     AtomicBinOpFSub, AtomicBinOpFMax, AtomicBinOpFMin, AtomicBinOpUIncWrap,
+     AtomicBinOpUDecWrap]> {
+  let cppNamespace = "::mlir::ptr";
+}
+
+//===----------------------------------------------------------------------===//
+// Atomic ordering enum attribute
+//===----------------------------------------------------------------------===//
+
+def AtomicOrderingNotAtomic : I64EnumAttrCase<"not_atomic", 0, "not_atomic">;
+def AtomicOrderingUnordered : I64EnumAttrCase<"unordered", 1, "unordered">;
+def AtomicOrderingMonotonic : I64EnumAttrCase<"monotonic", 2, "monotonic">;
+def AtomicOrderingAcquire   : I64EnumAttrCase<"acquire", 3, "acquire">;
+def AtomicOrderingRelease   : I64EnumAttrCase<"release", 4, "release">;
+def AtomicOrderingAcqRel    : I64EnumAttrCase<"acq_rel", 5, "acq_rel">;
+def AtomicOrderingSeqCst    : I64EnumAttrCase<"seq_cst", 6, "seq_cst">;
+
+def AtomicOrdering : I64EnumAttr<
+    "AtomicOrdering",
+    "Atomic ordering for LLVM's memory model",
+    [AtomicOrderingNotAtomic, AtomicOrderingUnordered, AtomicOrderingMonotonic,
+     AtomicOrderingAcquire, AtomicOrderingRelease, AtomicOrderingAcqRel,
+     AtomicOrderingSeqCst
+    ]> {
+  let cppNamespace = "::mlir::ptr";
+}
+
+#endif // PTR_ENUMS
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h
index ad8a2bbcbdd8d2..23d41a11a6da03 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h
@@ -14,6 +14,8 @@
 #define MLIR_DIALECT_PTR_IR_PTROPS_H
 
 #include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/Ptr/IR/MemorySpace.h"
+#include "mlir/Dialect/Ptr/IR/PtrAttrs.h"
 #include "mlir/Dialect/Ptr/IR/PtrDialect.h"
 #include "mlir/Dialect/Ptr/IR/PtrTypes.h"
 #include "mlir/IR/OpDefinition.h"
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
index 690941337bdfb5..91c73804c27133 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
@@ -10,6 +10,7 @@
 #define PTR_OPS
 
 include "mlir/Dialect/Ptr/IR/PtrDialect.td"
+include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td"
 include "mlir/IR/OpAsmInterface.td"
 
 #endif // PTR_OPS
diff --git a/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt b/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt
index 359b9f02a06266..24cc3bc6ef3b57 100644
--- a/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt
@@ -6,6 +6,8 @@ add_mlir_dialect_library(
   ${PROJECT_SOURCE_DIR}/mlir/Dialect/Pointer
   DEPENDS
   MLIRPtrOpsIncGen
+  MLIRPtrOpsEnumsGen
+  MLIRPtrMemorySpaceInterfacesIncGen
   LINK_LIBS
   PUBLIC
   MLIRIR
diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
index 59c97b22f332c4..c97a0626ae169e 100644
--- a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
+++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
@@ -39,8 +39,77 @@ void PtrDialect::initialize() {
 // Pointer API.
 //===----------------------------------------------------------------------===//
 
+// Returns a pair containing:
+// The underlying type of a vector or the type itself if it's not a vector.
+// The number of elements in the vector or an error code if the type is not
+// supported.
+static std::pair<Type, int64_t> getVecOrScalarInfo(Type ty) {
+  if (auto vecTy = dyn_cast<VectorType>(ty)) {
+    auto elemTy = vecTy.getElementType();
+    // Vectors of rank greater than one or with scalable dimensions are not
+    // supported.
+    if (vecTy.getRank() != 1)
+      return {elemTy, -1};
+    else if (vecTy.getScalableDims()[0])
+      return {elemTy, -2};
+    return {elemTy, vecTy.getShape()[0]};
+  }
+  // `ty` is a scalar type.
+  return {ty, 0};
+}
+
+LogicalResult mlir::ptr::isValidAddrSpaceCastImpl(Type tgt, Type src,
+                                                  Operation *op) {
+  std::pair<Type, int64_t> tgtInfo = getVecOrScalarInfo(tgt);
+  std::pair<Type, int64_t> srcInfo = getVecOrScalarInfo(src);
+  if (!isa<PtrType>(tgtInfo.first) || !isa<PtrType>(srcInfo.first))
+    return op ? op->emitError("invalid ptr-like operand") : failure();
+  // Check shape validity.
+  if (tgtInfo.second == -1 || srcInfo.second == -1)
+    return op ? op->emitError("vectors of rank != 1 are not supported")
+              : failure();
+  if (tgtInfo.second == -2 || srcInfo.second == -2)
+    return op ? op->emitError(
+                    "vectors with scalable dimensions are not supported")
+              : failure();
+  if (tgtInfo.second != srcInfo.second)
+    return op ? op->emitError("incompatible operand shapes") : failure();
+  return success();
+}
+
+LogicalResult mlir::ptr::isValidPtrIntCastImpl(Type intLikeTy, Type ptrLikeTy,
+                                               Operation *op) {
+  // Check int-like type.
+  std::pair<Type, int64_t> intInfo = getVecOrScalarInfo(intLikeTy);
+  if (!intInfo.first.isSignlessIntOrIndex())
+    /// The int-like operand is invalid.
+    return op ? op->emitError("invalid int-like type") : failure();
+  // Check ptr-like type.
+  std::pair<Type, int64_t> ptrInfo = getVecOrScalarInfo(ptrLikeTy);
+  if (!isa<PtrType>(ptrInfo.first))
+    /// The pointer-like operand is invalid.
+    return op ? op->emitError("invalid ptr-like type") : failure();
+  // Check shape validity.
+  if (intInfo.second == -1 || ptrInfo.second == -1)
+    return op ? op->emitError("vectors of rank != 1 are not supported")
+              : failure();
+  if (intInfo.second == -2 || ptrInfo.second == -2)
+    return op ? op->emitError(
+                    "vectors with scalable dimensions are not supported")
+              : failure();
+  if (intInfo.second != ptrInfo.second)
+    return op ? op->emitError("incompatible operand shapes") : failure();
+  return success();
+}
+
 #include "mlir/Dialect/Ptr/IR/PtrOpsDialect.cpp.inc"
 
+#include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.cpp.inc"
+
+#include "mlir/Dialect/Ptr/IR/MemorySpaceAttrInterfaces.cpp.inc"
+
+#include "mlir/Dialect/Ptr/IR/PtrOpsEnums.cpp.inc"
+
 #define GET_TYPEDEF_CLASSES
 #include "mlir/Dialect/Ptr/IR/PtrOpsTypes.cpp.inc"
 
diff --git a/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp b/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp
index 51d0a45051b85e..8f83cc210d44b7 100644
--- a/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp
+++ b/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp
@@ -11,6 +11,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Ptr/IR/PtrTypes.h"
+#include "mlir/Dialect/Ptr/IR/MemorySpace.h"
 #include "llvm/ADT/TypeSwitch.h"
 
 using namespace mlir;
@@ -64,9 +65,13 @@ getPointerDataLayoutEntry(DataLayoutEntryListRef params, PtrType type,
   return std::nullopt;
 }
 
-int64_t PtrType::getAddressSpace() const { return 0; }
+int64_t PtrType::getAddressSpace() const {
+  return MemorySpace(getMemorySpace()).getAddressSpace();
+}
 
-Attribute PtrType::getDefaultMemorySpace() const { return nullptr; }
+Attribute PtrType::getDefaultMemorySpace() const {
+  return MemorySpace(getMemorySpace()).getDefaultMemorySpace();
+}
 
 bool PtrType::areCompatible(DataLayoutEntryListRef oldLayout,
                             DataLayoutEntryListRef newLayout) const {
diff --git a/mlir/test/Dialect/Ptr/types.mlir b/mlir/test/Dialect/Ptr/types.mlir
index 279213bd6fc3e5..9bec8f01803e43 100644
--- a/mlir/test/Dialect/Ptr/types.mlir
+++ b/mlir/test/Dialect/Ptr/types.mlir
@@ -15,3 +15,10 @@ func.func @ptr_test(%arg0: !ptr.ptr, %arg1: !ptr.ptr<1 : i32>) -> (!ptr.ptr<1 :
 func.func @ptr_test(%arg0: memref<!ptr.ptr>) {
   return
 }
+
+// CHECK-LABEL: func @ptr_test_1
+// CHECK: (%[[ARG0:.*]]: !ptr.ptr<#test.const_memory_space>, %[[ARG1:.*]]: !ptr.ptr<#test.const_memory_space<3>>)
+func.func @ptr_test_1(%arg0: !ptr.ptr<#test.const_memory_space>,
+                      %arg1: !ptr.ptr<#test.const_memory_space<3>>) {
+  return
+}
diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt
index b82b1631eead59..b7e4c8ae1145bb 100644
--- a/mlir/test/lib/Dialect/Test/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt
@@ -78,6 +78,7 @@ add_mlir_library(MLIRTestDialect
   MLIRInferTypeOpInterface
   MLIRLinalgDialect
   MLIRLinalgTransforms
+  MLIRPtrDialect
   MLIRLLVMDialect
   MLIRPass
   MLIRReduce
diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index 40f035a3e3a4e5..82dbc92333e30e 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -16,6 +16,7 @@
 // To get the test dialect definition.
 include "TestDialect.td"
 include "TestEnumDefs.td"
+include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td"
 include "mlir/Dialect/Utils/StructuredOpsUtils.td"
 include "mlir/IR/AttrTypeBase.td"
 include "mlir/IR/BuiltinAttributeInterfaces.td"
@@ -340,4 +341,13 @@ def TestConditionalAliasAttr : Test_Attr<"TestConditionalAlias"> {
   }];
 }
 
+// Test a ptr constant memory space.
+def TestConstMemorySpaceAttr : Test_Attr<"TestConstMemorySpace", [
+    DeclareAttrInterfaceMethods<MemorySpaceAttrInterface>
+  ]> {
+  let mnemonic = "const_memory_space";
+  let parameters = (ins DefaultValuedParameter<"unsigned", "0">:$addressSpace);
+  let assemblyFormat = "(`<` $addressSpace^ `>`)?";
+}
+
 #endif // TEST_ATTRDEFS
diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
index d41d495c38e553..60c362465f12fe 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
@@ -239,6 +239,54 @@ static void printConditionalAlias(AsmPrinter &p, StringAttr value) {
   p.printKeywordOrString(value);
 }
 
+//===----------------------------------------------------------------------===//
+// TestConstMemorySpaceAttr
+//===----------------------------------------------------------------------===//
+
+Attribute TestConstMemorySpaceAttr::getDefaultMemorySpace() const {
+  return TestConstMemorySpaceAttr::get(getContext(), 0);
+}
+
+LogicalResult TestConstMemorySpaceAttr::isValidLoad(
+    Type type, mlir::ptr::AtomicOrdering ordering, IntegerAttr alignment,
+    Operation *diagnosticOp) const {
+  return success();
+}
+
+LogicalResult TestConstMemorySpaceAttr::isValidStore(
+    Type type, mlir::ptr::AtomicOrdering ordering, IntegerAttr alignment,
+    Operation *diagnosticOp) const {
+  return diagnosticOp ? diagnosticOp->emitError("memory space is read-only")
+                      : failure();
+}
+
+LogicalResult TestConstMemorySpaceAttr::isValidAtomicOp(
+    mlir::ptr::AtomicBinOp binOp, Type type, mlir::ptr::AtomicOrdering ordering,
+    IntegerAttr alignment, Operation *diagnosticOp) const {
+  return diagnosticOp ? diagnosticOp->emitError("memory space is read-only")
+                      : failure();
+}
+
+LogicalResult TestConstMemorySpaceAttr::isValidAtomicXchg(
+    Type type, mlir::ptr::AtomicOrdering successOrdering,
+    mlir::ptr::AtomicOrdering failureOrdering, IntegerAttr alignment,
+    Operation *diagnosticOp) const {
+  return diagnosticOp ? diagnosticOp->emitError("memory space is read-only")
+                      : failure();
+}
+
+LogicalResult
+TestConstMemorySpaceAttr::isValidAddrSpaceCast(Type tgt, Type src,
+                                               Operation *diagnosticOp) const {
+  return ptr::isValidAddrSpaceCastImpl(tgt, src, diagnosticOp);
+}
+
+LogicalResult
+TestConstMemorySpaceAttr::isValidPtrIntCast(Type intLikeTy, Type ptrLikeTy,
+                                            Operation *diagnosticOp) const {
+  return ptr::isValidPtrIntCastImpl(intLikeTy, ptrLikeTy, diagnosticOp);
+}
+
 //===----------------------------------------------------------------------===//
 // Tablegen Generated Definitions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.h b/mlir/test/lib/Dialect/Test/TestAttributes.h
index ef6eae51fdd628..a84e26fba9d912 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.h
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.h
@@ -17,6 +17,7 @@
 #include <tuple>
 
 #include "TestTraits.h"
+#include "mlir/Dialect/Ptr/IR/MemorySpace.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Diagnostics.h"

>From 161d1fcf888466b2e439b8aa54662283fef6f9a1 Mon Sep 17 00:00:00 2001
From: Fabian Mora <fmora.dev at gmail.com>
Date: Tue, 2 Apr 2024 15:40:49 +0000
Subject: [PATCH 3/3] address reviewer comments

---
 .../include/mlir/Dialect/Ptr/IR/MemorySpace.h | 50 ++++++++---------
 mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp        | 53 +++++++++++--------
 2 files changed, 55 insertions(+), 48 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Ptr/IR/MemorySpace.h b/mlir/include/mlir/Dialect/Ptr/IR/MemorySpace.h
index e467d121f2c886..948f97a0c95626 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/MemorySpace.h
+++ b/mlir/include/mlir/Dialect/Ptr/IR/MemorySpace.h
@@ -21,14 +21,14 @@
 namespace mlir {
 class Operation;
 namespace ptr {
-/// This method checks if it's valid to perform an `addrspacecast` op in the
+/// Checks if it's valid to perform an `addrspacecast` op in the
 /// memory space.
 /// Compatible types are:
 /// Vectors of rank 1, or scalars of `ptr` type.
 LogicalResult isValidAddrSpaceCastImpl(Type tgt, Type src,
                                        Operation *diagnosticOp);
 
-/// This method checks if it's valid to perform a `ptrtoint` or `inttoptr` op in
+/// Checks if it's valid to perform a `ptrtoint` or `inttoptr` op in
 /// the memory space.
 /// Compatible types are:
 /// IntLikeTy: Vectors of rank 1, or scalars of integer types or `index` type.
@@ -52,28 +52,29 @@ class MemorySpace {
   MemorySpace() = default;
   MemorySpace(std::nullptr_t) {}
   MemorySpace(MemorySpaceAttrInterface memorySpace)
-      : memorySpaceAttr(memorySpace), memorySpace(memorySpace) {}
-  MemorySpace(Attribute memorySpace)
-      : memorySpaceAttr(memorySpace),
+      : underlyingMemorySpace(memorySpace), memorySpace(memorySpace) {}
+  explicit MemorySpace(Attribute memorySpace)
+      : underlyingMemorySpace(memorySpace),
         memorySpace(dyn_cast_or_null<MemorySpaceAttrInterface>(memorySpace)) {}
 
-  operator Attribute() const { return memorySpaceAttr; }
+  operator Attribute() const { return underlyingMemorySpace; }
   operator MemorySpaceAttrInterface() const { return memorySpace; }
   bool operator==(const MemorySpace &memSpace) const {
-    return memSpace.memorySpaceAttr == memorySpaceAttr;
+    return memSpace.underlyingMemorySpace == underlyingMemorySpace;
   }
 
   /// Returns the underlying memory space.
-  Attribute getUnderlyingSpace() const { return memorySpaceAttr; }
+  Attribute getUnderlyingSpace() const { return underlyingMemorySpace; }
 
-  /// Returns true if the underlying memory space is null.
+  /// Returns true if the memory space is null.
   bool isDefaultModel() const { return memorySpace == nullptr; }
 
   /// Returns the memory space as an integer, or 0 if using the default space.
   unsigned getAddressSpace() const {
     if (memorySpace)
       return memorySpace.getAddressSpace();
-    if (auto intAttr = llvm::dyn_cast_or_null<IntegerAttr>(memorySpaceAttr))
+    if (auto intAttr =
+            llvm::dyn_cast_or_null<IntegerAttr>(underlyingMemorySpace))
       return intAttr.getInt();
     return 0;
   }
@@ -84,9 +85,9 @@ class MemorySpace {
     return memorySpace ? memorySpace.getDefaultMemorySpace() : nullptr;
   }
 
-  /// This method checks if it's valid to load a value from the memory space
-  /// with a specific type, alignment, and atomic ordering. The default model
-  /// assumes all values are loadable.
+  /// Checks if it's valid to load a value from the memory space with a specific
+  /// type, alignment, and atomic ordering. The default model assumes all values
+  /// can be loaded.
   LogicalResult isValidLoad(Type type, AtomicOrdering ordering,
                             IntegerAttr alignment,
                             Operation *diagnosticOp = nullptr) const {
@@ -95,9 +96,9 @@ class MemorySpace {
                        : success();
   }
 
-  /// This method checks if it's valid to store a value in the memory space with
-  /// a specific type, alignment, and atomic ordering. The default model assumes
-  /// all values are loadable.
+  /// Checks if it's valid to store a value in the memory space with a specific
+  /// type, alignment, and atomic ordering. The default model assumes all values
+  /// can be stored.
   LogicalResult isValidStore(Type type, AtomicOrdering ordering,
                              IntegerAttr alignment,
                              Operation *diagnosticOp = nullptr) const {
@@ -106,8 +107,8 @@ class MemorySpace {
                        : success();
   }
 
-  /// This method checks if it's valid to perform an atomic operation in the
-  /// memory space with a specific type, alignment, and atomic ordering.
+  /// Checks if it's valid to perform an atomic operation in the memory space
+  /// with a specific type, alignment, and atomic ordering.
   LogicalResult isValidAtomicOp(AtomicBinOp op, Type type,
                                 AtomicOrdering ordering, IntegerAttr alignment,
                                 Operation *diagnosticOp = nullptr) const {
@@ -116,8 +117,8 @@ class MemorySpace {
                        : success();
   }
 
-  /// This method checks if it's valid to perform an atomic operation in the
-  /// memory space with a specific type, alignment, and atomic ordering.
+  /// Checks if it's valid to perform an atomic exchange operation in the memory
+  /// space with a specific type, alignment, and atomic ordering.
   LogicalResult isValidAtomicXchg(Type type, AtomicOrdering successOrdering,
                                   AtomicOrdering failureOrdering,
                                   IntegerAttr alignment,
@@ -128,8 +129,7 @@ class MemorySpace {
                        : success();
   }
 
-  /// This method checks if it's valid to perform an `addrspacecast` op in the
-  /// memory space.
+  /// Checks if it's valid to perform an `addrspacecast` op in the memory space.
   LogicalResult isValidAddrSpaceCast(Type tgt, Type src,
                                      Operation *diagnosticOp = nullptr) const {
     return memorySpace
@@ -137,8 +137,8 @@ class MemorySpace {
                : isValidAddrSpaceCastImpl(tgt, src, diagnosticOp);
   }
 
-  /// This method checks if it's valid to perform a `ptrtoint` or `inttoptr` op
-  /// in the memory space.
+  /// Checks if it's valid to perform a `ptrtoint` or `inttoptr` op in the
+  /// memory space.
   LogicalResult isValidPtrIntCast(Type intLikeTy, Type ptrLikeTy,
                                   Operation *diagnosticOp = nullptr) const {
     return memorySpace
@@ -149,7 +149,7 @@ class MemorySpace {
 
 protected:
   /// Underlying memory space.
-  Attribute memorySpaceAttr{};
+  Attribute underlyingMemorySpace{};
   /// Memory space.
   MemorySpaceAttrInterface memorySpace{};
 };
diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
index c97a0626ae169e..195c1a56492902 100644
--- a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
+++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
@@ -39,6 +39,10 @@ void PtrDialect::initialize() {
 // Pointer API.
 //===----------------------------------------------------------------------===//
 
+// Error constants for vector data types.
+constexpr const static unsigned kInvalidRankError = -1;
+constexpr const static unsigned kScalableDimsError = -2;
+
 // Returns a pair containing:
 // The underlying type of a vector or the type itself if it's not a vector.
 // The number of elements in the vector or an error code if the type is not
@@ -49,26 +53,28 @@ static std::pair<Type, int64_t> getVecOrScalarInfo(Type ty) {
     // Vectors of rank greater than one or with scalable dimensions are not
     // supported.
     if (vecTy.getRank() != 1)
-      return {elemTy, -1};
+      return {elemTy, kInvalidRankError};
     else if (vecTy.getScalableDims()[0])
-      return {elemTy, -2};
+      return {elemTy, kScalableDimsError};
     return {elemTy, vecTy.getShape()[0]};
   }
   // `ty` is a scalar type.
   return {ty, 0};
 }
 
-LogicalResult mlir::ptr::isValidAddrSpaceCastImpl(Type tgt, Type src,
-                                                  Operation *op) {
-  std::pair<Type, int64_t> tgtInfo = getVecOrScalarInfo(tgt);
-  std::pair<Type, int64_t> srcInfo = getVecOrScalarInfo(src);
-  if (!isa<PtrType>(tgtInfo.first) || !isa<PtrType>(srcInfo.first))
-    return op ? op->emitError("invalid ptr-like operand") : failure();
+/// Checks whether the shape of the operands is compatible with the operation.
+/// Operands must be scalars or have the same vector shape, additionally only
+/// vectors of rank 1 are supported.
+static LogicalResult verifyShapeInfo(mlir::Operation *op,
+                                     const std::pair<Type, int64_t> &tgtInfo,
+                                     const std::pair<Type, int64_t> &srcInfo) {
   // Check shape validity.
-  if (tgtInfo.second == -1 || srcInfo.second == -1)
+  if (tgtInfo.second == kInvalidRankError ||
+      srcInfo.second == kInvalidRankError)
     return op ? op->emitError("vectors of rank != 1 are not supported")
               : failure();
-  if (tgtInfo.second == -2 || srcInfo.second == -2)
+  if (tgtInfo.second == kScalableDimsError ||
+      srcInfo.second == kScalableDimsError)
     return op ? op->emitError(
                     "vectors with scalable dimensions are not supported")
               : failure();
@@ -77,29 +83,30 @@ LogicalResult mlir::ptr::isValidAddrSpaceCastImpl(Type tgt, Type src,
   return success();
 }
 
+LogicalResult mlir::ptr::isValidAddrSpaceCastImpl(Type tgt, Type src,
+                                                  Operation *op) {
+  std::pair<Type, int64_t> tgtInfo = getVecOrScalarInfo(tgt);
+  std::pair<Type, int64_t> srcInfo = getVecOrScalarInfo(src);
+  if (!isa<PtrType>(tgtInfo.first) || !isa<PtrType>(srcInfo.first))
+    return op ? op->emitError("invalid ptr-like operand") : failure();
+  // Verify shape validity.
+  return verifyShapeInfo(op, tgtInfo, srcInfo);
+}
+
 LogicalResult mlir::ptr::isValidPtrIntCastImpl(Type intLikeTy, Type ptrLikeTy,
                                                Operation *op) {
   // Check int-like type.
   std::pair<Type, int64_t> intInfo = getVecOrScalarInfo(intLikeTy);
+  // The int-like operand is invalid.
   if (!intInfo.first.isSignlessIntOrIndex())
-    /// The int-like operand is invalid.
     return op ? op->emitError("invalid int-like type") : failure();
   // Check ptr-like type.
   std::pair<Type, int64_t> ptrInfo = getVecOrScalarInfo(ptrLikeTy);
+  // The pointer-like operand is invalid.
   if (!isa<PtrType>(ptrInfo.first))
-    /// The pointer-like operand is invalid.
     return op ? op->emitError("invalid ptr-like type") : failure();
-  // Check shape validity.
-  if (intInfo.second == -1 || ptrInfo.second == -1)
-    return op ? op->emitError("vectors of rank != 1 are not supported")
-              : failure();
-  if (intInfo.second == -2 || ptrInfo.second == -2)
-    return op ? op->emitError(
-                    "vectors with scalable dimensions are not supported")
-              : failure();
-  if (intInfo.second != ptrInfo.second)
-    return op ? op->emitError("incompatible operand shapes") : failure();
-  return success();
+  // Verify shape validity.
+  return verifyShapeInfo(op, intInfo, ptrInfo);
 }
 
 #include "mlir/Dialect/Ptr/IR/PtrOpsDialect.cpp.inc"



More information about the Mlir-commits mailing list