[flang-commits] [flang] Allow do concurrent inside cuf kernel directive (PR #127693)

Zhen Wang via flang-commits flang-commits at lists.llvm.org
Tue Feb 18 20:46:17 PST 2025


https://github.com/wangzpgi updated https://github.com/llvm/llvm-project/pull/127693

>From fdeb3fa2b2f15179a6745d261f3b8697881bdb10 Mon Sep 17 00:00:00 2001
From: Zhen Wang <zhenw at nvidia.com>
Date: Tue, 18 Feb 2025 12:21:02 -0800
Subject: [PATCH 1/2] Allow do concurrent inside cuf kernel directive

---
 flang/lib/Lower/Bridge.cpp            | 165 +++++++++++++++++++-------
 flang/test/Lower/CUDA/cuda-doconc.cuf |  20 ++++
 2 files changed, 145 insertions(+), 40 deletions(-)
 create mode 100644 flang/test/Lower/CUDA/cuda-doconc.cuf

diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 36e58e456dea3..61dd9f0797fc9 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -3074,50 +3074,135 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     llvm::SmallVector<mlir::Value> ivValues;
     Fortran::lower::pft::Evaluation *loopEval =
         &getEval().getFirstNestedEvaluation();
-    for (unsigned i = 0; i < nestedLoops; ++i) {
-      const Fortran::parser::LoopControl *loopControl;
-      mlir::Location crtLoc = loc;
-      if (i == 0) {
-        loopControl = &*outerDoConstruct->GetLoopControl();
-        crtLoc =
-            genLocation(Fortran::parser::FindSourceLocation(outerDoConstruct));
-      } else {
-        auto *doCons = loopEval->getIf<Fortran::parser::DoConstruct>();
-        assert(doCons && "expect do construct");
-        loopControl = &*doCons->GetLoopControl();
-        crtLoc = genLocation(Fortran::parser::FindSourceLocation(*doCons));
+    bool isDoConcurrent = outerDoConstruct->IsDoConcurrent();
+    if (isDoConcurrent) {
+      // Handle DO CONCURRENT
+      locs.push_back(
+          genLocation(Fortran::parser::FindSourceLocation(outerDoConstruct)));
+      const Fortran::parser::LoopControl *loopControl =
+          &*outerDoConstruct->GetLoopControl();
+      const auto &concurrent =
+          std::get<Fortran::parser::LoopControl::Concurrent>(loopControl->u);
+
+      if (!std::get<std::list<Fortran::parser::LocalitySpec>>(concurrent.t)
+               .empty())
+        TODO(loc, "DO CONCURRENT with locality spec");
+
+      const auto &concurrentHeader =
+          std::get<Fortran::parser::ConcurrentHeader>(concurrent.t);
+      const auto &controls =
+          std::get<std::list<Fortran::parser::ConcurrentControl>>(
+              concurrentHeader.t);
+
+      for (const auto &control : controls) {
+        auto lb = fir::getBase(genExprValue(
+            *Fortran::semantics::GetExpr(std::get<1>(control.t)), stmtCtx));
+        auto ub = fir::getBase(genExprValue(
+            *Fortran::semantics::GetExpr(std::get<2>(control.t)), stmtCtx));
+        mlir::Value step;
+
+        if (const auto &expr =
+                std::get<std::optional<Fortran::parser::ScalarIntExpr>>(
+                    control.t)) {
+          step = fir::getBase(
+              genExprValue(*Fortran::semantics::GetExpr(*expr), stmtCtx));
+        } else {
+          step = builder->create<mlir::arith::ConstantIndexOp>(
+              loc, 1); // Use index type directly
+        }
+
+        // Ensure lb, ub, and step are of index type using fir.convert
+        auto indexType = builder->getIndexType();
+        if (lb.getType() != indexType) {
+          lb = builder->create<fir::ConvertOp>(loc, indexType, lb);
+        }
+        if (ub.getType() != indexType) {
+          ub = builder->create<fir::ConvertOp>(loc, indexType, ub);
+        }
+        if (step.getType() != indexType) {
+          step = builder->create<fir::ConvertOp>(loc, indexType, step);
+        }
+
+        lbs.push_back(lb);
+        ubs.push_back(ub);
+        steps.push_back(step);
+
+        const auto &name = std::get<Fortran::parser::Name>(control.t);
+
+        // Handle induction variable
+        mlir::Value ivValue = getSymbolAddress(*name.symbol);
+        std::size_t ivTypeSize = name.symbol->size();
+        if (ivTypeSize == 0)
+          llvm::report_fatal_error("unexpected induction variable size");
+        mlir::Type ivTy = builder->getIntegerType(ivTypeSize * 8);
+
+        if (!ivValue) {
+          // DO CONCURRENT induction variables are not mapped yet since they are
+          // local to the DO CONCURRENT scope.
+          mlir::OpBuilder::InsertPoint insPt = builder->saveInsertionPoint();
+          builder->setInsertionPointToStart(builder->getAllocaBlock());
+          ivValue = builder->createTemporaryAlloc(
+              loc, ivTy, toStringRef(name.symbol->name()));
+          builder->restoreInsertionPoint(insPt);
+        }
+
+        // Create the hlfir.declare operation using the symbol's name
+        auto declareOp = builder->create<hlfir::DeclareOp>(
+            loc, ivValue, toStringRef(name.symbol->name()));
+        ivValue = declareOp.getResult(0);
+
+        // Bind the symbol to the declared variable
+        bindSymbol(*name.symbol, ivValue);
+        ivValues.push_back(ivValue);
+        ivTypes.push_back(ivTy);
+        ivLocs.push_back(loc);
       }
+    } else {
+      for (unsigned i = 0; i < nestedLoops; ++i) {
+        const Fortran::parser::LoopControl *loopControl;
+        mlir::Location crtLoc = loc;
+        if (i == 0) {
+          loopControl = &*outerDoConstruct->GetLoopControl();
+          crtLoc = genLocation(
+              Fortran::parser::FindSourceLocation(outerDoConstruct));
+        } else {
+          auto *doCons = loopEval->getIf<Fortran::parser::DoConstruct>();
+          assert(doCons && "expect do construct");
+          loopControl = &*doCons->GetLoopControl();
+          crtLoc = genLocation(Fortran::parser::FindSourceLocation(*doCons));
+        }
+
+        locs.push_back(crtLoc);
 
-      locs.push_back(crtLoc);
-
-      const Fortran::parser::LoopControl::Bounds *bounds =
-          std::get_if<Fortran::parser::LoopControl::Bounds>(&loopControl->u);
-      assert(bounds && "Expected bounds on the loop construct");
-
-      Fortran::semantics::Symbol &ivSym =
-          bounds->name.thing.symbol->GetUltimate();
-      ivValues.push_back(getSymbolAddress(ivSym));
-
-      lbs.push_back(builder->createConvert(
-          crtLoc, idxTy,
-          fir::getBase(genExprValue(*Fortran::semantics::GetExpr(bounds->lower),
-                                    stmtCtx))));
-      ubs.push_back(builder->createConvert(
-          crtLoc, idxTy,
-          fir::getBase(genExprValue(*Fortran::semantics::GetExpr(bounds->upper),
-                                    stmtCtx))));
-      if (bounds->step)
-        steps.push_back(builder->createConvert(
+        const Fortran::parser::LoopControl::Bounds *bounds =
+            std::get_if<Fortran::parser::LoopControl::Bounds>(&loopControl->u);
+        assert(bounds && "Expected bounds on the loop construct");
+
+        Fortran::semantics::Symbol &ivSym =
+            bounds->name.thing.symbol->GetUltimate();
+        ivValues.push_back(getSymbolAddress(ivSym));
+
+        lbs.push_back(builder->createConvert(
             crtLoc, idxTy,
             fir::getBase(genExprValue(
-                *Fortran::semantics::GetExpr(bounds->step), stmtCtx))));
-      else // If `step` is not present, assume it is `1`.
-        steps.push_back(builder->createIntegerConstant(loc, idxTy, 1));
-
-      ivTypes.push_back(idxTy);
-      ivLocs.push_back(crtLoc);
-      if (i < nestedLoops - 1)
-        loopEval = &*std::next(loopEval->getNestedEvaluations().begin());
+                *Fortran::semantics::GetExpr(bounds->lower), stmtCtx))));
+        ubs.push_back(builder->createConvert(
+            crtLoc, idxTy,
+            fir::getBase(genExprValue(
+                *Fortran::semantics::GetExpr(bounds->upper), stmtCtx))));
+        if (bounds->step)
+          steps.push_back(builder->createConvert(
+              crtLoc, idxTy,
+              fir::getBase(genExprValue(
+                  *Fortran::semantics::GetExpr(bounds->step), stmtCtx))));
+        else // If `step` is not present, assume it is `1`.
+          steps.push_back(builder->createIntegerConstant(loc, idxTy, 1));
+
+        ivTypes.push_back(idxTy);
+        ivLocs.push_back(crtLoc);
+        if (i < nestedLoops - 1)
+          loopEval = &*std::next(loopEval->getNestedEvaluations().begin());
+      }
     }
 
     auto op = builder->create<cuf::KernelOp>(
diff --git a/flang/test/Lower/CUDA/cuda-doconc.cuf b/flang/test/Lower/CUDA/cuda-doconc.cuf
new file mode 100644
index 0000000000000..e11688f4fe960
--- /dev/null
+++ b/flang/test/Lower/CUDA/cuda-doconc.cuf
@@ -0,0 +1,20 @@
+! RUN: bbc -emit-hlfir -fcuda %s -o - | FileCheck %s
+
+! Check if do concurrent works inside cuf kernel directive
+
+program main
+  integer :: i, n
+  integer, managed :: a(3)
+  a(:) = -1
+  n = 3
+  n = n - 1
+  !$cuf kernel do
+  do concurrent(i=1:n)
+    a(i) = 1
+  end do
+end
+
+! CHECK: func.func @_QQmain() attributes {fir.bindc_name = "main"} {
+! CHECK: %[[DECL:.*]]:2 = hlfir.declare %{{.*}}#0 {uniq_name = "_QFEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: cuf.kernel<<<*, *>>>
+! CHECK: %{{.*}} = fir.load %[[DECL]]#0 : !fir.ref<i32>

>From 17f277e0238310c685216a8c79f4dd2be8958125 Mon Sep 17 00:00:00 2001
From: Zhen Wang <zhenw at nvidia.com>
Date: Tue, 18 Feb 2025 20:44:43 -0800
Subject: [PATCH 2/2] addressing comments and add test

---
 flang/lib/Lower/Bridge.cpp            | 20 ++++++--------------
 flang/test/Lower/CUDA/cuda-doconc.cuf | 25 ++++++++++++++++++++++---
 2 files changed, 28 insertions(+), 17 deletions(-)

diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 61dd9f0797fc9..f8df44c9ec66f 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -3074,8 +3074,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     llvm::SmallVector<mlir::Value> ivValues;
     Fortran::lower::pft::Evaluation *loopEval =
         &getEval().getFirstNestedEvaluation();
-    bool isDoConcurrent = outerDoConstruct->IsDoConcurrent();
-    if (isDoConcurrent) {
+    if (outerDoConstruct->IsDoConcurrent()) {
       // Handle DO CONCURRENT
       locs.push_back(
           genLocation(Fortran::parser::FindSourceLocation(outerDoConstruct)));
@@ -3103,25 +3102,18 @@ class FirConverter : public Fortran::lower::AbstractConverter {
 
         if (const auto &expr =
                 std::get<std::optional<Fortran::parser::ScalarIntExpr>>(
-                    control.t)) {
+                    control.t))
           step = fir::getBase(
               genExprValue(*Fortran::semantics::GetExpr(*expr), stmtCtx));
-        } else {
+        else
           step = builder->create<mlir::arith::ConstantIndexOp>(
               loc, 1); // Use index type directly
-        }
 
         // Ensure lb, ub, and step are of index type using fir.convert
         auto indexType = builder->getIndexType();
-        if (lb.getType() != indexType) {
-          lb = builder->create<fir::ConvertOp>(loc, indexType, lb);
-        }
-        if (ub.getType() != indexType) {
-          ub = builder->create<fir::ConvertOp>(loc, indexType, ub);
-        }
-        if (step.getType() != indexType) {
-          step = builder->create<fir::ConvertOp>(loc, indexType, step);
-        }
+        lb = builder->create<fir::ConvertOp>(loc, indexType, lb);
+        ub = builder->create<fir::ConvertOp>(loc, indexType, ub);
+        step = builder->create<fir::ConvertOp>(loc, indexType, step);
 
         lbs.push_back(lb);
         ubs.push_back(ub);
diff --git a/flang/test/Lower/CUDA/cuda-doconc.cuf b/flang/test/Lower/CUDA/cuda-doconc.cuf
index e11688f4fe960..22db1fa36fe47 100644
--- a/flang/test/Lower/CUDA/cuda-doconc.cuf
+++ b/flang/test/Lower/CUDA/cuda-doconc.cuf
@@ -2,7 +2,7 @@
 
 ! Check if do concurrent works inside cuf kernel directive
 
-program main
+subroutine doconc1
   integer :: i, n
   integer, managed :: a(3)
   a(:) = -1
@@ -14,7 +14,26 @@ program main
   end do
 end
 
-! CHECK: func.func @_QQmain() attributes {fir.bindc_name = "main"} {
-! CHECK: %[[DECL:.*]]:2 = hlfir.declare %{{.*}}#0 {uniq_name = "_QFEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: func.func @_QPdoconc1() {
+! CHECK: %[[DECL:.*]]:2 = hlfir.declare %{{.*}}#0 {uniq_name = "_QFdoconc1Ei"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
 ! CHECK: cuf.kernel<<<*, *>>>
 ! CHECK: %{{.*}} = fir.load %[[DECL]]#0 : !fir.ref<i32>
+
+subroutine doconc2
+  integer :: i, j, m, n
+  integer, managed :: a(2, 4)
+  m = 2
+  n = 4
+  a(:,:) = -1
+  !$cuf kernel do
+  do concurrent(i=1:m,j=1:n)
+    a(i,j) = i+j
+  end do
+end
+
+! CHECK: func.func @_QPdoconc2() {
+! CHECK: %[[DECLI:.*]]:2 = hlfir.declare %{{.*}}#0 {uniq_name = "_QFdoconc2Ei"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: %[[DECLJ:.*]]:2 = hlfir.declare %{{.*}}#0 {uniq_name = "_QFdoconc2Ej"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: cuf.kernel<<<*, *>>>
+! CHECK: %{{.*}} = fir.load %[[DECLI]]#0 : !fir.ref<i32>
+! CHECK: %{{.*}} = fir.load %[[DECLJ]]#0 : !fir.ref<i32> 



More information about the flang-commits mailing list