[Mlir-commits] [mlir] [mlir][scf] Fix loop iteration calculation for negative step in LoopPipelining (PR #110035)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Sep 25 13:00:41 PDT 2024


https://github.com/sjw36 created https://github.com/llvm/llvm-project/pull/110035

    This fixes loop iteration count calculation if the step is
    a negative value, where we should adjust the added
    delta from `step-1` to `step+1` when doing the ceil div.

>From 67286897d41283dd82de3617312afb9c63dc2bfa Mon Sep 17 00:00:00 2001
From: SJW <swaters at amd.com>
Date: Wed, 25 Sep 2024 19:56:08 +0000
Subject: [PATCH] [mlir][scf] Fix loop iteration calculation for negative step
 in LoopPipelining

    This fixes loop iteration count calculation if the step is
    a negative value, where we should adjust the added
    delta from `step-1` to `step+1` when doing the ceil div.
---
 .../Dialect/SCF/Transforms/LoopPipelining.cpp | 23 +++++----
 mlir/test/Dialect/SCF/loop-pipelining.mlir    | 47 ++++++++++++-------
 2 files changed, 44 insertions(+), 26 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index 3d6da066875f99..83c9cf69ba0364 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -648,15 +648,22 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
   // bounds_range = ub - lb
   // total_iterations = (bounds_range + step - 1) / step
   Type t = lb.getType();
-  Value minus1 =
-      rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1));
-  Value boundsRange = rewriter.create<arith::SubIOp>(loc, ub, lb);
-  Value rangeIncr = rewriter.create<arith::AddIOp>(loc, boundsRange, step);
-  Value rangeDecr = rewriter.create<arith::AddIOp>(loc, rangeIncr, minus1);
-  Value totalIterations = rewriter.create<arith::DivUIOp>(loc, rangeDecr, step);
-
   Value zero =
       rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, 0));
+  Value one =
+      rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, 1));
+  Value minusOne =
+      rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1));
+  Value stepLessZero = rewriter.create<arith::CmpIOp>(
+      loc, arith::CmpIPredicate::slt, step, zero);
+  Value stepDecr =
+      rewriter.create<arith::SelectOp>(loc, stepLessZero, one, minusOne);
+
+  Value rangeDiff = rewriter.create<arith::SubIOp>(loc, ub, lb);
+  Value rangeIncrStep = rewriter.create<arith::AddIOp>(loc, rangeDiff, step);
+  Value rangeDecr =
+      rewriter.create<arith::AddIOp>(loc, rangeIncrStep, stepDecr);
+  Value totalIterations = rewriter.create<arith::DivSIOp>(loc, rangeDecr, step);
 
   SmallVector<Value> predicates(maxStage + 1);
   for (int64_t i = 0; i < maxStage; i++) {
@@ -665,7 +672,7 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
     Value minusI =
         rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -i));
     Value iterI = rewriter.create<arith::AddIOp>(
-        loc, rewriter.create<arith::AddIOp>(loc, totalIterations, minus1),
+        loc, rewriter.create<arith::AddIOp>(loc, totalIterations, minusOne),
         minusI);
     // newLastIter = lb + step * iterI
     Value newlastIter = rewriter.create<arith::AddIOp>(
diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir
index 4747aad977a492..af49d2afc049ba 100644
--- a/mlir/test/Dialect/SCF/loop-pipelining.mlir
+++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir
@@ -766,8 +766,11 @@ func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>, %ub:
 
 // Check for predicated epilogue for dynamic loop.
 // CHECK-LABEL: dynamic_loop(
-//        CHECK:   %[[C0:.*]] = arith.constant 0 : index
-//        CHECK:   %{{.*}}:2 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}})
+//    CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+//    CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+//    CHECK-DAG:   %[[CM1:.*]] = arith.constant -1 : index
+//        CHECK:   %[[UBM:.*]] = arith.subi %[[UB:.*]], %{{.*}}
+//        CHECK:   %{{.*}}:2 = scf.for %[[ARG5:.*]] = %[[LB:.*]] to %[[UBM]] step %[[STEP:.*]] iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}})
 //        CHECK:       memref.store %[[ARG6]], %{{.*}}[%[[ARG5]]]
 //        CHECK:       %[[ADDF_24:.*]] = arith.addf %[[ARG7]], %{{.*}}
 //        CHECK:       %[[MULI_25:.*]] = arith.muli %{{.*}}, %{{.*}}
@@ -775,15 +778,17 @@ func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>, %ub:
 //        CHECK:       %[[LOAD_27:.*]] = memref.load %{{.*}}[%[[ADDI_26]]]
 //        CHECK:       scf.yield %[[ADDF_24]], %[[LOAD_27]]
 //        CHECK:   }
-//        CHECK:   %[[SUBI_10:.*]] = arith.subi %{{.*}}, %{{.*}}
-//        CHECK:   %[[ADDI_11:.*]] = arith.addi %[[SUBI_10]], %{{.*}}
-//        CHECK:   %[[ADDI_12:.*]] = arith.addi %[[ADDI_11]], %{{.*}}-1
-//        CHECK:   %[[DIVUI_13:.*]] = arith.divui %[[ADDI_12]], %{{.*}}
-//        CHECK:   %[[ADDI_14:.*]] = arith.addi %[[DIVUI_13]], %{{.*}}-1
+//        CHECK:   %[[CMPI_10:.*]] = arith.cmpi slt, %[[STEP]], %[[C0]]
+//        CHECK:   %[[SEL_10:.*]] = arith.select %[[CMPI_10]], %[[C1]], %[[CM1]]
+//        CHECK:   %[[SUBI_10:.*]] = arith.subi %[[UB]], %[[LB]]
+//        CHECK:   %[[ADDI_11:.*]] = arith.addi %[[SUBI_10]], %[[STEP]]
+//        CHECK:   %[[ADDI_12:.*]] = arith.addi %[[ADDI_11]], %[[SEL_10]]
+//        CHECK:   %[[DIVSI_13:.*]] = arith.divsi %[[ADDI_12]], %[[STEP]]
+//        CHECK:   %[[ADDI_14:.*]] = arith.addi %[[DIVSI_13]], %[[CM1]]
 //        CHECK:   %[[MULI_15:.*]] = arith.muli %{{.*}}, %[[ADDI_14]]
 //        CHECK:   %[[ADDI_16:.*]] = arith.addi %{{.*}}, %[[MULI_15]]
 //        CHECK:   %[[CMPI_17:.*]] = arith.cmpi sge, %[[ADDI_14]], %[[C0]]
-//        CHECK:   %[[ADDI_18:.*]] = arith.addi %[[DIVUI_13]], %{{.*}}-1
+//        CHECK:   %[[ADDI_18:.*]] = arith.addi %[[DIVSI_13]], %{{.*}}-1
 //        CHECK:   %[[ADDI_19:.*]] = arith.addi %[[ADDI_18]], %{{.*}}-1
 //        CHECK:   %[[MULI_20:.*]] = arith.muli %{{.*}}, %[[ADDI_19]]
 //        CHECK:   %[[ADDI_21:.*]] = arith.addi %{{.*}}, %[[MULI_20]]
@@ -834,32 +839,38 @@ func.func @dynamic_loop(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %
 
 // Check for predicated epilogue for dynamic loop.
 // CHECK-LABEL:   func.func @dynamic_loop_result
-//       CHECK:     %{{.*}}:2 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}})
+//   CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+//   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[CM1:.*]] = arith.constant -1 : index
+//       CHECK:   %[[UBM:.*]] = arith.subi %[[UB:.*]], %{{.*}}
+//       CHECK:   %{{.*}}:2 = scf.for %[[ARG5:.*]] = %[[LB:.*]] to %[[UBM]] step %[[STEP:.*]] iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}})
 //       CHECK:       %[[ADDF_13:.*]] = arith.addf %[[ARG7]], %[[ARG6]]
 //       CHECK:       %[[MULF_14:.*]] = arith.mulf %[[ADDF_13]], %{{.*}}
 //       CHECK:       %[[ADDI_15:.*]] = arith.addi %[[ARG5]], %{{.*}}
 //       CHECK:       %[[LOAD_16:.*]] = memref.load %{{.*}}[%[[ADDI_15]]]
 //       CHECK:       scf.yield %[[MULF_14]], %[[LOAD_16]]
 //       CHECK:     }
-//       CHECK:     %[[SUBI_4:.*]] = arith.subi %{{.*}}, %{{.*}}
-//       CHECK:     %[[ADDI_5:.*]] = arith.addi %[[SUBI_4]], %{{.*}}
-//       CHECK:     %[[ADDI_6:.*]] = arith.addi %[[ADDI_5]], %{{.*}}-1
-//       CHECK:     %[[DIVUI_7:.*]] = arith.divui %[[ADDI_6]], %{{.*}}
-//       CHECK:     %[[ADDI_8:.*]] = arith.addi %[[DIVUI_7]], %{{.*}}-1
-//       CHECK:     %[[CMPI_9:.*]] = arith.cmpi sge, %[[ADDI_8]], %{{.*}}
-//       CHECK:     %[[IF_10:.*]] = scf.if %[[CMPI_9]]
+//       CHECK:     %[[CMPI_4:.*]] = arith.cmpi slt, %[[STEP]], %[[C0]]
+//       CHECK:     %[[SELECT_5:.*]] = arith.select %[[CMPI_4]], %[[C1]], %[[CM1]]
+//       CHECK:     %[[SUBI_6:.*]] = arith.subi %[[UB]], %[[LB]]
+//       CHECK:     %[[ADDI_7:.*]] = arith.addi %[[SUBI_6]], %[[STEP]]
+//       CHECK:     %[[ADDI_8:.*]] = arith.addi %[[ADDI_7]], %[[SELECT_5]]
+//       CHECK:     %[[DIVSI_9:.*]] = arith.divsi %[[ADDI_8]], %[[STEP]]
+//       CHECK:     %[[ADDI_10:.*]] = arith.addi %[[DIVSI_9]], %[[CM1]]
+//       CHECK:     %[[CMPI_11:.*]] = arith.cmpi sge, %[[ADDI_10]], %[[C0]]
+//       CHECK:     %[[IF_10:.*]] = scf.if %[[CMPI_11]]
 //       CHECK:       %[[ADDF_13:.*]] = arith.addf %{{.*}}#1, %{{.*}}#0
 //       CHECK:       scf.yield %[[ADDF_13]]
 //       CHECK:     } else {
 //       CHECK:       scf.yield %{{.*}}
 //       CHECK:     }
-//       CHECK:     %[[IF_11:.*]] = scf.if %[[CMPI_9]]
+//       CHECK:     %[[IF_11:.*]] = scf.if %[[CMPI_11]]
 //       CHECK:       %[[MULF_13:.*]] = arith.mulf %[[IF_10]], %{{.*}}
 //       CHECK:       scf.yield %[[MULF_13]]
 //       CHECK:     } else {
 //       CHECK:       scf.yield %{{.*}}
 //       CHECK:     }
-//       CHECK:     %[[SELECT_12:.*]] = arith.select %[[CMPI_9]], %[[IF_11]], %{{.*}}#0
+//       CHECK:     %[[SELECT_12:.*]] = arith.select %[[CMPI_11]], %[[IF_11]], %{{.*}}#0
 //       CHECK:     memref.store %[[SELECT_12]], %{{.*}}[%{{.*}}]
 func.func @dynamic_loop_result(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %ub: index, %step: index) {
   %cf0 = arith.constant 1.0 : f32



More information about the Mlir-commits mailing list