[Mlir-commits] [mlir] [SCF] Fixed epilogue predicates in loop pipelining (PR #108964)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Sep 19 13:18:09 PDT 2024


https://github.com/sjw36 updated https://github.com/llvm/llvm-project/pull/108964

>From 93b0115446a143451f7abc281f7ad4c6d411a485 Mon Sep 17 00:00:00 2001
From: SJW <swaters at amd.com>
Date: Thu, 19 Sep 2024 20:17:33 +0000
Subject: [PATCH] [SCF] Fixed epilogue predicates in loop pipelining

The computed loop iteration is zero based, so only check it is less than zero.
This fixes the case when lower bound is not zero.
---
 mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp | 7 +++++--
 mlir/test/Dialect/SCF/loop-pipelining.mlir         | 5 +++--
 2 files changed, 8 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index 7cecd4942b640f..3d6da066875f99 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -655,6 +655,9 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
   Value rangeDecr = rewriter.create<arith::AddIOp>(loc, rangeIncr, minus1);
   Value totalIterations = rewriter.create<arith::DivUIOp>(loc, rangeDecr, step);
 
+  Value zero =
+      rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, 0));
+
   SmallVector<Value> predicates(maxStage + 1);
   for (int64_t i = 0; i < maxStage; i++) {
     // iterI = total_iters - 1 - i
@@ -671,9 +674,9 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
     setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i);
 
     if (dynamicLoop) {
-      // pred = iterI >= lb
+      // pred = iterI >= 0
       predicates[i + 1] = rewriter.create<arith::CmpIOp>(
-          loc, arith::CmpIPredicate::sge, iterI, lb);
+          loc, arith::CmpIPredicate::sge, iterI, zero);
     }
   }
 
diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir
index 4a1406faabce1b..4747aad977a492 100644
--- a/mlir/test/Dialect/SCF/loop-pipelining.mlir
+++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir
@@ -766,6 +766,7 @@ func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>, %ub:
 
 // Check for predicated epilogue for dynamic loop.
 // CHECK-LABEL: dynamic_loop(
+//        CHECK:   %[[C0:.*]] = arith.constant 0 : index
 //        CHECK:   %{{.*}}:2 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}})
 //        CHECK:       memref.store %[[ARG6]], %{{.*}}[%[[ARG5]]]
 //        CHECK:       %[[ADDF_24:.*]] = arith.addf %[[ARG7]], %{{.*}}
@@ -781,12 +782,12 @@ func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>, %ub:
 //        CHECK:   %[[ADDI_14:.*]] = arith.addi %[[DIVUI_13]], %{{.*}}-1
 //        CHECK:   %[[MULI_15:.*]] = arith.muli %{{.*}}, %[[ADDI_14]]
 //        CHECK:   %[[ADDI_16:.*]] = arith.addi %{{.*}}, %[[MULI_15]]
-//        CHECK:   %[[CMPI_17:.*]] = arith.cmpi sge, %[[ADDI_14]], %{{.*}}
+//        CHECK:   %[[CMPI_17:.*]] = arith.cmpi sge, %[[ADDI_14]], %[[C0]]
 //        CHECK:   %[[ADDI_18:.*]] = arith.addi %[[DIVUI_13]], %{{.*}}-1
 //        CHECK:   %[[ADDI_19:.*]] = arith.addi %[[ADDI_18]], %{{.*}}-1
 //        CHECK:   %[[MULI_20:.*]] = arith.muli %{{.*}}, %[[ADDI_19]]
 //        CHECK:   %[[ADDI_21:.*]] = arith.addi %{{.*}}, %[[MULI_20]]
-//        CHECK:   %[[CMPI_22:.*]] = arith.cmpi sge, %[[ADDI_19]], %{{.*}}
+//        CHECK:   %[[CMPI_22:.*]] = arith.cmpi sge, %[[ADDI_19]], %[[C0]]
 //        CHECK:   scf.if %[[CMPI_17]] {
 //        CHECK:     memref.store %{{.*}}#0, %{{.*}}[%[[ADDI_21]]]
 //        CHECK:   } else {



More information about the Mlir-commits mailing list