[flang-commits] [flang] 0501102 - [flang][LoopVersioning] support fir.declare

Tom Eccles via flang-commits flang-commits at lists.llvm.org
Fri Aug 18 02:53:52 PDT 2023


Author: Tom Eccles
Date: 2023-08-18T09:51:22Z
New Revision: 05011024fd0a27f54e3fd9c3fb7c42f6a86ab391

URL: https://github.com/llvm/llvm-project/commit/05011024fd0a27f54e3fd9c3fb7c42f6a86ab391
DIFF: https://github.com/llvm/llvm-project/commit/05011024fd0a27f54e3fd9c3fb7c42f6a86ab391.diff

LOG: [flang][LoopVersioning] support fir.declare

When FIR comes from HLFIR, there will be a fir.declare operation between
the source and the usage of each source variable (and some temporary
allocations). This pass needs to be able to follow these so that it can
still transform loops when HLFIR is used, otherwise it mistakenly
assumes these values are not function arguments.

More work is needed after this patch to fully support HLFIR, because the
generated code tends to use fir.array_coor instead of fir.coordinate_of.

Differential Revision: https://reviews.llvm.org/D157964

Added: 
    

Modified: 
    flang/lib/Optimizer/Transforms/LoopVersioning.cpp
    flang/test/Transforms/loop-versioning.fir

Removed: 
    


################################################################################
diff  --git a/flang/lib/Optimizer/Transforms/LoopVersioning.cpp b/flang/lib/Optimizer/Transforms/LoopVersioning.cpp
index 511c80673f21dd..63786e377bb4a5 100644
--- a/flang/lib/Optimizer/Transforms/LoopVersioning.cpp
+++ b/flang/lib/Optimizer/Transforms/LoopVersioning.cpp
@@ -94,6 +94,16 @@ static fir::SequenceType getAsSequenceType(mlir::Value *v) {
   return argTy.dyn_cast<fir::SequenceType>();
 }
 
+/// if a value comes from a fir.declare, follow it to the original source,
+/// otherwise return the value
+static mlir::Value unwrapFirDeclare(mlir::Value val) {
+  // fir.declare is for source code variables. We don't have declares of
+  // declares
+  if (fir::DeclareOp declare = val.getDefiningOp<fir::DeclareOp>())
+    return declare.getMemref();
+  return val;
+}
+
 void LoopVersioningPass::runOnOperation() {
   LLVM_DEBUG(llvm::dbgs() << "=== Begin " DEBUG_TYPE " ===\n");
   mlir::func::FuncOp func = getOperation();
@@ -154,9 +164,9 @@ void LoopVersioningPass::runOnOperation() {
       // to it later.
       if (op->getParentOfType<fir::DoLoopOp>() != loop)
         return;
-      const mlir::Value &operand = op->getOperand(0);
+      mlir::Value operand = op->getOperand(0);
       for (auto a : argsOfInterest) {
-        if (*a.arg == operand) {
+        if (*a.arg == unwrapFirDeclare(operand)) {
           // Only add if it's not already in the list.
           if (std::find_if(argsInLoop.begin(), argsInLoop.end(), [&](auto it) {
                 return it.arg == a.arg;
@@ -244,7 +254,7 @@ void LoopVersioningPass::runOnOperation() {
         // arr(x, y, z) bedcomes arr(z * stride(2) + y * stride(1) + x)
         // where stride is the distance between elements in the dimensions
         // 0, 1 and 2 or x, y and z.
-        if (coop->getOperand(0) == *arg.arg &&
+        if (unwrapFirDeclare(coop->getOperand(0)) == *arg.arg &&
             coop->getOperands().size() >= 2) {
           builder.setInsertionPoint(coop);
           mlir::Value totalIndex;

diff  --git a/flang/test/Transforms/loop-versioning.fir b/flang/test/Transforms/loop-versioning.fir
index 6fc8eb852c1cf3..8659273f117502 100644
--- a/flang/test/Transforms/loop-versioning.fir
+++ b/flang/test/Transforms/loop-versioning.fir
@@ -13,6 +13,7 @@
 //  end subroutine sum1d
 module {
   func.func @sum1d(%arg0: !fir.box<!fir.array<?xf64>> {fir.bindc_name = "a"}, %arg1: !fir.ref<i32> {fir.bindc_name = "n"}) {
+    %decl = fir.declare %arg0 {uniq_name = "a"} : (!fir.box<!fir.array<?xf64>>) -> !fir.box<!fir.array<?xf64>>
     %0 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QMmoduleFsum1dEi"}
     %1 = fir.alloca f64 {bindc_name = "sum", uniq_name = "_QMmoduleFsum1dEsum"}
     %cst = arith.constant 0.000000e+00 : f64
@@ -30,7 +31,7 @@ module {
       %9 = fir.convert %8 : (i32) -> i64
       %c1_i64 = arith.constant 1 : i64
       %10 = arith.subi %9, %c1_i64 : i64
-      %11 = fir.coordinate_of %arg0, %10 : (!fir.box<!fir.array<?xf64>>, i64) -> !fir.ref<f64>
+      %11 = fir.coordinate_of %decl, %10 : (!fir.box<!fir.array<?xf64>>, i64) -> !fir.ref<f64>
       %12 = fir.load %11 : !fir.ref<f64>
       %13 = arith.addf %7, %12 fastmath<contract> : f64
       fir.store %13 to %1 : !fir.ref<f64>
@@ -47,6 +48,7 @@ module {
 // Note this only checks the expected transformation, not the entire generated code:
 // CHECK-LABEL: func.func @sum1d(
 // CHECK-SAME:                  %[[ARG0:.*]]: !fir.box<!fir.array<?xf64>> {{.*}})
+// CHECK: %[[DECL:.*]] = fir.declare %arg0 {uniq_name = "a"} : (!fir.box<!fir.array<?xf64>>) -> !fir.box<!fir.array<?xf64>>
 // CHECK: %[[ZERO:.*]] = arith.constant 0 : index
 // CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[ARG0]], %[[ZERO]] : {{.*}}
 // CHECK: %[[SIZE:.*]] = arith.constant 8 : index
@@ -62,7 +64,7 @@ module {
 // CHECK  fir.result %[[LOOP_RES]]#0, %[[LOOP_RES]]#1
 // CHECK: } else {
 // CHECK: %[[LOOP_RES2:.*]]:2 = fir.do_loop {{.*}}
-// CHECK: %[[COORD2:.*]] = fir.coordinate_of %[[ARG0]], %{{.*}} : (!fir.box<!fir.array<?xf64>>, i64) -> !fir.ref<f64>
+// CHECK: %[[COORD2:.*]] = fir.coordinate_of %[[DECL]], %{{.*}} : (!fir.box<!fir.array<?xf64>>, i64) -> !fir.ref<f64>
 // CHECK: %{{.*}}= fir.load %[[COORD2]] : !fir.ref<f64>
 // CHECK: fir.result %{{.*}}, %{{.*}}
 // CHECK: }


        


More information about the flang-commits mailing list