[flang-commits] [flang] [flang] Support DO CONCURRENT REDUCE clause (PR #92480)
via flang-commits
flang-commits at lists.llvm.org
Thu May 16 18:28:44 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-openacc
@llvm/pr-subscribers-flang-fir-hlfir
Author: None (khaki3)
<details>
<summary>Changes</summary>
This PR implements the DO CONCURRENT REDUCE clause in Fortran 2023. We update Parser, Optimizer, and Lower so that FIR utilizes a new operation tailored to reduction semantics.
- The change in Parser follows the style of the OpenMP parser in MLIR. The front end accepts both arbitrary operations and procedures for the REDUCE clause. But later Semantics can notify type errors.
- Optimizer introduces `fir.reduce`, which is similar to `acc.reduction`, in order to track the reference to reduction variables while keeping their original names. The `fir.do_loop` operation now has the `operandSegmentsizes` attribute and takes variable-length reduction operands with their operations.
- Lower extends `fir.do_loop` with reduction information only if it finds DO CONCURRENT loops with REDUCE.
Newly added tests are these three: Both `Semantics/resolve123.f90` and `Semantics/resolve124.f90` check parsing while `Lower/loops3.f90` validates FIR with reduction semantics. Many tests are updated to set the `operandSegmentsizes` attribute.
---
Patch is 688.69 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/92480.diff
158 Files Affected:
- (modified) flang/examples/FeatureList/FeatureList.cpp (+2)
- (modified) flang/include/flang/Optimizer/Dialect/FIRAttr.td (+30)
- (modified) flang/include/flang/Optimizer/Dialect/FIROps.td (+57-8)
- (modified) flang/include/flang/Parser/dump-parse-tree.h (+2)
- (modified) flang/include/flang/Parser/parse-tree.h (+21-8)
- (modified) flang/include/flang/Semantics/symbol.h (+1)
- (modified) flang/lib/Lower/Bridge.cpp (+86-2)
- (modified) flang/lib/Optimizer/Dialect/FIRAttr.cpp (+4-4)
- (modified) flang/lib/Optimizer/Dialect/FIROps.cpp (+63-8)
- (modified) flang/lib/Parser/executable-parsers.cpp (+10)
- (modified) flang/lib/Parser/unparse.cpp (+4)
- (modified) flang/lib/Semantics/check-do-forall.cpp (+89)
- (modified) flang/lib/Semantics/resolve-names.cpp (+74-29)
- (modified) flang/test/Analysis/AliasAnalysis/alias-analysis-5.fir (+11-11)
- (modified) flang/test/Analysis/AliasAnalysis/alias-analysis-7.fir (+2-2)
- (modified) flang/test/Fir/affine-promotion.fir (+5-5)
- (modified) flang/test/Fir/array-copies-pointers.fir (+10-10)
- (modified) flang/test/Fir/array-modify.fir (+6-6)
- (modified) flang/test/Fir/array-value-copy-2.fir (+2-2)
- (modified) flang/test/Fir/array-value-copy-3.fir (+1-1)
- (modified) flang/test/Fir/array-value-copy-4.fir (+1-1)
- (modified) flang/test/Fir/array-value-copy-cam4.fir (+2-2)
- (modified) flang/test/Fir/array-value-copy.fir (+26-26)
- (modified) flang/test/Fir/arrayset.fir (+1-1)
- (modified) flang/test/Fir/arrexp.fir (+10-10)
- (modified) flang/test/Fir/char-conversion.fir (+1-1)
- (modified) flang/test/Fir/convert-to-llvm-openmp-and-fir.fir (+3-3)
- (modified) flang/test/Fir/embox-write.fir (+1-1)
- (modified) flang/test/Fir/fir-ops.fir (+5-5)
- (modified) flang/test/Fir/invalid.fir (+6-6)
- (modified) flang/test/Fir/loop01.fir (+6-6)
- (modified) flang/test/Fir/loop02.fir (+1-1)
- (modified) flang/test/Fir/loop10.fir (+2-2)
- (modified) flang/test/Fir/target-rewrite-boxchar.fir (+1-1)
- (modified) flang/test/Fir/target.fir (+1-1)
- (modified) flang/test/Fir/tbaa-codegen2.fir (+1-1)
- (modified) flang/test/HLFIR/all-elemental.fir (+1-1)
- (modified) flang/test/HLFIR/any-elemental.fir (+4-4)
- (modified) flang/test/HLFIR/assign-codegen.fir (+3-3)
- (modified) flang/test/HLFIR/associate-codegen.fir (+5-5)
- (modified) flang/test/HLFIR/bufferize-poly-expr.fir (+2-2)
- (modified) flang/test/HLFIR/bufferize01.fir (+1-1)
- (modified) flang/test/HLFIR/char_extremum-bufferization.fir (+6-6)
- (modified) flang/test/HLFIR/concat-bufferization.fir (+3-3)
- (modified) flang/test/HLFIR/convert-assign-inside-openacc-recipe.fir (+1-1)
- (modified) flang/test/HLFIR/count-elemental.fir (+8-8)
- (modified) flang/test/HLFIR/elemental-codegen-nested.fir (+2-2)
- (modified) flang/test/HLFIR/elemental-codegen.fir (+13-13)
- (modified) flang/test/HLFIR/elemental-shallow-copy.fir (+1-1)
- (modified) flang/test/HLFIR/extents-of-shape-of.f90 (+4-4)
- (modified) flang/test/HLFIR/inline-elemental.fir (+1-1)
- (modified) flang/test/HLFIR/maxloc-elemental.fir (+3-3)
- (modified) flang/test/HLFIR/minloc-elemental.fir (+7-7)
- (modified) flang/test/HLFIR/opt-array-slice-assign.fir (+9-9)
- (modified) flang/test/HLFIR/opt-bufferization-leslie3d.fir (+2-2)
- (modified) flang/test/HLFIR/opt-bufferization.fir (+12-12)
- (modified) flang/test/HLFIR/opt-scalar-assign.fir (+7-7)
- (modified) flang/test/HLFIR/opt-variable-assign.fir (+12-12)
- (modified) flang/test/HLFIR/order_assignments/forall-codegen-fuse-assignments.fir (+1-1)
- (modified) flang/test/HLFIR/order_assignments/forall-codegen-no-conflict.fir (+7-7)
- (modified) flang/test/HLFIR/order_assignments/inlined-stack-temp.fir (+7-7)
- (modified) flang/test/HLFIR/order_assignments/lhs-conflicts-codegen.fir (+2-2)
- (modified) flang/test/HLFIR/order_assignments/runtime-stack-temp.fir (+2-2)
- (modified) flang/test/HLFIR/order_assignments/user-defined-assignment-finalization.fir (+3-3)
- (modified) flang/test/HLFIR/order_assignments/user-defined-assignment.fir (+3-3)
- (modified) flang/test/HLFIR/order_assignments/vector-subscripts-codegen.fir (+5-5)
- (modified) flang/test/HLFIR/order_assignments/where-codegen-no-conflict.fir (+7-7)
- (modified) flang/test/Lower/HLFIR/array-ctor-as-inlined-temp.f90 (+4-4)
- (modified) flang/test/Lower/HLFIR/array-ctor-as-runtime-temp.f90 (+2-2)
- (modified) flang/test/Lower/HLFIR/calls-optional.f90 (+3-3)
- (modified) flang/test/Lower/HLFIR/elemental-call-vector-subscripts.f90 (+3-3)
- (modified) flang/test/Lower/HLFIR/elemental-user-procedure-ref.f90 (+8-8)
- (modified) flang/test/Lower/HLFIR/intrinsic-subroutines.f90 (+1-1)
- (modified) flang/test/Lower/Intrinsics/cmplx.f90 (+1-1)
- (modified) flang/test/Lower/Intrinsics/ieee_festatus.f90 (+5-5)
- (modified) flang/test/Lower/Intrinsics/ieee_flag.f90 (+20-20)
- (modified) flang/test/Lower/Intrinsics/index.f90 (+1-1)
- (modified) flang/test/Lower/Intrinsics/max.f90 (+2-2)
- (modified) flang/test/Lower/Intrinsics/mvbits.f90 (+1-1)
- (modified) flang/test/Lower/Intrinsics/scan.f90 (+2-2)
- (modified) flang/test/Lower/Intrinsics/transfer.f90 (+1-1)
- (modified) flang/test/Lower/Intrinsics/transpose_opt.f90 (+8-8)
- (modified) flang/test/Lower/Intrinsics/verify.f90 (+1-1)
- (modified) flang/test/Lower/OpenACC/acc-declare.f90 (+3-3)
- (modified) flang/test/Lower/OpenACC/acc-reduction.f90 (+15-15)
- (modified) flang/test/Lower/OpenMP/hlfir-seqloop-parallel.f90 (+4-4)
- (modified) flang/test/Lower/OpenMP/parallel-private-clause-fixes.f90 (+1-1)
- (modified) flang/test/Lower/OpenMP/parallel-reduction-allocatable-array.f90 (+1-1)
- (modified) flang/test/Lower/OpenMP/parallel-reduction-array-lb.f90 (+2-2)
- (modified) flang/test/Lower/OpenMP/parallel-reduction-array.f90 (+1-1)
- (modified) flang/test/Lower/OpenMP/parallel-reduction-array2.f90 (+1-1)
- (modified) flang/test/Lower/OpenMP/parallel-reduction3.f90 (+1-1)
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-array-assumed-shape.f90 (+1-1)
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-array.f90 (+1-1)
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-array2.f90 (+1-1)
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-multiple-clauses.f90 (+2-2)
- (modified) flang/test/Lower/OpenMP/wsloop-variable.f90 (+1-1)
- (modified) flang/test/Lower/allocatable-assignment.f90 (+42-42)
- (modified) flang/test/Lower/allocate-source-allocatables.f90 (+1-1)
- (modified) flang/test/Lower/allocate-source-pointers.f90 (+1-1)
- (modified) flang/test/Lower/array-constructor-index.f90 (+8-8)
- (modified) flang/test/Lower/array-elemental-calls.f90 (+2-2)
- (modified) flang/test/Lower/array-elemental-subroutines.f90 (+1-1)
- (modified) flang/test/Lower/array-expression-assumed-size.f90 (+4-4)
- (modified) flang/test/Lower/array-expression-subscript.f90 (+4-4)
- (modified) flang/test/Lower/array-expression.f90 (+40-40)
- (modified) flang/test/Lower/array-user-def-assignments.f90 (+21-21)
- (modified) flang/test/Lower/assignment.f90 (+1-1)
- (modified) flang/test/Lower/assumed-shape-caller.f90 (+1-1)
- (modified) flang/test/Lower/call-parenthesized-arg.f90 (+4-4)
- (modified) flang/test/Lower/character-assignment.f90 (+4-4)
- (modified) flang/test/Lower/character-concatenation.f90 (+1-1)
- (modified) flang/test/Lower/character-substrings.f90 (+10-10)
- (modified) flang/test/Lower/cray-pointer.f90 (+4-4)
- (modified) flang/test/Lower/derived-assignments.f90 (+1-1)
- (modified) flang/test/Lower/do_loop.f90 (+7-7)
- (modified) flang/test/Lower/do_loop_unstructured.f90 (+2-2)
- (modified) flang/test/Lower/dummy-argument-contiguous.f90 (+2-2)
- (modified) flang/test/Lower/entry-statement.f90 (+2-2)
- (modified) flang/test/Lower/forall/array-constructor.f90 (+7-7)
- (modified) flang/test/Lower/forall/array-pointer.f90 (+14-14)
- (modified) flang/test/Lower/forall/degenerate.f90 (+1-1)
- (modified) flang/test/Lower/forall/forall-2.f90 (+4-4)
- (modified) flang/test/Lower/forall/forall-allocatable-2.f90 (+1-1)
- (modified) flang/test/Lower/forall/forall-allocatable.f90 (+1-1)
- (modified) flang/test/Lower/forall/forall-array.f90 (+2-2)
- (modified) flang/test/Lower/forall/forall-construct-2.f90 (+4-4)
- (modified) flang/test/Lower/forall/forall-construct-3.f90 (+4-4)
- (modified) flang/test/Lower/forall/forall-construct.f90 (+2-2)
- (modified) flang/test/Lower/forall/forall-ranked.f90 (+2-2)
- (modified) flang/test/Lower/forall/forall-slice.f90 (+3-3)
- (modified) flang/test/Lower/forall/forall-stmt.f90 (+1-1)
- (modified) flang/test/Lower/forall/forall-where.f90 (+9-9)
- (modified) flang/test/Lower/forall/scalar-substring.f90 (+2-2)
- (modified) flang/test/Lower/forall/test9.f90 (+1-1)
- (modified) flang/test/Lower/infinite_loop.f90 (+2-2)
- (modified) flang/test/Lower/io-implied-do-fixes.f90 (+4-4)
- (modified) flang/test/Lower/loops.f90 (+8-8)
- (added) flang/test/Lower/loops3.f90 (+21)
- (modified) flang/test/Lower/mixed_loops.f90 (+1-1)
- (modified) flang/test/Lower/nested-where.f90 (+6-6)
- (modified) flang/test/Lower/optional-value-caller.f90 (+1-1)
- (modified) flang/test/Lower/pointer-references.f90 (+2-2)
- (modified) flang/test/Lower/polymorphic.f90 (+15-15)
- (modified) flang/test/Lower/select-type.f90 (+12-12)
- (modified) flang/test/Lower/statement-function.f90 (+1-1)
- (modified) flang/test/Lower/structure-constructors.f90 (+5-5)
- (modified) flang/test/Lower/submodule.f90 (+3-3)
- (modified) flang/test/Lower/transformational-intrinsics.f90 (+3-3)
- (modified) flang/test/Lower/where.f90 (+8-8)
- (added) flang/test/Semantics/resolve123.f90 (+79)
- (added) flang/test/Semantics/resolve124.f90 (+89)
- (modified) flang/test/Semantics/resolve55.f90 (+11-8)
- (modified) flang/test/Transforms/loop-versioning.fir (+48-48)
- (modified) flang/test/Transforms/omp-reduction-cfg-conversion.fir (+1-1)
- (modified) flang/test/Transforms/simplifyintrinsics.fir (+33-33)
- (modified) flang/test/Transforms/stack-arrays.fir (+2-2)
- (modified) flang/test/Transforms/tbaa2.fir (+6-6)
``````````diff
diff --git a/flang/examples/FeatureList/FeatureList.cpp b/flang/examples/FeatureList/FeatureList.cpp
index 3ca92da4f6467..28689b5d3c4b0 100644
--- a/flang/examples/FeatureList/FeatureList.cpp
+++ b/flang/examples/FeatureList/FeatureList.cpp
@@ -410,10 +410,12 @@ struct NodeVisitor {
READ_FEATURE(LetterSpec)
READ_FEATURE(LiteralConstant)
READ_FEATURE(IntLiteralConstant)
+ READ_FEATURE(ReduceOperation)
READ_FEATURE(LocalitySpec)
READ_FEATURE(LocalitySpec::DefaultNone)
READ_FEATURE(LocalitySpec::Local)
READ_FEATURE(LocalitySpec::LocalInit)
+ READ_FEATURE(LocalitySpec::Reduce)
READ_FEATURE(LocalitySpec::Shared)
READ_FEATURE(LockStmt)
READ_FEATURE(LockStmt::LockStat)
diff --git a/flang/include/flang/Optimizer/Dialect/FIRAttr.td b/flang/include/flang/Optimizer/Dialect/FIRAttr.td
index f8b3fb861cc62..5404d189a4863 100644
--- a/flang/include/flang/Optimizer/Dialect/FIRAttr.td
+++ b/flang/include/flang/Optimizer/Dialect/FIRAttr.td
@@ -67,6 +67,36 @@ def fir_BoxFieldAttr : I32EnumAttr<
let cppNamespace = "fir";
}
+def fir_ReduceOperationEnum : I32BitEnumAttr<"ReduceOperationEnum",
+ "intrinsic operations and functions supported by DO CONCURRENT REDUCE",
+ [
+ I32BitEnumAttrCaseBit<"Add", 0, "add">,
+ I32BitEnumAttrCaseBit<"Multiply", 1, "multiply">,
+ I32BitEnumAttrCaseBit<"AND", 2, "and">,
+ I32BitEnumAttrCaseBit<"OR", 3, "or">,
+ I32BitEnumAttrCaseBit<"EQV", 4, "eqv">,
+ I32BitEnumAttrCaseBit<"NEQV", 5, "neqv">,
+ I32BitEnumAttrCaseBit<"MAX", 6, "max">,
+ I32BitEnumAttrCaseBit<"MIN", 7, "min">,
+ I32BitEnumAttrCaseBit<"IAND", 8, "iand">,
+ I32BitEnumAttrCaseBit<"IOR", 9, "ior">,
+ I32BitEnumAttrCaseBit<"EIOR", 10, "eior">
+ ]> {
+ let separator = ", ";
+ let cppNamespace = "::fir";
+ let printBitEnumPrimaryGroups = 1;
+}
+
+def fir_ReduceAttr : fir_Attr<"Reduce"> {
+ let mnemonic = "reduce_attr";
+
+ let parameters = (ins
+ "ReduceOperationEnum":$reduce_operation
+ );
+
+ let assemblyFormat = "`<` $reduce_operation `>`";
+}
+
// mlir::SideEffects::Resource for modelling operations which add debugging information
def DebuggingResource : Resource<"::fir::DebuggingResource">;
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 64c5e360b28f7..b9fd4fed6f13d 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -2062,8 +2062,37 @@ class region_Op<string mnemonic, list<Trait> traits = []> :
let hasVerifier = 1;
}
-def fir_DoLoopOp : region_Op<"do_loop",
- [DeclareOpInterfaceMethods<LoopLikeOpInterface,
+def fir_ReduceOp : fir_SimpleOp<"reduce", [NoMemoryEffect]> {
+ let summary = "Represent reduction semantics for the reduce clause";
+
+ let description = [{
+ Given the address of a variable, creates reduction information for the
+ reduce clause.
+
+ ```
+ %17 = fir.reduce %8 {name = "sum"} : (!fir.ref<f32>) -> !fir.ref<f32>
+ fir.do_loop ... unordered reduce(#fir.reduce_attr<add> -> %17 : !fir.ref<f32>) ...
+ ```
+
+ This operation is typically used for DO CONCURRENT REDUCE clause. The memref
+ operand may have a unique name while the `name` attribute preserves the
+ original name of a reduction variable.
+ }];
+
+ let arguments = (ins
+ AnyRefOrBoxLike:$memref,
+ Builtin_StringAttr:$name
+ );
+
+ let results = (outs AnyRefOrBox);
+
+ let assemblyFormat = [{
+ operands attr-dict `:` functional-type(operands, results)
+ }];
+}
+
+def fir_DoLoopOp : region_Op<"do_loop", [AttrSizedOperandSegments,
+ DeclareOpInterfaceMethods<LoopLikeOpInterface,
["getYieldedValuesMutable"]>]> {
let summary = "generalized loop operation";
let description = [{
@@ -2095,7 +2124,9 @@ def fir_DoLoopOp : region_Op<"do_loop",
Index:$step,
Variadic<AnyType>:$initArgs,
OptionalAttr<UnitAttr>:$unordered,
- OptionalAttr<UnitAttr>:$finalValue
+ OptionalAttr<UnitAttr>:$finalValue,
+ Variadic<AnyType>:$reduceOperands,
+ OptionalAttr<ArrayAttr>:$reduceAttrs
);
let results = (outs Variadic<AnyType>:$results);
let regions = (region SizedRegion<1>:$region);
@@ -2106,6 +2137,8 @@ def fir_DoLoopOp : region_Op<"do_loop",
"mlir::Value":$step, CArg<"bool", "false">:$unordered,
CArg<"bool", "false">:$finalCountValue,
CArg<"mlir::ValueRange", "std::nullopt">:$iterArgs,
+ CArg<"mlir::ValueRange", "std::nullopt">:$reduceOperands,
+ CArg<"llvm::ArrayRef<mlir::Attribute>", "{}">:$reduceAttrs,
CArg<"llvm::ArrayRef<mlir::NamedAttribute>", "{}">:$attributes)>
];
@@ -2118,11 +2151,13 @@ def fir_DoLoopOp : region_Op<"do_loop",
return getBody()->getArguments().drop_front();
}
mlir::Operation::operand_range getIterOperands() {
- return getOperands().drop_front(getNumControlOperands());
+ return getOperands().drop_front(getNumControlOperands())
+ .take_front(getNumIterOperands());
}
llvm::MutableArrayRef<mlir::OpOperand> getInitsMutable() {
return
- getOperation()->getOpOperands().drop_front(getNumControlOperands());
+ getOperation()->getOpOperands().drop_front(getNumControlOperands())
+ .take_front(getNumIterOperands());
}
void setLowerBound(mlir::Value bound) { (*this)->setOperand(0, bound); }
@@ -2131,17 +2166,31 @@ def fir_DoLoopOp : region_Op<"do_loop",
/// Number of region arguments for loop-carried values
unsigned getNumRegionIterArgs() {
- return getBody()->getNumArguments() - 1;
+ return getBody()->getNumArguments() - (1 + getNumReduceOperands());
}
/// Number of operands controlling the loop: lb, ub, step
unsigned getNumControlOperands() { return 3; }
/// Does the operation hold operands for loop-carried values
bool hasIterOperands() {
- return (*this)->getNumOperands() > getNumControlOperands();
+ return getNumIterOperands() > 0;
+ }
+ /// Does the operation hold operands for reduction variables
+ bool hasReduceOperands() {
+ return getNumReduceOperands() > 0;
+ }
+ /// Get Number of variadic operands
+ unsigned getNumOperands(unsigned idx) {
+ auto segments = (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>(
+ getOperandSegmentSizeAttr());
+ return static_cast<unsigned>(segments[idx]);
}
/// Get Number of loop-carried values
unsigned getNumIterOperands() {
- return (*this)->getNumOperands() - getNumControlOperands();
+ return getNumOperands(3);
+ }
+ // Get Number of reduction operands
+ unsigned getNumReduceOperands() {
+ return getNumOperands(4);
}
/// Get the body of the loop
diff --git a/flang/include/flang/Parser/dump-parse-tree.h b/flang/include/flang/Parser/dump-parse-tree.h
index 68ae50c312cde..15948bb073664 100644
--- a/flang/include/flang/Parser/dump-parse-tree.h
+++ b/flang/include/flang/Parser/dump-parse-tree.h
@@ -436,10 +436,12 @@ class ParseTreeDumper {
NODE(parser, LetterSpec)
NODE(parser, LiteralConstant)
NODE(parser, IntLiteralConstant)
+ NODE(parser, ReduceOperation)
NODE(parser, LocalitySpec)
NODE(LocalitySpec, DefaultNone)
NODE(LocalitySpec, Local)
NODE(LocalitySpec, LocalInit)
+ NODE(LocalitySpec, Reduce)
NODE(LocalitySpec, Shared)
NODE(parser, LockStmt)
NODE(LockStmt, LockStat)
diff --git a/flang/include/flang/Parser/parse-tree.h b/flang/include/flang/Parser/parse-tree.h
index 0a40aa8b8f616..68a4319a85047 100644
--- a/flang/include/flang/Parser/parse-tree.h
+++ b/flang/include/flang/Parser/parse-tree.h
@@ -1870,6 +1870,13 @@ struct ProcComponentRef {
WRAPPER_CLASS_BOILERPLATE(ProcComponentRef, Scalar<StructureComponent>);
};
+// R1522 procedure-designator ->
+// procedure-name | proc-component-ref | data-ref % binding-name
+struct ProcedureDesignator {
+ UNION_CLASS_BOILERPLATE(ProcedureDesignator);
+ std::variant<Name, ProcComponentRef> u;
+};
+
// R914 coindexed-named-object -> data-ref
struct CoindexedNamedObject {
BOILERPLATE(CoindexedNamedObject);
@@ -2236,16 +2243,29 @@ struct ConcurrentHeader {
t;
};
+// F'2023 R1131 reduce-operation ->
+// + | * | .AND. | .OR. | .EQV. | .NEQV. |
+// MAX | MIN | IAND | IOR | IEOR
+struct ReduceOperation {
+ UNION_CLASS_BOILERPLATE(ReduceOperation);
+ std::variant<DefinedOperator, ProcedureDesignator> u;
+};
+
// R1130 locality-spec ->
// LOCAL ( variable-name-list ) | LOCAL_INIT ( variable-name-list ) |
+// REDUCE ( reduce-operation : variable-name-list ) |
// SHARED ( variable-name-list ) | DEFAULT ( NONE )
struct LocalitySpec {
UNION_CLASS_BOILERPLATE(LocalitySpec);
WRAPPER_CLASS(Local, std::list<Name>);
WRAPPER_CLASS(LocalInit, std::list<Name>);
+ struct Reduce {
+ TUPLE_CLASS_BOILERPLATE(Reduce);
+ std::tuple<ReduceOperation, std::list<Name>> t;
+ };
WRAPPER_CLASS(Shared, std::list<Name>);
EMPTY_CLASS(DefaultNone);
- std::variant<Local, LocalInit, Shared, DefaultNone> u;
+ std::variant<Local, LocalInit, Reduce, Shared, DefaultNone> u;
};
// R1123 loop-control ->
@@ -3180,13 +3200,6 @@ WRAPPER_CLASS(ExternalStmt, std::list<Name>);
// R1519 intrinsic-stmt -> INTRINSIC [::] intrinsic-procedure-name-list
WRAPPER_CLASS(IntrinsicStmt, std::list<Name>);
-// R1522 procedure-designator ->
-// procedure-name | proc-component-ref | data-ref % binding-name
-struct ProcedureDesignator {
- UNION_CLASS_BOILERPLATE(ProcedureDesignator);
- std::variant<Name, ProcComponentRef> u;
-};
-
// R1525 alt-return-spec -> * label
WRAPPER_CLASS(AltReturnSpec, Label);
diff --git a/flang/include/flang/Semantics/symbol.h b/flang/include/flang/Semantics/symbol.h
index 50f7b68d80cb1..8ccf93c803845 100644
--- a/flang/include/flang/Semantics/symbol.h
+++ b/flang/include/flang/Semantics/symbol.h
@@ -714,6 +714,7 @@ class Symbol {
CrayPointer, CrayPointee,
LocalityLocal, // named in LOCAL locality-spec
LocalityLocalInit, // named in LOCAL_INIT locality-spec
+ LocalityReduce, // named in REDUCE locality-spec
LocalityShared, // named in SHARED locality-spec
InDataStmt, // initialized in a DATA statement, =>object, or /init/
InNamelist, // in a Namelist group
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index afbc1122de868..b9d4bcc0338fb 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -101,7 +101,7 @@ struct IncrementLoopInfo {
bool hasLocalitySpecs() const {
return !localSymList.empty() || !localInitSymList.empty() ||
- !sharedSymList.empty();
+ !reduceSymList.empty() || !sharedSymList.empty();
}
// Data members common to both structured and unstructured loops.
@@ -113,6 +113,8 @@ struct IncrementLoopInfo {
bool isUnordered; // do concurrent, forall
llvm::SmallVector<const Fortran::semantics::Symbol *> localSymList;
llvm::SmallVector<const Fortran::semantics::Symbol *> localInitSymList;
+ llvm::SmallVector<std::pair<fir::ReduceOperationEnum,
+ const Fortran::semantics::Symbol *>> reduceSymList;
llvm::SmallVector<const Fortran::semantics::Symbol *> sharedSymList;
mlir::Value loopVariable = nullptr;
@@ -1696,6 +1698,62 @@ class FirConverter : public Fortran::lower::AbstractConverter {
builder->create<fir::UnreachableOp>(loc);
}
+ fir::ReduceOperationEnum getReduceOperationEnum(
+ const Fortran::parser::ReduceOperation &rOpr) const {
+ fir::ReduceOperationEnum reduce_operation = fir::ReduceOperationEnum::Add;
+ using IntrinsicOperator =
+ Fortran::parser::DefinedOperator::IntrinsicOperator;
+ std::visit(
+ Fortran::common::visitors{
+ [&](const Fortran::parser::DefinedOperator &dOpr) {
+ const auto &intrinsicOp{std::get<IntrinsicOperator>(dOpr.u)};
+ switch (intrinsicOp) {
+ case IntrinsicOperator::Add:
+ reduce_operation = fir::ReduceOperationEnum::Add;
+ return;
+ case IntrinsicOperator::Multiply:
+ reduce_operation = fir::ReduceOperationEnum::Multiply;
+ return;
+ case IntrinsicOperator::AND:
+ reduce_operation = fir::ReduceOperationEnum::AND;
+ return;
+ case IntrinsicOperator::OR:
+ reduce_operation = fir::ReduceOperationEnum::OR;
+ return;
+ case IntrinsicOperator::EQV:
+ reduce_operation = fir::ReduceOperationEnum::EQV;
+ return;
+ case IntrinsicOperator::NEQV:
+ reduce_operation = fir::ReduceOperationEnum::NEQV;
+ return;
+ default:
+ return;
+ }
+ },
+ [&](const Fortran::parser::ProcedureDesignator &procD) {
+ const Fortran::parser::Name *name{
+ std::get_if<Fortran::parser::Name>(&procD.u)
+ };
+ if (name && name->symbol) {
+ const Fortran::parser::CharBlock
+ &realName{name->symbol->GetUltimate().name()};
+ if (realName == "max")
+ reduce_operation = fir::ReduceOperationEnum::MAX;
+ else if (realName == "min")
+ reduce_operation = fir::ReduceOperationEnum::MIN;
+ else if (realName == "iand")
+ reduce_operation = fir::ReduceOperationEnum::IAND;
+ else if (realName == "ior")
+ reduce_operation = fir::ReduceOperationEnum::IOR;
+ else if (realName == "ieor")
+ reduce_operation = fir::ReduceOperationEnum::EIOR;
+ }
+ }
+ },
+ rOpr.u);
+ return fir::ReduceOperationEnum(reduce_operation);
+ }
+
/// Collect DO CONCURRENT or FORALL loop control information.
IncrementLoopNestInfo getConcurrentControl(
const Fortran::parser::ConcurrentHeader &header,
@@ -1718,6 +1776,15 @@ class FirConverter : public Fortran::lower::AbstractConverter {
std::get_if<Fortran::parser::LocalitySpec::LocalInit>(&x.u))
for (const Fortran::parser::Name &x : localInitList->v)
info.localInitSymList.push_back(x.symbol);
+ if (const auto *reduceList =
+ std::get_if<Fortran::parser::LocalitySpec::Reduce>(&x.u)) {
+ fir::ReduceOperationEnum reduce_operation = getReduceOperationEnum(
+ std::get<Fortran::parser::ReduceOperation>(reduceList->t));
+ for (const Fortran::parser::Name &x :
+ std::get<std::list<Fortran::parser::Name>>(reduceList->t)) {
+ info.reduceSymList.push_back(std::make_pair(reduce_operation, x.symbol));
+ }
+ }
if (const auto *sharedList =
std::get_if<Fortran::parser::LocalitySpec::Shared>(&x.u))
for (const Fortran::parser::Name &x : sharedList->v)
@@ -1910,9 +1977,26 @@ class FirConverter : public Fortran::lower::AbstractConverter {
mlir::Type loopVarType = info.getLoopVariableType();
mlir::Value loopValue;
if (info.isUnordered) {
+ llvm::SmallVector<mlir::Value> reduceOperands;
+ llvm::SmallVector<mlir::Attribute> reduceAttrs;
+ // Create DO CONCURRENT reduce operations and attributes
+ for (const auto reduceSym : info.reduceSymList) {
+ const fir::ReduceOperationEnum reduce_operation = reduceSym.first;
+ const Fortran::semantics::Symbol *sym = reduceSym.second;
+ fir::ExtendedValue exv = getSymbolExtendedValue(*sym, nullptr);
+ auto reduce_op = builder->create<fir::ReduceOp>(
+ loc, fir::ReferenceType::get(genType(*sym)), fir::getBase(exv),
+ builder->getStringAttr(sym->name().ToString()));
+ reduceOperands.push_back(reduce_op);
+ auto reduce_attr = fir::ReduceAttr::get(
+ builder->getContext(), reduce_operation);
+ reduceAttrs.push_back(reduce_attr);
+ }
// The loop variable value is explicitly updated.
info.doLoop = builder->create<fir::DoLoopOp>(
- loc, lowerValue, upperValue, stepValue, /*unordered=*/true);
+ loc, lowerValue, upperValue, stepValue, /*unordered=*/true,
+ /*finalCountValue=*/false, /*iterArgs=*/std::nullopt,
+ llvm::ArrayRef<mlir::Value>(reduceOperands), reduceAttrs);
builder->setInsertionPointToStart(info.doLoop.getBody());
loopValue = builder->createConvert(loc, loopVarType,
info.doLoop.getInductionVar());
diff --git a/flang/lib/Optimizer/Dialect/FIRAttr.cpp b/flang/lib/Optimizer/Dialect/FIRAttr.cpp
index 9ea3a0568f691..2a688c144c069 100644
--- a/flang/lib/Optimizer/Dialect/FIRAttr.cpp
+++ b/flang/lib/Optimizer/Dialect/FIRAttr.cpp
@@ -297,8 +297,8 @@ void fir::printFirAttribute(FIROpsDialect *dialect, mlir::Attribute attr,
void FIROpsDialect::registerAttributes() {
addAttributes<ClosedIntervalAttr, ExactTypeAttr, FortranVariableFlagsAttr,
- LowerBoundAttr, PointIntervalAttr, RealAttr, SubclassAttr,
- UpperBoundAttr, CUDADataAttributeAttr, CUDAProcAttributeAttr,
- CUDALaunchBoundsAttr, CUDAClusterDimsAttr,
- CUDADataTransferKindAttr>();
+ LowerBoundAttr, PointIntervalAttr, RealAttr, ReduceAttr,
+ SubclassAttr, UpperBoundAttr, CUDADataAttributeAttr,
+ CUDAProcAttributeAttr, CUDALaunchBoundsAttr,
+ CUDAClusterDimsAttr, CUDADataTransferKindAttr>();
}
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index edf7f7f4b1a96..a8c923804d140 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -2065,9 +2065,15 @@ void fir::DoLoopOp::build(mlir::OpBuilder &builder,
mlir::OperationState &result, mlir::Value lb,
mlir::Value ub, mlir::Value step, bool unordered,
bool finalCountValue, mlir::ValueRange iterArgs,
+ mlir::ValueRange reduceOperands,
+ llvm::ArrayRef<mlir::Attribute> reduceAttrs,
llvm::ArrayRef<mlir::NamedAttribute> attributes) {
result.addOperands({lb, ub, step});
result.addOperands(iterArgs);
+ result.addOperands(reduceOperands);
+ result.addAttribute(getOperandSegmentSizeAttr(), builder.getDenseI32ArrayAttr(
+ {1, 1, 1, static_cast<int32_t>(iterArgs.size()),
+ static_cast<int32_t>(reduceOperands.size())}));
if (finalCountValue) {
result.addTypes(builder.getIndexType());
result.addAttribute(getFinalValueAttrName(result.name),
@@ -2086,6 +2092,9 @@ void fir::DoLoopOp::build(mlir::OpBuilder &builder,
if (unordered)
result.addAttribute(getUnorderedAttrName(result.name),
builder.getUnitAttr());
+ if (!reduceAttrs.empty())
+ result.addAttribute(getReduceAttrsAttrName(result.name),
+ builder.getArrayAttr(reduceAttrs));
result.addAttributes(attributes);
}
@@ -2113,26 +2122,55 @@ mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser,
// Parse the optional initial iteration arguments.
llvm::SmallVector<mlir::OpAsmParser::Argument> regionArgs;
- llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> operands;
+ llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> iterOperands;
llvm::SmallVector<mlir::Type> argTypes;
bool prependCount = false;
regionArgs.push_back(inductionVariable);
if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
// Parse assignment list and results type list.
- if (parser.parseAssignmentList(regionArgs, operands) ||
+ if (parser.parseAssignmentList(regionArgs, iterOperands) ||
parser.parseArrowTypeList(result.types))
return mlir::failure();
- if (result.types.size() == operands.size() + 1)
+ if (result.types.size() == iterOperands.size() + 1)
prependCount = true;
// Resolve input operands.
llvm::ArrayRef<mlir::Type> resTypes = result.types;
for (auto operand_type :
- llvm::zip(o...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/92480
More information about the flang-commits
mailing list