[flang-commits] [flang] [flang][fir] Add FIR structured control flow ops to SCF dialect pass. (PR #140374)
via flang-commits
flang-commits at lists.llvm.org
Sat May 17 04:47:06 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-fir-hlfir
Author: MingYan (NexMing)
<details>
<summary>Changes</summary>
Convert FIR structured control flow ops to SCF dialect.
---
Full diff: https://github.com/llvm/llvm-project/pull/140374.diff
6 Files Affected:
- (modified) flang/include/flang/Optimizer/Support/InitFIR.h (+2)
- (modified) flang/include/flang/Optimizer/Transforms/Passes.h (+1)
- (modified) flang/include/flang/Optimizer/Transforms/Passes.td (+11)
- (modified) flang/lib/Optimizer/Transforms/CMakeLists.txt (+1)
- (added) flang/lib/Optimizer/Transforms/FIRToSCF.cpp (+103)
- (added) flang/test/Fir/FirToSCF/do-loop.fir (+147)
``````````diff
diff --git a/flang/include/flang/Optimizer/Support/InitFIR.h b/flang/include/flang/Optimizer/Support/InitFIR.h
index 1868fbb201970..fa7c430ed631c 100644
--- a/flang/include/flang/Optimizer/Support/InitFIR.h
+++ b/flang/include/flang/Optimizer/Support/InitFIR.h
@@ -30,6 +30,7 @@
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/LocationSnapshot.h"
#include "mlir/Transforms/Passes.h"
+#include <mlir/Dialect/SCF/Transforms/Passes.h>
namespace fir::support {
@@ -103,6 +104,7 @@ inline void registerMLIRPassesForFortranTools() {
mlir::registerPrintOpStatsPass();
mlir::registerInlinerPass();
mlir::registerSCCPPass();
+ mlir::registerSCFPasses();
mlir::affine::registerAffineScalarReplacementPass();
mlir::registerSymbolDCEPass();
mlir::registerLocationSnapshotPass();
diff --git a/flang/include/flang/Optimizer/Transforms/Passes.h b/flang/include/flang/Optimizer/Transforms/Passes.h
index 6dbabd523f88a..dc8a5b9141ad2 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.h
+++ b/flang/include/flang/Optimizer/Transforms/Passes.h
@@ -72,6 +72,7 @@ std::unique_ptr<mlir::Pass>
createArrayValueCopyPass(fir::ArrayValueCopyOptions options = {});
std::unique_ptr<mlir::Pass> createMemDataFlowOptPass();
std::unique_ptr<mlir::Pass> createPromoteToAffinePass();
+std::unique_ptr<mlir::Pass> createFIRToSCFPass();
std::unique_ptr<mlir::Pass>
createAddDebugInfoPass(fir::AddDebugInfoOptions options = {});
diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index c0d88a8e19f80..da3d9bc751927 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -76,6 +76,17 @@ def AffineDialectDemotion : Pass<"demote-affine", "::mlir::func::FuncOp"> {
];
}
+def FIRToSCFPass : Pass<"fir-to-scf"> {
+ let summary = "Convert FIR structured control flow ops to SCF dialect.";
+ let description = [{
+ Convert FIR structured control flow ops to SCF dialect.
+ }];
+ let constructor = "::fir::createFIRToSCFPass()";
+ let dependentDialects = [
+ "fir::FIROpsDialect", "mlir::scf::SCFDialect"
+ ];
+}
+
def AnnotateConstantOperands : Pass<"annotate-constant"> {
let summary = "Annotate constant operands to all FIR operations";
let description = [{
diff --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt
index 170b6e2cca225..846d6c64dbd04 100644
--- a/flang/lib/Optimizer/Transforms/CMakeLists.txt
+++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt
@@ -16,6 +16,7 @@ add_flang_library(FIRTransforms
CUFComputeSharedMemoryOffsetsAndSize.cpp
ArrayValueCopy.cpp
ExternalNameConversion.cpp
+ FIRToSCF.cpp
MemoryUtils.cpp
MemoryAllocation.cpp
StackArrays.cpp
diff --git a/flang/lib/Optimizer/Transforms/FIRToSCF.cpp b/flang/lib/Optimizer/Transforms/FIRToSCF.cpp
new file mode 100644
index 0000000000000..02810f1bdba4e
--- /dev/null
+++ b/flang/lib/Optimizer/Transforms/FIRToSCF.cpp
@@ -0,0 +1,103 @@
+//===-- FIRToSCF.cpp ------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "flang/Optimizer/Transforms/Passes.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace fir {
+#define GEN_PASS_DEF_FIRTOSCFPASS
+#include "flang/Optimizer/Transforms/Passes.h.inc"
+} // namespace fir
+
+using namespace fir;
+using namespace mlir;
+
+namespace {
+class FIRToSCFPass : public fir::impl::FIRToSCFPassBase<FIRToSCFPass> {
+public:
+ void runOnOperation() override;
+};
+} // namespace
+
+struct DoLoopConversion : public OpRewritePattern<fir::DoLoopOp> {
+ using OpRewritePattern<fir::DoLoopOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(fir::DoLoopOp doLoopOp,
+ PatternRewriter &rewriter) const override {
+ auto loc = doLoopOp.getLoc();
+ bool hasFinalValue = doLoopOp.getFinalValue().has_value();
+
+ // Get loop values from the DoLoopOp
+ auto low = doLoopOp.getLowerBound();
+ auto high = doLoopOp.getUpperBound();
+ assert(low && high && "must be a Value");
+ auto step = doLoopOp.getStep();
+ llvm::SmallVector<mlir::Value> iterArgs;
+ if (hasFinalValue)
+ iterArgs.push_back(low);
+ iterArgs.append(doLoopOp.getIterOperands().begin(),
+ doLoopOp.getIterOperands().end());
+
+ // Caculate the trip count.
+ auto diff = rewriter.create<mlir::arith::SubIOp>(loc, high, low);
+ auto distance = rewriter.create<mlir::arith::AddIOp>(loc, diff, step);
+ auto tripCount = rewriter.create<mlir::arith::DivSIOp>(loc, distance, step);
+ auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0);
+ auto one = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 1);
+ auto scfForOp =
+ rewriter.create<scf::ForOp>(loc, zero, tripCount, one, iterArgs);
+
+ auto &loopOps = doLoopOp.getBody()->getOperations();
+ auto resultOp = cast<fir::ResultOp>(doLoopOp.getBody()->getTerminator());
+ auto results = resultOp.getOperands();
+ Block *loweredBody = scfForOp.getBody();
+
+ loweredBody->getOperations().splice(loweredBody->begin(), loopOps,
+ loopOps.begin(),
+ std::prev(loopOps.end()));
+
+ rewriter.setInsertionPointToStart(loweredBody);
+ Value iv =
+ rewriter.create<arith::MulIOp>(loc, scfForOp.getInductionVar(), step);
+ iv = rewriter.create<arith::AddIOp>(loc, low, iv);
+
+ if (!results.empty()) {
+ rewriter.setInsertionPointToEnd(loweredBody);
+ rewriter.create<scf::YieldOp>(resultOp->getLoc(), results);
+ }
+ doLoopOp.getInductionVar().replaceAllUsesWith(iv);
+ rewriter.replaceAllUsesWith(doLoopOp.getRegionIterArgs(),
+ hasFinalValue
+ ? scfForOp.getRegionIterArgs().drop_front()
+ : scfForOp.getRegionIterArgs());
+
+ // Copy loop annotations from the do loop to the loop entry condition.
+ if (auto ann = doLoopOp.getLoopAnnotation())
+ scfForOp->setAttr("loop_annotation", *ann);
+
+ rewriter.replaceOp(doLoopOp, scfForOp);
+ return success();
+ }
+};
+
+void FIRToSCFPass::runOnOperation() {
+ RewritePatternSet patterns(&getContext());
+ patterns.add<DoLoopConversion>(patterns.getContext());
+ ConversionTarget target(getContext());
+ target.addIllegalOp<fir::DoLoopOp>();
+ target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
+ if (failed(
+ applyPartialConversion(getOperation(), target, std::move(patterns))))
+ signalPassFailure();
+}
+
+std::unique_ptr<mlir::Pass> fir::createFIRToSCFPass() {
+ return std::make_unique<FIRToSCFPass>();
+}
diff --git a/flang/test/Fir/FirToSCF/do-loop.fir b/flang/test/Fir/FirToSCF/do-loop.fir
new file mode 100644
index 0000000000000..c3c24ccc1db71
--- /dev/null
+++ b/flang/test/Fir/FirToSCF/do-loop.fir
@@ -0,0 +1,147 @@
+// RUN: fir-opt %s --fir-to-scf | FileCheck %s
+
+// CHECK-LABEL: func.func @simple_loop(
+// CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.array<100xi32>>) {
+// CHECK: %[[VAL_0:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_1:.*]] = arith.constant 100 : index
+// CHECK: %[[VAL_2:.*]] = fir.shape %[[VAL_1]] : (index) -> !fir.shape<1>
+// CHECK: %[[VAL_3:.*]] = arith.constant 1 : i32
+// CHECK: %[[VAL_4:.*]] = arith.subi %[[VAL_1]], %[[VAL_0]] : index
+// CHECK: %[[VAL_5:.*]] = arith.addi %[[VAL_4]], %[[VAL_0]] : index
+// CHECK: %[[VAL_6:.*]] = arith.divsi %[[VAL_5]], %[[VAL_0]] : index
+// CHECK: %[[VAL_7:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_8:.*]] = arith.constant 1 : index
+// CHECK: scf.for %[[VAL_9:.*]] = %[[VAL_7]] to %[[VAL_6]] step %[[VAL_8]] {
+// CHECK: %[[VAL_10:.*]] = arith.muli %[[VAL_9]], %[[VAL_0]] : index
+// CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_0]], %[[VAL_10]] : index
+// CHECK: %[[VAL_12:.*]] = fir.array_coor %[[ARG0]](%[[VAL_2]]) %[[VAL_11]] : (!fir.ref<!fir.array<100xi32>>, !fir.shape<1>, index) -> !fir.ref<i32>
+// CHECK: fir.store %[[VAL_3]] to %[[VAL_12]] : !fir.ref<i32>
+// CHECK: }
+// CHECK: return
+// CHECK: }
+func.func @simple_loop(%arg0: !fir.ref<!fir.array<100xi32>>) {
+ %c1 = arith.constant 1 : index
+ %c100 = arith.constant 100 : index
+ %0 = fir.shape %c100 : (index) -> !fir.shape<1>
+ %c1_i32 = arith.constant 1 : i32
+ fir.do_loop %arg1 = %c1 to %c100 step %c1 {
+ %1 = fir.array_coor %arg0(%0) %arg1 : (!fir.ref<!fir.array<100xi32>>, !fir.shape<1>, index) -> !fir.ref<i32>
+ fir.store %c1_i32 to %1 : !fir.ref<i32>
+ }
+ return
+}
+
+// CHECK-LABEL: func.func @loop_with_negtive_step(
+// CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.array<100xi32>>) {
+// CHECK: %[[VAL_0:.*]] = arith.constant 100 : index
+// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_2:.*]] = arith.constant -1 : index
+// CHECK: %[[VAL_3:.*]] = fir.shape %[[VAL_0]] : (index) -> !fir.shape<1>
+// CHECK: %[[VAL_4:.*]] = arith.constant 1 : i32
+// CHECK: %[[VAL_5:.*]] = arith.subi %[[VAL_1]], %[[VAL_0]] : index
+// CHECK: %[[VAL_6:.*]] = arith.addi %[[VAL_5]], %[[VAL_2]] : index
+// CHECK: %[[VAL_7:.*]] = arith.divsi %[[VAL_6]], %[[VAL_2]] : index
+// CHECK: %[[VAL_8:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_9:.*]] = arith.constant 1 : index
+// CHECK: scf.for %[[VAL_10:.*]] = %[[VAL_8]] to %[[VAL_7]] step %[[VAL_9]] {
+// CHECK: %[[VAL_11:.*]] = arith.muli %[[VAL_10]], %[[VAL_2]] : index
+// CHECK: %[[VAL_12:.*]] = arith.addi %[[VAL_0]], %[[VAL_11]] : index
+// CHECK: %[[VAL_13:.*]] = fir.array_coor %[[ARG0]](%[[VAL_3]]) %[[VAL_12]] : (!fir.ref<!fir.array<100xi32>>, !fir.shape<1>, index) -> !fir.ref<i32>
+// CHECK: fir.store %[[VAL_4]] to %[[VAL_13]] : !fir.ref<i32>
+// CHECK: }
+// CHECK: return
+// CHECK: }
+func.func @loop_with_negtive_step(%arg0: !fir.ref<!fir.array<100xi32>>) {
+ %c100 = arith.constant 100 : index
+ %c1 = arith.constant 1 : index
+ %c-1 = arith.constant -1 : index
+ %0 = fir.shape %c100 : (index) -> !fir.shape<1>
+ %c1_i32 = arith.constant 1 : i32
+ fir.do_loop %arg1 = %c100 to %c1 step %c-1 {
+ %1 = fir.array_coor %arg0(%0) %arg1 : (!fir.ref<!fir.array<100xi32>>, !fir.shape<1>, index) -> !fir.ref<i32>
+ fir.store %c1_i32 to %1 : !fir.ref<i32>
+ }
+ return
+}
+
+// CHECK-LABEL: func.func @loop_with_results(
+// CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.array<100xi32>>,
+// CHECK-SAME: %[[ARG1:.*]]: !fir.ref<i32>) {
+// CHECK: %[[VAL_0:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_1:.*]] = arith.constant 0 : i32
+// CHECK: %[[VAL_2:.*]] = arith.constant 100 : index
+// CHECK: %[[VAL_3:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
+// CHECK: %[[VAL_4:.*]] = arith.subi %[[VAL_2]], %[[VAL_0]] : index
+// CHECK: %[[VAL_5:.*]] = arith.addi %[[VAL_4]], %[[VAL_0]] : index
+// CHECK: %[[VAL_6:.*]] = arith.divsi %[[VAL_5]], %[[VAL_0]] : index
+// CHECK: %[[VAL_7:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_8:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_9:.*]] = scf.for %[[VAL_10:.*]] = %[[VAL_7]] to %[[VAL_6]] step %[[VAL_8]] iter_args(%[[VAL_11:.*]] = %[[VAL_1]]) -> (i32) {
+// CHECK: %[[VAL_12:.*]] = arith.muli %[[VAL_10]], %[[VAL_0]] : index
+// CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_0]], %[[VAL_12]] : index
+// CHECK: %[[VAL_14:.*]] = fir.array_coor %[[ARG0]](%[[VAL_3]]) %[[VAL_13]] : (!fir.ref<!fir.array<100xi32>>, !fir.shape<1>, index) -> !fir.ref<i32>
+// CHECK: %[[VAL_15:.*]] = fir.load %[[VAL_14]] : !fir.ref<i32>
+// CHECK: %[[VAL_16:.*]] = arith.addi %[[VAL_11]], %[[VAL_15]] : i32
+// CHECK: scf.yield %[[VAL_16]] : i32
+// CHECK: }
+// CHECK: fir.store %[[VAL_9]] to %[[ARG1]] : !fir.ref<i32>
+// CHECK: return
+// CHECK: }
+func.func @loop_with_results(%arg0: !fir.ref<!fir.array<100xi32>>, %arg1: !fir.ref<i32>) {
+ %c1 = arith.constant 1 : index
+ %c0_i32 = arith.constant 0 : i32
+ %c100 = arith.constant 100 : index
+ %0 = fir.shape %c100 : (index) -> !fir.shape<1>
+ %1 = fir.do_loop %arg2 = %c1 to %c100 step %c1 iter_args(%arg3 = %c0_i32) -> (i32) {
+ %2 = fir.array_coor %arg0(%0) %arg2 : (!fir.ref<!fir.array<100xi32>>, !fir.shape<1>, index) -> !fir.ref<i32>
+ %3 = fir.load %2 : !fir.ref<i32>
+ %4 = arith.addi %arg3, %3 : i32
+ fir.result %4 : i32
+ }
+ fir.store %1 to %arg1 : !fir.ref<i32>
+ return
+}
+
+// CHECK-LABEL: func.func @loop_with_final_value(
+// CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.array<100xi32>>,
+// CHECK-SAME: %[[ARG1:.*]]: !fir.ref<i32>) {
+// CHECK: %[[VAL_0:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_1:.*]] = arith.constant 0 : i32
+// CHECK: %[[VAL_2:.*]] = arith.constant 100 : index
+// CHECK: %[[VAL_3:.*]] = fir.alloca index
+// CHECK: %[[VAL_4:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
+// CHECK: %[[VAL_5:.*]] = arith.subi %[[VAL_2]], %[[VAL_0]] : index
+// CHECK: %[[VAL_6:.*]] = arith.addi %[[VAL_5]], %[[VAL_0]] : index
+// CHECK: %[[VAL_7:.*]] = arith.divsi %[[VAL_6]], %[[VAL_0]] : index
+// CHECK: %[[VAL_8:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_9:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_10:.*]]:2 = scf.for %[[VAL_11:.*]] = %[[VAL_8]] to %[[VAL_7]] step %[[VAL_9]] iter_args(%[[VAL_12:.*]] = %[[VAL_0]], %[[VAL_13:.*]] = %[[VAL_1]]) -> (index, i32) {
+// CHECK: %[[VAL_14:.*]] = arith.muli %[[VAL_11]], %[[VAL_0]] : index
+// CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_0]], %[[VAL_14]] : index
+// CHECK: %[[VAL_16:.*]] = fir.array_coor %[[ARG0]](%[[VAL_4]]) %[[VAL_15]] : (!fir.ref<!fir.array<100xi32>>, !fir.shape<1>, index) -> !fir.ref<i32>
+// CHECK: %[[VAL_17:.*]] = fir.load %[[VAL_16]] : !fir.ref<i32>
+// CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_15]], %[[VAL_0]] overflow<nsw> : index
+// CHECK: %[[VAL_19:.*]] = arith.addi %[[VAL_13]], %[[VAL_17]] overflow<nsw> : i32
+// CHECK: scf.yield %[[VAL_18]], %[[VAL_19]] : index, i32
+// CHECK: }
+// CHECK: fir.store %[[VAL_20:.*]]#0 to %[[VAL_3]] : !fir.ref<index>
+// CHECK: fir.store %[[VAL_20]]#1 to %[[ARG1]] : !fir.ref<i32>
+// CHECK: return
+// CHECK: }
+func.func @loop_with_final_value(%arg0: !fir.ref<!fir.array<100xi32>>, %arg1: !fir.ref<i32>) {
+ %c1 = arith.constant 1 : index
+ %c0_i32 = arith.constant 0 : i32
+ %c100 = arith.constant 100 : index
+ %0 = fir.alloca index
+ %1 = fir.shape %c100 : (index) -> !fir.shape<1>
+ %2:2 = fir.do_loop %arg2 = %c1 to %c100 step %c1 iter_args(%arg3 = %c0_i32) -> (index, i32) {
+ %3 = fir.array_coor %arg0(%1) %arg2 : (!fir.ref<!fir.array<100xi32>>, !fir.shape<1>, index) -> !fir.ref<i32>
+ %4 = fir.load %3 : !fir.ref<i32>
+ %5 = arith.addi %arg2, %c1 overflow<nsw> : index
+ %6 = arith.addi %arg3, %4 overflow<nsw> : i32
+ fir.result %5, %6 : index, i32
+ }
+ fir.store %2#0 to %0 : !fir.ref<index>
+ fir.store %2#1 to %arg1 : !fir.ref<i32>
+ return
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/140374
More information about the flang-commits
mailing list