[flang-commits] [flang] 88cdd99 - [flang] Add reduction semantics to fir.do_loop (#93934)

via flang-commits flang-commits at lists.llvm.org
Thu Jun 6 11:16:44 PDT 2024


Author: khaki3
Date: 2024-06-06T11:16:40-07:00
New Revision: 88cdd9905597ace5b1ac7d080df5326d3399b3f8

URL: https://github.com/llvm/llvm-project/commit/88cdd9905597ace5b1ac7d080df5326d3399b3f8
DIFF: https://github.com/llvm/llvm-project/commit/88cdd9905597ace5b1ac7d080df5326d3399b3f8.diff

LOG: [flang] Add reduction semantics to fir.do_loop (#93934)

Derived from #92480. This PR introduces reduction semantics into loops
for DO CONCURRENT REDUCE. The `fir.do_loop` operation now invisibly has
the `operandSegmentsizes` attribute and takes variable-length reduction
operands with their operations given as `fir.reduce_attr`. For the sake
of compatibility, `fir.do_loop`'s builder has additional arguments at
the end. The `iter_args` operand should be placed in front of the
declaration of result types, so the new operand for reduction variables
(`reduce`) is put in the middle of arguments.

Added: 
    flang/test/Fir/loop03.fir

Modified: 
    flang/include/flang/Optimizer/Dialect/FIRAttr.td
    flang/include/flang/Optimizer/Dialect/FIROps.td
    flang/lib/Optimizer/Dialect/FIRAttr.cpp
    flang/lib/Optimizer/Dialect/FIROps.cpp

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/Dialect/FIRAttr.td b/flang/include/flang/Optimizer/Dialect/FIRAttr.td
index 0c34b640a5c9c..aedb6769186e9 100644
--- a/flang/include/flang/Optimizer/Dialect/FIRAttr.td
+++ b/flang/include/flang/Optimizer/Dialect/FIRAttr.td
@@ -67,6 +67,36 @@ def fir_BoxFieldAttr : I32EnumAttr<
   let cppNamespace = "fir";
 }
 
+def fir_ReduceOperationEnum : I32BitEnumAttr<"ReduceOperationEnum",
+    "intrinsic operations and functions supported by DO CONCURRENT REDUCE",
+    [
+      I32BitEnumAttrCaseBit<"Add", 0, "add">,
+      I32BitEnumAttrCaseBit<"Multiply", 1, "multiply">,
+      I32BitEnumAttrCaseBit<"AND", 2, "and">,
+      I32BitEnumAttrCaseBit<"OR", 3, "or">,
+      I32BitEnumAttrCaseBit<"EQV", 4, "eqv">,
+      I32BitEnumAttrCaseBit<"NEQV", 5, "neqv">,
+      I32BitEnumAttrCaseBit<"MAX", 6, "max">,
+      I32BitEnumAttrCaseBit<"MIN", 7, "min">,
+      I32BitEnumAttrCaseBit<"IAND", 8, "iand">,
+      I32BitEnumAttrCaseBit<"IOR", 9, "ior">,
+      I32BitEnumAttrCaseBit<"EIOR", 10, "eior">
+    ]> {
+  let separator = ", ";
+  let cppNamespace = "::fir";
+  let printBitEnumPrimaryGroups = 1;
+}
+
+def fir_ReduceAttr : fir_Attr<"Reduce"> {
+  let mnemonic = "reduce_attr";
+
+  let parameters = (ins
+    "ReduceOperationEnum":$reduce_operation
+  );
+
+  let assemblyFormat = "`<` $reduce_operation `>`";
+}
+
 // mlir::SideEffects::Resource for modelling operations which add debugging information
 def DebuggingResource : Resource<"::fir::DebuggingResource">;
 

diff  --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 37fbd1f9692a4..e7da3af5485cc 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -2125,8 +2125,8 @@ class region_Op<string mnemonic, list<Trait> traits = []> :
   let hasVerifier = 1;
 }
 
-def fir_DoLoopOp : region_Op<"do_loop",
-    [DeclareOpInterfaceMethods<LoopLikeOpInterface,
+def fir_DoLoopOp : region_Op<"do_loop", [AttrSizedOperandSegments,
+    DeclareOpInterfaceMethods<LoopLikeOpInterface,
         ["getYieldedValuesMutable"]>]> {
   let summary = "generalized loop operation";
   let description = [{
@@ -2156,9 +2156,11 @@ def fir_DoLoopOp : region_Op<"do_loop",
     Index:$lowerBound,
     Index:$upperBound,
     Index:$step,
+    Variadic<AnyType>:$reduceOperands,
     Variadic<AnyType>:$initArgs,
     OptionalAttr<UnitAttr>:$unordered,
-    OptionalAttr<UnitAttr>:$finalValue
+    OptionalAttr<UnitAttr>:$finalValue,
+    OptionalAttr<ArrayAttr>:$reduceAttrs
   );
   let results = (outs Variadic<AnyType>:$results);
   let regions = (region SizedRegion<1>:$region);
@@ -2169,6 +2171,8 @@ def fir_DoLoopOp : region_Op<"do_loop",
       "mlir::Value":$step, CArg<"bool", "false">:$unordered,
       CArg<"bool", "false">:$finalCountValue,
       CArg<"mlir::ValueRange", "std::nullopt">:$iterArgs,
+      CArg<"mlir::ValueRange", "std::nullopt">:$reduceOperands,
+      CArg<"llvm::ArrayRef<mlir::Attribute>", "{}">:$reduceAttrs,
       CArg<"llvm::ArrayRef<mlir::NamedAttribute>", "{}">:$attributes)>
   ];
 
@@ -2181,11 +2185,12 @@ def fir_DoLoopOp : region_Op<"do_loop",
       return getBody()->getArguments().drop_front();
     }
     mlir::Operation::operand_range getIterOperands() {
-      return getOperands().drop_front(getNumControlOperands());
+      return getOperands()
+          .drop_front(getNumControlOperands() + getNumReduceOperands());
     }
     llvm::MutableArrayRef<mlir::OpOperand> getInitsMutable() {
-      return
-          getOperation()->getOpOperands().drop_front(getNumControlOperands());
+      return getOperation()->getOpOperands()
+          .drop_front(getNumControlOperands() + getNumReduceOperands());
     }
 
     void setLowerBound(mlir::Value bound) { (*this)->setOperand(0, bound); }
@@ -2200,11 +2205,25 @@ def fir_DoLoopOp : region_Op<"do_loop",
     unsigned getNumControlOperands() { return 3; }
     /// Does the operation hold operands for loop-carried values
     bool hasIterOperands() {
-      return (*this)->getNumOperands() > getNumControlOperands();
+      return getNumIterOperands() > 0;
+    }
+    /// Does the operation hold operands for reduction variables
+    bool hasReduceOperands() {
+      return getNumReduceOperands() > 0;
+    }
+    /// Get Number of variadic operands
+    unsigned getNumOperands(unsigned idx) {
+      auto segments = (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>(
+        getOperandSegmentSizeAttr());
+      return static_cast<unsigned>(segments[idx]);
+    }
+    // Get Number of reduction operands
+    unsigned getNumReduceOperands() {
+      return getNumOperands(3);
     }
     /// Get Number of loop-carried values
     unsigned getNumIterOperands() {
-      return (*this)->getNumOperands() - getNumControlOperands();
+      return getNumOperands(4);
     }
 
     /// Get the body of the loop

diff  --git a/flang/lib/Optimizer/Dialect/FIRAttr.cpp b/flang/lib/Optimizer/Dialect/FIRAttr.cpp
index 2faba63dfba07..a0202a0159228 100644
--- a/flang/lib/Optimizer/Dialect/FIRAttr.cpp
+++ b/flang/lib/Optimizer/Dialect/FIRAttr.cpp
@@ -297,6 +297,6 @@ void fir::printFirAttribute(FIROpsDialect *dialect, mlir::Attribute attr,
 
 void FIROpsDialect::registerAttributes() {
   addAttributes<ClosedIntervalAttr, ExactTypeAttr, FortranVariableFlagsAttr,
-                LowerBoundAttr, PointIntervalAttr, RealAttr, SubclassAttr,
-                UpperBoundAttr>();
+                LowerBoundAttr, PointIntervalAttr, RealAttr, ReduceAttr,
+                SubclassAttr, UpperBoundAttr>();
 }

diff  --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index b530a9dc1bcc4..75ca738211abe 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -2456,9 +2456,16 @@ void fir::DoLoopOp::build(mlir::OpBuilder &builder,
                           mlir::OperationState &result, mlir::Value lb,
                           mlir::Value ub, mlir::Value step, bool unordered,
                           bool finalCountValue, mlir::ValueRange iterArgs,
+                          mlir::ValueRange reduceOperands,
+                          llvm::ArrayRef<mlir::Attribute> reduceAttrs,
                           llvm::ArrayRef<mlir::NamedAttribute> attributes) {
   result.addOperands({lb, ub, step});
+  result.addOperands(reduceOperands);
   result.addOperands(iterArgs);
+  result.addAttribute(getOperandSegmentSizeAttr(),
+                      builder.getDenseI32ArrayAttr(
+                          {1, 1, 1, static_cast<int32_t>(reduceOperands.size()),
+                           static_cast<int32_t>(iterArgs.size())}));
   if (finalCountValue) {
     result.addTypes(builder.getIndexType());
     result.addAttribute(getFinalValueAttrName(result.name),
@@ -2477,6 +2484,9 @@ void fir::DoLoopOp::build(mlir::OpBuilder &builder,
   if (unordered)
     result.addAttribute(getUnorderedAttrName(result.name),
                         builder.getUnitAttr());
+  if (!reduceAttrs.empty())
+    result.addAttribute(getReduceAttrsAttrName(result.name),
+                        builder.getArrayAttr(reduceAttrs));
   result.addAttributes(attributes);
 }
 
@@ -2502,24 +2512,51 @@ mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser,
   if (mlir::succeeded(parser.parseOptionalKeyword("unordered")))
     result.addAttribute("unordered", builder.getUnitAttr());
 
+  // Parse the reduction arguments.
+  llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> reduceOperands;
+  llvm::SmallVector<mlir::Type> reduceArgTypes;
+  if (succeeded(parser.parseOptionalKeyword("reduce"))) {
+    // Parse reduction attributes and variables.
+    llvm::SmallVector<ReduceAttr> attributes;
+    if (failed(parser.parseCommaSeparatedList(
+            mlir::AsmParser::Delimiter::Paren, [&]() {
+              if (parser.parseAttribute(attributes.emplace_back()) ||
+                  parser.parseArrow() ||
+                  parser.parseOperand(reduceOperands.emplace_back()) ||
+                  parser.parseColonType(reduceArgTypes.emplace_back()))
+                return mlir::failure();
+              return mlir::success();
+            })))
+      return mlir::failure();
+    // Resolve input operands.
+    for (auto operand_type : llvm::zip(reduceOperands, reduceArgTypes))
+      if (parser.resolveOperand(std::get<0>(operand_type),
+                                std::get<1>(operand_type), result.operands))
+        return mlir::failure();
+    llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
+                                                 attributes.end());
+    result.addAttribute(getReduceAttrsAttrName(result.name),
+                        builder.getArrayAttr(arrayAttr));
+  }
+
   // Parse the optional initial iteration arguments.
   llvm::SmallVector<mlir::OpAsmParser::Argument> regionArgs;
-  llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> operands;
+  llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> iterOperands;
   llvm::SmallVector<mlir::Type> argTypes;
   bool prependCount = false;
   regionArgs.push_back(inductionVariable);
 
   if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
     // Parse assignment list and results type list.
-    if (parser.parseAssignmentList(regionArgs, operands) ||
+    if (parser.parseAssignmentList(regionArgs, iterOperands) ||
         parser.parseArrowTypeList(result.types))
       return mlir::failure();
-    if (result.types.size() == operands.size() + 1)
+    if (result.types.size() == iterOperands.size() + 1)
       prependCount = true;
     // Resolve input operands.
     llvm::ArrayRef<mlir::Type> resTypes = result.types;
-    for (auto operand_type :
-         llvm::zip(operands, prependCount ? resTypes.drop_front() : resTypes))
+    for (auto operand_type : llvm::zip(
+             iterOperands, prependCount ? resTypes.drop_front() : resTypes))
       if (parser.resolveOperand(std::get<0>(operand_type),
                                 std::get<1>(operand_type), result.operands))
         return mlir::failure();
@@ -2530,6 +2567,12 @@ mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser,
     prependCount = true;
   }
 
+  // Set the operandSegmentSizes attribute
+  result.addAttribute(getOperandSegmentSizeAttr(),
+                      builder.getDenseI32ArrayAttr(
+                          {1, 1, 1, static_cast<int32_t>(reduceOperands.size()),
+                           static_cast<int32_t>(iterOperands.size())}));
+
   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
     return mlir::failure();
 
@@ -2606,6 +2649,10 @@ mlir::LogicalResult fir::DoLoopOp::verify() {
 
     i++;
   }
+  auto reduceAttrs = getReduceAttrsAttr();
+  if (getNumReduceOperands() != (reduceAttrs ? reduceAttrs.size() : 0))
+    return emitOpError(
+        "mismatch in number of reduction variables and reduction attributes");
   return mlir::success();
 }
 
@@ -2615,6 +2662,17 @@ void fir::DoLoopOp::print(mlir::OpAsmPrinter &p) {
     << getUpperBound() << " step " << getStep();
   if (getUnordered())
     p << " unordered";
+  if (hasReduceOperands()) {
+    p << " reduce(";
+    auto attrs = getReduceAttrsAttr();
+    auto operands = getReduceOperands();
+    llvm::interleaveComma(llvm::zip(attrs, operands), p, [&](auto it) {
+      p << std::get<0>(it) << " -> " << std::get<1>(it) << " : "
+        << std::get<1>(it).getType();
+    });
+    p << ')';
+    printBlockTerminators = true;
+  }
   if (hasIterOperands()) {
     p << " iter_args(";
     auto regionArgs = getRegionIterArgs();
@@ -2628,8 +2686,9 @@ void fir::DoLoopOp::print(mlir::OpAsmPrinter &p) {
     p << " -> " << getResultTypes();
     printBlockTerminators = true;
   }
-  p.printOptionalAttrDictWithKeyword((*this)->getAttrs(),
-                                     {"unordered", "finalValue"});
+  p.printOptionalAttrDictWithKeyword(
+      (*this)->getAttrs(),
+      {"unordered", "finalValue", "reduceAttrs", "operandSegmentSizes"});
   p << ' ';
   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
                 printBlockTerminators);

diff  --git a/flang/test/Fir/loop03.fir b/flang/test/Fir/loop03.fir
new file mode 100644
index 0000000000000..b88dcaf8639be
--- /dev/null
+++ b/flang/test/Fir/loop03.fir
@@ -0,0 +1,17 @@
+// Test the reduction semantics of fir.do_loop
+// RUN: fir-opt %s | FileCheck %s
+
+func.func @reduction() {
+  %bound = arith.constant 10 : index
+  %step = arith.constant 1 : index
+  %sum = fir.alloca i32
+// CHECK: %[[VAL_0:.*]] = fir.alloca i32
+// CHECK: fir.do_loop %[[VAL_1:.*]] = %[[VAL_2:.*]] to %[[VAL_3:.*]] step %[[VAL_4:.*]] unordered reduce(#fir.reduce_attr<add> -> %[[VAL_0]] : !fir.ref<i32>) {
+  fir.do_loop %iv = %step to %bound step %step unordered reduce(#fir.reduce_attr<add> -> %sum : !fir.ref<i32>) {
+    %index = fir.convert %iv : (index) -> i32
+    %1 = fir.load %sum : !fir.ref<i32>
+    %2 = arith.addi %index, %1 : i32
+    fir.store %2 to %sum : !fir.ref<i32>
+  }
+  return
+}


        


More information about the flang-commits mailing list