[Mlir-commits] [mlir] [mlir][core|ptr] Add `PtrLikeTypeInterface` and casting ops to the `ptr` dialect (PR #137469)

Fabian Mora llvmlistbot at llvm.org
Sun Apr 27 06:34:19 PDT 2025


https://github.com/fabianmcg updated https://github.com/llvm/llvm-project/pull/137469

>From 953761977293abcd98437152a106d18ed533d78c Mon Sep 17 00:00:00 2001
From: Fabian Mora <6982088+fabianmcg at users.noreply.github.com>
Date: Sat, 26 Apr 2025 19:25:54 +0000
Subject: [PATCH 1/3] [mlir][core|ptr] Add `PtrLikeTypeInterface` and casting
 ops to the `ptr` dialect

This patch adds the `PtrLikeTypeInterface` type interface to identify pointer-like types.
This interface is defined as:

```
A ptr-like type represents an object storing a memory address. This object
is constituted by:
- A memory address called the base pointer. The base pointer is an
  indivisible object.
- Optional metadata about the pointer. For example, the size of the  memory
  region associated with the pointer.

Furthermore, all ptr-like types have two properties:
- The memory space associated with the address held by the pointer.
- An optional element type. If the element type is not specified, the
  pointer is considered opaque.
```

This patch adds this interface to `!ptr.ptr` and the `memref` type.

Furthermore, this patch adds necessary ops and type to handle casting between `!ptr.ptr`
and ptr-like types.

First, it defines the `!ptr.ptr_metadata` type. An opaque type to represent the metadata
of a ptr-like type. The rationale behind adding this type, is that at high-level the
metadata of a type like `memref` cannot be specified, as its structure is tied to its
lowering.

The `ptr.get_metadata` operation was added to extract the opaque pointer metadata. The
concrete structure of the metadata is only known when the op is lowered.

Finally, this patch adds the `ptr.from_ptr` and `ptr.to_ptr` operations. Allowing to
cast back and forth between `!ptr.ptr` and ptr-liker types.

```mlir
func.func @func(%mr: memref<f32, #ptr.generic_space>) -> memref<f32, #ptr.generic_space> {
  %ptr = ptr.to_ptr %mr : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
  %mda = ptr.get_metadata %mr : memref<f32, #ptr.generic_space>
  %res = ptr.from_ptr %ptr metadata %mda : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
  return %res : memref<f32, #ptr.generic_space>
}
```
---
 .../include/mlir/Dialect/Ptr/IR/PtrDialect.td | 49 +++++++++
 mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td    | 99 +++++++++++++++++++
 mlir/include/mlir/IR/BuiltinTypeInterfaces.td | 49 +++++++++
 mlir/include/mlir/IR/BuiltinTypes.h           | 18 +++-
 mlir/include/mlir/IR/BuiltinTypes.td          |  2 +
 mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp        | 75 ++++++++++++++
 mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp          | 12 +++
 mlir/lib/IR/BuiltinTypes.cpp                  | 14 +++
 mlir/test/Dialect/Ptr/canonicalize.mlir       | 58 +++++++++++
 mlir/test/Dialect/Ptr/invalid.mlir            | 33 +++++++
 mlir/test/Dialect/Ptr/ops.mlir                | 10 ++
 11 files changed, 418 insertions(+), 1 deletion(-)
 create mode 100644 mlir/test/Dialect/Ptr/invalid.mlir

diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
index 73b2a0857cef3..6631b338db199 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
@@ -37,6 +37,7 @@ class Ptr_Type<string name, string typeMnemonic, list<Trait> traits = []>
 
 def Ptr_PtrType : Ptr_Type<"Ptr", "ptr", [
     MemRefElementTypeInterface,
+    PtrLikeTypeInterface,
     VectorElementTypeInterface,
     DeclareTypeInterfaceMethods<DataLayoutTypeInterface, [
       "areCompatible", "getIndexBitwidth", "verifyEntries",
@@ -63,6 +64,54 @@ def Ptr_PtrType : Ptr_Type<"Ptr", "ptr", [
       return $_get(memorySpace.getContext(), memorySpace);
     }]>
   ];
+  let extraClassDeclaration = [{
+    // `PtrLikeTypeInterface` interface methods.
+    /// Returns `Type()` as this pointer type is opaque.
+    Type getElementType() const {
+      return Type();
+    }
+    /// Clones the pointer with specified memory space or returns failure
+    /// if an `elementType` was specified or if the memory space doesn't
+    /// implement `MemorySpaceAttrInterface`.
+    FailureOr<PtrLikeTypeInterface> clonePtrWith(Attribute memorySpace,
+      std::optional<Type> elementType) const {
+      if (elementType)
+        return failure();
+      if (auto ms = dyn_cast<MemorySpaceAttrInterface>(memorySpace))
+        return cast<PtrLikeTypeInterface>(get(ms));
+      return failure();
+    }
+    /// `!ptr.ptr` types are seen as ptr-like objects with no metadata.
+    bool hasPtrMetadata() const {
+      return false;
+    }
+  }];
+}
+
+def Ptr_PtrMetadata : Ptr_Type<"PtrMetadata", "ptr_metadata"> {
+  let summary = "Pointer metadata type";
+  let description = [{
+    The `ptr_metadata` type represents an opaque-view of the metadata associated
+    with a `ptr-like` object type.
+    It's an error to get a `ptr_metadata` using `ptr-like` type with no
+    metadata.
+
+    Example:
+
+    ```mlir
+    // The metadata associated with a `memref` type.
+    !ptr.ptr_metadata<memref<f32>>
+    ```
+  }];
+  let parameters = (ins "PtrLikeTypeInterface":$type);
+  let assemblyFormat = "`<` $type `>`";
+  let builders = [
+    TypeBuilderWithInferredContext<(ins
+      "PtrLikeTypeInterface":$ptrLike), [{
+      return $_get(ptrLike.getContext(), ptrLike);
+    }]>
+  ];
+  let genVerifyDecl = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
index 791b95ad3559e..8ad475c41c8d3 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
@@ -17,6 +17,75 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Interfaces/ViewLikeInterface.td"
 include "mlir/IR/OpAsmInterface.td"
 
+//===----------------------------------------------------------------------===//
+// FromPtrOp
+//===----------------------------------------------------------------------===//
+
+def Ptr_FromPtrOp : Pointer_Op<"from_ptr", [
+    Pure, OptionalTypesMatchWith<"metadata type", "result", "metadata",
+            "PtrMetadataType::get(cast<PtrLikeTypeInterface>($_self))">
+  ]> {
+  let summary = "Casts a `!ptr.ptr` value to a ptr-like value.";
+  let description = [{
+    The `from_ptr` operation casts a `ptr` value to a ptr-like object. It's
+    important to note that:
+    - The ptr-like object cannot be a `!ptr.ptr`.
+    - The memory-space of both the `ptr` and ptr-like object must match.
+    - The cast is side-effect free.
+
+    If the ptr-like object type has metadata, then the operation expects the
+    metadata as an argument or expects that the flag `trivial_metadata` is set.
+    If `trivial_metadata` is set, then it is assumed that the metadata can be
+    reconstructed statically from the pointer-like type.
+
+    Example:
+
+    ```mlir
+    %typed_ptr = ptr.from_ptr %ptr : !ptr.ptr<0> -> !my.ptr<f32, 0>
+    %memref = ptr.from_ptr %ptr metadata %md : !ptr.ptr<0> -> memref<f32, 0>
+    %memref = ptr.from_ptr %ptr trivial_metadata : !ptr.ptr<0> -> memref<f32, 0>
+    ```
+  }];
+
+  let arguments = (ins Ptr_PtrType:$ptr,
+                       Optional<Ptr_PtrMetadata>:$metadata,
+                       UnitProp:$hasTrivialMetadata);
+  let results = (outs PtrLikeTypeInterface:$result);
+  let assemblyFormat = [{
+    $ptr (`metadata` $metadata^)? (`trivial_metadata` $hasTrivialMetadata^)?
+    attr-dict `:` type($ptr) `->` type($result)
+  }];
+  let hasFolder = 1;
+  let hasVerifier = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// GetMetadataOp
+//===----------------------------------------------------------------------===//
+
+def Ptr_GetMetadataOp : Pointer_Op<"get_metadata", [
+    Pure, TypesMatchWith<"metadata type", "ptr", "result",
+            "PtrMetadataType::get(cast<PtrLikeTypeInterface>($_self))">
+  ]> {
+  let summary = "SSA value representing pointer metadata.";
+  let description = [{
+    The `get_metadata` operation produces an opaque value that encodes the
+    metadata of the ptr-like type.
+
+    Example:
+
+    ```mlir
+    %metadata = ptr.get_metadata %memref : memref<?x?xf32>
+    ```
+  }];
+
+  let arguments = (ins PtrLikeTypeInterface:$ptr);
+  let results = (outs Ptr_PtrMetadata:$result);
+  let assemblyFormat = [{
+    $ptr attr-dict `:` type($ptr)
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // PtrAddOp
 //===----------------------------------------------------------------------===//
@@ -52,6 +121,36 @@ def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// ToPtrOp
+//===----------------------------------------------------------------------===//
+
+def Ptr_ToPtrOp : Pointer_Op<"to_ptr", [Pure]> {
+  let summary = "Casts a ptr-like value to a `!ptr.ptr` value.";
+  let description = [{
+    The `to_ptr` operation casts a ptr-like object to a `!ptr.ptr`. It's
+    important to note that:
+    - The ptr-like object cannot be a `!ptr.ptr`.
+    - The memory-space of both the `ptr` and ptr-like object must match.
+    - The cast is side-effect free.
+
+    Example:
+
+    ```mlir
+    %ptr0 = ptr.to_ptr %my_ptr : !my.ptr<f32, 0> -> !ptr.ptr<0>
+    %ptr1 = ptr.to_ptr %memref : memref<f32, 0> -> !ptr.ptr<0>
+    ```
+  }];
+
+  let arguments = (ins PtrLikeTypeInterface:$ptr);
+  let results = (outs Ptr_PtrType:$result);
+  let assemblyFormat = [{
+    $ptr attr-dict `:` type($ptr) `->` type($result)
+  }];
+  let hasFolder = 1;
+  let hasVerifier = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // TypeOffsetOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index 4a4f818b46c57..d058f6c4d9651 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -110,6 +110,55 @@ def MemRefElementTypeInterface : TypeInterface<"MemRefElementTypeInterface"> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// PtrLikeTypeInterface
+//===----------------------------------------------------------------------===//
+
+def PtrLikeTypeInterface : TypeInterface<"PtrLikeTypeInterface"> {
+  let cppNamespace = "::mlir";
+  let description = [{
+    A ptr-like type represents an object storing a memory address. This object
+    is constituted by:
+    - A memory address called the base pointer. The base pointer is an 
+      indivisible object.
+    - Optional metadata about the pointer. For example, the size of the  memory
+      region associated with the pointer.
+
+    Furthermore, all ptr-like types have two properties:
+    - The memory space associated with the address held by the pointer.
+    - An optional element type. If the element type is not specified, the
+      pointer is considered opaque.
+  }];
+  let methods = [
+    InterfaceMethod<[{
+      Returns the memory space of this ptr-like type.
+    }],
+    "::mlir::Attribute", "getMemorySpace">,
+    InterfaceMethod<[{
+      Returns the element type of this ptr-like type. Note: this method can
+      return `::mlir::Type()`, in which case the pointer is considered opaque.
+    }],
+    "::mlir::Type", "getElementType">,
+    InterfaceMethod<[{
+      Returns whether this ptr-like type has non-empty metadata.
+    }],
+    "bool", "hasPtrMetadata">,
+    InterfaceMethod<[{
+      Returns a clone of this type with the given memory space and element type,
+      or `failure` if the type cannot be cloned with the specified arguments.
+      If the pointer is opaque and `elementType` is not `std::nullopt` the
+      method will return `failure`.
+
+      If no `elementType` is provided and ptr is not opaque, the `elementType`
+      of this type is used.
+    }],
+    "::llvm::FailureOr<::mlir::PtrLikeTypeInterface>", "clonePtrWith", (ins
+      "::mlir::Attribute":$memorySpace,
+      "::std::optional<::mlir::Type>":$elementType
+    )>
+  ];
+}
+
 //===----------------------------------------------------------------------===//
 // ShapedType
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index df1e02732617d..86ec5c43970b1 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -99,7 +99,9 @@ class TensorType : public Type, public ShapedType::Trait<TensorType> {
 /// Note: This class attaches the ShapedType trait to act as a mixin to
 ///       provide many useful utility functions. This inheritance has no effect
 ///       on derived memref types.
-class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
+class BaseMemRefType : public Type,
+                       public PtrLikeTypeInterface::Trait<BaseMemRefType>,
+                       public ShapedType::Trait<BaseMemRefType> {
 public:
   using Type::Type;
 
@@ -117,6 +119,12 @@ class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
   BaseMemRefType cloneWith(std::optional<ArrayRef<int64_t>> shape,
                            Type elementType) const;
 
+  /// Clone this type with the given memory space and element type. If the
+  /// provided element type is `std::nullopt`, the current element type of the
+  /// type is used.
+  FailureOr<PtrLikeTypeInterface>
+  clonePtrWith(Attribute memorySpace, std::optional<Type> elementType) const;
+
   // Make sure that base class overloads are visible.
   using ShapedType::Trait<BaseMemRefType>::clone;
 
@@ -141,8 +149,16 @@ class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
   /// New `Attribute getMemorySpace()` method should be used instead.
   unsigned getMemorySpaceAsInt() const;
 
+  /// Returns that this ptr-like object has non-empty ptr metadata.
+  bool hasPtrMetadata() const { return true; }
+
   /// Allow implicit conversion to ShapedType.
   operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
+
+  /// Allow implicit conversion to PtrLikeTypeInterface.
+  operator PtrLikeTypeInterface() const {
+    return llvm::cast<PtrLikeTypeInterface>(*this);
+  }
 };
 
 } // namespace mlir
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 771de01fc8d5d..9ad24e45c8315 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -562,6 +562,7 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer",
 //===----------------------------------------------------------------------===//
 
 def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
+    PtrLikeTypeInterface,
     ShapedTypeInterface
   ], "BaseMemRefType"> {
   let summary = "Shaped reference to a region of memory";
@@ -1143,6 +1144,7 @@ def Builtin_Tuple : Builtin_Type<"Tuple", "tuple"> {
 //===----------------------------------------------------------------------===//
 
 def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [
+    PtrLikeTypeInterface,
     ShapedTypeInterface
   ], "BaseMemRefType"> {
   let summary = "Shaped reference, with unknown rank, to a region of memory";
diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
index c21783011452f..80fd7617c9354 100644
--- a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
+++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
@@ -41,6 +41,54 @@ void PtrDialect::initialize() {
       >();
 }
 
+//===----------------------------------------------------------------------===//
+// FromPtrOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult FromPtrOp::fold(FoldAdaptor adaptor) {
+  // Fold the pattern:
+  // %ptr = ptr.to_ptr %v : type -> ptr
+  // (%mda = ptr.get_metadata %v : type)?
+  // %val = ptr.from_ptr %ptr (metadata %mda)? : ptr -> type
+  // To:
+  // %val -> %v
+  auto toPtr = dyn_cast_or_null<ToPtrOp>(getPtr().getDefiningOp());
+  // Cannot fold if it's not a `to_ptr` op or the initial and final types are
+  // different.
+  if (!toPtr || toPtr.getPtr().getType() != getType())
+    return nullptr;
+  Value md = getMetadata();
+  if (!md)
+    return toPtr.getPtr();
+  // Fold if the metadata can be verified to be equal.
+  if (auto mdOp = dyn_cast_or_null<GetMetadataOp>(md.getDefiningOp());
+      mdOp && mdOp.getPtr() == toPtr.getPtr())
+    return toPtr.getPtr();
+  return nullptr;
+}
+
+LogicalResult FromPtrOp::verify() {
+  if (isa<PtrType>(getType()))
+    return emitError() << "the result type cannot be `!ptr.ptr`";
+  if (getType().getMemorySpace() != getPtr().getType().getMemorySpace()) {
+    return emitError()
+           << "expected the input and output to have the same memory space";
+  }
+  bool hasMD = getMetadata() != Value();
+  bool hasTrivialMD = getHasTrivialMetadata();
+  if (hasMD && hasTrivialMD) {
+    return emitError() << "expected either a metadata argument or the "
+                          "`trivial_metadata` flag, not both";
+  }
+  if (getType().hasPtrMetadata() && !(hasMD || hasTrivialMD)) {
+    return emitError() << "expected either a metadata argument or the "
+                          "`trivial_metadata` flag to be set";
+  }
+  if (!getType().hasPtrMetadata() && (hasMD || hasTrivialMD))
+    return emitError() << "expected no metadata specification";
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // PtrAddOp
 //===----------------------------------------------------------------------===//
@@ -55,6 +103,33 @@ OpFoldResult PtrAddOp::fold(FoldAdaptor adaptor) {
   return nullptr;
 }
 
+//===----------------------------------------------------------------------===//
+// ToPtrOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult ToPtrOp::fold(FoldAdaptor adaptor) {
+  // Fold the pattern:
+  // %val = ptr.from_ptr %p (metadata ...)? : ptr -> type
+  // %ptr = ptr.to_ptr %val : type -> ptr
+  // To:
+  // %ptr -> %p
+  auto fromPtr = dyn_cast_or_null<FromPtrOp>(getPtr().getDefiningOp());
+  // Cannot fold if it's not a `from_ptr` op.
+  if (!fromPtr)
+    return nullptr;
+  return fromPtr.getPtr();
+}
+
+LogicalResult ToPtrOp::verify() {
+  if (isa<PtrType>(getPtr().getType()))
+    return emitError() << "the input value cannot be of type `!ptr.ptr`";
+  if (getType().getMemorySpace() != getPtr().getType().getMemorySpace()) {
+    return emitError()
+           << "expected the input and output to have the same memory space";
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // TypeOffsetOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp b/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp
index cab9ca11e679e..7ad2a6bc4c80b 100644
--- a/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp
+++ b/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp
@@ -151,3 +151,15 @@ LogicalResult PtrType::verifyEntries(DataLayoutEntryListRef entries,
   }
   return success();
 }
+
+//===----------------------------------------------------------------------===//
+// Pointer metadata
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+PtrMetadataType::verify(function_ref<InFlightDiagnostic()> emitError,
+                        PtrLikeTypeInterface type) {
+  if (!type.hasPtrMetadata())
+    return emitError() << "the ptr-like type has no metadata";
+  return success();
+}
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 3924d082f0628..3032b68c1fdd4 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -376,6 +376,20 @@ BaseMemRefType BaseMemRefType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
   return builder;
 }
 
+FailureOr<PtrLikeTypeInterface>
+BaseMemRefType::clonePtrWith(Attribute memorySpace,
+                             std::optional<Type> elementType) const {
+  Type eTy = elementType ? *elementType : getElementType();
+  if (llvm::dyn_cast<UnrankedMemRefType>(*this))
+    return cast<PtrLikeTypeInterface>(
+        UnrankedMemRefType::get(eTy, memorySpace));
+
+  MemRefType::Builder builder(llvm::cast<MemRefType>(*this));
+  builder.setElementType(eTy);
+  builder.setMemorySpace(memorySpace);
+  return cast<PtrLikeTypeInterface>(static_cast<MemRefType>(builder));
+}
+
 MemRefType BaseMemRefType::clone(::llvm::ArrayRef<int64_t> shape,
                                  Type elementType) const {
   return ::llvm::cast<MemRefType>(cloneWith(shape, elementType));
diff --git a/mlir/test/Dialect/Ptr/canonicalize.mlir b/mlir/test/Dialect/Ptr/canonicalize.mlir
index ad363d554f247..837f364242beb 100644
--- a/mlir/test/Dialect/Ptr/canonicalize.mlir
+++ b/mlir/test/Dialect/Ptr/canonicalize.mlir
@@ -13,3 +13,61 @@ func.func @zero_offset(%ptr: !ptr.ptr<#ptr.generic_space>) -> !ptr.ptr<#ptr.gene
   %res0 = ptr.ptr_add %ptr, %off : !ptr.ptr<#ptr.generic_space>, index
   return %res0 : !ptr.ptr<#ptr.generic_space>
 }
+
+/// Tests the the `from_ptr` folder.
+// CHECK-LABEL: @test_from_ptr_0
+// CHECK-SAME: (%[[MEM_REF:.*]]: memref<f32, #ptr.generic_space>)
+func.func @test_from_ptr_0(%mr: memref<f32, #ptr.generic_space>) -> memref<f32, #ptr.generic_space> {
+  // CHECK-NOT: ptr.to_ptr
+  // CHECK-NOT: ptr.get_metadata
+  // CHECK-NOT: ptr.from_ptr
+  // CHECK: return %[[MEM_REF]]
+  %ptr = ptr.to_ptr %mr : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
+  %mda = ptr.get_metadata %mr : memref<f32, #ptr.generic_space>
+  %res = ptr.from_ptr %ptr metadata %mda : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+  return %res : memref<f32, #ptr.generic_space>
+}
+
+// CHECK-LABEL: @test_from_ptr_1
+// CHECK-SAME: (%[[MEM_REF:.*]]: memref<f32, #ptr.generic_space>)
+func.func @test_from_ptr_1(%mr: memref<f32, #ptr.generic_space>) -> memref<f32, #ptr.generic_space> {
+  // CHECK-NOT: ptr.to_ptr
+  // CHECK-NOT: ptr.from_ptr
+  // CHECK: return %[[MEM_REF]]
+  %ptr = ptr.to_ptr %mr : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
+  %res = ptr.from_ptr %ptr trivial_metadata : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+  return %res : memref<f32, #ptr.generic_space>
+}
+
+/// Check that the ops cannot be folded because the metadata cannot be guaranteed to be the same.
+// CHECK-LABEL: @test_from_ptr_2
+func.func @test_from_ptr_2(%mr: memref<f32, #ptr.generic_space>, %md: !ptr.ptr_metadata<memref<f32, #ptr.generic_space>>) -> memref<f32, #ptr.generic_space> {
+  // CHECK: ptr.to_ptr
+  // CHECK: ptr.from_ptr
+  %ptr = ptr.to_ptr %mr : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
+  %res = ptr.from_ptr %ptr metadata %md : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+  return %res : memref<f32, #ptr.generic_space>
+}
+
+/// Tests the the `to_ptr` folder.
+// CHECK-LABEL: @test_to_ptr_0
+// CHECK-SAME: (%[[PTR:.*]]: !ptr.ptr<#ptr.generic_space>
+func.func @test_to_ptr_0(%ptr: !ptr.ptr<#ptr.generic_space>, %md: !ptr.ptr_metadata<memref<f32, #ptr.generic_space>>) -> !ptr.ptr<#ptr.generic_space> {
+  // CHECK: return %[[PTR]]
+  // CHECK-NOT: ptr.from_ptr
+  // CHECK-NOT: ptr.to_ptr
+  %mrf = ptr.from_ptr %ptr metadata %md : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+  %res = ptr.to_ptr %mrf : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
+  return %res : !ptr.ptr<#ptr.generic_space>
+}
+
+// CHECK-LABEL: @test_to_ptr_1
+// CHECK-SAME: (%[[PTR:.*]]: !ptr.ptr<#ptr.generic_space>)
+func.func @test_to_ptr_1(%ptr: !ptr.ptr<#ptr.generic_space>) -> !ptr.ptr<#ptr.generic_space> {
+  // CHECK-NOT: ptr.from_ptr
+  // CHECK-NOT: ptr.to_ptr
+  // CHECK: return %[[PTR]]
+  %mrf = ptr.from_ptr %ptr trivial_metadata : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+  %res = ptr.to_ptr %mrf : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
+  return %res : !ptr.ptr<#ptr.generic_space>
+}
diff --git a/mlir/test/Dialect/Ptr/invalid.mlir b/mlir/test/Dialect/Ptr/invalid.mlir
new file mode 100644
index 0000000000000..e776e0ee04f90
--- /dev/null
+++ b/mlir/test/Dialect/Ptr/invalid.mlir
@@ -0,0 +1,33 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics %s
+
+/// Test `to_ptr` verifiers.
+func.func @invalid_to_ptr(%v: memref<f32, 0>) {
+  // expected-error at +1 {{expected the input and output to have the same memory space}}
+  %r = ptr.to_ptr %v : memref<f32, 0> -> !ptr.ptr<#ptr.generic_space>
+  return
+}
+
+// -----
+
+func.func @invalid_to_ptr(%v: !ptr.ptr<#ptr.generic_space>) {
+  // expected-error at +1 {{the input value cannot be of type `!ptr.ptr`}}
+  %r = ptr.to_ptr %v : !ptr.ptr<#ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
+  return
+}
+
+// -----
+
+/// Test `from_ptr` verifiers.
+func.func @invalid_from_ptr(%v: !ptr.ptr<#ptr.generic_space>) {
+  // expected-error at +1 {{expected either a metadata argument or the `trivial_metadata` flag to be set}}
+  %r = ptr.from_ptr %v : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+  return
+}
+
+// -----
+
+func.func @invalid_from_ptr(%v: !ptr.ptr<#ptr.generic_space>, %m: !ptr.ptr_metadata<memref<f32, #ptr.generic_space>>) {
+  // expected-error at +1 {{expected either a metadata argument or the `trivial_metadata` flag, not both}}
+  %r = ptr.from_ptr %v metadata %m trivial_metadata : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space> 
+  return
+}
diff --git a/mlir/test/Dialect/Ptr/ops.mlir b/mlir/test/Dialect/Ptr/ops.mlir
index d763ea221944b..74bff25b4f3e1 100644
--- a/mlir/test/Dialect/Ptr/ops.mlir
+++ b/mlir/test/Dialect/Ptr/ops.mlir
@@ -17,3 +17,13 @@ func.func @ptr_add_type_offset(%ptr: !ptr.ptr<#ptr.generic_space>) -> !ptr.ptr<#
   %res3 = ptr.ptr_add inbounds %ptr, %off : !ptr.ptr<#ptr.generic_space>, index
   return %res : !ptr.ptr<#ptr.generic_space>
 }
+
+/// Check cast ops assembly.
+// CHECK-LABEL: @cast_ops
+func.func @cast_ops(%mr: memref<f32, #ptr.generic_space>) -> memref<f32, #ptr.generic_space> {
+  %ptr = ptr.to_ptr %mr : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
+  %mda = ptr.get_metadata %mr : memref<f32, #ptr.generic_space>
+  %res = ptr.from_ptr %ptr metadata %mda : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+  %mr0 = ptr.from_ptr %ptr trivial_metadata : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+  return %res : memref<f32, #ptr.generic_space>
+}

>From 40acb7bd0b9f1cd383a37fb014756ce69fc90609 Mon Sep 17 00:00:00 2001
From: Fabian Mora <fmora.dev at gmail.com>
Date: Sun, 27 Apr 2025 07:29:15 -0400
Subject: [PATCH 2/3] Update mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td

Co-authored-by: Mehdi Amini <joker.eph at gmail.com>
---
 mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
index 8ad475c41c8d3..55cc47a41d03b 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
@@ -31,7 +31,7 @@ def Ptr_FromPtrOp : Pointer_Op<"from_ptr", [
     important to note that:
     - The ptr-like object cannot be a `!ptr.ptr`.
     - The memory-space of both the `ptr` and ptr-like object must match.
-    - The cast is side-effect free.
+    - The cast is Pure (no UB and side-effect free).
 
     If the ptr-like object type has metadata, then the operation expects the
     metadata as an argument or expects that the flag `trivial_metadata` is set.

>From 9c5c7b0cac841ab411ec765d684a44e8bea4d9dd Mon Sep 17 00:00:00 2001
From: Fabian Mora <6982088+fabianmcg at users.noreply.github.com>
Date: Sun, 27 Apr 2025 13:33:46 +0000
Subject: [PATCH 3/3] add tests for chains of casts

---
 mlir/test/Dialect/Ptr/canonicalize.mlir | 48 +++++++++++++++++++++++++
 1 file changed, 48 insertions(+)

diff --git a/mlir/test/Dialect/Ptr/canonicalize.mlir b/mlir/test/Dialect/Ptr/canonicalize.mlir
index 837f364242beb..2b9c8489f352e 100644
--- a/mlir/test/Dialect/Ptr/canonicalize.mlir
+++ b/mlir/test/Dialect/Ptr/canonicalize.mlir
@@ -49,6 +49,21 @@ func.func @test_from_ptr_2(%mr: memref<f32, #ptr.generic_space>, %md: !ptr.ptr_m
   return %res : memref<f32, #ptr.generic_space>
 }
 
+// Check the folding of `to_ptr -> from_ptr` chains.
+// CHECK-LABEL: @test_from_ptr_3
+// CHECK-SAME: (%[[MEM_REF:.*]]: memref<f32, #ptr.generic_space>)
+func.func @test_from_ptr_3(%mr0: memref<f32, #ptr.generic_space>) -> memref<f32, #ptr.generic_space> {
+  // CHECK-NOT: ptr.to_ptr
+  // CHECK-NOT: ptr.from_ptr
+  // CHECK: return %[[MEM_REF]]
+  %mda = ptr.get_metadata %mr0 : memref<f32, #ptr.generic_space>
+  %ptr0 = ptr.to_ptr %mr0 : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
+  %mrf0 = ptr.from_ptr %ptr0 metadata %mda : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+  %ptr1 = ptr.to_ptr %mrf0 : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
+  %mrf1 = ptr.from_ptr %ptr1 metadata %mda : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+  return %mrf1 : memref<f32, #ptr.generic_space>
+}
+
 /// Tests the the `to_ptr` folder.
 // CHECK-LABEL: @test_to_ptr_0
 // CHECK-SAME: (%[[PTR:.*]]: !ptr.ptr<#ptr.generic_space>
@@ -71,3 +86,36 @@ func.func @test_to_ptr_1(%ptr: !ptr.ptr<#ptr.generic_space>) -> !ptr.ptr<#ptr.ge
   %res = ptr.to_ptr %mrf : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
   return %res : !ptr.ptr<#ptr.generic_space>
 }
+
+// Check the folding of `from_ptr -> to_ptr` chains.
+// CHECK-LABEL: @test_to_ptr_2
+// CHECK-SAME: (%[[PTR:.*]]: !ptr.ptr<#ptr.generic_space>
+func.func @test_to_ptr_2(%ptr0: !ptr.ptr<#ptr.generic_space>) -> !ptr.ptr<#ptr.generic_space> {
+  // CHECK-NOT: ptr.from_ptr
+  // CHECK-NOT: ptr.to_ptr
+  // CHECK: return %[[PTR]]
+  %mrf0 = ptr.from_ptr %ptr0 trivial_metadata : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+  %ptr1 = ptr.to_ptr %mrf0 : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
+  %mrf1 = ptr.from_ptr %ptr1 trivial_metadata : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+  %ptr2 = ptr.to_ptr %mrf1 : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
+  %mrf2 = ptr.from_ptr %ptr2 trivial_metadata : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+  %res = ptr.to_ptr %mrf2 : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
+  return %res : !ptr.ptr<#ptr.generic_space>
+}
+
+// Check the folding of chains with different metadata.
+// CHECK-LABEL: @test_cast_chain_folding
+// CHECK-SAME: (%[[MEM_REF:.*]]: memref<f32, #ptr.generic_space>
+func.func @test_cast_chain_folding(%mr: memref<f32, #ptr.generic_space>, %md: !ptr.ptr_metadata<memref<f32, #ptr.generic_space>>) -> memref<f32, #ptr.generic_space> {
+  // CHECK-NOT: ptr.to_ptr
+  // CHECK-NOT: ptr.from_ptr
+  // CHECK: return %[[MEM_REF]]
+   %ptr1 = ptr.to_ptr %mr : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
+   %memrefWithOtherMd = ptr.from_ptr %ptr1 metadata %md : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+   %ptr = ptr.to_ptr %memrefWithOtherMd : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
+   %mda = ptr.get_metadata %mr : memref<f32, #ptr.generic_space>
+   // The chain can be folded because: the ptr always has the same value because
+   // `to_ptr` is a loss-less cast and %mda comes from the original memref.
+   %res = ptr.from_ptr %ptr metadata %mda : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+   return %res : memref<f32, #ptr.generic_space>
+}



More information about the Mlir-commits mailing list