[Mlir-commits] [mlir] fa089b0 - [SCF] Fixed epilogue predicates in loop pipelining (#108964)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Sep 23 22:06:22 PDT 2024


Author: SJW
Date: 2024-09-23T22:06:19-07:00
New Revision: fa089b014b41db4ef90378c7eae35306402cfcb3

URL: https://github.com/llvm/llvm-project/commit/fa089b014b41db4ef90378c7eae35306402cfcb3
DIFF: https://github.com/llvm/llvm-project/commit/fa089b014b41db4ef90378c7eae35306402cfcb3.diff

LOG: [SCF] Fixed epilogue predicates in loop pipelining (#108964)

The computed loop iteration is zero based, so only check it is less than
zero. This fixes the case when lower bound is not zero.

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
    mlir/test/Dialect/SCF/loop-pipelining.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index 7cecd4942b640f..3d6da066875f99 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -655,6 +655,9 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
   Value rangeDecr = rewriter.create<arith::AddIOp>(loc, rangeIncr, minus1);
   Value totalIterations = rewriter.create<arith::DivUIOp>(loc, rangeDecr, step);
 
+  Value zero =
+      rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, 0));
+
   SmallVector<Value> predicates(maxStage + 1);
   for (int64_t i = 0; i < maxStage; i++) {
     // iterI = total_iters - 1 - i
@@ -671,9 +674,9 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
     setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i);
 
     if (dynamicLoop) {
-      // pred = iterI >= lb
+      // pred = iterI >= 0
       predicates[i + 1] = rewriter.create<arith::CmpIOp>(
-          loc, arith::CmpIPredicate::sge, iterI, lb);
+          loc, arith::CmpIPredicate::sge, iterI, zero);
     }
   }
 

diff  --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir
index 4a1406faabce1b..4747aad977a492 100644
--- a/mlir/test/Dialect/SCF/loop-pipelining.mlir
+++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir
@@ -766,6 +766,7 @@ func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>, %ub:
 
 // Check for predicated epilogue for dynamic loop.
 // CHECK-LABEL: dynamic_loop(
+//        CHECK:   %[[C0:.*]] = arith.constant 0 : index
 //        CHECK:   %{{.*}}:2 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}})
 //        CHECK:       memref.store %[[ARG6]], %{{.*}}[%[[ARG5]]]
 //        CHECK:       %[[ADDF_24:.*]] = arith.addf %[[ARG7]], %{{.*}}
@@ -781,12 +782,12 @@ func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>, %ub:
 //        CHECK:   %[[ADDI_14:.*]] = arith.addi %[[DIVUI_13]], %{{.*}}-1
 //        CHECK:   %[[MULI_15:.*]] = arith.muli %{{.*}}, %[[ADDI_14]]
 //        CHECK:   %[[ADDI_16:.*]] = arith.addi %{{.*}}, %[[MULI_15]]
-//        CHECK:   %[[CMPI_17:.*]] = arith.cmpi sge, %[[ADDI_14]], %{{.*}}
+//        CHECK:   %[[CMPI_17:.*]] = arith.cmpi sge, %[[ADDI_14]], %[[C0]]
 //        CHECK:   %[[ADDI_18:.*]] = arith.addi %[[DIVUI_13]], %{{.*}}-1
 //        CHECK:   %[[ADDI_19:.*]] = arith.addi %[[ADDI_18]], %{{.*}}-1
 //        CHECK:   %[[MULI_20:.*]] = arith.muli %{{.*}}, %[[ADDI_19]]
 //        CHECK:   %[[ADDI_21:.*]] = arith.addi %{{.*}}, %[[MULI_20]]
-//        CHECK:   %[[CMPI_22:.*]] = arith.cmpi sge, %[[ADDI_19]], %{{.*}}
+//        CHECK:   %[[CMPI_22:.*]] = arith.cmpi sge, %[[ADDI_19]], %[[C0]]
 //        CHECK:   scf.if %[[CMPI_17]] {
 //        CHECK:     memref.store %{{.*}}#0, %{{.*}}[%[[ADDI_21]]]
 //        CHECK:   } else {


        


More information about the Mlir-commits mailing list