[flang-commits] [flang] [fir] Support promoting `fir.do_loop` with results to `affine.for`. (PR #137790)

via flang-commits flang-commits at lists.llvm.org
Thu May 8 00:53:05 PDT 2025


https://github.com/NexMing updated https://github.com/llvm/llvm-project/pull/137790

>From 17dfda28fdf8eb8184283b686e6831a3c8b7a9ab Mon Sep 17 00:00:00 2001
From: yanming <ming.yan at terapines.com>
Date: Tue, 29 Apr 2025 19:16:48 +0800
Subject: [PATCH 1/2] [fir] Support promoting `fir.do_loop` with results to
 `affine.for`.

---
 .../Optimizer/Transforms/AffinePromotion.cpp  | 39 +++++++++--
 flang/test/Fir/affine-promotion.fir           | 65 +++++++++++++++++++
 2 files changed, 99 insertions(+), 5 deletions(-)

diff --git a/flang/lib/Optimizer/Transforms/AffinePromotion.cpp b/flang/lib/Optimizer/Transforms/AffinePromotion.cpp
index 43fccf52dc8ab..ef82e400bea14 100644
--- a/flang/lib/Optimizer/Transforms/AffinePromotion.cpp
+++ b/flang/lib/Optimizer/Transforms/AffinePromotion.cpp
@@ -49,8 +49,9 @@ struct AffineIfAnalysis;
 /// second when doing rewrite.
 struct AffineFunctionAnalysis {
   explicit AffineFunctionAnalysis(mlir::func::FuncOp funcOp) {
-    for (fir::DoLoopOp op : funcOp.getOps<fir::DoLoopOp>())
-      loopAnalysisMap.try_emplace(op, op, *this);
+    funcOp->walk([&](fir::DoLoopOp doloop) {
+      loopAnalysisMap.try_emplace(doloop, doloop, *this);
+    });
   }
 
   AffineLoopAnalysis getChildLoopAnalysis(fir::DoLoopOp op) const;
@@ -102,10 +103,23 @@ struct AffineLoopAnalysis {
     return true;
   }
 
+  bool analysisResults(fir::DoLoopOp loopOperation) {
+    if (loopOperation.getFinalValue() &&
+        !loopOperation.getResult(0).use_empty()) {
+      LLVM_DEBUG(
+          llvm::dbgs()
+              << "AffineLoopAnalysis: cannot promote loop final value\n";);
+      return false;
+    }
+
+    return true;
+  }
+
   bool analyzeLoop(fir::DoLoopOp loopOperation,
                    AffineFunctionAnalysis &functionAnalysis) {
     LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: \n"; loopOperation.dump(););
     return analyzeMemoryAccess(loopOperation) &&
+           analysisResults(loopOperation) &&
            analyzeBody(loopOperation, functionAnalysis);
   }
 
@@ -461,14 +475,28 @@ class AffineLoopConversion : public mlir::OpRewritePattern<fir::DoLoopOp> {
     LLVM_ATTRIBUTE_UNUSED auto loopAnalysis =
         functionAnalysis.getChildLoopAnalysis(loop);
     auto &loopOps = loop.getBody()->getOperations();
+    auto resultOp = cast<fir::ResultOp>(loop.getBody()->getTerminator());
+    auto results = resultOp.getOperands();
+    auto loopResults = loop->getResults();
     auto loopAndIndex = createAffineFor(loop, rewriter);
     auto affineFor = loopAndIndex.first;
     auto inductionVar = loopAndIndex.second;
 
+    if (loop.getFinalValue()) {
+      results = results.drop_front();
+      loopResults = loopResults.drop_front();
+    }
+
     rewriter.startOpModification(affineFor.getOperation());
     affineFor.getBody()->getOperations().splice(
         std::prev(affineFor.getBody()->end()), loopOps, loopOps.begin(),
         std::prev(loopOps.end()));
+    rewriter.replaceAllUsesWith(loop.getRegionIterArgs(),
+                                affineFor.getRegionIterArgs());
+    if (!results.empty()) {
+      rewriter.setInsertionPointToEnd(affineFor.getBody());
+      rewriter.create<affine::AffineYieldOp>(resultOp->getLoc(), results);
+    }
     rewriter.finalizeOpModification(affineFor.getOperation());
 
     rewriter.startOpModification(loop.getOperation());
@@ -479,7 +507,8 @@ class AffineLoopConversion : public mlir::OpRewritePattern<fir::DoLoopOp> {
 
     LLVM_DEBUG(llvm::dbgs() << "AffineLoopConversion: loop rewriten to:\n";
                affineFor.dump(););
-    rewriter.replaceOp(loop, affineFor.getOperation()->getResults());
+    rewriter.replaceAllUsesWith(loopResults, affineFor->getResults());
+    rewriter.eraseOp(loop);
     return success();
   }
 
@@ -503,7 +532,7 @@ class AffineLoopConversion : public mlir::OpRewritePattern<fir::DoLoopOp> {
         ValueRange(op.getUpperBound()),
         mlir::AffineMap::get(0, 1,
                              1 + mlir::getAffineSymbolExpr(0, op.getContext())),
-        step);
+        step, op.getIterOperands());
     return std::make_pair(affineFor, affineFor.getInductionVar());
   }
 
@@ -528,7 +557,7 @@ class AffineLoopConversion : public mlir::OpRewritePattern<fir::DoLoopOp> {
         genericUpperBound.getResult(),
         mlir::AffineMap::get(0, 1,
                              1 + mlir::getAffineSymbolExpr(0, op.getContext())),
-        1);
+        1, op.getIterOperands());
     rewriter.setInsertionPointToStart(affineFor.getBody());
     auto actualIndex = rewriter.create<affine::AffineApplyOp>(
         op.getLoc(), actualIndexMap,
diff --git a/flang/test/Fir/affine-promotion.fir b/flang/test/Fir/affine-promotion.fir
index aae35c6ef5659..f50f851a89eae 100644
--- a/flang/test/Fir/affine-promotion.fir
+++ b/flang/test/Fir/affine-promotion.fir
@@ -131,3 +131,68 @@ func.func @loop_with_if(%a: !arr_d1, %v: f32) {
 // CHECK:   }
 // CHECK:   return
 // CHECK: }
+
+func.func @loop_with_result(%arg0: !fir.ref<!fir.array<100xf32>>, %arg1: !fir.ref<!fir.array<100x100xf32>>) -> f32 {
+  %c1 = arith.constant 1 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %c100 = arith.constant 100 : index
+  %0 = fir.shape %c100 : (index) -> !fir.shape<1>
+  %1 = fir.shape %c100, %c100 : (index, index) -> !fir.shape<2>
+  %2 = fir.alloca i32
+  %3:2 = fir.do_loop %arg2 = %c1 to %c100 step %c1 iter_args(%arg3 = %cst) -> (index, f32) {
+    %6 = fir.array_coor %arg0(%0) %arg2 : (!fir.ref<!fir.array<100xf32>>, !fir.shape<1>, index) -> !fir.ref<f32>
+    %7 = fir.load %6 : !fir.ref<f32>
+    %8 = arith.addf %arg3, %7 fastmath<contract> : f32
+    %9 = arith.addi %arg2, %c1 overflow<nsw> : index
+    fir.result %9, %8 : index, f32
+  }
+  %4:2 = fir.do_loop %arg2 = %c1 to %c100 step %c1 iter_args(%arg3 = %3#1) -> (index, f32) {
+    %6 = fir.array_coor %arg1(%1) %c1, %arg2 : (!fir.ref<!fir.array<100x100xf32>>, !fir.shape<2>, index, index) -> !fir.ref<f32>
+    %7 = fir.convert %6 : (!fir.ref<f32>) -> !fir.ref<!fir.array<100xf32>>
+    %8 = fir.do_loop %arg4 = %c1 to %c100 step %c1 iter_args(%arg5 = %arg3) -> (f32) {
+      %10 = fir.array_coor %7(%0) %arg4 : (!fir.ref<!fir.array<100xf32>>, !fir.shape<1>, index) -> !fir.ref<f32>
+      %11 = fir.load %10 : !fir.ref<f32>
+      %12 = arith.addf %arg5, %11 fastmath<contract> : f32
+      fir.result %12 : f32
+    }
+    %9 = arith.addi %arg2, %c1 overflow<nsw> : index
+    fir.result %9, %8 : index, f32
+  }
+  %5 = fir.convert %4#0 : (index) -> i32
+  fir.store %5 to %2 : !fir.ref<i32>
+  return %4#1 : f32
+}
+
+// CHECK-LABEL:   func.func @loop_with_result(
+// CHECK-SAME:      %[[ARG0:.*]]: !fir.ref<!fir.array<100xf32>>,
+// CHECK-SAME:      %[[ARG1:.*]]: !fir.ref<!fir.array<100x100xf32>>) -> f32 {
+// CHECK:           %[[VAL_0:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_1:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[VAL_2:.*]] = arith.constant 100 : index
+// CHECK:           %[[VAL_3:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
+// CHECK:           %[[VAL_4:.*]] = fir.shape %[[VAL_2]], %[[VAL_2]] : (index, index) -> !fir.shape<2>
+// CHECK:           %[[VAL_5:.*]] = fir.alloca i32
+// CHECK:           %[[VAL_6:.*]] = fir.convert %[[ARG0]] : (!fir.ref<!fir.array<100xf32>>) -> memref<?xf32>
+// CHECK:           %[[VAL_7:.*]] = affine.for %[[VAL_8:.*]] = %[[VAL_0]] to #{{.*}}(){{\[}}%[[VAL_2]]] iter_args(%[[VAL_9:.*]] = %[[VAL_1]]) -> (f32) {
+// CHECK:             %[[VAL_10:.*]] = affine.apply #{{.*}}(%[[VAL_8]]){{\[}}%[[VAL_0]], %[[VAL_2]], %[[VAL_0]]]
+// CHECK:             %[[VAL_11:.*]] = affine.load %[[VAL_6]]{{\[}}%[[VAL_10]]] : memref<?xf32>
+// CHECK:             %[[VAL_12:.*]] = arith.addf %[[VAL_9]], %[[VAL_11]] fastmath<contract> : f32
+// CHECK:             affine.yield %[[VAL_12]] : f32
+// CHECK:           }
+// CHECK:           %[[VAL_13:.*]]:2 = fir.do_loop %[[VAL_14:.*]] = %[[VAL_0]] to %[[VAL_2]] step %[[VAL_0]] iter_args(%[[VAL_15:.*]] = %[[VAL_7]]) -> (index, f32) {
+// CHECK:             %[[VAL_16:.*]] = fir.array_coor %[[ARG1]](%[[VAL_4]]) %[[VAL_0]], %[[VAL_14]] : (!fir.ref<!fir.array<100x100xf32>>, !fir.shape<2>, index, index) -> !fir.ref<f32>
+// CHECK:             %[[VAL_17:.*]] = fir.convert %[[VAL_16]] : (!fir.ref<f32>) -> !fir.ref<!fir.array<100xf32>>
+// CHECK:             %[[VAL_18:.*]] = fir.convert %[[VAL_17]] : (!fir.ref<!fir.array<100xf32>>) -> memref<?xf32>
+// CHECK:             %[[VAL_19:.*]] = affine.for %[[VAL_20:.*]] = %[[VAL_0]] to #{{.*}}(){{\[}}%[[VAL_2]]] iter_args(%[[VAL_21:.*]] = %[[VAL_15]]) -> (f32) {
+// CHECK:               %[[VAL_22:.*]] = affine.apply #{{.*}}(%[[VAL_20]]){{\[}}%[[VAL_0]], %[[VAL_2]], %[[VAL_0]]]
+// CHECK:               %[[VAL_23:.*]] = affine.load %[[VAL_18]]{{\[}}%[[VAL_22]]] : memref<?xf32>
+// CHECK:               %[[VAL_24:.*]] = arith.addf %[[VAL_21]], %[[VAL_23]] fastmath<contract> : f32
+// CHECK:               affine.yield %[[VAL_24]] : f32
+// CHECK:             }
+// CHECK:             %[[VAL_25:.*]] = arith.addi %[[VAL_14]], %[[VAL_0]] overflow<nsw> : index
+// CHECK:             fir.result %[[VAL_25]], %[[VAL_19]] : index, f32
+// CHECK:           }
+// CHECK:           %[[VAL_26:.*]] = fir.convert %[[VAL_27:.*]]#0 : (index) -> i32
+// CHECK:           fir.store %[[VAL_26]] to %[[VAL_5]] : !fir.ref<i32>
+// CHECK:           return %[[VAL_27]]#1 : f32
+// CHECK:         }

>From 98746e859c3bb9fdd72ecdd562cd3b404b1fc98b Mon Sep 17 00:00:00 2001
From: yanming <ming.yan at terapines.com>
Date: Thu, 8 May 2025 13:50:54 +0800
Subject: [PATCH 2/2] Add a test that loop with multiple results.

---
 flang/test/Fir/affine-promotion.fir | 69 +++++++++++++++++++----------
 1 file changed, 45 insertions(+), 24 deletions(-)

diff --git a/flang/test/Fir/affine-promotion.fir b/flang/test/Fir/affine-promotion.fir
index f50f851a89eae..46467ab4a292a 100644
--- a/flang/test/Fir/affine-promotion.fir
+++ b/flang/test/Fir/affine-promotion.fir
@@ -132,40 +132,51 @@ func.func @loop_with_if(%a: !arr_d1, %v: f32) {
 // CHECK:   return
 // CHECK: }
 
-func.func @loop_with_result(%arg0: !fir.ref<!fir.array<100xf32>>, %arg1: !fir.ref<!fir.array<100x100xf32>>) -> f32 {
+func.func @loop_with_result(%arg0: !fir.ref<!fir.array<100xf32>>, %arg1: !fir.ref<!fir.array<100x100xf32>>, %arg2: !fir.ref<!fir.array<100xf32>>) -> f32 {
   %c1 = arith.constant 1 : index
   %cst = arith.constant 0.000000e+00 : f32
   %c100 = arith.constant 100 : index
   %0 = fir.shape %c100 : (index) -> !fir.shape<1>
   %1 = fir.shape %c100, %c100 : (index, index) -> !fir.shape<2>
   %2 = fir.alloca i32
-  %3:2 = fir.do_loop %arg2 = %c1 to %c100 step %c1 iter_args(%arg3 = %cst) -> (index, f32) {
-    %6 = fir.array_coor %arg0(%0) %arg2 : (!fir.ref<!fir.array<100xf32>>, !fir.shape<1>, index) -> !fir.ref<f32>
-    %7 = fir.load %6 : !fir.ref<f32>
-    %8 = arith.addf %arg3, %7 fastmath<contract> : f32
-    %9 = arith.addi %arg2, %c1 overflow<nsw> : index
-    fir.result %9, %8 : index, f32
+  %3:2 = fir.do_loop %arg3 = %c1 to %c100 step %c1 iter_args(%arg4 = %cst) -> (index, f32) {
+    %8 = fir.array_coor %arg0(%0) %arg3 : (!fir.ref<!fir.array<100xf32>>, !fir.shape<1>, index) -> !fir.ref<f32>
+    %9 = fir.load %8 : !fir.ref<f32>
+    %10 = arith.addf %arg4, %9 fastmath<contract> : f32
+    %11 = arith.addi %arg3, %c1 overflow<nsw> : index
+    fir.result %11, %10 : index, f32
   }
-  %4:2 = fir.do_loop %arg2 = %c1 to %c100 step %c1 iter_args(%arg3 = %3#1) -> (index, f32) {
-    %6 = fir.array_coor %arg1(%1) %c1, %arg2 : (!fir.ref<!fir.array<100x100xf32>>, !fir.shape<2>, index, index) -> !fir.ref<f32>
-    %7 = fir.convert %6 : (!fir.ref<f32>) -> !fir.ref<!fir.array<100xf32>>
-    %8 = fir.do_loop %arg4 = %c1 to %c100 step %c1 iter_args(%arg5 = %arg3) -> (f32) {
-      %10 = fir.array_coor %7(%0) %arg4 : (!fir.ref<!fir.array<100xf32>>, !fir.shape<1>, index) -> !fir.ref<f32>
-      %11 = fir.load %10 : !fir.ref<f32>
-      %12 = arith.addf %arg5, %11 fastmath<contract> : f32
-      fir.result %12 : f32
+  %4:2 = fir.do_loop %arg3 = %c1 to %c100 step %c1 iter_args(%arg4 = %3#1) -> (index, f32) {
+    %8 = fir.array_coor %arg1(%1) %c1, %arg3 : (!fir.ref<!fir.array<100x100xf32>>, !fir.shape<2>, index, index) -> !fir.ref<f32>
+    %9 = fir.convert %8 : (!fir.ref<f32>) -> !fir.ref<!fir.array<100xf32>>
+    %10 = fir.do_loop %arg5 = %c1 to %c100 step %c1 iter_args(%arg6 = %arg4) -> (f32) {
+      %12 = fir.array_coor %9(%0) %arg5 : (!fir.ref<!fir.array<100xf32>>, !fir.shape<1>, index) -> !fir.ref<f32>
+      %13 = fir.load %12 : !fir.ref<f32>
+      %14 = arith.addf %arg6, %13 fastmath<contract> : f32
+      fir.result %14 : f32
     }
-    %9 = arith.addi %arg2, %c1 overflow<nsw> : index
-    fir.result %9, %8 : index, f32
+    %11 = arith.addi %arg3, %c1 overflow<nsw> : index
+    fir.result %11, %10 : index, f32
   }
-  %5 = fir.convert %4#0 : (index) -> i32
-  fir.store %5 to %2 : !fir.ref<i32>
-  return %4#1 : f32
+  %5:2 = fir.do_loop %arg3 = %c1 to %c100 step %c1 iter_args(%arg4 = %4#1, %arg5 = %cst) -> (f32, f32) {
+    %8 = fir.array_coor %arg0(%0) %arg3 : (!fir.ref<!fir.array<100xf32>>, !fir.shape<1>, index) -> !fir.ref<f32>
+    %9 = fir.load %8 : !fir.ref<f32>
+    %10 = arith.addf %arg4, %9 fastmath<contract> : f32
+    %11 = fir.array_coor %arg2(%0) %arg3 : (!fir.ref<!fir.array<100xf32>>, !fir.shape<1>, index) -> !fir.ref<f32>
+    %12 = fir.load %11 : !fir.ref<f32>
+    %13 = arith.addf %arg5, %12 fastmath<contract> : f32
+    fir.result %10, %13 : f32, f32
+  }
+  %6 = arith.addf %5#0, %5#1 fastmath<contract> : f32
+  %7 = fir.convert %4#0 : (index) -> i32
+  fir.store %7 to %2 : !fir.ref<i32>
+  return %6 : f32
 }
 
 // CHECK-LABEL:   func.func @loop_with_result(
 // CHECK-SAME:      %[[ARG0:.*]]: !fir.ref<!fir.array<100xf32>>,
-// CHECK-SAME:      %[[ARG1:.*]]: !fir.ref<!fir.array<100x100xf32>>) -> f32 {
+// CHECK-SAME:      %[[ARG1:.*]]: !fir.ref<!fir.array<100x100xf32>>,
+// CHECK-SAME:      %[[ARG2:.*]]: !fir.ref<!fir.array<100xf32>>) -> f32 {
 // CHECK:           %[[VAL_0:.*]] = arith.constant 1 : index
 // CHECK:           %[[VAL_1:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK:           %[[VAL_2:.*]] = arith.constant 100 : index
@@ -192,7 +203,17 @@ func.func @loop_with_result(%arg0: !fir.ref<!fir.array<100xf32>>, %arg1: !fir.re
 // CHECK:             %[[VAL_25:.*]] = arith.addi %[[VAL_14]], %[[VAL_0]] overflow<nsw> : index
 // CHECK:             fir.result %[[VAL_25]], %[[VAL_19]] : index, f32
 // CHECK:           }
-// CHECK:           %[[VAL_26:.*]] = fir.convert %[[VAL_27:.*]]#0 : (index) -> i32
-// CHECK:           fir.store %[[VAL_26]] to %[[VAL_5]] : !fir.ref<i32>
-// CHECK:           return %[[VAL_27]]#1 : f32
+// CHECK:           %[[VAL_26:.*]] = fir.convert %[[ARG2]] : (!fir.ref<!fir.array<100xf32>>) -> memref<?xf32>
+// CHECK:           %[[VAL_27:.*]]:2 = affine.for %[[VAL_28:.*]] = %[[VAL_0]] to #{{.*}}(){{\[}}%[[VAL_2]]] iter_args(%[[VAL_29:.*]] = %[[VAL_30:.*]]#1, %[[VAL_31:.*]] = %[[VAL_1]]) -> (f32, f32) {
+// CHECK:             %[[VAL_32:.*]] = affine.apply #{{.*}}(%[[VAL_28]]){{\[}}%[[VAL_0]], %[[VAL_2]], %[[VAL_0]]]
+// CHECK:             %[[VAL_33:.*]] = affine.load %[[VAL_6]]{{\[}}%[[VAL_32]]] : memref<?xf32>
+// CHECK:             %[[VAL_34:.*]] = arith.addf %[[VAL_29]], %[[VAL_33]] fastmath<contract> : f32
+// CHECK:             %[[VAL_35:.*]] = affine.load %[[VAL_26]]{{\[}}%[[VAL_32]]] : memref<?xf32>
+// CHECK:             %[[VAL_36:.*]] = arith.addf %[[VAL_31]], %[[VAL_35]] fastmath<contract> : f32
+// CHECK:             affine.yield %[[VAL_34]], %[[VAL_36]] : f32, f32
+// CHECK:           }
+// CHECK:           %[[VAL_37:.*]] = arith.addf %[[VAL_38:.*]]#0, %[[VAL_38]]#1 fastmath<contract> : f32
+// CHECK:           %[[VAL_39:.*]] = fir.convert %[[VAL_40:.*]]#0 : (index) -> i32
+// CHECK:           fir.store %[[VAL_39]] to %[[VAL_5]] : !fir.ref<i32>
+// CHECK:           return %[[VAL_37]] : f32
 // CHECK:         }



More information about the flang-commits mailing list