[Mlir-commits] [mlir] [mlir][core|ptr] Add `PtrLikeTypeInterface` and casting ops to the `ptr` dialect (PR #137469)
Fabian Mora
llvmlistbot at llvm.org
Mon Jun 16 12:19:04 PDT 2025
https://github.com/fabianmcg updated https://github.com/llvm/llvm-project/pull/137469
>From fe22487907de4153dd492b2b6d3bcd83d6468993 Mon Sep 17 00:00:00 2001
From: Fabian Mora <6982088+fabianmcg at users.noreply.github.com>
Date: Sat, 26 Apr 2025 19:25:54 +0000
Subject: [PATCH 1/7] [mlir][core|ptr] Add `PtrLikeTypeInterface` and casting
ops to the `ptr` dialect
This patch adds the `PtrLikeTypeInterface` type interface to identify pointer-like types.
This interface is defined as:
```
A ptr-like type represents an object storing a memory address. This object
is constituted by:
- A memory address called the base pointer. The base pointer is an
indivisible object.
- Optional metadata about the pointer. For example, the size of the memory
region associated with the pointer.
Furthermore, all ptr-like types have two properties:
- The memory space associated with the address held by the pointer.
- An optional element type. If the element type is not specified, the
pointer is considered opaque.
```
This patch adds this interface to `!ptr.ptr` and the `memref` type.
Furthermore, this patch adds necessary ops and type to handle casting between `!ptr.ptr`
and ptr-like types.
First, it defines the `!ptr.ptr_metadata` type. An opaque type to represent the metadata
of a ptr-like type. The rationale behind adding this type, is that at high-level the
metadata of a type like `memref` cannot be specified, as its structure is tied to its
lowering.
The `ptr.get_metadata` operation was added to extract the opaque pointer metadata. The
concrete structure of the metadata is only known when the op is lowered.
Finally, this patch adds the `ptr.from_ptr` and `ptr.to_ptr` operations. Allowing to
cast back and forth between `!ptr.ptr` and ptr-liker types.
```mlir
func.func @func(%mr: memref<f32, #ptr.generic_space>) -> memref<f32, #ptr.generic_space> {
%ptr = ptr.to_ptr %mr : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
%mda = ptr.get_metadata %mr : memref<f32, #ptr.generic_space>
%res = ptr.from_ptr %ptr metadata %mda : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
return %res : memref<f32, #ptr.generic_space>
}
```
---
.../include/mlir/Dialect/Ptr/IR/PtrDialect.td | 49 +++++++++
mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td | 99 +++++++++++++++++++
mlir/include/mlir/IR/BuiltinTypeInterfaces.td | 49 +++++++++
mlir/include/mlir/IR/BuiltinTypes.h | 18 +++-
mlir/include/mlir/IR/BuiltinTypes.td | 2 +
mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp | 75 ++++++++++++++
mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp | 12 +++
mlir/lib/IR/BuiltinTypes.cpp | 14 +++
mlir/test/Dialect/Ptr/canonicalize.mlir | 58 +++++++++++
mlir/test/Dialect/Ptr/invalid.mlir | 33 +++++++
mlir/test/Dialect/Ptr/ops.mlir | 10 ++
11 files changed, 418 insertions(+), 1 deletion(-)
create mode 100644 mlir/test/Dialect/Ptr/invalid.mlir
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
index 73b2a0857cef3..6631b338db199 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
@@ -37,6 +37,7 @@ class Ptr_Type<string name, string typeMnemonic, list<Trait> traits = []>
def Ptr_PtrType : Ptr_Type<"Ptr", "ptr", [
MemRefElementTypeInterface,
+ PtrLikeTypeInterface,
VectorElementTypeInterface,
DeclareTypeInterfaceMethods<DataLayoutTypeInterface, [
"areCompatible", "getIndexBitwidth", "verifyEntries",
@@ -63,6 +64,54 @@ def Ptr_PtrType : Ptr_Type<"Ptr", "ptr", [
return $_get(memorySpace.getContext(), memorySpace);
}]>
];
+ let extraClassDeclaration = [{
+ // `PtrLikeTypeInterface` interface methods.
+ /// Returns `Type()` as this pointer type is opaque.
+ Type getElementType() const {
+ return Type();
+ }
+ /// Clones the pointer with specified memory space or returns failure
+ /// if an `elementType` was specified or if the memory space doesn't
+ /// implement `MemorySpaceAttrInterface`.
+ FailureOr<PtrLikeTypeInterface> clonePtrWith(Attribute memorySpace,
+ std::optional<Type> elementType) const {
+ if (elementType)
+ return failure();
+ if (auto ms = dyn_cast<MemorySpaceAttrInterface>(memorySpace))
+ return cast<PtrLikeTypeInterface>(get(ms));
+ return failure();
+ }
+ /// `!ptr.ptr` types are seen as ptr-like objects with no metadata.
+ bool hasPtrMetadata() const {
+ return false;
+ }
+ }];
+}
+
+def Ptr_PtrMetadata : Ptr_Type<"PtrMetadata", "ptr_metadata"> {
+ let summary = "Pointer metadata type";
+ let description = [{
+ The `ptr_metadata` type represents an opaque-view of the metadata associated
+ with a `ptr-like` object type.
+ It's an error to get a `ptr_metadata` using `ptr-like` type with no
+ metadata.
+
+ Example:
+
+ ```mlir
+ // The metadata associated with a `memref` type.
+ !ptr.ptr_metadata<memref<f32>>
+ ```
+ }];
+ let parameters = (ins "PtrLikeTypeInterface":$type);
+ let assemblyFormat = "`<` $type `>`";
+ let builders = [
+ TypeBuilderWithInferredContext<(ins
+ "PtrLikeTypeInterface":$ptrLike), [{
+ return $_get(ptrLike.getContext(), ptrLike);
+ }]>
+ ];
+ let genVerifyDecl = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
index 791b95ad3559e..8ad475c41c8d3 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
@@ -17,6 +17,75 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td"
include "mlir/IR/OpAsmInterface.td"
+//===----------------------------------------------------------------------===//
+// FromPtrOp
+//===----------------------------------------------------------------------===//
+
+def Ptr_FromPtrOp : Pointer_Op<"from_ptr", [
+ Pure, OptionalTypesMatchWith<"metadata type", "result", "metadata",
+ "PtrMetadataType::get(cast<PtrLikeTypeInterface>($_self))">
+ ]> {
+ let summary = "Casts a `!ptr.ptr` value to a ptr-like value.";
+ let description = [{
+ The `from_ptr` operation casts a `ptr` value to a ptr-like object. It's
+ important to note that:
+ - The ptr-like object cannot be a `!ptr.ptr`.
+ - The memory-space of both the `ptr` and ptr-like object must match.
+ - The cast is side-effect free.
+
+ If the ptr-like object type has metadata, then the operation expects the
+ metadata as an argument or expects that the flag `trivial_metadata` is set.
+ If `trivial_metadata` is set, then it is assumed that the metadata can be
+ reconstructed statically from the pointer-like type.
+
+ Example:
+
+ ```mlir
+ %typed_ptr = ptr.from_ptr %ptr : !ptr.ptr<0> -> !my.ptr<f32, 0>
+ %memref = ptr.from_ptr %ptr metadata %md : !ptr.ptr<0> -> memref<f32, 0>
+ %memref = ptr.from_ptr %ptr trivial_metadata : !ptr.ptr<0> -> memref<f32, 0>
+ ```
+ }];
+
+ let arguments = (ins Ptr_PtrType:$ptr,
+ Optional<Ptr_PtrMetadata>:$metadata,
+ UnitProp:$hasTrivialMetadata);
+ let results = (outs PtrLikeTypeInterface:$result);
+ let assemblyFormat = [{
+ $ptr (`metadata` $metadata^)? (`trivial_metadata` $hasTrivialMetadata^)?
+ attr-dict `:` type($ptr) `->` type($result)
+ }];
+ let hasFolder = 1;
+ let hasVerifier = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// GetMetadataOp
+//===----------------------------------------------------------------------===//
+
+def Ptr_GetMetadataOp : Pointer_Op<"get_metadata", [
+ Pure, TypesMatchWith<"metadata type", "ptr", "result",
+ "PtrMetadataType::get(cast<PtrLikeTypeInterface>($_self))">
+ ]> {
+ let summary = "SSA value representing pointer metadata.";
+ let description = [{
+ The `get_metadata` operation produces an opaque value that encodes the
+ metadata of the ptr-like type.
+
+ Example:
+
+ ```mlir
+ %metadata = ptr.get_metadata %memref : memref<?x?xf32>
+ ```
+ }];
+
+ let arguments = (ins PtrLikeTypeInterface:$ptr);
+ let results = (outs Ptr_PtrMetadata:$result);
+ let assemblyFormat = [{
+ $ptr attr-dict `:` type($ptr)
+ }];
+}
+
//===----------------------------------------------------------------------===//
// PtrAddOp
//===----------------------------------------------------------------------===//
@@ -52,6 +121,36 @@ def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [
}];
}
+//===----------------------------------------------------------------------===//
+// ToPtrOp
+//===----------------------------------------------------------------------===//
+
+def Ptr_ToPtrOp : Pointer_Op<"to_ptr", [Pure]> {
+ let summary = "Casts a ptr-like value to a `!ptr.ptr` value.";
+ let description = [{
+ The `to_ptr` operation casts a ptr-like object to a `!ptr.ptr`. It's
+ important to note that:
+ - The ptr-like object cannot be a `!ptr.ptr`.
+ - The memory-space of both the `ptr` and ptr-like object must match.
+ - The cast is side-effect free.
+
+ Example:
+
+ ```mlir
+ %ptr0 = ptr.to_ptr %my_ptr : !my.ptr<f32, 0> -> !ptr.ptr<0>
+ %ptr1 = ptr.to_ptr %memref : memref<f32, 0> -> !ptr.ptr<0>
+ ```
+ }];
+
+ let arguments = (ins PtrLikeTypeInterface:$ptr);
+ let results = (outs Ptr_PtrType:$result);
+ let assemblyFormat = [{
+ $ptr attr-dict `:` type($ptr) `->` type($result)
+ }];
+ let hasFolder = 1;
+ let hasVerifier = 1;
+}
+
//===----------------------------------------------------------------------===//
// TypeOffsetOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index 4a4f818b46c57..d058f6c4d9651 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -110,6 +110,55 @@ def MemRefElementTypeInterface : TypeInterface<"MemRefElementTypeInterface"> {
}];
}
+//===----------------------------------------------------------------------===//
+// PtrLikeTypeInterface
+//===----------------------------------------------------------------------===//
+
+def PtrLikeTypeInterface : TypeInterface<"PtrLikeTypeInterface"> {
+ let cppNamespace = "::mlir";
+ let description = [{
+ A ptr-like type represents an object storing a memory address. This object
+ is constituted by:
+ - A memory address called the base pointer. The base pointer is an
+ indivisible object.
+ - Optional metadata about the pointer. For example, the size of the memory
+ region associated with the pointer.
+
+ Furthermore, all ptr-like types have two properties:
+ - The memory space associated with the address held by the pointer.
+ - An optional element type. If the element type is not specified, the
+ pointer is considered opaque.
+ }];
+ let methods = [
+ InterfaceMethod<[{
+ Returns the memory space of this ptr-like type.
+ }],
+ "::mlir::Attribute", "getMemorySpace">,
+ InterfaceMethod<[{
+ Returns the element type of this ptr-like type. Note: this method can
+ return `::mlir::Type()`, in which case the pointer is considered opaque.
+ }],
+ "::mlir::Type", "getElementType">,
+ InterfaceMethod<[{
+ Returns whether this ptr-like type has non-empty metadata.
+ }],
+ "bool", "hasPtrMetadata">,
+ InterfaceMethod<[{
+ Returns a clone of this type with the given memory space and element type,
+ or `failure` if the type cannot be cloned with the specified arguments.
+ If the pointer is opaque and `elementType` is not `std::nullopt` the
+ method will return `failure`.
+
+ If no `elementType` is provided and ptr is not opaque, the `elementType`
+ of this type is used.
+ }],
+ "::llvm::FailureOr<::mlir::PtrLikeTypeInterface>", "clonePtrWith", (ins
+ "::mlir::Attribute":$memorySpace,
+ "::std::optional<::mlir::Type>":$elementType
+ )>
+ ];
+}
+
//===----------------------------------------------------------------------===//
// ShapedType
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index df1e02732617d..86ec5c43970b1 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -99,7 +99,9 @@ class TensorType : public Type, public ShapedType::Trait<TensorType> {
/// Note: This class attaches the ShapedType trait to act as a mixin to
/// provide many useful utility functions. This inheritance has no effect
/// on derived memref types.
-class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
+class BaseMemRefType : public Type,
+ public PtrLikeTypeInterface::Trait<BaseMemRefType>,
+ public ShapedType::Trait<BaseMemRefType> {
public:
using Type::Type;
@@ -117,6 +119,12 @@ class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
BaseMemRefType cloneWith(std::optional<ArrayRef<int64_t>> shape,
Type elementType) const;
+ /// Clone this type with the given memory space and element type. If the
+ /// provided element type is `std::nullopt`, the current element type of the
+ /// type is used.
+ FailureOr<PtrLikeTypeInterface>
+ clonePtrWith(Attribute memorySpace, std::optional<Type> elementType) const;
+
// Make sure that base class overloads are visible.
using ShapedType::Trait<BaseMemRefType>::clone;
@@ -141,8 +149,16 @@ class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
/// New `Attribute getMemorySpace()` method should be used instead.
unsigned getMemorySpaceAsInt() const;
+ /// Returns that this ptr-like object has non-empty ptr metadata.
+ bool hasPtrMetadata() const { return true; }
+
/// Allow implicit conversion to ShapedType.
operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
+
+ /// Allow implicit conversion to PtrLikeTypeInterface.
+ operator PtrLikeTypeInterface() const {
+ return llvm::cast<PtrLikeTypeInterface>(*this);
+ }
};
} // namespace mlir
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 771de01fc8d5d..9ad24e45c8315 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -562,6 +562,7 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer",
//===----------------------------------------------------------------------===//
def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
+ PtrLikeTypeInterface,
ShapedTypeInterface
], "BaseMemRefType"> {
let summary = "Shaped reference to a region of memory";
@@ -1143,6 +1144,7 @@ def Builtin_Tuple : Builtin_Type<"Tuple", "tuple"> {
//===----------------------------------------------------------------------===//
def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [
+ PtrLikeTypeInterface,
ShapedTypeInterface
], "BaseMemRefType"> {
let summary = "Shaped reference, with unknown rank, to a region of memory";
diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
index c21783011452f..80fd7617c9354 100644
--- a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
+++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
@@ -41,6 +41,54 @@ void PtrDialect::initialize() {
>();
}
+//===----------------------------------------------------------------------===//
+// FromPtrOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult FromPtrOp::fold(FoldAdaptor adaptor) {
+ // Fold the pattern:
+ // %ptr = ptr.to_ptr %v : type -> ptr
+ // (%mda = ptr.get_metadata %v : type)?
+ // %val = ptr.from_ptr %ptr (metadata %mda)? : ptr -> type
+ // To:
+ // %val -> %v
+ auto toPtr = dyn_cast_or_null<ToPtrOp>(getPtr().getDefiningOp());
+ // Cannot fold if it's not a `to_ptr` op or the initial and final types are
+ // different.
+ if (!toPtr || toPtr.getPtr().getType() != getType())
+ return nullptr;
+ Value md = getMetadata();
+ if (!md)
+ return toPtr.getPtr();
+ // Fold if the metadata can be verified to be equal.
+ if (auto mdOp = dyn_cast_or_null<GetMetadataOp>(md.getDefiningOp());
+ mdOp && mdOp.getPtr() == toPtr.getPtr())
+ return toPtr.getPtr();
+ return nullptr;
+}
+
+LogicalResult FromPtrOp::verify() {
+ if (isa<PtrType>(getType()))
+ return emitError() << "the result type cannot be `!ptr.ptr`";
+ if (getType().getMemorySpace() != getPtr().getType().getMemorySpace()) {
+ return emitError()
+ << "expected the input and output to have the same memory space";
+ }
+ bool hasMD = getMetadata() != Value();
+ bool hasTrivialMD = getHasTrivialMetadata();
+ if (hasMD && hasTrivialMD) {
+ return emitError() << "expected either a metadata argument or the "
+ "`trivial_metadata` flag, not both";
+ }
+ if (getType().hasPtrMetadata() && !(hasMD || hasTrivialMD)) {
+ return emitError() << "expected either a metadata argument or the "
+ "`trivial_metadata` flag to be set";
+ }
+ if (!getType().hasPtrMetadata() && (hasMD || hasTrivialMD))
+ return emitError() << "expected no metadata specification";
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// PtrAddOp
//===----------------------------------------------------------------------===//
@@ -55,6 +103,33 @@ OpFoldResult PtrAddOp::fold(FoldAdaptor adaptor) {
return nullptr;
}
+//===----------------------------------------------------------------------===//
+// ToPtrOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult ToPtrOp::fold(FoldAdaptor adaptor) {
+ // Fold the pattern:
+ // %val = ptr.from_ptr %p (metadata ...)? : ptr -> type
+ // %ptr = ptr.to_ptr %val : type -> ptr
+ // To:
+ // %ptr -> %p
+ auto fromPtr = dyn_cast_or_null<FromPtrOp>(getPtr().getDefiningOp());
+ // Cannot fold if it's not a `from_ptr` op.
+ if (!fromPtr)
+ return nullptr;
+ return fromPtr.getPtr();
+}
+
+LogicalResult ToPtrOp::verify() {
+ if (isa<PtrType>(getPtr().getType()))
+ return emitError() << "the input value cannot be of type `!ptr.ptr`";
+ if (getType().getMemorySpace() != getPtr().getType().getMemorySpace()) {
+ return emitError()
+ << "expected the input and output to have the same memory space";
+ }
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// TypeOffsetOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp b/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp
index cab9ca11e679e..7ad2a6bc4c80b 100644
--- a/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp
+++ b/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp
@@ -151,3 +151,15 @@ LogicalResult PtrType::verifyEntries(DataLayoutEntryListRef entries,
}
return success();
}
+
+//===----------------------------------------------------------------------===//
+// Pointer metadata
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+PtrMetadataType::verify(function_ref<InFlightDiagnostic()> emitError,
+ PtrLikeTypeInterface type) {
+ if (!type.hasPtrMetadata())
+ return emitError() << "the ptr-like type has no metadata";
+ return success();
+}
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index d47e360e9dc13..97bab479c79bf 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -376,6 +376,20 @@ BaseMemRefType BaseMemRefType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
return builder;
}
+FailureOr<PtrLikeTypeInterface>
+BaseMemRefType::clonePtrWith(Attribute memorySpace,
+ std::optional<Type> elementType) const {
+ Type eTy = elementType ? *elementType : getElementType();
+ if (llvm::dyn_cast<UnrankedMemRefType>(*this))
+ return cast<PtrLikeTypeInterface>(
+ UnrankedMemRefType::get(eTy, memorySpace));
+
+ MemRefType::Builder builder(llvm::cast<MemRefType>(*this));
+ builder.setElementType(eTy);
+ builder.setMemorySpace(memorySpace);
+ return cast<PtrLikeTypeInterface>(static_cast<MemRefType>(builder));
+}
+
MemRefType BaseMemRefType::clone(::llvm::ArrayRef<int64_t> shape,
Type elementType) const {
return ::llvm::cast<MemRefType>(cloneWith(shape, elementType));
diff --git a/mlir/test/Dialect/Ptr/canonicalize.mlir b/mlir/test/Dialect/Ptr/canonicalize.mlir
index ad363d554f247..837f364242beb 100644
--- a/mlir/test/Dialect/Ptr/canonicalize.mlir
+++ b/mlir/test/Dialect/Ptr/canonicalize.mlir
@@ -13,3 +13,61 @@ func.func @zero_offset(%ptr: !ptr.ptr<#ptr.generic_space>) -> !ptr.ptr<#ptr.gene
%res0 = ptr.ptr_add %ptr, %off : !ptr.ptr<#ptr.generic_space>, index
return %res0 : !ptr.ptr<#ptr.generic_space>
}
+
+/// Tests the the `from_ptr` folder.
+// CHECK-LABEL: @test_from_ptr_0
+// CHECK-SAME: (%[[MEM_REF:.*]]: memref<f32, #ptr.generic_space>)
+func.func @test_from_ptr_0(%mr: memref<f32, #ptr.generic_space>) -> memref<f32, #ptr.generic_space> {
+ // CHECK-NOT: ptr.to_ptr
+ // CHECK-NOT: ptr.get_metadata
+ // CHECK-NOT: ptr.from_ptr
+ // CHECK: return %[[MEM_REF]]
+ %ptr = ptr.to_ptr %mr : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
+ %mda = ptr.get_metadata %mr : memref<f32, #ptr.generic_space>
+ %res = ptr.from_ptr %ptr metadata %mda : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+ return %res : memref<f32, #ptr.generic_space>
+}
+
+// CHECK-LABEL: @test_from_ptr_1
+// CHECK-SAME: (%[[MEM_REF:.*]]: memref<f32, #ptr.generic_space>)
+func.func @test_from_ptr_1(%mr: memref<f32, #ptr.generic_space>) -> memref<f32, #ptr.generic_space> {
+ // CHECK-NOT: ptr.to_ptr
+ // CHECK-NOT: ptr.from_ptr
+ // CHECK: return %[[MEM_REF]]
+ %ptr = ptr.to_ptr %mr : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
+ %res = ptr.from_ptr %ptr trivial_metadata : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+ return %res : memref<f32, #ptr.generic_space>
+}
+
+/// Check that the ops cannot be folded because the metadata cannot be guaranteed to be the same.
+// CHECK-LABEL: @test_from_ptr_2
+func.func @test_from_ptr_2(%mr: memref<f32, #ptr.generic_space>, %md: !ptr.ptr_metadata<memref<f32, #ptr.generic_space>>) -> memref<f32, #ptr.generic_space> {
+ // CHECK: ptr.to_ptr
+ // CHECK: ptr.from_ptr
+ %ptr = ptr.to_ptr %mr : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
+ %res = ptr.from_ptr %ptr metadata %md : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+ return %res : memref<f32, #ptr.generic_space>
+}
+
+/// Tests the the `to_ptr` folder.
+// CHECK-LABEL: @test_to_ptr_0
+// CHECK-SAME: (%[[PTR:.*]]: !ptr.ptr<#ptr.generic_space>
+func.func @test_to_ptr_0(%ptr: !ptr.ptr<#ptr.generic_space>, %md: !ptr.ptr_metadata<memref<f32, #ptr.generic_space>>) -> !ptr.ptr<#ptr.generic_space> {
+ // CHECK: return %[[PTR]]
+ // CHECK-NOT: ptr.from_ptr
+ // CHECK-NOT: ptr.to_ptr
+ %mrf = ptr.from_ptr %ptr metadata %md : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+ %res = ptr.to_ptr %mrf : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
+ return %res : !ptr.ptr<#ptr.generic_space>
+}
+
+// CHECK-LABEL: @test_to_ptr_1
+// CHECK-SAME: (%[[PTR:.*]]: !ptr.ptr<#ptr.generic_space>)
+func.func @test_to_ptr_1(%ptr: !ptr.ptr<#ptr.generic_space>) -> !ptr.ptr<#ptr.generic_space> {
+ // CHECK-NOT: ptr.from_ptr
+ // CHECK-NOT: ptr.to_ptr
+ // CHECK: return %[[PTR]]
+ %mrf = ptr.from_ptr %ptr trivial_metadata : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+ %res = ptr.to_ptr %mrf : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
+ return %res : !ptr.ptr<#ptr.generic_space>
+}
diff --git a/mlir/test/Dialect/Ptr/invalid.mlir b/mlir/test/Dialect/Ptr/invalid.mlir
new file mode 100644
index 0000000000000..e776e0ee04f90
--- /dev/null
+++ b/mlir/test/Dialect/Ptr/invalid.mlir
@@ -0,0 +1,33 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics %s
+
+/// Test `to_ptr` verifiers.
+func.func @invalid_to_ptr(%v: memref<f32, 0>) {
+ // expected-error at +1 {{expected the input and output to have the same memory space}}
+ %r = ptr.to_ptr %v : memref<f32, 0> -> !ptr.ptr<#ptr.generic_space>
+ return
+}
+
+// -----
+
+func.func @invalid_to_ptr(%v: !ptr.ptr<#ptr.generic_space>) {
+ // expected-error at +1 {{the input value cannot be of type `!ptr.ptr`}}
+ %r = ptr.to_ptr %v : !ptr.ptr<#ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
+ return
+}
+
+// -----
+
+/// Test `from_ptr` verifiers.
+func.func @invalid_from_ptr(%v: !ptr.ptr<#ptr.generic_space>) {
+ // expected-error at +1 {{expected either a metadata argument or the `trivial_metadata` flag to be set}}
+ %r = ptr.from_ptr %v : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+ return
+}
+
+// -----
+
+func.func @invalid_from_ptr(%v: !ptr.ptr<#ptr.generic_space>, %m: !ptr.ptr_metadata<memref<f32, #ptr.generic_space>>) {
+ // expected-error at +1 {{expected either a metadata argument or the `trivial_metadata` flag, not both}}
+ %r = ptr.from_ptr %v metadata %m trivial_metadata : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+ return
+}
diff --git a/mlir/test/Dialect/Ptr/ops.mlir b/mlir/test/Dialect/Ptr/ops.mlir
index d763ea221944b..74bff25b4f3e1 100644
--- a/mlir/test/Dialect/Ptr/ops.mlir
+++ b/mlir/test/Dialect/Ptr/ops.mlir
@@ -17,3 +17,13 @@ func.func @ptr_add_type_offset(%ptr: !ptr.ptr<#ptr.generic_space>) -> !ptr.ptr<#
%res3 = ptr.ptr_add inbounds %ptr, %off : !ptr.ptr<#ptr.generic_space>, index
return %res : !ptr.ptr<#ptr.generic_space>
}
+
+/// Check cast ops assembly.
+// CHECK-LABEL: @cast_ops
+func.func @cast_ops(%mr: memref<f32, #ptr.generic_space>) -> memref<f32, #ptr.generic_space> {
+ %ptr = ptr.to_ptr %mr : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
+ %mda = ptr.get_metadata %mr : memref<f32, #ptr.generic_space>
+ %res = ptr.from_ptr %ptr metadata %mda : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+ %mr0 = ptr.from_ptr %ptr trivial_metadata : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+ return %res : memref<f32, #ptr.generic_space>
+}
>From afac5f4118573633bc89451c36ee05144cd4baf2 Mon Sep 17 00:00:00 2001
From: Fabian Mora <fmora.dev at gmail.com>
Date: Sun, 27 Apr 2025 07:29:15 -0400
Subject: [PATCH 2/7] Update mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
Co-authored-by: Mehdi Amini <joker.eph at gmail.com>
---
mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
index 8ad475c41c8d3..55cc47a41d03b 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
@@ -31,7 +31,7 @@ def Ptr_FromPtrOp : Pointer_Op<"from_ptr", [
important to note that:
- The ptr-like object cannot be a `!ptr.ptr`.
- The memory-space of both the `ptr` and ptr-like object must match.
- - The cast is side-effect free.
+ - The cast is Pure (no UB and side-effect free).
If the ptr-like object type has metadata, then the operation expects the
metadata as an argument or expects that the flag `trivial_metadata` is set.
>From d9fd27ec7f65eda27587cc602d43da762c475fd5 Mon Sep 17 00:00:00 2001
From: Fabian Mora <6982088+fabianmcg at users.noreply.github.com>
Date: Sun, 27 Apr 2025 13:33:46 +0000
Subject: [PATCH 3/7] add tests for chains of casts
---
mlir/test/Dialect/Ptr/canonicalize.mlir | 48 +++++++++++++++++++++++++
1 file changed, 48 insertions(+)
diff --git a/mlir/test/Dialect/Ptr/canonicalize.mlir b/mlir/test/Dialect/Ptr/canonicalize.mlir
index 837f364242beb..2b9c8489f352e 100644
--- a/mlir/test/Dialect/Ptr/canonicalize.mlir
+++ b/mlir/test/Dialect/Ptr/canonicalize.mlir
@@ -49,6 +49,21 @@ func.func @test_from_ptr_2(%mr: memref<f32, #ptr.generic_space>, %md: !ptr.ptr_m
return %res : memref<f32, #ptr.generic_space>
}
+// Check the folding of `to_ptr -> from_ptr` chains.
+// CHECK-LABEL: @test_from_ptr_3
+// CHECK-SAME: (%[[MEM_REF:.*]]: memref<f32, #ptr.generic_space>)
+func.func @test_from_ptr_3(%mr0: memref<f32, #ptr.generic_space>) -> memref<f32, #ptr.generic_space> {
+ // CHECK-NOT: ptr.to_ptr
+ // CHECK-NOT: ptr.from_ptr
+ // CHECK: return %[[MEM_REF]]
+ %mda = ptr.get_metadata %mr0 : memref<f32, #ptr.generic_space>
+ %ptr0 = ptr.to_ptr %mr0 : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
+ %mrf0 = ptr.from_ptr %ptr0 metadata %mda : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+ %ptr1 = ptr.to_ptr %mrf0 : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
+ %mrf1 = ptr.from_ptr %ptr1 metadata %mda : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+ return %mrf1 : memref<f32, #ptr.generic_space>
+}
+
/// Tests the the `to_ptr` folder.
// CHECK-LABEL: @test_to_ptr_0
// CHECK-SAME: (%[[PTR:.*]]: !ptr.ptr<#ptr.generic_space>
@@ -71,3 +86,36 @@ func.func @test_to_ptr_1(%ptr: !ptr.ptr<#ptr.generic_space>) -> !ptr.ptr<#ptr.ge
%res = ptr.to_ptr %mrf : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
return %res : !ptr.ptr<#ptr.generic_space>
}
+
+// Check the folding of `from_ptr -> to_ptr` chains.
+// CHECK-LABEL: @test_to_ptr_2
+// CHECK-SAME: (%[[PTR:.*]]: !ptr.ptr<#ptr.generic_space>
+func.func @test_to_ptr_2(%ptr0: !ptr.ptr<#ptr.generic_space>) -> !ptr.ptr<#ptr.generic_space> {
+ // CHECK-NOT: ptr.from_ptr
+ // CHECK-NOT: ptr.to_ptr
+ // CHECK: return %[[PTR]]
+ %mrf0 = ptr.from_ptr %ptr0 trivial_metadata : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+ %ptr1 = ptr.to_ptr %mrf0 : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
+ %mrf1 = ptr.from_ptr %ptr1 trivial_metadata : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+ %ptr2 = ptr.to_ptr %mrf1 : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
+ %mrf2 = ptr.from_ptr %ptr2 trivial_metadata : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+ %res = ptr.to_ptr %mrf2 : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
+ return %res : !ptr.ptr<#ptr.generic_space>
+}
+
+// Check the folding of chains with different metadata.
+// CHECK-LABEL: @test_cast_chain_folding
+// CHECK-SAME: (%[[MEM_REF:.*]]: memref<f32, #ptr.generic_space>
+func.func @test_cast_chain_folding(%mr: memref<f32, #ptr.generic_space>, %md: !ptr.ptr_metadata<memref<f32, #ptr.generic_space>>) -> memref<f32, #ptr.generic_space> {
+ // CHECK-NOT: ptr.to_ptr
+ // CHECK-NOT: ptr.from_ptr
+ // CHECK: return %[[MEM_REF]]
+ %ptr1 = ptr.to_ptr %mr : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
+ %memrefWithOtherMd = ptr.from_ptr %ptr1 metadata %md : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+ %ptr = ptr.to_ptr %memrefWithOtherMd : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
+ %mda = ptr.get_metadata %mr : memref<f32, #ptr.generic_space>
+ // The chain can be folded because: the ptr always has the same value because
+ // `to_ptr` is a loss-less cast and %mda comes from the original memref.
+ %res = ptr.from_ptr %ptr metadata %mda : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+ return %res : memref<f32, #ptr.generic_space>
+}
>From 634d03ba963e01422a96b7046b44d37fdfbd4b12 Mon Sep 17 00:00:00 2001
From: Fabian Mora <6982088+fabianmcg at users.noreply.github.com>
Date: Sat, 10 May 2025 13:34:46 +0000
Subject: [PATCH 4/7] make folders work on cast sequences
---
mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp | 52 +++++++++++++++++---------
1 file changed, 34 insertions(+), 18 deletions(-)
diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
index 80fd7617c9354..c0310446e3cea 100644
--- a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
+++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
@@ -52,19 +52,28 @@ OpFoldResult FromPtrOp::fold(FoldAdaptor adaptor) {
// %val = ptr.from_ptr %ptr (metadata %mda)? : ptr -> type
// To:
// %val -> %v
- auto toPtr = dyn_cast_or_null<ToPtrOp>(getPtr().getDefiningOp());
- // Cannot fold if it's not a `to_ptr` op or the initial and final types are
- // different.
- if (!toPtr || toPtr.getPtr().getType() != getType())
- return nullptr;
- Value md = getMetadata();
- if (!md)
- return toPtr.getPtr();
- // Fold if the metadata can be verified to be equal.
- if (auto mdOp = dyn_cast_or_null<GetMetadataOp>(md.getDefiningOp());
- mdOp && mdOp.getPtr() == toPtr.getPtr())
- return toPtr.getPtr();
- return nullptr;
+ Value ptrLike;
+ FromPtrOp fromPtr = *this;
+ while (fromPtr != nullptr) {
+ auto toPtr = dyn_cast_or_null<ToPtrOp>(fromPtr.getPtr().getDefiningOp());
+ // Cannot fold if it's not a `to_ptr` op or the initial and final types are
+ // different.
+ if (!toPtr || toPtr.getPtr().getType() != fromPtr.getType())
+ return ptrLike;
+ Value md = fromPtr.getMetadata();
+ // If there's no metadata in the op, either the cast never requires metadata
+ // or the op has the trivial metadata flag set, therefore fold.
+ if (!md)
+ ptrLike = toPtr.getPtr();
+ // Fold if the metadata can be verified to be equal.
+ else if (auto mdOp = dyn_cast_or_null<GetMetadataOp>(md.getDefiningOp());
+ mdOp && mdOp.getPtr() == toPtr.getPtr())
+ ptrLike = toPtr.getPtr();
+ // Check for a sequence of casts.
+ fromPtr = dyn_cast_or_null<FromPtrOp>(ptrLike ? ptrLike.getDefiningOp()
+ : nullptr);
+ }
+ return ptrLike;
}
LogicalResult FromPtrOp::verify() {
@@ -113,11 +122,18 @@ OpFoldResult ToPtrOp::fold(FoldAdaptor adaptor) {
// %ptr = ptr.to_ptr %val : type -> ptr
// To:
// %ptr -> %p
- auto fromPtr = dyn_cast_or_null<FromPtrOp>(getPtr().getDefiningOp());
- // Cannot fold if it's not a `from_ptr` op.
- if (!fromPtr)
- return nullptr;
- return fromPtr.getPtr();
+ Value ptr;
+ ToPtrOp toPtr = *this;
+ while (toPtr != nullptr) {
+ auto fromPtr = dyn_cast_or_null<FromPtrOp>(toPtr.getPtr().getDefiningOp());
+ // Cannot fold if it's not a `from_ptr` op.
+ if (!fromPtr)
+ return ptr;
+ ptr = fromPtr.getPtr();
+ // Check for chains of casts.
+ toPtr = dyn_cast_or_null<ToPtrOp>(ptr.getDefiningOp());
+ }
+ return ptr;
}
LogicalResult ToPtrOp::verify() {
>From f090320893d026975f30edfbd3693a286be4da59 Mon Sep 17 00:00:00 2001
From: Fabian Mora <6982088+fabianmcg at users.noreply.github.com>
Date: Wed, 14 May 2025 16:44:14 +0000
Subject: [PATCH 5/7] remove trivial_metadata flag from from_ptr op
---
mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td | 17 +++++++----------
mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp | 15 +--------------
mlir/test/Dialect/Ptr/canonicalize.mlir | 10 +++++-----
mlir/test/Dialect/Ptr/invalid.mlir | 17 -----------------
mlir/test/Dialect/Ptr/ops.mlir | 2 +-
5 files changed, 14 insertions(+), 47 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
index 55cc47a41d03b..37eb91fa6a338 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
@@ -33,27 +33,24 @@ def Ptr_FromPtrOp : Pointer_Op<"from_ptr", [
- The memory-space of both the `ptr` and ptr-like object must match.
- The cast is Pure (no UB and side-effect free).
- If the ptr-like object type has metadata, then the operation expects the
- metadata as an argument or expects that the flag `trivial_metadata` is set.
- If `trivial_metadata` is set, then it is assumed that the metadata can be
- reconstructed statically from the pointer-like type.
+ The optional `metadata` operand exists to provide any ptr-like metadata
+ that might be required to perform the cast.
Example:
```mlir
%typed_ptr = ptr.from_ptr %ptr : !ptr.ptr<0> -> !my.ptr<f32, 0>
%memref = ptr.from_ptr %ptr metadata %md : !ptr.ptr<0> -> memref<f32, 0>
- %memref = ptr.from_ptr %ptr trivial_metadata : !ptr.ptr<0> -> memref<f32, 0>
+
+ // Cast the `%ptr` to a memref without utilizing metadata.
+ %memref = ptr.from_ptr %ptr : !ptr.ptr<0> -> memref<f32, 0>
```
}];
- let arguments = (ins Ptr_PtrType:$ptr,
- Optional<Ptr_PtrMetadata>:$metadata,
- UnitProp:$hasTrivialMetadata);
+ let arguments = (ins Ptr_PtrType:$ptr, Optional<Ptr_PtrMetadata>:$metadata);
let results = (outs PtrLikeTypeInterface:$result);
let assemblyFormat = [{
- $ptr (`metadata` $metadata^)? (`trivial_metadata` $hasTrivialMetadata^)?
- attr-dict `:` type($ptr) `->` type($result)
+ $ptr (`metadata` $metadata^)? attr-dict `:` type($ptr) `->` type($result)
}];
let hasFolder = 1;
let hasVerifier = 1;
diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
index c0310446e3cea..ffa924b20ab59 100644
--- a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
+++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
@@ -61,8 +61,7 @@ OpFoldResult FromPtrOp::fold(FoldAdaptor adaptor) {
if (!toPtr || toPtr.getPtr().getType() != fromPtr.getType())
return ptrLike;
Value md = fromPtr.getMetadata();
- // If there's no metadata in the op, either the cast never requires metadata
- // or the op has the trivial metadata flag set, therefore fold.
+ // If there's no metadata in the op fold the op.
if (!md)
ptrLike = toPtr.getPtr();
// Fold if the metadata can be verified to be equal.
@@ -83,18 +82,6 @@ LogicalResult FromPtrOp::verify() {
return emitError()
<< "expected the input and output to have the same memory space";
}
- bool hasMD = getMetadata() != Value();
- bool hasTrivialMD = getHasTrivialMetadata();
- if (hasMD && hasTrivialMD) {
- return emitError() << "expected either a metadata argument or the "
- "`trivial_metadata` flag, not both";
- }
- if (getType().hasPtrMetadata() && !(hasMD || hasTrivialMD)) {
- return emitError() << "expected either a metadata argument or the "
- "`trivial_metadata` flag to be set";
- }
- if (!getType().hasPtrMetadata() && (hasMD || hasTrivialMD))
- return emitError() << "expected no metadata specification";
return success();
}
diff --git a/mlir/test/Dialect/Ptr/canonicalize.mlir b/mlir/test/Dialect/Ptr/canonicalize.mlir
index 2b9c8489f352e..dfc679acb2ed4 100644
--- a/mlir/test/Dialect/Ptr/canonicalize.mlir
+++ b/mlir/test/Dialect/Ptr/canonicalize.mlir
@@ -35,7 +35,7 @@ func.func @test_from_ptr_1(%mr: memref<f32, #ptr.generic_space>) -> memref<f32,
// CHECK-NOT: ptr.from_ptr
// CHECK: return %[[MEM_REF]]
%ptr = ptr.to_ptr %mr : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
- %res = ptr.from_ptr %ptr trivial_metadata : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+ %res = ptr.from_ptr %ptr : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
return %res : memref<f32, #ptr.generic_space>
}
@@ -82,7 +82,7 @@ func.func @test_to_ptr_1(%ptr: !ptr.ptr<#ptr.generic_space>) -> !ptr.ptr<#ptr.ge
// CHECK-NOT: ptr.from_ptr
// CHECK-NOT: ptr.to_ptr
// CHECK: return %[[PTR]]
- %mrf = ptr.from_ptr %ptr trivial_metadata : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+ %mrf = ptr.from_ptr %ptr : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
%res = ptr.to_ptr %mrf : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
return %res : !ptr.ptr<#ptr.generic_space>
}
@@ -94,11 +94,11 @@ func.func @test_to_ptr_2(%ptr0: !ptr.ptr<#ptr.generic_space>) -> !ptr.ptr<#ptr.g
// CHECK-NOT: ptr.from_ptr
// CHECK-NOT: ptr.to_ptr
// CHECK: return %[[PTR]]
- %mrf0 = ptr.from_ptr %ptr0 trivial_metadata : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+ %mrf0 = ptr.from_ptr %ptr0 : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
%ptr1 = ptr.to_ptr %mrf0 : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
- %mrf1 = ptr.from_ptr %ptr1 trivial_metadata : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+ %mrf1 = ptr.from_ptr %ptr1 : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
%ptr2 = ptr.to_ptr %mrf1 : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
- %mrf2 = ptr.from_ptr %ptr2 trivial_metadata : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+ %mrf2 = ptr.from_ptr %ptr2 : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
%res = ptr.to_ptr %mrf2 : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
return %res : !ptr.ptr<#ptr.generic_space>
}
diff --git a/mlir/test/Dialect/Ptr/invalid.mlir b/mlir/test/Dialect/Ptr/invalid.mlir
index e776e0ee04f90..19fd715e5bba6 100644
--- a/mlir/test/Dialect/Ptr/invalid.mlir
+++ b/mlir/test/Dialect/Ptr/invalid.mlir
@@ -14,20 +14,3 @@ func.func @invalid_to_ptr(%v: !ptr.ptr<#ptr.generic_space>) {
%r = ptr.to_ptr %v : !ptr.ptr<#ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
return
}
-
-// -----
-
-/// Test `from_ptr` verifiers.
-func.func @invalid_from_ptr(%v: !ptr.ptr<#ptr.generic_space>) {
- // expected-error at +1 {{expected either a metadata argument or the `trivial_metadata` flag to be set}}
- %r = ptr.from_ptr %v : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
- return
-}
-
-// -----
-
-func.func @invalid_from_ptr(%v: !ptr.ptr<#ptr.generic_space>, %m: !ptr.ptr_metadata<memref<f32, #ptr.generic_space>>) {
- // expected-error at +1 {{expected either a metadata argument or the `trivial_metadata` flag, not both}}
- %r = ptr.from_ptr %v metadata %m trivial_metadata : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
- return
-}
diff --git a/mlir/test/Dialect/Ptr/ops.mlir b/mlir/test/Dialect/Ptr/ops.mlir
index 74bff25b4f3e1..eed3272d98da9 100644
--- a/mlir/test/Dialect/Ptr/ops.mlir
+++ b/mlir/test/Dialect/Ptr/ops.mlir
@@ -24,6 +24,6 @@ func.func @cast_ops(%mr: memref<f32, #ptr.generic_space>) -> memref<f32, #ptr.ge
%ptr = ptr.to_ptr %mr : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
%mda = ptr.get_metadata %mr : memref<f32, #ptr.generic_space>
%res = ptr.from_ptr %ptr metadata %mda : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
- %mr0 = ptr.from_ptr %ptr trivial_metadata : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+ %mr0 = ptr.from_ptr %ptr : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
return %res : memref<f32, #ptr.generic_space>
}
>From 5b1a2ad53b9ca95400531bd6ce930eda88e83446 Mon Sep 17 00:00:00 2001
From: Fabian Mora <6982088+fabianmcg at users.noreply.github.com>
Date: Fri, 6 Jun 2025 22:52:34 +0000
Subject: [PATCH 6/7] address reviewer comments
---
mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td | 5 +++--
mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td | 14 +++++++-------
mlir/include/mlir/IR/BuiltinTypeInterfaces.td | 8 ++++++--
3 files changed, 16 insertions(+), 11 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
index 6631b338db199..7407d74ce3a87 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
@@ -93,8 +93,9 @@ def Ptr_PtrMetadata : Ptr_Type<"PtrMetadata", "ptr_metadata"> {
let description = [{
The `ptr_metadata` type represents an opaque-view of the metadata associated
with a `ptr-like` object type.
- It's an error to get a `ptr_metadata` using `ptr-like` type with no
- metadata.
+
+ Note: It's a verification error to construct a `ptr_metadata` type using a
+ `ptr-like` type with no metadata.
Example:
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
index 37eb91fa6a338..1523762efc18f 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
@@ -39,11 +39,11 @@ def Ptr_FromPtrOp : Pointer_Op<"from_ptr", [
Example:
```mlir
- %typed_ptr = ptr.from_ptr %ptr : !ptr.ptr<0> -> !my.ptr<f32, 0>
- %memref = ptr.from_ptr %ptr metadata %md : !ptr.ptr<0> -> memref<f32, 0>
+ %typed_ptr = ptr.from_ptr %ptr : !ptr.ptr<#ptr.generic_space> -> !my.ptr<f32, #ptr.generic_space>
+ %memref = ptr.from_ptr %ptr metadata %md : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
// Cast the `%ptr` to a memref without utilizing metadata.
- %memref = ptr.from_ptr %ptr : !ptr.ptr<0> -> memref<f32, 0>
+ %memref = ptr.from_ptr %ptr : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
```
}];
@@ -98,8 +98,8 @@ def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [
Example:
```mlir
- %x_off = ptr.ptr_add %x, %off : !ptr.ptr<0>, i32
- %x_off0 = ptr.ptr_add nusw %x, %off : !ptr.ptr<0>, i32
+ %x_off = ptr.ptr_add %x, %off : !ptr.ptr<#ptr.generic_space>, i32
+ %x_off0 = ptr.ptr_add nusw %x, %off : !ptr.ptr<#ptr.generic_space>, i32
```
}];
@@ -134,8 +134,8 @@ def Ptr_ToPtrOp : Pointer_Op<"to_ptr", [Pure]> {
Example:
```mlir
- %ptr0 = ptr.to_ptr %my_ptr : !my.ptr<f32, 0> -> !ptr.ptr<0>
- %ptr1 = ptr.to_ptr %memref : memref<f32, 0> -> !ptr.ptr<0>
+ %ptr0 = ptr.to_ptr %my_ptr : !my.ptr<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
+ %ptr1 = ptr.to_ptr %memref : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
```
}];
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index d058f6c4d9651..367aeb6ac512b 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -119,8 +119,12 @@ def PtrLikeTypeInterface : TypeInterface<"PtrLikeTypeInterface"> {
let description = [{
A ptr-like type represents an object storing a memory address. This object
is constituted by:
- - A memory address called the base pointer. The base pointer is an
- indivisible object.
+ - A memory address called the base pointer. This pointer is treated as a
+ bag of bits without any assumed structure. The bit-width of the base
+ pointer must be a compile-time constant. However, the bit-width may remain
+ opaque or unavailable during transformations that do not depend on the
+ base pointer. Finally, it is considered indivisible in the sense that as
+ a `PtrLikeTypeInterface` value, it has no metadata.
- Optional metadata about the pointer. For example, the size of the memory
region associated with the pointer.
>From 105bcec7d423a2863c046397518d4873316f04fc Mon Sep 17 00:00:00 2001
From: Fabian Mora <fabian.mora-cordero at amd.com>
Date: Mon, 16 Jun 2025 19:18:00 +0000
Subject: [PATCH 7/7] address reviewer comments and rebase
---
mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp | 14 ++++++++------
mlir/test/Dialect/Ptr/canonicalize.mlir | 6 +++---
2 files changed, 11 insertions(+), 9 deletions(-)
diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
index ffa924b20ab59..c488144508128 100644
--- a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
+++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
@@ -61,13 +61,15 @@ OpFoldResult FromPtrOp::fold(FoldAdaptor adaptor) {
if (!toPtr || toPtr.getPtr().getType() != fromPtr.getType())
return ptrLike;
Value md = fromPtr.getMetadata();
- // If there's no metadata in the op fold the op.
- if (!md)
- ptrLike = toPtr.getPtr();
- // Fold if the metadata can be verified to be equal.
- else if (auto mdOp = dyn_cast_or_null<GetMetadataOp>(md.getDefiningOp());
- mdOp && mdOp.getPtr() == toPtr.getPtr())
+ // If the type has trivial metadata fold.
+ if (!fromPtr.getType().hasPtrMetadata()) {
ptrLike = toPtr.getPtr();
+ } else if (md) {
+ // Fold if the metadata can be verified to be equal.
+ if (auto mdOp = dyn_cast_or_null<GetMetadataOp>(md.getDefiningOp());
+ mdOp && mdOp.getPtr() == toPtr.getPtr())
+ ptrLike = toPtr.getPtr();
+ }
// Check for a sequence of casts.
fromPtr = dyn_cast_or_null<FromPtrOp>(ptrLike ? ptrLike.getDefiningOp()
: nullptr);
diff --git a/mlir/test/Dialect/Ptr/canonicalize.mlir b/mlir/test/Dialect/Ptr/canonicalize.mlir
index dfc679acb2ed4..e50cd1b76caf3 100644
--- a/mlir/test/Dialect/Ptr/canonicalize.mlir
+++ b/mlir/test/Dialect/Ptr/canonicalize.mlir
@@ -28,12 +28,12 @@ func.func @test_from_ptr_0(%mr: memref<f32, #ptr.generic_space>) -> memref<f32,
return %res : memref<f32, #ptr.generic_space>
}
+/// Check the op doesn't fold because folding a ptr-type with metadata requires knowing the origin of the metadata.
// CHECK-LABEL: @test_from_ptr_1
// CHECK-SAME: (%[[MEM_REF:.*]]: memref<f32, #ptr.generic_space>)
func.func @test_from_ptr_1(%mr: memref<f32, #ptr.generic_space>) -> memref<f32, #ptr.generic_space> {
- // CHECK-NOT: ptr.to_ptr
- // CHECK-NOT: ptr.from_ptr
- // CHECK: return %[[MEM_REF]]
+ // CHECK: ptr.to_ptr
+ // CHECK: ptr.from_ptr
%ptr = ptr.to_ptr %mr : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
%res = ptr.from_ptr %ptr : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
return %res : memref<f32, #ptr.generic_space>
More information about the Mlir-commits
mailing list