[Mlir-commits] [mlir] 641b12e - [mlir][SliceAnalysis] Add an options object to forward and backward slice.
Mahesh Ravishankar
llvmlistbot at llvm.org
Thu Jun 8 11:40:44 PDT 2023
Author: Mahesh Ravishankar
Date: 2023-06-08T18:40:20Z
New Revision: 641b12e94b8a4e7befbda691364554c186a61639
URL: https://github.com/llvm/llvm-project/commit/641b12e94b8a4e7befbda691364554c186a61639
DIFF: https://github.com/llvm/llvm-project/commit/641b12e94b8a4e7befbda691364554c186a61639.diff
LOG: [mlir][SliceAnalysis] Add an options object to forward and backward slice.
Add an options object to allow control of the slice computation (for
both forward and backward slice). This makes the ABI stable, and also
allows avoiding an assert that makes the slice analysis unusable for
operations with multiple blocks.
Reviewed By: hanchung, nicolasvasilache
Differential Revision: https://reviews.llvm.org/D151520
Added:
mlir/test/IR/slice_multiple_blocks.mlir
Modified:
mlir/include/mlir/Analysis/SliceAnalysis.h
mlir/lib/Analysis/SliceAnalysis.cpp
mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
mlir/lib/Dialect/SCF/Utils/Utils.cpp
mlir/test/lib/IR/TestSlicing.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Analysis/SliceAnalysis.h b/mlir/include/mlir/Analysis/SliceAnalysis.h
index 4445b645ba718..d170b512e3531 100644
--- a/mlir/include/mlir/Analysis/SliceAnalysis.h
+++ b/mlir/include/mlir/Analysis/SliceAnalysis.h
@@ -21,11 +21,27 @@ class BlockArgument;
class Operation;
class Value;
-/// Type of the condition to limit the propagation of transitive use-defs.
-/// This can be used in particular to limit the propagation to a given Scope or
-/// to avoid passing through certain types of operation in a configurable
-/// manner.
-using TransitiveFilter = llvm::function_ref<bool(Operation *)>;
+struct SliceOptions {
+ /// Type of the condition to limit the propagation of transitive use-defs.
+ /// This can be used in particular to limit the propagation to a given Scope
+ /// or to avoid passing through certain types of operation in a configurable
+ /// manner.
+ using TransitiveFilter = std::function<bool(Operation *)>;
+ TransitiveFilter filter = nullptr;
+
+ /// Include the top level op in the slice.
+ bool inclusive = false;
+};
+
+struct BackwardSliceOptions : public SliceOptions {
+ /// When omitBlockArguments is true, the backward slice computation omits
+ /// traversing any block arguments. When omitBlockArguments is false, the
+ /// backward slice computation traverses block arguments and asserts that the
+ /// parent op has a single region with a single block.
+ bool omitBlockArguments = false;
+};
+
+using ForwardSliceOptions = SliceOptions;
/// Fills `forwardSlice` with the computed forward slice (i.e. all
/// the transitive uses of op), **without** including that operation.
@@ -69,14 +85,12 @@ using TransitiveFilter = llvm::function_ref<bool(Operation *)>;
/// {4, 3, 6, 2, 1, 5, 8, 7, 9}
///
void getForwardSlice(Operation *op, SetVector<Operation *> *forwardSlice,
- TransitiveFilter filter = nullptr /* pass-through*/,
- bool inclusive = false);
+ ForwardSliceOptions options = {});
/// Value-rooted version of `getForwardSlice`. Return the union of all forward
/// slices for the uses of the value `root`.
void getForwardSlice(Value root, SetVector<Operation *> *forwardSlice,
- TransitiveFilter filter = nullptr /* pass-through*/,
- bool inclusive = false);
+ ForwardSliceOptions options = {});
/// Fills `backwardSlice` with the computed backward slice (i.e.
/// all the transitive defs of op), **without** including that operation.
@@ -113,14 +127,12 @@ void getForwardSlice(Value root, SetVector<Operation *> *forwardSlice,
/// {1, 2, 5, 3, 4, 6}
///
void getBackwardSlice(Operation *op, SetVector<Operation *> *backwardSlice,
- TransitiveFilter filter = nullptr /* pass-through*/,
- bool inclusive = false);
+ BackwardSliceOptions options = {});
/// Value-rooted version of `getBackwardSlice`. Return the union of all backward
/// slices for the op defining or owning the value `root`.
void getBackwardSlice(Value root, SetVector<Operation *> *backwardSlice,
- TransitiveFilter filter = nullptr /* pass-through*/,
- bool inclusive = false);
+ BackwardSliceOptions options = {});
/// Iteratively computes backward slices and forward slices until
/// a fixed point is reached. Returns an `SetVector<Operation *>` which
@@ -199,11 +211,9 @@ void getBackwardSlice(Value root, SetVector<Operation *> *backwardSlice,
/// and keep things ordered but this is still hand-wavy and not worth the
/// trouble for now: punt to a simple worklist-based solution.
///
-SetVector<Operation *>
-getSlice(Operation *op,
- TransitiveFilter backwardFilter = nullptr /* pass-through*/,
- TransitiveFilter forwardFilter = nullptr /* pass-through*/,
- bool inclusive = false);
+SetVector<Operation *> getSlice(Operation *op,
+ BackwardSliceOptions backwardSliceOptions = {},
+ ForwardSliceOptions forwardSliceOptions = {});
/// Multi-root DAG topological sort.
/// Performs a topological sort of the Operation in the `toSort` SetVector.
diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp
index 7af6a65cef99a..9a5821da6343d 100644
--- a/mlir/lib/Analysis/SliceAnalysis.cpp
+++ b/mlir/lib/Analysis/SliceAnalysis.cpp
@@ -24,9 +24,9 @@
using namespace mlir;
-static void getForwardSliceImpl(Operation *op,
- SetVector<Operation *> *forwardSlice,
- TransitiveFilter filter) {
+static void
+getForwardSliceImpl(Operation *op, SetVector<Operation *> *forwardSlice,
+ SliceOptions::TransitiveFilter filter = nullptr) {
if (!op)
return;
@@ -51,9 +51,9 @@ static void getForwardSliceImpl(Operation *op,
}
void mlir::getForwardSlice(Operation *op, SetVector<Operation *> *forwardSlice,
- TransitiveFilter filter, bool inclusive) {
- getForwardSliceImpl(op, forwardSlice, filter);
- if (!inclusive) {
+ ForwardSliceOptions options) {
+ getForwardSliceImpl(op, forwardSlice, options.filter);
+ if (!options.inclusive) {
// Don't insert the top level operation, we just queried on it and don't
// want it in the results.
forwardSlice->remove(op);
@@ -67,9 +67,9 @@ void mlir::getForwardSlice(Operation *op, SetVector<Operation *> *forwardSlice,
}
void mlir::getForwardSlice(Value root, SetVector<Operation *> *forwardSlice,
- TransitiveFilter filter, bool inclusive) {
+ SliceOptions options) {
for (Operation *user : root.getUsers())
- getForwardSliceImpl(user, forwardSlice, filter);
+ getForwardSliceImpl(user, forwardSlice, options.filter);
// Reverse to get back the actual topological order.
// std::reverse does not work out of the box on SetVector and I want an
@@ -80,22 +80,25 @@ void mlir::getForwardSlice(Value root, SetVector<Operation *> *forwardSlice,
static void getBackwardSliceImpl(Operation *op,
SetVector<Operation *> *backwardSlice,
- TransitiveFilter filter) {
+ BackwardSliceOptions options) {
if (!op || op->hasTrait<OpTrait::IsIsolatedFromAbove>())
return;
// Evaluate whether we should keep this def.
// This is useful in particular to implement scoping; i.e. return the
// transitive backwardSlice in the current scope.
- if (filter && !filter(op))
+ if (options.filter && !options.filter(op))
return;
for (const auto &en : llvm::enumerate(op->getOperands())) {
auto operand = en.value();
if (auto *definingOp = operand.getDefiningOp()) {
if (backwardSlice->count(definingOp) == 0)
- getBackwardSliceImpl(definingOp, backwardSlice, filter);
+ getBackwardSliceImpl(definingOp, backwardSlice, options);
} else if (auto blockArg = dyn_cast<BlockArgument>(operand)) {
+ if (options.omitBlockArguments)
+ continue;
+
Block *block = blockArg.getOwner();
Operation *parentOp = block->getParentOp();
// TODO: determine whether we want to recurse backward into the other
@@ -104,7 +107,7 @@ static void getBackwardSliceImpl(Operation *op,
if (parentOp && backwardSlice->count(parentOp) == 0) {
assert(parentOp->getNumRegions() == 1 &&
parentOp->getRegion(0).getBlocks().size() == 1);
- getBackwardSliceImpl(parentOp, backwardSlice, filter);
+ getBackwardSliceImpl(parentOp, backwardSlice, options);
}
} else {
llvm_unreachable("No definingOp and not a block argument.");
@@ -116,10 +119,10 @@ static void getBackwardSliceImpl(Operation *op,
void mlir::getBackwardSlice(Operation *op,
SetVector<Operation *> *backwardSlice,
- TransitiveFilter filter, bool inclusive) {
- getBackwardSliceImpl(op, backwardSlice, filter);
+ BackwardSliceOptions options) {
+ getBackwardSliceImpl(op, backwardSlice, options);
- if (!inclusive) {
+ if (!options.inclusive) {
// Don't insert the top level operation, we just queried on it and don't
// want it in the results.
backwardSlice->remove(op);
@@ -127,19 +130,18 @@ void mlir::getBackwardSlice(Operation *op,
}
void mlir::getBackwardSlice(Value root, SetVector<Operation *> *backwardSlice,
- TransitiveFilter filter, bool inclusive) {
+ BackwardSliceOptions options) {
if (Operation *definingOp = root.getDefiningOp()) {
- getBackwardSlice(definingOp, backwardSlice, filter, inclusive);
+ getBackwardSlice(definingOp, backwardSlice, options);
return;
}
Operation *bbAargOwner = cast<BlockArgument>(root).getOwner()->getParentOp();
- getBackwardSlice(bbAargOwner, backwardSlice, filter, inclusive);
+ getBackwardSlice(bbAargOwner, backwardSlice, options);
}
SetVector<Operation *> mlir::getSlice(Operation *op,
- TransitiveFilter backwardFilter,
- TransitiveFilter forwardFilter,
- bool inclusive) {
+ BackwardSliceOptions backwardSliceOptions,
+ ForwardSliceOptions forwardSliceOptions) {
SetVector<Operation *> slice;
slice.insert(op);
@@ -150,12 +152,12 @@ SetVector<Operation *> mlir::getSlice(Operation *op,
auto *currentOp = (slice)[currentIndex];
// Compute and insert the backwardSlice starting from currentOp.
backwardSlice.clear();
- getBackwardSlice(currentOp, &backwardSlice, backwardFilter, inclusive);
+ getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions);
slice.insert(backwardSlice.begin(), backwardSlice.end());
// Compute and insert the forwardSlice starting from currentOp.
forwardSlice.clear();
- getForwardSlice(currentOp, &forwardSlice, forwardFilter, inclusive);
+ getForwardSlice(currentOp, &forwardSlice, forwardSliceOptions);
slice.insert(forwardSlice.begin(), forwardSlice.end());
++currentIndex;
}
@@ -225,8 +227,11 @@ static bool dependsOnCarriedVals(Value value,
Operation *ancestorOp) {
// Compute the backward slice of the value.
SetVector<Operation *> slice;
- getBackwardSlice(value, &slice,
- [&](Operation *op) { return !ancestorOp->isAncestor(op); });
+ BackwardSliceOptions sliceOptions;
+ sliceOptions.filter = [&](Operation *op) {
+ return !ancestorOp->isAncestor(op);
+ };
+ getBackwardSlice(value, &slice, sliceOptions);
// Check that none of the operands of the operations in the backward slice are
// loop iteration arguments, and neither is the value itself.
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index a78402eb16428..08c5214aae3db 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -303,9 +303,9 @@ static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) {
/// Return an unsorted slice handling scf.for region
diff erently than
/// `getSlice`. In scf.for we only want to include as part of the slice elements
/// that are part of the use/def chain.
-static SetVector<Operation *> getSliceContract(Operation *op,
- TransitiveFilter backwardFilter,
- TransitiveFilter forwardFilter) {
+static SetVector<Operation *>
+getSliceContract(Operation *op, BackwardSliceOptions backwardSliceOptions,
+ ForwardSliceOptions forwardSliceOptions) {
SetVector<Operation *> slice;
slice.insert(op);
unsigned currentIndex = 0;
@@ -315,7 +315,7 @@ static SetVector<Operation *> getSliceContract(Operation *op,
auto *currentOp = (slice)[currentIndex];
// Compute and insert the backwardSlice starting from currentOp.
backwardSlice.clear();
- getBackwardSlice(currentOp, &backwardSlice, backwardFilter);
+ getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions);
slice.insert(backwardSlice.begin(), backwardSlice.end());
// Compute and insert the forwardSlice starting from currentOp.
@@ -326,11 +326,11 @@ static SetVector<Operation *> getSliceContract(Operation *op,
// converted to matrix type.
if (auto forOp = dyn_cast<scf::ForOp>(currentOp)) {
for (Value forOpResult : forOp.getResults())
- getForwardSlice(forOpResult, &forwardSlice, forwardFilter);
+ getForwardSlice(forOpResult, &forwardSlice, forwardSliceOptions);
for (BlockArgument &arg : forOp.getRegionIterArgs())
- getForwardSlice(arg, &forwardSlice, forwardFilter);
+ getForwardSlice(arg, &forwardSlice, forwardSliceOptions);
} else {
- getForwardSlice(currentOp, &forwardSlice, forwardFilter);
+ getForwardSlice(currentOp, &forwardSlice, forwardSliceOptions);
}
slice.insert(forwardSlice.begin(), forwardSlice.end());
++currentIndex;
@@ -346,16 +346,22 @@ static SetVector<Operation *> getOpToConvert(mlir::Operation *op,
return llvm::any_of(op->getResultTypes(),
[](Type t) { return isa<VectorType>(t); });
};
+ BackwardSliceOptions backwardSliceOptions;
+ backwardSliceOptions.filter = hasVectorDest;
+
auto hasVectorSrc = [](Operation *op) {
return llvm::any_of(op->getOperandTypes(),
[](Type t) { return isa<VectorType>(t); });
};
+ ForwardSliceOptions forwardSliceOptions;
+ forwardSliceOptions.filter = hasVectorSrc;
+
SetVector<Operation *> opToConvert;
op->walk([&](vector::ContractionOp contract) {
if (opToConvert.contains(contract.getOperation()))
return;
SetVector<Operation *> dependentOps =
- getSliceContract(contract, hasVectorDest, hasVectorSrc);
+ getSliceContract(contract, backwardSliceOptions, forwardSliceOptions);
// If any instruction cannot use MMA matrix type drop the whole
// chain. MMA matrix are stored in an opaque type so they cannot be used
// by all operations.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
index 21d83d225d705..21bc0554e7176 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
@@ -111,20 +111,22 @@ static void computeBackwardSlice(tensor::PadOp padOp,
scf::ForOp outermostEnclosingForOp,
SetVector<Operation *> &backwardSlice) {
DominanceInfo domInfo(outermostEnclosingForOp);
- auto filter = [&](Operation *op) {
+ BackwardSliceOptions sliceOptions;
+ sliceOptions.filter = [&](Operation *op) {
return domInfo.dominates(outermostEnclosingForOp, op) &&
!padOp->isProperAncestor(op);
};
+ sliceOptions.inclusive = true;
+
// First, add the ops required to compute the region to the backwardSlice.
SetVector<Value> valuesDefinedAbove;
getUsedValuesDefinedAbove(padOp.getRegion(), padOp.getRegion(),
valuesDefinedAbove);
for (Value v : valuesDefinedAbove) {
- getBackwardSlice(v, &backwardSlice, filter, /*inclusive=*/true);
+ getBackwardSlice(v, &backwardSlice, sliceOptions);
}
// Then, add the backward slice from padOp itself.
- getBackwardSlice(padOp.getOperation(), &backwardSlice, filter,
- /*inclusive=*/true);
+ getBackwardSlice(padOp.getOperation(), &backwardSlice, sliceOptions);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index e16e2881185a9..34225fd133b2e 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -797,9 +797,11 @@ void mlir::collapseParallelLoops(
// Return failure when any op fails to hoist.
static LogicalResult hoistOpsBetween(scf::ForOp outer, scf::ForOp inner) {
SetVector<Operation *> forwardSlice;
- getForwardSlice(
- outer.getInductionVar(), &forwardSlice,
- [&inner](Operation *op) { return op != inner.getOperation(); });
+ ForwardSliceOptions options;
+ options.filter = [&inner](Operation *op) {
+ return op != inner.getOperation();
+ };
+ getForwardSlice(outer.getInductionVar(), &forwardSlice, options);
LogicalResult status = success();
SmallVector<Operation *, 8> toHoist;
for (auto &op : outer.getBody()->without_terminator()) {
diff --git a/mlir/test/IR/slice_multiple_blocks.mlir b/mlir/test/IR/slice_multiple_blocks.mlir
new file mode 100644
index 0000000000000..395a4e970d5d4
--- /dev/null
+++ b/mlir/test/IR/slice_multiple_blocks.mlir
@@ -0,0 +1,36 @@
+// RUN: mlir-opt --pass-pipeline="builtin.module(slice-analysis-test{omit-block-arguments=true})" %s | FileCheck %s
+
+func.func @slicing_linalg_op(%arg0 : index, %arg1 : index, %arg2 : index) {
+ %a = memref.alloc(%arg0, %arg2) : memref<?x?xf32>
+ %b = memref.alloc(%arg2, %arg1) : memref<?x?xf32>
+ cf.br ^bb1
+^bb1() :
+ %c = memref.alloc(%arg0, %arg1) : memref<?x?xf32>
+ %d = memref.alloc(%arg0, %arg1) : memref<?x?xf32>
+ linalg.matmul ins(%a, %b : memref<?x?xf32>, memref<?x?xf32>)
+ outs(%c : memref<?x?xf32>)
+ linalg.matmul ins(%a, %b : memref<?x?xf32>, memref<?x?xf32>)
+ outs(%d : memref<?x?xf32>)
+ memref.dealloc %c : memref<?x?xf32>
+ memref.dealloc %b : memref<?x?xf32>
+ memref.dealloc %a : memref<?x?xf32>
+ memref.dealloc %d : memref<?x?xf32>
+ return
+}
+// CHECK-LABEL: func @slicing_linalg_op__backward_slice__0
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
+// CHECK-DAG: %[[A:.+]] = memref.alloc(%[[ARG0]], %[[ARG2]]) : memref<?x?xf32>
+// CHECK-DAG: %[[B:.+]] = memref.alloc(%[[ARG2]], %[[ARG1]]) : memref<?x?xf32>
+// CHECK-DAG: %[[C:.+]] = memref.alloc(%[[ARG0]], %[[ARG1]]) : memref<?x?xf32>
+// CHECK: return
+
+// CHECK-LABEL: func @slicing_linalg_op__backward_slice__1
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
+// CHECK-DAG: %[[A:.+]] = memref.alloc(%[[ARG0]], %[[ARG2]]) : memref<?x?xf32>
+// CHECK-DAG: %[[B:.+]] = memref.alloc(%[[ARG2]], %[[ARG1]]) : memref<?x?xf32>
+// CHECK-DAG: %[[C:.+]] = memref.alloc(%[[ARG0]], %[[ARG1]]) : memref<?x?xf32>
+// CHECK: return
diff --git a/mlir/test/lib/IR/TestSlicing.cpp b/mlir/test/lib/IR/TestSlicing.cpp
index 3388b972880e3..c3d0d151c6d24 100644
--- a/mlir/test/lib/IR/TestSlicing.cpp
+++ b/mlir/test/lib/IR/TestSlicing.cpp
@@ -24,7 +24,8 @@ using namespace mlir;
/// Create a function with the same signature as the parent function of `op`
/// with name being the function name and a `suffix`.
static LogicalResult createBackwardSliceFunction(Operation *op,
- StringRef suffix) {
+ StringRef suffix,
+ bool omitBlockArguments) {
func::FuncOp parentFuncOp = op->getParentOfType<func::FuncOp>();
OpBuilder builder(parentFuncOp);
Location loc = op->getLoc();
@@ -36,7 +37,9 @@ static LogicalResult createBackwardSliceFunction(Operation *op,
for (const auto &arg : enumerate(parentFuncOp.getArguments()))
mapper.map(arg.value(), clonedFuncOp.getArgument(arg.index()));
SetVector<Operation *> slice;
- getBackwardSlice(op, &slice);
+ BackwardSliceOptions options;
+ options.omitBlockArguments = omitBlockArguments;
+ getBackwardSlice(op, &slice, options);
for (Operation *slicedOp : slice)
builder.clone(*slicedOp, mapper);
builder.create<func::ReturnOp>(loc);
@@ -53,6 +56,13 @@ struct SliceAnalysisTestPass
StringRef getDescription() const final {
return "Test Slice analysis functionality.";
}
+
+ Option<bool> omitBlockArguments{
+ *this, "omit-block-arguments",
+ llvm::cl::desc("Test Slice analysis with multiple blocks but slice "
+ "omiting block arguments"),
+ llvm::cl::init(true)};
+
void runOnOperation() override;
SliceAnalysisTestPass() = default;
SliceAnalysisTestPass(const SliceAnalysisTestPass &) {}
@@ -64,11 +74,6 @@ void SliceAnalysisTestPass::runOnOperation() {
auto funcOps = module.getOps<func::FuncOp>();
unsigned opNum = 0;
for (auto funcOp : funcOps) {
- if (!llvm::hasSingleElement(funcOp.getBody())) {
- funcOp->emitOpError("Does not support functions with multiple blocks");
- signalPassFailure();
- return;
- }
// TODO: For now this is just looking for Linalg ops. It can be generalized
// to look for other ops using flags.
funcOp.walk([&](Operation *op) {
@@ -76,7 +81,7 @@ void SliceAnalysisTestPass::runOnOperation() {
return WalkResult::advance();
std::string append =
std::string("__backward_slice__") + std::to_string(opNum);
- (void)createBackwardSliceFunction(op, append);
+ (void)createBackwardSliceFunction(op, append, omitBlockArguments);
opNum++;
return WalkResult::advance();
});
More information about the Mlir-commits
mailing list