[flang-commits] [flang] 7c8fb2e - [mlir][acc][flang] Add genCast API to PointerLikeType (#192720)

via flang-commits flang-commits at lists.llvm.org
Mon Apr 20 07:32:18 PDT 2026


Author: Razvan Lupusoru
Date: 2026-04-20T07:32:12-07:00
New Revision: 7c8fb2ee69cdfa22d5b78b11617a38618ae83a08

URL: https://github.com/llvm/llvm-project/commit/7c8fb2ee69cdfa22d5b78b11617a38618ae83a08
DIFF: https://github.com/llvm/llvm-project/commit/7c8fb2ee69cdfa22d5b78b11617a38618ae83a08.diff

LOG: [mlir][acc][flang] Add genCast API to PointerLikeType (#192720)

Introduces new API for PointerLikeType named genCast which can be used
for generating IR that does type conversions. This is implemented for
FIR reference types, memref, and LLVM ptr.

Added: 
    flang/test/Fir/OpenACC/pointer-like-interface-cast.mlir
    mlir/test/Dialect/OpenACC/pointer-like-interface-cast.mlir

Modified: 
    flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.h
    flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp
    mlir/include/mlir/Dialect/OpenACC/OpenACCTypeInterfaces.td
    mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
    mlir/test/lib/Dialect/OpenACC/TestPointerLikeTypeInterface.cpp
    mlir/unittests/Dialect/OpenACC/CMakeLists.txt
    mlir/unittests/Dialect/OpenACC/OpenACCTypeInterfacesTest.cpp

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.h b/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.h
index 01a1e19afd74b..3798d668ed547 100644
--- a/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.h
+++ b/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.h
@@ -53,6 +53,10 @@ struct OpenACCPointerLikeModel
                 mlir::Location loc, mlir::Value valueToStore,
                 mlir::TypedValue<mlir::acc::PointerLikeType> destPtr) const;
 
+  mlir::Value genCast(mlir::Type pointer, mlir::OpBuilder &builder,
+                      mlir::Location loc, mlir::Value value,
+                      mlir::Type resultType) const;
+
   bool isDeviceData(mlir::Type pointer, mlir::Value var) const;
 };
 

diff  --git a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp
index 8749a816e1b00..9a8e7e47e54d0 100644
--- a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp
+++ b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp
@@ -1603,6 +1603,38 @@ template bool OpenACCPointerLikeModel<fir::LLVMPointerType>::genStore(
     mlir::Value valueToStore,
     mlir::TypedValue<mlir::acc::PointerLikeType> destPtr) const;
 
+template <typename Ty>
+mlir::Value OpenACCPointerLikeModel<Ty>::genCast(mlir::Type pointer,
+                                                 mlir::OpBuilder &builder,
+                                                 mlir::Location loc,
+                                                 mlir::Value value,
+                                                 mlir::Type resultType) const {
+  (void)pointer;
+  if (value.getType() == resultType)
+    return value;
+
+  if (fir::ConvertOp::canBeConverted(value.getType(), resultType))
+    return fir::ConvertOp::create(builder, loc, resultType, value);
+
+  return {};
+}
+
+template mlir::Value OpenACCPointerLikeModel<fir::ReferenceType>::genCast(
+    mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc,
+    mlir::Value value, mlir::Type resultType) const;
+
+template mlir::Value OpenACCPointerLikeModel<fir::PointerType>::genCast(
+    mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc,
+    mlir::Value value, mlir::Type resultType) const;
+
+template mlir::Value OpenACCPointerLikeModel<fir::HeapType>::genCast(
+    mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc,
+    mlir::Value value, mlir::Type resultType) const;
+
+template mlir::Value OpenACCPointerLikeModel<fir::LLVMPointerType>::genCast(
+    mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc,
+    mlir::Value value, mlir::Type resultType) const;
+
 /// Check CUDA attributes on a function argument.
 static bool hasCUDADeviceAttrOnFuncArg(mlir::BlockArgument blockArg) {
   auto *owner = blockArg.getOwner();

diff  --git a/flang/test/Fir/OpenACC/pointer-like-interface-cast.mlir b/flang/test/Fir/OpenACC/pointer-like-interface-cast.mlir
new file mode 100644
index 0000000000000..c0ebd8ae1fbc3
--- /dev/null
+++ b/flang/test/Fir/OpenACC/pointer-like-interface-cast.mlir
@@ -0,0 +1,68 @@
+// RUN: fir-opt %s --split-input-file --pass-pipeline="builtin.module(func.func(test-acc-pointer-like-interface{test-mode=cast}))" 2>&1 | FileCheck %s
+
+func.func @test_fir_ref_to_memref_scalar() {
+  %0 = fir.alloca f32 {test.cast, cast_dest = memref<f32>}
+  // CHECK: Successfully generated cast for operation: %{{.*}} = fir.alloca f32{{.*}}
+  // CHECK: Cast result type: memref<f32>
+  // CHECK: Generated: %{{.*}} = fir.convert %{{.*}} : (!fir.ref<f32>) -> memref<f32>
+  return
+}
+
+// -----
+
+func.func @test_memref_to_fir_ref_scalar() {
+  %0 = memref.alloca() {test.cast, cast_dest = !fir.ref<f32>} : memref<f32>
+  // CHECK: Successfully generated cast for operation: %{{.*}} = memref.alloca(){{.*}}
+  // CHECK: Cast result type: !fir.ref<f32>
+  // CHECK: Generated: %{{.*}} = fir.convert %{{.*}} : (memref<f32>) -> !fir.ref<f32>
+  return
+}
+
+// -----
+
+func.func @test_fir_ref_identity() {
+  %0 = fir.alloca i32 {test.cast, cast_dest = !fir.ref<i32>}
+  // CHECK: Successfully generated cast for operation: %{{.*}} = fir.alloca i32{{.*}}
+  // CHECK: Cast result type: !fir.ref<i32>
+  return
+}
+
+// -----
+
+func.func @test_i64_to_fir_ref() {
+  %0 = arith.constant {test.cast, cast_dest = !fir.ref<i8>} 0 : i64
+  // CHECK: Successfully generated cast for operation: %{{.*}} = arith.constant{{.*}}
+  // CHECK: Cast result type: !fir.ref<i8>
+  // CHECK: Generated: %{{.*}} = fir.convert %{{.*}} : (i64) -> !fir.ref<i8>
+  return
+}
+
+// -----
+
+func.func @test_index_to_fir_ptr() {
+  %0 = arith.constant {test.cast, cast_dest = !fir.ptr<i8>} 0 : index
+  // CHECK: Successfully generated cast for operation: %{{.*}} = arith.constant{{.*}}
+  // CHECK: Cast result type: !fir.ptr<i8>
+  // CHECK: Generated: %{{.*}} = fir.convert %{{.*}} : (index) -> !fir.ptr<i8>
+  return
+}
+
+// -----
+
+func.func @test_fir_heap_to_i64() {
+  %0 = fir.zero_bits !fir.heap<i8> {test.cast, cast_dest = i64}
+  // CHECK: Successfully generated cast for operation: %{{.*}} = fir.zero_bits{{.*}}
+  // CHECK: Cast result type: i64
+  // CHECK: Generated: %{{.*}} = fir.convert %{{.*}} : (!fir.heap<i8>) -> i64
+  return
+}
+
+// -----
+
+func.func @test_fir_llvm_ptr_to_index() {
+  %0 = fir.zero_bits !fir.llvm_ptr<i8> {test.cast, cast_dest = index}
+  // CHECK: Successfully generated cast for operation: %{{.*}} = fir.zero_bits{{.*}}
+  // CHECK: Cast result type: index
+  // CHECK: Generated: %{{.*}} = fir.convert %{{.*}} : (!fir.llvm_ptr<i8>) -> index
+  return
+}

diff  --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCTypeInterfaces.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCTypeInterfaces.td
index 3bd4e5c679659..2753056bf8a7a 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCTypeInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCTypeInterfaces.td
@@ -220,6 +220,34 @@ def OpenACC_PointerLikeTypeInterface : TypeInterface<"PointerLikeType"> {
         return false;
       }]
     >,
+    InterfaceMethod<
+      /*description=*/[{
+        Generates a cast from `value` to `resultType` when the implementing
+        pointer-like type can emit a lowering for that conversion.
+
+        This is intentionally a single operation (rather than separate
+        "cast to" / "cast from" hooks): the source type is always
+        `value.getType()` and the destination is always `resultType`.
+
+        Call sites typically dispatch on the `PointerLikeType` that owns the
+        conversion: use the source value's type when casting *from* a
+        pointer-like representation.
+
+        Returns the cast result on success, or an empty value if the cast is
+        unsupported (callers may then try the other endpoint's interface or
+        apply their own fallback).
+      }],
+      /*retTy=*/"::mlir::Value",
+      /*methodName=*/"genCast",
+      /*args=*/(ins "::mlir::OpBuilder &":$builder,
+                    "::mlir::Location":$loc,
+                    "::mlir::Value":$value,
+                    "::mlir::Type":$resultType),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return {};
+      }]
+    >,
     InterfaceMethod<
       /*description=*/[{
         Returns true if the pointer points to device data.

diff  --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index bd26c70eb1831..449a9b588910f 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/IRMapping.h"
@@ -246,6 +247,37 @@ struct MemRefPointerLikeModel
     return true;
   }
 
+  Value genCast(Type, OpBuilder &builder, Location loc, Value value,
+                Type resultType) const {
+    if (value.getType() == resultType)
+      return value;
+
+    if (isa<BaseMemRefType>(value.getType()) &&
+        isa<BaseMemRefType>(resultType)) {
+      if (memref::CastOp::areCastCompatible(TypeRange(value.getType()),
+                                            TypeRange(resultType)))
+        return memref::CastOp::create(builder, loc, resultType, value);
+      if (memref::MemorySpaceCastOp::areCastCompatible(
+              TypeRange(value.getType()), TypeRange(resultType)))
+        return memref::MemorySpaceCastOp::create(builder, loc, resultType,
+                                                 value);
+    }
+
+    // If one side is not a memref, try the other type's `PointerLikeType`
+    // implementation (since it may be an out-of-tree reference type that
+    // we cannot generate here).
+    if (auto resPtrLike = dyn_cast<PointerLikeType>(resultType))
+      if (!isa<BaseMemRefType>(resPtrLike))
+        if (Value v = resPtrLike.genCast(builder, loc, value, resultType))
+          return v;
+    if (auto valPtrLike = dyn_cast<PointerLikeType>(value.getType()))
+      if (!isa<BaseMemRefType>(valPtrLike))
+        if (Value v = valPtrLike.genCast(builder, loc, value, resultType))
+          return v;
+
+    return {};
+  }
+
   bool isDeviceData(Type pointer, Value var) const {
     auto memrefTy = cast<T>(pointer);
     Attribute memSpace = memrefTy.getMemorySpace();
@@ -273,6 +305,45 @@ struct LLVMPointerPointerLikeModel
     LLVM::StoreOp::create(builder, loc, valueToStore, destPtr);
     return true;
   }
+
+  Value genCast(Type, OpBuilder &builder, Location loc, Value value,
+                Type resultType) const {
+    if (value.getType() == resultType)
+      return value;
+
+    auto srcPtrTy = dyn_cast<LLVM::LLVMPointerType>(value.getType());
+    auto dstPtrTy = dyn_cast<LLVM::LLVMPointerType>(resultType);
+    if (srcPtrTy && dstPtrTy) {
+      if (srcPtrTy.getAddressSpace() != dstPtrTy.getAddressSpace())
+        return LLVM::AddrSpaceCastOp::create(builder, loc, resultType, value);
+      return value;
+    }
+
+    if (srcPtrTy && isa<IntegerType>(resultType))
+      return LLVM::PtrToIntOp::create(builder, loc, resultType, value);
+
+    if (dstPtrTy) {
+      Value intVal = value;
+      if (isa<IndexType>(value.getType()))
+        intVal = arith::IndexCastUIOp::create(builder, loc,
+                                              builder.getI64Type(), value);
+      if (isa<IntegerType>(intVal.getType()))
+        return LLVM::IntToPtrOp::create(builder, loc, resultType, intVal);
+    }
+
+    if (auto resPtrLike = dyn_cast<PointerLikeType>(resultType))
+      if (!isa<LLVM::LLVMPointerType>(resPtrLike))
+        if (Value v = resPtrLike.genCast(builder, loc, value, resultType))
+          return v;
+    if (auto valPtrLike = dyn_cast<PointerLikeType>(value.getType()))
+      if (!isa<LLVM::LLVMPointerType>(valPtrLike))
+        if (Value v = valPtrLike.genCast(builder, loc, value, resultType))
+          return v;
+
+    return UnrealizedConversionCastOp::create(builder, loc,
+                                              TypeRange(resultType), value)
+        .getResult(0);
+  }
 };
 
 struct MemrefAddressOfGlobalModel

diff  --git a/mlir/test/Dialect/OpenACC/pointer-like-interface-cast.mlir b/mlir/test/Dialect/OpenACC/pointer-like-interface-cast.mlir
new file mode 100644
index 0000000000000..2a0939a3020db
--- /dev/null
+++ b/mlir/test/Dialect/OpenACC/pointer-like-interface-cast.mlir
@@ -0,0 +1,36 @@
+// RUN: mlir-opt %s --split-input-file --pass-pipeline="builtin.module(func.func(test-acc-pointer-like-interface{test-mode=cast}))" 2>&1 | FileCheck %s
+
+func.func @test_memref_cast_identity() {
+  %0 = memref.alloca() {test.cast, cast_dest = memref<f32>} : memref<f32>
+  // CHECK: Successfully generated cast for operation: %[[V:.*]] = memref.alloca(){{.*}} : memref<f32>
+  // CHECK: Cast result type: memref<f32>
+  return
+}
+
+// -----
+
+func.func @test_memref_cast_static_to_dynamic() {
+  %0 = memref.alloca() {test.cast, cast_dest = memref<?xf32>} : memref<4xf32>
+  // CHECK: Successfully generated cast for operation: %[[V:.*]] = memref.alloca(){{.*}} : memref<4xf32>
+  // CHECK: Cast result type: memref<?xf32>
+  // CHECK: Generated: %{{.*}} = memref.cast %[[V]] : memref<4xf32> to memref<?xf32>
+  return
+}
+
+// -----
+
+func.func @test_memref_memory_space_cast() {
+  %0 = memref.alloca() {test.cast, cast_dest = memref<4xf32, 1>} : memref<4xf32>
+  // CHECK: Successfully generated cast for operation: %[[V:.*]] = memref.alloca(){{.*}} : memref<4xf32>
+  // CHECK: Cast result type: memref<4xf32, 1>
+  // CHECK: Generated: %{{.*}} = memref.memory_space_cast %[[V]] : memref<4xf32> to memref<4xf32, 1>
+  return
+}
+
+// -----
+
+func.func @test_memref_cast_incompatible() {
+  %0 = memref.alloca() {test.cast, cast_dest = tensor<4xf32>} : memref<4xf32>
+  // CHECK: Failed to generate cast for operation: %{{.*}} = memref.alloca(){{.*}} : memref<4xf32>
+  return
+}

diff  --git a/mlir/test/lib/Dialect/OpenACC/TestPointerLikeTypeInterface.cpp b/mlir/test/lib/Dialect/OpenACC/TestPointerLikeTypeInterface.cpp
index 3ff0dc85b2152..e45fd104e7331 100644
--- a/mlir/test/lib/Dialect/OpenACC/TestPointerLikeTypeInterface.cpp
+++ b/mlir/test/lib/Dialect/OpenACC/TestPointerLikeTypeInterface.cpp
@@ -46,7 +46,8 @@ struct TestPointerLikeTypeInterfacePass
 
   Pass::Option<std::string> testMode{
       *this, "test-mode",
-      llvm::cl::desc("Test mode: walk, alloc, copy, free, load, or store"),
+      llvm::cl::desc(
+          "Test mode: walk, alloc, copy, free, load, store, or cast"),
       llvm::cl::init("walk")};
 
   StringRef getArgument() const override {
@@ -79,6 +80,8 @@ struct TestPointerLikeTypeInterfacePass
                    OpBuilder &builder);
   void testGenStore(Operation *op, Value result, PointerLikeType pointerType,
                     OpBuilder &builder, Value providedValue = {});
+  void testGenCast(Operation *op, Value value, Type resultType,
+                   OpBuilder &builder);
 
   struct PointerCandidate {
     Operation *op;
@@ -96,6 +99,18 @@ void TestPointerLikeTypeInterfacePass::runOnOperation() {
   auto func = getOperation();
   OpBuilder builder(&getContext());
 
+  if (testMode == "cast") {
+    func.walk([&](Operation *op) {
+      if (!op->hasAttr("test.cast"))
+        return;
+      auto destAttr = dyn_cast_or_null<TypeAttr>(op->getAttr("cast_dest"));
+      if (!destAttr || op->getNumResults() == 0)
+        return;
+      testGenCast(op, op->getResult(0), destAttr.getValue(), builder);
+    });
+    return;
+  }
+
   if (testMode == "alloc" || testMode == "free" || testMode == "load" ||
       testMode == "store") {
     // Collect all candidates first
@@ -409,6 +424,51 @@ void TestPointerLikeTypeInterfacePass::testGenStore(Operation *op, Value result,
   }
 }
 
+void TestPointerLikeTypeInterfacePass::testGenCast(Operation *op, Value value,
+                                                   Type resultType,
+                                                   OpBuilder &builder) {
+  Location loc = op->getLoc();
+
+  OperationTracker tracker;
+  OpBuilder newBuilder(op->getContext());
+  newBuilder.setListener(&tracker);
+  newBuilder.setInsertionPointAfter(op);
+
+  PointerLikeType dispatchTy;
+  if (isa<PointerLikeType>(value.getType()))
+    dispatchTy = cast<PointerLikeType>(value.getType());
+  else if (isa<PointerLikeType>(resultType))
+    dispatchTy = cast<PointerLikeType>(resultType);
+  else {
+    llvm::errs() << "Failed genCast: neither value nor result type is "
+                    "PointerLikeType for operation: ";
+    op->print(llvm::errs());
+    llvm::errs() << "\n";
+    return;
+  }
+
+  Value castRes = dispatchTy.genCast(newBuilder, loc, value, resultType);
+
+  if (castRes) {
+    llvm::errs() << "Successfully generated cast for operation: ";
+    op->print(llvm::errs());
+    llvm::errs() << "\n";
+    llvm::errs() << "\tCast result type: ";
+    castRes.getType().print(llvm::errs());
+    llvm::errs() << "\n";
+
+    for (Operation *insertedOp : tracker.insertedOps) {
+      llvm::errs() << "\tGenerated: ";
+      insertedOp->print(llvm::errs());
+      llvm::errs() << "\n";
+    }
+  } else {
+    llvm::errs() << "Failed to generate cast for operation: ";
+    op->print(llvm::errs());
+    llvm::errs() << "\n";
+  }
+}
+
 } // namespace
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/unittests/Dialect/OpenACC/CMakeLists.txt b/mlir/unittests/Dialect/OpenACC/CMakeLists.txt
index 17d1721b82602..7bcb652b69185 100644
--- a/mlir/unittests/Dialect/OpenACC/CMakeLists.txt
+++ b/mlir/unittests/Dialect/OpenACC/CMakeLists.txt
@@ -16,6 +16,7 @@ mlir_target_link_libraries(MLIROpenACCTests
   MLIRDLTIDialect
   MLIRFuncDialect
   MLIRGPUDialect
+  MLIRLLVMDialect
   MLIRMemRefDialect
   MLIRArithDialect
   MLIROpenACCDialect

diff  --git a/mlir/unittests/Dialect/OpenACC/OpenACCTypeInterfacesTest.cpp b/mlir/unittests/Dialect/OpenACC/OpenACCTypeInterfacesTest.cpp
index c773c031e3b6e..0d5419203dbbe 100644
--- a/mlir/unittests/Dialect/OpenACC/OpenACCTypeInterfacesTest.cpp
+++ b/mlir/unittests/Dialect/OpenACC/OpenACCTypeInterfacesTest.cpp
@@ -7,8 +7,12 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/OpenACC/OpenACC.h"
 #include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/DialectRegistry.h"
 #include "mlir/IR/MLIRContext.h"
@@ -16,6 +20,7 @@
 
 using namespace mlir;
 using namespace mlir::acc;
+using namespace mlir::LLVM;
 
 namespace {
 
@@ -51,7 +56,9 @@ class OpenACCTypeInterfacesTest : public ::testing::Test {
       IntegerType::attachInterface<TestReducibleIntegerModel>(*ctx);
     });
     context.appendDialectRegistry(registry);
-    context.loadDialect<acc::OpenACCDialect, arith::ArithDialect>();
+    context
+        .loadDialect<acc::OpenACCDialect, arith::ArithDialect,
+                     memref::MemRefDialect, func::FuncDialect, LLVMDialect>();
   }
 
   MLIRContext context;
@@ -91,3 +98,161 @@ TEST_F(OpenACCTypeInterfacesTest, NonReducibleTypeReturnsNull) {
   auto reducible = dyn_cast<ReducibleType>(f32Type);
   EXPECT_TRUE(reducible == nullptr);
 }
+
+//===----------------------------------------------------------------------===//
+// PointerLikeType::genCast tests
+//===----------------------------------------------------------------------===//
+
+TEST_F(OpenACCTypeInterfacesTest, PointerLikeGenCastMemrefIdentity) {
+  Location loc = UnknownLoc::get(&context);
+  OwningOpRef<ModuleOp> module = ModuleOp::create(loc);
+  OpBuilder builder(module->getBodyRegion());
+  func::FuncOp fn = func::FuncOp::create(builder, loc, "cast_identity",
+                                         builder.getFunctionType({}, {}));
+  Block *block = fn.addEntryBlock();
+  builder.setInsertionPointToStart(block);
+
+  auto memTy = MemRefType::get({}, builder.getF32Type());
+  memref::AllocaOp alloca = memref::AllocaOp::create(builder, loc, memTy);
+  Value v = alloca.getResult();
+  auto ptrLike = cast<PointerLikeType>(v.getType());
+  Value out = ptrLike.genCast(builder, loc, v, memTy);
+  ASSERT_TRUE(out);
+  EXPECT_EQ(out, v);
+}
+
+TEST_F(OpenACCTypeInterfacesTest, PointerLikeGenCastMemrefStaticToDynamic) {
+  Location loc = UnknownLoc::get(&context);
+  OwningOpRef<ModuleOp> module = ModuleOp::create(loc);
+  OpBuilder builder(module->getBodyRegion());
+  func::FuncOp fn = func::FuncOp::create(builder, loc, "cast_memref",
+                                         builder.getFunctionType({}, {}));
+  Block *block = fn.addEntryBlock();
+  builder.setInsertionPointToStart(block);
+
+  auto srcTy = MemRefType::get({4}, builder.getF32Type());
+  auto dstTy = MemRefType::get({ShapedType::kDynamic}, builder.getF32Type());
+  memref::AllocaOp alloca = memref::AllocaOp::create(builder, loc, srcTy);
+  Value v = alloca.getResult();
+  auto ptrLike = cast<PointerLikeType>(v.getType());
+  Value out = ptrLike.genCast(builder, loc, v, dstTy);
+  ASSERT_TRUE(out);
+  EXPECT_EQ(out.getType(), dstTy);
+  ASSERT_TRUE(isa<memref::CastOp>(out.getDefiningOp()));
+}
+
+TEST_F(OpenACCTypeInterfacesTest, PointerLikeGenCastMemrefMemorySpace) {
+  Location loc = UnknownLoc::get(&context);
+  OwningOpRef<ModuleOp> module = ModuleOp::create(loc);
+  OpBuilder builder(module->getBodyRegion());
+  func::FuncOp fn = func::FuncOp::create(builder, loc, "cast_memref_memspace",
+                                         builder.getFunctionType({}, {}));
+  Block *block = fn.addEntryBlock();
+  builder.setInsertionPointToStart(block);
+
+  Attribute msHost = builder.getI32IntegerAttr(0);
+  Attribute msDev = builder.getI32IntegerAttr(1);
+  auto srcTy = MemRefType::get({4}, builder.getF32Type(), AffineMap(), msHost);
+  auto dstTy = MemRefType::get({4}, builder.getF32Type(), AffineMap(), msDev);
+  memref::AllocaOp alloca = memref::AllocaOp::create(builder, loc, srcTy);
+  Value v = alloca.getResult();
+  auto ptrLike = cast<PointerLikeType>(v.getType());
+  Value out = ptrLike.genCast(builder, loc, v, dstTy);
+  ASSERT_TRUE(out);
+  EXPECT_EQ(out.getType(), dstTy);
+  ASSERT_TRUE(isa<memref::MemorySpaceCastOp>(out.getDefiningOp()));
+}
+
+TEST_F(OpenACCTypeInterfacesTest, PointerLikeGenCastLLVMPtrAddrSpace) {
+  Location loc = UnknownLoc::get(&context);
+  OwningOpRef<ModuleOp> module = ModuleOp::create(loc);
+  OpBuilder builder(module->getBodyRegion());
+  func::FuncOp fn = func::FuncOp::create(builder, loc, "cast_llvm_addrspace",
+                                         builder.getFunctionType({}, {}));
+  Block *block = fn.addEntryBlock();
+  builder.setInsertionPointToStart(block);
+
+  Type ptrAs0 = LLVMPointerType::get(&context, 0);
+  Type ptrAs3 = LLVMPointerType::get(&context, 3);
+  Value v = UndefOp::create(builder, loc, ptrAs0);
+  auto ptrLike = cast<PointerLikeType>(ptrAs0);
+  Value out = ptrLike.genCast(builder, loc, v, ptrAs3);
+  ASSERT_TRUE(out);
+  EXPECT_EQ(out.getType(), ptrAs3);
+  ASSERT_TRUE(isa<AddrSpaceCastOp>(out.getDefiningOp()));
+}
+
+TEST_F(OpenACCTypeInterfacesTest, PointerLikeGenCastLLVMPtrSameAddrSpaceNoOp) {
+  Location loc = UnknownLoc::get(&context);
+  OwningOpRef<ModuleOp> module = ModuleOp::create(loc);
+  OpBuilder builder(module->getBodyRegion());
+  func::FuncOp fn = func::FuncOp::create(builder, loc, "cast_llvm_same_as",
+                                         builder.getFunctionType({}, {}));
+  Block *block = fn.addEntryBlock();
+  builder.setInsertionPointToStart(block);
+
+  Type ptrTy = LLVMPointerType::get(&context, 2);
+  Value v = UndefOp::create(builder, loc, ptrTy);
+  auto ptrLike = cast<PointerLikeType>(ptrTy);
+  Value out = ptrLike.genCast(builder, loc, v, ptrTy);
+  ASSERT_TRUE(out);
+  EXPECT_EQ(out, v);
+}
+
+TEST_F(OpenACCTypeInterfacesTest, PointerLikeGenCastLLVMPtrToI64) {
+  Location loc = UnknownLoc::get(&context);
+  OwningOpRef<ModuleOp> module = ModuleOp::create(loc);
+  OpBuilder builder(module->getBodyRegion());
+  func::FuncOp fn = func::FuncOp::create(builder, loc, "cast_llvm_ptrtoint",
+                                         builder.getFunctionType({}, {}));
+  Block *block = fn.addEntryBlock();
+  builder.setInsertionPointToStart(block);
+
+  Type ptrTy = LLVMPointerType::get(&context, 0);
+  Type i64Ty = builder.getI64Type();
+  Value v = UndefOp::create(builder, loc, ptrTy);
+  auto ptrLike = cast<PointerLikeType>(ptrTy);
+  Value out = ptrLike.genCast(builder, loc, v, i64Ty);
+  ASSERT_TRUE(out);
+  EXPECT_EQ(out.getType(), i64Ty);
+  ASSERT_TRUE(isa<PtrToIntOp>(out.getDefiningOp()));
+}
+
+TEST_F(OpenACCTypeInterfacesTest, PointerLikeGenCastLLVMIntToPtrFromI64) {
+  Location loc = UnknownLoc::get(&context);
+  OwningOpRef<ModuleOp> module = ModuleOp::create(loc);
+  OpBuilder builder(module->getBodyRegion());
+  func::FuncOp fn = func::FuncOp::create(builder, loc, "cast_llvm_inttoptr",
+                                         builder.getFunctionType({}, {}));
+  Block *block = fn.addEntryBlock();
+  builder.setInsertionPointToStart(block);
+
+  Type ptrTy = LLVMPointerType::get(&context, 0);
+  Value v = arith::ConstantIntOp::create(builder, loc, builder.getI64Type(), 0);
+  auto ptrLike = cast<PointerLikeType>(ptrTy);
+  Value out = ptrLike.genCast(builder, loc, v, ptrTy);
+  ASSERT_TRUE(out);
+  EXPECT_EQ(out.getType(), ptrTy);
+  ASSERT_TRUE(isa<IntToPtrOp>(out.getDefiningOp()));
+}
+
+TEST_F(OpenACCTypeInterfacesTest, PointerLikeGenCastLLVMIntToPtrFromIndex) {
+  Location loc = UnknownLoc::get(&context);
+  OwningOpRef<ModuleOp> module = ModuleOp::create(loc);
+  OpBuilder builder(module->getBodyRegion());
+  func::FuncOp fn =
+      func::FuncOp::create(builder, loc, "cast_llvm_index_inttoptr",
+                           builder.getFunctionType({}, {}));
+  Block *block = fn.addEntryBlock();
+  builder.setInsertionPointToStart(block);
+
+  Type ptrTy = LLVMPointerType::get(&context, 0);
+  Value v = arith::ConstantIndexOp::create(builder, loc, 0);
+  auto ptrLike = cast<PointerLikeType>(ptrTy);
+  Value out = ptrLike.genCast(builder, loc, v, ptrTy);
+  ASSERT_TRUE(out);
+  EXPECT_EQ(out.getType(), ptrTy);
+  auto intToPtr = dyn_cast<IntToPtrOp>(out.getDefiningOp());
+  ASSERT_TRUE(intToPtr);
+  EXPECT_TRUE(isa<arith::IndexCastUIOp>(intToPtr.getArg().getDefiningOp()));
+}


        


More information about the flang-commits mailing list