[flang-commits] [flang] [mlir] [MLIR][OpenMP] Create `LoopRelatedClause` (PR #99506)

Sergio Afonso via flang-commits flang-commits at lists.llvm.org
Mon Jul 29 03:01:56 PDT 2024


https://github.com/skatrak updated https://github.com/llvm/llvm-project/pull/99506

>From 9edb733b1838cd77f3ebac133b27d18a2ded8427 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Tue, 16 Jul 2024 15:48:00 +0100
Subject: [PATCH] [MLIR][OpenMP] Create `LoopRelatedClause`

This patch introduces a new OpenMP clause definition not defined by the spec.

Its main purpose is to define the `loop_inclusive` (previously "inclusive",
renamed according to the parent of this PR in the stack) argument of
`omp.loop_nest` in such a way that a followup implementation of a tablegen
backend to automatically generate clause and operation operand structures
directly from `OpenMP_Op` and `OpenMP_Clause` definitions can properly generate
the `LoopNestOperands` structure.

`collapse` clause arguments are also moved into this new definition, as they
represent information on the loop nests being collapsed rather than the
`collapse` clause itself.
---
 flang/lib/Lower/OpenMP/ClauseProcessor.cpp    | 27 ++++----
 flang/lib/Lower/OpenMP/ClauseProcessor.h      |  2 +-
 .../lib/Lower/OpenMP/DataSharingProcessor.cpp |  4 +-
 .../Dialect/OpenMP/OpenMPClauseOperands.h     |  8 +--
 .../mlir/Dialect/OpenMP/OpenMPClauses.td      | 63 ++++++++++---------
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 10 ++-
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 18 +++---
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      | 19 +++---
 mlir/test/Dialect/OpenMP/ops.mlir             |  2 +-
 9 files changed, 72 insertions(+), 81 deletions(-)

diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index facf95e17707e..310c0b0b5fb63 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -181,20 +181,19 @@ static void addUseDeviceClause(
 
 static void convertLoopBounds(lower::AbstractConverter &converter,
                               mlir::Location loc,
-                              mlir::omp::CollapseClauseOps &result,
+                              mlir::omp::LoopRelatedOps &result,
                               std::size_t loopVarTypeSize) {
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
   // The types of lower bound, upper bound, and step are converted into the
   // type of the loop variable if necessary.
   mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize);
-  for (unsigned it = 0; it < (unsigned)result.collapseLowerBounds.size();
-       it++) {
-    result.collapseLowerBounds[it] = firOpBuilder.createConvert(
-        loc, loopVarType, result.collapseLowerBounds[it]);
-    result.collapseUpperBounds[it] = firOpBuilder.createConvert(
-        loc, loopVarType, result.collapseUpperBounds[it]);
-    result.collapseSteps[it] =
-        firOpBuilder.createConvert(loc, loopVarType, result.collapseSteps[it]);
+  for (unsigned it = 0; it < (unsigned)result.loopLowerBounds.size(); it++) {
+    result.loopLowerBounds[it] = firOpBuilder.createConvert(
+        loc, loopVarType, result.loopLowerBounds[it]);
+    result.loopUpperBounds[it] = firOpBuilder.createConvert(
+        loc, loopVarType, result.loopUpperBounds[it]);
+    result.loopSteps[it] =
+        firOpBuilder.createConvert(loc, loopVarType, result.loopSteps[it]);
   }
 }
 
@@ -204,7 +203,7 @@ static void convertLoopBounds(lower::AbstractConverter &converter,
 
 bool ClauseProcessor::processCollapse(
     mlir::Location currentLocation, lower::pft::Evaluation &eval,
-    mlir::omp::CollapseClauseOps &result,
+    mlir::omp::LoopRelatedOps &result,
     llvm::SmallVectorImpl<const semantics::Symbol *> &iv) const {
   bool found = false;
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
@@ -233,15 +232,15 @@ bool ClauseProcessor::processCollapse(
         std::get_if<parser::LoopControl::Bounds>(&loopControl->u);
     assert(bounds && "Expected bounds for worksharing do loop");
     lower::StatementContext stmtCtx;
-    result.collapseLowerBounds.push_back(fir::getBase(
+    result.loopLowerBounds.push_back(fir::getBase(
         converter.genExprValue(*semantics::GetExpr(bounds->lower), stmtCtx)));
-    result.collapseUpperBounds.push_back(fir::getBase(
+    result.loopUpperBounds.push_back(fir::getBase(
         converter.genExprValue(*semantics::GetExpr(bounds->upper), stmtCtx)));
     if (bounds->step) {
-      result.collapseSteps.push_back(fir::getBase(
+      result.loopSteps.push_back(fir::getBase(
           converter.genExprValue(*semantics::GetExpr(bounds->step), stmtCtx)));
     } else { // If `step` is not present, assume it as `1`.
-      result.collapseSteps.push_back(firOpBuilder.createIntegerConstant(
+      result.loopSteps.push_back(firOpBuilder.createIntegerConstant(
           currentLocation, firOpBuilder.getIntegerType(32), 1));
     }
     iv.push_back(bounds->name.thing.symbol);
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index 4340c3278cebc..2c4b3857fda9f 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -55,7 +55,7 @@ class ClauseProcessor {
   // 'Unique' clauses: They can appear at most once in the clause list.
   bool
   processCollapse(mlir::Location currentLocation, lower::pft::Evaluation &eval,
-                  mlir::omp::CollapseClauseOps &result,
+                  mlir::omp::LoopRelatedOps &result,
                   llvm::SmallVectorImpl<const semantics::Symbol *> &iv) const;
   bool processDefault() const;
   bool processDevice(lower::StatementContext &stmtCtx,
diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
index 75c4f14f38a32..78f83ede570a6 100644
--- a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
@@ -274,8 +274,8 @@ void DataSharingProcessor::insertLastPrivateCompare(mlir::Operation *op) {
       llvm::SmallVector<mlir::Value> vs;
       vs.reserve(loopOp.getIVs().size());
       for (auto [iv, ub, step] :
-           llvm::zip_equal(loopOp.getIVs(), loopOp.getCollapseUpperBounds(),
-                           loopOp.getCollapseSteps())) {
+           llvm::zip_equal(loopOp.getIVs(), loopOp.getLoopUpperBounds(),
+                           loopOp.getLoopSteps())) {
         // v = iv + step
         // cmp = step < 0 ? v < ub : v > ub
         mlir::Value v = firOpBuilder.create<mlir::arith::AddIOp>(loc, iv, step);
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
index 4730d544e8739..a680a85db0f25 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
@@ -43,11 +43,6 @@ struct CancelDirectiveNameClauseOps {
   ClauseCancellationConstructTypeAttr cancelDirective;
 };
 
-struct CollapseClauseOps {
-  llvm::SmallVector<Value> collapseLowerBounds, collapseUpperBounds,
-      collapseSteps;
-};
-
 struct CopyprivateClauseOps {
   llvm::SmallVector<Value> copyprivateVars;
   llvm::SmallVector<Attribute> copyprivateSyms;
@@ -125,6 +120,7 @@ struct LinearClauseOps {
 };
 
 struct LoopRelatedOps {
+  llvm::SmallVector<Value> loopLowerBounds, loopUpperBounds, loopSteps;
   UnitAttr loopInclusive;
 };
 
@@ -261,7 +257,7 @@ using DistributeOperands =
     detail::Clauses<AllocateClauseOps, DistScheduleClauseOps, OrderClauseOps,
                     PrivateClauseOps>;
 
-using LoopNestOperands = detail::Clauses<CollapseClauseOps, LoopRelatedOps>;
+using LoopNestOperands = detail::Clauses<LoopRelatedOps>;
 
 using MaskedOperands = detail::Clauses<FilterClauseOps>;
 
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index b6be8be63b328..e703c323edbc8 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -107,37 +107,6 @@ class OpenMP_CancelDirectiveNameClauseSkip<
 
 def OpenMP_CancelDirectiveNameClause : OpenMP_CancelDirectiveNameClauseSkip<>;
 
-//===----------------------------------------------------------------------===//
-// V5.2: [4.4.3] `collapse` clause
-//===----------------------------------------------------------------------===//
-
-class OpenMP_CollapseClauseSkip<
-    bit traits = false, bit arguments = false, bit assemblyFormat = false,
-    bit description = false, bit extraClassDeclaration = false
-  > : OpenMP_Clause</*isRequired=*/false, traits, arguments, assemblyFormat,
-                    description, extraClassDeclaration> {
-  let traits = [
-    AllTypesMatch<
-      ["collapse_lower_bounds", "collapse_upper_bounds", "collapse_steps"]>
-  ];
-
-  let arguments = (ins
-    Variadic<IntLikeType>:$collapse_lower_bounds,
-    Variadic<IntLikeType>:$collapse_upper_bounds,
-    Variadic<IntLikeType>:$collapse_steps
-  );
-
-  let extraClassDeclaration = [{
-    /// Returns the number of loops in the loop nest.
-    unsigned getNumLoops() { return getCollapseLowerBounds().size(); }
-  }];
-
-  // Description and formatting integrated in the `omp.loop_nest` operation,
-  // which is the only one currently accepting this clause.
-}
-
-def OpenMP_CollapseClause : OpenMP_CollapseClauseSkip<>;
-
 //===----------------------------------------------------------------------===//
 // V5.2: [5.7.2] `copyprivate` clause
 //===----------------------------------------------------------------------===//
@@ -564,6 +533,38 @@ class OpenMP_LinearClauseSkip<
 
 def OpenMP_LinearClause : OpenMP_LinearClauseSkip<>;
 
+//===----------------------------------------------------------------------===//
+// Not in the spec: Clause-like structure to hold loop related information.
+//===----------------------------------------------------------------------===//
+
+class OpenMP_LoopRelatedClauseSkip<
+    bit traits = false, bit arguments = false, bit assemblyFormat = false,
+    bit description = false, bit extraClassDeclaration = false
+  > : OpenMP_Clause</*isRequired=*/false, traits, arguments, assemblyFormat,
+                    description, extraClassDeclaration> {
+  let traits = [
+    AllTypesMatch<
+      ["loop_lower_bounds", "loop_upper_bounds", "loop_steps"]>
+  ];
+  
+  let arguments = (ins
+    Variadic<IntLikeType>:$loop_lower_bounds,
+    Variadic<IntLikeType>:$loop_upper_bounds,
+    Variadic<IntLikeType>:$loop_steps,
+    UnitAttr:$loop_inclusive
+  );
+
+  let extraClassDeclaration = [{
+    /// Returns the number of loops in the loop nest.
+    unsigned getNumLoops() { return getLoopLowerBounds().size(); }
+  }];
+
+  // Description and formatting integrated in the `omp.loop_nest` operation,
+  // which is the only one currently accepting this clause.
+}
+
+def OpenMP_LoopRelatedClause : OpenMP_LoopRelatedClauseSkip<>;
+
 //===----------------------------------------------------------------------===//
 // V5.2: [5.8.3] `map` clause
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index a42e32587e570..00dadc6e6538d 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -297,7 +297,7 @@ def SingleOp : OpenMP_Op<"single", traits = [
 def LoopNestOp : OpenMP_Op<"loop_nest", traits = [
     RecursiveMemoryEffects, SameVariadicOperandSize
   ], clauses = [
-    OpenMP_CollapseClause
+    OpenMP_LoopRelatedClause
   ], singleRegion = true> {
   let summary = "rectangular loop nest";
   let description = [{
@@ -306,14 +306,14 @@ def LoopNestOp : OpenMP_Op<"loop_nest", traits = [
     lower and upper bounds, as well as a step variable, must be defined.
 
     The lower and upper bounds specify a half-open range: the range includes the
-    lower bound but does not include the upper bound. If the `inclusive`
+    lower bound but does not include the upper bound. If the `loop_inclusive`
     attribute is specified then the upper bound is also included.
 
     The body region can contain any number of blocks. The region is terminated
     by an `omp.yield` instruction without operands. The induction variables,
     represented as entry block arguments to the loop nest operation's single
-    region, match the types of the `collapse_lower_bounds`,
-    `collapse_upper_bounds` and `collapse_steps` arguments.
+    region, match the types of the `loop_lower_bounds`, `loop_upper_bounds` and
+    `loop_steps` arguments.
 
     ```mlir
     omp.loop_nest (%i1, %i2) : i32 = (%c0, %c0) to (%c10, %c10) step (%c1, %c1) {
@@ -335,8 +335,6 @@ def LoopNestOp : OpenMP_Op<"loop_nest", traits = [
     non-perfectly nested loops.
   }];
 
-  let arguments = !con(clausesArgs, (ins UnitAttr:$inclusive));
-
   let builders = [
     OpBuilder<(ins CArg<"const LoopNestOperands &">:$clauses)>
   ];
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 18a1c2a1bd587..84de63b7806dc 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2047,7 +2047,7 @@ ParseResult LoopNestOp::parse(OpAsmParser &parser, OperationState &result) {
 
   // Parse "inclusive" flag.
   if (succeeded(parser.parseOptionalKeyword("inclusive")))
-    result.addAttribute("inclusive",
+    result.addAttribute("loop_inclusive",
                         UnitAttr::get(parser.getBuilder().getContext()));
 
   // Parse step values.
@@ -2075,28 +2075,28 @@ void LoopNestOp::print(OpAsmPrinter &p) {
   Region &region = getRegion();
   auto args = region.getArguments();
   p << " (" << args << ") : " << args[0].getType() << " = ("
-    << getCollapseLowerBounds() << ") to (" << getCollapseUpperBounds() << ") ";
-  if (getInclusive())
+    << getLoopLowerBounds() << ") to (" << getLoopUpperBounds() << ") ";
+  if (getLoopInclusive())
     p << "inclusive ";
-  p << "step (" << getCollapseSteps() << ") ";
+  p << "step (" << getLoopSteps() << ") ";
   p.printRegion(region, /*printEntryBlockArgs=*/false);
 }
 
 void LoopNestOp::build(OpBuilder &builder, OperationState &state,
                        const LoopNestOperands &clauses) {
-  LoopNestOp::build(builder, state, clauses.collapseLowerBounds,
-                    clauses.collapseUpperBounds, clauses.collapseSteps,
+  LoopNestOp::build(builder, state, clauses.loopLowerBounds,
+                    clauses.loopUpperBounds, clauses.loopSteps,
                     clauses.loopInclusive);
 }
 
 LogicalResult LoopNestOp::verify() {
-  if (getCollapseLowerBounds().empty())
+  if (getLoopLowerBounds().empty())
     return emitOpError() << "must represent at least one loop";
 
-  if (getCollapseLowerBounds().size() != getIVs().size())
+  if (getLoopLowerBounds().size() != getIVs().size())
     return emitOpError() << "number of range arguments and IVs do not match";
 
-  for (auto [lb, iv] : llvm::zip_equal(getCollapseLowerBounds(), getIVs())) {
+  for (auto [lb, iv] : llvm::zip_equal(getLoopLowerBounds(), getIVs())) {
     if (lb.getType() != iv.getType())
       return emitOpError()
              << "range argument type does not match corresponding IV type";
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index b7d1792852b08..fe3f258156690 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -1109,8 +1109,7 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
       wsloopOp.getScheduleKind().value_or(omp::ClauseScheduleKind::Static);
 
   // Find the loop configuration.
-  llvm::Value *step =
-      moduleTranslation.lookupValue(loopOp.getCollapseSteps()[0]);
+  llvm::Value *step = moduleTranslation.lookupValue(loopOp.getLoopSteps()[0]);
   llvm::Type *ivType = step->getType();
   llvm::Value *chunk = nullptr;
   if (wsloopOp.getScheduleChunk()) {
@@ -1179,11 +1178,10 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
   for (unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
     llvm::Value *lowerBound =
-        moduleTranslation.lookupValue(loopOp.getCollapseLowerBounds()[i]);
+        moduleTranslation.lookupValue(loopOp.getLoopLowerBounds()[i]);
     llvm::Value *upperBound =
-        moduleTranslation.lookupValue(loopOp.getCollapseUpperBounds()[i]);
-    llvm::Value *step =
-        moduleTranslation.lookupValue(loopOp.getCollapseSteps()[i]);
+        moduleTranslation.lookupValue(loopOp.getLoopUpperBounds()[i]);
+    llvm::Value *step = moduleTranslation.lookupValue(loopOp.getLoopSteps()[i]);
 
     // Make sure loop trip count are emitted in the preheader of the outermost
     // loop at the latest so that they are all available for the new collapsed
@@ -1196,7 +1194,7 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
     }
     loopInfos.push_back(ompBuilder->createCanonicalLoop(
         loc, bodyGen, lowerBound, upperBound, step,
-        /*IsSigned=*/true, loopOp.getInclusive(), computeIP));
+        /*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP));
 
     if (failed(bodyGenStatus))
       return failure();
@@ -1644,11 +1642,10 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
   for (unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
     llvm::Value *lowerBound =
-        moduleTranslation.lookupValue(loopOp.getCollapseLowerBounds()[i]);
+        moduleTranslation.lookupValue(loopOp.getLoopLowerBounds()[i]);
     llvm::Value *upperBound =
-        moduleTranslation.lookupValue(loopOp.getCollapseUpperBounds()[i]);
-    llvm::Value *step =
-        moduleTranslation.lookupValue(loopOp.getCollapseSteps()[i]);
+        moduleTranslation.lookupValue(loopOp.getLoopUpperBounds()[i]);
+    llvm::Value *step = moduleTranslation.lookupValue(loopOp.getLoopSteps()[i]);
 
     // Make sure loop trip count are emitted in the preheader of the outermost
     // loop at the latest so that they are all available for the new collapsed
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 9db98883113b4..f639af9b86722 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -184,7 +184,7 @@ func.func @omp_loop_nest(%lb : index, %ub : index, %step : index) -> () {
     "omp.loop_nest" (%lb, %ub, %step) ({
     ^bb0(%iv: index):
       omp.yield
-    }) {inclusive} : (index, index, index) -> ()
+    }) {loop_inclusive} : (index, index, index) -> ()
     omp.terminator
   }
 



More information about the flang-commits mailing list