[flang-commits] [flang] [flang][cuda] Remove CUFDeviceAddressOpConversion from CUFOpConversion (PR #177213)

Valentin Clement バレンタイン クレメン via flang-commits flang-commits at lists.llvm.org
Wed Jan 21 10:22:56 PST 2026


https://github.com/clementval created https://github.com/llvm/llvm-project/pull/177213

The pattern has been moved to CUFOpConversionLate

>From b847d0ba3e8727e5e35eeb22d6fd34e4b5eb12ba Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Wed, 21 Jan 2026 10:22:12 -0800
Subject: [PATCH] [flang][cuda] Remove CUFDeviceAddressOpConversion from
 CUFOpConversion

---
 .../Transforms/CUDA/CUFOpConversion.cpp       | 50 ++-----------------
 flang/test/Fir/CUDA/cuda-data-transfer.fir    |  2 +-
 flang/test/Fir/CUDA/cuda-global-addr.mlir     |  2 +-
 flang/test/Fir/CUDA/cuda-launch.fir           |  2 +-
 4 files changed, 6 insertions(+), 50 deletions(-)

diff --git a/flang/lib/Optimizer/Transforms/CUDA/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUDA/CUFOpConversion.cpp
index e81714aa63bba..313c358f92792 100644
--- a/flang/lib/Optimizer/Transforms/CUDA/CUFOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CUDA/CUFOpConversion.cpp
@@ -73,49 +73,6 @@ static mlir::Value createConvertOp(mlir::PatternRewriter &rewriter,
   return val;
 }
 
-struct CUFDeviceAddressOpConversion
-    : public mlir::OpRewritePattern<cuf::DeviceAddressOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  CUFDeviceAddressOpConversion(mlir::MLIRContext *context,
-                               const mlir::SymbolTable &symtab)
-      : OpRewritePattern(context), symTab{symtab} {}
-
-  mlir::LogicalResult
-  matchAndRewrite(cuf::DeviceAddressOp op,
-                  mlir::PatternRewriter &rewriter) const override {
-    if (auto global = symTab.lookup<fir::GlobalOp>(
-            op.getHostSymbol().getRootReference().getValue())) {
-      auto mod = op->getParentOfType<mlir::ModuleOp>();
-      mlir::Location loc = op.getLoc();
-      auto hostAddr = fir::AddrOfOp::create(
-          rewriter, loc, fir::ReferenceType::get(global.getType()),
-          op.getHostSymbol());
-      fir::FirOpBuilder builder(rewriter, mod);
-      mlir::func::FuncOp callee =
-          fir::runtime::getRuntimeFunc<mkRTKey(CUFGetDeviceAddress)>(loc,
-                                                                     builder);
-      auto fTy = callee.getFunctionType();
-      mlir::Value conv =
-          createConvertOp(rewriter, loc, fTy.getInput(0), hostAddr);
-      mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
-      mlir::Value sourceLine =
-          fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
-      llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
-          builder, loc, fTy, conv, sourceFile, sourceLine)};
-      auto call = fir::CallOp::create(rewriter, loc, callee, args);
-      mlir::Value addr = createConvertOp(rewriter, loc, hostAddr.getType(),
-                                         call->getResult(0));
-      rewriter.replaceOp(op, addr.getDefiningOp());
-      return success();
-    }
-    return failure();
-  }
-
-private:
-  const mlir::SymbolTable &symTab;
-};
-
 struct DeclareOpConversion : public mlir::OpRewritePattern<fir::DeclareOp> {
   using OpRewritePattern::OpRewritePattern;
 
@@ -559,6 +516,7 @@ class CUFOpConversion : public fir::impl::CUFOpConversionBase<CUFOpConversion> {
     target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
                            mlir::gpu::GPUDialect>();
     target.addLegalOp<cuf::StreamCastOp>();
+    target.addLegalOp<cuf::DeviceAddressOp>();
     cuf::populateCUFToFIRConversionPatterns(typeConverter, *dl, symtab,
                                             patterns);
     if (allocationConversion)
@@ -606,12 +564,10 @@ void cuf::populateCUFToFIRConversionPatterns(
   patterns.insert<CUFSyncDescriptorOpConversion>(patterns.getContext());
   patterns.insert<CUFDataTransferOpConversion>(patterns.getContext(), symtab,
                                                &dl, &converter);
-  patterns.insert<CUFLaunchOpConversion, CUFDeviceAddressOpConversion>(
-      patterns.getContext(), symtab);
+  patterns.insert<CUFLaunchOpConversion>(patterns.getContext(), symtab);
 }
 
 void cuf::populateFIRCUFConversionPatterns(const mlir::SymbolTable &symtab,
                                            mlir::RewritePatternSet &patterns) {
-  patterns.insert<DeclareOpConversion, CUFDeviceAddressOpConversion>(
-      patterns.getContext(), symtab);
+  patterns.insert<DeclareOpConversion>(patterns.getContext(), symtab);
 }
diff --git a/flang/test/Fir/CUDA/cuda-data-transfer.fir b/flang/test/Fir/CUDA/cuda-data-transfer.fir
index b247fce44df3d..b6486a9ddecdd 100644
--- a/flang/test/Fir/CUDA/cuda-data-transfer.fir
+++ b/flang/test/Fir/CUDA/cuda-data-transfer.fir
@@ -1,4 +1,4 @@
-// RUN: fir-opt --cuf-convert %s | FileCheck %s
+// RUN: fir-opt --cuf-convert --cuf-convert-late %s | FileCheck %s
 
 module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<f80, dense<128> : vector<2xi64>>, #dlti.dl_entry<i128, dense<128> : vector<2xi64>>, #dlti.dl_entry<i64, dense<64> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr<272>, dense<64> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<271>, dense<32> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<270>, dense<32> : vector<4xi64>>, #dlti.dl_entry<f128, dense<128> : vector<2xi64>>, #dlti.dl_entry<f64, dense<64> : vector<2xi64>>, #dlti.dl_entry<f16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i32, dense<32> : vector<2xi64>>, #dlti.dl_entry<i16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i8, dense<8> : vector<2xi64>>, #dlti.dl_entry<i1, dense<8> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr, dense<64> : vector<4xi64>>, #dlti.dl_entry<"dlti.endianness", "little">, #dlti.dl_entry<"dlti.stack_alignment", 128 : i64>>} {
 
diff --git a/flang/test/Fir/CUDA/cuda-global-addr.mlir b/flang/test/Fir/CUDA/cuda-global-addr.mlir
index ae88af3d3c16c..ca1d562b75a9b 100644
--- a/flang/test/Fir/CUDA/cuda-global-addr.mlir
+++ b/flang/test/Fir/CUDA/cuda-global-addr.mlir
@@ -1,4 +1,4 @@
-// RUN: fir-opt --split-input-file --cuf-convert %s | FileCheck %s
+// RUN: fir-opt --split-input-file --cuf-convert --cuf-convert-late %s | FileCheck %s
 
 module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<f80, dense<128> : vector<2xi64>>, #dlti.dl_entry<i128, dense<128> : vector<2xi64>>, #dlti.dl_entry<i64, dense<64> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr<272>, dense<64> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<271>, dense<32> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<270>, dense<32> : vector<4xi64>>, #dlti.dl_entry<f128, dense<128> : vector<2xi64>>, #dlti.dl_entry<f64, dense<64> : vector<2xi64>>, #dlti.dl_entry<f16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i32, dense<32> : vector<2xi64>>, #dlti.dl_entry<i16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i8, dense<8> : vector<2xi64>>, #dlti.dl_entry<i1, dense<8> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr, dense<64> : vector<4xi64>>, #dlti.dl_entry<"dlti.endianness", "little">, #dlti.dl_entry<"dlti.stack_alignment", 128 : i64>>} {
 fir.global @_QMmod1Eadev {data_attr = #cuf.cuda<device>} : !fir.array<10xi32> {
diff --git a/flang/test/Fir/CUDA/cuda-launch.fir b/flang/test/Fir/CUDA/cuda-launch.fir
index 1e8036e628650..92db6ecaadc45 100644
--- a/flang/test/Fir/CUDA/cuda-launch.fir
+++ b/flang/test/Fir/CUDA/cuda-launch.fir
@@ -1,4 +1,4 @@
-// RUN: fir-opt --split-input-file --cuf-convert %s | FileCheck %s
+// RUN: fir-opt --split-input-file --cuf-convert --cuf-convert-late %s | FileCheck %s
 
 
 module attributes {gpu.container_module, dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<f80, dense<128> : vector<2xi64>>, #dlti.dl_entry<i128, dense<128> : vector<2xi64>>, #dlti.dl_entry<i64, dense<64> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr<272>, dense<64> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<271>, dense<32> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<270>, dense<32> : vector<4xi64>>, #dlti.dl_entry<f128, dense<128> : vector<2xi64>>, #dlti.dl_entry<f64, dense<64> : vector<2xi64>>, #dlti.dl_entry<f16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i32, dense<32> : vector<2xi64>>, #dlti.dl_entry<i16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i8, dense<8> : vector<2xi64>>, #dlti.dl_entry<i1, dense<8> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr, dense<64> : vector<4xi64>>, #dlti.dl_entry<"dlti.endianness", "little">, #dlti.dl_entry<"dlti.stack_alignment", 128 : i64>>} {



More information about the flang-commits mailing list