[llvm-branch-commits] [flang] [flang][fir] Add locality specifiers modeling to `fir.do_concurrent.loop` (PR #138506)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Mon May 5 04:10:25 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-fir-hlfir
Author: Kareem Ergawy (ergawy)
<details>
<summary>Changes</summary>
Extends `fir.do_concurrent.loop` ops to model locality specifiers. This follows the same pattern used in OpenMP where an op of type `fir.local` (in OpenMP it is `omp.private`) is referenced from the `do concurrent` locality specifier. This PR adds the MLIR op changes as well as printing and parsing logic.
---
Full diff: https://github.com/llvm/llvm-project/pull/138506.diff
5 Files Affected:
- (modified) flang/include/flang/Optimizer/Dialect/FIROps.td (+32-1)
- (modified) flang/lib/Lower/Bridge.cpp (+1-1)
- (modified) flang/lib/Optimizer/Dialect/FIROps.cpp (+94-18)
- (modified) flang/test/Fir/do_concurrent.fir (+63-1)
- (modified) flang/test/Fir/invalid.fir (+5-5)
``````````diff
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index aea57d2e8dd71..e1d9f877855c4 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -3647,6 +3647,13 @@ def fir_DoConcurrentOp : fir_Op<"do_concurrent",
let hasVerifier = 1;
}
+def fir_LocalSpecifier {
+ dag arguments = (ins
+ Variadic<AnyType>:$local_vars,
+ OptionalAttr<SymbolRefArrayAttr>:$local_syms
+ );
+}
+
def fir_DoConcurrentLoopOp : fir_Op<"do_concurrent.loop",
[AttrSizedOperandSegments, DeclareOpInterfaceMethods<LoopLikeOpInterface,
["getLoopInductionVars"]>,
@@ -3700,7 +3707,7 @@ def fir_DoConcurrentLoopOp : fir_Op<"do_concurrent.loop",
LLVM.
}];
- let arguments = (ins
+ defvar opArgs = (ins
Variadic<Index>:$lowerBound,
Variadic<Index>:$upperBound,
Variadic<Index>:$step,
@@ -3709,16 +3716,40 @@ def fir_DoConcurrentLoopOp : fir_Op<"do_concurrent.loop",
OptionalAttr<LoopAnnotationAttr>:$loopAnnotation
);
+ let arguments = !con(opArgs, fir_LocalSpecifier.arguments);
+
let regions = (region SizedRegion<1>:$region);
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
let extraClassDeclaration = [{
+ unsigned getNumInductionVars() { return getLowerBound().size(); }
+
+ unsigned getNumLocalOperands() { return getLocalVars().size(); }
+
+ mlir::Block::BlockArgListType getInductionVars() {
+ return getBody()->getArguments().slice(0, getNumInductionVars());
+ }
+
+ mlir::Block::BlockArgListType getRegionLocalArgs() {
+ return getBody()->getArguments().slice(getNumInductionVars(),
+ getNumLocalOperands());
+ }
+
+ /// Number of operands controlling the loop
+ unsigned getNumControlOperands() { return getLowerBound().size() * 3; }
+
// Get Number of reduction operands
unsigned getNumReduceOperands() {
return getReduceOperands().size();
}
+
+ mlir::Operation::operand_range getLocalOperands() {
+ return getOperands()
+ .slice(getNumControlOperands() + getNumReduceOperands(),
+ getNumLocalOperands());
+ }
}];
}
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 8da05255d5f41..0a61f61ab8f75 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -2460,7 +2460,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
nestReduceAttrs.empty()
? nullptr
: mlir::ArrayAttr::get(builder->getContext(), nestReduceAttrs),
- nullptr);
+ nullptr, /*local_vars=*/std::nullopt, /*local_syms=*/nullptr);
llvm::SmallVector<mlir::Type> loopBlockArgTypes(
incrementLoopNestInfo.size(), builder->getIndexType());
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 65ec730e134c2..c95655d7dcef6 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -5033,21 +5033,25 @@ mlir::ParseResult fir::DoConcurrentLoopOp::parse(mlir::OpAsmParser &parser,
mlir::OperationState &result) {
auto &builder = parser.getBuilder();
// Parse an opening `(` followed by induction variables followed by `)`
- llvm::SmallVector<mlir::OpAsmParser::Argument, 4> ivs;
- if (parser.parseArgumentList(ivs, mlir::OpAsmParser::Delimiter::Paren))
+ llvm::SmallVector<mlir::OpAsmParser::Argument, 4> regionArgs;
+
+ if (parser.parseArgumentList(regionArgs, mlir::OpAsmParser::Delimiter::Paren))
return mlir::failure();
+ llvm::SmallVector<mlir::Type> argTypes(regionArgs.size(),
+ builder.getIndexType());
+
// Parse loop bounds.
llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> lower;
if (parser.parseEqual() ||
- parser.parseOperandList(lower, ivs.size(),
+ parser.parseOperandList(lower, regionArgs.size(),
mlir::OpAsmParser::Delimiter::Paren) ||
parser.resolveOperands(lower, builder.getIndexType(), result.operands))
return mlir::failure();
llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> upper;
if (parser.parseKeyword("to") ||
- parser.parseOperandList(upper, ivs.size(),
+ parser.parseOperandList(upper, regionArgs.size(),
mlir::OpAsmParser::Delimiter::Paren) ||
parser.resolveOperands(upper, builder.getIndexType(), result.operands))
return mlir::failure();
@@ -5055,7 +5059,7 @@ mlir::ParseResult fir::DoConcurrentLoopOp::parse(mlir::OpAsmParser &parser,
// Parse step values.
llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> steps;
if (parser.parseKeyword("step") ||
- parser.parseOperandList(steps, ivs.size(),
+ parser.parseOperandList(steps, regionArgs.size(),
mlir::OpAsmParser::Delimiter::Paren) ||
parser.resolveOperands(steps, builder.getIndexType(), result.operands))
return mlir::failure();
@@ -5086,12 +5090,55 @@ mlir::ParseResult fir::DoConcurrentLoopOp::parse(mlir::OpAsmParser &parser,
builder.getArrayAttr(arrayAttr));
}
- // Now parse the body.
- mlir::Region *body = result.addRegion();
- for (auto &iv : ivs)
- iv.type = builder.getIndexType();
- if (parser.parseRegion(*body, ivs))
- return mlir::failure();
+ llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> localOperands;
+ if (succeeded(parser.parseOptionalKeyword("local"))) {
+ std::size_t oldArgTypesSize = argTypes.size();
+ if (failed(parser.parseLParen()))
+ return mlir::failure();
+
+ llvm::SmallVector<mlir::SymbolRefAttr> localSymbolVec;
+ if (failed(parser.parseCommaSeparatedList([&]() {
+ if (failed(parser.parseAttribute(localSymbolVec.emplace_back())))
+ return mlir::failure();
+
+ if (parser.parseOperand(localOperands.emplace_back()) ||
+ parser.parseArrow() ||
+ parser.parseArgument(regionArgs.emplace_back()))
+ return mlir::failure();
+
+ return mlir::success();
+ })))
+ return mlir::failure();
+
+ if (failed(parser.parseColon()))
+ return mlir::failure();
+
+ if (failed(parser.parseCommaSeparatedList([&]() {
+ if (failed(parser.parseType(argTypes.emplace_back())))
+ return mlir::failure();
+
+ return mlir::success();
+ })))
+ return mlir::failure();
+
+ if (regionArgs.size() != argTypes.size())
+ return parser.emitError(parser.getNameLoc(),
+ "mismatch in number of local arg and types");
+
+ if (failed(parser.parseRParen()))
+ return mlir::failure();
+
+ for (auto operandType : llvm::zip_equal(
+ localOperands, llvm::drop_begin(argTypes, oldArgTypesSize)))
+ if (parser.resolveOperand(std::get<0>(operandType),
+ std::get<1>(operandType), result.operands))
+ return mlir::failure();
+
+ llvm::SmallVector<mlir::Attribute> symbolAttrs(localSymbolVec.begin(),
+ localSymbolVec.end());
+ result.addAttribute(getLocalSymsAttrName(result.name),
+ builder.getArrayAttr(symbolAttrs));
+ }
// Set `operandSegmentSizes` attribute.
result.addAttribute(DoConcurrentLoopOp::getOperandSegmentSizeAttr(),
@@ -5099,7 +5146,16 @@ mlir::ParseResult fir::DoConcurrentLoopOp::parse(mlir::OpAsmParser &parser,
{static_cast<int32_t>(lower.size()),
static_cast<int32_t>(upper.size()),
static_cast<int32_t>(steps.size()),
- static_cast<int32_t>(reduceOperands.size())}));
+ static_cast<int32_t>(reduceOperands.size()),
+ static_cast<int32_t>(localOperands.size())}));
+
+ // Now parse the body.
+ for (auto [arg, type] : llvm::zip_equal(regionArgs, argTypes))
+ arg.type = type;
+
+ mlir::Region *body = result.addRegion();
+ if (parser.parseRegion(*body, regionArgs))
+ return mlir::failure();
// Parse attributes.
if (parser.parseOptionalAttrDict(result.attributes))
@@ -5109,8 +5165,9 @@ mlir::ParseResult fir::DoConcurrentLoopOp::parse(mlir::OpAsmParser &parser,
}
void fir::DoConcurrentLoopOp::print(mlir::OpAsmPrinter &p) {
- p << " (" << getBody()->getArguments() << ") = (" << getLowerBound()
- << ") to (" << getUpperBound() << ") step (" << getStep() << ")";
+ p << " (" << getBody()->getArguments().slice(0, getNumInductionVars())
+ << ") = (" << getLowerBound() << ") to (" << getUpperBound() << ") step ("
+ << getStep() << ")";
if (!getReduceOperands().empty()) {
p << " reduce(";
@@ -5123,12 +5180,27 @@ void fir::DoConcurrentLoopOp::print(mlir::OpAsmPrinter &p) {
p << ')';
}
+ if (!getLocalVars().empty()) {
+ p << " local(";
+ llvm::interleaveComma(llvm::zip_equal(getLocalSymsAttr(), getLocalVars(),
+ getRegionLocalArgs()),
+ p, [&](auto it) {
+ p << std::get<0>(it) << " " << std::get<1>(it)
+ << " -> " << std::get<2>(it);
+ });
+ p << " : ";
+ llvm::interleaveComma(getLocalVars(), p,
+ [&](auto it) { p << it.getType(); });
+ p << ")";
+ }
+
p << ' ';
p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
p.printOptionalAttrDict(
(*this)->getAttrs(),
/*elidedAttrs=*/{DoConcurrentLoopOp::getOperandSegmentSizeAttr(),
- DoConcurrentLoopOp::getReduceAttrsAttrName()});
+ DoConcurrentLoopOp::getReduceAttrsAttrName(),
+ DoConcurrentLoopOp::getLocalSymsAttrName()});
}
llvm::SmallVector<mlir::Region *> fir::DoConcurrentLoopOp::getLoopRegions() {
@@ -5139,6 +5211,7 @@ llvm::LogicalResult fir::DoConcurrentLoopOp::verify() {
mlir::Operation::operand_range lbValues = getLowerBound();
mlir::Operation::operand_range ubValues = getUpperBound();
mlir::Operation::operand_range stepValues = getStep();
+ mlir::Operation::operand_range localVars = getLocalVars();
if (lbValues.empty())
return emitOpError(
@@ -5152,11 +5225,13 @@ llvm::LogicalResult fir::DoConcurrentLoopOp::verify() {
// Check that the body defines the same number of block arguments as the
// number of tuple elements in step.
mlir::Block *body = getBody();
- if (body->getNumArguments() != stepValues.size())
+ unsigned numIndVarArgs = body->getNumArguments() - localVars.size();
+
+ if (numIndVarArgs != stepValues.size())
return emitOpError() << "expects the same number of induction variables: "
<< body->getNumArguments()
<< " as bound and step values: " << stepValues.size();
- for (auto arg : body->getArguments())
+ for (auto arg : body->getArguments().slice(0, numIndVarArgs))
if (!arg.getType().isIndex())
return emitOpError(
"expects arguments for the induction variable to be of index type");
@@ -5171,7 +5246,8 @@ llvm::LogicalResult fir::DoConcurrentLoopOp::verify() {
std::optional<llvm::SmallVector<mlir::Value>>
fir::DoConcurrentLoopOp::getLoopInductionVars() {
- return llvm::SmallVector<mlir::Value>{getBody()->getArguments()};
+ return llvm::SmallVector<mlir::Value>{
+ getBody()->getArguments().slice(0, getLowerBound().size())};
}
//===----------------------------------------------------------------------===//
diff --git a/flang/test/Fir/do_concurrent.fir b/flang/test/Fir/do_concurrent.fir
index 4e55777402428..cfb9a7abac15b 100644
--- a/flang/test/Fir/do_concurrent.fir
+++ b/flang/test/Fir/do_concurrent.fir
@@ -91,7 +91,6 @@ func.func @dc_2d_reduction(%i_lb: index, %i_ub: index, %i_st: index,
// CHECK: }
// CHECK: }
-
fir.local {type = local} @local_privatizer : i32
// CHECK: fir.local {type = local} @[[LOCAL_PRIV_SYM:local_privatizer]] : i32
@@ -109,3 +108,66 @@ fir.local {type = local_init} @local_init_privatizer : i32 copy {
// CHECK: fir.store %[[ORIG_VAL_LD]] to %[[LOCAL_VAL]] : !fir.ref<i32>
// CHECK: fir.yield(%[[LOCAL_VAL]] : !fir.ref<i32>)
// CHECK: }
+
+func.func @_QPdo_concurrent() {
+ %3 = fir.alloca i32 {bindc_name = "local_init_var", uniq_name = "_QFdo_concurrentElocal_init_var"}
+ %4:2 = hlfir.declare %3 {uniq_name = "_QFdo_concurrentElocal_init_var"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+ %5 = fir.alloca i32 {bindc_name = "local_var", uniq_name = "_QFdo_concurrentElocal_var"}
+ %6:2 = hlfir.declare %5 {uniq_name = "_QFdo_concurrentElocal_var"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+ %c1 = arith.constant 1 : index
+ %c10 = arith.constant 1 : index
+ fir.do_concurrent {
+ %9 = fir.alloca i32 {bindc_name = "i"}
+ %10:2 = hlfir.declare %9 {uniq_name = "_QFdo_concurrentEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+ fir.do_concurrent.loop (%arg0) = (%c1) to (%c10) step (%c1) local(@local_privatizer %6#0 -> %arg1, @local_init_privatizer %4#0 -> %arg2 : !fir.ref<i32>, !fir.ref<i32>) {
+ %11 = fir.convert %arg0 : (index) -> i32
+ fir.store %11 to %10#0 : !fir.ref<i32>
+ %13:2 = hlfir.declare %arg1 {uniq_name = "_QFdo_concurrentElocal_var"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+ %15:2 = hlfir.declare %arg2 {uniq_name = "_QFdo_concurrentElocal_init_var"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+ %17 = fir.load %10#0 : !fir.ref<i32>
+ %c5_i32 = arith.constant 5 : i32
+ %18 = arith.cmpi slt, %17, %c5_i32 : i32
+ fir.if %18 {
+ %c42_i32 = arith.constant 42 : i32
+ hlfir.assign %c42_i32 to %13#0 : i32, !fir.ref<i32>
+ } else {
+ %c84_i32 = arith.constant 84 : i32
+ hlfir.assign %c84_i32 to %15#0 : i32, !fir.ref<i32>
+ }
+ }
+ }
+ return
+}
+
+// CHECK-LABEL: func.func @_QPdo_concurrent() {
+// CHECK: %[[LOC_INIT_ALLOC:.*]] = fir.alloca i32 {bindc_name = "local_init_var", {{.*}}}
+// CHECK: %[[LOC_INIT_DECL:.*]]:2 = hlfir.declare %[[LOC_INIT_ALLOC]]
+
+// CHECK: %[[LOC_ALLOC:.*]] = fir.alloca i32 {bindc_name = "local_var", {{.*}}}
+// CHECK: %[[LOC_DECL:.*]]:2 = hlfir.declare %[[LOC_ALLOC]]
+
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[C10:.*]] = arith.constant 1 : index
+
+// CHECK: fir.do_concurrent {
+// CHECK: %[[DC_I_ALLOC:.*]] = fir.alloca i32 {bindc_name = "i"}
+// CHECK: %[[DC_I_DECL:.*]]:2 = hlfir.declare %[[DC_I_ALLOC]]
+
+// CHECK: fir.do_concurrent.loop (%[[IV:.*]]) = (%[[C1]]) to (%[[C10]]) step (%[[C1]]) local(@[[LOCAL_PRIV_SYM]] %[[LOC_DECL]]#0 -> %[[LOC_ARG:.*]], @[[LOCAL_INIT_PRIV_SYM]] %[[LOC_INIT_DECL]]#0 -> %[[LOC_INIT_ARG:.*]] : !fir.ref<i32>, !fir.ref<i32>) {
+// CHECK: %[[IV_CVT:.*]] = fir.convert %[[IV]] : (index) -> i32
+// CHECK: fir.store %[[IV_CVT]] to %[[DC_I_DECL]]#0 : !fir.ref<i32>
+
+// CHECK: %[[LOC_PRIV_DECL:.*]]:2 = hlfir.declare %[[LOC_ARG]]
+// CHECK: %[[LOC_INIT_PRIV_DECL:.*]]:2 = hlfir.declare %[[LOC_INIT_ARG]]
+
+// CHECK: fir.if %{{.*}} {
+// CHECK: %[[C42:.*]] = arith.constant 42 : i32
+// CHECK: hlfir.assign %[[C42]] to %[[LOC_PRIV_DECL]]#0 : i32, !fir.ref<i32>
+// CHECK: } else {
+// CHECK: %[[C84:.*]] = arith.constant 84 : i32
+// CHECK: hlfir.assign %[[C84]] to %[[LOC_INIT_PRIV_DECL]]#0 : i32, !fir.ref<i32>
+// CHECK: }
+// CHECK: }
+// CHECK: }
+// CHECK: return
+// CHECK: }
diff --git a/flang/test/Fir/invalid.fir b/flang/test/Fir/invalid.fir
index f9f5e267dd9bc..3cd3ab439b0e9 100644
--- a/flang/test/Fir/invalid.fir
+++ b/flang/test/Fir/invalid.fir
@@ -1198,7 +1198,7 @@ func.func @dc_0d() {
func.func @dc_invalid_parent(%arg0: index, %arg1: index) {
// expected-error at +1 {{'fir.do_concurrent.loop' op expects parent op 'fir.do_concurrent'}}
- "fir.do_concurrent.loop"(%arg0, %arg1) <{operandSegmentSizes = array<i32: 1, 1, 0, 0>}> ({
+ "fir.do_concurrent.loop"(%arg0, %arg1) <{operandSegmentSizes = array<i32: 1, 1, 0, 0, 0>}> ({
^bb0(%arg2: index):
%tmp = "fir.alloca"() <{in_type = i32, operandSegmentSizes = array<i32: 0, 0>}> : () -> !fir.ref<i32>
}) : (index, index) -> ()
@@ -1210,7 +1210,7 @@ func.func @dc_invalid_parent(%arg0: index, %arg1: index) {
func.func @dc_invalid_control(%arg0: index, %arg1: index) {
// expected-error at +2 {{'fir.do_concurrent.loop' op different number of tuple elements for lowerBound, upperBound or step}}
fir.do_concurrent {
- "fir.do_concurrent.loop"(%arg0, %arg1) <{operandSegmentSizes = array<i32: 1, 1, 0, 0>}> ({
+ "fir.do_concurrent.loop"(%arg0, %arg1) <{operandSegmentSizes = array<i32: 1, 1, 0, 0, 0>}> ({
^bb0(%arg2: index):
%tmp = "fir.alloca"() <{in_type = i32, operandSegmentSizes = array<i32: 0, 0>}> : () -> !fir.ref<i32>
}) : (index, index) -> ()
@@ -1223,7 +1223,7 @@ func.func @dc_invalid_control(%arg0: index, %arg1: index) {
func.func @dc_invalid_ind_var(%arg0: index, %arg1: index) {
// expected-error at +2 {{'fir.do_concurrent.loop' op expects the same number of induction variables: 2 as bound and step values: 1}}
fir.do_concurrent {
- "fir.do_concurrent.loop"(%arg0, %arg1, %arg0) <{operandSegmentSizes = array<i32: 1, 1, 1, 0>}> ({
+ "fir.do_concurrent.loop"(%arg0, %arg1, %arg0) <{operandSegmentSizes = array<i32: 1, 1, 1, 0, 0>}> ({
^bb0(%arg3: index, %arg4: index):
%tmp = "fir.alloca"() <{in_type = i32, operandSegmentSizes = array<i32: 0, 0>}> : () -> !fir.ref<i32>
}) : (index, index, index) -> ()
@@ -1236,7 +1236,7 @@ func.func @dc_invalid_ind_var(%arg0: index, %arg1: index) {
func.func @dc_invalid_ind_var_type(%arg0: index, %arg1: index) {
// expected-error at +2 {{'fir.do_concurrent.loop' op expects arguments for the induction variable to be of index type}}
fir.do_concurrent {
- "fir.do_concurrent.loop"(%arg0, %arg1, %arg0) <{operandSegmentSizes = array<i32: 1, 1, 1, 0>}> ({
+ "fir.do_concurrent.loop"(%arg0, %arg1, %arg0) <{operandSegmentSizes = array<i32: 1, 1, 1, 0, 0>}> ({
^bb0(%arg3: i32):
%tmp = "fir.alloca"() <{in_type = i32, operandSegmentSizes = array<i32: 0, 0>}> : () -> !fir.ref<i32>
}) : (index, index, index) -> ()
@@ -1250,7 +1250,7 @@ func.func @dc_invalid_reduction(%arg0: index, %arg1: index) {
%sum = fir.alloca i32
// expected-error at +2 {{'fir.do_concurrent.loop' op mismatch in number of reduction variables and reduction attributes}}
fir.do_concurrent {
- "fir.do_concurrent.loop"(%arg0, %arg1, %arg0, %sum) <{operandSegmentSizes = array<i32: 1, 1, 1, 1>}> ({
+ "fir.do_concurrent.loop"(%arg0, %arg1, %arg0, %sum) <{operandSegmentSizes = array<i32: 1, 1, 1, 1, 0>}> ({
^bb0(%arg3: index):
%tmp = "fir.alloca"() <{in_type = i32, operandSegmentSizes = array<i32: 0, 0>}> : () -> !fir.ref<i32>
}) : (index, index, index, !fir.ref<i32>) -> ()
``````````
</details>
https://github.com/llvm/llvm-project/pull/138506
More information about the llvm-branch-commits
mailing list