[flang-commits] [flang] [flang][fir] Add FIR structured control flow ops to SCF dialect pass. (PR #140374)

via flang-commits flang-commits at lists.llvm.org
Sat May 17 04:47:06 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-fir-hlfir

Author: MingYan (NexMing)

<details>
<summary>Changes</summary>

Convert FIR structured control flow ops to SCF dialect.

---
Full diff: https://github.com/llvm/llvm-project/pull/140374.diff


6 Files Affected:

- (modified) flang/include/flang/Optimizer/Support/InitFIR.h (+2) 
- (modified) flang/include/flang/Optimizer/Transforms/Passes.h (+1) 
- (modified) flang/include/flang/Optimizer/Transforms/Passes.td (+11) 
- (modified) flang/lib/Optimizer/Transforms/CMakeLists.txt (+1) 
- (added) flang/lib/Optimizer/Transforms/FIRToSCF.cpp (+103) 
- (added) flang/test/Fir/FirToSCF/do-loop.fir (+147) 


``````````diff
diff --git a/flang/include/flang/Optimizer/Support/InitFIR.h b/flang/include/flang/Optimizer/Support/InitFIR.h
index 1868fbb201970..fa7c430ed631c 100644
--- a/flang/include/flang/Optimizer/Support/InitFIR.h
+++ b/flang/include/flang/Optimizer/Support/InitFIR.h
@@ -30,6 +30,7 @@
 #include "mlir/Pass/PassRegistry.h"
 #include "mlir/Transforms/LocationSnapshot.h"
 #include "mlir/Transforms/Passes.h"
+#include <mlir/Dialect/SCF/Transforms/Passes.h>
 
 namespace fir::support {
 
@@ -103,6 +104,7 @@ inline void registerMLIRPassesForFortranTools() {
   mlir::registerPrintOpStatsPass();
   mlir::registerInlinerPass();
   mlir::registerSCCPPass();
+  mlir::registerSCFPasses();
   mlir::affine::registerAffineScalarReplacementPass();
   mlir::registerSymbolDCEPass();
   mlir::registerLocationSnapshotPass();
diff --git a/flang/include/flang/Optimizer/Transforms/Passes.h b/flang/include/flang/Optimizer/Transforms/Passes.h
index 6dbabd523f88a..dc8a5b9141ad2 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.h
+++ b/flang/include/flang/Optimizer/Transforms/Passes.h
@@ -72,6 +72,7 @@ std::unique_ptr<mlir::Pass>
 createArrayValueCopyPass(fir::ArrayValueCopyOptions options = {});
 std::unique_ptr<mlir::Pass> createMemDataFlowOptPass();
 std::unique_ptr<mlir::Pass> createPromoteToAffinePass();
+std::unique_ptr<mlir::Pass> createFIRToSCFPass();
 std::unique_ptr<mlir::Pass>
 createAddDebugInfoPass(fir::AddDebugInfoOptions options = {});
 
diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index c0d88a8e19f80..da3d9bc751927 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -76,6 +76,17 @@ def AffineDialectDemotion : Pass<"demote-affine", "::mlir::func::FuncOp"> {
   ];
 }
 
+def FIRToSCFPass : Pass<"fir-to-scf"> {
+  let summary = "Convert FIR structured control flow ops to SCF dialect.";
+  let description = [{
+    Convert FIR structured control flow ops to SCF dialect.
+  }];
+  let constructor = "::fir::createFIRToSCFPass()";
+  let dependentDialects = [
+    "fir::FIROpsDialect", "mlir::scf::SCFDialect"
+  ];
+}
+
 def AnnotateConstantOperands : Pass<"annotate-constant"> {
   let summary = "Annotate constant operands to all FIR operations";
   let description = [{
diff --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt
index 170b6e2cca225..846d6c64dbd04 100644
--- a/flang/lib/Optimizer/Transforms/CMakeLists.txt
+++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt
@@ -16,6 +16,7 @@ add_flang_library(FIRTransforms
   CUFComputeSharedMemoryOffsetsAndSize.cpp
   ArrayValueCopy.cpp
   ExternalNameConversion.cpp
+  FIRToSCF.cpp
   MemoryUtils.cpp
   MemoryAllocation.cpp
   StackArrays.cpp
diff --git a/flang/lib/Optimizer/Transforms/FIRToSCF.cpp b/flang/lib/Optimizer/Transforms/FIRToSCF.cpp
new file mode 100644
index 0000000000000..02810f1bdba4e
--- /dev/null
+++ b/flang/lib/Optimizer/Transforms/FIRToSCF.cpp
@@ -0,0 +1,103 @@
+//===-- FIRToSCF.cpp ------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "flang/Optimizer/Transforms/Passes.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace fir {
+#define GEN_PASS_DEF_FIRTOSCFPASS
+#include "flang/Optimizer/Transforms/Passes.h.inc"
+} // namespace fir
+
+using namespace fir;
+using namespace mlir;
+
+namespace {
+class FIRToSCFPass : public fir::impl::FIRToSCFPassBase<FIRToSCFPass> {
+public:
+  void runOnOperation() override;
+};
+} // namespace
+
+struct DoLoopConversion : public OpRewritePattern<fir::DoLoopOp> {
+  using OpRewritePattern<fir::DoLoopOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(fir::DoLoopOp doLoopOp,
+                                PatternRewriter &rewriter) const override {
+    auto loc = doLoopOp.getLoc();
+    bool hasFinalValue = doLoopOp.getFinalValue().has_value();
+
+    // Get loop values from the DoLoopOp
+    auto low = doLoopOp.getLowerBound();
+    auto high = doLoopOp.getUpperBound();
+    assert(low && high && "must be a Value");
+    auto step = doLoopOp.getStep();
+    llvm::SmallVector<mlir::Value> iterArgs;
+    if (hasFinalValue)
+      iterArgs.push_back(low);
+    iterArgs.append(doLoopOp.getIterOperands().begin(),
+                    doLoopOp.getIterOperands().end());
+
+    // Caculate the trip count.
+    auto diff = rewriter.create<mlir::arith::SubIOp>(loc, high, low);
+    auto distance = rewriter.create<mlir::arith::AddIOp>(loc, diff, step);
+    auto tripCount = rewriter.create<mlir::arith::DivSIOp>(loc, distance, step);
+    auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0);
+    auto one = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 1);
+    auto scfForOp =
+        rewriter.create<scf::ForOp>(loc, zero, tripCount, one, iterArgs);
+
+    auto &loopOps = doLoopOp.getBody()->getOperations();
+    auto resultOp = cast<fir::ResultOp>(doLoopOp.getBody()->getTerminator());
+    auto results = resultOp.getOperands();
+    Block *loweredBody = scfForOp.getBody();
+
+    loweredBody->getOperations().splice(loweredBody->begin(), loopOps,
+                                        loopOps.begin(),
+                                        std::prev(loopOps.end()));
+
+    rewriter.setInsertionPointToStart(loweredBody);
+    Value iv =
+        rewriter.create<arith::MulIOp>(loc, scfForOp.getInductionVar(), step);
+    iv = rewriter.create<arith::AddIOp>(loc, low, iv);
+
+    if (!results.empty()) {
+      rewriter.setInsertionPointToEnd(loweredBody);
+      rewriter.create<scf::YieldOp>(resultOp->getLoc(), results);
+    }
+    doLoopOp.getInductionVar().replaceAllUsesWith(iv);
+    rewriter.replaceAllUsesWith(doLoopOp.getRegionIterArgs(),
+                                hasFinalValue
+                                    ? scfForOp.getRegionIterArgs().drop_front()
+                                    : scfForOp.getRegionIterArgs());
+
+    // Copy loop annotations from the do loop to the loop entry condition.
+    if (auto ann = doLoopOp.getLoopAnnotation())
+      scfForOp->setAttr("loop_annotation", *ann);
+
+    rewriter.replaceOp(doLoopOp, scfForOp);
+    return success();
+  }
+};
+
+void FIRToSCFPass::runOnOperation() {
+  RewritePatternSet patterns(&getContext());
+  patterns.add<DoLoopConversion>(patterns.getContext());
+  ConversionTarget target(getContext());
+  target.addIllegalOp<fir::DoLoopOp>();
+  target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
+  if (failed(
+          applyPartialConversion(getOperation(), target, std::move(patterns))))
+    signalPassFailure();
+}
+
+std::unique_ptr<mlir::Pass> fir::createFIRToSCFPass() {
+  return std::make_unique<FIRToSCFPass>();
+}
diff --git a/flang/test/Fir/FirToSCF/do-loop.fir b/flang/test/Fir/FirToSCF/do-loop.fir
new file mode 100644
index 0000000000000..c3c24ccc1db71
--- /dev/null
+++ b/flang/test/Fir/FirToSCF/do-loop.fir
@@ -0,0 +1,147 @@
+// RUN: fir-opt %s --fir-to-scf | FileCheck %s
+
+// CHECK-LABEL:   func.func @simple_loop(
+// CHECK-SAME:      %[[ARG0:.*]]: !fir.ref<!fir.array<100xi32>>) {
+// CHECK:           %[[VAL_0:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_1:.*]] = arith.constant 100 : index
+// CHECK:           %[[VAL_2:.*]] = fir.shape %[[VAL_1]] : (index) -> !fir.shape<1>
+// CHECK:           %[[VAL_3:.*]] = arith.constant 1 : i32
+// CHECK:           %[[VAL_4:.*]] = arith.subi %[[VAL_1]], %[[VAL_0]] : index
+// CHECK:           %[[VAL_5:.*]] = arith.addi %[[VAL_4]], %[[VAL_0]] : index
+// CHECK:           %[[VAL_6:.*]] = arith.divsi %[[VAL_5]], %[[VAL_0]] : index
+// CHECK:           %[[VAL_7:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_8:.*]] = arith.constant 1 : index
+// CHECK:           scf.for %[[VAL_9:.*]] = %[[VAL_7]] to %[[VAL_6]] step %[[VAL_8]] {
+// CHECK:             %[[VAL_10:.*]] = arith.muli %[[VAL_9]], %[[VAL_0]] : index
+// CHECK:             %[[VAL_11:.*]] = arith.addi %[[VAL_0]], %[[VAL_10]] : index
+// CHECK:             %[[VAL_12:.*]] = fir.array_coor %[[ARG0]](%[[VAL_2]]) %[[VAL_11]] : (!fir.ref<!fir.array<100xi32>>, !fir.shape<1>, index) -> !fir.ref<i32>
+// CHECK:             fir.store %[[VAL_3]] to %[[VAL_12]] : !fir.ref<i32>
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+func.func @simple_loop(%arg0: !fir.ref<!fir.array<100xi32>>) {
+  %c1 = arith.constant 1 : index
+  %c100 = arith.constant 100 : index
+  %0 = fir.shape %c100 : (index) -> !fir.shape<1>
+  %c1_i32 = arith.constant 1 : i32
+  fir.do_loop %arg1 = %c1 to %c100 step %c1 {
+    %1 = fir.array_coor %arg0(%0) %arg1 : (!fir.ref<!fir.array<100xi32>>, !fir.shape<1>, index) -> !fir.ref<i32>
+    fir.store %c1_i32 to %1 : !fir.ref<i32>
+  }
+  return
+}
+
+// CHECK-LABEL:   func.func @loop_with_negtive_step(
+// CHECK-SAME:      %[[ARG0:.*]]: !fir.ref<!fir.array<100xi32>>) {
+// CHECK:           %[[VAL_0:.*]] = arith.constant 100 : index
+// CHECK:           %[[VAL_1:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_2:.*]] = arith.constant -1 : index
+// CHECK:           %[[VAL_3:.*]] = fir.shape %[[VAL_0]] : (index) -> !fir.shape<1>
+// CHECK:           %[[VAL_4:.*]] = arith.constant 1 : i32
+// CHECK:           %[[VAL_5:.*]] = arith.subi %[[VAL_1]], %[[VAL_0]] : index
+// CHECK:           %[[VAL_6:.*]] = arith.addi %[[VAL_5]], %[[VAL_2]] : index
+// CHECK:           %[[VAL_7:.*]] = arith.divsi %[[VAL_6]], %[[VAL_2]] : index
+// CHECK:           %[[VAL_8:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_9:.*]] = arith.constant 1 : index
+// CHECK:           scf.for %[[VAL_10:.*]] = %[[VAL_8]] to %[[VAL_7]] step %[[VAL_9]] {
+// CHECK:             %[[VAL_11:.*]] = arith.muli %[[VAL_10]], %[[VAL_2]] : index
+// CHECK:             %[[VAL_12:.*]] = arith.addi %[[VAL_0]], %[[VAL_11]] : index
+// CHECK:             %[[VAL_13:.*]] = fir.array_coor %[[ARG0]](%[[VAL_3]]) %[[VAL_12]] : (!fir.ref<!fir.array<100xi32>>, !fir.shape<1>, index) -> !fir.ref<i32>
+// CHECK:             fir.store %[[VAL_4]] to %[[VAL_13]] : !fir.ref<i32>
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+func.func @loop_with_negtive_step(%arg0: !fir.ref<!fir.array<100xi32>>) {
+  %c100 = arith.constant 100 : index
+  %c1 = arith.constant 1 : index
+  %c-1 = arith.constant -1 : index
+  %0 = fir.shape %c100 : (index) -> !fir.shape<1>
+  %c1_i32 = arith.constant 1 : i32
+  fir.do_loop %arg1 = %c100 to %c1 step %c-1 {
+    %1 = fir.array_coor %arg0(%0) %arg1 : (!fir.ref<!fir.array<100xi32>>, !fir.shape<1>, index) -> !fir.ref<i32>
+    fir.store %c1_i32 to %1 : !fir.ref<i32>
+  }
+  return
+}
+
+// CHECK-LABEL:   func.func @loop_with_results(
+// CHECK-SAME:      %[[ARG0:.*]]: !fir.ref<!fir.array<100xi32>>,
+// CHECK-SAME:      %[[ARG1:.*]]: !fir.ref<i32>) {
+// CHECK:           %[[VAL_0:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_1:.*]] = arith.constant 0 : i32
+// CHECK:           %[[VAL_2:.*]] = arith.constant 100 : index
+// CHECK:           %[[VAL_3:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
+// CHECK:           %[[VAL_4:.*]] = arith.subi %[[VAL_2]], %[[VAL_0]] : index
+// CHECK:           %[[VAL_5:.*]] = arith.addi %[[VAL_4]], %[[VAL_0]] : index
+// CHECK:           %[[VAL_6:.*]] = arith.divsi %[[VAL_5]], %[[VAL_0]] : index
+// CHECK:           %[[VAL_7:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_8:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_9:.*]] = scf.for %[[VAL_10:.*]] = %[[VAL_7]] to %[[VAL_6]] step %[[VAL_8]] iter_args(%[[VAL_11:.*]] = %[[VAL_1]]) -> (i32) {
+// CHECK:             %[[VAL_12:.*]] = arith.muli %[[VAL_10]], %[[VAL_0]] : index
+// CHECK:             %[[VAL_13:.*]] = arith.addi %[[VAL_0]], %[[VAL_12]] : index
+// CHECK:             %[[VAL_14:.*]] = fir.array_coor %[[ARG0]](%[[VAL_3]]) %[[VAL_13]] : (!fir.ref<!fir.array<100xi32>>, !fir.shape<1>, index) -> !fir.ref<i32>
+// CHECK:             %[[VAL_15:.*]] = fir.load %[[VAL_14]] : !fir.ref<i32>
+// CHECK:             %[[VAL_16:.*]] = arith.addi %[[VAL_11]], %[[VAL_15]] : i32
+// CHECK:             scf.yield %[[VAL_16]] : i32
+// CHECK:           }
+// CHECK:           fir.store %[[VAL_9]] to %[[ARG1]] : !fir.ref<i32>
+// CHECK:           return
+// CHECK:         }
+func.func @loop_with_results(%arg0: !fir.ref<!fir.array<100xi32>>, %arg1: !fir.ref<i32>) {
+  %c1 = arith.constant 1 : index
+  %c0_i32 = arith.constant 0 : i32
+  %c100 = arith.constant 100 : index
+  %0 = fir.shape %c100 : (index) -> !fir.shape<1>
+  %1 = fir.do_loop %arg2 = %c1 to %c100 step %c1 iter_args(%arg3 = %c0_i32) -> (i32) {
+    %2 = fir.array_coor %arg0(%0) %arg2 : (!fir.ref<!fir.array<100xi32>>, !fir.shape<1>, index) -> !fir.ref<i32>
+    %3 = fir.load %2 : !fir.ref<i32>
+    %4 = arith.addi %arg3, %3 : i32
+    fir.result %4 : i32
+  }
+  fir.store %1 to %arg1 : !fir.ref<i32>
+  return
+}
+
+// CHECK-LABEL:   func.func @loop_with_final_value(
+// CHECK-SAME:      %[[ARG0:.*]]: !fir.ref<!fir.array<100xi32>>,
+// CHECK-SAME:      %[[ARG1:.*]]: !fir.ref<i32>) {
+// CHECK:           %[[VAL_0:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_1:.*]] = arith.constant 0 : i32
+// CHECK:           %[[VAL_2:.*]] = arith.constant 100 : index
+// CHECK:           %[[VAL_3:.*]] = fir.alloca index
+// CHECK:           %[[VAL_4:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
+// CHECK:           %[[VAL_5:.*]] = arith.subi %[[VAL_2]], %[[VAL_0]] : index
+// CHECK:           %[[VAL_6:.*]] = arith.addi %[[VAL_5]], %[[VAL_0]] : index
+// CHECK:           %[[VAL_7:.*]] = arith.divsi %[[VAL_6]], %[[VAL_0]] : index
+// CHECK:           %[[VAL_8:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_9:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_10:.*]]:2 = scf.for %[[VAL_11:.*]] = %[[VAL_8]] to %[[VAL_7]] step %[[VAL_9]] iter_args(%[[VAL_12:.*]] = %[[VAL_0]], %[[VAL_13:.*]] = %[[VAL_1]]) -> (index, i32) {
+// CHECK:             %[[VAL_14:.*]] = arith.muli %[[VAL_11]], %[[VAL_0]] : index
+// CHECK:             %[[VAL_15:.*]] = arith.addi %[[VAL_0]], %[[VAL_14]] : index
+// CHECK:             %[[VAL_16:.*]] = fir.array_coor %[[ARG0]](%[[VAL_4]]) %[[VAL_15]] : (!fir.ref<!fir.array<100xi32>>, !fir.shape<1>, index) -> !fir.ref<i32>
+// CHECK:             %[[VAL_17:.*]] = fir.load %[[VAL_16]] : !fir.ref<i32>
+// CHECK:             %[[VAL_18:.*]] = arith.addi %[[VAL_15]], %[[VAL_0]] overflow<nsw> : index
+// CHECK:             %[[VAL_19:.*]] = arith.addi %[[VAL_13]], %[[VAL_17]] overflow<nsw> : i32
+// CHECK:             scf.yield %[[VAL_18]], %[[VAL_19]] : index, i32
+// CHECK:           }
+// CHECK:           fir.store %[[VAL_20:.*]]#0 to %[[VAL_3]] : !fir.ref<index>
+// CHECK:           fir.store %[[VAL_20]]#1 to %[[ARG1]] : !fir.ref<i32>
+// CHECK:           return
+// CHECK:         }
+func.func @loop_with_final_value(%arg0: !fir.ref<!fir.array<100xi32>>, %arg1: !fir.ref<i32>) {
+  %c1 = arith.constant 1 : index
+  %c0_i32 = arith.constant 0 : i32
+  %c100 = arith.constant 100 : index
+  %0 = fir.alloca index
+  %1 = fir.shape %c100 : (index) -> !fir.shape<1>
+  %2:2 = fir.do_loop %arg2 = %c1 to %c100 step %c1 iter_args(%arg3 = %c0_i32) -> (index, i32) {
+    %3 = fir.array_coor %arg0(%1) %arg2 : (!fir.ref<!fir.array<100xi32>>, !fir.shape<1>, index) -> !fir.ref<i32>
+    %4 = fir.load %3 : !fir.ref<i32>
+    %5 = arith.addi %arg2, %c1 overflow<nsw> : index
+    %6 = arith.addi %arg3, %4 overflow<nsw> : i32
+    fir.result %5, %6 : index, i32
+  }
+  fir.store %2#0 to %0 : !fir.ref<index>
+  fir.store %2#1 to %arg1 : !fir.ref<i32>
+  return
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/140374


More information about the flang-commits mailing list