[Mlir-commits] [mlir] [mlir][Ptr] Add the `MemorySpaceAttrInterface` interface and dependencies. (PR #86870)
Fabian Mora
llvmlistbot at llvm.org
Tue Apr 2 08:41:04 PDT 2024
https://github.com/fabianmcg updated https://github.com/llvm/llvm-project/pull/86870
>From af1ff4d4d302d0722e06ab0068eae917a30bb09d Mon Sep 17 00:00:00 2001
From: Fabian Mora <fmora.dev at gmail.com>
Date: Wed, 27 Mar 2024 19:26:30 +0000
Subject: [PATCH 1/3] [mlir][Ptr] Init the Ptr dialect with the `!ptr.ptr`
type.
This patch initializes the `ptr` dialect directories and some base files. It
also add the `!ptr.ptr` type, together with the `DataLayoutTypeInterface`
interface. The implementation of the `DataLayoutTypeInterface` interface
clones the implementation from `LLVM::LLVMPointerType`.
---
mlir/include/mlir/Dialect/CMakeLists.txt | 1 +
mlir/include/mlir/Dialect/Ptr/CMakeLists.txt | 1 +
.../mlir/Dialect/Ptr/IR/CMakeLists.txt | 2 +
mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.h | 20 ++
.../include/mlir/Dialect/Ptr/IR/PtrDialect.td | 83 ++++++++
mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h | 24 +++
mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td | 15 ++
mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h | 37 ++++
mlir/include/mlir/InitAllDialects.h | 2 +
mlir/lib/Dialect/CMakeLists.txt | 1 +
mlir/lib/Dialect/Ptr/CMakeLists.txt | 1 +
mlir/lib/Dialect/Ptr/IR/CMakeLists.txt | 14 ++
mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp | 48 +++++
mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp | 184 ++++++++++++++++++
mlir/test/Dialect/Ptr/types.mlir | 17 ++
15 files changed, 450 insertions(+)
create mode 100644 mlir/include/mlir/Dialect/Ptr/CMakeLists.txt
create mode 100644 mlir/include/mlir/Dialect/Ptr/IR/CMakeLists.txt
create mode 100644 mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.h
create mode 100644 mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
create mode 100644 mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h
create mode 100644 mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
create mode 100644 mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h
create mode 100644 mlir/lib/Dialect/Ptr/CMakeLists.txt
create mode 100644 mlir/lib/Dialect/Ptr/IR/CMakeLists.txt
create mode 100644 mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
create mode 100644 mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp
create mode 100644 mlir/test/Dialect/Ptr/types.mlir
diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt
index 2da79011fa26a3..5f0e9806926145 100644
--- a/mlir/include/mlir/Dialect/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/CMakeLists.txt
@@ -28,6 +28,7 @@ add_subdirectory(OpenACCMPCommon)
add_subdirectory(OpenMP)
add_subdirectory(PDL)
add_subdirectory(PDLInterp)
+add_subdirectory(Ptr)
add_subdirectory(Quant)
add_subdirectory(SCF)
add_subdirectory(Shape)
diff --git a/mlir/include/mlir/Dialect/Ptr/CMakeLists.txt b/mlir/include/mlir/Dialect/Ptr/CMakeLists.txt
new file mode 100644
index 00000000000000..f33061b2d87cff
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Ptr/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(IR)
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Ptr/IR/CMakeLists.txt
new file mode 100644
index 00000000000000..c6ffa892e4ecba
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Ptr/IR/CMakeLists.txt
@@ -0,0 +1,2 @@
+add_mlir_dialect(PtrOps ptr)
+add_mlir_doc(PtrOps PtrOps Dialects/ -gen-op-doc)
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.h b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.h
new file mode 100644
index 00000000000000..92f877c20dbf07
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.h
@@ -0,0 +1,20 @@
+//===- PtrDialect.h - Pointer dialect ---------------------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the Ptr dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_PTR_IR_PTRDIALECT_H
+#define MLIR_DIALECT_PTR_IR_PTRDIALECT_H
+
+#include "mlir/IR/Dialect.h"
+
+#include "mlir/Dialect/Ptr/IR/PtrOpsDialect.h.inc"
+
+#endif // MLIR_DIALECT_PTR_IR_PTRDIALECT_H
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
new file mode 100644
index 00000000000000..bffae6b1ad71bb
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
@@ -0,0 +1,83 @@
+//===- PtrDialect.td - Pointer dialect ---------------------*- tablegen -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef PTR_DIALECT
+#define PTR_DIALECT
+
+include "mlir/Interfaces/DataLayoutInterfaces.td"
+include "mlir/IR/AttrTypeBase.td"
+include "mlir/IR/BuiltinTypeInterfaces.td"
+include "mlir/IR/OpBase.td"
+
+//===----------------------------------------------------------------------===//
+// Pointer dialect definition.
+//===----------------------------------------------------------------------===//
+
+def Ptr_Dialect : Dialect {
+ let name = "ptr";
+ let summary = "Pointer dialect";
+ let cppNamespace = "::mlir::ptr";
+ let useDefaultTypePrinterParser = 1;
+ let useDefaultAttributePrinterParser = 0;
+}
+
+//===----------------------------------------------------------------------===//
+// Pointer type definitions
+//===----------------------------------------------------------------------===//
+
+class Ptr_Type<string name, string typeMnemonic, list<Trait> traits = []>
+ : TypeDef<Ptr_Dialect, name, traits> {
+ let mnemonic = typeMnemonic;
+}
+
+def Ptr_PtrType : Ptr_Type<"Ptr", "ptr", [
+ MemRefElementTypeInterface,
+ DeclareTypeInterfaceMethods<DataLayoutTypeInterface, [
+ "areCompatible", "getIndexBitwidth", "verifyEntries"]>
+ ]> {
+ let summary = "pointer type";
+ let description = [{
+ The `ptr` type is an opaque pointer type. This type typically represents
+ a reference to an object in memory. Pointers are optionally parameterized
+ by a memory space.
+ Syntax:
+
+ ```mlir
+ pointer ::= `ptr` (`<` memory-space `>`)?
+ memory-space ::= attribute-value
+ ```
+ }];
+ let parameters = (ins OptionalParameter<"Attribute">:$memorySpace);
+ let assemblyFormat = "(`<` $memorySpace^ `>`)?";
+ let builders = [
+ TypeBuilder<(ins CArg<"Attribute", "nullptr">:$addressSpace), [{
+ return $_get($_ctxt, addressSpace);
+ }]>,
+ TypeBuilder<(ins CArg<"unsigned">:$addressSpace), [{
+ return $_get($_ctxt, IntegerAttr::get(IntegerType::get($_ctxt, 32),
+ addressSpace));
+ }]>
+ ];
+ let skipDefaultBuilders = 1;
+ let extraClassDeclaration = [{
+ /// Returns the default memory space.
+ Attribute getDefaultMemorySpace() const;
+
+ /// Returns the memory space as an unsigned number.
+ int64_t getAddressSpace() const;
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// Base address operation definition.
+//===----------------------------------------------------------------------===//
+
+class Pointer_Op<string mnemonic, list<Trait> traits = []> :
+ Op<Ptr_Dialect, mnemonic, traits>;
+
+#endif // PTR_DIALECT
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h
new file mode 100644
index 00000000000000..ad8a2bbcbdd8d2
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h
@@ -0,0 +1,24 @@
+//===- PtrDialect.h - Pointer dialect ---------------------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the Ptr dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_PTR_IR_PTROPS_H
+#define MLIR_DIALECT_PTR_IR_PTROPS_H
+
+#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/Ptr/IR/PtrDialect.h"
+#include "mlir/Dialect/Ptr/IR/PtrTypes.h"
+#include "mlir/IR/OpDefinition.h"
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/Ptr/IR/PtrOps.h.inc"
+
+#endif // MLIR_DIALECT_PTR_IR_PTROPS_H
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
new file mode 100644
index 00000000000000..690941337bdfb5
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
@@ -0,0 +1,15 @@
+//===- PtrOps.td - Pointer dialect ops ---------------------*- tablegen -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://ptr.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef PTR_OPS
+#define PTR_OPS
+
+include "mlir/Dialect/Ptr/IR/PtrDialect.td"
+include "mlir/IR/OpAsmInterface.td"
+
+#endif // PTR_OPS
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h b/mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h
new file mode 100644
index 00000000000000..9984aedcbf6ce8
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h
@@ -0,0 +1,37 @@
+//===- PtrTypes.h - Pointer types -------------------------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the Pointer dialect types.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_PTR_IR_PTRTYPES_H
+#define MLIR_DIALECT_PTR_IR_PTRTYPES_H
+
+#include "mlir/IR/Types.h"
+#include "mlir/Interfaces/DataLayoutInterfaces.h"
+
+namespace mlir {
+namespace ptr {
+/// The positions of different values in the data layout entry for pointers.
+enum class PtrDLEntryPos { Size = 0, Abi = 1, Preferred = 2, Index = 3 };
+
+/// Returns the value that corresponds to named position `pos` from the
+/// data layout entry `attr` assuming it's a dense integer elements attribute.
+/// Returns `std::nullopt` if `pos` is not present in the entry.
+/// Currently only `PtrDLEntryPos::Index` is optional, and all other positions
+/// may be assumed to be present.
+std::optional<uint64_t> extractPointerSpecValue(Attribute attr,
+ PtrDLEntryPos pos);
+} // namespace ptr
+} // namespace mlir
+
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/Ptr/IR/PtrOpsTypes.h.inc"
+
+#endif // MLIR_DIALECT_PTR_IR_PTRTYPES_H
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index c558dc53cc7fac..b0cbb720519edd 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -61,6 +61,7 @@
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
+#include "mlir/Dialect/Ptr/IR/PtrDialect.h"
#include "mlir/Dialect/Quant/QuantOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.h"
@@ -131,6 +132,7 @@ inline void registerAllDialects(DialectRegistry ®istry) {
omp::OpenMPDialect,
pdl::PDLDialect,
pdl_interp::PDLInterpDialect,
+ ptr::PtrDialect,
quant::QuantizationDialect,
ROCDL::ROCDLDialect,
scf::SCFDialect,
diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt
index b1ba5a3bc8817d..6fa8a610e196c4 100644
--- a/mlir/lib/Dialect/CMakeLists.txt
+++ b/mlir/lib/Dialect/CMakeLists.txt
@@ -28,6 +28,7 @@ add_subdirectory(OpenACCMPCommon)
add_subdirectory(OpenMP)
add_subdirectory(PDL)
add_subdirectory(PDLInterp)
+add_subdirectory(Ptr)
add_subdirectory(Quant)
add_subdirectory(SCF)
add_subdirectory(Shape)
diff --git a/mlir/lib/Dialect/Ptr/CMakeLists.txt b/mlir/lib/Dialect/Ptr/CMakeLists.txt
new file mode 100644
index 00000000000000..f33061b2d87cff
--- /dev/null
+++ b/mlir/lib/Dialect/Ptr/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(IR)
diff --git a/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt b/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt
new file mode 100644
index 00000000000000..359b9f02a06266
--- /dev/null
+++ b/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt
@@ -0,0 +1,14 @@
+add_mlir_dialect_library(
+ MLIRPtrDialect
+ PtrTypes.cpp
+ PtrDialect.cpp
+ ADDITIONAL_HEADER_DIRS
+ ${PROJECT_SOURCE_DIR}/mlir/Dialect/Pointer
+ DEPENDS
+ MLIRPtrOpsIncGen
+ LINK_LIBS
+ PUBLIC
+ MLIRIR
+ MLIRDataLayoutInterfaces
+ MLIRMemorySlotInterfaces
+)
diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
new file mode 100644
index 00000000000000..59c97b22f332c4
--- /dev/null
+++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
@@ -0,0 +1,48 @@
+//===- PtrDialect.cpp - Pointer dialect ---------------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the Pointer dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Ptr/IR/PtrOps.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/InliningUtils.h"
+#include "llvm/ADT/SmallString.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+using namespace mlir;
+using namespace mlir::ptr;
+
+//===----------------------------------------------------------------------===//
+// Pointer dialect
+//===----------------------------------------------------------------------===//
+
+void PtrDialect::initialize() {
+ addOperations<
+#define GET_OP_LIST
+#include "mlir/Dialect/Ptr/IR/PtrOps.cpp.inc"
+ >();
+ addTypes<
+#define GET_TYPEDEF_LIST
+#include "mlir/Dialect/Ptr/IR/PtrOpsTypes.cpp.inc"
+ >();
+}
+
+//===----------------------------------------------------------------------===//
+// Pointer API.
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Ptr/IR/PtrOpsDialect.cpp.inc"
+
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/Ptr/IR/PtrOpsTypes.cpp.inc"
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/Ptr/IR/PtrOps.cpp.inc"
diff --git a/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp b/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp
new file mode 100644
index 00000000000000..51d0a45051b85e
--- /dev/null
+++ b/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp
@@ -0,0 +1,184 @@
+//===- PtrTypes.cpp - Pointer dialect types ---------------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the Ptr dialect types.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Ptr/IR/PtrTypes.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+using namespace mlir;
+using namespace mlir::ptr;
+
+//===----------------------------------------------------------------------===//
+// Pointer type
+//===----------------------------------------------------------------------===//
+
+constexpr const static unsigned kDefaultPointerSizeBits = 64;
+constexpr const static unsigned kBitsInByte = 8;
+constexpr const static unsigned kDefaultPointerAlignment = 8;
+
+/// Returns the part of the data layout entry that corresponds to `pos` for the
+/// given `type` by interpreting the list of entries `params`. For the pointer
+/// type in the default address space, returns the default value if the entries
+/// do not provide a custom one, for other address spaces returns std::nullopt.
+static std::optional<uint64_t>
+getPointerDataLayoutEntry(DataLayoutEntryListRef params, PtrType type,
+ PtrDLEntryPos pos) {
+ // First, look for the entry for the pointer in the current address space.
+ Attribute currentEntry;
+ for (DataLayoutEntryInterface entry : params) {
+ if (!entry.isTypeEntry())
+ continue;
+ if (cast<PtrType>(entry.getKey().get<Type>()).getAddressSpace() ==
+ type.getAddressSpace()) {
+ currentEntry = entry.getValue();
+ break;
+ }
+ }
+ if (currentEntry) {
+ std::optional<uint64_t> value = extractPointerSpecValue(currentEntry, pos);
+ // If the optional `PtrDLEntryPos::Index` entry is not available, use the
+ // pointer size as the index bitwidth.
+ if (!value && pos == PtrDLEntryPos::Index)
+ value = extractPointerSpecValue(currentEntry, PtrDLEntryPos::Size);
+ bool isSizeOrIndex =
+ pos == PtrDLEntryPos::Size || pos == PtrDLEntryPos::Index;
+ return *value / (isSizeOrIndex ? 1 : kBitsInByte);
+ }
+
+ // If not found, and this is the pointer to the default memory space, assume
+ // 64-bit pointers.
+ if (type.getAddressSpace() == 0) {
+ bool isSizeOrIndex =
+ pos == PtrDLEntryPos::Size || pos == PtrDLEntryPos::Index;
+ return isSizeOrIndex ? kDefaultPointerSizeBits : kDefaultPointerAlignment;
+ }
+
+ return std::nullopt;
+}
+
+int64_t PtrType::getAddressSpace() const { return 0; }
+
+Attribute PtrType::getDefaultMemorySpace() const { return nullptr; }
+
+bool PtrType::areCompatible(DataLayoutEntryListRef oldLayout,
+ DataLayoutEntryListRef newLayout) const {
+ for (DataLayoutEntryInterface newEntry : newLayout) {
+ if (!newEntry.isTypeEntry())
+ continue;
+ unsigned size = kDefaultPointerSizeBits;
+ unsigned abi = kDefaultPointerAlignment;
+ auto newType = llvm::cast<PtrType>(newEntry.getKey().get<Type>());
+ const auto *it =
+ llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
+ if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
+ return llvm::cast<PtrType>(type).getMemorySpace() ==
+ newType.getMemorySpace();
+ }
+ return false;
+ });
+ if (it == oldLayout.end()) {
+ llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
+ if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
+ return llvm::cast<PtrType>(type).getAddressSpace() == 0;
+ }
+ return false;
+ });
+ }
+ if (it != oldLayout.end()) {
+ size = *extractPointerSpecValue(*it, PtrDLEntryPos::Size);
+ abi = *extractPointerSpecValue(*it, PtrDLEntryPos::Abi);
+ }
+
+ Attribute newSpec = llvm::cast<DenseIntElementsAttr>(newEntry.getValue());
+ unsigned newSize = *extractPointerSpecValue(newSpec, PtrDLEntryPos::Size);
+ unsigned newAbi = *extractPointerSpecValue(newSpec, PtrDLEntryPos::Abi);
+ if (size != newSize || abi < newAbi || abi % newAbi != 0)
+ return false;
+ }
+ return true;
+}
+
+uint64_t PtrType::getABIAlignment(const DataLayout &dataLayout,
+ DataLayoutEntryListRef params) const {
+ if (std::optional<uint64_t> alignment =
+ getPointerDataLayoutEntry(params, *this, PtrDLEntryPos::Abi))
+ return *alignment;
+
+ return dataLayout.getTypeABIAlignment(
+ get(getContext(), getDefaultMemorySpace()));
+}
+
+std::optional<uint64_t>
+PtrType::getIndexBitwidth(const DataLayout &dataLayout,
+ DataLayoutEntryListRef params) const {
+ if (std::optional<uint64_t> indexBitwidth =
+ getPointerDataLayoutEntry(params, *this, PtrDLEntryPos::Index))
+ return *indexBitwidth;
+
+ return dataLayout.getTypeIndexBitwidth(
+ get(getContext(), getDefaultMemorySpace()));
+}
+
+llvm::TypeSize PtrType::getTypeSizeInBits(const DataLayout &dataLayout,
+ DataLayoutEntryListRef params) const {
+ if (std::optional<uint64_t> size =
+ getPointerDataLayoutEntry(params, *this, PtrDLEntryPos::Size))
+ return llvm::TypeSize::getFixed(*size);
+
+ // For other memory spaces, use the size of the pointer to the default memory
+ // space.
+ return dataLayout.getTypeSizeInBits(
+ get(getContext(), getDefaultMemorySpace()));
+}
+
+uint64_t PtrType::getPreferredAlignment(const DataLayout &dataLayout,
+ DataLayoutEntryListRef params) const {
+ if (std::optional<uint64_t> alignment =
+ getPointerDataLayoutEntry(params, *this, PtrDLEntryPos::Preferred))
+ return *alignment;
+
+ return dataLayout.getTypePreferredAlignment(
+ get(getContext(), getDefaultMemorySpace()));
+}
+
+std::optional<uint64_t> mlir::ptr::extractPointerSpecValue(Attribute attr,
+ PtrDLEntryPos pos) {
+ auto spec = cast<DenseIntElementsAttr>(attr);
+ auto idx = static_cast<int64_t>(pos);
+ if (idx >= spec.size())
+ return std::nullopt;
+ return spec.getValues<uint64_t>()[idx];
+}
+
+LogicalResult PtrType::verifyEntries(DataLayoutEntryListRef entries,
+ Location loc) const {
+ for (DataLayoutEntryInterface entry : entries) {
+ if (!entry.isTypeEntry())
+ continue;
+ auto key = entry.getKey().get<Type>();
+ auto values = llvm::dyn_cast<DenseIntElementsAttr>(entry.getValue());
+ if (!values || (values.size() != 3 && values.size() != 4)) {
+ return emitError(loc)
+ << "expected layout attribute for " << key
+ << " to be a dense integer elements attribute with 3 or 4 "
+ "elements";
+ }
+ if (!values.getElementType().isInteger(64))
+ return emitError(loc) << "expected i64 parameters for " << key;
+
+ if (extractPointerSpecValue(values, PtrDLEntryPos::Abi) >
+ extractPointerSpecValue(values, PtrDLEntryPos::Preferred)) {
+ return emitError(loc) << "preferred alignment is expected to be at least "
+ "as large as ABI alignment";
+ }
+ }
+ return success();
+}
diff --git a/mlir/test/Dialect/Ptr/types.mlir b/mlir/test/Dialect/Ptr/types.mlir
new file mode 100644
index 00000000000000..279213bd6fc3e5
--- /dev/null
+++ b/mlir/test/Dialect/Ptr/types.mlir
@@ -0,0 +1,17 @@
+// RUN: mlir-opt %s -split-input-file | mlir-opt | FileCheck %s
+
+// CHECK-LABEL: func @ptr_test
+// CHECK: (%[[ARG0:.*]]: !ptr.ptr, %[[ARG1:.*]]: !ptr.ptr<1 : i32>)
+// CHECK: -> (!ptr.ptr<1 : i32>, !ptr.ptr)
+func.func @ptr_test(%arg0: !ptr.ptr, %arg1: !ptr.ptr<1 : i32>) -> (!ptr.ptr<1 : i32>, !ptr.ptr) {
+ // CHECK: return %[[ARG1]], %[[ARG0]] : !ptr.ptr<1 : i32>, !ptr.ptr
+ return %arg1, %arg0 : !ptr.ptr<1 : i32>, !ptr.ptr
+}
+
+// -----
+
+// CHECK-LABEL: func @ptr_test
+// CHECK: %[[ARG:.*]]: memref<!ptr.ptr>
+func.func @ptr_test(%arg0: memref<!ptr.ptr>) {
+ return
+}
>From 827ba48faec31ded77b64b488533624345cba831 Mon Sep 17 00:00:00 2001
From: Fabian Mora <fmora.dev at gmail.com>
Date: Wed, 27 Mar 2024 20:31:55 +0000
Subject: [PATCH 2/3] [mlir][Ptr] Add the `MemorySpaceAttrInterface` interface
and dependencies.
This patch introduces the `MemorySpaceAttrInterface` interface. This
interface is responsible for handling the semantics of `ptr` operations.
For example, this interface can be used to create read-only memory spaces,
making any other operation other than a load a verification error, see
`TestConstMemorySpaceAttr` for a possible implementation of this concept.
This patch also introduces Enum depedencies `AtomicOrdering`, and `AtomicBinOp`,
both enumerations are clones of the Enums with same name in the LLVM Dialect.
---
.../mlir/Dialect/Ptr/IR/CMakeLists.txt | 12 ++
.../include/mlir/Dialect/Ptr/IR/MemorySpace.h | 161 ++++++++++++++++
.../Dialect/Ptr/IR/MemorySpaceInterfaces.td | 182 ++++++++++++++++++
mlir/include/mlir/Dialect/Ptr/IR/PtrAttrs.h | 20 ++
mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.td | 69 +++++++
mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h | 2 +
mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td | 1 +
mlir/lib/Dialect/Ptr/IR/CMakeLists.txt | 2 +
mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp | 69 +++++++
mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp | 9 +-
mlir/test/Dialect/Ptr/types.mlir | 7 +
mlir/test/lib/Dialect/Test/CMakeLists.txt | 1 +
mlir/test/lib/Dialect/Test/TestAttrDefs.td | 10 +
mlir/test/lib/Dialect/Test/TestAttributes.cpp | 48 +++++
mlir/test/lib/Dialect/Test/TestAttributes.h | 1 +
15 files changed, 592 insertions(+), 2 deletions(-)
create mode 100644 mlir/include/mlir/Dialect/Ptr/IR/MemorySpace.h
create mode 100644 mlir/include/mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td
create mode 100644 mlir/include/mlir/Dialect/Ptr/IR/PtrAttrs.h
create mode 100644 mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.td
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Ptr/IR/CMakeLists.txt
index c6ffa892e4ecba..80dffdfd402cf2 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Ptr/IR/CMakeLists.txt
@@ -1,2 +1,14 @@
add_mlir_dialect(PtrOps ptr)
add_mlir_doc(PtrOps PtrOps Dialects/ -gen-op-doc)
+
+set(LLVM_TARGET_DEFINITIONS MemorySpaceInterfaces.td)
+mlir_tablegen(MemorySpaceInterfaces.h.inc -gen-op-interface-decls)
+mlir_tablegen(MemorySpaceInterfaces.cpp.inc -gen-op-interface-defs)
+mlir_tablegen(MemorySpaceAttrInterfaces.h.inc -gen-attr-interface-decls)
+mlir_tablegen(MemorySpaceAttrInterfaces.cpp.inc -gen-attr-interface-defs)
+add_public_tablegen_target(MLIRPtrMemorySpaceInterfacesIncGen)
+
+set(LLVM_TARGET_DEFINITIONS PtrOps.td)
+mlir_tablegen(PtrOpsEnums.h.inc -gen-enum-decls)
+mlir_tablegen(PtrOpsEnums.cpp.inc -gen-enum-defs)
+add_public_tablegen_target(MLIRPtrOpsEnumsGen)
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/MemorySpace.h b/mlir/include/mlir/Dialect/Ptr/IR/MemorySpace.h
new file mode 100644
index 00000000000000..e467d121f2c886
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Ptr/IR/MemorySpace.h
@@ -0,0 +1,161 @@
+//===-- MemorySpace.h - ptr dialect memory space ---------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the ptr's dialect memory space class and related
+// interfaces.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_PTR_IR_MEMORYSPACE_H
+#define MLIR_DIALECT_PTR_IR_MEMORYSPACE_H
+
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/OpDefinition.h"
+
+namespace mlir {
+class Operation;
+namespace ptr {
+/// This method checks if it's valid to perform an `addrspacecast` op in the
+/// memory space.
+/// Compatible types are:
+/// Vectors of rank 1, or scalars of `ptr` type.
+LogicalResult isValidAddrSpaceCastImpl(Type tgt, Type src,
+ Operation *diagnosticOp);
+
+/// This method checks if it's valid to perform a `ptrtoint` or `inttoptr` op in
+/// the memory space.
+/// Compatible types are:
+/// IntLikeTy: Vectors of rank 1, or scalars of integer types or `index` type.
+/// PtrLikeTy: Vectors of rank 1, or scalars of `ptr` type.
+LogicalResult isValidPtrIntCastImpl(Type intLikeTy, Type ptrLikeTy,
+ Operation *diagnosticOp);
+
+enum class AtomicBinOp : uint64_t;
+enum class AtomicOrdering : uint64_t;
+} // namespace ptr
+} // namespace mlir
+
+#include "mlir/Dialect/Ptr/IR/MemorySpaceAttrInterfaces.h.inc"
+
+namespace mlir {
+namespace ptr {
+/// This class wraps the `MemorySpaceAttrInterface` interface, providing a safe
+/// mechanism to specify the default behavior assumed by the ptr dialect.
+class MemorySpace {
+public:
+ MemorySpace() = default;
+ MemorySpace(std::nullptr_t) {}
+ MemorySpace(MemorySpaceAttrInterface memorySpace)
+ : memorySpaceAttr(memorySpace), memorySpace(memorySpace) {}
+ MemorySpace(Attribute memorySpace)
+ : memorySpaceAttr(memorySpace),
+ memorySpace(dyn_cast_or_null<MemorySpaceAttrInterface>(memorySpace)) {}
+
+ operator Attribute() const { return memorySpaceAttr; }
+ operator MemorySpaceAttrInterface() const { return memorySpace; }
+ bool operator==(const MemorySpace &memSpace) const {
+ return memSpace.memorySpaceAttr == memorySpaceAttr;
+ }
+
+ /// Returns the underlying memory space.
+ Attribute getUnderlyingSpace() const { return memorySpaceAttr; }
+
+ /// Returns true if the underlying memory space is null.
+ bool isDefaultModel() const { return memorySpace == nullptr; }
+
+ /// Returns the memory space as an integer, or 0 if using the default space.
+ unsigned getAddressSpace() const {
+ if (memorySpace)
+ return memorySpace.getAddressSpace();
+ if (auto intAttr = llvm::dyn_cast_or_null<IntegerAttr>(memorySpaceAttr))
+ return intAttr.getInt();
+ return 0;
+ }
+
+ /// Returns the default memory space as an attribute, or nullptr if using the
+ /// default model.
+ Attribute getDefaultMemorySpace() const {
+ return memorySpace ? memorySpace.getDefaultMemorySpace() : nullptr;
+ }
+
+ /// This method checks if it's valid to load a value from the memory space
+ /// with a specific type, alignment, and atomic ordering. The default model
+ /// assumes all values are loadable.
+ LogicalResult isValidLoad(Type type, AtomicOrdering ordering,
+ IntegerAttr alignment,
+ Operation *diagnosticOp = nullptr) const {
+ return memorySpace ? memorySpace.isValidLoad(type, ordering, alignment,
+ diagnosticOp)
+ : success();
+ }
+
+ /// This method checks if it's valid to store a value in the memory space with
+ /// a specific type, alignment, and atomic ordering. The default model assumes
+ /// all values are loadable.
+ LogicalResult isValidStore(Type type, AtomicOrdering ordering,
+ IntegerAttr alignment,
+ Operation *diagnosticOp = nullptr) const {
+ return memorySpace ? memorySpace.isValidStore(type, ordering, alignment,
+ diagnosticOp)
+ : success();
+ }
+
+ /// This method checks if it's valid to perform an atomic operation in the
+ /// memory space with a specific type, alignment, and atomic ordering.
+ LogicalResult isValidAtomicOp(AtomicBinOp op, Type type,
+ AtomicOrdering ordering, IntegerAttr alignment,
+ Operation *diagnosticOp = nullptr) const {
+ return memorySpace ? memorySpace.isValidAtomicOp(op, type, ordering,
+ alignment, diagnosticOp)
+ : success();
+ }
+
+ /// This method checks if it's valid to perform an atomic operation in the
+ /// memory space with a specific type, alignment, and atomic ordering.
+ LogicalResult isValidAtomicXchg(Type type, AtomicOrdering successOrdering,
+ AtomicOrdering failureOrdering,
+ IntegerAttr alignment,
+ Operation *diagnosticOp = nullptr) const {
+ return memorySpace ? memorySpace.isValidAtomicXchg(type, successOrdering,
+ failureOrdering,
+ alignment, diagnosticOp)
+ : success();
+ }
+
+ /// This method checks if it's valid to perform an `addrspacecast` op in the
+ /// memory space.
+ LogicalResult isValidAddrSpaceCast(Type tgt, Type src,
+ Operation *diagnosticOp = nullptr) const {
+ return memorySpace
+ ? memorySpace.isValidAddrSpaceCast(tgt, src, diagnosticOp)
+ : isValidAddrSpaceCastImpl(tgt, src, diagnosticOp);
+ }
+
+ /// This method checks if it's valid to perform a `ptrtoint` or `inttoptr` op
+ /// in the memory space.
+ LogicalResult isValidPtrIntCast(Type intLikeTy, Type ptrLikeTy,
+ Operation *diagnosticOp = nullptr) const {
+ return memorySpace
+ ? memorySpace.isValidPtrIntCast(intLikeTy, ptrLikeTy,
+ diagnosticOp)
+ : isValidPtrIntCastImpl(intLikeTy, ptrLikeTy, diagnosticOp);
+ }
+
+protected:
+ /// Underlying memory space.
+ Attribute memorySpaceAttr{};
+ /// Memory space.
+ MemorySpaceAttrInterface memorySpace{};
+};
+} // namespace ptr
+} // namespace mlir
+
+#include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h.inc"
+
+#endif // MLIR_DIALECT_PTR_IR_MEMORYSPACE_H
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td b/mlir/include/mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td
new file mode 100644
index 00000000000000..b7bed95434839e
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td
@@ -0,0 +1,182 @@
+//===-- MemorySpaceInterfaces.td - Memory space interfaces ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines memory space attribute interfaces.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef PTR_MEMORYSPACEINTERFACES
+#define PTR_MEMORYSPACEINTERFACES
+
+include "mlir/IR/AttrTypeBase.td"
+include "mlir/IR/OpBase.td"
+
+//===----------------------------------------------------------------------===//
+// Memory space attribute interface.
+//===----------------------------------------------------------------------===//
+
+def MemorySpaceAttrInterface : AttrInterface<"MemorySpaceAttrInterface"> {
+ let description = [{
+ This interface defines a common API for interacting with the memory model of
+ a memory space and the operations in the pointer dialect, giving proper
+ semantical meaning to the ops.
+
+ Furthermore, this interface allows concepts such as read-only memory to be
+ adequately modeled and enforced.
+ }];
+ let cppNamespace = "::mlir::ptr";
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/ [{
+ Returns the dialect implementing the memory space.
+ }],
+ /*returnType=*/ "::mlir::Dialect*",
+ /*methodName=*/ "getMemorySpaceDialect",
+ /*args=*/ (ins),
+ /*methodBody=*/ [{}],
+ /*defaultImpl=*/ [{ return nullptr; }]
+ >,
+ InterfaceMethod<
+ /*desc=*/ [{
+ Returns the default memory space as an attribute.
+ }],
+ /*returnType=*/ "::mlir::Attribute",
+ /*methodName=*/ "getDefaultMemorySpace",
+ /*args=*/ (ins),
+ /*methodBody=*/ [{}],
+ /*defaultImpl=*/ [{}]
+ >,
+ InterfaceMethod<
+ /*desc=*/ [{
+ Returns the memory space as an integer, or 0 if using the default model.
+ }],
+ /*returnType=*/ "unsigned",
+ /*methodName=*/ "getAddressSpace",
+ /*args=*/ (ins),
+ /*methodBody=*/ [{}],
+ /*defaultImpl=*/ [{ return 0; }]
+ >,
+ InterfaceMethod<
+ /*desc=*/ [{
+ This method checks if it's valid to load a value from the memory space
+ with a specific type, alignment, and atomic ordering.
+ If `diagnosticOp` is non-null then the method might emit diagnostics.
+ }],
+ /*returnType=*/ "::mlir::LogicalResult",
+ /*methodName=*/ "isValidLoad",
+ /*args=*/ (ins "::mlir::Type":$type,
+ "::mlir::ptr::AtomicOrdering":$ordering,
+ "::mlir::IntegerAttr":$alignment,
+ "::mlir::Operation*":$diagnosticOp),
+ /*methodBody=*/ [{}],
+ /*defaultImpl=*/ [{}]
+ >,
+ InterfaceMethod<
+ /*desc=*/ [{
+ This method checks if it's valid to store a value in the memory space
+ with a specific type, alignment, and atomic ordering.
+ If `diagnosticOp` is non-null then the method might emit diagnostics.
+ }],
+ /*returnType=*/ "::mlir::LogicalResult",
+ /*methodName=*/ "isValidStore",
+ /*args=*/ (ins "::mlir::Type":$type,
+ "::mlir::ptr::AtomicOrdering":$ordering,
+ "::mlir::IntegerAttr":$alignment,
+ "::mlir::Operation*":$diagnosticOp),
+ /*methodBody=*/ [{}],
+ /*defaultImpl=*/ [{}]
+ >,
+ InterfaceMethod<
+ /*desc=*/ [{
+ This method checks if it's valid to perform an atomic operation in the
+ memory space with a specific type, alignment, and atomic ordering.
+ If `diagnosticOp` is non-null then the method might emit diagnostics.
+ }],
+ /*returnType=*/ "::mlir::LogicalResult",
+ /*methodName=*/ "isValidAtomicOp",
+ /*args=*/ (ins "::mlir::ptr::AtomicBinOp":$op,
+ "::mlir::Type":$type,
+ "::mlir::ptr::AtomicOrdering":$ordering,
+ "::mlir::IntegerAttr":$alignment,
+ "::mlir::Operation*":$diagnosticOp),
+ /*methodBody=*/ [{}],
+ /*defaultImpl=*/ [{}]
+ >,
+ InterfaceMethod<
+ /*desc=*/ [{
+ This method checks if it's valid to perform an atomic exchange operation
+ in the memory space with a specific type, alignment, and atomic
+ orderings.
+ If `diagnosticOp` is non-null then the method might emit diagnostics.
+ }],
+ /*returnType=*/ "::mlir::LogicalResult",
+ /*methodName=*/ "isValidAtomicXchg",
+ /*args=*/ (ins "::mlir::Type":$type,
+ "::mlir::ptr::AtomicOrdering":$successOrdering,
+ "::mlir::ptr::AtomicOrdering":$failureOrdering,
+ "::mlir::IntegerAttr":$alignment,
+ "::mlir::Operation*":$diagnosticOp),
+ /*methodBody=*/ [{}],
+ /*defaultImpl=*/ [{}]
+ >,
+ InterfaceMethod<
+ /*desc=*/ [{
+ This method checks if it's valid to perform an `addrspacecast` op
+ in the memory space.
+ Both types are expected to be vectors of rank 1, or scalars of `ptr`
+ type.
+ If `diagnosticOp` is non-null then the method might emit diagnostics.
+ }],
+ /*returnType=*/ "::mlir::LogicalResult",
+ /*methodName=*/ "isValidAddrSpaceCast",
+ /*args=*/ (ins "::mlir::Type":$tgt,
+ "::mlir::Type":$src,
+ "::mlir::Operation*":$diagnosticOp),
+ /*methodBody=*/ [{}],
+ /*defaultImpl=*/ [{}]
+ >,
+ InterfaceMethod<
+ /*desc=*/ [{
+ This method checks if it's valid to perform a `ptrtoint` or `inttoptr`
+ op in the memory space. `CastValidity::InvalidSourceType` always refers
+ to the 'ptr-like' type and `CastValidity::InvalidTargetType` always
+ refers to the `int-like` type.
+ The first type is expected to be integer-like, while the second must be a
+ ptr-like type.
+ If `diagnosticOp` is non-null then the method might emit diagnostics.
+ }],
+ /*returnType=*/ "::mlir::LogicalResult",
+ /*methodName=*/ "isValidPtrIntCast",
+ /*args=*/ (ins "::mlir::Type":$intLikeTy,
+ "::mlir::Type":$ptrLikeTy,
+ "::mlir::Operation*":$diagnosticOp),
+ /*methodBody=*/ [{}],
+ /*defaultImpl=*/ [{}]
+ >,
+ ];
+}
+
+def MemorySpaceOpInterface : OpInterface<"MemorySpaceOpInterface"> {
+ let description = [{
+ An interface for operations with a memory space.
+ }];
+
+ let cppNamespace = "::mlir::ptr";
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/ "Returns the memory space of the op.",
+ /*returnType=*/ "::mlir::ptr::MemorySpace",
+ /*methodName=*/ "getMemorySpace",
+ /*args=*/ (ins),
+ /*methodBody=*/ [{}],
+ /*defaultImpl=*/ [{}]
+ >,
+ ];
+}
+#endif // PTR_MEMORYSPACEINTERFACES
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrs.h b/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrs.h
new file mode 100644
index 00000000000000..e6aa7635919f6d
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrs.h
@@ -0,0 +1,20 @@
+//===- PtrAttrs.h - Pointer dialect attributes ------------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the Ptr dialect attributes.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_PTR_IR_PTRATTRS_H
+#define MLIR_DIALECT_PTR_IR_PTRATTRS_H
+
+#include "mlir/IR/OpImplementation.h"
+
+#include "mlir/Dialect/Ptr/IR/PtrOpsEnums.h.inc"
+
+#endif // MLIR_DIALECT_PTR_IR_PTRATTRS_H
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.td
new file mode 100644
index 00000000000000..3a921e4de08d8d
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.td
@@ -0,0 +1,69 @@
+//===-- PtrEnums.td - Ptr dialect enumerations -------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef PTR_ENUMS
+#define PTR_ENUMS
+
+include "mlir/IR/EnumAttr.td"
+
+//===----------------------------------------------------------------------===//
+// Atomic binary op enum attribute
+//===----------------------------------------------------------------------===//
+
+def AtomicBinOpXchg : I64EnumAttrCase<"xchg", 0, "xchg">;
+def AtomicBinOpAdd : I64EnumAttrCase<"add", 1, "add">;
+def AtomicBinOpSub : I64EnumAttrCase<"sub", 2, "sub">;
+def AtomicBinOpAnd : I64EnumAttrCase<"_and", 3, "_and">;
+def AtomicBinOpNand : I64EnumAttrCase<"nand", 4, "nand">;
+def AtomicBinOpOr : I64EnumAttrCase<"_or", 5, "_or">;
+def AtomicBinOpXor : I64EnumAttrCase<"_xor", 6, "_xor">;
+def AtomicBinOpMax : I64EnumAttrCase<"max", 7, "max">;
+def AtomicBinOpMin : I64EnumAttrCase<"min", 8, "min">;
+def AtomicBinOpUMax : I64EnumAttrCase<"umax", 9, "umax">;
+def AtomicBinOpUMin : I64EnumAttrCase<"umin", 10, "umin">;
+def AtomicBinOpFAdd : I64EnumAttrCase<"fadd", 11, "fadd">;
+def AtomicBinOpFSub : I64EnumAttrCase<"fsub", 12, "fsub">;
+def AtomicBinOpFMax : I64EnumAttrCase<"fmax", 13, "fmax">;
+def AtomicBinOpFMin : I64EnumAttrCase<"fmin", 14, "fmin">;
+def AtomicBinOpUIncWrap : I64EnumAttrCase<"uinc_wrap", 15, "uinc_wrap">;
+def AtomicBinOpUDecWrap : I64EnumAttrCase<"udec_wrap", 16, "udec_wrap">;
+
+def AtomicBinOp : I64EnumAttr<
+ "AtomicBinOp",
+ "ptr.atomicrmw binary operations",
+ [AtomicBinOpXchg, AtomicBinOpAdd, AtomicBinOpSub, AtomicBinOpAnd,
+ AtomicBinOpNand, AtomicBinOpOr, AtomicBinOpXor, AtomicBinOpMax,
+ AtomicBinOpMin, AtomicBinOpUMax, AtomicBinOpUMin, AtomicBinOpFAdd,
+ AtomicBinOpFSub, AtomicBinOpFMax, AtomicBinOpFMin, AtomicBinOpUIncWrap,
+ AtomicBinOpUDecWrap]> {
+ let cppNamespace = "::mlir::ptr";
+}
+
+//===----------------------------------------------------------------------===//
+// Atomic ordering enum attribute
+//===----------------------------------------------------------------------===//
+
+def AtomicOrderingNotAtomic : I64EnumAttrCase<"not_atomic", 0, "not_atomic">;
+def AtomicOrderingUnordered : I64EnumAttrCase<"unordered", 1, "unordered">;
+def AtomicOrderingMonotonic : I64EnumAttrCase<"monotonic", 2, "monotonic">;
+def AtomicOrderingAcquire : I64EnumAttrCase<"acquire", 3, "acquire">;
+def AtomicOrderingRelease : I64EnumAttrCase<"release", 4, "release">;
+def AtomicOrderingAcqRel : I64EnumAttrCase<"acq_rel", 5, "acq_rel">;
+def AtomicOrderingSeqCst : I64EnumAttrCase<"seq_cst", 6, "seq_cst">;
+
+def AtomicOrdering : I64EnumAttr<
+ "AtomicOrdering",
+ "Atomic ordering for LLVM's memory model",
+ [AtomicOrderingNotAtomic, AtomicOrderingUnordered, AtomicOrderingMonotonic,
+ AtomicOrderingAcquire, AtomicOrderingRelease, AtomicOrderingAcqRel,
+ AtomicOrderingSeqCst
+ ]> {
+ let cppNamespace = "::mlir::ptr";
+}
+
+#endif // PTR_ENUMS
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h
index ad8a2bbcbdd8d2..23d41a11a6da03 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h
@@ -14,6 +14,8 @@
#define MLIR_DIALECT_PTR_IR_PTROPS_H
#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/Ptr/IR/MemorySpace.h"
+#include "mlir/Dialect/Ptr/IR/PtrAttrs.h"
#include "mlir/Dialect/Ptr/IR/PtrDialect.h"
#include "mlir/Dialect/Ptr/IR/PtrTypes.h"
#include "mlir/IR/OpDefinition.h"
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
index 690941337bdfb5..91c73804c27133 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
@@ -10,6 +10,7 @@
#define PTR_OPS
include "mlir/Dialect/Ptr/IR/PtrDialect.td"
+include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td"
include "mlir/IR/OpAsmInterface.td"
#endif // PTR_OPS
diff --git a/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt b/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt
index 359b9f02a06266..24cc3bc6ef3b57 100644
--- a/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt
@@ -6,6 +6,8 @@ add_mlir_dialect_library(
${PROJECT_SOURCE_DIR}/mlir/Dialect/Pointer
DEPENDS
MLIRPtrOpsIncGen
+ MLIRPtrOpsEnumsGen
+ MLIRPtrMemorySpaceInterfacesIncGen
LINK_LIBS
PUBLIC
MLIRIR
diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
index 59c97b22f332c4..c97a0626ae169e 100644
--- a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
+++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
@@ -39,8 +39,77 @@ void PtrDialect::initialize() {
// Pointer API.
//===----------------------------------------------------------------------===//
+// Returns a pair containing:
+// The underlying type of a vector or the type itself if it's not a vector.
+// The number of elements in the vector or an error code if the type is not
+// supported.
+static std::pair<Type, int64_t> getVecOrScalarInfo(Type ty) {
+ if (auto vecTy = dyn_cast<VectorType>(ty)) {
+ auto elemTy = vecTy.getElementType();
+ // Vectors of rank greater than one or with scalable dimensions are not
+ // supported.
+ if (vecTy.getRank() != 1)
+ return {elemTy, -1};
+ else if (vecTy.getScalableDims()[0])
+ return {elemTy, -2};
+ return {elemTy, vecTy.getShape()[0]};
+ }
+ // `ty` is a scalar type.
+ return {ty, 0};
+}
+
+LogicalResult mlir::ptr::isValidAddrSpaceCastImpl(Type tgt, Type src,
+ Operation *op) {
+ std::pair<Type, int64_t> tgtInfo = getVecOrScalarInfo(tgt);
+ std::pair<Type, int64_t> srcInfo = getVecOrScalarInfo(src);
+ if (!isa<PtrType>(tgtInfo.first) || !isa<PtrType>(srcInfo.first))
+ return op ? op->emitError("invalid ptr-like operand") : failure();
+ // Check shape validity.
+ if (tgtInfo.second == -1 || srcInfo.second == -1)
+ return op ? op->emitError("vectors of rank != 1 are not supported")
+ : failure();
+ if (tgtInfo.second == -2 || srcInfo.second == -2)
+ return op ? op->emitError(
+ "vectors with scalable dimensions are not supported")
+ : failure();
+ if (tgtInfo.second != srcInfo.second)
+ return op ? op->emitError("incompatible operand shapes") : failure();
+ return success();
+}
+
+LogicalResult mlir::ptr::isValidPtrIntCastImpl(Type intLikeTy, Type ptrLikeTy,
+ Operation *op) {
+ // Check int-like type.
+ std::pair<Type, int64_t> intInfo = getVecOrScalarInfo(intLikeTy);
+ if (!intInfo.first.isSignlessIntOrIndex())
+ /// The int-like operand is invalid.
+ return op ? op->emitError("invalid int-like type") : failure();
+ // Check ptr-like type.
+ std::pair<Type, int64_t> ptrInfo = getVecOrScalarInfo(ptrLikeTy);
+ if (!isa<PtrType>(ptrInfo.first))
+ /// The pointer-like operand is invalid.
+ return op ? op->emitError("invalid ptr-like type") : failure();
+ // Check shape validity.
+ if (intInfo.second == -1 || ptrInfo.second == -1)
+ return op ? op->emitError("vectors of rank != 1 are not supported")
+ : failure();
+ if (intInfo.second == -2 || ptrInfo.second == -2)
+ return op ? op->emitError(
+ "vectors with scalable dimensions are not supported")
+ : failure();
+ if (intInfo.second != ptrInfo.second)
+ return op ? op->emitError("incompatible operand shapes") : failure();
+ return success();
+}
+
#include "mlir/Dialect/Ptr/IR/PtrOpsDialect.cpp.inc"
+#include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.cpp.inc"
+
+#include "mlir/Dialect/Ptr/IR/MemorySpaceAttrInterfaces.cpp.inc"
+
+#include "mlir/Dialect/Ptr/IR/PtrOpsEnums.cpp.inc"
+
#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/Ptr/IR/PtrOpsTypes.cpp.inc"
diff --git a/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp b/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp
index 51d0a45051b85e..8f83cc210d44b7 100644
--- a/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp
+++ b/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp
@@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Ptr/IR/PtrTypes.h"
+#include "mlir/Dialect/Ptr/IR/MemorySpace.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
@@ -64,9 +65,13 @@ getPointerDataLayoutEntry(DataLayoutEntryListRef params, PtrType type,
return std::nullopt;
}
-int64_t PtrType::getAddressSpace() const { return 0; }
+int64_t PtrType::getAddressSpace() const {
+ return MemorySpace(getMemorySpace()).getAddressSpace();
+}
-Attribute PtrType::getDefaultMemorySpace() const { return nullptr; }
+Attribute PtrType::getDefaultMemorySpace() const {
+ return MemorySpace(getMemorySpace()).getDefaultMemorySpace();
+}
bool PtrType::areCompatible(DataLayoutEntryListRef oldLayout,
DataLayoutEntryListRef newLayout) const {
diff --git a/mlir/test/Dialect/Ptr/types.mlir b/mlir/test/Dialect/Ptr/types.mlir
index 279213bd6fc3e5..9bec8f01803e43 100644
--- a/mlir/test/Dialect/Ptr/types.mlir
+++ b/mlir/test/Dialect/Ptr/types.mlir
@@ -15,3 +15,10 @@ func.func @ptr_test(%arg0: !ptr.ptr, %arg1: !ptr.ptr<1 : i32>) -> (!ptr.ptr<1 :
func.func @ptr_test(%arg0: memref<!ptr.ptr>) {
return
}
+
+// CHECK-LABEL: func @ptr_test_1
+// CHECK: (%[[ARG0:.*]]: !ptr.ptr<#test.const_memory_space>, %[[ARG1:.*]]: !ptr.ptr<#test.const_memory_space<3>>)
+func.func @ptr_test_1(%arg0: !ptr.ptr<#test.const_memory_space>,
+ %arg1: !ptr.ptr<#test.const_memory_space<3>>) {
+ return
+}
diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt
index b82b1631eead59..b7e4c8ae1145bb 100644
--- a/mlir/test/lib/Dialect/Test/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt
@@ -78,6 +78,7 @@ add_mlir_library(MLIRTestDialect
MLIRInferTypeOpInterface
MLIRLinalgDialect
MLIRLinalgTransforms
+ MLIRPtrDialect
MLIRLLVMDialect
MLIRPass
MLIRReduce
diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index 40f035a3e3a4e5..82dbc92333e30e 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -16,6 +16,7 @@
// To get the test dialect definition.
include "TestDialect.td"
include "TestEnumDefs.td"
+include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td"
include "mlir/Dialect/Utils/StructuredOpsUtils.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinAttributeInterfaces.td"
@@ -340,4 +341,13 @@ def TestConditionalAliasAttr : Test_Attr<"TestConditionalAlias"> {
}];
}
+// Test a ptr constant memory space.
+def TestConstMemorySpaceAttr : Test_Attr<"TestConstMemorySpace", [
+ DeclareAttrInterfaceMethods<MemorySpaceAttrInterface>
+ ]> {
+ let mnemonic = "const_memory_space";
+ let parameters = (ins DefaultValuedParameter<"unsigned", "0">:$addressSpace);
+ let assemblyFormat = "(`<` $addressSpace^ `>`)?";
+}
+
#endif // TEST_ATTRDEFS
diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
index d41d495c38e553..60c362465f12fe 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
@@ -239,6 +239,54 @@ static void printConditionalAlias(AsmPrinter &p, StringAttr value) {
p.printKeywordOrString(value);
}
+//===----------------------------------------------------------------------===//
+// TestConstMemorySpaceAttr
+//===----------------------------------------------------------------------===//
+
+Attribute TestConstMemorySpaceAttr::getDefaultMemorySpace() const {
+ return TestConstMemorySpaceAttr::get(getContext(), 0);
+}
+
+LogicalResult TestConstMemorySpaceAttr::isValidLoad(
+ Type type, mlir::ptr::AtomicOrdering ordering, IntegerAttr alignment,
+ Operation *diagnosticOp) const {
+ return success();
+}
+
+LogicalResult TestConstMemorySpaceAttr::isValidStore(
+ Type type, mlir::ptr::AtomicOrdering ordering, IntegerAttr alignment,
+ Operation *diagnosticOp) const {
+ return diagnosticOp ? diagnosticOp->emitError("memory space is read-only")
+ : failure();
+}
+
+LogicalResult TestConstMemorySpaceAttr::isValidAtomicOp(
+ mlir::ptr::AtomicBinOp binOp, Type type, mlir::ptr::AtomicOrdering ordering,
+ IntegerAttr alignment, Operation *diagnosticOp) const {
+ return diagnosticOp ? diagnosticOp->emitError("memory space is read-only")
+ : failure();
+}
+
+LogicalResult TestConstMemorySpaceAttr::isValidAtomicXchg(
+ Type type, mlir::ptr::AtomicOrdering successOrdering,
+ mlir::ptr::AtomicOrdering failureOrdering, IntegerAttr alignment,
+ Operation *diagnosticOp) const {
+ return diagnosticOp ? diagnosticOp->emitError("memory space is read-only")
+ : failure();
+}
+
+LogicalResult
+TestConstMemorySpaceAttr::isValidAddrSpaceCast(Type tgt, Type src,
+ Operation *diagnosticOp) const {
+ return ptr::isValidAddrSpaceCastImpl(tgt, src, diagnosticOp);
+}
+
+LogicalResult
+TestConstMemorySpaceAttr::isValidPtrIntCast(Type intLikeTy, Type ptrLikeTy,
+ Operation *diagnosticOp) const {
+ return ptr::isValidPtrIntCastImpl(intLikeTy, ptrLikeTy, diagnosticOp);
+}
+
//===----------------------------------------------------------------------===//
// Tablegen Generated Definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.h b/mlir/test/lib/Dialect/Test/TestAttributes.h
index ef6eae51fdd628..a84e26fba9d912 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.h
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.h
@@ -17,6 +17,7 @@
#include <tuple>
#include "TestTraits.h"
+#include "mlir/Dialect/Ptr/IR/MemorySpace.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Diagnostics.h"
>From 161d1fcf888466b2e439b8aa54662283fef6f9a1 Mon Sep 17 00:00:00 2001
From: Fabian Mora <fmora.dev at gmail.com>
Date: Tue, 2 Apr 2024 15:40:49 +0000
Subject: [PATCH 3/3] address reviewer comments
---
.../include/mlir/Dialect/Ptr/IR/MemorySpace.h | 50 ++++++++---------
mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp | 53 +++++++++++--------
2 files changed, 55 insertions(+), 48 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/MemorySpace.h b/mlir/include/mlir/Dialect/Ptr/IR/MemorySpace.h
index e467d121f2c886..948f97a0c95626 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/MemorySpace.h
+++ b/mlir/include/mlir/Dialect/Ptr/IR/MemorySpace.h
@@ -21,14 +21,14 @@
namespace mlir {
class Operation;
namespace ptr {
-/// This method checks if it's valid to perform an `addrspacecast` op in the
+/// Checks if it's valid to perform an `addrspacecast` op in the
/// memory space.
/// Compatible types are:
/// Vectors of rank 1, or scalars of `ptr` type.
LogicalResult isValidAddrSpaceCastImpl(Type tgt, Type src,
Operation *diagnosticOp);
-/// This method checks if it's valid to perform a `ptrtoint` or `inttoptr` op in
+/// Checks if it's valid to perform a `ptrtoint` or `inttoptr` op in
/// the memory space.
/// Compatible types are:
/// IntLikeTy: Vectors of rank 1, or scalars of integer types or `index` type.
@@ -52,28 +52,29 @@ class MemorySpace {
MemorySpace() = default;
MemorySpace(std::nullptr_t) {}
MemorySpace(MemorySpaceAttrInterface memorySpace)
- : memorySpaceAttr(memorySpace), memorySpace(memorySpace) {}
- MemorySpace(Attribute memorySpace)
- : memorySpaceAttr(memorySpace),
+ : underlyingMemorySpace(memorySpace), memorySpace(memorySpace) {}
+ explicit MemorySpace(Attribute memorySpace)
+ : underlyingMemorySpace(memorySpace),
memorySpace(dyn_cast_or_null<MemorySpaceAttrInterface>(memorySpace)) {}
- operator Attribute() const { return memorySpaceAttr; }
+ operator Attribute() const { return underlyingMemorySpace; }
operator MemorySpaceAttrInterface() const { return memorySpace; }
bool operator==(const MemorySpace &memSpace) const {
- return memSpace.memorySpaceAttr == memorySpaceAttr;
+ return memSpace.underlyingMemorySpace == underlyingMemorySpace;
}
/// Returns the underlying memory space.
- Attribute getUnderlyingSpace() const { return memorySpaceAttr; }
+ Attribute getUnderlyingSpace() const { return underlyingMemorySpace; }
- /// Returns true if the underlying memory space is null.
+ /// Returns true if the memory space is null.
bool isDefaultModel() const { return memorySpace == nullptr; }
/// Returns the memory space as an integer, or 0 if using the default space.
unsigned getAddressSpace() const {
if (memorySpace)
return memorySpace.getAddressSpace();
- if (auto intAttr = llvm::dyn_cast_or_null<IntegerAttr>(memorySpaceAttr))
+ if (auto intAttr =
+ llvm::dyn_cast_or_null<IntegerAttr>(underlyingMemorySpace))
return intAttr.getInt();
return 0;
}
@@ -84,9 +85,9 @@ class MemorySpace {
return memorySpace ? memorySpace.getDefaultMemorySpace() : nullptr;
}
- /// This method checks if it's valid to load a value from the memory space
- /// with a specific type, alignment, and atomic ordering. The default model
- /// assumes all values are loadable.
+ /// Checks if it's valid to load a value from the memory space with a specific
+ /// type, alignment, and atomic ordering. The default model assumes all values
+ /// can be loaded.
LogicalResult isValidLoad(Type type, AtomicOrdering ordering,
IntegerAttr alignment,
Operation *diagnosticOp = nullptr) const {
@@ -95,9 +96,9 @@ class MemorySpace {
: success();
}
- /// This method checks if it's valid to store a value in the memory space with
- /// a specific type, alignment, and atomic ordering. The default model assumes
- /// all values are loadable.
+ /// Checks if it's valid to store a value in the memory space with a specific
+ /// type, alignment, and atomic ordering. The default model assumes all values
+ /// can be stored.
LogicalResult isValidStore(Type type, AtomicOrdering ordering,
IntegerAttr alignment,
Operation *diagnosticOp = nullptr) const {
@@ -106,8 +107,8 @@ class MemorySpace {
: success();
}
- /// This method checks if it's valid to perform an atomic operation in the
- /// memory space with a specific type, alignment, and atomic ordering.
+ /// Checks if it's valid to perform an atomic operation in the memory space
+ /// with a specific type, alignment, and atomic ordering.
LogicalResult isValidAtomicOp(AtomicBinOp op, Type type,
AtomicOrdering ordering, IntegerAttr alignment,
Operation *diagnosticOp = nullptr) const {
@@ -116,8 +117,8 @@ class MemorySpace {
: success();
}
- /// This method checks if it's valid to perform an atomic operation in the
- /// memory space with a specific type, alignment, and atomic ordering.
+ /// Checks if it's valid to perform an atomic exchange operation in the memory
+ /// space with a specific type, alignment, and atomic ordering.
LogicalResult isValidAtomicXchg(Type type, AtomicOrdering successOrdering,
AtomicOrdering failureOrdering,
IntegerAttr alignment,
@@ -128,8 +129,7 @@ class MemorySpace {
: success();
}
- /// This method checks if it's valid to perform an `addrspacecast` op in the
- /// memory space.
+ /// Checks if it's valid to perform an `addrspacecast` op in the memory space.
LogicalResult isValidAddrSpaceCast(Type tgt, Type src,
Operation *diagnosticOp = nullptr) const {
return memorySpace
@@ -137,8 +137,8 @@ class MemorySpace {
: isValidAddrSpaceCastImpl(tgt, src, diagnosticOp);
}
- /// This method checks if it's valid to perform a `ptrtoint` or `inttoptr` op
- /// in the memory space.
+ /// Checks if it's valid to perform a `ptrtoint` or `inttoptr` op in the
+ /// memory space.
LogicalResult isValidPtrIntCast(Type intLikeTy, Type ptrLikeTy,
Operation *diagnosticOp = nullptr) const {
return memorySpace
@@ -149,7 +149,7 @@ class MemorySpace {
protected:
/// Underlying memory space.
- Attribute memorySpaceAttr{};
+ Attribute underlyingMemorySpace{};
/// Memory space.
MemorySpaceAttrInterface memorySpace{};
};
diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
index c97a0626ae169e..195c1a56492902 100644
--- a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
+++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
@@ -39,6 +39,10 @@ void PtrDialect::initialize() {
// Pointer API.
//===----------------------------------------------------------------------===//
+// Error constants for vector data types.
+constexpr const static unsigned kInvalidRankError = -1;
+constexpr const static unsigned kScalableDimsError = -2;
+
// Returns a pair containing:
// The underlying type of a vector or the type itself if it's not a vector.
// The number of elements in the vector or an error code if the type is not
@@ -49,26 +53,28 @@ static std::pair<Type, int64_t> getVecOrScalarInfo(Type ty) {
// Vectors of rank greater than one or with scalable dimensions are not
// supported.
if (vecTy.getRank() != 1)
- return {elemTy, -1};
+ return {elemTy, kInvalidRankError};
else if (vecTy.getScalableDims()[0])
- return {elemTy, -2};
+ return {elemTy, kScalableDimsError};
return {elemTy, vecTy.getShape()[0]};
}
// `ty` is a scalar type.
return {ty, 0};
}
-LogicalResult mlir::ptr::isValidAddrSpaceCastImpl(Type tgt, Type src,
- Operation *op) {
- std::pair<Type, int64_t> tgtInfo = getVecOrScalarInfo(tgt);
- std::pair<Type, int64_t> srcInfo = getVecOrScalarInfo(src);
- if (!isa<PtrType>(tgtInfo.first) || !isa<PtrType>(srcInfo.first))
- return op ? op->emitError("invalid ptr-like operand") : failure();
+/// Checks whether the shape of the operands is compatible with the operation.
+/// Operands must be scalars or have the same vector shape, additionally only
+/// vectors of rank 1 are supported.
+static LogicalResult verifyShapeInfo(mlir::Operation *op,
+ const std::pair<Type, int64_t> &tgtInfo,
+ const std::pair<Type, int64_t> &srcInfo) {
// Check shape validity.
- if (tgtInfo.second == -1 || srcInfo.second == -1)
+ if (tgtInfo.second == kInvalidRankError ||
+ srcInfo.second == kInvalidRankError)
return op ? op->emitError("vectors of rank != 1 are not supported")
: failure();
- if (tgtInfo.second == -2 || srcInfo.second == -2)
+ if (tgtInfo.second == kScalableDimsError ||
+ srcInfo.second == kScalableDimsError)
return op ? op->emitError(
"vectors with scalable dimensions are not supported")
: failure();
@@ -77,29 +83,30 @@ LogicalResult mlir::ptr::isValidAddrSpaceCastImpl(Type tgt, Type src,
return success();
}
+LogicalResult mlir::ptr::isValidAddrSpaceCastImpl(Type tgt, Type src,
+ Operation *op) {
+ std::pair<Type, int64_t> tgtInfo = getVecOrScalarInfo(tgt);
+ std::pair<Type, int64_t> srcInfo = getVecOrScalarInfo(src);
+ if (!isa<PtrType>(tgtInfo.first) || !isa<PtrType>(srcInfo.first))
+ return op ? op->emitError("invalid ptr-like operand") : failure();
+ // Verify shape validity.
+ return verifyShapeInfo(op, tgtInfo, srcInfo);
+}
+
LogicalResult mlir::ptr::isValidPtrIntCastImpl(Type intLikeTy, Type ptrLikeTy,
Operation *op) {
// Check int-like type.
std::pair<Type, int64_t> intInfo = getVecOrScalarInfo(intLikeTy);
+ // The int-like operand is invalid.
if (!intInfo.first.isSignlessIntOrIndex())
- /// The int-like operand is invalid.
return op ? op->emitError("invalid int-like type") : failure();
// Check ptr-like type.
std::pair<Type, int64_t> ptrInfo = getVecOrScalarInfo(ptrLikeTy);
+ // The pointer-like operand is invalid.
if (!isa<PtrType>(ptrInfo.first))
- /// The pointer-like operand is invalid.
return op ? op->emitError("invalid ptr-like type") : failure();
- // Check shape validity.
- if (intInfo.second == -1 || ptrInfo.second == -1)
- return op ? op->emitError("vectors of rank != 1 are not supported")
- : failure();
- if (intInfo.second == -2 || ptrInfo.second == -2)
- return op ? op->emitError(
- "vectors with scalable dimensions are not supported")
- : failure();
- if (intInfo.second != ptrInfo.second)
- return op ? op->emitError("incompatible operand shapes") : failure();
- return success();
+ // Verify shape validity.
+ return verifyShapeInfo(op, intInfo, ptrInfo);
}
#include "mlir/Dialect/Ptr/IR/PtrOpsDialect.cpp.inc"
More information about the Mlir-commits
mailing list