[flang-commits] [flang] [mlir] [Flang][OpenMP] Support for lowering Linear clause to mlir (PR #111085)

via flang-commits flang-commits at lists.llvm.org
Thu Oct 3 19:49:34 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-openmp

Author: None (harishch4)

<details>
<summary>Changes</summary>

This supports lowering linear clauses to mlir on do, simd constructs.

---
Full diff: https://github.com/llvm/llvm-project/pull/111085.diff


6 Files Affected:

- (modified) flang/lib/Lower/OpenMP/ClauseProcessor.cpp (+30) 
- (modified) flang/lib/Lower/OpenMP/ClauseProcessor.h (+1) 
- (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+2-5) 
- (removed) flang/test/Lower/OpenMP/Todo/omp-do-simd-linear.f90 (-14) 
- (added) flang/test/Lower/OpenMP/linear-clause.f90 (+59) 
- (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+1-1) 


``````````diff
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index a4d2524bccf5c3..52203045e72a57 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -875,6 +875,36 @@ bool ClauseProcessor::processIsDevicePtr(
       });
 }
 
+bool ClauseProcessor::processLinear(mlir::omp::LinearClauseOps &result) const {
+  lower::StatementContext stmtCtx;
+  return findRepeatableClause<omp::clause::Linear>(
+      [&](const omp::clause::Linear &clause, const parser::CharBlock &) {
+        auto &objects = std::get<omp::ObjectList>(clause.t);
+        for (const omp::Object &object : objects) {
+          semantics::Symbol *sym = object.sym();
+          const mlir::Value variable = converter.getSymbolAddress(*sym);
+          result.linearVars.push_back(variable);
+        }
+        if (objects.size()) {
+          if (auto &mod = std::get<0>(clause.t)) {
+            mlir::Value operand =
+                fir::getBase(converter.genExprValue(*mod, stmtCtx));
+            result.linearStepVars.append(objects.size(), operand);
+          } else if (auto &mod = std::get<1>(clause.t)) {
+            mlir::Value operand =
+                fir::getBase(converter.genExprValue(*mod, stmtCtx));
+            result.linearStepVars.append(objects.size(), operand);
+          } else {
+            fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+            mlir::Location currentLocation = converter.getCurrentLocation();
+            mlir::Value operand = firOpBuilder.createIntegerConstant(
+                currentLocation, firOpBuilder.getI32Type(), 1);
+            result.linearStepVars.append(objects.size(), operand);
+          }
+        }
+      });
+}
+
 bool ClauseProcessor::processLink(
     llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
   return findRepeatableClause<omp::clause::Link>(
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index 0c8e7bd47ab5a6..202d0f89add6a0 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -107,6 +107,7 @@ class ClauseProcessor {
       llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes,
       llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs,
       llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSymbols) const;
+  bool processLinear(mlir::omp::LinearClauseOps &result) const;
   bool
   processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
 
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 60c83586e468b6..2beaca3049ee28 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1132,13 +1132,12 @@ static void genSimdClauses(lower::AbstractConverter &converter,
   ClauseProcessor cp(converter, semaCtx, clauses);
   cp.processAligned(clauseOps);
   cp.processIf(llvm::omp::Directive::OMPD_simd, clauseOps);
+  cp.processLinear(clauseOps);
   cp.processNontemporal(clauseOps);
   cp.processOrder(clauseOps);
   cp.processReduction(loc, clauseOps);
   cp.processSafelen(clauseOps);
   cp.processSimdlen(clauseOps);
-
-  cp.processTODO<clause::Linear>(loc, llvm::omp::Directive::OMPD_simd);
 }
 
 static void genSingleClauses(lower::AbstractConverter &converter,
@@ -1300,14 +1299,12 @@ static void genWsloopClauses(
     llvm::SmallVectorImpl<mlir::Type> &reductionTypes,
     llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms) {
   ClauseProcessor cp(converter, semaCtx, clauses);
+  cp.processLinear(clauseOps);
   cp.processNowait(clauseOps);
   cp.processOrder(clauseOps);
   cp.processOrdered(clauseOps);
   cp.processReduction(loc, clauseOps, &reductionTypes, &reductionSyms);
   cp.processSchedule(stmtCtx, clauseOps);
-
-  cp.processTODO<clause::Allocate, clause::Linear>(
-      loc, llvm::omp::Directive::OMPD_do);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/flang/test/Lower/OpenMP/Todo/omp-do-simd-linear.f90 b/flang/test/Lower/OpenMP/Todo/omp-do-simd-linear.f90
deleted file mode 100644
index 4caf12a0169c42..00000000000000
--- a/flang/test/Lower/OpenMP/Todo/omp-do-simd-linear.f90
+++ /dev/null
@@ -1,14 +0,0 @@
-! This test checks lowering of OpenMP do simd linear() pragma
-
-! RUN: %not_todo_cmd bbc -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
-! RUN: %not_todo_cmd %flang_fc1 -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
-subroutine testDoSimdLinear(int_array)
-        integer :: int_array(*)
-!CHECK: not yet implemented: Unhandled clause LINEAR in SIMD construct
-!$omp do simd linear(int_array)
-        do index_ = 1, 10
-        end do
-!$omp end do simd
-
-end subroutine testDoSimdLinear
-
diff --git a/flang/test/Lower/OpenMP/linear-clause.f90 b/flang/test/Lower/OpenMP/linear-clause.f90
new file mode 100644
index 00000000000000..5d3f7f553716eb
--- /dev/null
+++ b/flang/test/Lower/OpenMP/linear-clause.f90
@@ -0,0 +1,59 @@
+! This test checks lowering of OpenMP linear() clause
+
+! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 %s -o - 2>&1 | FileCheck %s
+
+! CHECK-LABEL: func.func @_QPtestdolinear() {
+subroutine testDoLinear()
+   implicit none
+   integer :: i
+   integer :: A(10)
+!CHECK: %[[C10:.*]] = arith.constant 10 : index
+!CHECK: %[[A:.*]] = fir.alloca !fir.array<10xi32> {bindc_name = "a", uniq_name = "_QFtestdolinearEa"}
+!CHECK: %[[S2:.*]] = fir.shape %[[C10]] : (index) -> !fir.shape<1>
+!CHECK: %[[A_DECL:.*]]:2 = hlfir.declare %[[A]](%[[S2]]) {uniq_name = "_QFtestdolinearEa"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>)
+!CHECK: %[[C2:.*]] = arith.constant 2 : i32
+!CHECK: omp.wsloop linear(%[[A_DECL]]#1 = %[[C2]] : !fir.ref<!fir.array<10xi32>>) {
+!$omp do linear(A:2)
+   do i = 1, 10
+      A(i) = i
+   end do
+!$omp end do
+end subroutine testDoLinear
+
+! CHECK-LABEL: func.func @_QPtestsimdlinear() {
+subroutine testSimdLinear()
+   implicit none
+   integer :: i
+   integer :: A(10)
+!CHECK: %[[C10:.*]] = arith.constant 10 : index
+!CHECK: %[[A:.*]] = fir.alloca !fir.array<10xi32> {bindc_name = "a", uniq_name = "_QFtestsimdlinearEa"}
+!CHECK: %[[S2:.*]] = fir.shape %[[C10]] : (index) -> !fir.shape<1>
+!CHECK: %[[A_DECL:.*]]:2 = hlfir.declare %[[A]](%[[S2]]) {uniq_name = "_QFtestsimdlinearEa"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>)
+!CHECK: %[[C2:.*]] = arith.constant 2 : i32
+!CHECK:   omp.simd linear(%[[A_DECL]]#1 = %[[C2]] : !fir.ref<!fir.array<10xi32>>) {
+!$omp simd linear(A:2)
+   do i = 1, 10
+      A(i) = i
+   end do
+!$omp end simd
+end subroutine testSimdLinear
+
+! CHECK-LABEL: func.func @_QPtestdosimdlinear() {
+subroutine testDoSimdLinear()
+   implicit none
+   integer :: i
+   integer :: A(10)
+!CHECK: %[[C10:.*]] = arith.constant 10 : index
+!CHECK: %[[A:.*]] = fir.alloca !fir.array<10xi32> {bindc_name = "a", uniq_name = "_QFtestdosimdlinearEa"}
+!CHECK: %[[S2:.*]] = fir.shape %[[C10]] : (index) -> !fir.shape<1>
+!CHECK: %[[A_DECL:.*]]:2 = hlfir.declare %[[A]](%[[S2]]) {uniq_name = "_QFtestdosimdlinearEa"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>)
+!CHECK: %[[C2:.*]] = arith.constant 2 : i32
+!CHECK: omp.wsloop {
+!CHECK:   omp.simd linear(%[[A_DECL]]#1 = %[[C2]] : !fir.ref<!fir.array<10xi32>>) {
+!$omp do simd linear(A:2)
+   do i = 1, 10
+      A(i) = i
+   end do
+!$omp end do simd
+end subroutine testDoSimdLinear
+
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index d516c8d9e0be6c..85aa03c7b54cd8 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2016,7 +2016,7 @@ void SimdOp::build(OpBuilder &builder, OperationState &state,
   // privateSyms, reductionVars, reductionByref, reductionSyms.
   SimdOp::build(builder, state, clauses.alignedVars,
                 makeArrayAttr(ctx, clauses.alignments), clauses.ifExpr,
-                /*linear_vars=*/{}, /*linear_step_vars=*/{},
+                clauses.linearVars, clauses.linearStepVars,
                 clauses.nontemporalVars, clauses.order, clauses.orderMod,
                 /*private_vars=*/{}, /*private_syms=*/nullptr,
                 /*reduction_vars=*/{}, /*reduction_byref=*/nullptr,

``````````

</details>


https://github.com/llvm/llvm-project/pull/111085


More information about the flang-commits mailing list