[flang-commits] [flang] [flang][cuda] Handle gpu.return in AbstractResult pass (PR #119035)

Valentin Clement バレンタイン クレメン via flang-commits flang-commits at lists.llvm.org
Mon Dec 9 13:32:30 PST 2024


https://github.com/clementval updated https://github.com/llvm/llvm-project/pull/119035

>From df364862d7f0700525c00156fa4c532e19088bfe Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Fri, 6 Dec 2024 13:30:23 -0800
Subject: [PATCH 1/2] [flang][cuda] Change how abstract result pass is
 scheduled on func.func and gpu.func

---
 flang/lib/Optimizer/Passes/Pipelines.cpp      | 11 +++++++--
 .../Optimizer/Transforms/AbstractResult.cpp   | 15 ++++--------
 flang/test/Driver/bbc-mlir-pass-pipeline.f90  |  9 ++-----
 .../test/Driver/mlir-debug-pass-pipeline.f90  | 19 +++++++--------
 flang/test/Driver/mlir-pass-pipeline.f90      | 24 +++++++------------
 flang/test/Fir/basic-program.fir              | 24 +++++++------------
 6 files changed, 41 insertions(+), 61 deletions(-)

diff --git a/flang/lib/Optimizer/Passes/Pipelines.cpp b/flang/lib/Optimizer/Passes/Pipelines.cpp
index 0743fb60aa847a..ff79c811541c44 100644
--- a/flang/lib/Optimizer/Passes/Pipelines.cpp
+++ b/flang/lib/Optimizer/Passes/Pipelines.cpp
@@ -16,8 +16,14 @@ namespace fir {
 void addNestedPassToAllTopLevelOperations(mlir::PassManager &pm,
                                           PassConstructor ctor) {
   addNestedPassToOps<mlir::func::FuncOp, mlir::omp::DeclareReductionOp,
-                     mlir::omp::PrivateClauseOp, fir::GlobalOp,
-                     mlir::gpu::GPUModuleOp>(pm, ctor);
+                     mlir::omp::PrivateClauseOp, fir::GlobalOp>(pm, ctor);
+}
+
+template <typename NestOpTy>
+void addNestedPassToNest(mlir::PassManager &pm, PassConstructor ctor) {
+  mlir::OpPassManager &nestPM = pm.nest<NestOpTy>();
+  nestPM.addNestedPass<mlir::func::FuncOp>(ctor());
+  nestPM.addNestedPass<mlir::gpu::GPUFuncOp>(ctor());
 }
 
 void addNestedPassToAllTopLevelOperationsConditionally(
@@ -266,6 +272,7 @@ void createDefaultFIRCodeGenPassPipeline(mlir::PassManager &pm,
                                          llvm::StringRef inputFilename) {
   fir::addBoxedProcedurePass(pm);
   addNestedPassToAllTopLevelOperations(pm, fir::createAbstractResultOpt);
+  addNestedPassToNest<mlir::gpu::GPUModuleOp>(pm, fir::createAbstractResultOpt);
   fir::addCodeGenRewritePass(
       pm, (config.DebugInfo != llvm::codegenoptions::NoDebugInfo));
   fir::addExternalNameConversionPass(pm, config.Underscoring);
diff --git a/flang/lib/Optimizer/Transforms/AbstractResult.cpp b/flang/lib/Optimizer/Transforms/AbstractResult.cpp
index 2eca349110f3af..2ed66cc83eefb5 100644
--- a/flang/lib/Optimizer/Transforms/AbstractResult.cpp
+++ b/flang/lib/Optimizer/Transforms/AbstractResult.cpp
@@ -460,17 +460,10 @@ class AbstractResultOpt
     const bool shouldBoxResult = this->passResultAsBox.getValue();
 
     mlir::TypeSwitch<mlir::Operation *, void>(op)
-        .Case<mlir::func::FuncOp, fir::GlobalOp>([&](auto op) {
-          runOnSpecificOperation(op, shouldBoxResult, patterns, target);
-        })
-        .Case<mlir::gpu::GPUModuleOp>([&](auto op) {
-          auto gpuMod = mlir::dyn_cast<mlir::gpu::GPUModuleOp>(*op);
-          for (auto funcOp : gpuMod.template getOps<mlir::func::FuncOp>())
-            runOnSpecificOperation(funcOp, shouldBoxResult, patterns, target);
-          for (auto gpuFuncOp : gpuMod.template getOps<mlir::gpu::GPUFuncOp>())
-            runOnSpecificOperation(gpuFuncOp, shouldBoxResult, patterns,
-                                   target);
-        });
+        .Case<mlir::func::FuncOp, fir::GlobalOp, mlir::gpu::GPUFuncOp>(
+            [&](auto op) {
+              runOnSpecificOperation(op, shouldBoxResult, patterns, target);
+            });
 
     // Convert the calls and, if needed,  the ReturnOp in the function body.
     target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
diff --git a/flang/test/Driver/bbc-mlir-pass-pipeline.f90 b/flang/test/Driver/bbc-mlir-pass-pipeline.f90
index 1f09e7ad4c2f5a..5520d750e2ce1c 100644
--- a/flang/test/Driver/bbc-mlir-pass-pipeline.f90
+++ b/flang/test/Driver/bbc-mlir-pass-pipeline.f90
@@ -17,14 +17,12 @@
 ! CHECK-NEXT:   (S) 0 num-cse'd - Number of operations CSE'd
 ! CHECK-NEXT:   (S) 0 num-dce'd - Number of operations DCE'd
 
-! CHECK-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'gpu.module', 'omp.declare_reduction', 'omp.private']
+! CHECK-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction', 'omp.private']
 ! CHECK-NEXT: 'fir.global' Pipeline
 ! CHECK-NEXT:   CharacterConversion
 ! CHECK-NEXT: 'func.func' Pipeline
 ! CHECK-NEXT:   ArrayValueCopy
 ! CHECK-NEXT:   CharacterConversion
-! CHECK-NEXT: 'gpu.module' Pipeline
-! CHECK-NEXT:   CharacterConversion
 ! CHECK-NEXT: 'omp.declare_reduction' Pipeline
 ! CHECK-NEXT:   CharacterConversion
 ! CHECK-NEXT: 'omp.private' Pipeline
@@ -50,16 +48,13 @@
 ! CHECK-NEXT: PolymorphicOpConversion
 ! CHECK-NEXT: AssumedRankOpConversion
 
-! CHECK-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'gpu.module', 'omp.declare_reduction', 'omp.private']
+! CHECK-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction', 'omp.private']
 ! CHECK-NEXT: 'fir.global' Pipeline
 ! CHECK-NEXT:   StackReclaim
 ! CHECK-NEXT:   CFGConversion
 ! CHECK-NEXT: 'func.func' Pipeline
 ! CHECK-NEXT:   StackReclaim
 ! CHECK-NEXT:   CFGConversion
-! CHECK-NEXT: 'gpu.module' Pipeline
-! CHECK-NEXT:   StackReclaim
-! CHECK-NEXT:   CFGConversion
 ! CHECK-NEXT: 'omp.declare_reduction' Pipeline
 ! CHECK-NEXT:   StackReclaim
 ! CHECK-NEXT:   CFGConversion
diff --git a/flang/test/Driver/mlir-debug-pass-pipeline.f90 b/flang/test/Driver/mlir-debug-pass-pipeline.f90
index 4326953421e4bd..edc6f59b0ad7c9 100644
--- a/flang/test/Driver/mlir-debug-pass-pipeline.f90
+++ b/flang/test/Driver/mlir-debug-pass-pipeline.f90
@@ -28,13 +28,11 @@
 ! ALL: Pass statistics report
 
 ! ALL: Fortran::lower::VerifierPass
-! ALL-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'gpu.module', 'omp.declare_reduction', 'omp.private']
+! ALL-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction', 'omp.private']
 ! ALL-NEXT: 'fir.global' Pipeline
 ! ALL-NEXT:   InlineElementals
 ! ALL-NEXT: 'func.func' Pipeline
 ! ALL-NEXT:   InlineElementals
-! ALL-NEXT: 'gpu.module' Pipeline
-! ALL-NEXT:   InlineElementals
 ! ALL-NEXT: 'omp.declare_reduction' Pipeline
 ! ALL-NEXT:   InlineElementals
 ! ALL-NEXT: 'omp.private' Pipeline
@@ -51,14 +49,12 @@
 ! ALL-NEXT:   (S) 0 num-cse'd - Number of operations CSE'd
 ! ALL-NEXT:   (S) 0 num-dce'd - Number of operations DCE'd
 
-! ALL-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'gpu.module', 'omp.declare_reduction', 'omp.private']
+! ALL-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction', 'omp.private']
 ! ALL-NEXT: 'fir.global' Pipeline
 ! ALL-NEXT:   CharacterConversion
 ! ALL-NEXT: 'func.func' Pipeline
 ! ALL-NEXT:   ArrayValueCopy
 ! ALL-NEXT:   CharacterConversion
-! ALL-NEXT: 'gpu.module' Pipeline
-! ALL-NEXT:   CharacterConversion
 ! ALL-NEXT: 'omp.declare_reduction' Pipeline
 ! ALL-NEXT:   CharacterConversion
 ! ALL-NEXT: 'omp.private' Pipeline
@@ -82,16 +78,13 @@
 ! ALL-NEXT: PolymorphicOpConversion
 ! ALL-NEXT: AssumedRankOpConversion
 
-! ALL-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'gpu.module', 'omp.declare_reduction', 'omp.private']
+! ALL-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction', 'omp.private']
 ! ALL-NEXT:   'fir.global' Pipeline
 ! ALL-NEXT:     StackReclaim
 ! ALL-NEXT:     CFGConversion
 ! ALL-NEXT:   'func.func' Pipeline
 ! ALL-NEXT:     StackReclaim
 ! ALL-NEXT:     CFGConversion
-! ALL-NEXT:   'gpu.module' Pipeline
-! ALL-NEXT:     StackReclaim
-! ALL-NEXT:     CFGConversion
 ! ALL-NEXT:   'omp.declare_reduction' Pipeline
 ! ALL-NEXT:     StackReclaim
 ! ALL-NEXT:     CFGConversion
@@ -112,7 +105,11 @@
 ! ALL-NEXT:   'func.func' Pipeline
 ! ALL-NEXT:     AbstractResultOpt
 ! ALL-NEXT:   'gpu.module' Pipeline
-! ALL-NEXT:     AbstractResultOpt
+! ALL-NEXT:   Pipeline Collection : ['func.func', 'gpu.func'] 
+! ALL-NEXT:   'func.func' Pipeline 
+! ALL-NEXT:   AbstractResultOpt
+! ALL-NEXT:   'gpu.func' Pipeline 
+! ALL-NEXT:   AbstractResultOpt
 ! ALL-NEXT:   'omp.declare_reduction' Pipeline
 ! ALL-NEXT:     AbstractResultOpt
 ! ALL-NEXT:   'omp.private' Pipeline
diff --git a/flang/test/Driver/mlir-pass-pipeline.f90 b/flang/test/Driver/mlir-pass-pipeline.f90
index 6ffdbb0234e856..b30affe691b840 100644
--- a/flang/test/Driver/mlir-pass-pipeline.f90
+++ b/flang/test/Driver/mlir-pass-pipeline.f90
@@ -16,16 +16,13 @@
 
 ! ALL: Fortran::lower::VerifierPass
 ! O2-NEXT: Canonicalizer
-! ALL:     Pipeline Collection : ['fir.global', 'func.func', 'gpu.module', 'omp.declare_reduction', 'omp.private']
+! ALL:     Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction', 'omp.private']
 ! ALL-NEXT:'fir.global' Pipeline
 ! O2-NEXT:   SimplifyHLFIRIntrinsics
 ! ALL:       InlineElementals
 ! ALL-NEXT:'func.func' Pipeline
 ! O2-NEXT:   SimplifyHLFIRIntrinsics
 ! ALL:       InlineElementals
-! ALL-NEXT:'gpu.module' Pipeline
-! O2-NEXT:   SimplifyHLFIRIntrinsics
-! ALL:       InlineElementals
 ! ALL-NEXT:'omp.declare_reduction' Pipeline
 ! O2-NEXT:   SimplifyHLFIRIntrinsics
 ! ALL:       InlineElementals
@@ -36,13 +33,11 @@
 ! O2-NEXT: CSE
 ! O2-NEXT: (S) {{.*}} num-cse'd
 ! O2-NEXT: (S) {{.*}} num-dce'd
-! O2-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'gpu.module', 'omp.declare_reduction', 'omp.private']
+! O2-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction', 'omp.private']
 ! O2-NEXT: 'fir.global' Pipeline
 ! O2-NEXT:   OptimizedBufferization
 ! O2-NEXT: 'func.func' Pipeline
 ! O2-NEXT:   OptimizedBufferization
-! O2-NEXT: 'gpu.module' Pipeline
-! O2-NEXT:   OptimizedBufferization
 ! O2-NEXT: 'omp.declare_reduction' Pipeline
 ! O2-NEXT:   OptimizedBufferization
 ! O2-NEXT: 'omp.private' Pipeline
@@ -59,14 +54,12 @@
 ! ALL-NEXT:   (S) 0 num-cse'd - Number of operations CSE'd
 ! ALL-NEXT:   (S) 0 num-dce'd - Number of operations DCE'd
 
-! ALL-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'gpu.module', 'omp.declare_reduction', 'omp.private']
+! ALL-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction', 'omp.private']
 ! ALL-NEXT: 'fir.global' Pipeline
 ! ALL-NEXT:   CharacterConversion
 ! ALL-NEXT: 'func.func' Pipeline
 ! ALL-NEXT:   ArrayValueCopy
 ! ALL-NEXT:   CharacterConversion
-! ALL-NEXT: 'gpu.module' Pipeline
-! ALL-NEXT:   CharacterConversion
 ! ALL-NEXT: 'omp.declare_reduction' Pipeline
 ! ALL-NEXT:   CharacterConversion
 ! ALL-NEXT: 'omp.private' Pipeline
@@ -93,16 +86,13 @@
 ! ALL-NEXT: AssumedRankOpConversion
 ! O2-NEXT:  AddAliasTags
 
-! ALL-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'gpu.module', 'omp.declare_reduction', 'omp.private']
+! ALL-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction', 'omp.private']
 ! ALL-NEXT:    'fir.global' Pipeline
 ! ALL-NEXT:      StackReclaim
 ! ALL-NEXT:      CFGConversion
 ! ALL-NEXT:    'func.func' Pipeline
 ! ALL-NEXT:      StackReclaim
 ! ALL-NEXT:      CFGConversion
-! ALL-NEXT:   'gpu.module' Pipeline
-! ALL-NEXT:      StackReclaim
-! ALL-NEXT:      CFGConversion
 ! ALL-NEXT:   'omp.declare_reduction' Pipeline
 ! ALL-NEXT:      StackReclaim
 ! ALL-NEXT:      CFGConversion
@@ -124,7 +114,11 @@
 ! ALL-NEXT:  'func.func' Pipeline
 ! ALL-NEXT:    AbstractResultOpt
 ! ALL-NEXT:  'gpu.module' Pipeline
-! ALL-NEXT:    AbstractResultOpt
+! ALL-NEXT:   Pipeline Collection : ['func.func', 'gpu.func'] 
+! ALL-NEXT:   'func.func' Pipeline 
+! ALL-NEXT:   AbstractResultOpt
+! ALL-NEXT:   'gpu.func' Pipeline 
+! ALL-NEXT:   AbstractResultOpt
 ! ALL-NEXT:  'omp.declare_reduction' Pipeline
 ! ALL-NEXT:    AbstractResultOpt
 ! ALL-NEXT:  'omp.private' Pipeline
diff --git a/flang/test/Fir/basic-program.fir b/flang/test/Fir/basic-program.fir
index 50b91ce340b3a6..d2788008c3893e 100644
--- a/flang/test/Fir/basic-program.fir
+++ b/flang/test/Fir/basic-program.fir
@@ -17,16 +17,13 @@ func.func @_QQmain() {
 // PASSES: Pass statistics report
 
 // PASSES:        Canonicalizer
-// PASSES-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'gpu.module', 'omp.declare_reduction', 'omp.private']
+// PASSES-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction', 'omp.private']
 // PASSES-NEXT: 'fir.global' Pipeline
 // PASSES-NEXT:   SimplifyHLFIRIntrinsics
 // PASSES-NEXT:   InlineElementals
 // PASSES-NEXT: 'func.func' Pipeline
 // PASSES-NEXT:   SimplifyHLFIRIntrinsics
 // PASSES-NEXT:   InlineElementals
-// PASSES-NEXT: 'gpu.module' Pipeline
-// PASSES-NEXT:   SimplifyHLFIRIntrinsics
-// PASSES-NEXT:   InlineElementals
 // PASSES-NEXT: 'omp.declare_reduction' Pipeline
 // PASSES-NEXT:   SimplifyHLFIRIntrinsics
 // PASSES-NEXT:   InlineElementals
@@ -37,13 +34,11 @@ func.func @_QQmain() {
 // PASSES-NEXT:   CSE
 // PASSES-NEXT:    (S) 0 num-cse'd - Number of operations CSE'd
 // PASSES-NEXT:    (S) 0 num-dce'd - Number of operations DCE'd
-// PASSES-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'gpu.module', 'omp.declare_reduction', 'omp.private']
+// PASSES-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction', 'omp.private']
 // PASSES-NEXT: 'fir.global' Pipeline
 // PASSES-NEXT:    OptimizedBufferization
 // PASSES-NEXT: 'func.func' Pipeline
 // PASSES-NEXT:    OptimizedBufferization
-// PASSES-NEXT: 'gpu.module' Pipeline
-// PASSES-NEXT:    OptimizedBufferization
 // PASSES-NEXT: 'omp.declare_reduction' Pipeline
 // PASSES-NEXT:    OptimizedBufferization
 // PASSES-NEXT: 'omp.private' Pipeline
@@ -57,14 +52,12 @@ func.func @_QQmain() {
 // PASSES-NEXT:   (S) 0 num-cse'd - Number of operations CSE'd
 // PASSES-NEXT:   (S) 0 num-dce'd - Number of operations DCE'd
 
-// PASSES-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'gpu.module', 'omp.declare_reduction', 'omp.private']
+// PASSES-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction', 'omp.private']
 // PASSES-NEXT: 'fir.global' Pipeline
 // PASSES-NEXT:   CharacterConversion
 // PASSES-NEXT: 'func.func' Pipeline
 // PASSES-NEXT:   ArrayValueCopy
 // PASSES-NEXT:   CharacterConversion
-// PASSES-NEXT: 'gpu.module' Pipeline
-// PASSES-NEXT:   CharacterConversion
 // PASSES-NEXT: 'omp.declare_reduction' Pipeline
 // PASSES-NEXT:   CharacterConversion
 // PASSES-NEXT: 'omp.private' Pipeline
@@ -91,16 +84,13 @@ func.func @_QQmain() {
 // PASSES-NEXT: AssumedRankOpConversion
 // PASSES-NEXT: AddAliasTags
 
-// PASSES-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'gpu.module', 'omp.declare_reduction', 'omp.private']
+// PASSES-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction', 'omp.private']
 // PASSES-NEXT: 'fir.global' Pipeline
 // PASSES-NEXT:   StackReclaim
 // PASSES-NEXT:   CFGConversion
 // PASSES-NEXT: 'func.func' Pipeline
 // PASSES-NEXT:   StackReclaim
 // PASSES-NEXT:   CFGConversion
-// PASSES-NEXT: 'gpu.module' Pipeline
-// PASSES-NEXT:   StackReclaim
-// PASSES-NEXT:   CFGConversion
 // PASSES-NEXT: 'omp.declare_reduction' Pipeline
 // PASSES-NEXT:   StackReclaim
 // PASSES-NEXT:   CFGConversion
@@ -122,7 +112,11 @@ func.func @_QQmain() {
 // PASSES-NEXT:  'func.func' Pipeline
 // PASSES-NEXT:    AbstractResultOpt
 // PASSES-NEXT:  'gpu.module' Pipeline
-// PASSES-NEXT:    AbstractResultOpt
+// PASSES-NEXT:   Pipeline Collection : ['func.func', 'gpu.func'] 
+// PASSES-NEXT:   'func.func' Pipeline 
+// PASSES-NEXT:   AbstractResultOpt
+// PASSES-NEXT:   'gpu.func' Pipeline 
+// PASSES-NEXT:   AbstractResultOpt
 // PASSES-NEXT:  'omp.declare_reduction' Pipeline
 // PASSES-NEXT:    AbstractResultOpt
 // PASSES-NEXT:  'omp.private' Pipeline

>From 72a892909a7ce898423cea709f674956dc2121d1 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Fri, 6 Dec 2024 13:42:06 -0800
Subject: [PATCH 2/2] [flang][cuda] Handle gpu.return in AbstractResult pass

---
 .../Optimizer/Transforms/AbstractResult.cpp   | 123 +++++++++++-------
 flang/test/Fir/CUDA/cuda-abstract-result.mlir |  37 ++++++
 2 files changed, 111 insertions(+), 49 deletions(-)
 create mode 100644 flang/test/Fir/CUDA/cuda-abstract-result.mlir

diff --git a/flang/lib/Optimizer/Transforms/AbstractResult.cpp b/flang/lib/Optimizer/Transforms/AbstractResult.cpp
index 2ed66cc83eefb5..b0327cc10e9de6 100644
--- a/flang/lib/Optimizer/Transforms/AbstractResult.cpp
+++ b/flang/lib/Optimizer/Transforms/AbstractResult.cpp
@@ -234,6 +234,60 @@ class SaveResultOpConversion
   }
 };
 
+template <typename OpTy>
+static mlir::LogicalResult
+processReturnLikeOp(OpTy ret, mlir::Value newArg,
+                    mlir::PatternRewriter &rewriter) {
+  auto loc = ret.getLoc();
+  rewriter.setInsertionPoint(ret);
+  mlir::Value resultValue = ret.getOperand(0);
+  fir::LoadOp resultLoad;
+  mlir::Value resultStorage;
+  // Identify result local storage.
+  if (auto load = resultValue.getDefiningOp<fir::LoadOp>()) {
+    resultLoad = load;
+    resultStorage = load.getMemref();
+    // The result alloca may be behind a fir.declare, if any.
+    if (auto declare = resultStorage.getDefiningOp<fir::DeclareOp>())
+      resultStorage = declare.getMemref();
+  }
+  // Replace old local storage with new storage argument, unless
+  // the derived type is C_PTR/C_FUN_PTR, in which case the return
+  // type is updated to return void* (no new argument is passed).
+  if (fir::isa_builtin_cptr_type(resultValue.getType())) {
+    auto module = ret->template getParentOfType<mlir::ModuleOp>();
+    FirOpBuilder builder(rewriter, module);
+    mlir::Value cptr = resultValue;
+    if (resultLoad) {
+      // Replace whole derived type load by component load.
+      cptr = resultLoad.getMemref();
+      rewriter.setInsertionPoint(resultLoad);
+    }
+    mlir::Value newResultValue =
+        fir::factory::genCPtrOrCFunptrValue(builder, loc, cptr);
+    newResultValue = builder.createConvert(
+        loc, getVoidPtrType(ret.getContext()), newResultValue);
+    rewriter.setInsertionPoint(ret);
+    rewriter.replaceOpWithNewOp<OpTy>(ret, mlir::ValueRange{newResultValue});
+  } else if (resultStorage) {
+    resultStorage.replaceAllUsesWith(newArg);
+    rewriter.replaceOpWithNewOp<OpTy>(ret);
+  } else {
+    // The result storage may have been optimized out by a memory to
+    // register pass, this is possible for fir.box results, or fir.record
+    // with no length parameters. Simply store the result in the result
+    // storage. at the return point.
+    rewriter.create<fir::StoreOp>(loc, resultValue, newArg);
+    rewriter.replaceOpWithNewOp<OpTy>(ret);
+  }
+  // Delete result old local storage if unused.
+  if (resultStorage)
+    if (auto alloc = resultStorage.getDefiningOp<fir::AllocaOp>())
+      if (alloc->use_empty())
+        rewriter.eraseOp(alloc);
+  return mlir::success();
+}
+
 class ReturnOpConversion : public mlir::OpRewritePattern<mlir::func::ReturnOp> {
 public:
   using OpRewritePattern::OpRewritePattern;
@@ -242,55 +296,23 @@ class ReturnOpConversion : public mlir::OpRewritePattern<mlir::func::ReturnOp> {
   llvm::LogicalResult
   matchAndRewrite(mlir::func::ReturnOp ret,
                   mlir::PatternRewriter &rewriter) const override {
-    auto loc = ret.getLoc();
-    rewriter.setInsertionPoint(ret);
-    mlir::Value resultValue = ret.getOperand(0);
-    fir::LoadOp resultLoad;
-    mlir::Value resultStorage;
-    // Identify result local storage.
-    if (auto load = resultValue.getDefiningOp<fir::LoadOp>()) {
-      resultLoad = load;
-      resultStorage = load.getMemref();
-      // The result alloca may be behind a fir.declare, if any.
-      if (auto declare = resultStorage.getDefiningOp<fir::DeclareOp>())
-        resultStorage = declare.getMemref();
-    }
-    // Replace old local storage with new storage argument, unless
-    // the derived type is C_PTR/C_FUN_PTR, in which case the return
-    // type is updated to return void* (no new argument is passed).
-    if (fir::isa_builtin_cptr_type(resultValue.getType())) {
-      auto module = ret->getParentOfType<mlir::ModuleOp>();
-      FirOpBuilder builder(rewriter, module);
-      mlir::Value cptr = resultValue;
-      if (resultLoad) {
-        // Replace whole derived type load by component load.
-        cptr = resultLoad.getMemref();
-        rewriter.setInsertionPoint(resultLoad);
-      }
-      mlir::Value newResultValue =
-          fir::factory::genCPtrOrCFunptrValue(builder, loc, cptr);
-      newResultValue = builder.createConvert(
-          loc, getVoidPtrType(ret.getContext()), newResultValue);
-      rewriter.setInsertionPoint(ret);
-      rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(
-          ret, mlir::ValueRange{newResultValue});
-    } else if (resultStorage) {
-      resultStorage.replaceAllUsesWith(newArg);
-      rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret);
-    } else {
-      // The result storage may have been optimized out by a memory to
-      // register pass, this is possible for fir.box results, or fir.record
-      // with no length parameters. Simply store the result in the result
-      // storage. at the return point.
-      rewriter.create<fir::StoreOp>(loc, resultValue, newArg);
-      rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret);
-    }
-    // Delete result old local storage if unused.
-    if (resultStorage)
-      if (auto alloc = resultStorage.getDefiningOp<fir::AllocaOp>())
-        if (alloc->use_empty())
-          rewriter.eraseOp(alloc);
-    return mlir::success();
+    return processReturnLikeOp(ret, newArg, rewriter);
+  }
+
+private:
+  mlir::Value newArg;
+};
+
+class GPUReturnOpConversion
+    : public mlir::OpRewritePattern<mlir::gpu::ReturnOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  GPUReturnOpConversion(mlir::MLIRContext *context, mlir::Value newArg)
+      : OpRewritePattern(context), newArg{newArg} {}
+  llvm::LogicalResult
+  matchAndRewrite(mlir::gpu::ReturnOp ret,
+                  mlir::PatternRewriter &rewriter) const override {
+    return processReturnLikeOp(ret, newArg, rewriter);
   }
 
 private:
@@ -373,6 +395,9 @@ class AbstractResultOpt
         patterns.insert<ReturnOpConversion>(context, newArg);
         target.addDynamicallyLegalOp<mlir::func::ReturnOp>(
             [](mlir::func::ReturnOp ret) { return ret.getOperands().empty(); });
+        patterns.insert<GPUReturnOpConversion>(context, newArg);
+        target.addDynamicallyLegalOp<mlir::gpu::ReturnOp>(
+            [](mlir::gpu::ReturnOp ret) { return ret.getOperands().empty(); });
         assert(func.getFunctionType() ==
                getNewFunctionType(funcTy, shouldBoxResult));
       } else {
diff --git a/flang/test/Fir/CUDA/cuda-abstract-result.mlir b/flang/test/Fir/CUDA/cuda-abstract-result.mlir
new file mode 100644
index 00000000000000..8c59487ca5cd5c
--- /dev/null
+++ b/flang/test/Fir/CUDA/cuda-abstract-result.mlir
@@ -0,0 +1,37 @@
+// RUN: fir-opt -pass-pipeline='builtin.module(gpu.module(gpu.func(abstract-result)))' %s | FileCheck %s
+
+gpu.module @test {
+ gpu.func @_QMinterval_mPtest1(%arg0: !fir.ref<!fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}>>, %arg1: !fir.ref<f32>) -> !fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}> {
+    %c1_i32 = arith.constant 1 : i32
+    %18 = fir.dummy_scope : !fir.dscope
+    %19 = fir.declare %arg0 dummy_scope %18 {uniq_name = "_QMinterval_mFtest1Ea"} : (!fir.ref<!fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}>>, !fir.dscope) -> !fir.ref<!fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}>>
+    %20 = fir.declare %arg1 dummy_scope %18 {uniq_name = "_QMinterval_mFtest1Eb"} : (!fir.ref<f32>, !fir.dscope) -> !fir.ref<f32>
+    %21 = fir.alloca !fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}> {bindc_name = "c", uniq_name = "_QMinterval_mFtest1Ec"}
+    %22 = fir.declare %21 {uniq_name = "_QMinterval_mFtest1Ec"} : (!fir.ref<!fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}>>) -> !fir.ref<!fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}>>
+    %23 = fir.alloca i32 {bindc_name = "warpsize", uniq_name = "_QMcudadeviceECwarpsize"}
+    %24 = fir.declare %23 {uniq_name = "_QMcudadeviceECwarpsize"} : (!fir.ref<i32>) -> !fir.ref<i32>
+    %25 = fir.field_index inf, !fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}>
+    %26 = fir.coordinate_of %19, %25 : (!fir.ref<!fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}>>, !fir.field) -> !fir.ref<f32>
+    %27 = fir.load %20 : !fir.ref<f32>
+    %28 = arith.negf %27 fastmath<contract> : f32
+    %29 = fir.load %26 : !fir.ref<f32>
+    %30 = fir.call @__fadd_rd(%29, %28) proc_attrs<bind_c> fastmath<contract> : (f32, f32) -> f32
+    %31 = fir.field_index inf, !fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}>
+    %32 = fir.coordinate_of %22, %31 : (!fir.ref<!fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}>>, !fir.field) -> !fir.ref<f32>
+    fir.store %30 to %32 : !fir.ref<f32>
+    %33 = fir.field_index sup, !fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}>
+    %34 = fir.coordinate_of %19, %33 : (!fir.ref<!fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}>>, !fir.field) -> !fir.ref<f32>
+    %35 = fir.load %20 : !fir.ref<f32>
+    %36 = arith.negf %35 fastmath<contract> : f32
+    %37 = fir.load %34 : !fir.ref<f32>
+    %38 = fir.call @__fadd_ru(%37, %36) proc_attrs<bind_c> fastmath<contract> : (f32, f32) -> f32
+    %39 = fir.field_index sup, !fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}>
+    %40 = fir.coordinate_of %22, %39 : (!fir.ref<!fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}>>, !fir.field) -> !fir.ref<f32>
+    fir.store %38 to %40 : !fir.ref<f32>
+    %41 = fir.load %22 : !fir.ref<!fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}>>
+    gpu.return %41 : !fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}>
+  }
+}
+
+// CHECK: gpu.func @_QMinterval_mPtest1(%arg0: !fir.ref<!fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}>>, %arg1: !fir.ref<!fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}>>, %arg2: !fir.ref<f32>) {
+// CHECK: gpu.return{{$}}



More information about the flang-commits mailing list