[Mlir-commits] [flang] [mlir] [mlir][acc][flang] Add genCast API to PointerLikeType (PR #192720)
Razvan Lupusoru
llvmlistbot at llvm.org
Fri Apr 17 12:25:57 PDT 2026
https://github.com/razvanlupusoru created https://github.com/llvm/llvm-project/pull/192720
Introduces new API for PointerLikeType named genCast which can be used for generating IR that does type conversions. This is implemented for FIR reference types, memref, and LLVM ptr.
>From 7f65ffc4f37a5afdc66d3ec6edb2ddf4965cd65d Mon Sep 17 00:00:00 2001
From: Razvan Lupusoru <rlupusoru at nvidia.com>
Date: Fri, 17 Apr 2026 12:21:11 -0700
Subject: [PATCH] [mlir][acc][flang] Add genCast API to PointerLikeType
Introduces new API for PointerLikeType named genCast which can be
used for generating IR that does type conversions. This is
implemented for FIR reference types, memref, and LLVM ptr.
---
.../Support/FIROpenACCTypeInterfaces.h | 4 +
.../Support/FIROpenACCTypeInterfaces.cpp | 30 ++++
.../OpenACC/pointer-like-interface-cast.mlir | 68 +++++++
.../Dialect/OpenACC/OpenACCTypeInterfaces.td | 28 +++
mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp | 73 ++++++++
.../OpenACC/pointer-like-interface-cast.mlir | 36 ++++
.../OpenACC/TestPointerLikeTypeInterface.cpp | 62 ++++++-
mlir/unittests/Dialect/OpenACC/CMakeLists.txt | 1 +
.../OpenACC/OpenACCTypeInterfacesTest.cpp | 167 +++++++++++++++++-
9 files changed, 467 insertions(+), 2 deletions(-)
create mode 100644 flang/test/Fir/OpenACC/pointer-like-interface-cast.mlir
create mode 100644 mlir/test/Dialect/OpenACC/pointer-like-interface-cast.mlir
diff --git a/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.h b/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.h
index 01a1e19afd74b..3798d668ed547 100644
--- a/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.h
+++ b/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.h
@@ -53,6 +53,10 @@ struct OpenACCPointerLikeModel
mlir::Location loc, mlir::Value valueToStore,
mlir::TypedValue<mlir::acc::PointerLikeType> destPtr) const;
+ mlir::Value genCast(mlir::Type pointer, mlir::OpBuilder &builder,
+ mlir::Location loc, mlir::Value value,
+ mlir::Type resultType) const;
+
bool isDeviceData(mlir::Type pointer, mlir::Value var) const;
};
diff --git a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp
index 8749a816e1b00..ba72e46438a82 100644
--- a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp
+++ b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp
@@ -1603,6 +1603,36 @@ template bool OpenACCPointerLikeModel<fir::LLVMPointerType>::genStore(
mlir::Value valueToStore,
mlir::TypedValue<mlir::acc::PointerLikeType> destPtr) const;
+template <typename Ty>
+mlir::Value OpenACCPointerLikeModel<Ty>::genCast(
+ mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc,
+ mlir::Value value, mlir::Type resultType) const {
+ (void)pointer;
+ if (value.getType() == resultType)
+ return value;
+
+ if (fir::ConvertOp::canBeConverted(value.getType(), resultType))
+ return fir::ConvertOp::create(builder, loc, resultType, value);
+
+ return {};
+}
+
+template mlir::Value OpenACCPointerLikeModel<fir::ReferenceType>::genCast(
+ mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc,
+ mlir::Value value, mlir::Type resultType) const;
+
+template mlir::Value OpenACCPointerLikeModel<fir::PointerType>::genCast(
+ mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc,
+ mlir::Value value, mlir::Type resultType) const;
+
+template mlir::Value OpenACCPointerLikeModel<fir::HeapType>::genCast(
+ mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc,
+ mlir::Value value, mlir::Type resultType) const;
+
+template mlir::Value OpenACCPointerLikeModel<fir::LLVMPointerType>::genCast(
+ mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc,
+ mlir::Value value, mlir::Type resultType) const;
+
/// Check CUDA attributes on a function argument.
static bool hasCUDADeviceAttrOnFuncArg(mlir::BlockArgument blockArg) {
auto *owner = blockArg.getOwner();
diff --git a/flang/test/Fir/OpenACC/pointer-like-interface-cast.mlir b/flang/test/Fir/OpenACC/pointer-like-interface-cast.mlir
new file mode 100644
index 0000000000000..c0ebd8ae1fbc3
--- /dev/null
+++ b/flang/test/Fir/OpenACC/pointer-like-interface-cast.mlir
@@ -0,0 +1,68 @@
+// RUN: fir-opt %s --split-input-file --pass-pipeline="builtin.module(func.func(test-acc-pointer-like-interface{test-mode=cast}))" 2>&1 | FileCheck %s
+
+func.func @test_fir_ref_to_memref_scalar() {
+ %0 = fir.alloca f32 {test.cast, cast_dest = memref<f32>}
+ // CHECK: Successfully generated cast for operation: %{{.*}} = fir.alloca f32{{.*}}
+ // CHECK: Cast result type: memref<f32>
+ // CHECK: Generated: %{{.*}} = fir.convert %{{.*}} : (!fir.ref<f32>) -> memref<f32>
+ return
+}
+
+// -----
+
+func.func @test_memref_to_fir_ref_scalar() {
+ %0 = memref.alloca() {test.cast, cast_dest = !fir.ref<f32>} : memref<f32>
+ // CHECK: Successfully generated cast for operation: %{{.*}} = memref.alloca(){{.*}}
+ // CHECK: Cast result type: !fir.ref<f32>
+ // CHECK: Generated: %{{.*}} = fir.convert %{{.*}} : (memref<f32>) -> !fir.ref<f32>
+ return
+}
+
+// -----
+
+func.func @test_fir_ref_identity() {
+ %0 = fir.alloca i32 {test.cast, cast_dest = !fir.ref<i32>}
+ // CHECK: Successfully generated cast for operation: %{{.*}} = fir.alloca i32{{.*}}
+ // CHECK: Cast result type: !fir.ref<i32>
+ return
+}
+
+// -----
+
+func.func @test_i64_to_fir_ref() {
+ %0 = arith.constant {test.cast, cast_dest = !fir.ref<i8>} 0 : i64
+ // CHECK: Successfully generated cast for operation: %{{.*}} = arith.constant{{.*}}
+ // CHECK: Cast result type: !fir.ref<i8>
+ // CHECK: Generated: %{{.*}} = fir.convert %{{.*}} : (i64) -> !fir.ref<i8>
+ return
+}
+
+// -----
+
+func.func @test_index_to_fir_ptr() {
+ %0 = arith.constant {test.cast, cast_dest = !fir.ptr<i8>} 0 : index
+ // CHECK: Successfully generated cast for operation: %{{.*}} = arith.constant{{.*}}
+ // CHECK: Cast result type: !fir.ptr<i8>
+ // CHECK: Generated: %{{.*}} = fir.convert %{{.*}} : (index) -> !fir.ptr<i8>
+ return
+}
+
+// -----
+
+func.func @test_fir_heap_to_i64() {
+ %0 = fir.zero_bits !fir.heap<i8> {test.cast, cast_dest = i64}
+ // CHECK: Successfully generated cast for operation: %{{.*}} = fir.zero_bits{{.*}}
+ // CHECK: Cast result type: i64
+ // CHECK: Generated: %{{.*}} = fir.convert %{{.*}} : (!fir.heap<i8>) -> i64
+ return
+}
+
+// -----
+
+func.func @test_fir_llvm_ptr_to_index() {
+ %0 = fir.zero_bits !fir.llvm_ptr<i8> {test.cast, cast_dest = index}
+ // CHECK: Successfully generated cast for operation: %{{.*}} = fir.zero_bits{{.*}}
+ // CHECK: Cast result type: index
+ // CHECK: Generated: %{{.*}} = fir.convert %{{.*}} : (!fir.llvm_ptr<i8>) -> index
+ return
+}
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCTypeInterfaces.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCTypeInterfaces.td
index 3bd4e5c679659..2753056bf8a7a 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCTypeInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCTypeInterfaces.td
@@ -220,6 +220,34 @@ def OpenACC_PointerLikeTypeInterface : TypeInterface<"PointerLikeType"> {
return false;
}]
>,
+ InterfaceMethod<
+ /*description=*/[{
+ Generates a cast from `value` to `resultType` when the implementing
+ pointer-like type can emit a lowering for that conversion.
+
+ This is intentionally a single operation (rather than separate
+ "cast to" / "cast from" hooks): the source type is always
+ `value.getType()` and the destination is always `resultType`.
+
+ Call sites typically dispatch on the `PointerLikeType` that owns the
+ conversion: use the source value's type when casting *from* a
+ pointer-like representation.
+
+ Returns the cast result on success, or an empty value if the cast is
+ unsupported (callers may then try the other endpoint's interface or
+ apply their own fallback).
+ }],
+ /*retTy=*/"::mlir::Value",
+ /*methodName=*/"genCast",
+ /*args=*/(ins "::mlir::OpBuilder &":$builder,
+ "::mlir::Location":$loc,
+ "::mlir::Value":$value,
+ "::mlir::Type":$resultType),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return {};
+ }]
+ >,
InterfaceMethod<
/*description=*/[{
Returns true if the pointer points to device data.
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index bd26c70eb1831..6f49a66167f67 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/IRMapping.h"
@@ -246,6 +247,38 @@ struct MemRefPointerLikeModel
return true;
}
+ Value genCast(Type pointer, OpBuilder &builder, Location loc, Value value,
+ Type resultType) const {
+ (void)pointer;
+ if (value.getType() == resultType)
+ return value;
+
+ if (isa<BaseMemRefType>(value.getType()) &&
+ isa<BaseMemRefType>(resultType)) {
+ if (memref::CastOp::areCastCompatible(TypeRange(value.getType()),
+ TypeRange(resultType)))
+ return memref::CastOp::create(builder, loc, resultType, value);
+ if (memref::MemorySpaceCastOp::areCastCompatible(
+ TypeRange(value.getType()), TypeRange(resultType)))
+ return memref::MemorySpaceCastOp::create(builder, loc, resultType,
+ value);
+ }
+
+ // If one side is not a memref, try the other type's `PointerLikeType`
+ // implementation (since it may be an out-of-tree reference type that
+ // we cannot generate here).
+ if (auto resPtrLike = dyn_cast<PointerLikeType>(resultType))
+ if (!isa<BaseMemRefType>(resPtrLike))
+ if (Value v = resPtrLike.genCast(builder, loc, value, resultType))
+ return v;
+ if (auto valPtrLike = dyn_cast<PointerLikeType>(value.getType()))
+ if (!isa<BaseMemRefType>(valPtrLike))
+ if (Value v = valPtrLike.genCast(builder, loc, value, resultType))
+ return v;
+
+ return {};
+ }
+
bool isDeviceData(Type pointer, Value var) const {
auto memrefTy = cast<T>(pointer);
Attribute memSpace = memrefTy.getMemorySpace();
@@ -273,6 +306,46 @@ struct LLVMPointerPointerLikeModel
LLVM::StoreOp::create(builder, loc, valueToStore, destPtr);
return true;
}
+
+ Value genCast(Type pointer, OpBuilder &builder, Location loc, Value value,
+ Type resultType) const {
+ (void)pointer;
+ if (value.getType() == resultType)
+ return value;
+
+ auto srcPtrTy = dyn_cast<LLVM::LLVMPointerType>(value.getType());
+ auto dstPtrTy = dyn_cast<LLVM::LLVMPointerType>(resultType);
+ if (srcPtrTy && dstPtrTy) {
+ if (srcPtrTy.getAddressSpace() != dstPtrTy.getAddressSpace())
+ return LLVM::AddrSpaceCastOp::create(builder, loc, resultType, value);
+ return value;
+ }
+
+ if (srcPtrTy && isa<IntegerType>(resultType))
+ return LLVM::PtrToIntOp::create(builder, loc, resultType, value);
+
+ if (dstPtrTy) {
+ Value intVal = value;
+ if (isa<IndexType>(value.getType()))
+ intVal = arith::IndexCastUIOp::create(builder, loc, builder.getI64Type(),
+ value);
+ if (isa<IntegerType>(intVal.getType()))
+ return LLVM::IntToPtrOp::create(builder, loc, resultType, intVal);
+ }
+
+ if (auto resPtrLike = dyn_cast<PointerLikeType>(resultType))
+ if (!isa<LLVM::LLVMPointerType>(resPtrLike))
+ if (Value v = resPtrLike.genCast(builder, loc, value, resultType))
+ return v;
+ if (auto valPtrLike = dyn_cast<PointerLikeType>(value.getType()))
+ if (!isa<LLVM::LLVMPointerType>(valPtrLike))
+ if (Value v = valPtrLike.genCast(builder, loc, value, resultType))
+ return v;
+
+ return UnrealizedConversionCastOp::create(builder, loc,
+ TypeRange(resultType), value)
+ .getResult(0);
+ }
};
struct MemrefAddressOfGlobalModel
diff --git a/mlir/test/Dialect/OpenACC/pointer-like-interface-cast.mlir b/mlir/test/Dialect/OpenACC/pointer-like-interface-cast.mlir
new file mode 100644
index 0000000000000..2a0939a3020db
--- /dev/null
+++ b/mlir/test/Dialect/OpenACC/pointer-like-interface-cast.mlir
@@ -0,0 +1,36 @@
+// RUN: mlir-opt %s --split-input-file --pass-pipeline="builtin.module(func.func(test-acc-pointer-like-interface{test-mode=cast}))" 2>&1 | FileCheck %s
+
+func.func @test_memref_cast_identity() {
+ %0 = memref.alloca() {test.cast, cast_dest = memref<f32>} : memref<f32>
+ // CHECK: Successfully generated cast for operation: %[[V:.*]] = memref.alloca(){{.*}} : memref<f32>
+ // CHECK: Cast result type: memref<f32>
+ return
+}
+
+// -----
+
+func.func @test_memref_cast_static_to_dynamic() {
+ %0 = memref.alloca() {test.cast, cast_dest = memref<?xf32>} : memref<4xf32>
+ // CHECK: Successfully generated cast for operation: %[[V:.*]] = memref.alloca(){{.*}} : memref<4xf32>
+ // CHECK: Cast result type: memref<?xf32>
+ // CHECK: Generated: %{{.*}} = memref.cast %[[V]] : memref<4xf32> to memref<?xf32>
+ return
+}
+
+// -----
+
+func.func @test_memref_memory_space_cast() {
+ %0 = memref.alloca() {test.cast, cast_dest = memref<4xf32, 1>} : memref<4xf32>
+ // CHECK: Successfully generated cast for operation: %[[V:.*]] = memref.alloca(){{.*}} : memref<4xf32>
+ // CHECK: Cast result type: memref<4xf32, 1>
+ // CHECK: Generated: %{{.*}} = memref.memory_space_cast %[[V]] : memref<4xf32> to memref<4xf32, 1>
+ return
+}
+
+// -----
+
+func.func @test_memref_cast_incompatible() {
+ %0 = memref.alloca() {test.cast, cast_dest = tensor<4xf32>} : memref<4xf32>
+ // CHECK: Failed to generate cast for operation: %{{.*}} = memref.alloca(){{.*}} : memref<4xf32>
+ return
+}
diff --git a/mlir/test/lib/Dialect/OpenACC/TestPointerLikeTypeInterface.cpp b/mlir/test/lib/Dialect/OpenACC/TestPointerLikeTypeInterface.cpp
index 3ff0dc85b2152..e45fd104e7331 100644
--- a/mlir/test/lib/Dialect/OpenACC/TestPointerLikeTypeInterface.cpp
+++ b/mlir/test/lib/Dialect/OpenACC/TestPointerLikeTypeInterface.cpp
@@ -46,7 +46,8 @@ struct TestPointerLikeTypeInterfacePass
Pass::Option<std::string> testMode{
*this, "test-mode",
- llvm::cl::desc("Test mode: walk, alloc, copy, free, load, or store"),
+ llvm::cl::desc(
+ "Test mode: walk, alloc, copy, free, load, store, or cast"),
llvm::cl::init("walk")};
StringRef getArgument() const override {
@@ -79,6 +80,8 @@ struct TestPointerLikeTypeInterfacePass
OpBuilder &builder);
void testGenStore(Operation *op, Value result, PointerLikeType pointerType,
OpBuilder &builder, Value providedValue = {});
+ void testGenCast(Operation *op, Value value, Type resultType,
+ OpBuilder &builder);
struct PointerCandidate {
Operation *op;
@@ -96,6 +99,18 @@ void TestPointerLikeTypeInterfacePass::runOnOperation() {
auto func = getOperation();
OpBuilder builder(&getContext());
+ if (testMode == "cast") {
+ func.walk([&](Operation *op) {
+ if (!op->hasAttr("test.cast"))
+ return;
+ auto destAttr = dyn_cast_or_null<TypeAttr>(op->getAttr("cast_dest"));
+ if (!destAttr || op->getNumResults() == 0)
+ return;
+ testGenCast(op, op->getResult(0), destAttr.getValue(), builder);
+ });
+ return;
+ }
+
if (testMode == "alloc" || testMode == "free" || testMode == "load" ||
testMode == "store") {
// Collect all candidates first
@@ -409,6 +424,51 @@ void TestPointerLikeTypeInterfacePass::testGenStore(Operation *op, Value result,
}
}
+void TestPointerLikeTypeInterfacePass::testGenCast(Operation *op, Value value,
+ Type resultType,
+ OpBuilder &builder) {
+ Location loc = op->getLoc();
+
+ OperationTracker tracker;
+ OpBuilder newBuilder(op->getContext());
+ newBuilder.setListener(&tracker);
+ newBuilder.setInsertionPointAfter(op);
+
+ PointerLikeType dispatchTy;
+ if (isa<PointerLikeType>(value.getType()))
+ dispatchTy = cast<PointerLikeType>(value.getType());
+ else if (isa<PointerLikeType>(resultType))
+ dispatchTy = cast<PointerLikeType>(resultType);
+ else {
+ llvm::errs() << "Failed genCast: neither value nor result type is "
+ "PointerLikeType for operation: ";
+ op->print(llvm::errs());
+ llvm::errs() << "\n";
+ return;
+ }
+
+ Value castRes = dispatchTy.genCast(newBuilder, loc, value, resultType);
+
+ if (castRes) {
+ llvm::errs() << "Successfully generated cast for operation: ";
+ op->print(llvm::errs());
+ llvm::errs() << "\n";
+ llvm::errs() << "\tCast result type: ";
+ castRes.getType().print(llvm::errs());
+ llvm::errs() << "\n";
+
+ for (Operation *insertedOp : tracker.insertedOps) {
+ llvm::errs() << "\tGenerated: ";
+ insertedOp->print(llvm::errs());
+ llvm::errs() << "\n";
+ }
+ } else {
+ llvm::errs() << "Failed to generate cast for operation: ";
+ op->print(llvm::errs());
+ llvm::errs() << "\n";
+ }
+}
+
} // namespace
//===----------------------------------------------------------------------===//
diff --git a/mlir/unittests/Dialect/OpenACC/CMakeLists.txt b/mlir/unittests/Dialect/OpenACC/CMakeLists.txt
index 17d1721b82602..7bcb652b69185 100644
--- a/mlir/unittests/Dialect/OpenACC/CMakeLists.txt
+++ b/mlir/unittests/Dialect/OpenACC/CMakeLists.txt
@@ -16,6 +16,7 @@ mlir_target_link_libraries(MLIROpenACCTests
MLIRDLTIDialect
MLIRFuncDialect
MLIRGPUDialect
+ MLIRLLVMDialect
MLIRMemRefDialect
MLIRArithDialect
MLIROpenACCDialect
diff --git a/mlir/unittests/Dialect/OpenACC/OpenACCTypeInterfacesTest.cpp b/mlir/unittests/Dialect/OpenACC/OpenACCTypeInterfacesTest.cpp
index c773c031e3b6e..521a2f3f26c16 100644
--- a/mlir/unittests/Dialect/OpenACC/OpenACCTypeInterfacesTest.cpp
+++ b/mlir/unittests/Dialect/OpenACC/OpenACCTypeInterfacesTest.cpp
@@ -7,8 +7,12 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/OpenACC/OpenACC.h"
#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/IR/MLIRContext.h"
@@ -16,6 +20,7 @@
using namespace mlir;
using namespace mlir::acc;
+using namespace mlir::LLVM;
namespace {
@@ -51,7 +56,9 @@ class OpenACCTypeInterfacesTest : public ::testing::Test {
IntegerType::attachInterface<TestReducibleIntegerModel>(*ctx);
});
context.appendDialectRegistry(registry);
- context.loadDialect<acc::OpenACCDialect, arith::ArithDialect>();
+ context.loadDialect<acc::OpenACCDialect, arith::ArithDialect,
+ memref::MemRefDialect, func::FuncDialect,
+ LLVMDialect>();
}
MLIRContext context;
@@ -91,3 +98,161 @@ TEST_F(OpenACCTypeInterfacesTest, NonReducibleTypeReturnsNull) {
auto reducible = dyn_cast<ReducibleType>(f32Type);
EXPECT_TRUE(reducible == nullptr);
}
+
+//===----------------------------------------------------------------------===//
+// PointerLikeType::genCast tests
+//===----------------------------------------------------------------------===//
+
+TEST_F(OpenACCTypeInterfacesTest, PointerLikeGenCastMemrefIdentity) {
+ Location loc = UnknownLoc::get(&context);
+ OwningOpRef<ModuleOp> module = ModuleOp::create(loc);
+ OpBuilder builder(module->getBodyRegion());
+ func::FuncOp fn = func::FuncOp::create(builder, loc, "cast_identity",
+ builder.getFunctionType({}, {}));
+ Block *block = fn.addEntryBlock();
+ builder.setInsertionPointToStart(block);
+
+ auto memTy = MemRefType::get({}, builder.getF32Type());
+ memref::AllocaOp alloca = memref::AllocaOp::create(builder, loc, memTy);
+ Value v = alloca.getResult();
+ auto ptrLike = cast<PointerLikeType>(v.getType());
+ Value out = ptrLike.genCast(builder, loc, v, memTy);
+ ASSERT_TRUE(out);
+ EXPECT_EQ(out, v);
+}
+
+TEST_F(OpenACCTypeInterfacesTest, PointerLikeGenCastMemrefStaticToDynamic) {
+ Location loc = UnknownLoc::get(&context);
+ OwningOpRef<ModuleOp> module = ModuleOp::create(loc);
+ OpBuilder builder(module->getBodyRegion());
+ func::FuncOp fn = func::FuncOp::create(builder, loc, "cast_memref",
+ builder.getFunctionType({}, {}));
+ Block *block = fn.addEntryBlock();
+ builder.setInsertionPointToStart(block);
+
+ auto srcTy = MemRefType::get({4}, builder.getF32Type());
+ auto dstTy = MemRefType::get({ShapedType::kDynamic}, builder.getF32Type());
+ memref::AllocaOp alloca = memref::AllocaOp::create(builder, loc, srcTy);
+ Value v = alloca.getResult();
+ auto ptrLike = cast<PointerLikeType>(v.getType());
+ Value out = ptrLike.genCast(builder, loc, v, dstTy);
+ ASSERT_TRUE(out);
+ EXPECT_EQ(out.getType(), dstTy);
+ ASSERT_TRUE(isa<memref::CastOp>(out.getDefiningOp()));
+}
+
+TEST_F(OpenACCTypeInterfacesTest, PointerLikeGenCastMemrefMemorySpace) {
+ Location loc = UnknownLoc::get(&context);
+ OwningOpRef<ModuleOp> module = ModuleOp::create(loc);
+ OpBuilder builder(module->getBodyRegion());
+ func::FuncOp fn = func::FuncOp::create(builder, loc, "cast_memref_memspace",
+ builder.getFunctionType({}, {}));
+ Block *block = fn.addEntryBlock();
+ builder.setInsertionPointToStart(block);
+
+ Attribute msHost = builder.getI32IntegerAttr(0);
+ Attribute msDev = builder.getI32IntegerAttr(1);
+ auto srcTy = MemRefType::get({4}, builder.getF32Type(), AffineMap(), msHost);
+ auto dstTy = MemRefType::get({4}, builder.getF32Type(), AffineMap(), msDev);
+ memref::AllocaOp alloca = memref::AllocaOp::create(builder, loc, srcTy);
+ Value v = alloca.getResult();
+ auto ptrLike = cast<PointerLikeType>(v.getType());
+ Value out = ptrLike.genCast(builder, loc, v, dstTy);
+ ASSERT_TRUE(out);
+ EXPECT_EQ(out.getType(), dstTy);
+ ASSERT_TRUE(isa<memref::MemorySpaceCastOp>(out.getDefiningOp()));
+}
+
+TEST_F(OpenACCTypeInterfacesTest, PointerLikeGenCastLLVMPtrAddrSpace) {
+ Location loc = UnknownLoc::get(&context);
+ OwningOpRef<ModuleOp> module = ModuleOp::create(loc);
+ OpBuilder builder(module->getBodyRegion());
+ func::FuncOp fn = func::FuncOp::create(builder, loc, "cast_llvm_addrspace",
+ builder.getFunctionType({}, {}));
+ Block *block = fn.addEntryBlock();
+ builder.setInsertionPointToStart(block);
+
+ Type ptrAs0 = LLVMPointerType::get(&context, 0);
+ Type ptrAs3 = LLVMPointerType::get(&context, 3);
+ Value v = UndefOp::create(builder, loc, ptrAs0);
+ auto ptrLike = cast<PointerLikeType>(ptrAs0);
+ Value out = ptrLike.genCast(builder, loc, v, ptrAs3);
+ ASSERT_TRUE(out);
+ EXPECT_EQ(out.getType(), ptrAs3);
+ ASSERT_TRUE(isa<AddrSpaceCastOp>(out.getDefiningOp()));
+}
+
+TEST_F(OpenACCTypeInterfacesTest, PointerLikeGenCastLLVMPtrSameAddrSpaceNoOp) {
+ Location loc = UnknownLoc::get(&context);
+ OwningOpRef<ModuleOp> module = ModuleOp::create(loc);
+ OpBuilder builder(module->getBodyRegion());
+ func::FuncOp fn = func::FuncOp::create(builder, loc, "cast_llvm_same_as",
+ builder.getFunctionType({}, {}));
+ Block *block = fn.addEntryBlock();
+ builder.setInsertionPointToStart(block);
+
+ Type ptrTy = LLVMPointerType::get(&context, 2);
+ Value v = UndefOp::create(builder, loc, ptrTy);
+ auto ptrLike = cast<PointerLikeType>(ptrTy);
+ Value out = ptrLike.genCast(builder, loc, v, ptrTy);
+ ASSERT_TRUE(out);
+ EXPECT_EQ(out, v);
+}
+
+TEST_F(OpenACCTypeInterfacesTest, PointerLikeGenCastLLVMPtrToI64) {
+ Location loc = UnknownLoc::get(&context);
+ OwningOpRef<ModuleOp> module = ModuleOp::create(loc);
+ OpBuilder builder(module->getBodyRegion());
+ func::FuncOp fn = func::FuncOp::create(builder, loc, "cast_llvm_ptrtoint",
+ builder.getFunctionType({}, {}));
+ Block *block = fn.addEntryBlock();
+ builder.setInsertionPointToStart(block);
+
+ Type ptrTy = LLVMPointerType::get(&context, 0);
+ Type i64Ty = builder.getI64Type();
+ Value v = UndefOp::create(builder, loc, ptrTy);
+ auto ptrLike = cast<PointerLikeType>(ptrTy);
+ Value out = ptrLike.genCast(builder, loc, v, i64Ty);
+ ASSERT_TRUE(out);
+ EXPECT_EQ(out.getType(), i64Ty);
+ ASSERT_TRUE(isa<PtrToIntOp>(out.getDefiningOp()));
+}
+
+TEST_F(OpenACCTypeInterfacesTest, PointerLikeGenCastLLVMIntToPtrFromI64) {
+ Location loc = UnknownLoc::get(&context);
+ OwningOpRef<ModuleOp> module = ModuleOp::create(loc);
+ OpBuilder builder(module->getBodyRegion());
+ func::FuncOp fn = func::FuncOp::create(builder, loc, "cast_llvm_inttoptr",
+ builder.getFunctionType({}, {}));
+ Block *block = fn.addEntryBlock();
+ builder.setInsertionPointToStart(block);
+
+ Type ptrTy = LLVMPointerType::get(&context, 0);
+ Value v = arith::ConstantIntOp::create(builder, loc, builder.getI64Type(), 0);
+ auto ptrLike = cast<PointerLikeType>(ptrTy);
+ Value out = ptrLike.genCast(builder, loc, v, ptrTy);
+ ASSERT_TRUE(out);
+ EXPECT_EQ(out.getType(), ptrTy);
+ ASSERT_TRUE(isa<IntToPtrOp>(out.getDefiningOp()));
+}
+
+TEST_F(OpenACCTypeInterfacesTest, PointerLikeGenCastLLVMIntToPtrFromIndex) {
+ Location loc = UnknownLoc::get(&context);
+ OwningOpRef<ModuleOp> module = ModuleOp::create(loc);
+ OpBuilder builder(module->getBodyRegion());
+ func::FuncOp fn =
+ func::FuncOp::create(builder, loc, "cast_llvm_index_inttoptr",
+ builder.getFunctionType({}, {}));
+ Block *block = fn.addEntryBlock();
+ builder.setInsertionPointToStart(block);
+
+ Type ptrTy = LLVMPointerType::get(&context, 0);
+ Value v = arith::ConstantIndexOp::create(builder, loc, 0);
+ auto ptrLike = cast<PointerLikeType>(ptrTy);
+ Value out = ptrLike.genCast(builder, loc, v, ptrTy);
+ ASSERT_TRUE(out);
+ EXPECT_EQ(out.getType(), ptrTy);
+ auto intToPtr = dyn_cast<IntToPtrOp>(out.getDefiningOp());
+ ASSERT_TRUE(intToPtr);
+ EXPECT_TRUE(isa<arith::IndexCastUIOp>(intToPtr.getArg().getDefiningOp()));
+}
More information about the Mlir-commits
mailing list