[Mlir-commits] [mlir] [mlir][ptr] Add ConstantOp with NullAttr and AddressAttr support (PR #157347)

Fabian Mora llvmlistbot at llvm.org
Sun Sep 14 08:30:53 PDT 2025


https://github.com/fabianmcg updated https://github.com/llvm/llvm-project/pull/157347

>From 966d174801809e733017f6940ed46ffbe2a78df0 Mon Sep 17 00:00:00 2001
From: Fabian Mora <6982088+fabianmcg at users.noreply.github.com>
Date: Sun, 7 Sep 2025 15:14:41 +0000
Subject: [PATCH 1/3] [mlir][ptr] Add ConstantOp with NullAttr and AddressAttr
 support

This patch introduces the `ptr.constant` operation. It also adds the `NullAttr`
and `AddressAttr` for representing null pointers, and integer raw addresses.

It also implements LLVM IR translation for `ptr.constant` with `#ptr.null` or
`#ptr.address` attributes.

Finally, it extends `FieldParser` to support APInt parsing.

Example:
```mlir
llvm.func @constant_address_op() ->
    !llvm.struct<(!ptr.ptr<#llvm.address_space<0>>,
                  !ptr.ptr<#llvm.address_space<1>>,
                  !ptr.ptr<#llvm.address_space<2>>)> {
  %0 = ptr.constant #ptr.null : !ptr.ptr<#llvm.address_space<0>>
  %1 = ptr.constant #ptr.address<0x1000> : !ptr.ptr<#llvm.address_space<1>>
  %2 = ptr.constant #ptr.address<3735928559> : !ptr.ptr<#llvm.address_space<2>>
  %3 = llvm.mlir.poison : !llvm.struct<(!ptr.ptr<#llvm.address_space<0>>, !ptr.ptr<#llvm.address_space<1>>, !ptr.ptr<#llvm.address_space<2>>)>
  %4 = llvm.insertvalue %0, %3[0] : !llvm.struct<(!ptr.ptr<#llvm.address_space<0>>, !ptr.ptr<#llvm.address_space<1>>, !ptr.ptr<#llvm.address_space<2>>)>
  %5 = llvm.insertvalue %1, %4[1] : !llvm.struct<(!ptr.ptr<#llvm.address_space<0>>, !ptr.ptr<#llvm.address_space<1>>, !ptr.ptr<#llvm.address_space<2>>)>
  %6 = llvm.insertvalue %2, %5[2] : !llvm.struct<(!ptr.ptr<#llvm.address_space<0>>, !ptr.ptr<#llvm.address_space<1>>, !ptr.ptr<#llvm.address_space<2>>)>
  llvm.return %6 : !llvm.struct<(!ptr.ptr<#llvm.address_space<0>>, !ptr.ptr<#llvm.address_space<1>>, !ptr.ptr<#llvm.address_space<2>>)>
}
```
Result of translation to LLVM IR:
```llvm
define { ptr, ptr addrspace(1), ptr addrspace(2) } @constant_address_op() {
  ret { ptr, ptr addrspace(1), ptr addrspace(2) } { ptr null, ptr addrspace(1) inttoptr (i64 4096 to ptr addrspace(1)), ptr addrspace(2) inttoptr (i64 3735928559 to ptr addrspace(2)) }
}
```
---
 .../mlir/Dialect/Ptr/IR/PtrAttrDefs.td        | 60 ++++++++++++++++++-
 mlir/include/mlir/Dialect/Ptr/IR/PtrAttrs.h   |  6 ++
 mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td    | 39 +++++++++---
 mlir/include/mlir/IR/DialectImplementation.h  |  7 ++-
 mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp        |  6 ++
 .../Dialect/Ptr/PtrToLLVMIRTranslation.cpp    | 52 ++++++++++++++++
 mlir/test/Dialect/Ptr/ops.mlir                | 24 +++++++-
 mlir/test/Target/LLVMIR/ptr.mlir              | 36 +++++++++--
 8 files changed, 211 insertions(+), 19 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrDefs.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrDefs.td
index 4542f57a62d79..bec97e9aa1b90 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrDefs.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrDefs.td
@@ -22,6 +22,34 @@ class Ptr_Attr<string name, string attrMnemonic,
   let mnemonic = attrMnemonic;
 }
 
+//===----------------------------------------------------------------------===//
+// AddressAttr
+//===----------------------------------------------------------------------===//
+
+def Ptr_AddressAttr : Ptr_Attr<"Address", "address", [
+    DeclareAttrInterfaceMethods<TypedAttrInterface>
+  ]> {
+  let summary = "Address attribute";
+  let description = [{
+    The `address` attribute represents a raw memory address.
+
+    Example:
+
+    ```mlir
+      #ptr.address<0x1000> : !ptr.ptr<#ptr.generic_space>
+    ```
+  }];
+  let parameters = (ins AttributeSelfTypeParameter<"", "PtrType">:$type,
+                        APIntParameter<"">:$value);
+  let builders = [
+    AttrBuilderWithInferredContext<(ins "PtrType":$type,
+                                        "const llvm::APInt &":$value), [{
+      return $_get(type.getContext(), type, value);
+    }]>
+  ];
+  let assemblyFormat = "`<` $value `>`";
+}
+
 //===----------------------------------------------------------------------===//
 // GenericSpaceAttr
 //===----------------------------------------------------------------------===//
@@ -37,16 +65,42 @@ def Ptr_GenericSpaceAttr :
     - Load and store operations are always valid, regardless of the type.
     - Atomic operations are always valid, regardless of the type.
     - Cast operations to `generic_space` are always valid.
-  
+
     Example:
 
     ```mlir
-    #ptr.generic_space
+      #ptr.generic_space : !ptr.ptr<#ptr.generic_space>
     ```
   }];
   let assemblyFormat = "";
 }
 
+//===----------------------------------------------------------------------===//
+// NullAttr
+//===----------------------------------------------------------------------===//
+
+def Ptr_NullAttr : Ptr_Attr<"Null", "null", [
+    DeclareAttrInterfaceMethods<TypedAttrInterface>
+  ]> {
+  let summary = "Null pointer attribute";
+  let description = [{
+    The `null` attribute represents a null pointer.
+
+    Example:
+
+    ```mlir
+      #ptr.null
+    ```
+  }];
+  let parameters = (ins AttributeSelfTypeParameter<"", "PtrType">:$type);
+  let builders = [
+    AttrBuilderWithInferredContext<(ins "PtrType":$type), [{
+      return $_get(type.getContext(), type);
+    }]>
+  ];
+  let assemblyFormat = "";
+}
+
 //===----------------------------------------------------------------------===//
 // SpecAttr
 //===----------------------------------------------------------------------===//
@@ -62,7 +116,7 @@ def Ptr_SpecAttr : Ptr_Attr<"Spec", "spec"> {
      - [Optional] index: bitwidth that should be used when performing index
      computations for the type. Setting the field to `kOptionalSpecValue`, means
      the field is optional.
-    
+
     Furthermore, the attribute will verify that all present values are divisible
     by 8 (number of bits in a byte), and that `preferred` > `abi`.
 
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrs.h b/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrs.h
index bb01ceaaeea54..c252f9efd0471 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrs.h
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrs.h
@@ -21,6 +21,12 @@
 #include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h"
 #include "mlir/Dialect/Ptr/IR/PtrEnums.h"
 
+namespace mlir {
+namespace ptr {
+class PtrType;
+} // namespace ptr
+} // namespace mlir
+
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/Ptr/IR/PtrOpsAttrs.h.inc"
 
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
index 3ac12978b947c..468a3004d5c62 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
@@ -36,7 +36,7 @@ class Ptr_ShapedValueType<list<Type> allowedTypes, list<Pred> preds = []> :
     /*cppType=*/"::mlir::ShapedType">;
 
 // A ptr-like type, either scalar or shaped type with value semantics.
-def Ptr_PtrLikeType : 
+def Ptr_PtrLikeType :
   AnyTypeOf<[Ptr_ShapedValueType<[Ptr_PtrType], [HasRankPred]>, Ptr_PtrType]>;
 
 // An int-like type, either scalar or shaped type with value semantics.
@@ -57,6 +57,31 @@ def Ptr_Mask1DType :
 def Ptr_Ptr1DType :
   Ptr_ShapedValueType<[Ptr_PtrType], [HasAnyRankOfPred<[1]>]>;
 
+//===----------------------------------------------------------------------===//
+// ConstantOp
+//===----------------------------------------------------------------------===//
+
+def Ptr_ConstantOp : Pointer_Op<"constant", [
+    ConstantLike, Pure, AllTypesMatch<["value", "result"]>
+  ]> {
+  let summary = "Pointer constant operation";
+  let description = [{
+    The `constant` operation produces a pointer constant. The attribute must be
+    a typed attribute of pointer type.
+
+    Example:
+
+    ```mlir
+    // Create a null pointer
+    %null = ptr.constant #ptr.null : !ptr.ptr<#ptr.generic_space>
+    ```
+  }];
+  let arguments = (ins TypedAttrInterface:$value);
+  let results = (outs Ptr_PtrType:$result);
+  let assemblyFormat = "attr-dict $value";
+  let hasFolder = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // FromPtrOp
 //===----------------------------------------------------------------------===//
@@ -81,7 +106,7 @@ def Ptr_FromPtrOp : Pointer_Op<"from_ptr", [
     ```mlir
     %typed_ptr = ptr.from_ptr %ptr : !ptr.ptr<#ptr.generic_space> -> !my.ptr<f32, #ptr.generic_space>
     %memref = ptr.from_ptr %ptr metadata %md : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
-  
+
     // Cast the `%ptr` to a memref without utilizing metadata.
     %memref = ptr.from_ptr %ptr : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
     ```
@@ -361,13 +386,13 @@ def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [
     // Scalar base and offset
     %x_off  = ptr.ptr_add %x, %off : !ptr.ptr<#ptr.generic_space>, i32
     %x_off0 = ptr.ptr_add nusw %x, %off : !ptr.ptr<#ptr.generic_space>, i32
-    
+
     // Shaped base with scalar offset
     %ptrs_off = ptr.ptr_add %ptrs, %off : vector<4x!ptr.ptr<#ptr.generic_space>>, i32
-    
+
     // Scalar base with shaped offset
     %x_offs = ptr.ptr_add %x, %offs : !ptr.ptr<#ptr.generic_space>, vector<4xi32>
-    
+
     // Both base and offset are shaped
     %ptrs_offs = ptr.ptr_add %ptrs, %offs : vector<4x!ptr.ptr<#ptr.generic_space>>, vector<4xi32>
     ```
@@ -382,7 +407,7 @@ def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [
   }];
   let hasFolder = 1;
   let extraClassDeclaration = [{
-    /// `ViewLikeOp::getViewSource` method. 
+    /// `ViewLikeOp::getViewSource` method.
     Value getViewSource() { return getBase(); }
 
     /// Returns the ptr type of the operation.
@@ -418,7 +443,7 @@ def Ptr_ScatterOp : Pointer_Op<"scatter", [
     // Scatter values to multiple memory locations
     ptr.scatter %value, %ptrs, %mask :
       vector<4xf32>, vector<4x!ptr.ptr<#ptr.generic_space>>
-    
+
     // Scatter with alignment
     ptr.scatter %value, %ptrs, %mask alignment = 8 :
       vector<4xf32>, vector<4x!ptr.ptr<#ptr.generic_space>>
diff --git a/mlir/include/mlir/IR/DialectImplementation.h b/mlir/include/mlir/IR/DialectImplementation.h
index f45b88dc6deca..0b4f91cd750b8 100644
--- a/mlir/include/mlir/IR/DialectImplementation.h
+++ b/mlir/include/mlir/IR/DialectImplementation.h
@@ -103,10 +103,11 @@ struct FieldParser<
 
 /// Parse any integer.
 template <typename IntT>
-struct FieldParser<IntT,
-                   std::enable_if_t<std::is_integral<IntT>::value, IntT>> {
+struct FieldParser<IntT, std::enable_if_t<(std::is_integral<IntT>::value ||
+                                           std::is_same_v<IntT, llvm::APInt>),
+                                          IntT>> {
   static FailureOr<IntT> parse(AsmParser &parser) {
-    IntT value = 0;
+    IntT value{};
     if (parser.parseInteger(value))
       return failure();
     return value;
diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
index 284c998690170..f0209af8a1ca3 100644
--- a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
+++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
@@ -56,6 +56,12 @@ verifyAlignment(std::optional<int64_t> alignment,
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// ConstantOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
+
 //===----------------------------------------------------------------------===//
 // FromPtrOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp
index d777667022a98..11b921de21596 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp
@@ -300,6 +300,55 @@ convertScatterOp(ScatterOp scatterOp, llvm::IRBuilderBase &builder,
   return success();
 }
 
+/// Convert ptr.constant operation
+static LogicalResult
+convertConstantOp(ConstantOp constantOp, llvm::IRBuilderBase &builder,
+                  LLVM::ModuleTranslation &moduleTranslation) {
+  // Convert result type to LLVM type
+  llvm::PointerType *resultType = dyn_cast_or_null<llvm::PointerType>(
+      moduleTranslation.convertType(constantOp.getResult().getType()));
+  if (!resultType)
+    return constantOp.emitError("Expected a valid pointer type");
+
+  llvm::Value *result = nullptr;
+
+  TypedAttr value = constantOp.getValue();
+  if (auto nullAttr = dyn_cast<ptr::NullAttr>(value)) {
+    // Create a null pointer constant
+    result = llvm::ConstantPointerNull::get(resultType);
+  } else if (auto addressAttr = dyn_cast<ptr::AddressAttr>(value)) {
+    // Create an integer constant and convert it to pointer
+    llvm::APInt addressValue = addressAttr.getValue();
+
+    // Determine the integer type width based on the target's pointer size
+    llvm::DataLayout dataLayout =
+        moduleTranslation.getLLVMModule()->getDataLayout();
+    unsigned pointerSizeInBits =
+        dataLayout.getPointerSizeInBits(resultType->getAddressSpace());
+
+    // Extend or truncate the address value to match pointer size if needed
+    if (addressValue.getBitWidth() != pointerSizeInBits) {
+      if (addressValue.getBitWidth() > pointerSizeInBits) {
+        constantOp.emitWarning()
+            << "Truncating address value to fit pointer size";
+      }
+      addressValue = addressValue.getBitWidth() < pointerSizeInBits
+                         ? addressValue.zext(pointerSizeInBits)
+                         : addressValue.trunc(pointerSizeInBits);
+    }
+
+    // Create integer constant and convert to pointer
+    llvm::Type *intType = builder.getIntNTy(pointerSizeInBits);
+    llvm::Value *intValue = llvm::ConstantInt::get(intType, addressValue);
+    result = builder.CreateIntToPtr(intValue, resultType);
+  } else {
+    return constantOp.emitError("Unsupported constant attribute type");
+  }
+
+  moduleTranslation.mapValue(constantOp.getResult(), result);
+  return success();
+}
+
 /// Implementation of the dialect interface that converts operations belonging
 /// to the `ptr` dialect to LLVM IR.
 class PtrDialectLLVMIRTranslationInterface
@@ -314,6 +363,9 @@ class PtrDialectLLVMIRTranslationInterface
                    LLVM::ModuleTranslation &moduleTranslation) const final {
 
     return llvm::TypeSwitch<Operation *, LogicalResult>(op)
+        .Case([&](ConstantOp constantOp) {
+          return convertConstantOp(constantOp, builder, moduleTranslation);
+        })
         .Case([&](PtrAddOp ptrAddOp) {
           return convertPtrAddOp(ptrAddOp, builder, moduleTranslation);
         })
diff --git a/mlir/test/Dialect/Ptr/ops.mlir b/mlir/test/Dialect/Ptr/ops.mlir
index 51e5ac3ae691d..7b2254185f57c 100644
--- a/mlir/test/Dialect/Ptr/ops.mlir
+++ b/mlir/test/Dialect/Ptr/ops.mlir
@@ -114,7 +114,7 @@ func.func @masked_store_ops_tensor(%value: tensor<8xi64>, %ptr: !ptr.ptr<#ptr.ge
 }
 
 /// Test operations with LLVM address space
-func.func @llvm_masked_ops(%ptr: !ptr.ptr<#llvm.address_space<3>>, %ptrs: vector<4x!ptr.ptr<#llvm.address_space<3>>>, 
+func.func @llvm_masked_ops(%ptr: !ptr.ptr<#llvm.address_space<3>>, %ptrs: vector<4x!ptr.ptr<#llvm.address_space<3>>>,
                            %mask: vector<4xi1>, %value: vector<4xf32>, %passthrough: vector<4xf32>) -> vector<4xf32> {
   // Gather from shared memory (address space 3)
   %0 = ptr.gather %ptrs, %mask, %passthrough alignment = 4 : vector<4x!ptr.ptr<#llvm.address_space<3>>> -> vector<4xf32>
@@ -189,3 +189,25 @@ func.func @ptr_add_tensor_base_scalar_offset(%ptrs: tensor<8x!ptr.ptr<#ptr.gener
   %res3 = ptr.ptr_add inbounds %ptrs, %offset : tensor<8x!ptr.ptr<#ptr.generic_space>>, i64
   return %res : tensor<8x!ptr.ptr<#ptr.generic_space>>
 }
+
+/// Test constant operations with null pointer
+func.func @constant_null_ops() -> (!ptr.ptr<#ptr.generic_space>, !ptr.ptr<#llvm.address_space<1>>) {
+  %null_generic = ptr.constant #ptr.null : !ptr.ptr<#ptr.generic_space>
+  %null_as1 = ptr.constant #ptr.null : !ptr.ptr<#llvm.address_space<1>>
+  return %null_generic, %null_as1 : !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#llvm.address_space<1>>
+}
+
+/// Test constant operations with address values
+func.func @constant_address_ops() -> (!ptr.ptr<#ptr.generic_space>, !ptr.ptr<#llvm.address_space<1>>, !ptr.ptr<#llvm.address_space<3>>) {
+  %addr_0 = ptr.constant #ptr.address<0> : !ptr.ptr<#ptr.generic_space>
+  %addr_1000 = ptr.constant #ptr.address<0x1000> : !ptr.ptr<#llvm.address_space<1>>
+  %addr_deadbeef = ptr.constant #ptr.address<0xDEADBEEF> : !ptr.ptr<#llvm.address_space<3>>
+  return %addr_0, %addr_1000, %addr_deadbeef : !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#llvm.address_space<1>>, !ptr.ptr<#llvm.address_space<3>>
+}
+
+/// Test constant operations with large address values
+func.func @constant_large_address_ops() -> (!ptr.ptr<#ptr.generic_space>, !ptr.ptr<#llvm.address_space<0>>) {
+  %addr_max32 = ptr.constant #ptr.address<0xFFFFFFFF> : !ptr.ptr<#ptr.generic_space>
+  %addr_large = ptr.constant #ptr.address<0x123456789ABCDEF0> : !ptr.ptr<#llvm.address_space<0>>
+  return %addr_max32, %addr_large : !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#llvm.address_space<0>>
+}
diff --git a/mlir/test/Target/LLVMIR/ptr.mlir b/mlir/test/Target/LLVMIR/ptr.mlir
index 9b99dd8e3a3eb..2fa794130ec52 100644
--- a/mlir/test/Target/LLVMIR/ptr.mlir
+++ b/mlir/test/Target/LLVMIR/ptr.mlir
@@ -41,10 +41,10 @@ llvm.func @type_offset(%arg0: !ptr.ptr<#llvm.address_space<0>>) -> !llvm.struct<
   %2 = ptr.type_offset i16 : i32
   %3 = ptr.type_offset i32 : i32
   %4 = llvm.mlir.poison : !llvm.struct<(i32, i32, i32, i32)>
-  %5 = llvm.insertvalue %0, %4[0] : !llvm.struct<(i32, i32, i32, i32)> 
-  %6 = llvm.insertvalue %1, %5[1] : !llvm.struct<(i32, i32, i32, i32)> 
-  %7 = llvm.insertvalue %2, %6[2] : !llvm.struct<(i32, i32, i32, i32)> 
-  %8 = llvm.insertvalue %3, %7[3] : !llvm.struct<(i32, i32, i32, i32)> 
+  %5 = llvm.insertvalue %0, %4[0] : !llvm.struct<(i32, i32, i32, i32)>
+  %6 = llvm.insertvalue %1, %5[1] : !llvm.struct<(i32, i32, i32, i32)>
+  %7 = llvm.insertvalue %2, %6[2] : !llvm.struct<(i32, i32, i32, i32)>
+  %8 = llvm.insertvalue %3, %7[3] : !llvm.struct<(i32, i32, i32, i32)>
   llvm.return %8 : !llvm.struct<(i32, i32, i32, i32)>
 }
 
@@ -194,7 +194,7 @@ llvm.func @scatter_ops_i64(%value: vector<8xi64>, %ptrs: vector<8x!ptr.ptr<#llvm
 // CHECK-NEXT:   call void @llvm.masked.store.v4f64.p3(<4 x double> %[[VALUE_F64]], ptr addrspace(3) %[[PTR_SHARED]], i32 8, <4 x i1> %[[MASK]])
 // CHECK-NEXT:   ret void
 // CHECK-NEXT: }
-llvm.func @mixed_masked_ops_address_spaces(%ptr: !ptr.ptr<#llvm.address_space<3>>, %ptrs: vector<4x!ptr.ptr<#llvm.address_space<3>>>, 
+llvm.func @mixed_masked_ops_address_spaces(%ptr: !ptr.ptr<#llvm.address_space<3>>, %ptrs: vector<4x!ptr.ptr<#llvm.address_space<3>>>,
                                           %mask: vector<4xi1>, %value: vector<4xf64>, %passthrough: vector<4xf64>) {
   // Test with shared memory address space (3) and f64 elements
   %0 = ptr.gather %ptrs, %mask, %passthrough alignment = 8 : vector<4x!ptr.ptr<#llvm.address_space<3>>> -> vector<4xf64>
@@ -255,3 +255,29 @@ llvm.func @llvm_ops_with_ptr_nvvm_values(%arg0: !llvm.ptr) {
   llvm.store %1, %arg0 : !ptr.ptr<#nvvm.memory_space<global>>, !llvm.ptr
   llvm.return
 }
+
+// CHECK-LABEL: define { ptr, ptr addrspace(1), ptr addrspace(2) } @constant_address_op() {
+// CHECK-NEXT: ret { ptr, ptr addrspace(1), ptr addrspace(2) } { ptr null, ptr addrspace(1) inttoptr (i64 4096 to ptr addrspace(1)), ptr addrspace(2) inttoptr (i64 3735928559 to ptr addrspace(2)) }
+llvm.func @constant_address_op() ->
+    !llvm.struct<(!ptr.ptr<#llvm.address_space<0>>,
+                  !ptr.ptr<#llvm.address_space<1>>,
+                  !ptr.ptr<#llvm.address_space<2>>)> {
+  %0 = ptr.constant #ptr.null : !ptr.ptr<#llvm.address_space<0>>
+  %1 = ptr.constant #ptr.address<0x1000> : !ptr.ptr<#llvm.address_space<1>>
+  %2 = ptr.constant #ptr.address<3735928559> : !ptr.ptr<#llvm.address_space<2>>
+  %3 = llvm.mlir.poison : !llvm.struct<(!ptr.ptr<#llvm.address_space<0>>, !ptr.ptr<#llvm.address_space<1>>, !ptr.ptr<#llvm.address_space<2>>)>
+  %4 = llvm.insertvalue %0, %3[0] : !llvm.struct<(!ptr.ptr<#llvm.address_space<0>>, !ptr.ptr<#llvm.address_space<1>>, !ptr.ptr<#llvm.address_space<2>>)>
+  %5 = llvm.insertvalue %1, %4[1] : !llvm.struct<(!ptr.ptr<#llvm.address_space<0>>, !ptr.ptr<#llvm.address_space<1>>, !ptr.ptr<#llvm.address_space<2>>)>
+  %6 = llvm.insertvalue %2, %5[2] : !llvm.struct<(!ptr.ptr<#llvm.address_space<0>>, !ptr.ptr<#llvm.address_space<1>>, !ptr.ptr<#llvm.address_space<2>>)>
+  llvm.return %6 : !llvm.struct<(!ptr.ptr<#llvm.address_space<0>>, !ptr.ptr<#llvm.address_space<1>>, !ptr.ptr<#llvm.address_space<2>>)>
+}
+
+// Test gep folders.
+// CHECK-LABEL: define ptr @ptr_add_cst() {
+// CHECK-NEXT:   ret ptr inttoptr (i64 42 to ptr)
+llvm.func @ptr_add_cst() -> !ptr.ptr<#llvm.address_space<0>> {
+  %off = llvm.mlir.constant(42 : i32) : i32
+  %ptr = ptr.constant #ptr.null : !ptr.ptr<#llvm.address_space<0>>
+  %res = ptr.ptr_add %ptr, %off : !ptr.ptr<#llvm.address_space<0>>, i32
+  llvm.return %res : !ptr.ptr<#llvm.address_space<0>>
+}

>From ac65fb653188479707e9249ce8ec29589a2a9c8e Mon Sep 17 00:00:00 2001
From: Fabian Mora <fmora.dev at gmail.com>
Date: Sun, 14 Sep 2025 11:15:15 -0400
Subject: [PATCH 2/3] Update mlir/include/mlir/Dialect/Ptr/IR/PtrAttrDefs.td

Co-authored-by: Mehdi Amini <joker.eph at gmail.com>
---
 mlir/include/mlir/Dialect/Ptr/IR/PtrAttrDefs.td | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrDefs.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrDefs.td
index bec97e9aa1b90..78006d2dec40d 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrDefs.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrDefs.td
@@ -31,7 +31,7 @@ def Ptr_AddressAttr : Ptr_Attr<"Address", "address", [
   ]> {
   let summary = "Address attribute";
   let description = [{
-    The `address` attribute represents a raw memory address.
+    The `address` attribute represents a raw memory address, expressed in bytes.
 
     Example:
 

>From 95c9aa7cd249afd48efbc517ac0a4dc2245495a7 Mon Sep 17 00:00:00 2001
From: Fabian Mora <6982088+fabianmcg at users.noreply.github.com>
Date: Sun, 14 Sep 2025 15:29:41 +0000
Subject: [PATCH 3/3] address reviewer comments

---
 .../Dialect/Ptr/PtrToLLVMIRTranslation.cpp    | 112 +++++++++---------
 1 file changed, 58 insertions(+), 54 deletions(-)

diff --git a/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp
index 11b921de21596..7e610cd42e931 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp
@@ -29,7 +29,7 @@ namespace {
 
 /// Converts ptr::AtomicOrdering to llvm::AtomicOrdering
 static llvm::AtomicOrdering
-convertAtomicOrdering(ptr::AtomicOrdering ordering) {
+translateAtomicOrdering(ptr::AtomicOrdering ordering) {
   switch (ordering) {
   case ptr::AtomicOrdering::not_atomic:
     return llvm::AtomicOrdering::NotAtomic;
@@ -49,10 +49,10 @@ convertAtomicOrdering(ptr::AtomicOrdering ordering) {
   llvm_unreachable("Unknown atomic ordering");
 }
 
-/// Convert ptr.ptr_add operation
+/// Translate ptr.ptr_add operation to LLVM IR.
 static LogicalResult
-convertPtrAddOp(PtrAddOp ptrAddOp, llvm::IRBuilderBase &builder,
-                LLVM::ModuleTranslation &moduleTranslation) {
+translatePtrAddOp(PtrAddOp ptrAddOp, llvm::IRBuilderBase &builder,
+                  LLVM::ModuleTranslation &moduleTranslation) {
   llvm::Value *basePtr = moduleTranslation.lookupValue(ptrAddOp.getBase());
   llvm::Value *offset = moduleTranslation.lookupValue(ptrAddOp.getOffset());
 
@@ -83,18 +83,19 @@ convertPtrAddOp(PtrAddOp ptrAddOp, llvm::IRBuilderBase &builder,
   return success();
 }
 
-/// Convert ptr.load operation
-static LogicalResult convertLoadOp(LoadOp loadOp, llvm::IRBuilderBase &builder,
-                                   LLVM::ModuleTranslation &moduleTranslation) {
+/// Translate ptr.load operation to LLVM IR.
+static LogicalResult
+translateLoadOp(LoadOp loadOp, llvm::IRBuilderBase &builder,
+                LLVM::ModuleTranslation &moduleTranslation) {
   llvm::Value *ptr = moduleTranslation.lookupValue(loadOp.getPtr());
   if (!ptr)
     return loadOp.emitError("Failed to lookup pointer operand");
 
-  // Convert result type to LLVM type
+  // Translate result type to LLVM type
   llvm::Type *resultType =
       moduleTranslation.convertType(loadOp.getValue().getType());
   if (!resultType)
-    return loadOp.emitError("Failed to convert result type");
+    return loadOp.emitError("Failed to translate result type");
 
   // Create the load instruction.
   llvm::MaybeAlign alignment(loadOp.getAlignment().value_or(0));
@@ -102,7 +103,7 @@ static LogicalResult convertLoadOp(LoadOp loadOp, llvm::IRBuilderBase &builder,
       resultType, ptr, alignment, loadOp.getVolatile_());
 
   // Set op flags and metadata.
-  loadInst->setAtomic(convertAtomicOrdering(loadOp.getOrdering()));
+  loadInst->setAtomic(translateAtomicOrdering(loadOp.getOrdering()));
   // Set sync scope if specified
   if (loadOp.getSyncscope().has_value()) {
     llvm::LLVMContext &ctx = builder.getContext();
@@ -135,10 +136,10 @@ static LogicalResult convertLoadOp(LoadOp loadOp, llvm::IRBuilderBase &builder,
   return success();
 }
 
-/// Convert ptr.store operation
+/// Translate ptr.store operation to LLVM IR.
 static LogicalResult
-convertStoreOp(StoreOp storeOp, llvm::IRBuilderBase &builder,
-               LLVM::ModuleTranslation &moduleTranslation) {
+translateStoreOp(StoreOp storeOp, llvm::IRBuilderBase &builder,
+                 LLVM::ModuleTranslation &moduleTranslation) {
   llvm::Value *value = moduleTranslation.lookupValue(storeOp.getValue());
   llvm::Value *ptr = moduleTranslation.lookupValue(storeOp.getPtr());
 
@@ -151,7 +152,7 @@ convertStoreOp(StoreOp storeOp, llvm::IRBuilderBase &builder,
       builder.CreateAlignedStore(value, ptr, alignment, storeOp.getVolatile_());
 
   // Set op flags and metadata.
-  storeInst->setAtomic(convertAtomicOrdering(storeOp.getOrdering()));
+  storeInst->setAtomic(translateAtomicOrdering(storeOp.getOrdering()));
   // Set sync scope if specified
   if (storeOp.getSyncscope().has_value()) {
     llvm::LLVMContext &ctx = builder.getContext();
@@ -178,21 +179,21 @@ convertStoreOp(StoreOp storeOp, llvm::IRBuilderBase &builder,
   return success();
 }
 
-/// Convert ptr.type_offset operation
+/// Translate ptr.type_offset operation to LLVM IR.
 static LogicalResult
-convertTypeOffsetOp(TypeOffsetOp typeOffsetOp, llvm::IRBuilderBase &builder,
-                    LLVM::ModuleTranslation &moduleTranslation) {
-  // Convert the element type to LLVM type
+translateTypeOffsetOp(TypeOffsetOp typeOffsetOp, llvm::IRBuilderBase &builder,
+                      LLVM::ModuleTranslation &moduleTranslation) {
+  // Translate the element type to LLVM type
   llvm::Type *elementType =
       moduleTranslation.convertType(typeOffsetOp.getElementType());
   if (!elementType)
-    return typeOffsetOp.emitError("Failed to convert the element type");
+    return typeOffsetOp.emitError("Failed to translate the element type");
 
-  // Convert result type
+  // Translate result type
   llvm::Type *resultType =
       moduleTranslation.convertType(typeOffsetOp.getResult().getType());
   if (!resultType)
-    return typeOffsetOp.emitError("Failed to convert the result type");
+    return typeOffsetOp.emitError("Failed to translate the result type");
 
   // Use GEP with null pointer to compute type size/offset.
   llvm::Value *nullPtr = llvm::Constant::getNullValue(builder.getPtrTy(0));
@@ -204,10 +205,10 @@ convertTypeOffsetOp(TypeOffsetOp typeOffsetOp, llvm::IRBuilderBase &builder,
   return success();
 }
 
-/// Convert ptr.gather operation
+/// Translate ptr.gather operation to LLVM IR.
 static LogicalResult
-convertGatherOp(GatherOp gatherOp, llvm::IRBuilderBase &builder,
-                LLVM::ModuleTranslation &moduleTranslation) {
+translateGatherOp(GatherOp gatherOp, llvm::IRBuilderBase &builder,
+                  LLVM::ModuleTranslation &moduleTranslation) {
   llvm::Value *ptrs = moduleTranslation.lookupValue(gatherOp.getPtrs());
   llvm::Value *mask = moduleTranslation.lookupValue(gatherOp.getMask());
   llvm::Value *passthrough =
@@ -216,11 +217,11 @@ convertGatherOp(GatherOp gatherOp, llvm::IRBuilderBase &builder,
   if (!ptrs || !mask || !passthrough)
     return gatherOp.emitError("Failed to lookup operands");
 
-  // Convert result type to LLVM type.
+  // Translate result type to LLVM type.
   llvm::Type *resultType =
       moduleTranslation.convertType(gatherOp.getResult().getType());
   if (!resultType)
-    return gatherOp.emitError("Failed to convert result type");
+    return gatherOp.emitError("Failed to translate result type");
 
   // Get the alignment.
   llvm::MaybeAlign alignment(gatherOp.getAlignment().value_or(0));
@@ -233,10 +234,10 @@ convertGatherOp(GatherOp gatherOp, llvm::IRBuilderBase &builder,
   return success();
 }
 
-/// Convert ptr.masked_load operation
+/// Translate ptr.masked_load operation to LLVM IR.
 static LogicalResult
-convertMaskedLoadOp(MaskedLoadOp maskedLoadOp, llvm::IRBuilderBase &builder,
-                    LLVM::ModuleTranslation &moduleTranslation) {
+translateMaskedLoadOp(MaskedLoadOp maskedLoadOp, llvm::IRBuilderBase &builder,
+                      LLVM::ModuleTranslation &moduleTranslation) {
   llvm::Value *ptr = moduleTranslation.lookupValue(maskedLoadOp.getPtr());
   llvm::Value *mask = moduleTranslation.lookupValue(maskedLoadOp.getMask());
   llvm::Value *passthrough =
@@ -245,11 +246,11 @@ convertMaskedLoadOp(MaskedLoadOp maskedLoadOp, llvm::IRBuilderBase &builder,
   if (!ptr || !mask || !passthrough)
     return maskedLoadOp.emitError("Failed to lookup operands");
 
-  // Convert result type to LLVM type.
+  // Translate result type to LLVM type.
   llvm::Type *resultType =
       moduleTranslation.convertType(maskedLoadOp.getResult().getType());
   if (!resultType)
-    return maskedLoadOp.emitError("Failed to convert result type");
+    return maskedLoadOp.emitError("Failed to translate result type");
 
   // Get the alignment.
   llvm::MaybeAlign alignment(maskedLoadOp.getAlignment().value_or(0));
@@ -262,10 +263,11 @@ convertMaskedLoadOp(MaskedLoadOp maskedLoadOp, llvm::IRBuilderBase &builder,
   return success();
 }
 
-/// Convert ptr.masked_store operation
+/// Translate ptr.masked_store operation to LLVM IR.
 static LogicalResult
-convertMaskedStoreOp(MaskedStoreOp maskedStoreOp, llvm::IRBuilderBase &builder,
-                     LLVM::ModuleTranslation &moduleTranslation) {
+translateMaskedStoreOp(MaskedStoreOp maskedStoreOp,
+                       llvm::IRBuilderBase &builder,
+                       LLVM::ModuleTranslation &moduleTranslation) {
   llvm::Value *value = moduleTranslation.lookupValue(maskedStoreOp.getValue());
   llvm::Value *ptr = moduleTranslation.lookupValue(maskedStoreOp.getPtr());
   llvm::Value *mask = moduleTranslation.lookupValue(maskedStoreOp.getMask());
@@ -281,10 +283,10 @@ convertMaskedStoreOp(MaskedStoreOp maskedStoreOp, llvm::IRBuilderBase &builder,
   return success();
 }
 
-/// Convert ptr.scatter operation
+/// Translate ptr.scatter operation to LLVM IR.
 static LogicalResult
-convertScatterOp(ScatterOp scatterOp, llvm::IRBuilderBase &builder,
-                 LLVM::ModuleTranslation &moduleTranslation) {
+translateScatterOp(ScatterOp scatterOp, llvm::IRBuilderBase &builder,
+                   LLVM::ModuleTranslation &moduleTranslation) {
   llvm::Value *value = moduleTranslation.lookupValue(scatterOp.getValue());
   llvm::Value *ptrs = moduleTranslation.lookupValue(scatterOp.getPtrs());
   llvm::Value *mask = moduleTranslation.lookupValue(scatterOp.getMask());
@@ -300,11 +302,11 @@ convertScatterOp(ScatterOp scatterOp, llvm::IRBuilderBase &builder,
   return success();
 }
 
-/// Convert ptr.constant operation
+/// Translate ptr.constant operation to LLVM IR.
 static LogicalResult
-convertConstantOp(ConstantOp constantOp, llvm::IRBuilderBase &builder,
-                  LLVM::ModuleTranslation &moduleTranslation) {
-  // Convert result type to LLVM type
+translateConstantOp(ConstantOp constantOp, llvm::IRBuilderBase &builder,
+                    LLVM::ModuleTranslation &moduleTranslation) {
+  // Translate result type to LLVM type
   llvm::PointerType *resultType = dyn_cast_or_null<llvm::PointerType>(
       moduleTranslation.convertType(constantOp.getResult().getType()));
   if (!resultType)
@@ -317,7 +319,7 @@ convertConstantOp(ConstantOp constantOp, llvm::IRBuilderBase &builder,
     // Create a null pointer constant
     result = llvm::ConstantPointerNull::get(resultType);
   } else if (auto addressAttr = dyn_cast<ptr::AddressAttr>(value)) {
-    // Create an integer constant and convert it to pointer
+    // Create an integer constant and translate it to pointer
     llvm::APInt addressValue = addressAttr.getValue();
 
     // Determine the integer type width based on the target's pointer size
@@ -337,7 +339,7 @@ convertConstantOp(ConstantOp constantOp, llvm::IRBuilderBase &builder,
                          : addressValue.trunc(pointerSizeInBits);
     }
 
-    // Create integer constant and convert to pointer
+    // Create integer constant and translate to pointer
     llvm::Type *intType = builder.getIntNTy(pointerSizeInBits);
     llvm::Value *intValue = llvm::ConstantInt::get(intType, addressValue);
     result = builder.CreateIntToPtr(intValue, resultType);
@@ -349,7 +351,7 @@ convertConstantOp(ConstantOp constantOp, llvm::IRBuilderBase &builder,
   return success();
 }
 
-/// Implementation of the dialect interface that converts operations belonging
+/// Implementation of the dialect interface that translates operations belonging
 /// to the `ptr` dialect to LLVM IR.
 class PtrDialectLLVMIRTranslationInterface
     : public LLVMTranslationDialectInterface {
@@ -364,32 +366,34 @@ class PtrDialectLLVMIRTranslationInterface
 
     return llvm::TypeSwitch<Operation *, LogicalResult>(op)
         .Case([&](ConstantOp constantOp) {
-          return convertConstantOp(constantOp, builder, moduleTranslation);
+          return translateConstantOp(constantOp, builder, moduleTranslation);
         })
         .Case([&](PtrAddOp ptrAddOp) {
-          return convertPtrAddOp(ptrAddOp, builder, moduleTranslation);
+          return translatePtrAddOp(ptrAddOp, builder, moduleTranslation);
         })
         .Case([&](LoadOp loadOp) {
-          return convertLoadOp(loadOp, builder, moduleTranslation);
+          return translateLoadOp(loadOp, builder, moduleTranslation);
         })
         .Case([&](StoreOp storeOp) {
-          return convertStoreOp(storeOp, builder, moduleTranslation);
+          return translateStoreOp(storeOp, builder, moduleTranslation);
         })
         .Case([&](TypeOffsetOp typeOffsetOp) {
-          return convertTypeOffsetOp(typeOffsetOp, builder, moduleTranslation);
+          return translateTypeOffsetOp(typeOffsetOp, builder,
+                                       moduleTranslation);
         })
         .Case<GatherOp>([&](GatherOp gatherOp) {
-          return convertGatherOp(gatherOp, builder, moduleTranslation);
+          return translateGatherOp(gatherOp, builder, moduleTranslation);
         })
         .Case<MaskedLoadOp>([&](MaskedLoadOp maskedLoadOp) {
-          return convertMaskedLoadOp(maskedLoadOp, builder, moduleTranslation);
+          return translateMaskedLoadOp(maskedLoadOp, builder,
+                                       moduleTranslation);
         })
         .Case<MaskedStoreOp>([&](MaskedStoreOp maskedStoreOp) {
-          return convertMaskedStoreOp(maskedStoreOp, builder,
-                                      moduleTranslation);
+          return translateMaskedStoreOp(maskedStoreOp, builder,
+                                        moduleTranslation);
         })
         .Case<ScatterOp>([&](ScatterOp scatterOp) {
-          return convertScatterOp(scatterOp, builder, moduleTranslation);
+          return translateScatterOp(scatterOp, builder, moduleTranslation);
         })
         .Default([&](Operation *op) {
           return op->emitError("Translation for operation '")



More information about the Mlir-commits mailing list