[flang-commits] [flang] 8d24b73 - [flang][LoopVersioning] support reboxed operands

Tom Eccles via flang-commits flang-commits at lists.llvm.org
Wed Aug 23 03:11:34 PDT 2023


Author: Tom Eccles
Date: 2023-08-23T09:53:05Z
New Revision: 8d24b7322ee55eb780fc8115bfa8af07b6ee66b7

URL: https://github.com/llvm/llvm-project/commit/8d24b7322ee55eb780fc8115bfa8af07b6ee66b7
DIFF: https://github.com/llvm/llvm-project/commit/8d24b7322ee55eb780fc8115bfa8af07b6ee66b7.diff

LOG: [flang][LoopVersioning] support reboxed operands

Since https://reviews.llvm.org/D158119, many boxes lowered via HLFIR are
reboxed with better lower bounds information after they are declared.

For the loop versioning pass to support FIR lowered via HLFIR, it needs
to dereference fir.rebox operations to figure out that the variable was
a function argument.

I decided to modify the existing dereferencing of fir.declare so that
the declared/reboxed value is used in the versioned loop instead of the
function argument. This makes it easier for the improved lower bounds
information to be accessed. In doing this, I changed ArgInfo to store
ArgInfo::arg by value instead of by pointer because mlir::Value has
value-type semantics.

Differential Revision: https://reviews.llvm.org/D158408

Added: 
    

Modified: 
    flang/lib/Optimizer/Transforms/LoopVersioning.cpp
    flang/test/Transforms/loop-versioning.fir

Removed: 
    


################################################################################
diff  --git a/flang/lib/Optimizer/Transforms/LoopVersioning.cpp b/flang/lib/Optimizer/Transforms/LoopVersioning.cpp
index c191df21d5abc1..56dcadf6cbab39 100644
--- a/flang/lib/Optimizer/Transforms/LoopVersioning.cpp
+++ b/flang/lib/Optimizer/Transforms/LoopVersioning.cpp
@@ -104,6 +104,21 @@ static mlir::Value unwrapFirDeclare(mlir::Value val) {
   return val;
 }
 
+/// if a value comes from a fir.rebox, follow the rebox to the original source,
+/// of the value, otherwise return the value
+static mlir::Value unwrapReboxOp(mlir::Value val) {
+  // don't support reboxes of reboxes
+  if (fir::ReboxOp rebox = val.getDefiningOp<fir::ReboxOp>())
+    val = rebox.getBox();
+  return val;
+}
+
+/// normalize a value (removing fir.declare and fir.rebox) so that we can
+/// more conveniently spot values which came from function arguments
+static mlir::Value normaliseVal(mlir::Value val) {
+  return unwrapFirDeclare(unwrapReboxOp(val));
+}
+
 void LoopVersioningPass::runOnOperation() {
   LLVM_DEBUG(llvm::dbgs() << "=== Begin " DEBUG_TYPE " ===\n");
   mlir::func::FuncOp func = getOperation();
@@ -112,7 +127,7 @@ void LoopVersioningPass::runOnOperation() {
   /// A structure to hold an argument, the size of the argument and dimension
   /// information.
   struct ArgInfo {
-    mlir::Value *arg;
+    mlir::Value arg;
     size_t size;
     unsigned rank;
     fir::BoxDimsOp dims[CFI_MAX_RANK];
@@ -138,7 +153,7 @@ void LoopVersioningPass::runOnOperation() {
         else if (auto cty = elementType.dyn_cast<fir::ComplexType>())
           typeSize = 2 * cty.getEleType(kindMap).getIntOrFloatBitWidth() / 8;
         if (typeSize)
-          argsOfInterest.push_back({&arg, typeSize, rank, {}});
+          argsOfInterest.push_back({arg, typeSize, rank, {}});
         else
           LLVM_DEBUG(llvm::dbgs() << "Type not supported\n");
       }
@@ -166,7 +181,9 @@ void LoopVersioningPass::runOnOperation() {
         return;
       mlir::Value operand = op->getOperand(0);
       for (auto a : argsOfInterest) {
-        if (*a.arg == unwrapFirDeclare(operand)) {
+        if (a.arg == normaliseVal(operand)) {
+          // use the reboxed value, not the block arg when re-creating the loop:
+          a.arg = operand;
           // Only add if it's not already in the list.
           if (std::find_if(argsInLoop.begin(), argsInLoop.end(), [&](auto it) {
                 return it.arg == a.arg;
@@ -211,7 +228,7 @@ void LoopVersioningPass::runOnOperation() {
       for (unsigned i = 0; i < ndims; i++) {
         mlir::Value dimIdx = builder.createIntegerConstant(loc, idxTy, i);
         arg.dims[i] = builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy,
-                                                     *arg.arg, dimIdx);
+                                                     arg.arg, dimIdx);
       }
       // We only care about lowest order dimension, here.
       mlir::Value elemSize =
@@ -238,11 +255,11 @@ void LoopVersioningPass::runOnOperation() {
     for (auto &arg : op.argsAndDims) {
       fir::SequenceType::Shape newShape;
       newShape.push_back(fir::SequenceType::getUnknownExtent());
-      auto elementType = fir::unwrapSeqOrBoxedSeqType(arg.arg->getType());
+      auto elementType = fir::unwrapSeqOrBoxedSeqType(arg.arg.getType());
       mlir::Type arrTy = fir::SequenceType::get(newShape, elementType);
       mlir::Type boxArrTy = fir::BoxType::get(arrTy);
       mlir::Type refArrTy = builder.getRefType(arrTy);
-      auto carg = builder.create<fir::ConvertOp>(loc, boxArrTy, *arg.arg);
+      auto carg = builder.create<fir::ConvertOp>(loc, boxArrTy, arg.arg);
       auto caddr = builder.create<fir::BoxAddrOp>(loc, refArrTy, carg);
       auto insPt = builder.saveInsertionPoint();
       // Use caddr instead of arg.
@@ -254,8 +271,7 @@ void LoopVersioningPass::runOnOperation() {
         // arr(x, y, z) bedcomes arr(z * stride(2) + y * stride(1) + x)
         // where stride is the distance between elements in the dimensions
         // 0, 1 and 2 or x, y and z.
-        if (unwrapFirDeclare(coop->getOperand(0)) == *arg.arg &&
-            coop->getOperands().size() >= 2) {
+        if (coop->getOperand(0) == arg.arg && coop->getOperands().size() >= 2) {
           builder.setInsertionPoint(coop);
           mlir::Value totalIndex;
           for (unsigned i = arg.rank - 1; i > 0; i--) {

diff  --git a/flang/test/Transforms/loop-versioning.fir b/flang/test/Transforms/loop-versioning.fir
index 8659273f117502..8e2fde0711f2cb 100644
--- a/flang/test/Transforms/loop-versioning.fir
+++ b/flang/test/Transforms/loop-versioning.fir
@@ -14,6 +14,7 @@
 module {
   func.func @sum1d(%arg0: !fir.box<!fir.array<?xf64>> {fir.bindc_name = "a"}, %arg1: !fir.ref<i32> {fir.bindc_name = "n"}) {
     %decl = fir.declare %arg0 {uniq_name = "a"} : (!fir.box<!fir.array<?xf64>>) -> !fir.box<!fir.array<?xf64>>
+    %rebox = fir.rebox %decl : (!fir.box<!fir.array<?xf64>>) -> !fir.box<!fir.array<?xf64>>
     %0 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QMmoduleFsum1dEi"}
     %1 = fir.alloca f64 {bindc_name = "sum", uniq_name = "_QMmoduleFsum1dEsum"}
     %cst = arith.constant 0.000000e+00 : f64
@@ -31,7 +32,7 @@ module {
       %9 = fir.convert %8 : (i32) -> i64
       %c1_i64 = arith.constant 1 : i64
       %10 = arith.subi %9, %c1_i64 : i64
-      %11 = fir.coordinate_of %decl, %10 : (!fir.box<!fir.array<?xf64>>, i64) -> !fir.ref<f64>
+      %11 = fir.coordinate_of %rebox, %10 : (!fir.box<!fir.array<?xf64>>, i64) -> !fir.ref<f64>
       %12 = fir.load %11 : !fir.ref<f64>
       %13 = arith.addf %7, %12 fastmath<contract> : f64
       fir.store %13 to %1 : !fir.ref<f64>
@@ -49,12 +50,13 @@ module {
 // CHECK-LABEL: func.func @sum1d(
 // CHECK-SAME:                  %[[ARG0:.*]]: !fir.box<!fir.array<?xf64>> {{.*}})
 // CHECK: %[[DECL:.*]] = fir.declare %arg0 {uniq_name = "a"} : (!fir.box<!fir.array<?xf64>>) -> !fir.box<!fir.array<?xf64>>
+// CHECK: %[[REBOX:.*]] = fir.rebox %[[DECL]]
 // CHECK: %[[ZERO:.*]] = arith.constant 0 : index
-// CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[ARG0]], %[[ZERO]] : {{.*}}
+// CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[REBOX]], %[[ZERO]] : {{.*}}
 // CHECK: %[[SIZE:.*]] = arith.constant 8 : index
 // CHECK: %[[CMP:.*]] = arith.cmpi eq, %[[DIMS]]#2, %[[SIZE]]
 // CHECK: %[[IF_RES:.*]]:2 = fir.if %[[CMP]] -> {{.*}}
-// CHECK: %[[NEWARR:.*]] = fir.convert %[[ARG0]]
+// CHECK: %[[NEWARR:.*]] = fir.convert %[[REBOX]]
 // CHECK: %[[BOXADDR:.*]] = fir.box_addr %[[NEWARR]] : {{.*}} -> !fir.ref<!fir.array<?xf64>>
 // CHECK: %[[LOOP_RES:.*]]:2 = fir.do_loop {{.*}}
 // CHECK: %[[COORD:.*]] = fir.coordinate_of %[[BOXADDR]], %{{.*}} : (!fir.ref<!fir.array<?xf64>>, index) -> !fir.ref<f64>
@@ -64,7 +66,7 @@ module {
 // CHECK  fir.result %[[LOOP_RES]]#0, %[[LOOP_RES]]#1
 // CHECK: } else {
 // CHECK: %[[LOOP_RES2:.*]]:2 = fir.do_loop {{.*}}
-// CHECK: %[[COORD2:.*]] = fir.coordinate_of %[[DECL]], %{{.*}} : (!fir.box<!fir.array<?xf64>>, i64) -> !fir.ref<f64>
+// CHECK: %[[COORD2:.*]] = fir.coordinate_of %[[REBOX]], %{{.*}} : (!fir.box<!fir.array<?xf64>>, i64) -> !fir.ref<f64>
 // CHECK: %{{.*}}= fir.load %[[COORD2]] : !fir.ref<f64>
 // CHECK: fir.result %{{.*}}, %{{.*}}
 // CHECK: }


        


More information about the flang-commits mailing list