[Mlir-commits] [mlir] [MLIR][SCF] Add support for loop pipeline peeling for dynamic loops. (PR #106436)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Sep 4 10:50:22 PDT 2024


https://github.com/sjw36 updated https://github.com/llvm/llvm-project/pull/106436

>From 9be66a1a091d268b9d61bceb39d43cd2b66e63bd Mon Sep 17 00:00:00 2001
From: SJW <swaters at amd.com>
Date: Thu, 29 Aug 2024 19:45:57 +0000
Subject: [PATCH 1/4] [MLIR][SCF] Add support for loop pipeline peeling for
 dynamic loops. * Allow speculative execution and predicate results per stage.

---
 .../Dialect/SCF/Transforms/LoopPipelining.cpp | 126 ++++++++++++------
 mlir/test/Dialect/SCF/loop-pipelining.mlir    |  43 +++++-
 mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp    |   4 +-
 3 files changed, 129 insertions(+), 44 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index d8e1cc0ecef88e..258e075e263259 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -94,8 +94,8 @@ struct LoopPipelinerInternal {
       RewriterBase &rewriter);
   /// Emits the epilogue, this creates `maxStage - 1` part which will contain
   /// operations from stages [i; maxStage], where i is the part index.
-  void emitEpilogue(RewriterBase &rewriter,
-                    llvm::SmallVector<Value> &returnValues);
+  LogicalResult emitEpilogue(RewriterBase &rewriter,
+                             llvm::SmallVector<Value> &returnValues);
 };
 
 bool LoopPipelinerInternal::initializeLoopInfo(
@@ -133,10 +133,6 @@ bool LoopPipelinerInternal::initializeLoopInfo(
     LDBG("--no epilogue or predicate set -> BAIL");
     return false;
   }
-  if (dynamicLoop && peelEpilogue) {
-    LDBG("--dynamic loop doesn't support epilogue yet -> BAIL");
-    return false;
-  }
   std::vector<std::pair<Operation *, unsigned>> schedule;
   options.getScheduleFn(forOp, schedule);
   if (schedule.empty()) {
@@ -313,10 +309,10 @@ void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
           });
       int predicateIdx = i - stages[op];
       if (predicates[predicateIdx]) {
+        OpBuilder::InsertionGuard insertGuard(rewriter);
         newOp = predicateFn(rewriter, newOp, predicates[predicateIdx]);
         assert(newOp && "failed to predicate op.");
       }
-      rewriter.setInsertionPointAfter(newOp);
       if (annotateFn)
         annotateFn(newOp, PipeliningOption::PipelinerPart::Prologue, i);
       for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) {
@@ -561,6 +557,7 @@ LogicalResult LoopPipelinerInternal::createKernel(
     }
 
     if (predicates[useStage]) {
+      OpBuilder::InsertionGuard insertGuard(rewriter);
       newOp = predicateFn(rewriter, newOp, predicates[useStage]);
       if (!newOp)
         return failure();
@@ -568,7 +565,6 @@ LogicalResult LoopPipelinerInternal::createKernel(
       for (auto values : llvm::zip(op->getResults(), newOp->getResults()))
         mapping.map(std::get<0>(values), std::get<1>(values));
     }
-    rewriter.setInsertionPointAfter(newOp);
     if (annotateFn)
       annotateFn(newOp, PipeliningOption::PipelinerPart::Kernel, 0);
   }
@@ -640,70 +636,123 @@ LogicalResult LoopPipelinerInternal::createKernel(
   return success();
 }
 
-void LoopPipelinerInternal::emitEpilogue(
-    RewriterBase &rewriter, llvm::SmallVector<Value> &returnValues) {
+LogicalResult
+LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
+                                    llvm::SmallVector<Value> &returnValues) {
+  Location loc = forOp.getLoc();
   // Emit different versions of the induction variable. They will be
   // removed by dead code if not used.
+
+  // bounds_range = ub - lb
+  // total_iterations = bounds_range / step + (bounds_range % step ? 1 : 0)
+  Type t = lb.getType();
+  Value minus1 =
+      rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1));
+
+  Value const_0 =
+      rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, 0));
+  Value const_1 =
+      rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, 1));
+  Value boundsRange = rewriter.create<arith::SubIOp>(loc, ub, lb);
+  Value boundsRem = rewriter.create<arith::RemUIOp>(loc, boundsRange, step);
+  Value hasRem = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
+                                                boundsRem, const_0);
+  Value selRem =
+      rewriter.create<arith::SelectOp>(loc, hasRem, const_1, const_0);
+  Value boundsDiv = rewriter.create<arith::DivUIOp>(loc, boundsRange, step);
+  Value totalIterations =
+      rewriter.create<arith::AddIOp>(loc, boundsDiv, selRem);
+
+  SmallVector<Value> predicates(maxStage + 1);
   for (int64_t i = 0; i < maxStage; i++) {
-    Location loc = forOp.getLoc();
-    Type t = lb.getType();
-    Value minusOne =
-        rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1));
-    // number of iterations = ((ub - 1) - lb) / step
-    Value totalNumIteration = rewriter.create<arith::DivUIOp>(
-        loc,
-        rewriter.create<arith::SubIOp>(
-            loc, rewriter.create<arith::AddIOp>(loc, ub, minusOne), lb),
-        step);
-    // newLastIter = lb + step * ((((ub - 1) - lb) / step) - i)
+    // iterI = total_iters - 1 - i
+    // May go negative...
     Value minusI =
         rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -i));
+    Value iterI = rewriter.create<arith::AddIOp>(
+        loc, rewriter.create<arith::AddIOp>(loc, totalIterations, minus1),
+        minusI);
+    // newLastIter = lb + step * iterI
     Value newlastIter = rewriter.create<arith::AddIOp>(
-        loc, lb,
-        rewriter.create<arith::MulIOp>(
-            loc, step,
-            rewriter.create<arith::AddIOp>(loc, totalNumIteration, minusI)));
+        loc, lb, rewriter.create<arith::MulIOp>(loc, step, iterI));
+
     setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i);
+
+    if (dynamicLoop) {
+      // pred = iterI >= lb
+      predicates[i + 1] = rewriter.create<arith::CmpIOp>(
+          loc, arith::CmpIPredicate::sge, iterI, lb);
+    }
   }
+
   // Emit `maxStage - 1` epilogue part that includes operations from stages
   // [i; maxStage].
   for (int64_t i = 1; i <= maxStage; i++) {
+    SmallVector<std::pair<Value, unsigned>> returnMap(returnValues.size());
     for (Operation *op : opOrder) {
       if (stages[op] < i)
         continue;
+      unsigned currentVersion = maxStage - stages[op] + i;
+      unsigned nextVersion = currentVersion + 1;
       Operation *newOp =
           cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) {
             auto it = valueMapping.find(newOperand->get());
             if (it != valueMapping.end()) {
-              Value replacement = it->second[maxStage - stages[op] + i];
+              Value replacement = it->second[currentVersion];
               newOperand->set(replacement);
             }
           });
       if (annotateFn)
         annotateFn(newOp, PipeliningOption::PipelinerPart::Epilogue, i - 1);
-      for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) {
-        setValueMapping(op->getResult(destId), newOp->getResult(destId),
-                        maxStage - stages[op] + i);
+      if (dynamicLoop) {
+        OpBuilder::InsertionGuard insertGuard(rewriter);
+        newOp = predicateFn(rewriter, newOp, predicates[currentVersion]);
+        if (!newOp)
+          return failure();
+      }
+
+      for (auto [opRes, newRes] :
+           llvm::zip(op->getResults(), newOp->getResults())) {
+        setValueMapping(opRes, newRes, currentVersion);
         // If the value is a loop carried dependency update the loop argument
         // mapping and keep track of the last version to replace the original
         // forOp uses.
         for (OpOperand &operand :
              forOp.getBody()->getTerminator()->getOpOperands()) {
-          if (operand.get() != op->getResult(destId))
+          if (operand.get() != opRes)
             continue;
-          unsigned version = maxStage - stages[op] + i + 1;
           // If the version is greater than maxStage it means it maps to the
           // original forOp returned value.
-          if (version > maxStage) {
-            returnValues[operand.getOperandNumber()] = newOp->getResult(destId);
-            continue;
-          }
-          setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()],
-                          newOp->getResult(destId), version);
+          unsigned ri = operand.getOperandNumber();
+          returnValues[ri] = newRes;
+          Value mapVal = forOp.getRegionIterArgs()[ri];
+          returnMap[ri] = std::make_pair(mapVal, currentVersion);
+          if (nextVersion <= maxStage)
+            setValueMapping(mapVal, newRes, nextVersion);
+        }
+      }
+    }
+    if (dynamicLoop) {
+      // Select return values from this stage (live outs) based on predication.
+      // If the stage is valid select the peeled value, else use previous stage
+      // value.
+      for (auto pair : llvm::enumerate(returnValues)) {
+        unsigned ri = pair.index();
+        auto [mapVal, currentVersion] = returnMap[ri];
+        if (mapVal) {
+          unsigned nextVersion = currentVersion + 1;
+          Value pred = predicates[currentVersion];
+          Value prevValue = valueMapping[mapVal][currentVersion];
+          auto selOp = rewriter.create<arith::SelectOp>(loc, pred, pair.value(),
+                                                        prevValue);
+          returnValues[ri] = selOp;
+          if (nextVersion <= maxStage)
+            setValueMapping(mapVal, selOp, nextVersion);
         }
       }
     }
   }
+  return success();
 }
 
 void LoopPipelinerInternal::setValueMapping(Value key, Value el, int64_t idx) {
@@ -760,7 +809,8 @@ FailureOr<ForOp> mlir::scf::pipelineForLoop(RewriterBase &rewriter, ForOp forOp,
   if (options.peelEpilogue) {
     // 4. Emit the epilogue after the new forOp.
     rewriter.setInsertionPointAfter(newForOp);
-    pipeliner.emitEpilogue(rewriter, returnValues);
+    if (failed(pipeliner.emitEpilogue(rewriter, returnValues)))
+      return failure();
   }
   // 5. Erase the original loop and replace the uses with the epilogue output.
   if (forOp->getNumResults() > 0)
diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir
index 9687f80f5ddfc8..957dc5295c0583 100644
--- a/mlir/test/Dialect/SCF/loop-pipelining.mlir
+++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir
@@ -764,11 +764,46 @@ func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>, %ub:
 //      NOEPILOGUE:     memref.load %[[A]][%[[IV3]]] : memref<?xf32>
 //      NOEPILOGUE:   scf.yield %[[V2]], %[[L3]] : f32, f32
 
-// In case dynamic loop pipelining is off check that the transformation didn't
-// apply.
+// Check for predicated epilogue for dynamic loop.
 // CHECK-LABEL: dynamic_loop(
-//   CHECK-NOT:   memref.load
-//       CHECK:   scf.for
+//        CHECK:   %{{.*}}:2 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}})
+//        CHECK:       memref.store %[[ARG6]], %{{.*}}[%[[ARG5]]] 
+//        CHECK:       %[[ADDF_26:.*]] = arith.addf %[[ARG7]], %{{.*}} 
+//        CHECK:       %[[MULI_27:.*]] = arith.muli %{{.*}}, %{{.*}} 
+//        CHECK:       %[[ADDI_28:.*]] = arith.addi %[[ARG5]], %[[MULI_27]] 
+//        CHECK:       %[[LOAD_29:.*]] = memref.load %{{.*}}[%[[ADDI_28]]] 
+//        CHECK:       scf.yield %[[ADDF_26]], %[[LOAD_29]] 
+//        CHECK:   }
+//        CHECK:   %[[SUBI_10:.*]] = arith.subi %{{.*}}, %{{.*}} 
+//        CHECK:   %[[REMUI_11:.*]] = arith.remui %[[SUBI_10]], %{{.*}} 
+//        CHECK:   %[[CMPI_12:.*]] = arith.cmpi ne, %[[REMUI_11]], %{{.*}} 
+//        CHECK:   %[[SELECT_13:.*]] = arith.select %[[CMPI_12]], %{{.*}}, %{{.*}} 
+//        CHECK:   %[[DIVUI_14:.*]] = arith.divui %[[SUBI_10]], %{{.*}} 
+//        CHECK:   %[[ADDI_15:.*]] = arith.addi %[[DIVUI_14]], %[[SELECT_13]] 
+//        CHECK:   %[[ADDI_16:.*]] = arith.addi %[[ADDI_15]], %{{.*}}-1 
+//        CHECK:   %[[MULI_17:.*]] = arith.muli %{{.*}}, %[[ADDI_16]] 
+//        CHECK:   %[[ADDI_18:.*]] = arith.addi %{{.*}}, %[[MULI_17]] 
+//        CHECK:   %[[CMPI_19:.*]] = arith.cmpi sge, %[[ADDI_16]], %{{.*}} 
+//        CHECK:   %[[ADDI_20:.*]] = arith.addi %[[ADDI_15]], %{{.*}}-1 
+//        CHECK:   %[[ADDI_21:.*]] = arith.addi %[[ADDI_20]], %{{.*}}-1 
+//        CHECK:   %[[MULI_22:.*]] = arith.muli %{{.*}}, %[[ADDI_21]] 
+//        CHECK:   %[[ADDI_23:.*]] = arith.addi %{{.*}}, %[[MULI_22]] 
+//        CHECK:   %[[CMPI_24:.*]] = arith.cmpi sge, %[[ADDI_21]], %{{.*}} 
+//        CHECK:   scf.if %[[CMPI_19]] {
+//        CHECK:     memref.store %{{.*}}#0, %{{.*}}[%[[ADDI_23]]] 
+//        CHECK:   } else {
+//        CHECK:   }
+//        CHECK:   %[[IF_25:.*]] = scf.if %[[CMPI_24]] -> (f32) {
+//        CHECK:     %[[ADDF_26:.*]] = arith.addf %{{.*}}#1, %{{.*}} 
+//        CHECK:     scf.yield %[[ADDF_26]] 
+//        CHECK:   } else {
+//        CHECK:     scf.yield %{{.*}} 
+//        CHECK:   }
+//        CHECK:   scf.if %[[CMPI_24]] {
+//        CHECK:     memref.store %[[IF_25]], %{{.*}}[%[[ADDI_18]]] 
+//        CHECK:   } else {
+//        CHECK:   }
+//        CHECK:   return
 func.func @dynamic_loop(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %ub: index, %step: index) {
   %cf = arith.constant 1.0 : f32
   scf.for %i0 = %lb to %ub step %step {
diff --git a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
index 8a92d840ad1302..3ff7f9966e93da 100644
--- a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
+++ b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
@@ -214,12 +214,12 @@ struct TestSCFPipeliningPass
     RewritePatternSet patterns(&getContext());
     mlir::scf::PipeliningOption options;
     options.getScheduleFn = getSchedule;
+    options.supportDynamicLoops = true;
+    options.predicateFn = predicateOp;
     if (annotatePipeline)
       options.annotateFn = annotate;
     if (noEpiloguePeeling) {
-      options.supportDynamicLoops = true;
       options.peelEpilogue = false;
-      options.predicateFn = predicateOp;
     }
     scf::populateSCFLoopPipeliningPatterns(patterns, options);
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));

>From df8268d576f17625c87c8a3b7383ee4247eeabbd Mon Sep 17 00:00:00 2001
From: SJW <swaters at amd.com>
Date: Thu, 29 Aug 2024 20:24:50 +0000
Subject: [PATCH 2/4] * annotate predicated op

---
 mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index 258e075e263259..0615ffce072262 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -702,14 +702,14 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
               newOperand->set(replacement);
             }
           });
-      if (annotateFn)
-        annotateFn(newOp, PipeliningOption::PipelinerPart::Epilogue, i - 1);
       if (dynamicLoop) {
         OpBuilder::InsertionGuard insertGuard(rewriter);
         newOp = predicateFn(rewriter, newOp, predicates[currentVersion]);
         if (!newOp)
           return failure();
       }
+      if (annotateFn)
+        annotateFn(newOp, PipeliningOption::PipelinerPart::Epilogue, i - 1);
 
       for (auto [opRes, newRes] :
            llvm::zip(op->getResults(), newOp->getResults())) {

>From 5603dedcde6f6f5dab3777c8106435b35192ac3b Mon Sep 17 00:00:00 2001
From: SJW <swaters at amd.com>
Date: Fri, 30 Aug 2024 02:58:50 +0000
Subject: [PATCH 3/4] * strength reduce

---
 .../Dialect/SCF/Transforms/LoopPipelining.cpp | 18 ++----
 mlir/test/Dialect/SCF/loop-pipelining.mlir    | 56 +++++++++----------
 2 files changed, 31 insertions(+), 43 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index 0615ffce072262..a34542f0161aca 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -644,24 +644,14 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
   // removed by dead code if not used.
 
   // bounds_range = ub - lb
-  // total_iterations = bounds_range / step + (bounds_range % step ? 1 : 0)
+  // total_iterations = (bounds_range + step - 1) / step
   Type t = lb.getType();
   Value minus1 =
       rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1));
-
-  Value const_0 =
-      rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, 0));
-  Value const_1 =
-      rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, 1));
   Value boundsRange = rewriter.create<arith::SubIOp>(loc, ub, lb);
-  Value boundsRem = rewriter.create<arith::RemUIOp>(loc, boundsRange, step);
-  Value hasRem = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
-                                                boundsRem, const_0);
-  Value selRem =
-      rewriter.create<arith::SelectOp>(loc, hasRem, const_1, const_0);
-  Value boundsDiv = rewriter.create<arith::DivUIOp>(loc, boundsRange, step);
-  Value totalIterations =
-      rewriter.create<arith::AddIOp>(loc, boundsDiv, selRem);
+  Value rangeIncr = rewriter.create<arith::AddIOp>(loc, boundsRange, step);
+  Value rangeDecr = rewriter.create<arith::AddIOp>(loc, rangeIncr, minus1);
+  Value totalIterations = rewriter.create<arith::DivUIOp>(loc, rangeDecr, step);
 
   SmallVector<Value> predicates(maxStage + 1);
   for (int64_t i = 0; i < maxStage; i++) {
diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir
index 957dc5295c0583..010c39f21afc30 100644
--- a/mlir/test/Dialect/SCF/loop-pipelining.mlir
+++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir
@@ -767,40 +767,38 @@ func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>, %ub:
 // Check for predicated epilogue for dynamic loop.
 // CHECK-LABEL: dynamic_loop(
 //        CHECK:   %{{.*}}:2 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}})
-//        CHECK:       memref.store %[[ARG6]], %{{.*}}[%[[ARG5]]] 
-//        CHECK:       %[[ADDF_26:.*]] = arith.addf %[[ARG7]], %{{.*}} 
-//        CHECK:       %[[MULI_27:.*]] = arith.muli %{{.*}}, %{{.*}} 
-//        CHECK:       %[[ADDI_28:.*]] = arith.addi %[[ARG5]], %[[MULI_27]] 
-//        CHECK:       %[[LOAD_29:.*]] = memref.load %{{.*}}[%[[ADDI_28]]] 
-//        CHECK:       scf.yield %[[ADDF_26]], %[[LOAD_29]] 
+//        CHECK:       memref.store %[[ARG6]], %{{.*}}[%[[ARG5]]]
+//        CHECK:       %[[ADDF_24:.*]] = arith.addf %[[ARG7]], %{{.*}}
+//        CHECK:       %[[MULI_25:.*]] = arith.muli %{{.*}}, %{{.*}}
+//        CHECK:       %[[ADDI_26:.*]] = arith.addi %[[ARG5]], %[[MULI_25]]
+//        CHECK:       %[[LOAD_27:.*]] = memref.load %{{.*}}[%[[ADDI_26]]]
+//        CHECK:       scf.yield %[[ADDF_24]], %[[LOAD_27]]
 //        CHECK:   }
-//        CHECK:   %[[SUBI_10:.*]] = arith.subi %{{.*}}, %{{.*}} 
-//        CHECK:   %[[REMUI_11:.*]] = arith.remui %[[SUBI_10]], %{{.*}} 
-//        CHECK:   %[[CMPI_12:.*]] = arith.cmpi ne, %[[REMUI_11]], %{{.*}} 
-//        CHECK:   %[[SELECT_13:.*]] = arith.select %[[CMPI_12]], %{{.*}}, %{{.*}} 
-//        CHECK:   %[[DIVUI_14:.*]] = arith.divui %[[SUBI_10]], %{{.*}} 
-//        CHECK:   %[[ADDI_15:.*]] = arith.addi %[[DIVUI_14]], %[[SELECT_13]] 
-//        CHECK:   %[[ADDI_16:.*]] = arith.addi %[[ADDI_15]], %{{.*}}-1 
-//        CHECK:   %[[MULI_17:.*]] = arith.muli %{{.*}}, %[[ADDI_16]] 
-//        CHECK:   %[[ADDI_18:.*]] = arith.addi %{{.*}}, %[[MULI_17]] 
-//        CHECK:   %[[CMPI_19:.*]] = arith.cmpi sge, %[[ADDI_16]], %{{.*}} 
-//        CHECK:   %[[ADDI_20:.*]] = arith.addi %[[ADDI_15]], %{{.*}}-1 
-//        CHECK:   %[[ADDI_21:.*]] = arith.addi %[[ADDI_20]], %{{.*}}-1 
-//        CHECK:   %[[MULI_22:.*]] = arith.muli %{{.*}}, %[[ADDI_21]] 
-//        CHECK:   %[[ADDI_23:.*]] = arith.addi %{{.*}}, %[[MULI_22]] 
-//        CHECK:   %[[CMPI_24:.*]] = arith.cmpi sge, %[[ADDI_21]], %{{.*}} 
-//        CHECK:   scf.if %[[CMPI_19]] {
-//        CHECK:     memref.store %{{.*}}#0, %{{.*}}[%[[ADDI_23]]] 
+//        CHECK:   %[[SUBI_10:.*]] = arith.subi %{{.*}}, %{{.*}}
+//        CHECK:   %[[ADDI_11:.*]] = arith.addi %[[SUBI_10]], %{{.*}}
+//        CHECK:   %[[ADDI_12:.*]] = arith.addi %[[ADDI_11]], %{{.*}}-1
+//        CHECK:   %[[DIVUI_13:.*]] = arith.divui %[[ADDI_12]], %{{.*}}
+//        CHECK:   %[[ADDI_14:.*]] = arith.addi %[[DIVUI_13]], %{{.*}}-1
+//        CHECK:   %[[MULI_15:.*]] = arith.muli %{{.*}}, %[[ADDI_14]]
+//        CHECK:   %[[ADDI_16:.*]] = arith.addi %{{.*}}, %[[MULI_15]]
+//        CHECK:   %[[CMPI_17:.*]] = arith.cmpi sge, %[[ADDI_14]], %{{.*}}
+//        CHECK:   %[[ADDI_18:.*]] = arith.addi %[[DIVUI_13]], %{{.*}}-1
+//        CHECK:   %[[ADDI_19:.*]] = arith.addi %[[ADDI_18]], %{{.*}}-1
+//        CHECK:   %[[MULI_20:.*]] = arith.muli %{{.*}}, %[[ADDI_19]]
+//        CHECK:   %[[ADDI_21:.*]] = arith.addi %{{.*}}, %[[MULI_20]]
+//        CHECK:   %[[CMPI_22:.*]] = arith.cmpi sge, %[[ADDI_19]], %{{.*}}
+//        CHECK:   scf.if %[[CMPI_17]] {
+//        CHECK:     memref.store %{{.*}}#0, %{{.*}}[%[[ADDI_21]]]
 //        CHECK:   } else {
 //        CHECK:   }
-//        CHECK:   %[[IF_25:.*]] = scf.if %[[CMPI_24]] -> (f32) {
-//        CHECK:     %[[ADDF_26:.*]] = arith.addf %{{.*}}#1, %{{.*}} 
-//        CHECK:     scf.yield %[[ADDF_26]] 
+//        CHECK:   %[[IF_23:.*]] = scf.if %[[CMPI_22]] -> (f32) {
+//        CHECK:     %[[ADDF_24:.*]] = arith.addf %{{.*}}#1, %{{.*}}
+//        CHECK:     scf.yield %[[ADDF_24]]
 //        CHECK:   } else {
-//        CHECK:     scf.yield %{{.*}} 
+//        CHECK:     scf.yield %{{.*}}
 //        CHECK:   }
-//        CHECK:   scf.if %[[CMPI_24]] {
-//        CHECK:     memref.store %[[IF_25]], %{{.*}}[%[[ADDI_18]]] 
+//        CHECK:   scf.if %[[CMPI_22]] {
+//        CHECK:     memref.store %[[IF_23]], %{{.*}}[%[[ADDI_16]]]
 //        CHECK:   } else {
 //        CHECK:   }
 //        CHECK:   return

>From 969e8bf3c9f99f39a0c46c57f4cf0b78c09c151b Mon Sep 17 00:00:00 2001
From: SJW <swaters at amd.com>
Date: Wed, 4 Sep 2024 15:28:16 +0000
Subject: [PATCH 4/4] * added test with scf.for results

---
 mlir/test/Dialect/SCF/loop-pipelining.mlir | 62 ++++++++++++++++++++++
 1 file changed, 62 insertions(+)

diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir
index 010c39f21afc30..4a1406faabce1b 100644
--- a/mlir/test/Dialect/SCF/loop-pipelining.mlir
+++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir
@@ -814,6 +814,68 @@ func.func @dynamic_loop(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %
 
 // -----
 
+// NOEPILOGUE-LABEL:   func.func @dynamic_loop_result
+//       NOEPILOGUE:     %{{.*}}:2 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}})
+//       NOEPILOGUE:       %[[SUBI_3:.*]] = arith.subi %{{.*}}, %{{.*}}
+//       NOEPILOGUE:       %[[CMPI_4:.*]] = arith.cmpi slt, %[[ARG5]], %[[SUBI_3]]
+//       NOEPILOGUE:       %[[ADDF_5:.*]] = arith.addf %[[ARG7]], %[[ARG6]]
+//       NOEPILOGUE:       %[[MULF_6:.*]] = arith.mulf %[[ADDF_5]], %{{.*}}
+//       NOEPILOGUE:       %[[ADDI_7:.*]] = arith.addi %[[ARG5]], %{{.*}}
+//       NOEPILOGUE:       %[[IF_8:.*]] = scf.if %[[CMPI_4]]
+//       NOEPILOGUE:         %[[LOAD_9:.*]] = memref.load %{{.*}}[%[[ADDI_7]]]
+//       NOEPILOGUE:         scf.yield %[[LOAD_9]]
+//       NOEPILOGUE:       } else {
+//       NOEPILOGUE:         scf.yield %{{.*}}
+//       NOEPILOGUE:       }
+//       NOEPILOGUE:       scf.yield %[[MULF_6]], %[[IF_8]]
+//       NOEPILOGUE:     }
+//       NOEPILOGUE:     memref.store %{{.*}}#0, %{{.*}}[%{{.*}}]
+
+// Check for predicated epilogue for dynamic loop.
+// CHECK-LABEL:   func.func @dynamic_loop_result
+//       CHECK:     %{{.*}}:2 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}})
+//       CHECK:       %[[ADDF_13:.*]] = arith.addf %[[ARG7]], %[[ARG6]]
+//       CHECK:       %[[MULF_14:.*]] = arith.mulf %[[ADDF_13]], %{{.*}}
+//       CHECK:       %[[ADDI_15:.*]] = arith.addi %[[ARG5]], %{{.*}}
+//       CHECK:       %[[LOAD_16:.*]] = memref.load %{{.*}}[%[[ADDI_15]]]
+//       CHECK:       scf.yield %[[MULF_14]], %[[LOAD_16]]
+//       CHECK:     }
+//       CHECK:     %[[SUBI_4:.*]] = arith.subi %{{.*}}, %{{.*}}
+//       CHECK:     %[[ADDI_5:.*]] = arith.addi %[[SUBI_4]], %{{.*}}
+//       CHECK:     %[[ADDI_6:.*]] = arith.addi %[[ADDI_5]], %{{.*}}-1
+//       CHECK:     %[[DIVUI_7:.*]] = arith.divui %[[ADDI_6]], %{{.*}}
+//       CHECK:     %[[ADDI_8:.*]] = arith.addi %[[DIVUI_7]], %{{.*}}-1
+//       CHECK:     %[[CMPI_9:.*]] = arith.cmpi sge, %[[ADDI_8]], %{{.*}}
+//       CHECK:     %[[IF_10:.*]] = scf.if %[[CMPI_9]]
+//       CHECK:       %[[ADDF_13:.*]] = arith.addf %{{.*}}#1, %{{.*}}#0
+//       CHECK:       scf.yield %[[ADDF_13]]
+//       CHECK:     } else {
+//       CHECK:       scf.yield %{{.*}}
+//       CHECK:     }
+//       CHECK:     %[[IF_11:.*]] = scf.if %[[CMPI_9]]
+//       CHECK:       %[[MULF_13:.*]] = arith.mulf %[[IF_10]], %{{.*}}
+//       CHECK:       scf.yield %[[MULF_13]]
+//       CHECK:     } else {
+//       CHECK:       scf.yield %{{.*}}
+//       CHECK:     }
+//       CHECK:     %[[SELECT_12:.*]] = arith.select %[[CMPI_9]], %[[IF_11]], %{{.*}}#0
+//       CHECK:     memref.store %[[SELECT_12]], %{{.*}}[%{{.*}}]
+func.func @dynamic_loop_result(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %ub: index, %step: index) {
+  %cf0 = arith.constant 1.0 : f32
+  %cf1 = arith.constant 33.0 : f32
+  %cst = arith.constant 0 : index
+  %res:1 = scf.for %i0 = %lb to %ub step %step iter_args (%arg0 = %cf0) -> (f32) {
+    %A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 2 } : memref<?xf32>
+    %A1_elem = arith.addf %A_elem, %arg0 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 0 } : f32
+    %A2_elem = arith.mulf %A1_elem, %cf1 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 1 } : f32
+    scf.yield %A2_elem : f32
+  } { __test_pipelining_loop__ }
+  memref.store %res#0, %result[%cst] : memref<?xf32>
+  return
+}
+
+// -----
+
 // CHECK-LABEL: yield_constant_loop(
 //  CHECK-SAME:   %[[A:.*]]: memref<?xf32>) -> f32 {
 //   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index



More information about the Mlir-commits mailing list