[Mlir-commits] [mlir] 6a8dde0 - [MLIR] Change getBackwardSlice to return a logicalresult rather than crash (#140961)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu May 22 12:13:36 PDT 2025
Author: William Moses
Date: 2025-05-22T14:13:32-05:00
New Revision: 6a8dde04a07287f837bbabeb93e23e47af366d3d
URL: https://github.com/llvm/llvm-project/commit/6a8dde04a07287f837bbabeb93e23e47af366d3d
DIFF: https://github.com/llvm/llvm-project/commit/6a8dde04a07287f837bbabeb93e23e47af366d3d.diff
LOG: [MLIR] Change getBackwardSlice to return a logicalresult rather than crash (#140961)
The current implementation of getBackwardSlice will crash if an
operation in the dependency chain is defined by an operation with
multiple regions or blocks. Crashing is bad (and forbids many analyses
from using getBackwardSlice, as well as causing existing users of
getBackwardSlice to fail for IR with this property).
This PR instead causes the analysis to return a failure, rather than
crash in the cases it cannot compute the full slice
---------
Co-authored-by: Oleksandr "Alex" Zinenko <git at ozinenko.com>
Added:
Modified:
mlir/include/mlir/Analysis/SliceAnalysis.h
mlir/include/mlir/Query/Matcher/SliceMatchers.h
mlir/lib/Analysis/SliceAnalysis.cpp
mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
mlir/lib/Transforms/Utils/RegionUtils.cpp
mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp
mlir/test/lib/IR/TestSlicing.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Analysis/SliceAnalysis.h b/mlir/include/mlir/Analysis/SliceAnalysis.h
index 3b731e8bb1c22..d082d2d9f758b 100644
--- a/mlir/include/mlir/Analysis/SliceAnalysis.h
+++ b/mlir/include/mlir/Analysis/SliceAnalysis.h
@@ -138,13 +138,17 @@ void getForwardSlice(Value root, SetVector<Operation *> *forwardSlice,
/// Assuming all local orders match the numbering order:
/// {1, 2, 5, 3, 4, 6}
///
-void getBackwardSlice(Operation *op, SetVector<Operation *> *backwardSlice,
- const BackwardSliceOptions &options = {});
+/// This function returns whether the backwards slice was able to be
+/// successfully computed, and failure if it was unable to determine the slice.
+LogicalResult getBackwardSlice(Operation *op,
+ SetVector<Operation *> *backwardSlice,
+ const BackwardSliceOptions &options = {});
/// Value-rooted version of `getBackwardSlice`. Return the union of all backward
/// slices for the op defining or owning the value `root`.
-void getBackwardSlice(Value root, SetVector<Operation *> *backwardSlice,
- const BackwardSliceOptions &options = {});
+LogicalResult getBackwardSlice(Value root,
+ SetVector<Operation *> *backwardSlice,
+ const BackwardSliceOptions &options = {});
/// Iteratively computes backward slices and forward slices until
/// a fixed point is reached. Returns an `SetVector<Operation *>` which
diff --git a/mlir/include/mlir/Query/Matcher/SliceMatchers.h b/mlir/include/mlir/Query/Matcher/SliceMatchers.h
index 1b0e4c32dbe94..40a39d23ca695 100644
--- a/mlir/include/mlir/Query/Matcher/SliceMatchers.h
+++ b/mlir/include/mlir/Query/Matcher/SliceMatchers.h
@@ -112,7 +112,8 @@ bool BackwardSliceMatcher<Matcher>::matches(
}
return true;
};
- getBackwardSlice(rootOp, &backwardSlice, options);
+ LogicalResult result = getBackwardSlice(rootOp, &backwardSlice, options);
+ assert(result.succeeded() && "expected backward slice to succeed");
return options.inclusive ? backwardSlice.size() > 1
: backwardSlice.size() >= 1;
}
diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp
index 5aebb19e3a86e..12b9d3adb49fa 100644
--- a/mlir/lib/Analysis/SliceAnalysis.cpp
+++ b/mlir/lib/Analysis/SliceAnalysis.cpp
@@ -80,25 +80,25 @@ void mlir::getForwardSlice(Value root, SetVector<Operation *> *forwardSlice,
forwardSlice->insert(v.rbegin(), v.rend());
}
-static void getBackwardSliceImpl(Operation *op,
- SetVector<Operation *> *backwardSlice,
- const BackwardSliceOptions &options) {
+static LogicalResult getBackwardSliceImpl(Operation *op,
+ SetVector<Operation *> *backwardSlice,
+ const BackwardSliceOptions &options) {
if (!op || op->hasTrait<OpTrait::IsIsolatedFromAbove>())
- return;
+ return success();
// Evaluate whether we should keep this def.
// This is useful in particular to implement scoping; i.e. return the
// transitive backwardSlice in the current scope.
if (options.filter && !options.filter(op))
- return;
+ return success();
auto processValue = [&](Value value) {
if (auto *definingOp = value.getDefiningOp()) {
if (backwardSlice->count(definingOp) == 0)
- getBackwardSliceImpl(definingOp, backwardSlice, options);
+ return getBackwardSliceImpl(definingOp, backwardSlice, options);
} else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
if (options.omitBlockArguments)
- return;
+ return success();
Block *block = blockArg.getOwner();
Operation *parentOp = block->getParentOp();
@@ -106,15 +106,17 @@ static void getBackwardSliceImpl(Operation *op,
// blocks of parentOp, which are not technically backward unless they flow
// into us. For now, just bail.
if (parentOp && backwardSlice->count(parentOp) == 0) {
- assert(parentOp->getNumRegions() == 1 &&
- llvm::hasSingleElement(parentOp->getRegion(0).getBlocks()));
- getBackwardSliceImpl(parentOp, backwardSlice, options);
+ if (parentOp->getNumRegions() == 1 &&
+ llvm::hasSingleElement(parentOp->getRegion(0).getBlocks())) {
+ return getBackwardSliceImpl(parentOp, backwardSlice, options);
+ }
}
- } else {
- llvm_unreachable("No definingOp and not a block argument.");
}
+ return failure();
};
+ bool succeeded = true;
+
if (!options.omitUsesFromAbove) {
llvm::for_each(op->getRegions(), [&](Region ®ion) {
// Walk this region recursively to collect the regions that descend from
@@ -125,36 +127,41 @@ static void getBackwardSliceImpl(Operation *op,
region.walk([&](Operation *op) {
for (OpOperand &operand : op->getOpOperands()) {
if (!descendents.contains(operand.get().getParentRegion()))
- processValue(operand.get());
+ if (!processValue(operand.get()).succeeded()) {
+ return WalkResult::interrupt();
+ }
}
+ return WalkResult::advance();
});
});
}
llvm::for_each(op->getOperands(), processValue);
backwardSlice->insert(op);
+ return success(succeeded);
}
-void mlir::getBackwardSlice(Operation *op,
- SetVector<Operation *> *backwardSlice,
- const BackwardSliceOptions &options) {
- getBackwardSliceImpl(op, backwardSlice, options);
+LogicalResult mlir::getBackwardSlice(Operation *op,
+ SetVector<Operation *> *backwardSlice,
+ const BackwardSliceOptions &options) {
+ LogicalResult result = getBackwardSliceImpl(op, backwardSlice, options);
if (!options.inclusive) {
// Don't insert the top level operation, we just queried on it and don't
// want it in the results.
backwardSlice->remove(op);
}
+ return result;
}
-void mlir::getBackwardSlice(Value root, SetVector<Operation *> *backwardSlice,
- const BackwardSliceOptions &options) {
+LogicalResult mlir::getBackwardSlice(Value root,
+ SetVector<Operation *> *backwardSlice,
+ const BackwardSliceOptions &options) {
if (Operation *definingOp = root.getDefiningOp()) {
- getBackwardSlice(definingOp, backwardSlice, options);
- return;
+ return getBackwardSlice(definingOp, backwardSlice, options);
}
Operation *bbAargOwner = cast<BlockArgument>(root).getOwner()->getParentOp();
- getBackwardSlice(bbAargOwner, backwardSlice, options);
+ return getBackwardSlice(bbAargOwner, backwardSlice, options);
}
SetVector<Operation *>
@@ -170,7 +177,9 @@ mlir::getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions,
auto *currentOp = (slice)[currentIndex];
// Compute and insert the backwardSlice starting from currentOp.
backwardSlice.clear();
- getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions);
+ LogicalResult result =
+ getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions);
+ assert(result.succeeded());
slice.insert_range(backwardSlice);
// Compute and insert the forwardSlice starting from currentOp.
@@ -193,7 +202,8 @@ static bool dependsOnCarriedVals(Value value,
sliceOptions.filter = [&](Operation *op) {
return !ancestorOp->isAncestor(op);
};
- getBackwardSlice(value, &slice, sliceOptions);
+ LogicalResult result = getBackwardSlice(value, &slice, sliceOptions);
+ assert(result.succeeded());
// Check that none of the operands of the operations in the backward slice are
// loop iteration arguments, and neither is the value itself.
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 8b16da387457d..0ec9ddc25ff8d 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -317,7 +317,9 @@ getSliceContract(Operation *op,
auto *currentOp = (slice)[currentIndex];
// Compute and insert the backwardSlice starting from currentOp.
backwardSlice.clear();
- getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions);
+ LogicalResult result =
+ getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions);
+ assert(result.succeeded() && "expected a backward slice");
slice.insert_range(backwardSlice);
// Compute and insert the forwardSlice starting from currentOp.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
index d33a17af63459..2c98bd3ba93af 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
@@ -124,10 +124,13 @@ static void computeBackwardSlice(tensor::PadOp padOp,
getUsedValuesDefinedAbove(padOp.getRegion(), padOp.getRegion(),
valuesDefinedAbove);
for (Value v : valuesDefinedAbove) {
- getBackwardSlice(v, &backwardSlice, sliceOptions);
+ LogicalResult result = getBackwardSlice(v, &backwardSlice, sliceOptions);
+ assert(result.succeeded() && "expected a backward slice");
}
// Then, add the backward slice from padOp itself.
- getBackwardSlice(padOp.getOperation(), &backwardSlice, sliceOptions);
+ LogicalResult result =
+ getBackwardSlice(padOp.getOperation(), &backwardSlice, sliceOptions);
+ assert(result.succeeded() && "expected a backward slice");
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
index 75dbe0becf80d..1046f5798ecd4 100644
--- a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
@@ -290,8 +290,10 @@ static void getPipelineStages(
});
options.inclusive = true;
for (Operation &op : forOp.getBody()->getOperations()) {
- if (stage0Ops.contains(&op))
- getBackwardSlice(&op, &dependencies, options);
+ if (stage0Ops.contains(&op)) {
+ LogicalResult result = getBackwardSlice(&op, &dependencies, options);
+ assert(result.succeeded() && "expected a backward slice");
+ }
}
for (Operation &op : forOp.getBody()->getOperations()) {
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 719e2c6fa459e..9e3d3f8b10a13 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -1772,7 +1772,8 @@ checkAssumptionForLoop(Operation *loopOp, Operation *consumerOp,
};
llvm::SetVector<Operation *> slice;
for (auto operand : consumerOp->getOperands()) {
- getBackwardSlice(operand, &slice, options);
+ LogicalResult result = getBackwardSlice(operand, &slice, options);
+ assert(result.succeeded() && "expected a backward slice");
}
if (!slice.empty()) {
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index 4985d718c1780..c136ff92255cd 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -1094,7 +1094,8 @@ LogicalResult mlir::moveOperationDependencies(RewriterBase &rewriter,
return !dominance.properlyDominates(sliceBoundaryOp, insertionPoint);
};
llvm::SetVector<Operation *> slice;
- getBackwardSlice(op, &slice, options);
+ LogicalResult result = getBackwardSlice(op, &slice, options);
+ assert(result.succeeded() && "expected a backward slice");
// If the slice contains `insertionPoint` cannot move the dependencies.
if (slice.contains(insertionPoint)) {
@@ -1159,7 +1160,8 @@ LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter,
};
llvm::SetVector<Operation *> slice;
for (auto value : prunedValues) {
- getBackwardSlice(value, &slice, options);
+ LogicalResult result = getBackwardSlice(value, &slice, options);
+ assert(result.succeeded() && "expected a backward slice");
}
// If the slice contains `insertionPoint` cannot move the dependencies.
diff --git a/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp b/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp
index f26058f30ad7b..145acd99e6616 100644
--- a/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp
@@ -154,7 +154,9 @@ void VectorizerTestPass::testBackwardSlicing(llvm::raw_ostream &outs) {
patternTestSlicingOps().match(f, &matches);
for (auto m : matches) {
SetVector<Operation *> backwardSlice;
- getBackwardSlice(m.getMatchedOperation(), &backwardSlice);
+ LogicalResult result =
+ getBackwardSlice(m.getMatchedOperation(), &backwardSlice);
+ assert(result.succeeded() && "expected a backward slice");
outs << "\nmatched: " << *m.getMatchedOperation()
<< " backward static slice: ";
for (auto *op : backwardSlice)
diff --git a/mlir/test/lib/IR/TestSlicing.cpp b/mlir/test/lib/IR/TestSlicing.cpp
index e99d5976d6d9d..ad99be2b9d0c9 100644
--- a/mlir/test/lib/IR/TestSlicing.cpp
+++ b/mlir/test/lib/IR/TestSlicing.cpp
@@ -41,7 +41,8 @@ static LogicalResult createBackwardSliceFunction(Operation *op,
options.omitBlockArguments = omitBlockArguments;
// TODO: Make this default.
options.omitUsesFromAbove = false;
- getBackwardSlice(op, &slice, options);
+ LogicalResult result = getBackwardSlice(op, &slice, options);
+ assert(result.succeeded() && "expected a backward slice");
for (Operation *slicedOp : slice)
builder.clone(*slicedOp, mapper);
builder.create<func::ReturnOp>(loc);
More information about the Mlir-commits
mailing list