[flang-commits] [flang] 6328506 - [flang][fir] Add rewrite pattern to convert `fir.do_concurrent` to `fir.do_loop` (#132207)
via flang-commits
flang-commits at lists.llvm.org
Mon Mar 24 04:09:39 PDT 2025
Author: Kareem Ergawy
Date: 2025-03-24T12:09:32+01:00
New Revision: 63285065368f22894aea87a8d82880cc8b0e8267
URL: https://github.com/llvm/llvm-project/commit/63285065368f22894aea87a8d82880cc8b0e8267
DIFF: https://github.com/llvm/llvm-project/commit/63285065368f22894aea87a8d82880cc8b0e8267.diff
LOG: [flang][fir] Add rewrite pattern to convert `fir.do_concurrent` to `fir.do_loop` (#132207)
Rewrites `fir.do_concurrent` ops to a corresponding nest of `fir.do_loop
... unordered` ops.
Added:
flang/test/Transforms/do_concurrent-to-do_loop-unodered.fir
Modified:
flang/include/flang/Optimizer/Dialect/FIROps.td
flang/lib/Optimizer/Dialect/FIROps.cpp
flang/lib/Optimizer/Transforms/SimplifyFIROperations.cpp
Removed:
################################################################################
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index c8d8ab41552c2..753e4bd18dc6d 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -3478,7 +3478,8 @@ def fir_DoConcurrentOp : fir_Op<"do_concurrent",
}
def fir_DoConcurrentLoopOp : fir_Op<"do_concurrent.loop",
- [AttrSizedOperandSegments, DeclareOpInterfaceMethods<LoopLikeOpInterface>,
+ [AttrSizedOperandSegments, DeclareOpInterfaceMethods<LoopLikeOpInterface,
+ ["getLoopInductionVars"]>,
Terminator, NoTerminator, SingleBlock, ParentOneOf<["DoConcurrentOp"]>]> {
let summary = "do concurrent loop";
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 1e8a7354da561..2d8017d0318d2 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -4915,6 +4915,11 @@ llvm::LogicalResult fir::DoConcurrentLoopOp::verify() {
return mlir::success();
}
+std::optional<llvm::SmallVector<mlir::Value>>
+fir::DoConcurrentLoopOp::getLoopInductionVars() {
+ return llvm::SmallVector<mlir::Value>{getBody()->getArguments()};
+}
+
//===----------------------------------------------------------------------===//
// FIROpsDialect
//===----------------------------------------------------------------------===//
diff --git a/flang/lib/Optimizer/Transforms/SimplifyFIROperations.cpp b/flang/lib/Optimizer/Transforms/SimplifyFIROperations.cpp
index e79d420c81c9c..b6baae501f87e 100644
--- a/flang/lib/Optimizer/Transforms/SimplifyFIROperations.cpp
+++ b/flang/lib/Optimizer/Transforms/SimplifyFIROperations.cpp
@@ -18,8 +18,10 @@
#include "flang/Optimizer/Builder/Todo.h"
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/Transforms/Passes.h"
+#include "mlir/IR/IRMapping.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include <optional>
namespace fir {
#define GEN_PASS_DEF_SIMPLIFYFIROPERATIONS
@@ -122,6 +124,57 @@ mlir::LogicalResult BoxTotalElementsConversion::matchAndRewrite(
return mlir::failure();
}
+class DoConcurrentConversion
+ : public mlir::OpRewritePattern<fir::DoConcurrentOp> {
+public:
+ using mlir::OpRewritePattern<fir::DoConcurrentOp>::OpRewritePattern;
+
+ mlir::LogicalResult
+ matchAndRewrite(fir::DoConcurrentOp doConcurentOp,
+ mlir::PatternRewriter &rewriter) const override {
+ assert(doConcurentOp.getRegion().hasOneBlock());
+ mlir::Block &wrapperBlock = doConcurentOp.getRegion().getBlocks().front();
+ auto loop =
+ mlir::cast<fir::DoConcurrentLoopOp>(wrapperBlock.getTerminator());
+ assert(loop.getRegion().hasOneBlock());
+ mlir::Block &loopBlock = loop.getRegion().getBlocks().front();
+
+ // Collect iteration variable(s) allocations do that we can move them
+ // outside the `fir.do_concurrent` wrapper.
+ llvm::SmallVector<mlir::Operation *> opsToMove;
+ for (mlir::Operation &op : llvm::drop_end(wrapperBlock))
+ opsToMove.push_back(&op);
+
+ fir::FirOpBuilder firBuilder(
+ rewriter, doConcurentOp->getParentOfType<mlir::ModuleOp>());
+ auto *allocIt = firBuilder.getAllocaBlock();
+
+ for (mlir::Operation *op : llvm::reverse(opsToMove))
+ rewriter.moveOpBefore(op, allocIt, allocIt->begin());
+
+ rewriter.setInsertionPointAfter(doConcurentOp);
+ fir::DoLoopOp innermostUnorderdLoop;
+ mlir::SmallVector<mlir::Value> ivArgs;
+
+ for (auto [lb, ub, st, iv] :
+ llvm::zip_equal(loop.getLowerBound(), loop.getUpperBound(),
+ loop.getStep(), *loop.getLoopInductionVars())) {
+ innermostUnorderdLoop = rewriter.create<fir::DoLoopOp>(
+ doConcurentOp.getLoc(), lb, ub, st,
+ /*unordred=*/true, /*finalCountValue=*/false,
+ /*iterArgs=*/std::nullopt, loop.getReduceOperands(),
+ loop.getReduceAttrsAttr());
+ ivArgs.push_back(innermostUnorderdLoop.getInductionVar());
+ rewriter.setInsertionPointToStart(innermostUnorderdLoop.getBody());
+ }
+
+ rewriter.inlineBlockBefore(
+ &loopBlock, innermostUnorderdLoop.getBody()->getTerminator(), ivArgs);
+ rewriter.eraseOp(doConcurentOp);
+ return mlir::success();
+ }
+};
+
void SimplifyFIROperationsPass::runOnOperation() {
mlir::ModuleOp module = getOperation();
mlir::MLIRContext &context = getContext();
@@ -142,4 +195,5 @@ void fir::populateSimplifyFIROperationsPatterns(
mlir::RewritePatternSet &patterns, bool preferInlineImplementation) {
patterns.insert<IsContiguousBoxCoversion, BoxTotalElementsConversion>(
patterns.getContext(), preferInlineImplementation);
+ patterns.insert<DoConcurrentConversion>(patterns.getContext());
}
diff --git a/flang/test/Transforms/do_concurrent-to-do_loop-unodered.fir b/flang/test/Transforms/do_concurrent-to-do_loop-unodered.fir
new file mode 100644
index 0000000000000..d2ceafdda5b22
--- /dev/null
+++ b/flang/test/Transforms/do_concurrent-to-do_loop-unodered.fir
@@ -0,0 +1,123 @@
+// Tests converting `fir.do_concurrent` ops to `fir.do_loop` ops.
+
+// RUN: fir-opt --split-input-file --simplify-fir-operations %s | FileCheck %s
+
+func.func @dc_1d(%i_lb: index, %i_ub: index, %i_st: index) {
+ fir.do_concurrent {
+ %i = fir.alloca i32
+ fir.do_concurrent.loop (%i_iv) = (%i_lb) to (%i_ub) step (%i_st) {
+ %0 = fir.convert %i_iv : (index) -> i32
+ fir.store %0 to %i : !fir.ref<i32>
+ }
+ }
+ return
+}
+
+// CHECK-LABEL: func.func @dc_1d(
+// CHECK-SAME: %[[I_LB:[^[:space:]]+]]: index,
+// CHECK-SAME: %[[I_UB:[^[:space:]]+]]: index,
+// CHECK-SAME: %[[I_ST:[^[:space:]]+]]: index) {
+
+// CHECK: %[[I:.*]] = fir.alloca i32
+
+// CHECK: fir.do_loop %[[I_IV:.*]] = %[[I_LB]] to %[[I_UB]] step %[[I_ST]] unordered {
+// CHECK: %[[I_IV_CVT:.*]] = fir.convert %[[I_IV]] : (index) -> i32
+// CHECK: fir.store %[[I_IV_CVT]] to %[[I]] : !fir.ref<i32>
+// CHECK: }
+
+// CHECK: return
+// CHECK: }
+
+// -----
+
+func.func @dc_2d(%i_lb: index, %i_ub: index, %i_st: index,
+ %j_lb: index, %j_ub: index, %j_st: index) {
+ llvm.br ^bb1
+
+^bb1:
+ fir.do_concurrent {
+ %i = fir.alloca i32
+ %j = fir.alloca i32
+ fir.do_concurrent.loop
+ (%i_iv, %j_iv) = (%i_lb, %j_lb) to (%i_ub, %j_ub) step (%i_st, %j_st) {
+ %0 = fir.convert %i_iv : (index) -> i32
+ fir.store %0 to %i : !fir.ref<i32>
+
+ %1 = fir.convert %j_iv : (index) -> i32
+ fir.store %1 to %j : !fir.ref<i32>
+ }
+ }
+ return
+}
+
+// CHECK-LABEL: func.func @dc_2d(
+// CHECK-SAME: %[[I_LB:[^[:space:]]+]]: index,
+// CHECK-SAME: %[[I_UB:[^[:space:]]+]]: index,
+// CHECK-SAME: %[[I_ST:[^[:space:]]+]]: index,
+// CHECK-SAME: %[[J_LB:[^[:space:]]+]]: index,
+// CHECK-SAME: %[[J_UB:[^[:space:]]+]]: index,
+// CHECK-SAME: %[[J_ST:[^[:space:]]+]]: index) {
+
+// CHECK: %[[I:.*]] = fir.alloca i32
+// CHECK: %[[J:.*]] = fir.alloca i32
+// CHECK: llvm.br ^bb1
+
+// CHECK: ^bb1:
+// CHECK: fir.do_loop %[[I_IV:.*]] = %[[I_LB]] to %[[I_UB]] step %[[I_ST]] unordered {
+// CHECK: fir.do_loop %[[J_IV:.*]] = %[[J_LB]] to %[[J_UB]] step %[[J_ST]] unordered {
+// CHECK: %[[I_IV_CVT:.*]] = fir.convert %[[I_IV]] : (index) -> i32
+// CHECK: fir.store %[[I_IV_CVT]] to %[[I]] : !fir.ref<i32>
+// CHECK: %[[J_IV_CVT:.*]] = fir.convert %[[J_IV]] : (index) -> i32
+// CHECK: fir.store %[[J_IV_CVT]] to %[[J]] : !fir.ref<i32>
+// CHECK: }
+// CHECK: }
+
+// CHECK: return
+// CHECK: }
+
+// -----
+
+func.func @dc_2d_reduction(%i_lb: index, %i_ub: index, %i_st: index,
+ %j_lb: index, %j_ub: index, %j_st: index) {
+ %sum = fir.alloca i32
+
+ fir.do_concurrent {
+ %i = fir.alloca i32
+ %j = fir.alloca i32
+ fir.do_concurrent.loop
+ (%i_iv, %j_iv) = (%i_lb, %j_lb) to (%i_ub, %j_ub) step (%i_st, %j_st)
+ reduce(#fir.reduce_attr<add> -> %sum : !fir.ref<i32>) {
+ %0 = fir.convert %i_iv : (index) -> i32
+ fir.store %0 to %i : !fir.ref<i32>
+
+ %1 = fir.convert %j_iv : (index) -> i32
+ fir.store %1 to %j : !fir.ref<i32>
+ }
+ }
+ return
+}
+
+// CHECK-LABEL: func.func @dc_2d_reduction(
+// CHECK-SAME: %[[I_LB:[^[:space:]]+]]: index,
+// CHECK-SAME: %[[I_UB:[^[:space:]]+]]: index,
+// CHECK-SAME: %[[I_ST:[^[:space:]]+]]: index,
+// CHECK-SAME: %[[J_LB:[^[:space:]]+]]: index,
+// CHECK-SAME: %[[J_UB:[^[:space:]]+]]: index,
+// CHECK-SAME: %[[J_ST:[^[:space:]]+]]: index) {
+
+// CHECK: %[[I:.*]] = fir.alloca i32
+// CHECK: %[[J:.*]] = fir.alloca i32
+// CHECK: %[[SUM:.*]] = fir.alloca i32
+
+// CHECK: fir.do_loop %[[I_IV:.*]] = %[[I_LB]] to %[[I_UB]] step %[[I_ST]] unordered reduce({{.*}}<add>] -> %[[SUM]] : !fir.ref<i32>) {
+// CHECK: fir.do_loop %[[J_IV:.*]] = %[[J_LB]] to %[[J_UB]] step %[[J_ST]] unordered reduce({{.*}}<add>] -> %[[SUM]] : !fir.ref<i32>) {
+// CHECK: %[[I_IV_CVT:.*]] = fir.convert %[[I_IV]] : (index) -> i32
+// CHECK: fir.store %[[I_IV_CVT]] to %[[I]] : !fir.ref<i32>
+// CHECK: %[[J_IV_CVT:.*]] = fir.convert %[[J_IV]] : (index) -> i32
+// CHECK: fir.store %[[J_IV_CVT]] to %[[J]] : !fir.ref<i32>
+// CHECK: fir.result
+// CHECK: }
+// CHECK: fir.result
+// CHECK: }
+// CHECK: return
+// CHECK: }
More information about the flang-commits
mailing list