[Mlir-commits] [mlir] [MLIR][SCF] Add checks to verify that the pipeliner schedule is correct. (PR #77083)

Thomas Raoux llvmlistbot at llvm.org
Fri Jan 5 03:46:47 PST 2024


https://github.com/ThomasRaoux created https://github.com/llvm/llvm-project/pull/77083

Add a check to validate that the schedule passed to the pipeliner transformation is valid and won't cause the pipeliner to break SSA.

This checks that the for each operation in the loop operations are scheduled after their operands.

>From 1bf6958294156f90b3b07bca54da1977e1d1620f Mon Sep 17 00:00:00 2001
From: Thomas Raoux <thomas.raoux at openai.com>
Date: Thu, 4 Jan 2024 03:00:49 -0800
Subject: [PATCH] [MLIR][SCF] Add checks to verify that the pipeliner schedule
 is correct.

Add a check to validate that the schedule passed to the pipeliner transformation
is valid and won't cause the pipeliner to break SSA.

This checks that the for each operation in the loop operations are scheduled
after their operands.
---
 .../Dialect/SCF/Transforms/LoopPipelining.cpp | 42 +++++++++++++++++++
 mlir/test/Dialect/SCF/loop-pipelining.mlir    | 34 ++++++++++++++-
 2 files changed, 75 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index 7d45b484f76575..4de5a495c9290a 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -67,6 +67,10 @@ struct LoopPipelinerInternal {
   /// the Value.
   std::pair<Operation *, int64_t> getDefiningOpAndDistance(Value value);
 
+  /// Return true if the schedule is possible and return false otherwise. A
+  /// schedule is correct if all definitions are scheduled before uses.
+  bool verifySchedule();
+
 public:
   /// Initalize the information for the given `op`, return true if it
   /// satisfies the pre-condition to apply pipelining.
@@ -156,6 +160,11 @@ bool LoopPipelinerInternal::initializeLoopInfo(
     }
   }
 
+  if (!verifySchedule()) {
+    LDBG("--invalid schedule: " << op << " -> BAIL");
+    return false;
+  }
+
   // Currently, we do not support assigning stages to ops in nested regions. The
   // block of all operations assigned a stage should be the single `scf.for`
   // body block.
@@ -330,6 +339,39 @@ LoopPipelinerInternal::getDefiningOpAndDistance(Value value) {
   return {def, distance};
 }
 
+/// Compute unrolled cycles of each op and verify that each op is scheduled
+/// after its operands (modulo the distance between producer and consumer).
+bool LoopPipelinerInternal::verifySchedule() {
+  int64_t numCylesPerIter = opOrder.size();
+  // Pre-compute the unrolled cycle of each op.
+  DenseMap<Operation *, int64_t> unrolledCyles;
+  for (int64_t cycle = 0; cycle < numCylesPerIter; cycle++) {
+    Operation *def = opOrder[cycle];
+    auto it = stages.find(def);
+    assert(it != stages.end());
+    int64_t stage = it->second;
+    unrolledCyles[def] = cycle + stage * numCylesPerIter;
+  }
+  for (Operation *consumer : opOrder) {
+    int64_t consumerCycle = unrolledCyles[consumer];
+    for (Value operand : consumer->getOperands()) {
+      auto [producer, distance] = getDefiningOpAndDistance(operand);
+      if (!producer)
+        continue;
+      auto it = unrolledCyles.find(producer);
+      // Skip producer coming from outside the loop.
+      if (it == unrolledCyles.end())
+        continue;
+      int64_t producerCycle = it->second;
+      if (consumerCycle < producerCycle - numCylesPerIter * distance) {
+        consumer->emitError("operation scheduled before its operands.");
+        return false;
+      }
+    }
+  }
+  return true;
+}
+
 scf::ForOp LoopPipelinerInternal::createKernelLoop(
     const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
         &crossStageValues,
diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir
index 33290d2db31d66..694c0c321e6b36 100644
--- a/mlir/test/Dialect/SCF/loop-pipelining.mlir
+++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-scf-pipelining -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-scf-pipelining -split-input-file -verify-diagnostics | FileCheck %s
 // RUN: mlir-opt %s -test-scf-pipelining=annotate -split-input-file | FileCheck %s --check-prefix ANNOTATE
 // RUN: mlir-opt %s -test-scf-pipelining=no-epilogue-peeling -split-input-file | FileCheck %s --check-prefix NOEPILOGUE
 
@@ -814,3 +814,35 @@ func.func @yield_constant_loop(%A: memref<?xf32>) -> f32 {
   return %r : f32
 }
 
+// -----
+
+func.func @invalid_schedule(%A: memref<?xf32>, %result: memref<?xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %cf = arith.constant 1.0 : f32
+  scf.for %i0 = %c0 to %c4 step %c1 {
+    %A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 2 } : memref<?xf32>
+    %A1_elem = arith.addf %A_elem, %cf { __test_pipelining_stage__ = 2, __test_pipelining_op_order__ = 0 } : f32
+    // expected-error at +1 {{operation scheduled before its operands.}}
+    memref.store %A1_elem, %result[%i0] { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 1 } : memref<?xf32>
+  }  { __test_pipelining_loop__ }
+  return
+}
+
+// -----
+
+func.func @invalid_schedule2(%A: memref<?xf32>, %result: memref<?xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %cf = arith.constant 1.0 : f32
+  %r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%idx = %c0) -> (index) {
+    // expected-error at +1 {{operation scheduled before its operands.}}
+    %A_elem = memref.load %A[%idx] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 0 } : memref<?xf32>
+    %idx1 = arith.addi %idx, %c1 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 1 } : index
+    memref.store %A_elem, %result[%idx] { __test_pipelining_stage__ = 2, __test_pipelining_op_order__ = 2 } : memref<?xf32>
+    scf.yield %idx1 : index
+  }  { __test_pipelining_loop__ }
+  return
+}



More information about the Mlir-commits mailing list