[Mlir-commits] [mlir] [Linalg] Add pattern to push down extract slice through linalg generic op (PR #154162)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Aug 20 07:03:59 PDT 2025
================
@@ -1236,6 +1238,269 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
ControlPropagationFn controlFn;
};
+// This struct contains infomation about extract_slice dims.
+struct SliceDimInfo {
+ OpFoldResult offset;
+ OpFoldResult sliceSize;
+ OpFoldResult outputSize;
+};
+
+/// Return the first input extract slice operand, if present, for the current
+/// generic op.
+static FailureOr<std::tuple<OpOperand *, unsigned>>
+getSliceOperandAndIndex(GenericOp genericOp) {
+ OpOperand *sliceOperand = nullptr;
+ unsigned operandIndex;
+ for (auto [idx, operand] : llvm::enumerate(genericOp.getDpsInputOperands())) {
+ auto extractOp = operand->get().getDefiningOp<tensor::ExtractSliceOp>();
+ if (!extractOp)
+ continue;
+ sliceOperand = operand;
+ operandIndex = idx;
+ break;
+ }
+ if (!sliceOperand) {
+ return failure();
+ }
+ return std::make_tuple(sliceOperand, operandIndex);
+}
+
+// Return a map of dims that have non full slices on them so that other operands
+// can use this information. Also return a bool mentioning if a reduction dim
+// has a non full slice as that can be used to fold the original extract slice.
+static FailureOr<std::tuple<llvm::DenseMap<int64_t, SliceDimInfo>, bool>>
+getNonFullSliceDimInfo(GenericOp genericOp, OpOperand *sliceOperand,
+ tensor::ExtractSliceOp producerSliceOp) {
+ llvm::DenseMap<int64_t, SliceDimInfo> nonZeroSliceDimMap;
+ bool hasNonZeroReductionDimSlice = false;
+ SmallVector<utils::IteratorType> iterators =
+ genericOp.getIteratorTypesArray();
+ SmallVector<OpFoldResult> offsets = producerSliceOp.getMixedOffsets();
+ SmallVector<OpFoldResult> sizes = producerSliceOp.getMixedSizes();
+
+ SmallVector<OpFoldResult> shape = llvm::map_to_vector(
+ producerSliceOp.getSourceType().getShape(),
+ [&](int64_t sz) -> OpFoldResult {
+ return getAsIndexOpFoldResult(genericOp.getContext(), sz);
+ });
+
+ for (auto [idx, expr] : llvm::enumerate(
+ genericOp.getMatchingIndexingMap(sliceOperand).getResults())) {
+ if (isConstantIntValue(offsets[idx], 0) &&
+ isEqualConstantIntOrValue(sizes[idx], shape[idx])) {
+ continue;
+ }
+ if (!isa<AffineDimExpr>(expr)) {
+ return failure();
+ }
+ SliceDimInfo sliceDimInfo{offsets[idx], sizes[idx], shape[idx]};
+ int64_t dimPos = cast<AffineDimExpr>(expr).getPosition();
+ nonZeroSliceDimMap[dimPos] = sliceDimInfo;
+ if (iterators[dimPos] == utils::IteratorType::reduction) {
+ hasNonZeroReductionDimSlice = true;
+ }
+ }
+ // Next check if the dims with non zero slice info are used as non
+ // AffineDimExpr and if they are then bail-out.
+ for (OpOperand &operand : genericOp->getOpOperands()) {
+ if (operand == *sliceOperand) {
+ continue;
+ }
+ AffineMap IndexingMap = genericOp.getMatchingIndexingMap(&operand);
+ if (llvm::any_of(IndexingMap.getResults(), [&](AffineExpr expr) {
+ if (isa<AffineDimExpr>(expr)) {
+ return false;
+ }
+ WalkResult status = expr.walk([&](AffineExpr expr) {
+ if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
+ if (nonZeroSliceDimMap.count(dimExpr.getPosition()) != 0) {
+ return WalkResult::interrupt();
----------------
Max191 wrote:
nit: Use `nonZeroSliceDimMap.contains(dimExpr.getPosition())`?
https://github.com/llvm/llvm-project/pull/154162
More information about the Mlir-commits
mailing list