[flang-commits] [flang] [flang] Add reduction semantics to fir.do_loop (PR #93934)

via flang-commits flang-commits at lists.llvm.org
Fri May 31 01:01:04 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-fir-hlfir

Author: None (khaki3)

<details>
<summary>Changes</summary>

Derived from #<!-- -->92480. This PR introduces `fir.reduce` into `fir.do_loop`. The operation `fir.reduce` conveys reduction semantics in a similar way to `acc.reduction`; it marks the reference to reduction variables while keeping their original names. The `fir.do_loop` operation now invisibly has the `operandSegmentsizes` attribute and takes variable-length reduction operands with their operations given as `fir.reduce_attr`. For the sake of compatibility, `fir.do_loop`'s builder has additional arguments at the end. The `iter_args` operand should be next to a return-type declaration, so the new operand for `fir.reduce` is put in the middle of arguments.


---
Full diff: https://github.com/llvm/llvm-project/pull/93934.diff


4 Files Affected:

- (modified) flang/include/flang/Optimizer/Dialect/FIRAttr.td (+30) 
- (modified) flang/include/flang/Optimizer/Dialect/FIROps.td (+56-8) 
- (modified) flang/lib/Optimizer/Dialect/FIRAttr.cpp (+2-2) 
- (modified) flang/lib/Optimizer/Dialect/FIROps.cpp (+66-7) 


``````````diff
diff --git a/flang/include/flang/Optimizer/Dialect/FIRAttr.td b/flang/include/flang/Optimizer/Dialect/FIRAttr.td
index 0c34b640a5c9c..aedb6769186e9 100644
--- a/flang/include/flang/Optimizer/Dialect/FIRAttr.td
+++ b/flang/include/flang/Optimizer/Dialect/FIRAttr.td
@@ -67,6 +67,36 @@ def fir_BoxFieldAttr : I32EnumAttr<
   let cppNamespace = "fir";
 }
 
+def fir_ReduceOperationEnum : I32BitEnumAttr<"ReduceOperationEnum",
+    "intrinsic operations and functions supported by DO CONCURRENT REDUCE",
+    [
+      I32BitEnumAttrCaseBit<"Add", 0, "add">,
+      I32BitEnumAttrCaseBit<"Multiply", 1, "multiply">,
+      I32BitEnumAttrCaseBit<"AND", 2, "and">,
+      I32BitEnumAttrCaseBit<"OR", 3, "or">,
+      I32BitEnumAttrCaseBit<"EQV", 4, "eqv">,
+      I32BitEnumAttrCaseBit<"NEQV", 5, "neqv">,
+      I32BitEnumAttrCaseBit<"MAX", 6, "max">,
+      I32BitEnumAttrCaseBit<"MIN", 7, "min">,
+      I32BitEnumAttrCaseBit<"IAND", 8, "iand">,
+      I32BitEnumAttrCaseBit<"IOR", 9, "ior">,
+      I32BitEnumAttrCaseBit<"EIOR", 10, "eior">
+    ]> {
+  let separator = ", ";
+  let cppNamespace = "::fir";
+  let printBitEnumPrimaryGroups = 1;
+}
+
+def fir_ReduceAttr : fir_Attr<"Reduce"> {
+  let mnemonic = "reduce_attr";
+
+  let parameters = (ins
+    "ReduceOperationEnum":$reduce_operation
+  );
+
+  let assemblyFormat = "`<` $reduce_operation `>`";
+}
+
 // mlir::SideEffects::Resource for modelling operations which add debugging information
 def DebuggingResource : Resource<"::fir::DebuggingResource">;
 
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 3afc97475db11..d79f2da916d05 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -2107,8 +2107,37 @@ class region_Op<string mnemonic, list<Trait> traits = []> :
   let hasVerifier = 1;
 }
 
-def fir_DoLoopOp : region_Op<"do_loop",
-    [DeclareOpInterfaceMethods<LoopLikeOpInterface,
+def fir_ReduceOp : fir_SimpleOp<"reduce", [NoMemoryEffect]> {
+  let summary = "Represent reduction semantics for the reduce clause";
+
+  let description = [{
+    Given the address of a variable, creates reduction information for the
+    reduce clause.
+
+    ```
+      %17 = fir.reduce %8 {name = "sum"} : (!fir.ref<f32>) -> !fir.ref<f32>
+      fir.do_loop ... unordered reduce(#fir.reduce_attr<add> -> %17 : !fir.ref<f32>) ...
+    ```
+
+    This operation is typically used for DO CONCURRENT REDUCE clause. The memref
+    operand may have a unique name while the `name` attribute preserves the
+    original name of a reduction variable.
+  }];
+
+  let arguments = (ins
+     AnyRefOrBoxLike:$memref,
+     Builtin_StringAttr:$name
+  );
+
+  let results = (outs AnyRefOrBox);
+
+  let assemblyFormat = [{
+    operands attr-dict `:` functional-type(operands, results)
+  }];
+}
+
+def fir_DoLoopOp : region_Op<"do_loop", [AttrSizedOperandSegments,
+    DeclareOpInterfaceMethods<LoopLikeOpInterface,
         ["getYieldedValuesMutable"]>]> {
   let summary = "generalized loop operation";
   let description = [{
@@ -2138,9 +2167,11 @@ def fir_DoLoopOp : region_Op<"do_loop",
     Index:$lowerBound,
     Index:$upperBound,
     Index:$step,
+    Variadic<AnyType>:$reduceOperands,
     Variadic<AnyType>:$initArgs,
     OptionalAttr<UnitAttr>:$unordered,
-    OptionalAttr<UnitAttr>:$finalValue
+    OptionalAttr<UnitAttr>:$finalValue,
+    OptionalAttr<ArrayAttr>:$reduceAttrs
   );
   let results = (outs Variadic<AnyType>:$results);
   let regions = (region SizedRegion<1>:$region);
@@ -2151,6 +2182,8 @@ def fir_DoLoopOp : region_Op<"do_loop",
       "mlir::Value":$step, CArg<"bool", "false">:$unordered,
       CArg<"bool", "false">:$finalCountValue,
       CArg<"mlir::ValueRange", "std::nullopt">:$iterArgs,
+      CArg<"mlir::ValueRange", "std::nullopt">:$reduceOperands,
+      CArg<"llvm::ArrayRef<mlir::Attribute>", "{}">:$reduceAttrs,
       CArg<"llvm::ArrayRef<mlir::NamedAttribute>", "{}">:$attributes)>
   ];
 
@@ -2163,11 +2196,12 @@ def fir_DoLoopOp : region_Op<"do_loop",
       return getBody()->getArguments().drop_front();
     }
     mlir::Operation::operand_range getIterOperands() {
-      return getOperands().drop_front(getNumControlOperands());
+      return getOperands()
+          .drop_front(getNumControlOperands() + getNumReduceOperands());
     }
     llvm::MutableArrayRef<mlir::OpOperand> getInitsMutable() {
-      return
-          getOperation()->getOpOperands().drop_front(getNumControlOperands());
+      return getOperation()->getOpOperands()
+          .drop_front(getNumControlOperands() + getNumReduceOperands());
     }
 
     void setLowerBound(mlir::Value bound) { (*this)->setOperand(0, bound); }
@@ -2182,11 +2216,25 @@ def fir_DoLoopOp : region_Op<"do_loop",
     unsigned getNumControlOperands() { return 3; }
     /// Does the operation hold operands for loop-carried values
     bool hasIterOperands() {
-      return (*this)->getNumOperands() > getNumControlOperands();
+      return getNumIterOperands() > 0;
+    }
+    /// Does the operation hold operands for reduction variables
+    bool hasReduceOperands() {
+      return getNumReduceOperands() > 0;
+    }
+    /// Get Number of variadic operands
+    unsigned getNumOperands(unsigned idx) {
+      auto segments = (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>(
+        getOperandSegmentSizeAttr());
+      return static_cast<unsigned>(segments[idx]);
+    }
+    // Get Number of reduction operands
+    unsigned getNumReduceOperands() {
+      return getNumOperands(3);
     }
     /// Get Number of loop-carried values
     unsigned getNumIterOperands() {
-      return (*this)->getNumOperands() - getNumControlOperands();
+      return getNumOperands(4);
     }
 
     /// Get the body of the loop
diff --git a/flang/lib/Optimizer/Dialect/FIRAttr.cpp b/flang/lib/Optimizer/Dialect/FIRAttr.cpp
index 2faba63dfba07..a0202a0159228 100644
--- a/flang/lib/Optimizer/Dialect/FIRAttr.cpp
+++ b/flang/lib/Optimizer/Dialect/FIRAttr.cpp
@@ -297,6 +297,6 @@ void fir::printFirAttribute(FIROpsDialect *dialect, mlir::Attribute attr,
 
 void FIROpsDialect::registerAttributes() {
   addAttributes<ClosedIntervalAttr, ExactTypeAttr, FortranVariableFlagsAttr,
-                LowerBoundAttr, PointIntervalAttr, RealAttr, SubclassAttr,
-                UpperBoundAttr>();
+                LowerBoundAttr, PointIntervalAttr, RealAttr, ReduceAttr,
+                SubclassAttr, UpperBoundAttr>();
 }
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index b541b7cdc7a5b..807459c8ec3c7 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -2079,9 +2079,16 @@ void fir::DoLoopOp::build(mlir::OpBuilder &builder,
                           mlir::OperationState &result, mlir::Value lb,
                           mlir::Value ub, mlir::Value step, bool unordered,
                           bool finalCountValue, mlir::ValueRange iterArgs,
+                          mlir::ValueRange reduceOperands,
+                          llvm::ArrayRef<mlir::Attribute> reduceAttrs,
                           llvm::ArrayRef<mlir::NamedAttribute> attributes) {
   result.addOperands({lb, ub, step});
+  result.addOperands(reduceOperands);
   result.addOperands(iterArgs);
+  result.addAttribute(getOperandSegmentSizeAttr(),
+                      builder.getDenseI32ArrayAttr(
+                          {1, 1, 1, static_cast<int32_t>(reduceOperands.size()),
+                           static_cast<int32_t>(iterArgs.size())}));
   if (finalCountValue) {
     result.addTypes(builder.getIndexType());
     result.addAttribute(getFinalValueAttrName(result.name),
@@ -2100,6 +2107,9 @@ void fir::DoLoopOp::build(mlir::OpBuilder &builder,
   if (unordered)
     result.addAttribute(getUnorderedAttrName(result.name),
                         builder.getUnitAttr());
+  if (!reduceAttrs.empty())
+    result.addAttribute(getReduceAttrsAttrName(result.name),
+                        builder.getArrayAttr(reduceAttrs));
   result.addAttributes(attributes);
 }
 
@@ -2125,24 +2135,51 @@ mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser,
   if (mlir::succeeded(parser.parseOptionalKeyword("unordered")))
     result.addAttribute("unordered", builder.getUnitAttr());
 
+  // Parse the reduction arguments.
+  llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> reduceOperands;
+  llvm::SmallVector<mlir::Type> reduceArgTypes;
+  if (succeeded(parser.parseOptionalKeyword("reduce"))) {
+    // Parse reduction attributes and variables.
+    llvm::SmallVector<ReduceAttr> attributes;
+    if (failed(parser.parseCommaSeparatedList(
+            mlir::AsmParser::Delimiter::Paren, [&]() {
+              if (parser.parseAttribute(attributes.emplace_back()) ||
+                  parser.parseArrow() ||
+                  parser.parseOperand(reduceOperands.emplace_back()) ||
+                  parser.parseColonType(reduceArgTypes.emplace_back()))
+                return mlir::failure();
+              return mlir::success();
+            })))
+      return mlir::failure();
+    // Resolve input operands.
+    for (auto operand_type : llvm::zip(reduceOperands, reduceArgTypes))
+      if (parser.resolveOperand(std::get<0>(operand_type),
+                                std::get<1>(operand_type), result.operands))
+        return mlir::failure();
+    llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
+                                                 attributes.end());
+    result.addAttribute(getReduceAttrsAttrName(result.name),
+                        builder.getArrayAttr(arrayAttr));
+  }
+
   // Parse the optional initial iteration arguments.
   llvm::SmallVector<mlir::OpAsmParser::Argument> regionArgs;
-  llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> operands;
+  llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> iterOperands;
   llvm::SmallVector<mlir::Type> argTypes;
   bool prependCount = false;
   regionArgs.push_back(inductionVariable);
 
   if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
     // Parse assignment list and results type list.
-    if (parser.parseAssignmentList(regionArgs, operands) ||
+    if (parser.parseAssignmentList(regionArgs, iterOperands) ||
         parser.parseArrowTypeList(result.types))
       return mlir::failure();
-    if (result.types.size() == operands.size() + 1)
+    if (result.types.size() == iterOperands.size() + 1)
       prependCount = true;
     // Resolve input operands.
     llvm::ArrayRef<mlir::Type> resTypes = result.types;
-    for (auto operand_type :
-         llvm::zip(operands, prependCount ? resTypes.drop_front() : resTypes))
+    for (auto operand_type : llvm::zip(
+             iterOperands, prependCount ? resTypes.drop_front() : resTypes))
       if (parser.resolveOperand(std::get<0>(operand_type),
                                 std::get<1>(operand_type), result.operands))
         return mlir::failure();
@@ -2153,6 +2190,12 @@ mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser,
     prependCount = true;
   }
 
+  // Set the operandSegmentSizes attribute
+  result.addAttribute(getOperandSegmentSizeAttr(),
+                      builder.getDenseI32ArrayAttr(
+                          {1, 1, 1, static_cast<int32_t>(reduceOperands.size()),
+                           static_cast<int32_t>(iterOperands.size())}));
+
   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
     return mlir::failure();
 
@@ -2229,6 +2272,10 @@ mlir::LogicalResult fir::DoLoopOp::verify() {
 
     i++;
   }
+  auto reduceAttrs = getReduceAttrsAttr();
+  if (getNumReduceOperands() != (reduceAttrs ? reduceAttrs.size() : 0))
+    return emitOpError(
+        "mismatch in number of reduction variables and reduction attributes");
   return mlir::success();
 }
 
@@ -2238,6 +2285,17 @@ void fir::DoLoopOp::print(mlir::OpAsmPrinter &p) {
     << getUpperBound() << " step " << getStep();
   if (getUnordered())
     p << " unordered";
+  if (hasReduceOperands()) {
+    p << " reduce(";
+    auto attrs = getReduceAttrsAttr();
+    auto operands = getReduceOperands();
+    llvm::interleaveComma(llvm::zip(attrs, operands), p, [&](auto it) {
+      p << std::get<0>(it) << " -> " << std::get<1>(it) << " : "
+        << std::get<1>(it).getType();
+    });
+    p << ')';
+    printBlockTerminators = true;
+  }
   if (hasIterOperands()) {
     p << " iter_args(";
     auto regionArgs = getRegionIterArgs();
@@ -2251,8 +2309,9 @@ void fir::DoLoopOp::print(mlir::OpAsmPrinter &p) {
     p << " -> " << getResultTypes();
     printBlockTerminators = true;
   }
-  p.printOptionalAttrDictWithKeyword((*this)->getAttrs(),
-                                     {"unordered", "finalValue"});
+  p.printOptionalAttrDictWithKeyword(
+      (*this)->getAttrs(),
+      {"unordered", "finalValue", "reduceAttrs", "operandSegmentSizes"});
   p << ' ';
   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
                 printBlockTerminators);

``````````

</details>


https://github.com/llvm/llvm-project/pull/93934


More information about the flang-commits mailing list