[Mlir-commits] [mlir] [mlir][ptr] Add int_to_ptr && ptr_to_int ops (PR #190527)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Apr 5 07:29:41 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: justLi (enjustli)
<details>
<summary>Changes</summary>
1. add int_to_ptr op
2. add ptr_to_int
3. adjust `ptr` types code
---
Full diff: https://github.com/llvm/llvm-project/pull/190527.diff
5 Files Affected:
- (modified) mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td (+39)
- (modified) mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td (+47-39)
- (modified) mlir/lib/Conversion/PtrToLLVM/PtrToLLVM.cpp (+26-17)
- (modified) mlir/test/Conversion/PtrToLLVM/ptr-to-llvm.mlir (+57)
- (modified) mlir/test/Dialect/Ptr/ops.mlir (+11)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
index c98df5775195a..009c979bfc9bf 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
@@ -127,6 +127,45 @@ def Ptr_PtrMetadata : Ptr_Type<"PtrMetadata", "ptr_metadata"> {
let genVerifyDecl = 1;
}
+//===----------------------------------------------------------------------===//
+// Common props
+//===----------------------------------------------------------------------===//
+
+def AlignmentProp : OptionalProp<I64Prop>;
+
+//===----------------------------------------------------------------------===//
+// Common types
+//===----------------------------------------------------------------------===//
+
+// A shaped value type with value semantics and rank.
+class Ptr_ShapedValueType<list<Type> allowedTypes, list<Pred> preds = []> :
+ ShapedContainerType<allowedTypes,
+ /*containerPred=*/And<[HasValueSemanticsPred] # preds>,
+ /*descr=*/[{A shaped type with value semantics and rank.}],
+ /*cppType=*/"::mlir::ShapedType">;
+
+// A ptr-like type, either scalar or shaped type with value semantics.
+def Ptr_PtrLikeType :
+ AnyTypeOf<[Ptr_ShapedValueType<[Ptr_PtrType], [HasRankPred]>, Ptr_PtrType]>;
+
+// An int-like type, either scalar or shaped type with value semantics.
+def Ptr_IntLikeType :AnyTypeOf<[
+ Ptr_ShapedValueType<[AnySignlessIntegerOrIndex], [HasRankPred]>,
+ AnySignlessIntegerOrIndex
+]>;
+
+// A shaped value type of rank 1 of any element type.
+def Ptr_Any1DType :
+ Ptr_ShapedValueType<[AnyType], [HasAnyRankOfPred<[1]>]>;
+
+// A shaped value type of rank 1 of `i1` element type.
+def Ptr_Mask1DType :
+ Ptr_ShapedValueType<[I1], [HasAnyRankOfPred<[1]>]>;
+
+// A shaped value type of rank 1 of `i1` element type.
+def Ptr_Ptr1DType :
+ Ptr_ShapedValueType<[Ptr_PtrType], [HasAnyRankOfPred<[1]>]>;
+
//===----------------------------------------------------------------------===//
// Base address operation definition.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
index e14f64330c294..8af224247e6e3 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
@@ -18,45 +18,6 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td"
include "mlir/IR/OpAsmInterface.td"
-//===----------------------------------------------------------------------===//
-// Common props
-//===----------------------------------------------------------------------===//
-
-def AlignmentProp : OptionalProp<I64Prop>;
-
-//===----------------------------------------------------------------------===//
-// Common types
-//===----------------------------------------------------------------------===//
-
-// A shaped value type with value semantics and rank.
-class Ptr_ShapedValueType<list<Type> allowedTypes, list<Pred> preds = []> :
- ShapedContainerType<allowedTypes,
- /*containerPred=*/And<[HasValueSemanticsPred] # preds>,
- /*descr=*/[{A shaped type with value semantics and rank.}],
- /*cppType=*/"::mlir::ShapedType">;
-
-// A ptr-like type, either scalar or shaped type with value semantics.
-def Ptr_PtrLikeType :
- AnyTypeOf<[Ptr_ShapedValueType<[Ptr_PtrType], [HasRankPred]>, Ptr_PtrType]>;
-
-// An int-like type, either scalar or shaped type with value semantics.
-def Ptr_IntLikeType :AnyTypeOf<[
- Ptr_ShapedValueType<[AnySignlessIntegerOrIndex], [HasRankPred]>,
- AnySignlessIntegerOrIndex
-]>;
-
-// A shaped value type of rank 1 of any element type.
-def Ptr_Any1DType :
- Ptr_ShapedValueType<[AnyType], [HasAnyRankOfPred<[1]>]>;
-
-// A shaped value type of rank 1 of `i1` element type.
-def Ptr_Mask1DType :
- Ptr_ShapedValueType<[I1], [HasAnyRankOfPred<[1]>]>;
-
-// A shaped value type of rank 1 of `i1` element type.
-def Ptr_Ptr1DType :
- Ptr_ShapedValueType<[Ptr_PtrType], [HasAnyRankOfPred<[1]>]>;
-
//===----------------------------------------------------------------------===//
// ConstantOp
//===----------------------------------------------------------------------===//
@@ -645,4 +606,51 @@ def Ptr_TypeOffsetOp : Pointer_Op<"type_offset", [Pure]> {
}];
}
+//===----------------------------------------------------------------------===//
+// IntToPtrOp
+//===----------------------------------------------------------------------===//
+
+def Ptr_IntToPtrOp : Pointer_Op<"int_to_ptr", [Pure, SameOperandsAndResultShape]> {
+ let summary = "Integer to a pointer operation";
+ let description = [{
+ The `int_to_ptr` operation casts an int or index value to a pointer.
+
+ Example:
+ ```mlir
+ // Cast an integer to a pointer in the generic address space
+ %ptr = ptr.int_to_ptr %int : i64 -> !ptr.ptr<#ptr.generic_space>
+ // Cast a tensor of integer to a tensor of pointers in the generic address space
+ %ptr = ptr.int_to_ptr %int : tensor<10xi64> -> tensor<10x!ptr.ptr<#ptr.generic_space>>
+ // Cast a vector of integer to a vector of pointers in the generic address space
+ %ptr = ptr.int_to_ptr %int : vector<10xi64> -> vector<10x!ptr.ptr<#ptr.generic_space>>
+ ```
+ }];
+ let arguments = (ins Ptr_IntLikeType:$arg);
+ let results = (outs Ptr_PtrLikeType:$res);
+ let assemblyFormat = "$arg attr-dict `:` type($arg) `->` type($res)";
+}
+
+//===----------------------------------------------------------------------===//
+// PtrToIntOp
+//===----------------------------------------------------------------------===//
+
+def Ptr_PtrToIntOp : Pointer_Op<"ptr_to_int", [Pure, SameOperandsAndResultShape]> {
+ let summary = "Pointer to an integer operation";
+ let description = [{
+ The `ptr_to_int` operation casts a pointer value to an int or index.
+
+ Example:
+ ```mlir
+ // Cast a pointer in the generic address space to an integer
+ %int = ptr.ptr_to_int %ptr : !ptr.ptr<#ptr.generic_space> -> i64
+ // Cast a tensor of pointers in the generic address space to a tensor of integer
+ %int = ptr.ptr_to_int %ptr : tensor<10x!ptr.ptr<#ptr.generic_space>> -> tensor<10xi64>
+ // Cast a vector of pointers in the generic address space to a vector of integer
+ %int = ptr.ptr_to_int %ptr : vector<10x!ptr.ptr<#ptr.generic_space>> -> vector<10xi64>
+ ```
+ }];
+ let arguments = (ins Ptr_PtrLikeType:$arg);
+ let results = (outs Ptr_IntLikeType:$res);
+ let assemblyFormat = "$arg attr-dict `:` type($arg) `->` type($res)";
+}
#endif // PTR_OPS
diff --git a/mlir/lib/Conversion/PtrToLLVM/PtrToLLVM.cpp b/mlir/lib/Conversion/PtrToLLVM/PtrToLLVM.cpp
index 01199155ade39..587a177cd9bf7 100644
--- a/mlir/lib/Conversion/PtrToLLVM/PtrToLLVM.cpp
+++ b/mlir/lib/Conversion/PtrToLLVM/PtrToLLVM.cpp
@@ -12,6 +12,7 @@
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
@@ -143,10 +144,7 @@ LogicalResult FromPtrOpConversion::matchAndRewrite(
if (!mTy)
return rewriter.notifyMatchFailure(op, "Expected memref result type");
- if (!op.getMetadata() && op.getType().hasPtrMetadata()) {
- return rewriter.notifyMatchFailure(
- op, "Can convert only memrefs with metadata");
- }
+ bool hasMetadata = static_cast<bool>(op.getMetadata());
// Convert the result type
Type descriptorTy = getTypeConverter()->convertType(mTy);
@@ -167,9 +165,11 @@ LogicalResult FromPtrOpConversion::matchAndRewrite(
auto desc = MemRefDescriptor::poison(rewriter, loc, descriptorTy);
// Set the allocated and aligned pointers.
- desc.setAllocatedPtr(
- rewriter, loc,
- LLVM::ExtractValueOp::create(rewriter, loc, adaptor.getMetadata(), 0));
+ desc.setAllocatedPtr(rewriter, loc,
+ hasMetadata
+ ? LLVM::ExtractValueOp::create(
+ rewriter, loc, adaptor.getMetadata(), 0)
+ : adaptor.getPtr());
desc.setAlignedPtr(rewriter, loc, adaptor.getPtr());
// Extract metadata from the passed struct.
@@ -177,9 +177,11 @@ LogicalResult FromPtrOpConversion::matchAndRewrite(
// Set dynamic offset if needed.
if (offset == ShapedType::kDynamic) {
- Value offsetValue = LLVM::ExtractValueOp::create(
- rewriter, loc, adaptor.getMetadata(), fieldIdx++);
- desc.setOffset(rewriter, loc, offsetValue);
+ if (hasMetadata) {
+ Value offsetValue = LLVM::ExtractValueOp::create(
+ rewriter, loc, adaptor.getMetadata(), fieldIdx++);
+ desc.setOffset(rewriter, loc, offsetValue);
+ }
} else {
desc.setConstantOffset(rewriter, loc, offset);
}
@@ -187,9 +189,11 @@ LogicalResult FromPtrOpConversion::matchAndRewrite(
// Set dynamic sizes if needed.
for (auto [i, dim] : llvm::enumerate(shape)) {
if (dim == ShapedType::kDynamic) {
- Value sizeValue = LLVM::ExtractValueOp::create(
- rewriter, loc, adaptor.getMetadata(), fieldIdx++);
- desc.setSize(rewriter, loc, i, sizeValue);
+ if (hasMetadata) {
+ Value sizeValue = LLVM::ExtractValueOp::create(
+ rewriter, loc, adaptor.getMetadata(), fieldIdx++);
+ desc.setSize(rewriter, loc, i, sizeValue);
+ }
} else {
desc.setConstantSize(rewriter, loc, i, dim);
}
@@ -198,9 +202,11 @@ LogicalResult FromPtrOpConversion::matchAndRewrite(
// Set dynamic strides if needed.
for (auto [i, stride] : llvm::enumerate(strides)) {
if (stride == ShapedType::kDynamic) {
- Value strideValue = LLVM::ExtractValueOp::create(
- rewriter, loc, adaptor.getMetadata(), fieldIdx++);
- desc.setStride(rewriter, loc, i, strideValue);
+ if (hasMetadata) {
+ Value strideValue = LLVM::ExtractValueOp::create(
+ rewriter, loc, adaptor.getMetadata(), fieldIdx++);
+ desc.setStride(rewriter, loc, i, strideValue);
+ }
} else {
desc.setConstantStride(rewriter, loc, i, stride);
}
@@ -433,7 +439,10 @@ void mlir::ptr::populatePtrToLLVMConversionPatterns(
// Add conversion patterns.
patterns.add<FromPtrOpConversion, GetMetadataOpConversion, PtrAddOpConversion,
- ToPtrOpConversion, TypeOffsetOpConversion>(converter);
+ ToPtrOpConversion, TypeOffsetOpConversion,
+ VectorConvertToLLVMPattern<ptr::IntToPtrOp, LLVM::IntToPtrOp>,
+ VectorConvertToLLVMPattern<ptr::PtrToIntOp, LLVM::PtrToIntOp>>(
+ converter);
}
void mlir::ptr::registerConvertPtrToLLVMInterface(DialectRegistry ®istry) {
diff --git a/mlir/test/Conversion/PtrToLLVM/ptr-to-llvm.mlir b/mlir/test/Conversion/PtrToLLVM/ptr-to-llvm.mlir
index 5128fd8ccb265..c6f453b8c85b2 100644
--- a/mlir/test/Conversion/PtrToLLVM/ptr-to-llvm.mlir
+++ b/mlir/test/Conversion/PtrToLLVM/ptr-to-llvm.mlir
@@ -316,3 +316,60 @@ func.func @test_memref_ptradd_indexing(%arg0: memref<10x?x30xf32, #ptr.generic_s
%3 = ptr.ptr_add %0, %2 : !ptr.ptr<#ptr.generic_space>, index
return %3 : !ptr.ptr<#ptr.generic_space>
}
+
+// Tests reconstructing a static-sized memref from a pointer without metadata
+// CHECK-LABEL: llvm.func @test_from_ptr_pure(
+// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr) -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> {
+// CHECK: %[[MLIR_0:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[INSERTVALUE_0:.*]] = llvm.insertvalue %[[ARG0]], %[[MLIR_0]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[INSERTVALUE_1:.*]] = llvm.insertvalue %[[ARG0]], %[[INSERTVALUE_0]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[MLIR_1:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK: %[[INSERTVALUE_2:.*]] = llvm.insertvalue %[[MLIR_1]], %[[INSERTVALUE_1]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[MLIR_2:.*]] = llvm.mlir.constant(10 : index) : i64
+// CHECK: %[[INSERTVALUE_3:.*]] = llvm.insertvalue %[[MLIR_2]], %[[INSERTVALUE_2]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[MLIR_3:.*]] = llvm.mlir.constant(20 : index) : i64
+// CHECK: %[[INSERTVALUE_4:.*]] = llvm.insertvalue %[[MLIR_3]], %[[INSERTVALUE_3]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[MLIR_4:.*]] = llvm.mlir.constant(20 : index) : i64
+// CHECK: %[[INSERTVALUE_5:.*]] = llvm.insertvalue %[[MLIR_4]], %[[INSERTVALUE_4]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[MLIR_5:.*]] = llvm.mlir.constant(1 : index) : i64
+// CHECK: %[[INSERTVALUE_6:.*]] = llvm.insertvalue %[[MLIR_5]], %[[INSERTVALUE_5]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: llvm.return %[[INSERTVALUE_6]] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: }
+func.func @test_from_ptr_pure(%arg0: !ptr.ptr<#ptr.generic_space>) -> memref<10x20xf32, #ptr.generic_space> {
+ %0 = ptr.from_ptr %arg0 : <#ptr.generic_space> -> memref<10x20xf32, #ptr.generic_space>
+ return %0 : memref<10x20xf32, #ptr.generic_space>
+}
+
+// Tests int_to_ptr with both scalar and vector forms
+// CHECK-LABEL: llvm.func @test_int_to_ptr(
+// CHECK-SAME: %[[ARG0:.*]]: i64,
+// CHECK-SAME: %[[ARG1:.*]]: vector<10xi64>) -> !llvm.struct<(ptr, vector<10x!llvm.ptr>)> {
+// CHECK: %[[INTTOPTR_0:.*]] = llvm.inttoptr %[[ARG0]] : i64 to !llvm.ptr
+// CHECK: %[[INTTOPTR_1:.*]] = llvm.inttoptr %[[ARG1]] : vector<10xi64> to vector<10x!llvm.ptr>
+// CHECK: %[[MLIR_0:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, vector<10x!llvm.ptr>)>
+// CHECK: %[[INSERTVALUE_0:.*]] = llvm.insertvalue %[[INTTOPTR_0]], %[[MLIR_0]][0] : !llvm.struct<(ptr, vector<10x!llvm.ptr>)>
+// CHECK: %[[INSERTVALUE_1:.*]] = llvm.insertvalue %[[INTTOPTR_1]], %[[INSERTVALUE_0]][1] : !llvm.struct<(ptr, vector<10x!llvm.ptr>)>
+// CHECK: llvm.return %[[INSERTVALUE_1]] : !llvm.struct<(ptr, vector<10x!llvm.ptr>)>
+// CHECK: }
+func.func @test_int_to_ptr(%arg0: i64, %arg1: vector<10xi64>) -> (!ptr.ptr<#ptr.generic_space>, vector<10x!ptr.ptr<#ptr.generic_space>>) {
+ %0 = ptr.int_to_ptr %arg0 : i64 -> !ptr.ptr<#ptr.generic_space>
+ %1 = ptr.int_to_ptr %arg1 : vector<10xi64> -> vector<10x!ptr.ptr<#ptr.generic_space>>
+ return %0, %1 : !ptr.ptr<#ptr.generic_space>, vector<10x!ptr.ptr<#ptr.generic_space>>
+}
+
+// Tests ptr_to_int with both scalar and vector forms
+// CHECK-LABEL: llvm.func @test_ptr_to_int(
+// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr,
+// CHECK-SAME: %[[ARG1:.*]]: vector<10x!llvm.ptr>) -> !llvm.struct<(i64, vector<10xi64>)> {
+// CHECK: %[[PTRTOINT_0:.*]] = llvm.ptrtoint %[[ARG0]] : !llvm.ptr to i64
+// CHECK: %[[PTRTOINT_1:.*]] = llvm.ptrtoint %[[ARG1]] : vector<10x!llvm.ptr> to vector<10xi64>
+// CHECK: %[[MLIR_0:.*]] = llvm.mlir.poison : !llvm.struct<(i64, vector<10xi64>)>
+// CHECK: %[[INSERTVALUE_0:.*]] = llvm.insertvalue %[[PTRTOINT_0]], %[[MLIR_0]][0] : !llvm.struct<(i64, vector<10xi64>)>
+// CHECK: %[[INSERTVALUE_1:.*]] = llvm.insertvalue %[[PTRTOINT_1]], %[[INSERTVALUE_0]][1] : !llvm.struct<(i64, vector<10xi64>)>
+// CHECK: llvm.return %[[INSERTVALUE_1]] : !llvm.struct<(i64, vector<10xi64>)>
+// CHECK: }
+func.func @test_ptr_to_int(%arg0: !ptr.ptr<#ptr.generic_space>, %arg1: vector<10x!ptr.ptr<#ptr.generic_space>>) -> (i64, vector<10xi64>) {
+ %0 = ptr.ptr_to_int %arg0 : !ptr.ptr<#ptr.generic_space> -> i64
+ %1 = ptr.ptr_to_int %arg1 : vector<10x!ptr.ptr<#ptr.generic_space>> -> vector<10xi64>
+ return %0, %1 : i64, vector<10xi64>
+}
diff --git a/mlir/test/Dialect/Ptr/ops.mlir b/mlir/test/Dialect/Ptr/ops.mlir
index 0a906ad559e21..136a5327a00e7 100644
--- a/mlir/test/Dialect/Ptr/ops.mlir
+++ b/mlir/test/Dialect/Ptr/ops.mlir
@@ -20,6 +20,17 @@ func.func @cast_ops(%mr: memref<f32, #ptr.generic_space>) -> memref<f32, #ptr.ge
return %res : memref<f32, #ptr.generic_space>
}
+/// Check int_to_ptr and ptr_to_int ops assembly.
+func.func @ptr_int_ops(%arg0: i64, %arg1: vector<10xi64>, %arg2: tensor<10xi64>) -> (i64, vector<10xi64>, tensor<10xi64>) {
+ %ptr = ptr.int_to_ptr %arg0 : i64 -> !ptr.ptr<#ptr.generic_space>
+ %res = ptr.ptr_to_int %ptr : !ptr.ptr<#ptr.generic_space> -> i64
+ %ptr_vec = ptr.int_to_ptr %arg1 : vector<10xi64> -> vector<10x!ptr.ptr<#ptr.generic_space>>
+ %res_vec = ptr.ptr_to_int %ptr_vec : vector<10x!ptr.ptr<#ptr.generic_space>> -> vector<10xi64>
+ %ptr_tensor = ptr.int_to_ptr %arg2 : tensor<10xi64> -> tensor<10x!ptr.ptr<#ptr.generic_space>>
+ %res_tensor = ptr.ptr_to_int %ptr_tensor : tensor<10x!ptr.ptr<#ptr.generic_space>> -> tensor<10xi64>
+ return %res, %res_vec, %res_tensor : i64, vector<10xi64>, tensor<10xi64>
+}
+
/// Check load ops assembly.
func.func @load_ops(%arg0: !ptr.ptr<#ptr.generic_space>) -> (f32, f32, f32, f32, f32, i64, i32) {
%0 = ptr.load %arg0 : !ptr.ptr<#ptr.generic_space> -> f32
``````````
</details>
https://github.com/llvm/llvm-project/pull/190527
More information about the Mlir-commits
mailing list