[Mlir-commits] [mlir] [MLIR][SCF] Add a pattern to remove dead cycles in scf.for ops (PR #67350)

Thomas Raoux llvmlistbot at llvm.org
Mon Sep 25 10:10:29 PDT 2023


https://github.com/ThomasRaoux created https://github.com/llvm/llvm-project/pull/67350

Dead code elimination cannot remove dead cycles in loop operations. This pattern allows removing dead cycles in scf.for loops. The algorithm first assumes all the arguments associated to unused results to be dead then recursively propagates liveness to identify dead arguments.

>From 6c57df59414c9bfc21bd1189ff2a7986043646db Mon Sep 17 00:00:00 2001
From: Thomas Raoux <thomas.raoux at openai.com>
Date: Mon, 25 Sep 2023 10:02:36 -0700
Subject: [PATCH] [MLIR][SCF] Add a pattern to remove dead cycles in scf.for
 ops

Dead code elimination cannot remove dead cycles in loop operations.
This pattern allows removing dead cycles in scf.for loops.
The algorithm first assumes all the arguments associated to unused
results to be dead then recursively propagates liveness to identify
dead arguments.
---
 .../mlir/Dialect/SCF/Transforms/Patterns.h    |   6 +
 .../lib/Dialect/SCF/Transforms/CMakeLists.txt |   1 +
 .../Transforms/ForOpDeadCycleElimination.cpp  | 127 +++++++++++++++
 .../test/Dialect/SCF/for-loop-dead-cycle.mlir | 144 ++++++++++++++++++
 mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp    |  19 +++
 5 files changed, 297 insertions(+)
 create mode 100644 mlir/lib/Dialect/SCF/Transforms/ForOpDeadCycleElimination.cpp
 create mode 100644 mlir/test/Dialect/SCF/for-loop-dead-cycle.mlir

diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
index 5c0d5643c01986d..19a9aff15aad4ac 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
@@ -79,6 +79,12 @@ void populateSCFLoopPipeliningPatterns(RewritePatternSet &patterns,
 /// loop bounds and loop steps are canonicalized.
 void populateSCFForLoopCanonicalizationPatterns(RewritePatternSet &patterns);
 
+/// Applies an expensive dead code elimination to scf.for op loop arguments.
+/// This allows removing dead cycles in loop arguments.
+/// The pattern will first assume that all the arguments are dead and
+/// recursively propagate liveness to values in the loop to a fix point.
+void populateForOpDeadCycleEliminationPatterns(RewritePatternSet &patterns);
+
 } // namespace scf
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
index fdaeb2fad9afa4f..9bc266660db0652 100644
--- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
   BufferDeallocationOpInterfaceImpl.cpp
   BufferizableOpInterfaceImpl.cpp
   Bufferize.cpp
+  ForOpDeadCycleElimination.cpp
   ForToWhile.cpp
   LoopCanonicalization.cpp
   LoopPipelining.cpp
diff --git a/mlir/lib/Dialect/SCF/Transforms/ForOpDeadCycleElimination.cpp b/mlir/lib/Dialect/SCF/Transforms/ForOpDeadCycleElimination.cpp
new file mode 100644
index 000000000000000..8e4d3b96af93229
--- /dev/null
+++ b/mlir/lib/Dialect/SCF/Transforms/ForOpDeadCycleElimination.cpp
@@ -0,0 +1,127 @@
+//==-- ForOpDeadCycleElimination.cpp - dead code elimination for scf.for ---==//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::scf;
+namespace {
+/// Detect dead arguments in scf.for op by assuming all the values are dead and
+/// propagate liveness property.
+struct ForOpDeadArgElimination : public OpRewritePattern<scf::ForOp> {
+  using OpRewritePattern<scf::ForOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(scf::ForOp forOp,
+                                PatternRewriter &rewriter) const final {
+    Block &block = *forOp.getBody();
+    auto yieldOp = cast<scf::YieldOp>(block.getTerminator());
+    // Assume that nothing is live at the beginning and mark values as live
+    // based on uses.
+    DenseSet<Value> aliveValues;
+    SmallVector<Value> queue;
+    // Helper to mark values as live and add them to the queue of value to
+    // propagate if it is the first time we detect the value as live.
+    auto markLive = [&](Value val) {
+      if (!forOp->isAncestor(val.getParentRegion()->getParentOp()))
+        return;
+      if (aliveValues.insert(val).second)
+        queue.push_back(val);
+    };
+    // Mark all yield operands as live if the associated forOp result has any
+    // use.
+    for (auto result : llvm::enumerate(forOp.getResults())) {
+      if (!result.value().use_empty())
+        markLive(yieldOp.getOperand(result.index()));
+    }
+    if (aliveValues.size() == forOp.getNumResults())
+      return failure();
+    // Operations with side-effects are always live. Mark all theirs operands as
+    // live except for scf.for and scf.if that have special handling.
+    block.walk([&](Operation *op) {
+      if (auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
+        if (isa<scf::ForOp, scf::IfOp>(yieldOp->getParentOp()))
+          return;
+      }
+      if (!isa<scf::ForOp, scf::IfOp>(op) && !wouldOpBeTriviallyDead(op)) {
+        for (Value operand : op->getOperands())
+          markLive(operand);
+      }
+    });
+    // Propagate live property until reaching a fixed point.
+    while (!queue.empty()) {
+      Value value = queue.pop_back_val();
+      if (auto nestedFor = value.getDefiningOp<scf::ForOp>()) {
+        auto result = value.cast<OpResult>();
+        OpOperand &forOperand = nestedFor.getOpOperandForResult(result);
+        markLive(forOperand.get());
+        auto nestedYieldOp =
+            cast<scf::YieldOp>(nestedFor.getBody()->getTerminator());
+        Value nestedYieldOperand =
+            nestedYieldOp.getOperand(result.getResultNumber());
+        markLive(nestedYieldOperand);
+        continue;
+      }
+      if (auto nestedIf = value.getDefiningOp<scf::IfOp>()) {
+        auto result = value.cast<OpResult>();
+        for (scf::YieldOp nestedYieldOp :
+             {nestedIf.thenYield(), nestedIf.elseYield()}) {
+          Value nestedYieldOperand =
+              nestedYieldOp.getOperand(result.getResultNumber());
+          markLive(nestedYieldOperand);
+        }
+        continue;
+      }
+      if (Operation *def = value.getDefiningOp()) {
+        for (Value operand : def->getOperands())
+          markLive(operand);
+        continue;
+      }
+      // If an argument block is live then the associated yield operand and
+      // forOp operand are live.
+      auto arg = value.cast<BlockArgument>();
+      if (auto forOwner = dyn_cast<scf::ForOp>(arg.getOwner()->getParentOp())) {
+        if (arg.getArgNumber() < forOwner.getNumInductionVars())
+          continue;
+        unsigned iterIdx = arg.getArgNumber() - forOwner.getNumInductionVars();
+        Value yieldOperand =
+            forOwner.getBody()->getTerminator()->getOperand(iterIdx);
+        markLive(yieldOperand);
+        markLive(forOwner.getInitArgs()[iterIdx]);
+      }
+    }
+    SmallVector<unsigned> deadArg;
+    for (auto yieldOperand : llvm::enumerate(yieldOp->getOperands())) {
+      if (aliveValues.contains(yieldOperand.value()))
+        continue;
+      if (yieldOperand.value() == block.getArgument(yieldOperand.index() + 1))
+        continue;
+      deadArg.push_back(yieldOperand.index());
+    }
+    if (deadArg.empty())
+      return failure();
+    rewriter.updateRootInPlace(forOp, [&]() {
+      // For simplicity we just change the dead yield operand to use the
+      // associated argument and leave the operations and argument removal to
+      // dead code elimination.
+      for (unsigned deadArgIdx : deadArg) {
+        BlockArgument arg = block.getArgument(deadArgIdx + 1);
+        yieldOp.setOperand(deadArgIdx, arg);
+      }
+    });
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::scf::populateForOpDeadCycleEliminationPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<ForOpDeadArgElimination>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/SCF/for-loop-dead-cycle.mlir b/mlir/test/Dialect/SCF/for-loop-dead-cycle.mlir
new file mode 100644
index 000000000000000..0d03b77ca45ddf2
--- /dev/null
+++ b/mlir/test/Dialect/SCF/for-loop-dead-cycle.mlir
@@ -0,0 +1,144 @@
+// RUN: mlir-opt %s -test-scf-for-op-dead-cycles -canonicalize -split-input-file | FileCheck %s
+
+// CHECK-LABEL: func @dead_arg(
+//  CHECK-SAME:   %[[A0:.*]]: f32, %[[A1:.*]]: f32)
+//       CHECK:   scf.for {{.*}} iter_args(%[[ARG1:.*]] = %[[A1]])
+//  CHECK-NEXT:     %[[S:.+]] = arith.addf %[[ARG1]], %[[ARG1]] : f32
+//  CHECK-NEXT:     scf.yield %[[S]]
+func.func @dead_arg(%a0 : f32, %a1 : f32) -> f32 {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c10 = arith.constant 10 : index
+  %0, %1 = scf.for %i = %c0 to %c10 step %c1
+  iter_args(%arg0 = %a0, %arg1 = %a1) -> (f32, f32) {
+    %s0 = arith.addf %arg0, %arg0 : f32
+    %s1 = arith.addf %arg1, %arg1 : f32
+    scf.yield %s0, %s1 : f32, f32
+  }
+  return %1 : f32
+}
+
+// -----
+
+// CHECK-LABEL: func @dead_arg_negative(
+//  CHECK-SAME:   %[[A0:.*]]: f32, %[[A1:.*]]: f32)
+//       CHECK:   scf.for {{.*}} iter_args(%[[ARG0:.*]] = %[[A0]], %[[ARG1:.*]] = %[[A1]])
+//  CHECK-NEXT:     %[[S0:.+]] = arith.addf %[[ARG0]], %[[ARG0]] : f32
+//  CHECK-NEXT:     %[[S1:.+]] = arith.addf %[[ARG1]], %[[ARG1]] : f32
+//  CHECK-NEXT:     scf.yield %[[S1]], %[[S0]]
+func.func @dead_arg_negative(%a0 : f32, %a1 : f32) -> f32 {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c10 = arith.constant 10 : index
+  %0, %1 = scf.for %i = %c0 to %c10 step %c1
+  iter_args(%arg0 = %a0, %arg1 = %a1) -> (f32, f32) {
+    %s0 = arith.addf %arg0, %arg0 : f32
+    %s1 = arith.addf %arg1, %arg1 : f32
+    scf.yield %s1, %s0 : f32, f32
+  }
+  return %1 : f32
+}
+
+// -----
+
+// CHECK-LABEL: func @dead_arg_side_effect(
+//  CHECK-SAME:   %[[A:.*]]: f32
+//       CHECK:   scf.for {{.*}} iter_args(%[[ARG0:.*]] = %[[A]])
+//  CHECK-NEXT:     %[[S0:.+]] = arith.addf %[[ARG0]], %[[ARG0]] : f32
+//  CHECK-NEXT:     memref.store
+//  CHECK-NEXT:     scf.yield %[[S0]]
+func.func @dead_arg_side_effect(%a : f32, %A : memref<f32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c10 = arith.constant 10 : index
+  %0, %1 = scf.for %i = %c0 to %c10 step %c1
+  iter_args(%arg0 = %a, %arg1 = %a) -> (f32, f32) {
+    %s0 = arith.addf %arg0, %arg0 : f32
+    %s1 = arith.addf %arg1, %arg1 : f32
+    memref.store %s0, %A[]: memref<f32>
+    scf.yield %s0, %s1 : f32, f32
+  }
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @dead_arg_recurse(
+//  CHECK-SAME:   %[[A0:.*]]: f32, %[[A1:.*]]: f32, %[[A2:.*]]: f32, %[[A3:.*]]: f32)
+//       CHECK:   scf.for {{.*}} iter_args(%[[ARG1:.*]] = %[[A1]], %[[ARG3:.*]] = %[[A3]])
+//  CHECK-NEXT:     %[[S0:.+]] = arith.addf %[[ARG1]], %[[ARG1]] : f32
+//  CHECK-NEXT:     %[[S1:.+]] = arith.addf %[[ARG1]], %[[ARG3]] : f32
+//  CHECK-NEXT:     scf.yield %[[S0]], %[[S1]]
+func.func @dead_arg_recurse(%a0 : f32, %a1 : f32, %a2 : f32, %a3 : f32) -> f32 {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c10 = arith.constant 10 : index
+  %0:4 = scf.for %i = %c0 to %c10 step %c1
+  iter_args(%arg0 = %a0, %arg1 = %a1, %arg2 = %a2, %arg3 = %a3) -> (f32, f32, f32, f32) {
+    %s0 = arith.addf %arg0, %arg3 : f32
+    %s1 = arith.addf %arg1, %arg1 : f32
+    %s2 = arith.addf %arg2, %arg0 : f32
+    %s3 = arith.addf %arg1, %arg3 : f32
+    scf.yield %s0, %s1, %s2, %s3 : f32, f32, f32, f32
+  }
+  return %0#3 : f32
+}
+
+// -----
+
+// CHECK-LABEL: func @dead_arg_nested(
+//  CHECK-SAME:   %[[A0:.*]]: f32, %[[A1:.*]]: f32)
+//       CHECK:   scf.for {{.*}} iter_args(%[[ARG1:.*]] = %[[A1]])
+//       CHECK:     %[[R:.+]] = scf.for {{.*}} iter_args(%[[ARG4:.*]] = %[[ARG1]])
+//  CHECK-NEXT:       %[[S:.+]] = arith.addf %[[ARG4]], %[[ARG4]] : f32
+//  CHECK-NEXT:       scf.yield %[[S]]
+//  CHECK-NEXT:     }
+//  CHECK-NEXT:     scf.yield %[[R]]
+func.func @dead_arg_nested(%a0 : f32, %a1 : f32) -> f32{
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c10 = arith.constant 10 : index
+  %0:2 = scf.for %i = %c0 to %c10 step %c1
+  iter_args(%arg0 = %a0, %arg1 = %a1) -> (f32, f32) {
+    %1:2 = scf.for %i1 = %c0 to %c10 step %c1
+    iter_args(%arg4 = %arg0, %arg5 = %arg1) -> (f32, f32) {
+      %s1 = arith.addf %arg4, %arg4 : f32
+      %s2 = arith.addf %arg5, %arg5 : f32
+      scf.yield %s1, %s2 : f32, f32
+    }
+    scf.yield %1#0, %1#1 : f32, f32
+  }
+  return %0#1 : f32
+}
+
+// -----
+
+// CHECK-LABEL: func @dead_arg_nested_if(
+//  CHECK-SAME:   %[[A0:.*]]: f32, %[[A1:.*]]: f32, %{{.*}}: i1)
+//       CHECK:   scf.for {{.*}} iter_args(%[[ARG1:.*]] = %[[A1]])
+//       CHECK:     %[[R:.+]] = scf.if {{.*}} {
+//  CHECK-NEXT:       %[[S:.+]] = arith.addf %[[ARG1]], %[[ARG1]] : f32
+//  CHECK-NEXT:       scf.yield %[[S]]
+//  CHECK-NEXT:     } else {
+//  CHECK-NEXT:       scf.yield %{{.*}} : f32
+//  CHECK-NEXT:     }
+//  CHECK-NEXT:     scf.yield %[[R]]
+func.func @dead_arg_nested_if(%a0 : f32, %a1 : f32, %c: i1) -> f32{
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c10 = arith.constant 10 : index
+  %0:2 = scf.for %i = %c0 to %c10 step %c1
+  iter_args(%arg0 = %a0, %arg1 = %a1) -> (f32, f32) {
+    %1:2 = scf.if %c -> (f32, f32) {
+      %s1 = arith.addf %arg0, %arg0 : f32
+      %s2 = arith.addf %arg1, %arg1 : f32
+      scf.yield %s1, %s2 : f32, f32
+    } else {
+      %cst_0 = arith.constant 1.000000e+00 : f32
+      %cst_1 = arith.constant 2.000000e+00 : f32
+      scf.yield %cst_0, %cst_1 : f32, f32
+    }
+    scf.yield %1#0, %1#1 : f32, f32
+  }
+  return %0#1 : f32
+}
diff --git a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
index 455c9234b8c93de..5bc7f37ed0994ea 100644
--- a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
+++ b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
@@ -226,6 +226,24 @@ struct TestSCFPipeliningPass
     });
   }
 };
+
+struct TestSCFForOpDeadArgPass
+    : public PassWrapper<TestSCFForOpDeadArgPass, OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFForOpDeadArgPass)
+
+  TestSCFForOpDeadArgPass() = default;
+  TestSCFForOpDeadArgPass(const TestSCFForOpDeadArgPass &) {}
+  StringRef getArgument() const final { return "test-scf-for-op-dead-cycles"; }
+  StringRef getDescription() const final {
+    return "test removing dead cycles in scf.forOp";
+  }
+
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    scf::populateForOpDeadCycleEliminationPatterns(patterns);
+    (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+  }
+};
 } // namespace
 
 namespace mlir {
@@ -234,6 +252,7 @@ void registerTestSCFUtilsPass() {
   PassRegistration<TestSCFForUtilsPass>();
   PassRegistration<TestSCFIfUtilsPass>();
   PassRegistration<TestSCFPipeliningPass>();
+  PassRegistration<TestSCFForOpDeadArgPass>();
 }
 } // namespace test
 } // namespace mlir



More information about the Mlir-commits mailing list