[Mlir-commits] [mlir] 2c915e3 - [mlir][OpenMP] Add if clause to OpenMP simd construct

Dominik Adamski llvmlistbot at llvm.org
Wed Jul 6 05:37:14 PDT 2022


Author: Dominik Adamski
Date: 2022-07-06T07:24:48-05:00
New Revision: 2c915e3b2627a4e03341e14b354915c58741d7ec

URL: https://github.com/llvm/llvm-project/commit/2c915e3b2627a4e03341e14b354915c58741d7ec
DIFF: https://github.com/llvm/llvm-project/commit/2c915e3b2627a4e03341e14b354915c58741d7ec.diff

LOG: [mlir][OpenMP] Add if clause to OpenMP simd construct

This patch adds if clause to OpenMP TableGen for simd construct.

Reviewed By: peixin

Differential Revision: https://reviews.llvm.org/D128940

Signed-off-by: Dominik Adamski <dominik.adamski at amd.com>

Added: 
    

Modified: 
    flang/lib/Lower/OpenMP.cpp
    flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
    flang/test/Lower/OpenMP/simd.f90
    mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
    mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
    mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
    mlir/test/Dialect/OpenMP/invalid.mlir
    mlir/test/Dialect/OpenMP/ops.mlir
    mlir/test/Target/LLVMIR/openmp-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index 0949cb7d3aea2..9ac78b537abfa 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -507,6 +507,19 @@ static omp::ClauseProcBindKindAttr genProcBindKindAttr(
   return omp::ClauseProcBindKindAttr::get(firOpBuilder.getContext(), pbKind);
 }
 
+static mlir::Value
+getIfClauseOperand(Fortran::lower::AbstractConverter &converter,
+                   Fortran::lower::StatementContext &stmtCtx,
+                   const Fortran::parser::OmpClause::If *ifClause) {
+  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+  mlir::Location currentLocation = converter.getCurrentLocation();
+  auto &expr = std::get<Fortran::parser::ScalarLogicalExpr>(ifClause->v.t);
+  mlir::Value ifVal = fir::getBase(
+      converter.genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx));
+  return firOpBuilder.createConvert(currentLocation, firOpBuilder.getI1Type(),
+                                    ifVal);
+}
+
 /* When parallel is used in a combined construct, then use this function to
  * create the parallel operation. It handles the parallel specific clauses
  * and leaves the rest for handling at the inner operations.
@@ -532,11 +545,7 @@ createCombinedParallelOp(Fortran::lower::AbstractConverter &converter,
   for (const Fortran::parser::OmpClause &clause : opClauseList.v) {
     if (const auto &ifClause =
             std::get_if<Fortran::parser::OmpClause::If>(&clause.u)) {
-      auto &expr = std::get<Fortran::parser::ScalarLogicalExpr>(ifClause->v.t);
-      mlir::Value ifVal = fir::getBase(
-          converter.genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx));
-      ifClauseOperand = firOpBuilder.createConvert(
-          currentLocation, firOpBuilder.getI1Type(), ifVal);
+      ifClauseOperand = getIfClauseOperand(converter, stmtCtx, ifClause);
     } else if (const auto &numThreadsClause =
                    std::get_if<Fortran::parser::OmpClause::NumThreads>(
                        &clause.u)) {
@@ -585,11 +594,7 @@ genOMP(Fortran::lower::AbstractConverter &converter,
   for (const auto &clause : opClauseList.v) {
     if (const auto &ifClause =
             std::get_if<Fortran::parser::OmpClause::If>(&clause.u)) {
-      auto &expr = std::get<Fortran::parser::ScalarLogicalExpr>(ifClause->v.t);
-      mlir::Value ifVal = fir::getBase(
-          converter.genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx));
-      ifClauseOperand = firOpBuilder.createConvert(
-          currentLocation, firOpBuilder.getI1Type(), ifVal);
+      ifClauseOperand = getIfClauseOperand(converter, stmtCtx, ifClause);
     } else if (const auto &numThreadsClause =
                    std::get_if<Fortran::parser::OmpClause::NumThreads>(
                        &clause.u)) {
@@ -760,9 +765,10 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
   mlir::Location currentLocation = converter.getCurrentLocation();
   llvm::SmallVector<mlir::Value> lowerBound, upperBound, step, linearVars,
       linearStepVars, reductionVars;
-  mlir::Value scheduleChunkClauseOperand;
+  mlir::Value scheduleChunkClauseOperand, ifClauseOperand;
   mlir::Attribute scheduleClauseOperand, noWaitClauseOperand,
       orderedClauseOperand, orderClauseOperand;
+  Fortran::lower::StatementContext stmtCtx;
   const auto &loopOpClauseList = std::get<Fortran::parser::OmpClauseList>(
       std::get<Fortran::parser::OmpBeginLoopDirective>(loopConstruct.t).t);
 
@@ -823,11 +829,13 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
               std::get<std::optional<Fortran::parser::ScalarIntExpr>>(
                   scheduleClause->v.t)) {
         if (const auto *expr = Fortran::semantics::GetExpr(*chunkExpr)) {
-          Fortran::lower::StatementContext stmtCtx;
           scheduleChunkClauseOperand =
               fir::getBase(converter.genExprValue(*expr, stmtCtx));
         }
       }
+    } else if (const auto &ifClause =
+                   std::get_if<Fortran::parser::OmpClause::If>(&clause.u)) {
+      ifClauseOperand = getIfClauseOperand(converter, stmtCtx, ifClause);
     }
   }
 
@@ -848,7 +856,8 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
   if (llvm::omp::OMPD_simd == ompDirective) {
     TypeRange resultType;
     auto SimdLoopOp = firOpBuilder.create<mlir::omp::SimdLoopOp>(
-        currentLocation, resultType, lowerBound, upperBound, step);
+        currentLocation, resultType, lowerBound, upperBound, step,
+        ifClauseOperand, /*inclusive=*/firOpBuilder.getUnitAttr());
     createBodyOfOp<omp::SimdLoopOp>(SimdLoopOp, converter, currentLocation,
                                     eval, &loopOpClauseList, iv);
     return;

diff  --git a/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir b/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
index 32e7bfceab959..1e2d07cea0bc1 100644
--- a/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
+++ b/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
@@ -181,7 +181,7 @@ func.func @_QPsimd1(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}, %arg1: !fir.ref
   omp.parallel  {
     %1 = fir.alloca i32 {adapt.valuebyref, pinned}
     %2 = fir.load %arg0 : !fir.ref<i32>
-    omp.simdloop (%arg2) : i32 = (%c1_i32) to (%2) step (%c1_i32)  {
+    omp.simdloop for (%arg2) : i32 = (%c1_i32) to (%2) step (%c1_i32)  {
       fir.store %arg2 to %1 : !fir.ref<i32>
       %3 = fir.load %1 : !fir.ref<i32>
       %4 = fir.convert %3 : (i32) -> i64

diff  --git a/flang/test/Lower/OpenMP/simd.f90 b/flang/test/Lower/OpenMP/simd.f90
index 1d08fb7069c90..df4489f9f20cf 100644
--- a/flang/test/Lower/OpenMP/simd.f90
+++ b/flang/test/Lower/OpenMP/simd.f90
@@ -9,7 +9,7 @@ subroutine simdloop
   ! CHECK: %[[LB:.*]] = arith.constant 1 : i32
   ! CHECK-NEXT: %[[UB:.*]] = arith.constant 9 : i32
   ! CHECK-NEXT: %[[STEP:.*]] = arith.constant 1 : i32
-  ! CHECK-NEXT: omp.simdloop (%[[I:.*]]) : i32 = (%[[LB]]) to (%[[UB]]) step (%[[STEP]]) { 
+  ! CHECK-NEXT: omp.simdloop for (%[[I:.*]]) : i32 = (%[[LB]]) to (%[[UB]]) inclusive step (%[[STEP]]) {
   do i=1, 9
     ! CHECK: fir.store %[[I]] to %[[LOCAL:.*]] : !fir.ref<i32>
     ! CHECK: %[[LD:.*]] = fir.load %[[LOCAL]] : !fir.ref<i32>
@@ -18,3 +18,21 @@ subroutine simdloop
   end do
   !$OMP END SIMD 
 end subroutine
+
+!CHECK-LABEL: func @_QPsimdloop_with_if_clause
+subroutine simdloop_with_if_clause(n, threshold)
+integer :: i, n, threshold
+  !$OMP SIMD IF( n .GE. threshold )
+  ! CHECK: %[[LB:.*]] = arith.constant 1 : i32
+  ! CHECK: %[[UB:.*]] = fir.load %arg0
+  ! CHECK: %[[STEP:.*]] = arith.constant 1 : i32
+  ! CHECK: %[[COND:.*]] = arith.cmpi sge
+  ! CHECK: omp.simdloop if(%[[COND:.*]]) for (%[[I:.*]]) : i32 = (%[[LB]]) to (%[[UB]]) inclusive  step (%[[STEP]]) {
+  do i = 1, n
+    ! CHECK: fir.store %[[I]] to %[[LOCAL:.*]] : !fir.ref<i32>
+    ! CHECK: %[[LD:.*]] = fir.load %[[LOCAL]] : !fir.ref<i32>
+    ! CHECK: fir.call @_FortranAioOutputInteger32({{.*}}, %[[LD]]) : (!fir.ref<i8>, i32) -> i1
+    print*, i
+  end do
+  !$OMP END SIMD
+end subroutine

diff  --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 761e964644b0d..0c85e6bd09cc4 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -420,13 +420,18 @@ def SimdLoopOp : OpenMP_Op<"simdloop", [AttrSizedOperandSegments,
     transformed into a SIMD loop (that is, multiple iterations of the loop can 
     be executed concurrently using SIMD instructions).. The lower and upper 
     bounds specify a half-open range: the range includes the lower bound but 
-    does not include the upper bound.
+    does not include the upper bound. If the `inclusive` attribute is specified
+    then the upper bound is also included.
 
     The body region can contain any number of blocks. The region is terminated
     by "omp.yield" instruction without operands.
+
+    When an if clause is present and evaluates to false, the preferred number of
+    iterations to be executed concurrently is one, regardless of whether
+    a simdlen clause is speciļ¬ed.
     ```
-    omp.simdloop (%i1, %i2) : index = (%c0, %c0) to (%c10, %c10) 
-                                      step (%c1, %c1) {
+    omp.simdloop <clauses>
+    for (%i1, %i2) : index = (%c0, %c0) to (%c10, %c10) step (%c1, %c1) {
       // block operations
       omp.yield
     }
@@ -436,9 +441,17 @@ def SimdLoopOp : OpenMP_Op<"simdloop", [AttrSizedOperandSegments,
   // TODO: Add other clauses
   let arguments = (ins Variadic<IntLikeType>:$lowerBound,
              Variadic<IntLikeType>:$upperBound,
-             Variadic<IntLikeType>:$step);
+             Variadic<IntLikeType>:$step,
+             Optional<I1>:$if_expr,
+             UnitAttr:$inclusive
+     );
  
   let regions = (region AnyRegion:$region);
+  let assemblyFormat = [{
+    oilist(`if` `(` $if_expr `)`
+    ) `for` custom<LoopControl>($region, $lowerBound, $upperBound, $step,
+                                  type($step), $inclusive) attr-dict
+  }];
 
   let extraClassDeclaration = [{
     /// Returns the number of loops in the simd loop nest.

diff  --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index d09ef969c2f68..96ff6b1c1414e 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -570,62 +570,6 @@ void printLoopControl(OpAsmPrinter &p, Operation *op, Region &region,
   p.printRegion(region, /*printEntryBlockArgs=*/false);
 }
 
-//===----------------------------------------------------------------------===//
-// SimdLoopOp
-//===----------------------------------------------------------------------===//
-/// Parses an OpenMP Simd construct [2.9.3.1]
-///
-/// simdloop ::= `omp.simdloop` loop-control clause-list
-/// loop-control ::= `(` ssa-id-list `)` `:` type `=`  loop-bounds
-/// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` steps
-/// steps := `step` `(`ssa-id-list`)`
-/// clause-list ::= clause clause-list | empty
-/// clause ::= TODO
-ParseResult SimdLoopOp::parse(OpAsmParser &parser, OperationState &result) {
-  // Parse an opening `(` followed by induction variables followed by `)`
-  SmallVector<OpAsmParser::Argument> ivs;
-  Type loopVarType;
-  SmallVector<OpAsmParser::UnresolvedOperand> lower, upper, steps;
-  if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren) ||
-      parser.parseColonType(loopVarType) ||
-      // Parse loop bounds.
-      parser.parseEqual() ||
-      parser.parseOperandList(lower, ivs.size(),
-                              OpAsmParser::Delimiter::Paren) ||
-      parser.resolveOperands(lower, loopVarType, result.operands) ||
-      parser.parseKeyword("to") ||
-      parser.parseOperandList(upper, ivs.size(),
-                              OpAsmParser::Delimiter::Paren) ||
-      parser.resolveOperands(upper, loopVarType, result.operands) ||
-      // Parse step values.
-      parser.parseKeyword("step") ||
-      parser.parseOperandList(steps, ivs.size(),
-                              OpAsmParser::Delimiter::Paren) ||
-      parser.resolveOperands(steps, loopVarType, result.operands))
-    return failure();
-
-  int numIVs = static_cast<int>(ivs.size());
-  SmallVector<int> segments{numIVs, numIVs, numIVs};
-  // TODO: Add parseClauses() when we support clauses
-  result.addAttribute("operand_segment_sizes",
-                      parser.getBuilder().getI32VectorAttr(segments));
-
-  // Now parse the body.
-  Region *body = result.addRegion();
-  for (auto &iv : ivs)
-    iv.type = loopVarType;
-  return parser.parseRegion(*body, ivs);
-}
-
-void SimdLoopOp::print(OpAsmPrinter &p) {
-  auto args = getRegion().front().getArguments();
-  p << " (" << args << ") : " << args[0].getType() << " = (" << lowerBound()
-    << ") to (" << upperBound() << ") ";
-  p << "step (" << step() << ") ";
-
-  p.printRegion(region(), /*printEntryBlockArgs=*/false);
-}
-
 //===----------------------------------------------------------------------===//
 // Verifier for Simd construct [2.9.3.1]
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 87409ee568483..1c7c11fccbf8e 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -912,6 +912,11 @@ convertOmpSimdLoop(Operation &opInst, llvm::IRBuilderBase &builder,
   SmallVector<llvm::CanonicalLoopInfo *> loopInfos;
   SmallVector<llvm::OpenMPIRBuilder::InsertPointTy> bodyInsertPoints;
   LogicalResult bodyGenStatus = success();
+
+  // TODO: The code generation for if clause is not supported yet.
+  if (loop.if_expr())
+    return failure();
+
   auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *iv) {
     // Make sure further conversions know about the induction variable.
     moduleTranslation.mapValue(

diff  --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 27ba2c88a7893..379b8f4bfe1cc 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -197,7 +197,7 @@ func.func @omp_simdloop(%lb : index, %ub : index, %step : i32) -> () {
   "omp.simdloop" (%lb, %ub, %step) ({
     ^bb0(%iv: index):
       omp.yield
-  }) {operand_segment_sizes = dense<[1,1,1]> : vector<3xi32>} :
+  }) {operand_segment_sizes = dense<[1,1,1,0]> : vector<4xi32>} :
     (index, index, i32) -> () 
 
   return

diff  --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 2682c977cca90..393378f70b7e9 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -329,21 +329,29 @@ func.func @omp_wsloop_pretty_multiple(%lb1 : i32, %ub1 : i32, %step1 : i32, %lb2
 
 // CHECK-LABEL: omp_simdloop
 func.func @omp_simdloop(%lb : index, %ub : index, %step : index) -> () {
-  // CHECK: omp.simdloop (%{{.*}}) : index = (%{{.*}}) to (%{{.*}}) step (%{{.*}})
+  // CHECK: omp.simdloop for (%{{.*}}) : index = (%{{.*}}) to (%{{.*}}) step (%{{.*}})
   "omp.simdloop" (%lb, %ub, %step) ({
     ^bb0(%iv: index):
       omp.yield
-  }) {operand_segment_sizes = dense<[1,1,1]> : vector<3xi32>} :
+  }) {operand_segment_sizes = dense<[1,1,1,0]> : vector<4xi32>} :
     (index, index, index) -> () 
 
   return
 }
 
-
 // CHECK-LABEL: omp_simdloop_pretty
 func.func @omp_simdloop_pretty(%lb : index, %ub : index, %step : index) -> () {
-  // CHECK: omp.simdloop (%{{.*}}) : index = (%{{.*}}) to (%{{.*}}) step (%{{.*}})
-  omp.simdloop (%iv) : index = (%lb) to (%ub) step (%step) {
+  // CHECK: omp.simdloop for (%{{.*}}) : index = (%{{.*}}) to (%{{.*}}) step (%{{.*}})
+  omp.simdloop for (%iv) : index = (%lb) to (%ub) step (%step) {
+    omp.yield
+  }
+  return
+}
+
+// CHECK-LABEL: omp_simdloop_pretty_if
+func.func @omp_simdloop_pretty_if(%lb : index, %ub : index, %step : index, %if_cond : i1) -> () {
+  // CHECK: omp.simdloop if(%{{.*}}) for (%{{.*}}) : index = (%{{.*}}) to (%{{.*}}) step (%{{.*}})
+  omp.simdloop if(%if_cond) for (%iv): index = (%lb) to (%ub) step (%step) {
     omp.yield
   }
   return
@@ -351,8 +359,8 @@ func.func @omp_simdloop_pretty(%lb : index, %ub : index, %step : index) -> () {
 
 // CHECK-LABEL: omp_simdloop_pretty_multiple
 func.func @omp_simdloop_pretty_multiple(%lb1 : index, %ub1 : index, %step1 : index, %lb2 : index, %ub2 : index, %step2 : index) -> () {
-  // CHECK: omp.simdloop (%{{.*}}, %{{.*}}) : index = (%{{.*}}, %{{.*}}) to (%{{.*}}, %{{.*}}) step (%{{.*}}, %{{.*}})
-  omp.simdloop (%iv1, %iv2) : index = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
+  // CHECK: omp.simdloop for (%{{.*}}, %{{.*}}) : index = (%{{.*}}, %{{.*}}) to (%{{.*}}, %{{.*}}) step (%{{.*}}, %{{.*}})
+  omp.simdloop for (%iv1, %iv2) : index = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
     omp.yield
   }
   return

diff  --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
index bdb2fad4e257e..40ac2e3e28ebe 100644
--- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
@@ -697,7 +697,7 @@ llvm.func @simdloop_simple(%lb : i64, %ub : i64, %step : i64, %arg0: !llvm.ptr<f
       %4 = llvm.getelementptr %arg0[%iv] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
       llvm.store %3, %4 : !llvm.ptr<f32>
       omp.yield
-  }) {operand_segment_sizes = dense<[1,1,1]> : vector<3xi32>} :
+  }) {operand_segment_sizes = dense<[1,1,1,0]> : vector<4xi32>} :
     (i64, i64, i64) -> () 
 
   llvm.return
@@ -709,7 +709,7 @@ llvm.func @simdloop_simple(%lb : i64, %ub : i64, %step : i64, %arg0: !llvm.ptr<f
 
 // CHECK-LABEL: @simdloop_simple_multiple
 llvm.func @simdloop_simple_multiple(%lb1 : i64, %ub1 : i64, %step1 : i64, %lb2 : i64, %ub2 : i64, %step2 : i64, %arg0: !llvm.ptr<f32>, %arg1: !llvm.ptr<f32>) {
-  omp.simdloop (%iv1, %iv2) : i64 = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
+  omp.simdloop for (%iv1, %iv2) : i64 = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
     %3 = llvm.mlir.constant(2.000000e+00 : f32) : f32
     // The form of the emitted IR is controlled by OpenMPIRBuilder and
     // tested there. Just check that the right metadata is added.


        


More information about the Mlir-commits mailing list