[Mlir-commits] [mlir] [mlir][core|ptr] Add `PtrLikeTypeInterface` and casting ops to the `ptr` dialect (PR #137469)

Fabian Mora llvmlistbot at llvm.org
Fri Jun 6 15:54:26 PDT 2025


https://github.com/fabianmcg updated https://github.com/llvm/llvm-project/pull/137469

>From e4715591651e869ff6b7a32c2791efde2c573dcf Mon Sep 17 00:00:00 2001
From: Fabian Mora <6982088+fabianmcg at users.noreply.github.com>
Date: Sat, 26 Apr 2025 19:25:54 +0000
Subject: [PATCH 1/6] [mlir][core|ptr] Add `PtrLikeTypeInterface` and casting
 ops to the `ptr` dialect

This patch adds the `PtrLikeTypeInterface` type interface to identify pointer-like types.
This interface is defined as:

```
A ptr-like type represents an object storing a memory address. This object
is constituted by:
- A memory address called the base pointer. The base pointer is an
  indivisible object.
- Optional metadata about the pointer. For example, the size of the  memory
  region associated with the pointer.

Furthermore, all ptr-like types have two properties:
- The memory space associated with the address held by the pointer.
- An optional element type. If the element type is not specified, the
  pointer is considered opaque.
```

This patch adds this interface to `!ptr.ptr` and the `memref` type.

Furthermore, this patch adds necessary ops and type to handle casting between `!ptr.ptr`
and ptr-like types.

First, it defines the `!ptr.ptr_metadata` type. An opaque type to represent the metadata
of a ptr-like type. The rationale behind adding this type, is that at high-level the
metadata of a type like `memref` cannot be specified, as its structure is tied to its
lowering.

The `ptr.get_metadata` operation was added to extract the opaque pointer metadata. The
concrete structure of the metadata is only known when the op is lowered.

Finally, this patch adds the `ptr.from_ptr` and `ptr.to_ptr` operations. Allowing to
cast back and forth between `!ptr.ptr` and ptr-liker types.

```mlir
func.func @func(%mr: memref<f32, #ptr.generic_space>) -> memref<f32, #ptr.generic_space> {
  %ptr = ptr.to_ptr %mr : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
  %mda = ptr.get_metadata %mr : memref<f32, #ptr.generic_space>
  %res = ptr.from_ptr %ptr metadata %mda : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
  return %res : memref<f32, #ptr.generic_space>
}
```
---
 .../include/mlir/Dialect/Ptr/IR/PtrDialect.td | 49 +++++++++
 mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td    | 99 +++++++++++++++++++
 mlir/include/mlir/IR/BuiltinTypeInterfaces.td | 49 +++++++++
 mlir/include/mlir/IR/BuiltinTypes.h           | 18 +++-
 mlir/include/mlir/IR/BuiltinTypes.td          |  2 +
 mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp        | 75 ++++++++++++++
 mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp          | 12 +++
 mlir/lib/IR/BuiltinTypes.cpp                  | 14 +++
 mlir/test/Dialect/Ptr/canonicalize.mlir       | 58 +++++++++++
 mlir/test/Dialect/Ptr/invalid.mlir            | 33 +++++++
 mlir/test/Dialect/Ptr/ops.mlir                | 10 ++
 11 files changed, 418 insertions(+), 1 deletion(-)
 create mode 100644 mlir/test/Dialect/Ptr/invalid.mlir

diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
index 73b2a0857cef3..6631b338db199 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
@@ -37,6 +37,7 @@ class Ptr_Type<string name, string typeMnemonic, list<Trait> traits = []>
 
 def Ptr_PtrType : Ptr_Type<"Ptr", "ptr", [
     MemRefElementTypeInterface,
+    PtrLikeTypeInterface,
     VectorElementTypeInterface,
     DeclareTypeInterfaceMethods<DataLayoutTypeInterface, [
       "areCompatible", "getIndexBitwidth", "verifyEntries",
@@ -63,6 +64,54 @@ def Ptr_PtrType : Ptr_Type<"Ptr", "ptr", [
       return $_get(memorySpace.getContext(), memorySpace);
     }]>
   ];
+  let extraClassDeclaration = [{
+    // `PtrLikeTypeInterface` interface methods.
+    /// Returns `Type()` as this pointer type is opaque.
+    Type getElementType() const {
+      return Type();
+    }
+    /// Clones the pointer with specified memory space or returns failure
+    /// if an `elementType` was specified or if the memory space doesn't
+    /// implement `MemorySpaceAttrInterface`.
+    FailureOr<PtrLikeTypeInterface> clonePtrWith(Attribute memorySpace,
+      std::optional<Type> elementType) const {
+      if (elementType)
+        return failure();
+      if (auto ms = dyn_cast<MemorySpaceAttrInterface>(memorySpace))
+        return cast<PtrLikeTypeInterface>(get(ms));
+      return failure();
+    }
+    /// `!ptr.ptr` types are seen as ptr-like objects with no metadata.
+    bool hasPtrMetadata() const {
+      return false;
+    }
+  }];
+}
+
+def Ptr_PtrMetadata : Ptr_Type<"PtrMetadata", "ptr_metadata"> {
+  let summary = "Pointer metadata type";
+  let description = [{
+    The `ptr_metadata` type represents an opaque-view of the metadata associated
+    with a `ptr-like` object type.
+    It's an error to get a `ptr_metadata` using `ptr-like` type with no
+    metadata.
+
+    Example:
+
+    ```mlir
+    // The metadata associated with a `memref` type.
+    !ptr.ptr_metadata<memref<f32>>
+    ```
+  }];
+  let parameters = (ins "PtrLikeTypeInterface":$type);
+  let assemblyFormat = "`<` $type `>`";
+  let builders = [
+    TypeBuilderWithInferredContext<(ins
+      "PtrLikeTypeInterface":$ptrLike), [{
+      return $_get(ptrLike.getContext(), ptrLike);
+    }]>
+  ];
+  let genVerifyDecl = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
index 791b95ad3559e..8ad475c41c8d3 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
@@ -17,6 +17,75 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Interfaces/ViewLikeInterface.td"
 include "mlir/IR/OpAsmInterface.td"
 
+//===----------------------------------------------------------------------===//
+// FromPtrOp
+//===----------------------------------------------------------------------===//
+
+def Ptr_FromPtrOp : Pointer_Op<"from_ptr", [
+    Pure, OptionalTypesMatchWith<"metadata type", "result", "metadata",
+            "PtrMetadataType::get(cast<PtrLikeTypeInterface>($_self))">
+  ]> {
+  let summary = "Casts a `!ptr.ptr` value to a ptr-like value.";
+  let description = [{
+    The `from_ptr` operation casts a `ptr` value to a ptr-like object. It's
+    important to note that:
+    - The ptr-like object cannot be a `!ptr.ptr`.
+    - The memory-space of both the `ptr` and ptr-like object must match.
+    - The cast is side-effect free.
+
+    If the ptr-like object type has metadata, then the operation expects the
+    metadata as an argument or expects that the flag `trivial_metadata` is set.
+    If `trivial_metadata` is set, then it is assumed that the metadata can be
+    reconstructed statically from the pointer-like type.
+
+    Example:
+
+    ```mlir
+    %typed_ptr = ptr.from_ptr %ptr : !ptr.ptr<0> -> !my.ptr<f32, 0>
+    %memref = ptr.from_ptr %ptr metadata %md : !ptr.ptr<0> -> memref<f32, 0>
+    %memref = ptr.from_ptr %ptr trivial_metadata : !ptr.ptr<0> -> memref<f32, 0>
+    ```
+  }];
+
+  let arguments = (ins Ptr_PtrType:$ptr,
+                       Optional<Ptr_PtrMetadata>:$metadata,
+                       UnitProp:$hasTrivialMetadata);
+  let results = (outs PtrLikeTypeInterface:$result);
+  let assemblyFormat = [{
+    $ptr (`metadata` $metadata^)? (`trivial_metadata` $hasTrivialMetadata^)?
+    attr-dict `:` type($ptr) `->` type($result)
+  }];
+  let hasFolder = 1;
+  let hasVerifier = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// GetMetadataOp
+//===----------------------------------------------------------------------===//
+
+def Ptr_GetMetadataOp : Pointer_Op<"get_metadata", [
+    Pure, TypesMatchWith<"metadata type", "ptr", "result",
+            "PtrMetadataType::get(cast<PtrLikeTypeInterface>($_self))">
+  ]> {
+  let summary = "SSA value representing pointer metadata.";
+  let description = [{
+    The `get_metadata` operation produces an opaque value that encodes the
+    metadata of the ptr-like type.
+
+    Example:
+
+    ```mlir
+    %metadata = ptr.get_metadata %memref : memref<?x?xf32>
+    ```
+  }];
+
+  let arguments = (ins PtrLikeTypeInterface:$ptr);
+  let results = (outs Ptr_PtrMetadata:$result);
+  let assemblyFormat = [{
+    $ptr attr-dict `:` type($ptr)
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // PtrAddOp
 //===----------------------------------------------------------------------===//
@@ -52,6 +121,36 @@ def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// ToPtrOp
+//===----------------------------------------------------------------------===//
+
+def Ptr_ToPtrOp : Pointer_Op<"to_ptr", [Pure]> {
+  let summary = "Casts a ptr-like value to a `!ptr.ptr` value.";
+  let description = [{
+    The `to_ptr` operation casts a ptr-like object to a `!ptr.ptr`. It's
+    important to note that:
+    - The ptr-like object cannot be a `!ptr.ptr`.
+    - The memory-space of both the `ptr` and ptr-like object must match.
+    - The cast is side-effect free.
+
+    Example:
+
+    ```mlir
+    %ptr0 = ptr.to_ptr %my_ptr : !my.ptr<f32, 0> -> !ptr.ptr<0>
+    %ptr1 = ptr.to_ptr %memref : memref<f32, 0> -> !ptr.ptr<0>
+    ```
+  }];
+
+  let arguments = (ins PtrLikeTypeInterface:$ptr);
+  let results = (outs Ptr_PtrType:$result);
+  let assemblyFormat = [{
+    $ptr attr-dict `:` type($ptr) `->` type($result)
+  }];
+  let hasFolder = 1;
+  let hasVerifier = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // TypeOffsetOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index 4a4f818b46c57..d058f6c4d9651 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -110,6 +110,55 @@ def MemRefElementTypeInterface : TypeInterface<"MemRefElementTypeInterface"> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// PtrLikeTypeInterface
+//===----------------------------------------------------------------------===//
+
+def PtrLikeTypeInterface : TypeInterface<"PtrLikeTypeInterface"> {
+  let cppNamespace = "::mlir";
+  let description = [{
+    A ptr-like type represents an object storing a memory address. This object
+    is constituted by:
+    - A memory address called the base pointer. The base pointer is an 
+      indivisible object.
+    - Optional metadata about the pointer. For example, the size of the  memory
+      region associated with the pointer.
+
+    Furthermore, all ptr-like types have two properties:
+    - The memory space associated with the address held by the pointer.
+    - An optional element type. If the element type is not specified, the
+      pointer is considered opaque.
+  }];
+  let methods = [
+    InterfaceMethod<[{
+      Returns the memory space of this ptr-like type.
+    }],
+    "::mlir::Attribute", "getMemorySpace">,
+    InterfaceMethod<[{
+      Returns the element type of this ptr-like type. Note: this method can
+      return `::mlir::Type()`, in which case the pointer is considered opaque.
+    }],
+    "::mlir::Type", "getElementType">,
+    InterfaceMethod<[{
+      Returns whether this ptr-like type has non-empty metadata.
+    }],
+    "bool", "hasPtrMetadata">,
+    InterfaceMethod<[{
+      Returns a clone of this type with the given memory space and element type,
+      or `failure` if the type cannot be cloned with the specified arguments.
+      If the pointer is opaque and `elementType` is not `std::nullopt` the
+      method will return `failure`.
+
+      If no `elementType` is provided and ptr is not opaque, the `elementType`
+      of this type is used.
+    }],
+    "::llvm::FailureOr<::mlir::PtrLikeTypeInterface>", "clonePtrWith", (ins
+      "::mlir::Attribute":$memorySpace,
+      "::std::optional<::mlir::Type>":$elementType
+    )>
+  ];
+}
+
 //===----------------------------------------------------------------------===//
 // ShapedType
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index df1e02732617d..86ec5c43970b1 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -99,7 +99,9 @@ class TensorType : public Type, public ShapedType::Trait<TensorType> {
 /// Note: This class attaches the ShapedType trait to act as a mixin to
 ///       provide many useful utility functions. This inheritance has no effect
 ///       on derived memref types.
-class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
+class BaseMemRefType : public Type,
+                       public PtrLikeTypeInterface::Trait<BaseMemRefType>,
+                       public ShapedType::Trait<BaseMemRefType> {
 public:
   using Type::Type;
 
@@ -117,6 +119,12 @@ class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
   BaseMemRefType cloneWith(std::optional<ArrayRef<int64_t>> shape,
                            Type elementType) const;
 
+  /// Clone this type with the given memory space and element type. If the
+  /// provided element type is `std::nullopt`, the current element type of the
+  /// type is used.
+  FailureOr<PtrLikeTypeInterface>
+  clonePtrWith(Attribute memorySpace, std::optional<Type> elementType) const;
+
   // Make sure that base class overloads are visible.
   using ShapedType::Trait<BaseMemRefType>::clone;
 
@@ -141,8 +149,16 @@ class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
   /// New `Attribute getMemorySpace()` method should be used instead.
   unsigned getMemorySpaceAsInt() const;
 
+  /// Returns that this ptr-like object has non-empty ptr metadata.
+  bool hasPtrMetadata() const { return true; }
+
   /// Allow implicit conversion to ShapedType.
   operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
+
+  /// Allow implicit conversion to PtrLikeTypeInterface.
+  operator PtrLikeTypeInterface() const {
+    return llvm::cast<PtrLikeTypeInterface>(*this);
+  }
 };
 
 } // namespace mlir
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 771de01fc8d5d..9ad24e45c8315 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -562,6 +562,7 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer",
 //===----------------------------------------------------------------------===//
 
 def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
+    PtrLikeTypeInterface,
     ShapedTypeInterface
   ], "BaseMemRefType"> {
   let summary = "Shaped reference to a region of memory";
@@ -1143,6 +1144,7 @@ def Builtin_Tuple : Builtin_Type<"Tuple", "tuple"> {
 //===----------------------------------------------------------------------===//
 
 def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [
+    PtrLikeTypeInterface,
     ShapedTypeInterface
   ], "BaseMemRefType"> {
   let summary = "Shaped reference, with unknown rank, to a region of memory";
diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
index c21783011452f..80fd7617c9354 100644
--- a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
+++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
@@ -41,6 +41,54 @@ void PtrDialect::initialize() {
       >();
 }
 
+//===----------------------------------------------------------------------===//
+// FromPtrOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult FromPtrOp::fold(FoldAdaptor adaptor) {
+  // Fold the pattern:
+  // %ptr = ptr.to_ptr %v : type -> ptr
+  // (%mda = ptr.get_metadata %v : type)?
+  // %val = ptr.from_ptr %ptr (metadata %mda)? : ptr -> type
+  // To:
+  // %val -> %v
+  auto toPtr = dyn_cast_or_null<ToPtrOp>(getPtr().getDefiningOp());
+  // Cannot fold if it's not a `to_ptr` op or the initial and final types are
+  // different.
+  if (!toPtr || toPtr.getPtr().getType() != getType())
+    return nullptr;
+  Value md = getMetadata();
+  if (!md)
+    return toPtr.getPtr();
+  // Fold if the metadata can be verified to be equal.
+  if (auto mdOp = dyn_cast_or_null<GetMetadataOp>(md.getDefiningOp());
+      mdOp && mdOp.getPtr() == toPtr.getPtr())
+    return toPtr.getPtr();
+  return nullptr;
+}
+
+LogicalResult FromPtrOp::verify() {
+  if (isa<PtrType>(getType()))
+    return emitError() << "the result type cannot be `!ptr.ptr`";
+  if (getType().getMemorySpace() != getPtr().getType().getMemorySpace()) {
+    return emitError()
+           << "expected the input and output to have the same memory space";
+  }
+  bool hasMD = getMetadata() != Value();
+  bool hasTrivialMD = getHasTrivialMetadata();
+  if (hasMD && hasTrivialMD) {
+    return emitError() << "expected either a metadata argument or the "
+                          "`trivial_metadata` flag, not both";
+  }
+  if (getType().hasPtrMetadata() && !(hasMD || hasTrivialMD)) {
+    return emitError() << "expected either a metadata argument or the "
+                          "`trivial_metadata` flag to be set";
+  }
+  if (!getType().hasPtrMetadata() && (hasMD || hasTrivialMD))
+    return emitError() << "expected no metadata specification";
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // PtrAddOp
 //===----------------------------------------------------------------------===//
@@ -55,6 +103,33 @@ OpFoldResult PtrAddOp::fold(FoldAdaptor adaptor) {
   return nullptr;
 }
 
+//===----------------------------------------------------------------------===//
+// ToPtrOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult ToPtrOp::fold(FoldAdaptor adaptor) {
+  // Fold the pattern:
+  // %val = ptr.from_ptr %p (metadata ...)? : ptr -> type
+  // %ptr = ptr.to_ptr %val : type -> ptr
+  // To:
+  // %ptr -> %p
+  auto fromPtr = dyn_cast_or_null<FromPtrOp>(getPtr().getDefiningOp());
+  // Cannot fold if it's not a `from_ptr` op.
+  if (!fromPtr)
+    return nullptr;
+  return fromPtr.getPtr();
+}
+
+LogicalResult ToPtrOp::verify() {
+  if (isa<PtrType>(getPtr().getType()))
+    return emitError() << "the input value cannot be of type `!ptr.ptr`";
+  if (getType().getMemorySpace() != getPtr().getType().getMemorySpace()) {
+    return emitError()
+           << "expected the input and output to have the same memory space";
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // TypeOffsetOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp b/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp
index cab9ca11e679e..7ad2a6bc4c80b 100644
--- a/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp
+++ b/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp
@@ -151,3 +151,15 @@ LogicalResult PtrType::verifyEntries(DataLayoutEntryListRef entries,
   }
   return success();
 }
+
+//===----------------------------------------------------------------------===//
+// Pointer metadata
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+PtrMetadataType::verify(function_ref<InFlightDiagnostic()> emitError,
+                        PtrLikeTypeInterface type) {
+  if (!type.hasPtrMetadata())
+    return emitError() << "the ptr-like type has no metadata";
+  return success();
+}
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index d47e360e9dc13..97bab479c79bf 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -376,6 +376,20 @@ BaseMemRefType BaseMemRefType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
   return builder;
 }
 
+FailureOr<PtrLikeTypeInterface>
+BaseMemRefType::clonePtrWith(Attribute memorySpace,
+                             std::optional<Type> elementType) const {
+  Type eTy = elementType ? *elementType : getElementType();
+  if (llvm::dyn_cast<UnrankedMemRefType>(*this))
+    return cast<PtrLikeTypeInterface>(
+        UnrankedMemRefType::get(eTy, memorySpace));
+
+  MemRefType::Builder builder(llvm::cast<MemRefType>(*this));
+  builder.setElementType(eTy);
+  builder.setMemorySpace(memorySpace);
+  return cast<PtrLikeTypeInterface>(static_cast<MemRefType>(builder));
+}
+
 MemRefType BaseMemRefType::clone(::llvm::ArrayRef<int64_t> shape,
                                  Type elementType) const {
   return ::llvm::cast<MemRefType>(cloneWith(shape, elementType));
diff --git a/mlir/test/Dialect/Ptr/canonicalize.mlir b/mlir/test/Dialect/Ptr/canonicalize.mlir
index ad363d554f247..837f364242beb 100644
--- a/mlir/test/Dialect/Ptr/canonicalize.mlir
+++ b/mlir/test/Dialect/Ptr/canonicalize.mlir
@@ -13,3 +13,61 @@ func.func @zero_offset(%ptr: !ptr.ptr<#ptr.generic_space>) -> !ptr.ptr<#ptr.gene
   %res0 = ptr.ptr_add %ptr, %off : !ptr.ptr<#ptr.generic_space>, index
   return %res0 : !ptr.ptr<#ptr.generic_space>
 }
+
+/// Tests the the `from_ptr` folder.
+// CHECK-LABEL: @test_from_ptr_0
+// CHECK-SAME: (%[[MEM_REF:.*]]: memref<f32, #ptr.generic_space>)
+func.func @test_from_ptr_0(%mr: memref<f32, #ptr.generic_space>) -> memref<f32, #ptr.generic_space> {
+  // CHECK-NOT: ptr.to_ptr
+  // CHECK-NOT: ptr.get_metadata
+  // CHECK-NOT: ptr.from_ptr
+  // CHECK: return %[[MEM_REF]]
+  %ptr = ptr.to_ptr %mr : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
+  %mda = ptr.get_metadata %mr : memref<f32, #ptr.generic_space>
+  %res = ptr.from_ptr %ptr metadata %mda : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+  return %res : memref<f32, #ptr.generic_space>
+}
+
+// CHECK-LABEL: @test_from_ptr_1
+// CHECK-SAME: (%[[MEM_REF:.*]]: memref<f32, #ptr.generic_space>)
+func.func @test_from_ptr_1(%mr: memref<f32, #ptr.generic_space>) -> memref<f32, #ptr.generic_space> {
+  // CHECK-NOT: ptr.to_ptr
+  // CHECK-NOT: ptr.from_ptr
+  // CHECK: return %[[MEM_REF]]
+  %ptr = ptr.to_ptr %mr : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
+  %res = ptr.from_ptr %ptr trivial_metadata : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+  return %res : memref<f32, #ptr.generic_space>
+}
+
+/// Check that the ops cannot be folded because the metadata cannot be guaranteed to be the same.
+// CHECK-LABEL: @test_from_ptr_2
+func.func @test_from_ptr_2(%mr: memref<f32, #ptr.generic_space>, %md: !ptr.ptr_metadata<memref<f32, #ptr.generic_space>>) -> memref<f32, #ptr.generic_space> {
+  // CHECK: ptr.to_ptr
+  // CHECK: ptr.from_ptr
+  %ptr = ptr.to_ptr %mr : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
+  %res = ptr.from_ptr %ptr metadata %md : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+  return %res : memref<f32, #ptr.generic_space>
+}
+
+/// Tests the the `to_ptr` folder.
+// CHECK-LABEL: @test_to_ptr_0
+// CHECK-SAME: (%[[PTR:.*]]: !ptr.ptr<#ptr.generic_space>
+func.func @test_to_ptr_0(%ptr: !ptr.ptr<#ptr.generic_space>, %md: !ptr.ptr_metadata<memref<f32, #ptr.generic_space>>) -> !ptr.ptr<#ptr.generic_space> {
+  // CHECK: return %[[PTR]]
+  // CHECK-NOT: ptr.from_ptr
+  // CHECK-NOT: ptr.to_ptr
+  %mrf = ptr.from_ptr %ptr metadata %md : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+  %res = ptr.to_ptr %mrf : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
+  return %res : !ptr.ptr<#ptr.generic_space>
+}
+
+// CHECK-LABEL: @test_to_ptr_1
+// CHECK-SAME: (%[[PTR:.*]]: !ptr.ptr<#ptr.generic_space>)
+func.func @test_to_ptr_1(%ptr: !ptr.ptr<#ptr.generic_space>) -> !ptr.ptr<#ptr.generic_space> {
+  // CHECK-NOT: ptr.from_ptr
+  // CHECK-NOT: ptr.to_ptr
+  // CHECK: return %[[PTR]]
+  %mrf = ptr.from_ptr %ptr trivial_metadata : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+  %res = ptr.to_ptr %mrf : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
+  return %res : !ptr.ptr<#ptr.generic_space>
+}
diff --git a/mlir/test/Dialect/Ptr/invalid.mlir b/mlir/test/Dialect/Ptr/invalid.mlir
new file mode 100644
index 0000000000000..e776e0ee04f90
--- /dev/null
+++ b/mlir/test/Dialect/Ptr/invalid.mlir
@@ -0,0 +1,33 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics %s
+
+/// Test `to_ptr` verifiers.
+func.func @invalid_to_ptr(%v: memref<f32, 0>) {
+  // expected-error at +1 {{expected the input and output to have the same memory space}}
+  %r = ptr.to_ptr %v : memref<f32, 0> -> !ptr.ptr<#ptr.generic_space>
+  return
+}
+
+// -----
+
+func.func @invalid_to_ptr(%v: !ptr.ptr<#ptr.generic_space>) {
+  // expected-error at +1 {{the input value cannot be of type `!ptr.ptr`}}
+  %r = ptr.to_ptr %v : !ptr.ptr<#ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
+  return
+}
+
+// -----
+
+/// Test `from_ptr` verifiers.
+func.func @invalid_from_ptr(%v: !ptr.ptr<#ptr.generic_space>) {
+  // expected-error at +1 {{expected either a metadata argument or the `trivial_metadata` flag to be set}}
+  %r = ptr.from_ptr %v : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+  return
+}
+
+// -----
+
+func.func @invalid_from_ptr(%v: !ptr.ptr<#ptr.generic_space>, %m: !ptr.ptr_metadata<memref<f32, #ptr.generic_space>>) {
+  // expected-error at +1 {{expected either a metadata argument or the `trivial_metadata` flag, not both}}
+  %r = ptr.from_ptr %v metadata %m trivial_metadata : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space> 
+  return
+}
diff --git a/mlir/test/Dialect/Ptr/ops.mlir b/mlir/test/Dialect/Ptr/ops.mlir
index d763ea221944b..74bff25b4f3e1 100644
--- a/mlir/test/Dialect/Ptr/ops.mlir
+++ b/mlir/test/Dialect/Ptr/ops.mlir
@@ -17,3 +17,13 @@ func.func @ptr_add_type_offset(%ptr: !ptr.ptr<#ptr.generic_space>) -> !ptr.ptr<#
   %res3 = ptr.ptr_add inbounds %ptr, %off : !ptr.ptr<#ptr.generic_space>, index
   return %res : !ptr.ptr<#ptr.generic_space>
 }
+
+/// Check cast ops assembly.
+// CHECK-LABEL: @cast_ops
+func.func @cast_ops(%mr: memref<f32, #ptr.generic_space>) -> memref<f32, #ptr.generic_space> {
+  %ptr = ptr.to_ptr %mr : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
+  %mda = ptr.get_metadata %mr : memref<f32, #ptr.generic_space>
+  %res = ptr.from_ptr %ptr metadata %mda : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+  %mr0 = ptr.from_ptr %ptr trivial_metadata : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+  return %res : memref<f32, #ptr.generic_space>
+}

>From f8b05815c694b8fddc82b3728436fef797a79938 Mon Sep 17 00:00:00 2001
From: Fabian Mora <fmora.dev at gmail.com>
Date: Sun, 27 Apr 2025 07:29:15 -0400
Subject: [PATCH 2/6] Update mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td

Co-authored-by: Mehdi Amini <joker.eph at gmail.com>
---
 mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
index 8ad475c41c8d3..55cc47a41d03b 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
@@ -31,7 +31,7 @@ def Ptr_FromPtrOp : Pointer_Op<"from_ptr", [
     important to note that:
     - The ptr-like object cannot be a `!ptr.ptr`.
     - The memory-space of both the `ptr` and ptr-like object must match.
-    - The cast is side-effect free.
+    - The cast is Pure (no UB and side-effect free).
 
     If the ptr-like object type has metadata, then the operation expects the
     metadata as an argument or expects that the flag `trivial_metadata` is set.

>From 602d2649da4d15799853156b2a0ee521b89bbd95 Mon Sep 17 00:00:00 2001
From: Fabian Mora <6982088+fabianmcg at users.noreply.github.com>
Date: Sun, 27 Apr 2025 13:33:46 +0000
Subject: [PATCH 3/6] add tests for chains of casts

---
 mlir/test/Dialect/Ptr/canonicalize.mlir | 48 +++++++++++++++++++++++++
 1 file changed, 48 insertions(+)

diff --git a/mlir/test/Dialect/Ptr/canonicalize.mlir b/mlir/test/Dialect/Ptr/canonicalize.mlir
index 837f364242beb..2b9c8489f352e 100644
--- a/mlir/test/Dialect/Ptr/canonicalize.mlir
+++ b/mlir/test/Dialect/Ptr/canonicalize.mlir
@@ -49,6 +49,21 @@ func.func @test_from_ptr_2(%mr: memref<f32, #ptr.generic_space>, %md: !ptr.ptr_m
   return %res : memref<f32, #ptr.generic_space>
 }
 
+// Check the folding of `to_ptr -> from_ptr` chains.
+// CHECK-LABEL: @test_from_ptr_3
+// CHECK-SAME: (%[[MEM_REF:.*]]: memref<f32, #ptr.generic_space>)
+func.func @test_from_ptr_3(%mr0: memref<f32, #ptr.generic_space>) -> memref<f32, #ptr.generic_space> {
+  // CHECK-NOT: ptr.to_ptr
+  // CHECK-NOT: ptr.from_ptr
+  // CHECK: return %[[MEM_REF]]
+  %mda = ptr.get_metadata %mr0 : memref<f32, #ptr.generic_space>
+  %ptr0 = ptr.to_ptr %mr0 : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
+  %mrf0 = ptr.from_ptr %ptr0 metadata %mda : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+  %ptr1 = ptr.to_ptr %mrf0 : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
+  %mrf1 = ptr.from_ptr %ptr1 metadata %mda : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+  return %mrf1 : memref<f32, #ptr.generic_space>
+}
+
 /// Tests the the `to_ptr` folder.
 // CHECK-LABEL: @test_to_ptr_0
 // CHECK-SAME: (%[[PTR:.*]]: !ptr.ptr<#ptr.generic_space>
@@ -71,3 +86,36 @@ func.func @test_to_ptr_1(%ptr: !ptr.ptr<#ptr.generic_space>) -> !ptr.ptr<#ptr.ge
   %res = ptr.to_ptr %mrf : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
   return %res : !ptr.ptr<#ptr.generic_space>
 }
+
+// Check the folding of `from_ptr -> to_ptr` chains.
+// CHECK-LABEL: @test_to_ptr_2
+// CHECK-SAME: (%[[PTR:.*]]: !ptr.ptr<#ptr.generic_space>
+func.func @test_to_ptr_2(%ptr0: !ptr.ptr<#ptr.generic_space>) -> !ptr.ptr<#ptr.generic_space> {
+  // CHECK-NOT: ptr.from_ptr
+  // CHECK-NOT: ptr.to_ptr
+  // CHECK: return %[[PTR]]
+  %mrf0 = ptr.from_ptr %ptr0 trivial_metadata : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+  %ptr1 = ptr.to_ptr %mrf0 : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
+  %mrf1 = ptr.from_ptr %ptr1 trivial_metadata : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+  %ptr2 = ptr.to_ptr %mrf1 : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
+  %mrf2 = ptr.from_ptr %ptr2 trivial_metadata : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+  %res = ptr.to_ptr %mrf2 : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
+  return %res : !ptr.ptr<#ptr.generic_space>
+}
+
+// Check the folding of chains with different metadata.
+// CHECK-LABEL: @test_cast_chain_folding
+// CHECK-SAME: (%[[MEM_REF:.*]]: memref<f32, #ptr.generic_space>
+func.func @test_cast_chain_folding(%mr: memref<f32, #ptr.generic_space>, %md: !ptr.ptr_metadata<memref<f32, #ptr.generic_space>>) -> memref<f32, #ptr.generic_space> {
+  // CHECK-NOT: ptr.to_ptr
+  // CHECK-NOT: ptr.from_ptr
+  // CHECK: return %[[MEM_REF]]
+   %ptr1 = ptr.to_ptr %mr : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
+   %memrefWithOtherMd = ptr.from_ptr %ptr1 metadata %md : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+   %ptr = ptr.to_ptr %memrefWithOtherMd : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
+   %mda = ptr.get_metadata %mr : memref<f32, #ptr.generic_space>
+   // The chain can be folded because: the ptr always has the same value because
+   // `to_ptr` is a loss-less cast and %mda comes from the original memref.
+   %res = ptr.from_ptr %ptr metadata %mda : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+   return %res : memref<f32, #ptr.generic_space>
+}

>From 103808075430ec54e0221906db4822ad881e88e2 Mon Sep 17 00:00:00 2001
From: Fabian Mora <6982088+fabianmcg at users.noreply.github.com>
Date: Sat, 10 May 2025 13:34:46 +0000
Subject: [PATCH 4/6] make folders work on cast sequences

---
 mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp | 52 +++++++++++++++++---------
 1 file changed, 34 insertions(+), 18 deletions(-)

diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
index 80fd7617c9354..c0310446e3cea 100644
--- a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
+++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
@@ -52,19 +52,28 @@ OpFoldResult FromPtrOp::fold(FoldAdaptor adaptor) {
   // %val = ptr.from_ptr %ptr (metadata %mda)? : ptr -> type
   // To:
   // %val -> %v
-  auto toPtr = dyn_cast_or_null<ToPtrOp>(getPtr().getDefiningOp());
-  // Cannot fold if it's not a `to_ptr` op or the initial and final types are
-  // different.
-  if (!toPtr || toPtr.getPtr().getType() != getType())
-    return nullptr;
-  Value md = getMetadata();
-  if (!md)
-    return toPtr.getPtr();
-  // Fold if the metadata can be verified to be equal.
-  if (auto mdOp = dyn_cast_or_null<GetMetadataOp>(md.getDefiningOp());
-      mdOp && mdOp.getPtr() == toPtr.getPtr())
-    return toPtr.getPtr();
-  return nullptr;
+  Value ptrLike;
+  FromPtrOp fromPtr = *this;
+  while (fromPtr != nullptr) {
+    auto toPtr = dyn_cast_or_null<ToPtrOp>(fromPtr.getPtr().getDefiningOp());
+    // Cannot fold if it's not a `to_ptr` op or the initial and final types are
+    // different.
+    if (!toPtr || toPtr.getPtr().getType() != fromPtr.getType())
+      return ptrLike;
+    Value md = fromPtr.getMetadata();
+    // If there's no metadata in the op, either the cast never requires metadata
+    // or the op has the trivial metadata flag set, therefore fold.
+    if (!md)
+      ptrLike = toPtr.getPtr();
+    // Fold if the metadata can be verified to be equal.
+    else if (auto mdOp = dyn_cast_or_null<GetMetadataOp>(md.getDefiningOp());
+             mdOp && mdOp.getPtr() == toPtr.getPtr())
+      ptrLike = toPtr.getPtr();
+    // Check for a sequence of casts.
+    fromPtr = dyn_cast_or_null<FromPtrOp>(ptrLike ? ptrLike.getDefiningOp()
+                                                  : nullptr);
+  }
+  return ptrLike;
 }
 
 LogicalResult FromPtrOp::verify() {
@@ -113,11 +122,18 @@ OpFoldResult ToPtrOp::fold(FoldAdaptor adaptor) {
   // %ptr = ptr.to_ptr %val : type -> ptr
   // To:
   // %ptr -> %p
-  auto fromPtr = dyn_cast_or_null<FromPtrOp>(getPtr().getDefiningOp());
-  // Cannot fold if it's not a `from_ptr` op.
-  if (!fromPtr)
-    return nullptr;
-  return fromPtr.getPtr();
+  Value ptr;
+  ToPtrOp toPtr = *this;
+  while (toPtr != nullptr) {
+    auto fromPtr = dyn_cast_or_null<FromPtrOp>(toPtr.getPtr().getDefiningOp());
+    // Cannot fold if it's not a `from_ptr` op.
+    if (!fromPtr)
+      return ptr;
+    ptr = fromPtr.getPtr();
+    // Check for chains of casts.
+    toPtr = dyn_cast_or_null<ToPtrOp>(ptr.getDefiningOp());
+  }
+  return ptr;
 }
 
 LogicalResult ToPtrOp::verify() {

>From ae9b69cabc28ff7c13ad089bc6e11f0558fac0b9 Mon Sep 17 00:00:00 2001
From: Fabian Mora <6982088+fabianmcg at users.noreply.github.com>
Date: Wed, 14 May 2025 16:44:14 +0000
Subject: [PATCH 5/6] remove trivial_metadata flag from from_ptr op

---
 mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td | 17 +++++++----------
 mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp     | 15 +--------------
 mlir/test/Dialect/Ptr/canonicalize.mlir    | 10 +++++-----
 mlir/test/Dialect/Ptr/invalid.mlir         | 17 -----------------
 mlir/test/Dialect/Ptr/ops.mlir             |  2 +-
 5 files changed, 14 insertions(+), 47 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
index 55cc47a41d03b..37eb91fa6a338 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
@@ -33,27 +33,24 @@ def Ptr_FromPtrOp : Pointer_Op<"from_ptr", [
     - The memory-space of both the `ptr` and ptr-like object must match.
     - The cast is Pure (no UB and side-effect free).
 
-    If the ptr-like object type has metadata, then the operation expects the
-    metadata as an argument or expects that the flag `trivial_metadata` is set.
-    If `trivial_metadata` is set, then it is assumed that the metadata can be
-    reconstructed statically from the pointer-like type.
+    The optional `metadata` operand exists to provide any ptr-like metadata
+    that might be required to perform the cast.
 
     Example:
 
     ```mlir
     %typed_ptr = ptr.from_ptr %ptr : !ptr.ptr<0> -> !my.ptr<f32, 0>
     %memref = ptr.from_ptr %ptr metadata %md : !ptr.ptr<0> -> memref<f32, 0>
-    %memref = ptr.from_ptr %ptr trivial_metadata : !ptr.ptr<0> -> memref<f32, 0>
+  
+    // Cast the `%ptr` to a memref without utilizing metadata.
+    %memref = ptr.from_ptr %ptr : !ptr.ptr<0> -> memref<f32, 0>
     ```
   }];
 
-  let arguments = (ins Ptr_PtrType:$ptr,
-                       Optional<Ptr_PtrMetadata>:$metadata,
-                       UnitProp:$hasTrivialMetadata);
+  let arguments = (ins Ptr_PtrType:$ptr, Optional<Ptr_PtrMetadata>:$metadata);
   let results = (outs PtrLikeTypeInterface:$result);
   let assemblyFormat = [{
-    $ptr (`metadata` $metadata^)? (`trivial_metadata` $hasTrivialMetadata^)?
-    attr-dict `:` type($ptr) `->` type($result)
+    $ptr (`metadata` $metadata^)? attr-dict `:` type($ptr) `->` type($result)
   }];
   let hasFolder = 1;
   let hasVerifier = 1;
diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
index c0310446e3cea..ffa924b20ab59 100644
--- a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
+++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
@@ -61,8 +61,7 @@ OpFoldResult FromPtrOp::fold(FoldAdaptor adaptor) {
     if (!toPtr || toPtr.getPtr().getType() != fromPtr.getType())
       return ptrLike;
     Value md = fromPtr.getMetadata();
-    // If there's no metadata in the op, either the cast never requires metadata
-    // or the op has the trivial metadata flag set, therefore fold.
+    // If there's no metadata in the op fold the op.
     if (!md)
       ptrLike = toPtr.getPtr();
     // Fold if the metadata can be verified to be equal.
@@ -83,18 +82,6 @@ LogicalResult FromPtrOp::verify() {
     return emitError()
            << "expected the input and output to have the same memory space";
   }
-  bool hasMD = getMetadata() != Value();
-  bool hasTrivialMD = getHasTrivialMetadata();
-  if (hasMD && hasTrivialMD) {
-    return emitError() << "expected either a metadata argument or the "
-                          "`trivial_metadata` flag, not both";
-  }
-  if (getType().hasPtrMetadata() && !(hasMD || hasTrivialMD)) {
-    return emitError() << "expected either a metadata argument or the "
-                          "`trivial_metadata` flag to be set";
-  }
-  if (!getType().hasPtrMetadata() && (hasMD || hasTrivialMD))
-    return emitError() << "expected no metadata specification";
   return success();
 }
 
diff --git a/mlir/test/Dialect/Ptr/canonicalize.mlir b/mlir/test/Dialect/Ptr/canonicalize.mlir
index 2b9c8489f352e..dfc679acb2ed4 100644
--- a/mlir/test/Dialect/Ptr/canonicalize.mlir
+++ b/mlir/test/Dialect/Ptr/canonicalize.mlir
@@ -35,7 +35,7 @@ func.func @test_from_ptr_1(%mr: memref<f32, #ptr.generic_space>) -> memref<f32,
   // CHECK-NOT: ptr.from_ptr
   // CHECK: return %[[MEM_REF]]
   %ptr = ptr.to_ptr %mr : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
-  %res = ptr.from_ptr %ptr trivial_metadata : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+  %res = ptr.from_ptr %ptr : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
   return %res : memref<f32, #ptr.generic_space>
 }
 
@@ -82,7 +82,7 @@ func.func @test_to_ptr_1(%ptr: !ptr.ptr<#ptr.generic_space>) -> !ptr.ptr<#ptr.ge
   // CHECK-NOT: ptr.from_ptr
   // CHECK-NOT: ptr.to_ptr
   // CHECK: return %[[PTR]]
-  %mrf = ptr.from_ptr %ptr trivial_metadata : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+  %mrf = ptr.from_ptr %ptr : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
   %res = ptr.to_ptr %mrf : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
   return %res : !ptr.ptr<#ptr.generic_space>
 }
@@ -94,11 +94,11 @@ func.func @test_to_ptr_2(%ptr0: !ptr.ptr<#ptr.generic_space>) -> !ptr.ptr<#ptr.g
   // CHECK-NOT: ptr.from_ptr
   // CHECK-NOT: ptr.to_ptr
   // CHECK: return %[[PTR]]
-  %mrf0 = ptr.from_ptr %ptr0 trivial_metadata : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+  %mrf0 = ptr.from_ptr %ptr0 : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
   %ptr1 = ptr.to_ptr %mrf0 : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
-  %mrf1 = ptr.from_ptr %ptr1 trivial_metadata : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+  %mrf1 = ptr.from_ptr %ptr1 : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
   %ptr2 = ptr.to_ptr %mrf1 : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
-  %mrf2 = ptr.from_ptr %ptr2 trivial_metadata : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+  %mrf2 = ptr.from_ptr %ptr2 : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
   %res = ptr.to_ptr %mrf2 : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
   return %res : !ptr.ptr<#ptr.generic_space>
 }
diff --git a/mlir/test/Dialect/Ptr/invalid.mlir b/mlir/test/Dialect/Ptr/invalid.mlir
index e776e0ee04f90..19fd715e5bba6 100644
--- a/mlir/test/Dialect/Ptr/invalid.mlir
+++ b/mlir/test/Dialect/Ptr/invalid.mlir
@@ -14,20 +14,3 @@ func.func @invalid_to_ptr(%v: !ptr.ptr<#ptr.generic_space>) {
   %r = ptr.to_ptr %v : !ptr.ptr<#ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
   return
 }
-
-// -----
-
-/// Test `from_ptr` verifiers.
-func.func @invalid_from_ptr(%v: !ptr.ptr<#ptr.generic_space>) {
-  // expected-error at +1 {{expected either a metadata argument or the `trivial_metadata` flag to be set}}
-  %r = ptr.from_ptr %v : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
-  return
-}
-
-// -----
-
-func.func @invalid_from_ptr(%v: !ptr.ptr<#ptr.generic_space>, %m: !ptr.ptr_metadata<memref<f32, #ptr.generic_space>>) {
-  // expected-error at +1 {{expected either a metadata argument or the `trivial_metadata` flag, not both}}
-  %r = ptr.from_ptr %v metadata %m trivial_metadata : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space> 
-  return
-}
diff --git a/mlir/test/Dialect/Ptr/ops.mlir b/mlir/test/Dialect/Ptr/ops.mlir
index 74bff25b4f3e1..eed3272d98da9 100644
--- a/mlir/test/Dialect/Ptr/ops.mlir
+++ b/mlir/test/Dialect/Ptr/ops.mlir
@@ -24,6 +24,6 @@ func.func @cast_ops(%mr: memref<f32, #ptr.generic_space>) -> memref<f32, #ptr.ge
   %ptr = ptr.to_ptr %mr : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
   %mda = ptr.get_metadata %mr : memref<f32, #ptr.generic_space>
   %res = ptr.from_ptr %ptr metadata %mda : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
-  %mr0 = ptr.from_ptr %ptr trivial_metadata : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+  %mr0 = ptr.from_ptr %ptr : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
   return %res : memref<f32, #ptr.generic_space>
 }

>From b54eded6393d01228c79d7ca336c3511a02f34e2 Mon Sep 17 00:00:00 2001
From: Fabian Mora <6982088+fabianmcg at users.noreply.github.com>
Date: Fri, 6 Jun 2025 22:52:34 +0000
Subject: [PATCH 6/6] address reviewer comments

---
 mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td |  5 +++--
 mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td     | 14 +++++++-------
 mlir/include/mlir/IR/BuiltinTypeInterfaces.td  |  8 ++++++--
 3 files changed, 16 insertions(+), 11 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
index 6631b338db199..7407d74ce3a87 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
@@ -93,8 +93,9 @@ def Ptr_PtrMetadata : Ptr_Type<"PtrMetadata", "ptr_metadata"> {
   let description = [{
     The `ptr_metadata` type represents an opaque-view of the metadata associated
     with a `ptr-like` object type.
-    It's an error to get a `ptr_metadata` using `ptr-like` type with no
-    metadata.
+
+    Note: It's a verification error to construct a `ptr_metadata` type using a
+    `ptr-like` type with no metadata.
 
     Example:
 
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
index 37eb91fa6a338..1523762efc18f 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
@@ -39,11 +39,11 @@ def Ptr_FromPtrOp : Pointer_Op<"from_ptr", [
     Example:
 
     ```mlir
-    %typed_ptr = ptr.from_ptr %ptr : !ptr.ptr<0> -> !my.ptr<f32, 0>
-    %memref = ptr.from_ptr %ptr metadata %md : !ptr.ptr<0> -> memref<f32, 0>
+    %typed_ptr = ptr.from_ptr %ptr : !ptr.ptr<#ptr.generic_space> -> !my.ptr<f32, #ptr.generic_space>
+    %memref = ptr.from_ptr %ptr metadata %md : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
   
     // Cast the `%ptr` to a memref without utilizing metadata.
-    %memref = ptr.from_ptr %ptr : !ptr.ptr<0> -> memref<f32, 0>
+    %memref = ptr.from_ptr %ptr : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
     ```
   }];
 
@@ -98,8 +98,8 @@ def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [
     Example:
 
     ```mlir
-    %x_off  = ptr.ptr_add %x, %off : !ptr.ptr<0>, i32
-    %x_off0 = ptr.ptr_add nusw %x, %off : !ptr.ptr<0>, i32
+    %x_off  = ptr.ptr_add %x, %off : !ptr.ptr<#ptr.generic_space>, i32
+    %x_off0 = ptr.ptr_add nusw %x, %off : !ptr.ptr<#ptr.generic_space>, i32
     ```
   }];
 
@@ -134,8 +134,8 @@ def Ptr_ToPtrOp : Pointer_Op<"to_ptr", [Pure]> {
     Example:
 
     ```mlir
-    %ptr0 = ptr.to_ptr %my_ptr : !my.ptr<f32, 0> -> !ptr.ptr<0>
-    %ptr1 = ptr.to_ptr %memref : memref<f32, 0> -> !ptr.ptr<0>
+    %ptr0 = ptr.to_ptr %my_ptr : !my.ptr<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
+    %ptr1 = ptr.to_ptr %memref : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
     ```
   }];
 
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index d058f6c4d9651..367aeb6ac512b 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -119,8 +119,12 @@ def PtrLikeTypeInterface : TypeInterface<"PtrLikeTypeInterface"> {
   let description = [{
     A ptr-like type represents an object storing a memory address. This object
     is constituted by:
-    - A memory address called the base pointer. The base pointer is an 
-      indivisible object.
+    - A memory address called the base pointer. This pointer is treated as a
+      bag of bits without any assumed structure. The bit-width of the base
+      pointer must be a compile-time constant. However, the bit-width may remain
+      opaque or unavailable during transformations that do not depend on the
+      base pointer. Finally, it is considered indivisible in the sense that as
+      a `PtrLikeTypeInterface` value, it has no metadata.
     - Optional metadata about the pointer. For example, the size of the  memory
       region associated with the pointer.
 



More information about the Mlir-commits mailing list