[Mlir-commits] [mlir] 1415365 - [MLIR][LLVM]: Add an IR utility to perform slice walking (#103053)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Aug 15 01:30:51 PDT 2024
Author: Christian Ulmann
Date: 2024-08-15T10:30:44+02:00
New Revision: 141536544f4ec1d1bf24256157f4ff1a3bc07dae
URL: https://github.com/llvm/llvm-project/commit/141536544f4ec1d1bf24256157f4ff1a3bc07dae
DIFF: https://github.com/llvm/llvm-project/commit/141536544f4ec1d1bf24256157f4ff1a3bc07dae.diff
LOG: [MLIR][LLVM]: Add an IR utility to perform slice walking (#103053)
This commit introduces a slicing utility that can be used to walk
arbitrary IR slices. It additionally ships logic to determine control
flow predecessors, which allows users to walk backward slices without
dealing with both `RegionBranchOpInterface` and `BranchOpInterface`.
This utility is used to improve the `noalias` propagation in the LLVM
dialect's inliner interface. Before this change, it broke down as soon
as pointer were passed through region control flow operations.
Added:
mlir/include/mlir/Analysis/SliceWalk.h
mlir/lib/Analysis/SliceWalk.cpp
Modified:
mlir/lib/Analysis/CMakeLists.txt
mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp
mlir/test/Dialect/LLVMIR/inlining-alias-scopes.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Analysis/SliceWalk.h b/mlir/include/mlir/Analysis/SliceWalk.h
new file mode 100644
index 00000000000000..481c5690c533ba
--- /dev/null
+++ b/mlir/include/mlir/Analysis/SliceWalk.h
@@ -0,0 +1,98 @@
+//===- SliceWalk.h - Helpers for performing IR slice walks ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_ANALYSIS_SLICEWALK_H
+#define MLIR_ANALYSIS_SLICEWALK_H
+
+#include "mlir/IR/ValueRange.h"
+
+namespace mlir {
+
+/// A class to signal how to proceed with the walk of the backward slice:
+/// - Interrupt: Stops the walk.
+/// - AdvanceTo: Continues the walk to user-specified values.
+/// - Skip: Continues the walk, but skips the predecessors of the current value.
+class WalkContinuation {
+public:
+ enum class WalkAction {
+ /// Stops the walk.
+ Interrupt,
+ /// Continues the walk to user-specified values.
+ AdvanceTo,
+ /// Continues the walk, but skips the predecessors of the current value.
+ Skip
+ };
+
+ WalkContinuation(WalkAction action, mlir::ValueRange nextValues)
+ : action(action), nextValues(nextValues) {}
+
+ /// Allows diagnostics to interrupt the walk.
+ explicit WalkContinuation(mlir::Diagnostic &&)
+ : action(WalkAction::Interrupt) {}
+
+ /// Allows diagnostics to interrupt the walk.
+ explicit WalkContinuation(mlir::InFlightDiagnostic &&)
+ : action(WalkAction::Interrupt) {}
+
+ /// Creates a continuation that interrupts the walk.
+ static WalkContinuation interrupt() {
+ return WalkContinuation(WalkAction::Interrupt, {});
+ }
+
+ /// Creates a continuation that adds the user-specified `nextValues` to the
+ /// work list and advances the walk.
+ static WalkContinuation advanceTo(mlir::ValueRange nextValues) {
+ return WalkContinuation(WalkAction::AdvanceTo, nextValues);
+ }
+
+ /// Creates a continuation that advances the walk without adding any
+ /// predecessor values to the work list.
+ static WalkContinuation skip() {
+ return WalkContinuation(WalkAction::Skip, {});
+ }
+
+ /// Returns true if the walk was interrupted.
+ bool wasInterrupted() const { return action == WalkAction::Interrupt; }
+
+ /// Returns true if the walk was skipped.
+ bool wasSkipped() const { return action == WalkAction::Skip; }
+
+ /// Returns true if the walk was advanced to user-specified values.
+ bool wasAdvancedTo() const { return action == WalkAction::AdvanceTo; }
+
+ /// Returns the next values to continue the walk with.
+ mlir::ArrayRef<mlir::Value> getNextValues() const { return nextValues; }
+
+private:
+ WalkAction action;
+ /// The next values to continue the walk with.
+ mlir::SmallVector<mlir::Value> nextValues;
+};
+
+/// A callback that is invoked for each value encountered during the walk of the
+/// slice. The callback takes the current value, and returns the walk
+/// continuation, which determines if the walk should proceed and if yes, with
+/// which values.
+using WalkCallback = mlir::function_ref<WalkContinuation(mlir::Value)>;
+
+/// Walks the slice starting from the `rootValues` using a depth-first
+/// traversal. The walk calls the provided `walkCallback` for each value
+/// encountered in the slice and uses the returned walk continuation to
+/// determine how to proceed.
+WalkContinuation walkSlice(mlir::ValueRange rootValues,
+ WalkCallback walkCallback);
+
+/// Computes a vector of all control predecessors of `value`. Relies on
+/// RegionBranchOpInterface and BranchOpInterface to determine predecessors.
+/// Returns nullopt if `value` has no predecessors or when the relevant
+/// operations are missing the interface implementations.
+std::optional<SmallVector<Value>> getControlFlowPredecessors(Value value);
+
+} // namespace mlir
+
+#endif // MLIR_ANALYSIS_SLICEWALK_H
diff --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt
index 38d8415d81c72d..609cb34309829e 100644
--- a/mlir/lib/Analysis/CMakeLists.txt
+++ b/mlir/lib/Analysis/CMakeLists.txt
@@ -29,6 +29,7 @@ add_mlir_library(MLIRAnalysis
Liveness.cpp
CFGLoopInfo.cpp
SliceAnalysis.cpp
+ SliceWalk.cpp
TopologicalSortUtils.cpp
AliasAnalysis/LocalAliasAnalysis.cpp
diff --git a/mlir/lib/Analysis/SliceWalk.cpp b/mlir/lib/Analysis/SliceWalk.cpp
new file mode 100644
index 00000000000000..9d770639dc53ca
--- /dev/null
+++ b/mlir/lib/Analysis/SliceWalk.cpp
@@ -0,0 +1,139 @@
+#include "mlir/Analysis/SliceWalk.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
+
+using namespace mlir;
+
+WalkContinuation mlir::walkSlice(ValueRange rootValues,
+ WalkCallback walkCallback) {
+ // Search the backward slice starting from the root values.
+ SmallVector<Value> workList = rootValues;
+ llvm::SmallDenseSet<Value, 16> seenValues;
+ while (!workList.empty()) {
+ // Search the backward slice of the current value.
+ Value current = workList.pop_back_val();
+
+ // Skip the current value if it has already been seen.
+ if (!seenValues.insert(current).second)
+ continue;
+
+ // Call the walk callback with the current value.
+ WalkContinuation continuation = walkCallback(current);
+ if (continuation.wasInterrupted())
+ return continuation;
+ if (continuation.wasSkipped())
+ continue;
+
+ assert(continuation.wasAdvancedTo());
+ // Add the next values to the work list if the walk should continue.
+ workList.append(continuation.getNextValues().begin(),
+ continuation.getNextValues().end());
+ }
+
+ return WalkContinuation::skip();
+}
+
+/// Returns the operands from all predecessor regions that match `operandNumber`
+/// for the `successor` region within `regionOp`.
+static SmallVector<Value>
+getRegionPredecessorOperands(RegionBranchOpInterface regionOp,
+ RegionSuccessor successor,
+ unsigned operandNumber) {
+ SmallVector<Value> predecessorOperands;
+
+ // Returns true if `successors` contains `successor`.
+ auto isContained = [](ArrayRef<RegionSuccessor> successors,
+ RegionSuccessor successor) {
+ auto *it = llvm::find_if(successors, [&successor](RegionSuccessor curr) {
+ return curr.getSuccessor() == successor.getSuccessor();
+ });
+ return it != successors.end();
+ };
+
+ // Search the operand ranges on the region operation itself.
+ SmallVector<Attribute> operandAttributes(regionOp->getNumOperands());
+ SmallVector<RegionSuccessor> successors;
+ regionOp.getEntrySuccessorRegions(operandAttributes, successors);
+ if (isContained(successors, successor)) {
+ OperandRange operands = regionOp.getEntrySuccessorOperands(successor);
+ predecessorOperands.push_back(operands[operandNumber]);
+ }
+
+ // Search the operand ranges on region terminators.
+ for (Region ®ion : regionOp->getRegions()) {
+ for (Block &block : region) {
+ auto terminatorOp =
+ dyn_cast<RegionBranchTerminatorOpInterface>(block.getTerminator());
+ if (!terminatorOp)
+ continue;
+ SmallVector<Attribute> operandAttributes(terminatorOp->getNumOperands());
+ SmallVector<RegionSuccessor> successors;
+ terminatorOp.getSuccessorRegions(operandAttributes, successors);
+ if (isContained(successors, successor)) {
+ OperandRange operands = terminatorOp.getSuccessorOperands(successor);
+ predecessorOperands.push_back(operands[operandNumber]);
+ }
+ }
+ }
+
+ return predecessorOperands;
+}
+
+/// Returns the predecessor branch operands that match `blockArg`, or nullopt if
+/// some of the predecessor terminators do not implement the BranchOpInterface.
+static std::optional<SmallVector<Value>>
+getBlockPredecessorOperands(BlockArgument blockArg) {
+ Block *block = blockArg.getOwner();
+
+ // Search the predecessor operands for all predecessor terminators.
+ SmallVector<Value> predecessorOperands;
+ for (auto it = block->pred_begin(); it != block->pred_end(); ++it) {
+ Block *predecessor = *it;
+ auto branchOp = dyn_cast<BranchOpInterface>(predecessor->getTerminator());
+ if (!branchOp)
+ return std::nullopt;
+ SuccessorOperands successorOperands =
+ branchOp.getSuccessorOperands(it.getSuccessorIndex());
+ // Store the predecessor operand if the block argument matches an operand
+ // and is not produced by the terminator.
+ if (Value operand = successorOperands[blockArg.getArgNumber()])
+ predecessorOperands.push_back(operand);
+ }
+
+ return predecessorOperands;
+}
+
+std::optional<SmallVector<Value>>
+mlir::getControlFlowPredecessors(Value value) {
+ SmallVector<Value> result;
+ if (OpResult opResult = dyn_cast<OpResult>(value)) {
+ auto regionOp = dyn_cast<RegionBranchOpInterface>(opResult.getOwner());
+ // If the interface is not implemented, there are no control flow
+ // predecessors to work with.
+ if (!regionOp)
+ return std::nullopt;
+ // Add the control flow predecessor operands to the work list.
+ RegionSuccessor region(regionOp->getResults());
+ SmallVector<Value> predecessorOperands = getRegionPredecessorOperands(
+ regionOp, region, opResult.getResultNumber());
+ return predecessorOperands;
+ }
+
+ auto blockArg = cast<BlockArgument>(value);
+ Block *block = blockArg.getOwner();
+ // Search the region predecessor operands for structured control flow.
+ if (block->isEntryBlock()) {
+ if (auto regionBranchOp =
+ dyn_cast<RegionBranchOpInterface>(block->getParentOp())) {
+ RegionSuccessor region(blockArg.getParentRegion());
+ SmallVector<Value> predecessorOperands = getRegionPredecessorOperands(
+ regionBranchOp, region, blockArg.getArgNumber());
+ return predecessorOperands;
+ }
+ // If the interface is not implemented, there are no control flow
+ // predecessors to work with.
+ return std::nullopt;
+ }
+
+ // Search the block predecessor operands for unstructured control flow.
+ return getBlockPredecessorOperands(blockArg);
+}
diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp
index 8eba76a9abee8d..504f63b48c9433 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp
@@ -12,6 +12,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h"
+#include "mlir/Analysis/SliceWalk.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Interfaces/DataLayoutInterfaces.h"
@@ -221,86 +222,45 @@ static ArrayAttr concatArrayAttr(ArrayAttr lhs, ArrayAttr rhs) {
return ArrayAttr::get(lhs.getContext(), result);
}
-/// Attempts to return the underlying pointer value that `pointerValue` is based
-/// on. This traverses down the chain of operations to the last operation
-/// producing the base pointer and returns it. If it encounters an operation it
-/// cannot further traverse through, returns the operation's result.
-static Value getUnderlyingObject(Value pointerValue) {
- while (true) {
- if (auto gepOp = pointerValue.getDefiningOp<LLVM::GEPOp>()) {
- pointerValue = gepOp.getBase();
- continue;
- }
-
- if (auto addrCast = pointerValue.getDefiningOp<LLVM::AddrSpaceCastOp>()) {
- pointerValue = addrCast.getOperand();
- continue;
- }
-
- break;
- }
-
- return pointerValue;
-}
-
/// Attempts to return the set of all underlying pointer values that
/// `pointerValue` is based on. This function traverses through select
-/// operations and block arguments unlike getUnderlyingObject.
-static SmallVector<Value> getUnderlyingObjectSet(Value pointerValue) {
+/// operations and block arguments.
+static FailureOr<SmallVector<Value>>
+getUnderlyingObjectSet(Value pointerValue) {
SmallVector<Value> result;
-
- SmallVector<Value> workList{pointerValue};
- // Avoid dataflow loops.
- SmallPtrSet<Value, 4> seen;
- do {
- Value current = workList.pop_back_val();
- current = getUnderlyingObject(current);
-
- if (!seen.insert(current).second)
- continue;
-
- if (auto selectOp = current.getDefiningOp<LLVM::SelectOp>()) {
- workList.push_back(selectOp.getTrueValue());
- workList.push_back(selectOp.getFalseValue());
- continue;
+ WalkContinuation walkResult = walkSlice(pointerValue, [&](Value val) {
+ if (auto gepOp = val.getDefiningOp<LLVM::GEPOp>())
+ return WalkContinuation::advanceTo(gepOp.getBase());
+
+ if (auto addrCast = val.getDefiningOp<LLVM::AddrSpaceCastOp>())
+ return WalkContinuation::advanceTo(addrCast.getOperand());
+
+ // TODO: Add a SelectLikeOpInterface and use it in the slicing utility.
+ if (auto selectOp = val.getDefiningOp<LLVM::SelectOp>())
+ return WalkContinuation::advanceTo(
+ {selectOp.getTrueValue(), selectOp.getFalseValue()});
+
+ // Attempt to advance to control flow predecessors.
+ std::optional<SmallVector<Value>> controlFlowPredecessors =
+ getControlFlowPredecessors(val);
+ if (controlFlowPredecessors)
+ return WalkContinuation::advanceTo(*controlFlowPredecessors);
+
+ // For all non-control flow results, consider `val` an underlying object.
+ if (isa<OpResult>(val)) {
+ result.push_back(val);
+ return WalkContinuation::skip();
}
- if (auto blockArg = dyn_cast<BlockArgument>(current)) {
- Block *parentBlock = blockArg.getParentBlock();
-
- // Attempt to find all block argument operands for every predecessor.
- // If any operand to the block argument wasn't found in a predecessor,
- // conservatively add the block argument to the result set.
- SmallVector<Value> operands;
- bool anyUnknown = false;
- for (auto iter = parentBlock->pred_begin();
- iter != parentBlock->pred_end(); iter++) {
- auto branch = dyn_cast<BranchOpInterface>((*iter)->getTerminator());
- if (!branch) {
- result.push_back(blockArg);
- anyUnknown = true;
- break;
- }
-
- Value operand = branch.getSuccessorOperands(
- iter.getSuccessorIndex())[blockArg.getArgNumber()];
- if (!operand) {
- result.push_back(blockArg);
- anyUnknown = true;
- break;
- }
-
- operands.push_back(operand);
- }
-
- if (!anyUnknown)
- llvm::append_range(workList, operands);
-
- continue;
- }
+ // If this place is reached, `val` is a block argument that is not
+ // understood. Therefore, we conservatively interrupt.
+ // Note: Dealing with function arguments is not necessary, as the slice
+ // would have to go through an SSACopyOp first.
+ return WalkContinuation::interrupt();
+ });
- result.push_back(current);
- } while (!workList.empty());
+ if (walkResult.wasInterrupted())
+ return failure();
return result;
}
@@ -363,9 +323,14 @@ static void createNewAliasScopesFromNoAliasParameter(
// Find the set of underlying pointers that this pointer is based on.
SmallPtrSet<Value, 4> basedOnPointers;
- for (Value pointer : pointerArgs)
- llvm::copy(getUnderlyingObjectSet(pointer),
+ for (Value pointer : pointerArgs) {
+ FailureOr<SmallVector<Value>> underlyingObjectSet =
+ getUnderlyingObjectSet(pointer);
+ if (failed(underlyingObjectSet))
+ return;
+ llvm::copy(*underlyingObjectSet,
std::inserter(basedOnPointers, basedOnPointers.begin()));
+ }
bool aliasesOtherKnownObject = false;
// Go through the based on pointers and check that they are either:
diff --git a/mlir/test/Dialect/LLVMIR/inlining-alias-scopes.mlir b/mlir/test/Dialect/LLVMIR/inlining-alias-scopes.mlir
index 0b8b60e963bb01..a91b991c5ed2b9 100644
--- a/mlir/test/Dialect/LLVMIR/inlining-alias-scopes.mlir
+++ b/mlir/test/Dialect/LLVMIR/inlining-alias-scopes.mlir
@@ -296,6 +296,60 @@ llvm.func @bar(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) {
llvm.func @random() -> i1
+llvm.func @region_branch(%arg0: !llvm.ptr {llvm.noalias}, %arg1: !llvm.ptr {llvm.noalias}) {
+ %0 = llvm.mlir.constant(5 : i64) : i32
+ test.region_if %arg0: !llvm.ptr -> !llvm.ptr then {
+ ^bb0(%arg2: !llvm.ptr):
+ test.region_if_yield %arg0 : !llvm.ptr
+ } else {
+ ^bb0(%arg2: !llvm.ptr):
+ test.region_if_yield %arg0 : !llvm.ptr
+ } join {
+ ^bb0(%arg2: !llvm.ptr):
+ llvm.store %0, %arg2 : i32, !llvm.ptr
+ test.region_if_yield %arg0 : !llvm.ptr
+ }
+ llvm.return
+}
+
+// CHECK-LABEL: llvm.func @region_branch_inlining
+// CHECK: llvm.store
+// CHECK-SAME: alias_scopes = [#[[$ARG0_SCOPE]]]
+// CHECK-SAME: noalias_scopes = [#[[$ARG1_SCOPE]]]
+llvm.func @region_branch_inlining(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) {
+ llvm.call @region_branch(%arg0, %arg2) : (!llvm.ptr, !llvm.ptr) -> ()
+ llvm.return
+}
+
+// -----
+
+llvm.func @missing_region_branch(%arg0: !llvm.ptr {llvm.noalias}, %arg1: !llvm.ptr {llvm.noalias}) {
+ %0 = llvm.mlir.constant(5 : i64) : i32
+ "test.one_region_op"() ({
+ ^bb0(%arg2: !llvm.ptr):
+ llvm.store %0, %arg2 : i32, !llvm.ptr
+ "test.terminator"() : () -> ()
+ }) : () -> ()
+ llvm.return
+}
+
+// CHECK-LABEL: llvm.func @missing_region_branch_inlining
+// CHECK: llvm.store
+// CHECK-NOT: alias_scopes
+// CHECK-NOT: noalias_scopes
+llvm.func @missing_region_branch_inlining(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) {
+ llvm.call @missing_region_branch(%arg0, %arg2) : (!llvm.ptr, !llvm.ptr) -> ()
+ llvm.return
+}
+
+// -----
+
+// CHECK-DAG: #[[DOMAIN:.*]] = #llvm.alias_scope_domain<{{.*}}>
+// CHECK-DAG: #[[$ARG0_SCOPE:.*]] = #llvm.alias_scope<id = {{.*}}, domain = #[[DOMAIN]]{{(,.*)?}}>
+// CHECK-DAG: #[[$ARG1_SCOPE:.*]] = #llvm.alias_scope<id = {{.*}}, domain = #[[DOMAIN]]{{(,.*)?}}>
+
+llvm.func @random() -> i1
+
llvm.func @block_arg(%arg0: !llvm.ptr {llvm.noalias}, %arg1: !llvm.ptr {llvm.noalias}) {
%0 = llvm.mlir.constant(5 : i64) : i32
%1 = llvm.mlir.constant(1 : i64) : i64
More information about the Mlir-commits
mailing list