[Mlir-commits] [mlir] Prevent invalid IR from being passed outside of RemoveDeadValues (PR #121079)

Renat Idrisov llvmlistbot at llvm.org
Tue Dec 24 18:39:55 PST 2024


https://github.com/parsifal-47 created https://github.com/llvm/llvm-project/pull/121079

This is a follow-up for https://github.com/llvm/llvm-project/pull/119110 and a fix for https://github.com/llvm/llvm-project/issues/118450

RemoveDeadValues used to delete Values and analyzing the IR at the same time, because of that, `isMemoryEffectFree` got invalid IR with half-deleted linalg.generic operation. This PR separates analysis and cleanup to prevent such situation.

Thank you!

>From 2db7609b851d249e16a96cc40cdc7f710575b0bc Mon Sep 17 00:00:00 2001
From: Renat Idrisov <parsifal-47 at users.noreply.github.com>
Date: Wed, 25 Dec 2024 02:30:22 +0000
Subject: [PATCH] Prevent invalid IR from being passed outside of
 RemoveDeadValues

---
 mlir/lib/Transforms/RemoveDeadValues.cpp     | 206 +++++++++++++------
 mlir/test/Transforms/remove-dead-values.mlir |  26 +++
 2 files changed, 172 insertions(+), 60 deletions(-)

diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 3429008b2f241a..5d4ec66d6905a4 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -72,15 +72,54 @@ using namespace mlir::dataflow;
 
 namespace {
 
+// Set of structures below to be filled with operations and arguments to erase.
+// This is done to separate analysis and tree modification phases,
+// otherwise analysis is operating on half-deleted tree which is incorrect.
+
+struct CleanupFunction {
+  FunctionOpInterface funcOp;
+  BitVector nonLiveArgs;
+  BitVector nonLiveRets;
+};
+
+struct CleanupOperands {
+  Operation *op;
+  BitVector nonLiveOperands;
+};
+
+struct CleanupResults {
+  Operation *op;
+  BitVector nonLiveResults;
+};
+
+struct CleanupBlockArgs {
+  Block *b;
+  BitVector nonLiveArgs;
+};
+
+struct CleanupSuccessorOperands {
+  BranchOpInterface branch;
+  unsigned index;
+  BitVector nonLiveOperands;
+};
+
+struct CleanupList {
+  SmallVector<Operation *> operations;
+  SmallVector<Value> values;
+  SmallVector<CleanupFunction> functions;
+  SmallVector<CleanupOperands> operands;
+  SmallVector<CleanupResults> results;
+  SmallVector<CleanupBlockArgs> blocks;
+  SmallVector<CleanupSuccessorOperands> successorOperands;
+};
+
 // Some helper functions...
 
 /// Return true iff at least one value in `values` is live, given the liveness
 /// information in `la`.
-static bool hasLive(ValueRange values, RunLivenessAnalysis &la) {
+static bool hasLive(ValueRange values, const DenseSet<Value> &deletionSet, RunLivenessAnalysis &la) {
   for (Value value : values) {
-    // If there is a null value, it implies that it was dropped during the
-    // execution of this pass, implying that it was non-live.
-    if (!value)
+    if (deletionSet.contains(value))
       continue;
 
     const Liveness *liveness = la.getLiveness(value);
@@ -92,11 +131,11 @@ static bool hasLive(ValueRange values, RunLivenessAnalysis &la) {
 
 /// Return a BitVector of size `values.size()` where its i-th bit is 1 iff the
 /// i-th value in `values` is live, given the liveness information in `la`.
-static BitVector markLives(ValueRange values, RunLivenessAnalysis &la) {
+static BitVector markLives(ValueRange values, const DenseSet<Value> &deletionSet, RunLivenessAnalysis &la) {
   BitVector lives(values.size(), true);
 
   for (auto [index, value] : llvm::enumerate(values)) {
-    if (!value) {
+    if (deletionSet.contains(value)) {
       lives.reset(index);
       continue;
     }
@@ -115,6 +154,18 @@ static BitVector markLives(ValueRange values, RunLivenessAnalysis &la) {
   return lives;
 }
 
+// DeletionSet is used to track the Values that are scheduled for removal
+void updateDeletionSet(DenseSet<Value> &deletionSet, ValueRange range, const BitVector &nonLive) {
+  for (auto [index, result] : llvm::enumerate(range)) {
+    if (!nonLive[index]) continue;
+    deletionSet.insert(result);
+  }
+}
+
+void updateDeletionSet(DenseSet<Value> &deletionSet, Operation *op, const BitVector &nonLive) {
+  updateDeletionSet(deletionSet, op->getResults(), nonLive);
+}
+
 /// Drop the uses of the i-th result of `op` and then erase it iff toErase[i]
 /// is 1.
 static void dropUsesAndEraseResults(Operation *op, BitVector toErase) {
@@ -174,43 +225,43 @@ static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) {
 /// It is assumed that `op` is simple. Here, a simple op is one which isn't a
 /// function-like op, a call-like op, a region branch op, a branch op, a region
 /// branch terminator op, or return-like.
-static void cleanSimpleOp(Operation *op, RunLivenessAnalysis &la) {
-  if (!isMemoryEffectFree(op) || hasLive(op->getResults(), la))
+static void cleanSimpleOp(CleanupList &cl, DenseSet<Value> &deletionSet, Operation *op, RunLivenessAnalysis &la) {
+  if (!isMemoryEffectFree(op) || hasLive(op->getResults(), deletionSet, la))
     return;
 
-  op->dropAllUses();
-  op->erase();
+  cl.operations.push_back(op);
+  updateDeletionSet(deletionSet, op, BitVector(op->getNumResults(), true));
 }
 
 /// Clean a function-like op `funcOp`, given the liveness information in `la`
 /// and the IR in `module`. Here, cleaning means:
 ///   (1) Dropping the uses of its unnecessary (non-live) arguments,
-///   (2) Erasing these arguments,
-///   (3) Erasing their corresponding operands from its callers,
+///   (2) Erasing their corresponding operands from its callers,
+///   (3) Erasing these arguments,
 ///   (4) Erasing its unnecessary terminator operands (return values that are
 ///   non-live across all callers),
 ///   (5) Dropping the uses of these return values from its callers, AND
 ///   (6) Erasing these return values
 /// iff it is not public or external.
-static void cleanFuncOp(FunctionOpInterface funcOp, Operation *module,
+static void cleanFuncOp(CleanupList &cl, DenseSet<Value> &deletionSet,
+                        FunctionOpInterface funcOp, Operation *module,
                         RunLivenessAnalysis &la) {
   if (funcOp.isPublic() || funcOp.isExternal())
     return;
 
   // Get the list of unnecessary (non-live) arguments in `nonLiveArgs`.
   SmallVector<Value> arguments(funcOp.getArguments());
-  BitVector nonLiveArgs = markLives(arguments, la);
+  BitVector nonLiveArgs = markLives(arguments, deletionSet, la);
   nonLiveArgs = nonLiveArgs.flip();
 
   // Do (1).
   for (auto [index, arg] : llvm::enumerate(arguments))
-    if (arg && nonLiveArgs[index])
-      arg.dropAllUses();
+    if (arg && nonLiveArgs[index]) {
+      cl.values.push_back(arg);
+      deletionSet.insert(arg);
+    }
 
   // Do (2).
-  funcOp.eraseArguments(nonLiveArgs);
-
-  // Do (3).
   SymbolTable::UseRange uses = *funcOp.getSymbolUses(module);
   for (SymbolTable::SymbolUse use : uses) {
     Operation *callOp = use.getUser();
@@ -222,7 +273,7 @@ static void cleanFuncOp(FunctionOpInterface funcOp, Operation *module,
         operandsToOpOperands(cast<CallOpInterface>(callOp).getArgOperands());
     for (int index : nonLiveArgs.set_bits())
       nonLiveCallOperands.set(callOpOperands[index]->getOperandNumber());
-    callOp->eraseOperands(nonLiveCallOperands);
+    cl.operands.push_back({callOp, nonLiveCallOperands});
   }
 
   // Get the list of unnecessary terminator operands (return values that are
@@ -253,26 +304,27 @@ static void cleanFuncOp(FunctionOpInterface funcOp, Operation *module,
   for (SymbolTable::SymbolUse use : uses) {
     Operation *callOp = use.getUser();
     assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
-    BitVector liveCallRets = markLives(callOp->getResults(), la);
+    BitVector liveCallRets = markLives(callOp->getResults(), deletionSet, la);
     nonLiveRets &= liveCallRets.flip();
   }
 
-  // Do (4).
+  // Do (3).
   // Note that in the absence of control flow ops forcing the control to go from
   // the entry (first) block to the other blocks, the control never reaches any
   // block other than the entry block, because every block has a terminator.
   for (Block &block : funcOp.getBlocks()) {
     Operation *returnOp = block.getTerminator();
     if (returnOp && returnOp->getNumOperands() == numReturns)
-      returnOp->eraseOperands(nonLiveRets);
+      cl.operands.push_back({returnOp, nonLiveRets});
   }
-  funcOp.eraseResults(nonLiveRets);
+  cl.functions.push_back({funcOp, nonLiveArgs, nonLiveRets});
 
   // Do (5) and (6).
   for (SymbolTable::SymbolUse use : uses) {
     Operation *callOp = use.getUser();
     assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
-    dropUsesAndEraseResults(callOp, nonLiveRets);
+    cl.results.push_back({callOp, nonLiveRets});
+    updateDeletionSet(deletionSet, callOp, nonLiveRets);
   }
 }
 
@@ -297,18 +349,19 @@ static void cleanFuncOp(FunctionOpInterface funcOp, Operation *module,
 /// It is important to note that values in this op flow from operands and
 /// terminator operands (successor operands) to arguments and results (successor
 /// inputs).
-static void cleanRegionBranchOp(RegionBranchOpInterface regionBranchOp,
+static void cleanRegionBranchOp(CleanupList &cl, DenseSet<Value> &deletionSet,
+                                RegionBranchOpInterface regionBranchOp,
                                 RunLivenessAnalysis &la) {
   // Mark live results of `regionBranchOp` in `liveResults`.
   auto markLiveResults = [&](BitVector &liveResults) {
-    liveResults = markLives(regionBranchOp->getResults(), la);
+    liveResults = markLives(regionBranchOp->getResults(), deletionSet, la);
   };
 
   // Mark live arguments in the regions of `regionBranchOp` in `liveArgs`.
   auto markLiveArgs = [&](DenseMap<Region *, BitVector> &liveArgs) {
     for (Region &region : regionBranchOp->getRegions()) {
       SmallVector<Value> arguments(region.front().getArguments());
-      BitVector regionLiveArgs = markLives(arguments, la);
+      BitVector regionLiveArgs = markLives(arguments, deletionSet, la);
       liveArgs[&region] = regionLiveArgs;
     }
   };
@@ -497,9 +550,8 @@ static void cleanRegionBranchOp(RegionBranchOpInterface regionBranchOp,
   // It could never be live because of this op but its liveness could have been
   // attributed to something else.
   if (isMemoryEffectFree(regionBranchOp.getOperation()) &&
-      !hasLive(regionBranchOp->getResults(), la)) {
-    regionBranchOp->dropAllUses();
-    regionBranchOp->erase();
+      !hasLive(regionBranchOp->getResults(), deletionSet, la)) {
+    cl.operations.push_back(regionBranchOp.getOperation());
     return;
   }
 
@@ -538,29 +590,27 @@ static void cleanRegionBranchOp(RegionBranchOpInterface regionBranchOp,
                    terminatorOperandsToKeep);
 
   // Do (1).
-  regionBranchOp->eraseOperands(operandsToKeep.flip());
+  cl.operands.push_back({regionBranchOp, operandsToKeep.flip()});
 
   // Do (2.a) and (2.b).
   for (Region &region : regionBranchOp->getRegions()) {
     assert(!region.empty() && "expected a non-empty region in an op "
                               "implementing `RegionBranchOpInterface`");
-    for (auto [index, arg] : llvm::enumerate(region.front().getArguments())) {
-      if (argsToKeep[&region][index])
-        continue;
-      if (arg)
-        arg.dropAllUses();
-    }
-    region.front().eraseArguments(argsToKeep[&region].flip());
+    BitVector argsToRemove = argsToKeep[&region].flip();
+    cl.blocks.push_back({&region.front(), argsToRemove});
+    updateDeletionSet(deletionSet, region.front().getArguments(), argsToRemove);
   }
 
   // Do (2.c).
   for (Region &region : regionBranchOp->getRegions()) {
     Operation *terminator = region.front().getTerminator();
-    terminator->eraseOperands(terminatorOperandsToKeep[terminator].flip());
+    cl.operands.push_back({terminator, terminatorOperandsToKeep[terminator].flip()});
   }
 
   // Do (3) and (4).
-  dropUsesAndEraseResults(regionBranchOp.getOperation(), resultsToKeep.flip());
+  BitVector resultsToRemove = resultsToKeep.flip();
+  updateDeletionSet(deletionSet, regionBranchOp.getOperation(), resultsToRemove);
+  cl.results.push_back({regionBranchOp.getOperation(), resultsToRemove});
 }
 
 // 1. Iterate over each successor block of the given BranchOpInterface
@@ -572,7 +622,8 @@ static void cleanRegionBranchOp(RegionBranchOpInterface regionBranchOp,
 //    c. Mark each operand as live or dead based on the analysis.
 // 3. Remove dead operands from the branch operation and arguments accordingly
 
-static void cleanBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la) {
+static void cleanBranchOp(CleanupList &cl, DenseSet<Value> &deletionSet,
+                          BranchOpInterface branchOp, RunLivenessAnalysis &la) {
   unsigned numSuccessors = branchOp->getNumSuccessors();
 
   // Do (1)
@@ -588,22 +639,53 @@ static void cleanBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la) {
       operandValues.push_back(successorOperands[operandIdx]);
     }
 
-    BitVector successorLiveOperands = markLives(operandValues, la);
-
     // Do (3)
-    for (int argIdx = successorLiveOperands.size() - 1; argIdx >= 0; --argIdx) {
-      if (!successorLiveOperands[argIdx]) {
-        if (successorBlock->getNumArguments() < successorOperands.size()) {
-          // if block was cleaned through a different code path
-          // we only need to remove operands from the invokation
-          successorOperands.erase(argIdx);
-          continue;
-        }
+    BitVector successorNonLive = markLives(operandValues, deletionSet, la).flip();
+    updateDeletionSet(deletionSet, successorBlock->getArguments(), successorNonLive);
+    cl.blocks.push_back({successorBlock, successorNonLive});
+    cl.successorOperands.push_back({branchOp, succIdx, successorNonLive});
+  }
+}
+
+void cleanup(CleanupList &cl) {
+  for (auto &op: cl.operations) {
+    op->dropAllUses();
+    op->erase();
+  }
+
+  for (auto &v: cl.values) {
+    v.dropAllUses();
+  }
+
+  for (auto &f: cl.functions) {
+    f.funcOp.eraseArguments(f.nonLiveArgs);
+    f.funcOp.eraseResults(f.nonLiveRets);
+  }
+
+  for (auto &o: cl.operands) {
+    o.op->eraseOperands(o.nonLiveOperands);  }
+
+  for (auto &r: cl.results) {
+    dropUsesAndEraseResults(r.op, r.nonLiveResults);
+  }
 
-        successorBlock->getArgument(argIdx).dropAllUses();
-        successorOperands.erase(argIdx);
-        successorBlock->eraseArgument(argIdx);
-      }
+  for (auto &b: cl.blocks) {
+    // blocks that are accessed via multiple codepaths processed once
+    if (b.b->getNumArguments() != b.nonLiveArgs.size()) continue;
+    for (int i = b.nonLiveArgs.size() - 1; i >= 0; --i) {
+      if (!b.nonLiveArgs[i]) continue;
+      b.b->getArgument(i).dropAllUses();
+      b.b->eraseArgument(i);
+    }
+  }
+  for (auto &op: cl.successorOperands) {
+    SuccessorOperands successorOperands =
+            op.branch.getSuccessorOperands(op.index);
+    // blocks that are accessed via multiple codepaths processed once
+    if (successorOperands.size() != op.nonLiveOperands.size()) continue;
+    for (int i = successorOperands.size() - 1; i >= 0; --i) {
+      if (!op.nonLiveOperands[i]) continue;
+      successorOperands.erase(i);
     }
   }
 }
@@ -616,14 +698,16 @@ struct RemoveDeadValues : public impl::RemoveDeadValuesBase<RemoveDeadValues> {
 void RemoveDeadValues::runOnOperation() {
   auto &la = getAnalysis<RunLivenessAnalysis>();
   Operation *module = getOperation();
+  DenseSet<Value> deletionSet;
+  CleanupList cl;
 
   module->walk([&](Operation *op) {
     if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
-      cleanFuncOp(funcOp, module, la);
+      cleanFuncOp(cl, deletionSet, funcOp, module, la);
     } else if (auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op)) {
-      cleanRegionBranchOp(regionBranchOp, la);
+      cleanRegionBranchOp(cl, deletionSet, regionBranchOp, la);
     } else if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
-      cleanBranchOp(branchOp, la);
+      cleanBranchOp(cl, deletionSet, branchOp, la);
     } else if (op->hasTrait<::mlir::OpTrait::IsTerminator>()) {
       // Nothing to do here because this is a terminator op and it should be
       // honored with respect to its parent
@@ -631,9 +715,11 @@ void RemoveDeadValues::runOnOperation() {
       // Nothing to do because this op is associated with a function op and gets
       // cleaned when the latter is cleaned.
     } else {
-      cleanSimpleOp(op, la);
+      cleanSimpleOp(cl, deletionSet, op, la);
     }
   });
+
+  cleanup(cl);
 }
 
 std::unique_ptr<Pass> mlir::createRemoveDeadValuesPass() {
diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir
index 9273ac01e7ccec..fe7bcbc7c490b6 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -73,6 +73,32 @@ func.func @cleanable_loop_iter_args_value(%arg0: index) -> index {
 
 // -----
 
+// Checking that the arguments of linalg.generic are properly handled
+// All code below is removed as a result of the pass
+//
+#map = affine_map<(d0, d1, d2) -> (0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+module {
+  func.func @main() {
+    %cst_3 = arith.constant dense<54> : tensor<1x25x13xi32>
+    %cst_7 = arith.constant dense<11> : tensor<1x25x13xi32>
+    // CHECK-NOT: arith.constant
+    %0 = tensor.empty() : tensor<1x25x13xi32>
+    // CHECK-NOT: tensor
+    %1 = linalg.generic {indexing_maps = [#map, #map, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_3, %cst_7 : tensor<1x25x13xi32>, tensor<1x25x13xi32>) outs(%0 : tensor<1x25x13xi32>) {
+    // CHECK-NOT: linalg.generic
+    ^bb0(%in: i32, %in_15: i32, %out: i32):
+      %29 = arith.xori %in, %in_15 : i32
+      // CHECK-NOT: arith.xori
+      linalg.yield %29 : i32
+      // CHECK-NOT: linalg.yield
+    } -> tensor<1x25x13xi32>
+    return
+  }
+}
+
+// -----
+
 // Note that this cleanup cannot be done by the `canonicalize` pass.
 //
 // CHECK-LABEL: func.func private @clean_func_op_remove_argument_and_return_value() {



More information about the Mlir-commits mailing list