[Mlir-commits] [mlir] e035ef0 - [mlir][Ptr] Init the Ptr dialect with the `!ptr.ptr` type. (#86860)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jun 27 05:14:38 PDT 2024


Author: Fabian Mora
Date: 2024-06-27T07:14:34-05:00
New Revision: e035ef0e7423c1a4c78e922508da817dbd5b6a02

URL: https://github.com/llvm/llvm-project/commit/e035ef0e7423c1a4c78e922508da817dbd5b6a02
DIFF: https://github.com/llvm/llvm-project/commit/e035ef0e7423c1a4c78e922508da817dbd5b6a02.diff

LOG: [mlir][Ptr] Init the Ptr dialect with the `!ptr.ptr` type. (#86860)

This patch initializes the `ptr` dialect directories and base files,
adding the `!ptr.ptr` type and the `#ptr.spec<...>` data layout spec
attribute.

The `!ptr.ptr` type is an opaque pointer type optionally parameterized
by a memory space. This type typically represents a handle to an object
in memory or target-dependent values like `nullptr`.

The implementation of the `DataLayoutTypeInterface` interface for
`!ptr.ptr` was adapted from `!llvm.ptr`'s implementation. This
implementation uses the `#ptr.spec<...>` attribute for defining the data
layout specification.

See [[RFC] `ptr` dialect & modularizing ptr ops in the LLVM
dialect](https://discourse.llvm.org/t/rfc-ptr-dialect-modularizing-ptr-ops-in-the-llvm-dialect/75142)
for rationale and roadmap.

Added: 
    mlir/include/mlir/Dialect/Ptr/CMakeLists.txt
    mlir/include/mlir/Dialect/Ptr/IR/CMakeLists.txt
    mlir/include/mlir/Dialect/Ptr/IR/PtrAttrDefs.td
    mlir/include/mlir/Dialect/Ptr/IR/PtrAttrs.h
    mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.h
    mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
    mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h
    mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
    mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h
    mlir/lib/Dialect/Ptr/CMakeLists.txt
    mlir/lib/Dialect/Ptr/IR/CMakeLists.txt
    mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp
    mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
    mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp
    mlir/test/Dialect/Ptr/layout.mlir
    mlir/test/Dialect/Ptr/types.mlir

Modified: 
    mlir/include/mlir/Dialect/CMakeLists.txt
    mlir/include/mlir/InitAllDialects.h
    mlir/lib/Dialect/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt
index 4bd7f12fabf7b..f710235197334 100644
--- a/mlir/include/mlir/Dialect/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/CMakeLists.txt
@@ -29,6 +29,7 @@ add_subdirectory(OpenMP)
 add_subdirectory(PDL)
 add_subdirectory(PDLInterp)
 add_subdirectory(Polynomial)
+add_subdirectory(Ptr)
 add_subdirectory(Quant)
 add_subdirectory(SCF)
 add_subdirectory(Shape)

diff  --git a/mlir/include/mlir/Dialect/Ptr/CMakeLists.txt b/mlir/include/mlir/Dialect/Ptr/CMakeLists.txt
new file mode 100644
index 0000000000000..f33061b2d87cf
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Ptr/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(IR)

diff  --git a/mlir/include/mlir/Dialect/Ptr/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Ptr/IR/CMakeLists.txt
new file mode 100644
index 0000000000000..df07b8d5a63d9
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Ptr/IR/CMakeLists.txt
@@ -0,0 +1,7 @@
+add_mlir_dialect(PtrOps ptr)
+add_mlir_doc(PtrOps PtrOps Dialects/ -gen-op-doc)
+
+set(LLVM_TARGET_DEFINITIONS PtrOps.td)
+mlir_tablegen(PtrOpsAttrs.h.inc -gen-attrdef-decls -attrdefs-dialect=ptr)
+mlir_tablegen(PtrOpsAttrs.cpp.inc -gen-attrdef-defs -attrdefs-dialect=ptr)
+add_public_tablegen_target(MLIRPtrOpsAttributesIncGen)

diff  --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrDefs.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrDefs.td
new file mode 100644
index 0000000000000..e75038f300f1a
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrDefs.td
@@ -0,0 +1,70 @@
+//===-- PtrAttrDefs.td - Ptr Attributes definition file ----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef PTR_ATTRDEFS
+#define PTR_ATTRDEFS
+
+include "mlir/Dialect/Ptr/IR/PtrDialect.td"
+include "mlir/IR/AttrTypeBase.td"
+
+// All of the attributes will extend this class.
+class Ptr_Attr<string name, string attrMnemonic,
+                list<Trait> traits = [],
+                string baseCppClass = "::mlir::Attribute">
+    : AttrDef<Ptr_Dialect, name, traits, baseCppClass> {
+  let mnemonic = attrMnemonic;
+}
+
+//===----------------------------------------------------------------------===//
+// SpecAttr
+//===----------------------------------------------------------------------===//
+
+def Ptr_SpecAttr : Ptr_Attr<"Spec", "spec"> {
+  let summary = "ptr data layout spec";
+  let description = [{
+    Defines the data layout spec for a pointer type. This attribute has 4
+    fields:
+     - [Required] size: size of the pointer in bits.
+     - [Required] abi: ABI-required alignment for the pointer in bits.
+     - [Required] preferred: preferred alignment for the pointer in bits.
+     - [Optional] index: bitwidth that should be used when performing index
+     computations for the type. Setting the field to `kOptionalSpecValue`, means
+     the field is optional.
+    
+    Furthermore, the attribute will verify that all present values are divisible
+    by 8 (number of bits in a byte), and that `preferred` > `abi`.
+
+    Example:
+    ```mlir
+    // Spec for a 64 bit ptr, with a required alignment of 64 bits, but with
+    // a preferred alignment of 128 bits and an index bitwidth of 64 bits.
+    #ptr.spec<size = 64, abi = 64, preferred = 128, index = 64>
+    ```
+  }];
+  let parameters = (ins
+    "uint32_t":$size,
+    "uint32_t":$abi,
+    "uint32_t":$preferred,
+    DefaultValuedParameter<"uint32_t", "kOptionalSpecValue">:$index
+  );
+  let skipDefaultBuilders = 1;
+  let builders = [
+    AttrBuilder<(ins "uint32_t":$size, "uint32_t":$abi, "uint32_t":$preferred,
+                     CArg<"uint32_t", "kOptionalSpecValue">:$index), [{
+      return $_get($_ctxt, size, abi, preferred, index);
+    }]>
+  ];
+  let assemblyFormat = "`<` struct(params) `>`";
+  let extraClassDeclaration = [{
+    /// Constant for specifying a spec entry is optional.
+    static constexpr uint32_t kOptionalSpecValue = std::numeric_limits<uint32_t>::max();
+  }];
+  let genVerifyDecl = 1;
+}
+
+#endif // PTR_ATTRDEFS

diff  --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrs.h b/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrs.h
new file mode 100644
index 0000000000000..72e767764d98b
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrs.h
@@ -0,0 +1,21 @@
+//===- PtrAttrs.h - Pointer dialect attributes ------------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the Ptr dialect attributes.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_PTR_IR_PTRATTRS_H
+#define MLIR_DIALECT_PTR_IR_PTRATTRS_H
+
+#include "mlir/IR/OpImplementation.h"
+
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/Ptr/IR/PtrOpsAttrs.h.inc"
+
+#endif // MLIR_DIALECT_PTR_IR_PTRATTRS_H

diff  --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.h b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.h
new file mode 100644
index 0000000000000..92f877c20dbf0
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.h
@@ -0,0 +1,20 @@
+//===- PtrDialect.h - Pointer dialect ---------------------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the Ptr dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_PTR_IR_PTRDIALECT_H
+#define MLIR_DIALECT_PTR_IR_PTRDIALECT_H
+
+#include "mlir/IR/Dialect.h"
+
+#include "mlir/Dialect/Ptr/IR/PtrOpsDialect.h.inc"
+
+#endif // MLIR_DIALECT_PTR_IR_PTRDIALECT_H

diff  --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
new file mode 100644
index 0000000000000..14d72c3001d91
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
@@ -0,0 +1,73 @@
+//===- PtrDialect.td - Pointer dialect ---------------------*- tablegen -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef PTR_DIALECT
+#define PTR_DIALECT
+
+include "mlir/Interfaces/DataLayoutInterfaces.td"
+include "mlir/IR/AttrTypeBase.td"
+include "mlir/IR/BuiltinTypeInterfaces.td"
+include "mlir/IR/OpBase.td"
+
+//===----------------------------------------------------------------------===//
+// Pointer dialect definition.
+//===----------------------------------------------------------------------===//
+
+def Ptr_Dialect : Dialect {
+  let name = "ptr";
+  let summary = "Pointer dialect";
+  let cppNamespace = "::mlir::ptr";
+  let useDefaultTypePrinterParser = 1;
+  let useDefaultAttributePrinterParser = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// Pointer type definitions
+//===----------------------------------------------------------------------===//
+
+class Ptr_Type<string name, string typeMnemonic, list<Trait> traits = []>
+    : TypeDef<Ptr_Dialect, name, traits> {
+  let mnemonic = typeMnemonic;
+}
+
+def Ptr_PtrType : Ptr_Type<"Ptr", "ptr", [
+    MemRefElementTypeInterface,
+    DeclareTypeInterfaceMethods<DataLayoutTypeInterface, [
+      "areCompatible", "getIndexBitwidth", "verifyEntries"]>
+  ]> {
+  let summary = "pointer type";
+  let description = [{
+    The `ptr` type is an opaque pointer type. This type typically represents a
+    handle to an object in memory or target-dependent values like `nullptr`.
+    Pointers are optionally parameterized by a memory space.
+
+    Syntax:
+
+    ```mlir
+    pointer ::= `ptr` (`<` memory-space `>`)?
+    memory-space ::= attribute-value
+    ```
+  }];
+  let parameters = (ins OptionalParameter<"Attribute">:$memorySpace);
+  let assemblyFormat = "(`<` $memorySpace^ `>`)?";
+  let builders = [
+    TypeBuilder<(ins CArg<"Attribute", "nullptr">:$memorySpace), [{
+      return $_get($_ctxt, memorySpace);
+    }]>
+  ];
+  let skipDefaultBuilders = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// Base address operation definition.
+//===----------------------------------------------------------------------===//
+
+class Pointer_Op<string mnemonic, list<Trait> traits = []> :
+        Op<Ptr_Dialect, mnemonic, traits>;
+
+#endif // PTR_DIALECT

diff  --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h
new file mode 100644
index 0000000000000..6a0c1429c6be9
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h
@@ -0,0 +1,25 @@
+//===- PtrDialect.h - Pointer dialect ---------------------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the Ptr dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_PTR_IR_PTROPS_H
+#define MLIR_DIALECT_PTR_IR_PTROPS_H
+
+#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/Ptr/IR/PtrAttrs.h"
+#include "mlir/Dialect/Ptr/IR/PtrDialect.h"
+#include "mlir/Dialect/Ptr/IR/PtrTypes.h"
+#include "mlir/IR/OpDefinition.h"
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/Ptr/IR/PtrOps.h.inc"
+
+#endif // MLIR_DIALECT_PTR_IR_PTROPS_H

diff  --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
new file mode 100644
index 0000000000000..c63a0b220e501
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
@@ -0,0 +1,16 @@
+//===- PtrOps.td - Pointer dialect ops ---------------------*- tablegen -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://ptr.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef PTR_OPS
+#define PTR_OPS
+
+include "mlir/Dialect/Ptr/IR/PtrDialect.td"
+include "mlir/Dialect/Ptr/IR/PtrAttrDefs.td"
+include "mlir/IR/OpAsmInterface.td"
+
+#endif // PTR_OPS

diff  --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h b/mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h
new file mode 100644
index 0000000000000..264a97c80722a
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h
@@ -0,0 +1,22 @@
+//===- PtrTypes.h - Pointer types -------------------------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the Pointer dialect types.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_PTR_IR_PTRTYPES_H
+#define MLIR_DIALECT_PTR_IR_PTRTYPES_H
+
+#include "mlir/IR/Types.h"
+#include "mlir/Interfaces/DataLayoutInterfaces.h"
+
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/Ptr/IR/PtrOpsTypes.h.inc"
+
+#endif // MLIR_DIALECT_PTR_IR_PTRTYPES_H

diff  --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index d9db21073e15c..549c26c72d8a1 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -63,6 +63,7 @@
 #include "mlir/Dialect/PDL/IR/PDL.h"
 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
 #include "mlir/Dialect/Polynomial/IR/PolynomialDialect.h"
+#include "mlir/Dialect/Ptr/IR/PtrDialect.h"
 #include "mlir/Dialect/Quant/QuantOps.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.h"
@@ -134,6 +135,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
                   pdl::PDLDialect,
                   pdl_interp::PDLInterpDialect,
                   polynomial::PolynomialDialect,
+                  ptr::PtrDialect,
                   quant::QuantizationDialect,
                   ROCDL::ROCDLDialect,
                   scf::SCFDialect,

diff  --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt
index a324ce7f9b19f..80b0ef068d96d 100644
--- a/mlir/lib/Dialect/CMakeLists.txt
+++ b/mlir/lib/Dialect/CMakeLists.txt
@@ -29,6 +29,7 @@ add_subdirectory(OpenMP)
 add_subdirectory(PDL)
 add_subdirectory(PDLInterp)
 add_subdirectory(Polynomial)
+add_subdirectory(Ptr)
 add_subdirectory(Quant)
 add_subdirectory(SCF)
 add_subdirectory(Shape)

diff  --git a/mlir/lib/Dialect/Ptr/CMakeLists.txt b/mlir/lib/Dialect/Ptr/CMakeLists.txt
new file mode 100644
index 0000000000000..f33061b2d87cf
--- /dev/null
+++ b/mlir/lib/Dialect/Ptr/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(IR)

diff  --git a/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt b/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt
new file mode 100644
index 0000000000000..9cf3643c73d3e
--- /dev/null
+++ b/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt
@@ -0,0 +1,16 @@
+add_mlir_dialect_library(
+  MLIRPtrDialect
+  PtrAttrs.cpp
+  PtrTypes.cpp
+  PtrDialect.cpp
+
+  DEPENDS
+  MLIRPtrOpsAttributesIncGen
+  MLIRPtrOpsIncGen
+
+  LINK_LIBS
+  PUBLIC
+  MLIRIR
+  MLIRDataLayoutInterfaces
+  MLIRMemorySlotInterfaces
+)

diff  --git a/mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp b/mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp
new file mode 100644
index 0000000000000..f8ce820d0bcbd
--- /dev/null
+++ b/mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp
@@ -0,0 +1,40 @@
+//===- PtrAttrs.cpp - Pointer dialect attributes ----------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the Ptr dialect attributes.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Ptr/IR/PtrAttrs.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+using namespace mlir;
+using namespace mlir::ptr;
+
+constexpr const static unsigned kBitsInByte = 8;
+
+//===----------------------------------------------------------------------===//
+// SpecAttr
+//===----------------------------------------------------------------------===//
+
+LogicalResult SpecAttr::verify(function_ref<InFlightDiagnostic()> emitError,
+                               uint32_t size, uint32_t abi, uint32_t preferred,
+                               uint32_t index) {
+  if (size % kBitsInByte != 0)
+    return emitError() << "size entry must be divisible by 8";
+  if (abi % kBitsInByte != 0)
+    return emitError() << "abi entry must be divisible by 8";
+  if (preferred % kBitsInByte != 0)
+    return emitError() << "preferred entry must be divisible by 8";
+  if (index != kOptionalSpecValue && index % kBitsInByte != 0)
+    return emitError() << "index entry must be divisible by 8";
+  if (abi > preferred)
+    return emitError() << "preferred alignment is expected to be at least "
+                          "as large as ABI alignment";
+  return success();
+}

diff  --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
new file mode 100644
index 0000000000000..7830ffe893dfd
--- /dev/null
+++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
@@ -0,0 +1,55 @@
+//===- PtrDialect.cpp - Pointer dialect ---------------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the Pointer dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Ptr/IR/PtrOps.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/InliningUtils.h"
+#include "llvm/ADT/SmallString.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+using namespace mlir;
+using namespace mlir::ptr;
+
+//===----------------------------------------------------------------------===//
+// Pointer dialect
+//===----------------------------------------------------------------------===//
+
+void PtrDialect::initialize() {
+  addOperations<
+#define GET_OP_LIST
+#include "mlir/Dialect/Ptr/IR/PtrOps.cpp.inc"
+      >();
+  addAttributes<
+#define GET_ATTRDEF_LIST
+#include "mlir/Dialect/Ptr/IR/PtrOpsAttrs.cpp.inc"
+      >();
+  addTypes<
+#define GET_TYPEDEF_LIST
+#include "mlir/Dialect/Ptr/IR/PtrOpsTypes.cpp.inc"
+      >();
+}
+
+//===----------------------------------------------------------------------===//
+// Pointer API.
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Ptr/IR/PtrOpsDialect.cpp.inc"
+
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/Ptr/IR/PtrOpsAttrs.cpp.inc"
+
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/Ptr/IR/PtrOpsTypes.cpp.inc"
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/Ptr/IR/PtrOps.cpp.inc"

diff  --git a/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp b/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp
new file mode 100644
index 0000000000000..2866d4eb10feb
--- /dev/null
+++ b/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp
@@ -0,0 +1,144 @@
+//===- PtrTypes.cpp - Pointer dialect types ---------------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the Ptr dialect types.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Ptr/IR/PtrTypes.h"
+#include "mlir/Dialect/Ptr/IR/PtrAttrs.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+using namespace mlir;
+using namespace mlir::ptr;
+
+//===----------------------------------------------------------------------===//
+// Pointer type
+//===----------------------------------------------------------------------===//
+
+constexpr const static unsigned kDefaultPointerSizeBits = 64;
+constexpr const static unsigned kBitsInByte = 8;
+constexpr const static unsigned kDefaultPointerAlignment = 8;
+
+static Attribute getDefaultMemorySpace(PtrType ptr) { return nullptr; }
+
+/// Searches the data layout for the pointer spec, returns nullptr if it is not
+/// found.
+static SpecAttr getPointerSpec(DataLayoutEntryListRef params, PtrType type) {
+  for (DataLayoutEntryInterface entry : params) {
+    if (!entry.isTypeEntry())
+      continue;
+    if (cast<PtrType>(entry.getKey().get<Type>()).getMemorySpace() ==
+        type.getMemorySpace()) {
+      if (auto spec = dyn_cast<SpecAttr>(entry.getValue()))
+        return spec;
+    }
+  }
+  // If not found, and this is the pointer to the default memory space, assume
+  // 64-bit pointers.
+  if (type.getMemorySpace() == getDefaultMemorySpace(type))
+    return SpecAttr::get(type.getContext(), kDefaultPointerSizeBits,
+                         kDefaultPointerAlignment, kDefaultPointerAlignment,
+                         kDefaultPointerSizeBits);
+  return nullptr;
+}
+
+bool PtrType::areCompatible(DataLayoutEntryListRef oldLayout,
+                            DataLayoutEntryListRef newLayout) const {
+  for (DataLayoutEntryInterface newEntry : newLayout) {
+    if (!newEntry.isTypeEntry())
+      continue;
+    uint32_t size = kDefaultPointerSizeBits;
+    uint32_t abi = kDefaultPointerAlignment;
+    auto newType = llvm::cast<PtrType>(newEntry.getKey().get<Type>());
+    const auto *it =
+        llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
+          if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
+            return llvm::cast<PtrType>(type).getMemorySpace() ==
+                   newType.getMemorySpace();
+          }
+          return false;
+        });
+    if (it == oldLayout.end()) {
+      it = llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
+        if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
+          auto ptrTy = llvm::cast<PtrType>(type);
+          return ptrTy.getMemorySpace() == getDefaultMemorySpace(ptrTy);
+        }
+        return false;
+      });
+    }
+    if (it != oldLayout.end()) {
+      auto spec = llvm::cast<SpecAttr>(*it);
+      size = spec.getSize();
+      abi = spec.getAbi();
+    }
+
+    auto newSpec = llvm::cast<SpecAttr>(newEntry.getValue());
+    uint32_t newSize = newSpec.getSize();
+    uint32_t newAbi = newSpec.getAbi();
+    if (size != newSize || abi < newAbi || abi % newAbi != 0)
+      return false;
+  }
+  return true;
+}
+
+uint64_t PtrType::getABIAlignment(const DataLayout &dataLayout,
+                                  DataLayoutEntryListRef params) const {
+  if (SpecAttr spec = getPointerSpec(params, *this))
+    return spec.getAbi() / kBitsInByte;
+
+  return dataLayout.getTypeABIAlignment(
+      get(getContext(), getDefaultMemorySpace(*this)));
+}
+
+std::optional<uint64_t>
+PtrType::getIndexBitwidth(const DataLayout &dataLayout,
+                          DataLayoutEntryListRef params) const {
+  if (SpecAttr spec = getPointerSpec(params, *this)) {
+    return spec.getIndex() == SpecAttr::kOptionalSpecValue ? spec.getSize()
+                                                           : spec.getIndex();
+  }
+
+  return dataLayout.getTypeIndexBitwidth(
+      get(getContext(), getDefaultMemorySpace(*this)));
+}
+
+llvm::TypeSize PtrType::getTypeSizeInBits(const DataLayout &dataLayout,
+                                          DataLayoutEntryListRef params) const {
+  if (SpecAttr spec = getPointerSpec(params, *this))
+    return llvm::TypeSize::getFixed(spec.getSize());
+
+  // For other memory spaces, use the size of the pointer to the default memory
+  // space.
+  return dataLayout.getTypeSizeInBits(
+      get(getContext(), getDefaultMemorySpace(*this)));
+}
+
+uint64_t PtrType::getPreferredAlignment(const DataLayout &dataLayout,
+                                        DataLayoutEntryListRef params) const {
+  if (SpecAttr spec = getPointerSpec(params, *this))
+    return spec.getPreferred() / kBitsInByte;
+
+  return dataLayout.getTypePreferredAlignment(
+      get(getContext(), getDefaultMemorySpace(*this)));
+}
+
+LogicalResult PtrType::verifyEntries(DataLayoutEntryListRef entries,
+                                     Location loc) const {
+  for (DataLayoutEntryInterface entry : entries) {
+    if (!entry.isTypeEntry())
+      continue;
+    auto key = entry.getKey().get<Type>();
+    if (!llvm::isa<SpecAttr>(entry.getValue())) {
+      return emitError(loc) << "expected layout attribute for " << key
+                            << " to be a #ptr.spec attribute";
+    }
+  }
+  return success();
+}

diff  --git a/mlir/test/Dialect/Ptr/layout.mlir b/mlir/test/Dialect/Ptr/layout.mlir
new file mode 100644
index 0000000000000..73189a388942a
--- /dev/null
+++ b/mlir/test/Dialect/Ptr/layout.mlir
@@ -0,0 +1,114 @@
+// RUN: mlir-opt --test-data-layout-query --split-input-file --verify-diagnostics %s | FileCheck %s
+
+module attributes { dlti.dl_spec = #dlti.dl_spec<
+  #dlti.dl_entry<!ptr.ptr, #ptr.spec<size = 32, abi = 32, preferred = 64>>,
+  #dlti.dl_entry<!ptr.ptr<5>,#ptr.spec<size = 64, abi = 64, preferred = 64>>,
+  #dlti.dl_entry<!ptr.ptr<4>, #ptr.spec<size = 32, abi = 64, preferred = 64, index = 24>>,
+  #dlti.dl_entry<"dlti.alloca_memory_space", 5 : ui64>,
+  #dlti.dl_entry<"dlti.global_memory_space", 2 : ui64>,
+  #dlti.dl_entry<"dlti.program_memory_space", 3 : ui64>,
+  #dlti.dl_entry<"dlti.stack_alignment", 128 : i64>
+>} {
+  // CHECK: @spec
+  func.func @spec() {
+    // CHECK: alignment = 4
+    // CHECK: alloca_memory_space = 5
+    // CHECK: bitsize = 32
+    // CHECK: global_memory_space = 2
+    // CHECK: index = 32
+    // CHECK: preferred = 8
+    // CHECK: program_memory_space = 3
+    // CHECK: size = 4
+    // CHECK: stack_alignment = 128
+    "test.data_layout_query"() : () -> !ptr.ptr
+    // CHECK: alignment = 4
+    // CHECK: alloca_memory_space = 5
+    // CHECK: bitsize = 32
+    // CHECK: global_memory_space = 2
+    // CHECK: index = 32
+    // CHECK: preferred = 8
+    // CHECK: program_memory_space = 3
+    // CHECK: size = 4
+    // CHECK: stack_alignment = 128
+    "test.data_layout_query"() : () -> !ptr.ptr<3>
+    // CHECK: alignment = 8
+    // CHECK: alloca_memory_space = 5
+    // CHECK: bitsize = 64
+    // CHECK: global_memory_space = 2
+    // CHECK: index = 64
+    // CHECK: preferred = 8
+    // CHECK: program_memory_space = 3
+    // CHECK: size = 8
+    // CHECK: stack_alignment = 128
+    "test.data_layout_query"() : () -> !ptr.ptr<5>
+    // CHECK: alignment = 8
+    // CHECK: alloca_memory_space = 5
+    // CHECK: bitsize = 32
+    // CHECK: global_memory_space = 2
+    // CHECK: index = 24
+    // CHECK: preferred = 8
+    // CHECK: program_memory_space = 3
+    // CHECK: size = 4
+    // CHECK: stack_alignment = 128
+    "test.data_layout_query"() : () -> !ptr.ptr<4>
+    return
+  }
+}
+
+// -----
+
+// expected-error at +2 {{preferred alignment is expected to be at least as large as ABI alignment}}
+module attributes { dlti.dl_spec = #dlti.dl_spec<
+  #dlti.dl_entry<!ptr.ptr, #ptr.spec<size = 64, abi = 64, preferred = 32>>
+>} {
+  func.func @pointer() {
+    return
+  }
+}
+
+// -----
+
+// expected-error at +2 {{size entry must be divisible by 8}}
+module attributes { dlti.dl_spec = #dlti.dl_spec<
+  #dlti.dl_entry<!ptr.ptr, #ptr.spec<size = 33, abi = 32, preferred = 32>>
+>} {
+  func.func @pointer() {
+    return
+  }
+}
+
+
+// -----
+
+// expected-error at +2 {{abi entry must be divisible by 8}}
+module attributes { dlti.dl_spec = #dlti.dl_spec<
+  #dlti.dl_entry<!ptr.ptr, #ptr.spec<size = 32, abi = 33, preferred = 64>>
+>} {
+  func.func @pointer() {
+    return
+  }
+}
+
+
+// -----
+
+// expected-error at +2 {{preferred entry must be divisible by 8}}
+module attributes { dlti.dl_spec = #dlti.dl_spec<
+  #dlti.dl_entry<!ptr.ptr, #ptr.spec<size = 32, abi = 32, preferred = 33>>
+>} {
+  func.func @pointer() {
+    return
+  }
+}
+
+
+// -----
+
+// expected-error at +2 {{index entry must be divisible by 8}}
+module attributes { dlti.dl_spec = #dlti.dl_spec<
+  #dlti.dl_entry<!ptr.ptr, #ptr.spec<size = 32, abi = 32, preferred = 32, index = 33>>
+>} {
+  func.func @pointer() {
+    return
+  }
+}

diff  --git a/mlir/test/Dialect/Ptr/types.mlir b/mlir/test/Dialect/Ptr/types.mlir
new file mode 100644
index 0000000000000..279213bd6fc3e
--- /dev/null
+++ b/mlir/test/Dialect/Ptr/types.mlir
@@ -0,0 +1,17 @@
+// RUN: mlir-opt %s -split-input-file | mlir-opt | FileCheck %s
+
+// CHECK-LABEL: func @ptr_test
+// CHECK: (%[[ARG0:.*]]: !ptr.ptr, %[[ARG1:.*]]: !ptr.ptr<1 : i32>)
+// CHECK: -> (!ptr.ptr<1 : i32>, !ptr.ptr)
+func.func @ptr_test(%arg0: !ptr.ptr, %arg1: !ptr.ptr<1 : i32>) -> (!ptr.ptr<1 : i32>, !ptr.ptr) {
+  // CHECK: return %[[ARG1]], %[[ARG0]] : !ptr.ptr<1 : i32>, !ptr.ptr
+  return %arg1, %arg0 : !ptr.ptr<1 : i32>, !ptr.ptr
+}
+
+// -----
+
+// CHECK-LABEL: func @ptr_test
+// CHECK: %[[ARG:.*]]: memref<!ptr.ptr>
+func.func @ptr_test(%arg0: memref<!ptr.ptr>) {
+  return
+}


        


More information about the Mlir-commits mailing list