[Mlir-commits] [mlir] MLIR, LLVM: Add an IR utility to perform proper slicing (PR #103053)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Aug 13 05:59:59 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir

Author: Christian Ulmann (Dinistro)

<details>
<summary>Changes</summary>

This commit introduces a backwards slicing utility that is both `RegionBranchOpInterface` and `BranchOpInterface` aware. This is a first step in replacing the currently broken slicing utilities that cannot be used for structured control flow.

Note that the utility is currently placed in the IR library, because LLVM's inlining uses it and is part of the LLVMIR library. There was originally a substantial push back on moving the inliner interface out of said library.

For now, I didn't add a test pass, as I'm not sure if we usually have some of these for IR utilities.

---
Full diff: https://github.com/llvm/llvm-project/pull/103053.diff


4 Files Affected:

- (added) mlir/include/mlir/IR/SliceSupport.h (+123) 
- (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp (+14-71) 
- (modified) mlir/lib/IR/CMakeLists.txt (+1) 
- (added) mlir/lib/IR/SliceSupport.cpp (+130) 


``````````diff
diff --git a/mlir/include/mlir/IR/SliceSupport.h b/mlir/include/mlir/IR/SliceSupport.h
new file mode 100644
index 00000000000000..b840ca72d9d027
--- /dev/null
+++ b/mlir/include/mlir/IR/SliceSupport.h
@@ -0,0 +1,123 @@
+//===- SliceSupport.h - Helpers for performing IR slicing -------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_SLICESUPPORT_H
+#define MLIR_IR_SLICESUPPORT_H
+
+#include "mlir/IR/ValueRange.h"
+
+namespace mlir {
+
+/// A class to signal how to proceed with the walk of the backward slice:
+/// - Interrupt: Stops the walk.
+/// - Advance: Continues the walk to control flow predecessors values.
+/// - AdvanceTo: Continues the walk to user-specified values.
+/// - Skip: Continues the walk, but skips the predecessors of the current value.
+class WalkContinuation {
+public:
+  enum class WalkAction {
+    /// Stops the walk.
+    Interrupt,
+    /// Continues the walk to control flow predecessors values.
+    Advance,
+    /// Continues the walk to user-specified values.
+    AdvanceTo,
+    /// Continues the walk, but skips the predecessors of the current value.
+    Skip
+  };
+
+  WalkContinuation(WalkAction action, mlir::ValueRange nextValues)
+      : action(action), nextValues(nextValues) {}
+
+  /// Allows LogicalResult to interrupt the walk on failure.
+  explicit WalkContinuation(llvm::LogicalResult action)
+      : action(failed(action) ? WalkAction::Interrupt : WalkAction::Advance) {}
+
+  /// Allows diagnostics to interrupt the walk.
+  explicit WalkContinuation(mlir::Diagnostic &&)
+      : action(WalkAction::Interrupt) {}
+
+  /// Allows diagnostics to interrupt the walk.
+  explicit WalkContinuation(mlir::InFlightDiagnostic &&)
+      : action(WalkAction::Interrupt) {}
+
+  /// Creates a continuation that interrupts the walk.
+  static WalkContinuation interrupt() {
+    return WalkContinuation(WalkAction::Interrupt, {});
+  }
+
+  /// Creates a continuation that adds the user-specified `nextValues` to the
+  /// work list and advances the walk. Unlike advance, this function does not
+  /// add the control flow predecessor values to the work list.
+  static WalkContinuation advanceTo(mlir::ValueRange nextValues) {
+    return WalkContinuation(WalkAction::AdvanceTo, nextValues);
+  }
+
+  /// Creates a continuation that adds the control flow predecessor values to
+  /// the work list and advances the walk.
+  static WalkContinuation advance() {
+    return WalkContinuation(WalkAction::Advance, {});
+  }
+
+  /// Creates a continuation that advances the walk without adding any
+  /// predecessor values to the work list.
+  static WalkContinuation skip() {
+    return WalkContinuation(WalkAction::Skip, {});
+  }
+
+  /// Returns true if the walk was interrupted.
+  bool wasInterrupted() const { return action == WalkAction::Interrupt; }
+
+  /// Returns true if the walk was skipped.
+  bool wasSkipped() const { return action == WalkAction::Skip; }
+
+  /// Returns true if the walk was advanced to user-specified values.
+  bool wasAdvancedTo() const { return action == WalkAction::AdvanceTo; }
+
+  /// Returns the next values to continue the walk with.
+  mlir::ArrayRef<mlir::Value> getNextValues() const { return nextValues; }
+
+private:
+  WalkAction action;
+  /// The next values to continue the walk with.
+  mlir::SmallVector<mlir::Value> nextValues;
+};
+
+/// A callback that is invoked for each value encountered during the walk of the
+/// backward slice. The callback takes the current value, and returns the walk
+/// continuation, which determines if the walk should proceed and if yes, with
+/// which values.
+using WalkCallback = mlir::function_ref<WalkContinuation(mlir::Value)>;
+
+/// Walks the backward slice starting from the `rootValues` using a depth-first
+/// traversal following the use-def chains. The walk calls the provided
+/// `walkCallback` for each value encountered in the backward slice and uses the
+/// returned walk continuation to determine how to proceed. Additionally, the
+/// walk also transparently traverses through select operations and control flow
+/// operations that implement RegionBranchOpInterface or BranchOpInterface.
+WalkContinuation walkBackwardSlice(mlir::ValueRange rootValues,
+                                   WalkCallback walkCallback);
+
+/// A callback that is invoked for each value encountered during the walk of the
+/// backward slice. The callback takes the current value, and returns the walk
+/// continuation, which determines if the walk should proceed and if yes, with
+/// which values.
+using WalkCallback = mlir::function_ref<WalkContinuation(mlir::Value)>;
+
+/// Walks the backward slice starting from the `rootValues` using a depth-first
+/// traversal following the use-def chains. The walk calls the provided
+/// `walkCallback` for each value encountered in the backward slice and uses the
+/// returned walk continuation to determine how to proceed. Additionally, the
+/// walk also transparently traverses through select operations and control flow
+/// operations that implement RegionBranchOpInterface or BranchOpInterface.
+WalkContinuation walkBackwardSlice(mlir::ValueRange rootValues,
+                                   WalkCallback walkCallback);
+
+} // namespace mlir
+
+#endif // MLIR_IR_SLICESUPPORT_H
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp
index 137c1962b100af..fd875fe2ffcac4 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp
@@ -14,6 +14,7 @@
 #include "LLVMInlining.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/IR/Matchers.h"
+#include "mlir/IR/SliceSupport.h"
 #include "mlir/Interfaces/DataLayoutInterfaces.h"
 #include "mlir/Transforms/InliningUtils.h"
 #include "llvm/ADT/ScopeExit.h"
@@ -221,86 +222,28 @@ static ArrayAttr concatArrayAttr(ArrayAttr lhs, ArrayAttr rhs) {
   return ArrayAttr::get(lhs.getContext(), result);
 }
 
-/// Attempts to return the underlying pointer value that `pointerValue` is based
-/// on. This traverses down the chain of operations to the last operation
-/// producing the base pointer and returns it. If it encounters an operation it
-/// cannot further traverse through, returns the operation's result.
-static Value getUnderlyingObject(Value pointerValue) {
-  while (true) {
-    if (auto gepOp = pointerValue.getDefiningOp<LLVM::GEPOp>()) {
-      pointerValue = gepOp.getBase();
-      continue;
-    }
-
-    if (auto addrCast = pointerValue.getDefiningOp<LLVM::AddrSpaceCastOp>()) {
-      pointerValue = addrCast.getOperand();
-      continue;
-    }
-
-    break;
-  }
-
-  return pointerValue;
-}
-
 /// Attempts to return the set of all underlying pointer values that
 /// `pointerValue` is based on. This function traverses through select
 /// operations and block arguments unlike getUnderlyingObject.
 static SmallVector<Value> getUnderlyingObjectSet(Value pointerValue) {
   SmallVector<Value> result;
+  walkBackwardSlice(pointerValue, [&](Value val) {
+    if (auto gepOp = val.getDefiningOp<LLVM::GEPOp>())
+      return WalkContinuation::advanceTo(gepOp.getBase());
 
-  SmallVector<Value> workList{pointerValue};
-  // Avoid dataflow loops.
-  SmallPtrSet<Value, 4> seen;
-  do {
-    Value current = workList.pop_back_val();
-    current = getUnderlyingObject(current);
-
-    if (!seen.insert(current).second)
-      continue;
-
-    if (auto selectOp = current.getDefiningOp<LLVM::SelectOp>()) {
-      workList.push_back(selectOp.getTrueValue());
-      workList.push_back(selectOp.getFalseValue());
-      continue;
-    }
-
-    if (auto blockArg = dyn_cast<BlockArgument>(current)) {
-      Block *parentBlock = blockArg.getParentBlock();
-
-      // Attempt to find all block argument operands for every predecessor.
-      // If any operand to the block argument wasn't found in a predecessor,
-      // conservatively add the block argument to the result set.
-      SmallVector<Value> operands;
-      bool anyUnknown = false;
-      for (auto iter = parentBlock->pred_begin();
-           iter != parentBlock->pred_end(); iter++) {
-        auto branch = dyn_cast<BranchOpInterface>((*iter)->getTerminator());
-        if (!branch) {
-          result.push_back(blockArg);
-          anyUnknown = true;
-          break;
-        }
-
-        Value operand = branch.getSuccessorOperands(
-            iter.getSuccessorIndex())[blockArg.getArgNumber()];
-        if (!operand) {
-          result.push_back(blockArg);
-          anyUnknown = true;
-          break;
-        }
-
-        operands.push_back(operand);
-      }
+    if (auto addrCast = val.getDefiningOp<LLVM::AddrSpaceCastOp>())
+      return WalkContinuation::advanceTo(addrCast.getOperand());
 
-      if (!anyUnknown)
-        llvm::append_range(workList, operands);
+    // TODO: Add a SelectLikeOpInterface and use it in the slicing utility.
+    if (auto selectOp = val.getDefiningOp<LLVM::SelectOp>())
+      return WalkContinuation::advanceTo(
+          {selectOp.getTrueValue(), selectOp.getFalseValue()});
 
-      continue;
-    }
+    if (isa<OpResult>(val))
+      result.push_back(val);
 
-    result.push_back(current);
-  } while (!workList.empty());
+    return WalkContinuation::advance();
+  });
 
   return result;
 }
diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt
index c38ce6c058a006..bfff29abc824a0 100644
--- a/mlir/lib/IR/CMakeLists.txt
+++ b/mlir/lib/IR/CMakeLists.txt
@@ -32,6 +32,7 @@ add_mlir_library(MLIRIR
   PatternMatch.cpp
   Region.cpp
   RegionKindInterface.cpp
+  SliceSupport.cpp
   SymbolTable.cpp
   TensorEncoding.cpp
   Types.cpp
diff --git a/mlir/lib/IR/SliceSupport.cpp b/mlir/lib/IR/SliceSupport.cpp
new file mode 100644
index 00000000000000..b3bb6a39ddd981
--- /dev/null
+++ b/mlir/lib/IR/SliceSupport.cpp
@@ -0,0 +1,130 @@
+#include "mlir/IR/SliceSupport.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
+
+using namespace mlir;
+
+/// Returns the operands from all predecessor regions that match `operandNumber`
+/// for the `successor` region within `regionOp`.
+static SmallVector<Value>
+getRegionPredecessorOperands(RegionBranchOpInterface regionOp,
+                             RegionSuccessor successor,
+                             unsigned operandNumber) {
+  SmallVector<Value> predecessorOperands;
+
+  // Returns true if `successors` contains `successor`.
+  auto isContained = [](ArrayRef<RegionSuccessor> successors,
+                        RegionSuccessor successor) {
+    auto *it = llvm::find_if(successors, [&successor](RegionSuccessor curr) {
+      return curr.getSuccessor() == successor.getSuccessor();
+    });
+    return it != successors.end();
+  };
+
+  // Search the operand ranges on the region operation itself.
+  SmallVector<Attribute> operandAttributes(regionOp->getNumOperands());
+  SmallVector<RegionSuccessor> successors;
+  regionOp.getEntrySuccessorRegions(operandAttributes, successors);
+  if (isContained(successors, successor)) {
+    OperandRange operands = regionOp.getEntrySuccessorOperands(successor);
+    predecessorOperands.push_back(operands[operandNumber]);
+  }
+
+  // Search the operand ranges on region terminators.
+  for (Region &region : regionOp->getRegions()) {
+    for (Block &block : region) {
+      auto terminatorOp =
+          dyn_cast<RegionBranchTerminatorOpInterface>(block.getTerminator());
+      if (!terminatorOp)
+        continue;
+      SmallVector<Attribute> operandAttributes(terminatorOp->getNumOperands());
+      SmallVector<RegionSuccessor> successors;
+      terminatorOp.getSuccessorRegions(operandAttributes, successors);
+      if (isContained(successors, successor)) {
+        OperandRange operands = terminatorOp.getSuccessorOperands(successor);
+        predecessorOperands.push_back(operands[operandNumber]);
+      }
+    }
+  }
+
+  return predecessorOperands;
+}
+
+/// Returns the predecessor branch operands that match `blockArg`.
+static SmallVector<Value> getBlockPredecessorOperands(BlockArgument blockArg) {
+  Block *block = blockArg.getOwner();
+
+  // Search the predecessor operands for all predecessor terminators.
+  SmallVector<Value> predecessorOperands;
+  for (auto it = block->pred_begin(); it != block->pred_end(); ++it) {
+    Block *predecessor = *it;
+    auto branchOp = cast<BranchOpInterface>(predecessor->getTerminator());
+    SuccessorOperands successorOperands =
+        branchOp.getSuccessorOperands(it.getSuccessorIndex());
+    // Store the predecessor operand if the block argument matches an operand
+    // and is not produced by the terminator.
+    if (Value operand = successorOperands[blockArg.getArgNumber()])
+      predecessorOperands.push_back(operand);
+  }
+
+  return predecessorOperands;
+}
+
+mlir::WalkContinuation mlir::walkBackwardSlice(ValueRange rootValues,
+                                               WalkCallback walkCallback) {
+  // Search the backward slice starting from the root values.
+  SmallVector<Value> workList = rootValues;
+  llvm::SmallDenseSet<Value, 16> seenValues;
+  while (!workList.empty()) {
+    // Search the backward slice of the current value.
+    Value current = workList.pop_back_val();
+
+    // Skip the current value if it has already been seen.
+    if (!seenValues.insert(current).second)
+      continue;
+
+    // Call the walk callback with the current value.
+    WalkContinuation continuation = walkCallback(current);
+    if (continuation.wasInterrupted())
+      return continuation;
+    if (continuation.wasSkipped())
+      continue;
+
+    // Add the next values to the work list if the walk should continue.
+    if (continuation.wasAdvancedTo()) {
+      workList.append(continuation.getNextValues().begin(),
+                      continuation.getNextValues().end());
+      continue;
+    }
+
+    // Add the control flow predecessor operands to the work list.
+    if (OpResult opResult = dyn_cast<OpResult>(current)) {
+      auto regionOp = dyn_cast<RegionBranchOpInterface>(opResult.getOwner());
+      if (!regionOp)
+        continue;
+      RegionSuccessor region(regionOp->getResults());
+      SmallVector<Value> predecessorOperands = getRegionPredecessorOperands(
+          regionOp, region, opResult.getResultNumber());
+      workList.append(predecessorOperands.begin(), predecessorOperands.end());
+      continue;
+    }
+
+    auto blockArg = cast<BlockArgument>(current);
+    Block *block = blockArg.getOwner();
+    // Search the region predecessor operands for structured control flow.
+    auto regionBranchOp =
+        dyn_cast<RegionBranchOpInterface>(block->getParentOp());
+    if (block->isEntryBlock() && regionBranchOp) {
+      RegionSuccessor region(blockArg.getParentRegion());
+      SmallVector<Value> predecessorOperands = getRegionPredecessorOperands(
+          regionBranchOp, region, blockArg.getArgNumber());
+      workList.append(predecessorOperands.begin(), predecessorOperands.end());
+      continue;
+    }
+    // Search the block predecessor operands for unstructured control flow.
+    SmallVector<Value> predecessorOperands =
+        getBlockPredecessorOperands(blockArg);
+    workList.append(predecessorOperands.begin(), predecessorOperands.end());
+  }
+
+  return WalkContinuation::advance();
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/103053


More information about the Mlir-commits mailing list