[Mlir-commits] [mlir] [Linalg] Add pattern to push down extract slice through linalg generic op (PR #154162)
Nirvedh Meshram
llvmlistbot at llvm.org
Mon Aug 25 14:07:10 PDT 2025
https://github.com/nirvedhmeshram updated https://github.com/llvm/llvm-project/pull/154162
>From 4eebe2174cc773b213a2f512b7405e14174c4714 Mon Sep 17 00:00:00 2001
From: Nirvedh Meshram <nirvedh at gmail.com>
Date: Fri, 8 Aug 2025 14:44:54 -0700
Subject: [PATCH 1/3] [Linalg] Add pattern to push down extract slice through
generic
Signed-off-by: Nirvedh Meshram <nirvedh at gmail.com>
---
.../Dialect/Linalg/Transforms/Transforms.h | 5 +
.../Transforms/DataLayoutPropagation.cpp | 272 ++++++++++++++++++
.../Linalg/data-layout-propagation.mlir | 110 +++++++
.../Linalg/TestDataLayoutPropagation.cpp | 2 +
4 files changed, 389 insertions(+)
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 8d5306dca43e3..680fdffa9e587 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1918,6 +1918,11 @@ void populateDataLayoutPropagationPatterns(
RewritePatternSet &patterns,
const ControlPropagationFn &controlPackUnPackPropagation);
+/// Patterns to sink extract slice across other operations.
+void populateExtractSliceSinkingPatterns(
+ RewritePatternSet &patterns,
+ const ControlPropagationFn &controlPackUnPackPropagation);
+
/// Pattern to remove dead operands and results of `linalg.generic` operations.
/// This is a pattern wrapper for `deduplicateOperandsAndRemoveDeadResults`.
void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 0a9c1766425bd..d50ab8cf03714 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -6,10 +6,12 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/Dominance.h"
#include "llvm/ADT/SetOperations.h"
@@ -1236,6 +1238,269 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
ControlPropagationFn controlFn;
};
+// This struct contains infomation about extract_slice dims.
+struct SliceDimInfo {
+ OpFoldResult offset;
+ OpFoldResult sliceSize;
+ OpFoldResult outputSize;
+};
+
+/// Return the first input extract slice operand, if present, for the current
+/// generic op.
+static FailureOr<std::tuple<OpOperand *, unsigned>>
+getSliceOperandAndIndex(GenericOp genericOp) {
+ OpOperand *sliceOperand = nullptr;
+ unsigned operandIndex;
+ for (auto [idx, operand] : llvm::enumerate(genericOp.getDpsInputOperands())) {
+ auto extractOp = operand->get().getDefiningOp<tensor::ExtractSliceOp>();
+ if (!extractOp)
+ continue;
+ sliceOperand = operand;
+ operandIndex = idx;
+ break;
+ }
+ if (!sliceOperand) {
+ return failure();
+ }
+ return std::make_tuple(sliceOperand, operandIndex);
+}
+
+// Return a map of dims that have non full slices on them so that other operands
+// can use this information. Also return a bool mentioning if a reduction dim
+// has a non full slice as that can be used to fold the original extract slice.
+static FailureOr<std::tuple<llvm::DenseMap<int64_t, SliceDimInfo>, bool>>
+getNonFullSliceDimInfo(GenericOp genericOp, OpOperand *sliceOperand,
+ tensor::ExtractSliceOp producerSliceOp) {
+ llvm::DenseMap<int64_t, SliceDimInfo> nonZeroSliceDimMap;
+ bool hasNonZeroReductionDimSlice = false;
+ SmallVector<utils::IteratorType> iterators =
+ genericOp.getIteratorTypesArray();
+ SmallVector<OpFoldResult> offsets = producerSliceOp.getMixedOffsets();
+ SmallVector<OpFoldResult> sizes = producerSliceOp.getMixedSizes();
+
+ SmallVector<OpFoldResult> shape = llvm::map_to_vector(
+ producerSliceOp.getSourceType().getShape(),
+ [&](int64_t sz) -> OpFoldResult {
+ return getAsIndexOpFoldResult(genericOp.getContext(), sz);
+ });
+
+ for (auto [idx, expr] : llvm::enumerate(
+ genericOp.getMatchingIndexingMap(sliceOperand).getResults())) {
+ if (isConstantIntValue(offsets[idx], 0) &&
+ isEqualConstantIntOrValue(sizes[idx], shape[idx])) {
+ continue;
+ }
+ if (!isa<AffineDimExpr>(expr)) {
+ return failure();
+ }
+ SliceDimInfo sliceDimInfo{offsets[idx], sizes[idx], shape[idx]};
+ int64_t dimPos = cast<AffineDimExpr>(expr).getPosition();
+ nonZeroSliceDimMap[dimPos] = sliceDimInfo;
+ if (iterators[dimPos] == utils::IteratorType::reduction) {
+ hasNonZeroReductionDimSlice = true;
+ }
+ }
+ // Next check if the dims with non zero slice info are used as non
+ // AffineDimExpr and if they are then bail-out.
+ for (OpOperand &operand : genericOp->getOpOperands()) {
+ if (operand == *sliceOperand) {
+ continue;
+ }
+ AffineMap IndexingMap = genericOp.getMatchingIndexingMap(&operand);
+ if (llvm::any_of(IndexingMap.getResults(), [&](AffineExpr expr) {
+ if (isa<AffineDimExpr>(expr)) {
+ return false;
+ }
+ WalkResult status = expr.walk([&](AffineExpr expr) {
+ if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
+ if (nonZeroSliceDimMap.count(dimExpr.getPosition()) != 0) {
+ return WalkResult::interrupt();
+ }
+ }
+ return WalkResult::advance();
+ });
+ if (status.wasInterrupted()) {
+ return true;
+ }
+ return false;
+ })) {
+ return failure();
+ }
+ }
+ return std::make_tuple(nonZeroSliceDimMap, hasNonZeroReductionDimSlice);
+}
+
+static FailureOr<std::tuple<GenericOp, Value>>
+pushDownExtractSliceOpThroughGenericOp(RewriterBase &rewriter,
+ GenericOp genericOp,
+ ControlPropagationFn controlFn) {
+ if (genericOp.getNumResults() != 1)
+ return failure();
+ if (hasGatherSemantics(genericOp))
+ return failure();
+ // Collect the unPacked operand, if present.
+ auto maybeSliceOperandAndIndex = getSliceOperandAndIndex(genericOp);
+ if (failed(maybeSliceOperandAndIndex))
+ return failure();
+ OpOperand *sliceOperand = std::get<0>(*maybeSliceOperandAndIndex);
+ unsigned OperandIndex = std::get<1>(*maybeSliceOperandAndIndex);
+
+ if (!controlFn(sliceOperand))
+ return failure();
+
+ tensor::ExtractSliceOp producerSliceOp =
+ sliceOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
+ assert(producerSliceOp && "expect a valid UnPackOp");
+
+ if (producerSliceOp.getSource().getType().getRank() !=
+ producerSliceOp.getResult().getType().getRank()) {
+ return failure();
+ }
+
+ SmallVector<OpFoldResult> strides = producerSliceOp.getMixedStrides();
+ if (!areAllConstantIntValue(strides, 1))
+ return failure();
+
+ SmallVector<OpFoldResult> offsets = producerSliceOp.getMixedOffsets();
+ SmallVector<OpFoldResult> sizes = producerSliceOp.getMixedSizes();
+
+ // check if we can support the propagation of this extractSlice
+ // through the generic op and if so return the dimensions that
+
+ auto maybeNonZeroSliceDimMap =
+ getNonFullSliceDimInfo(genericOp, sliceOperand, producerSliceOp);
+
+ if (failed(maybeNonZeroSliceDimMap)) {
+ return failure();
+ }
+
+ auto nonZeroSliceDimMap = std::get<0>(*maybeNonZeroSliceDimMap);
+ bool hasNonZeroReductionDimSlice = std::get<1>(*maybeNonZeroSliceDimMap);
+
+ // Store the padding information as (dimPos, lowPad, highPad, PaddedShape).
+ Location loc = genericOp->getLoc();
+ AffineExpr dim0, dim1;
+ bindDims(rewriter.getContext(), dim0, dim1);
+ auto subMap = AffineMap::get(2, 0, {dim0 - dim1});
+ auto sub = [&](OpFoldResult v1, OpFoldResult v2) {
+ return affine::makeComposedFoldedAffineApply(rewriter, loc, subMap,
+ {v1, v2});
+ };
+
+ MLIRContext *ctx = genericOp.getContext();
+ SmallVector<Value> paddedInputs;
+ for (auto [idx, operand] : llvm::enumerate(genericOp.getDpsInputOperands())) {
+ if (idx == OperandIndex && !hasNonZeroReductionDimSlice) {
+ paddedInputs.push_back(producerSliceOp.getSource());
+ continue;
+ }
+ AffineMap IndexingMap = genericOp.getMatchingIndexingMap(operand);
+ SmallVector<OpFoldResult> operandLowPads(IndexingMap.getNumResults(),
+ getAsIndexOpFoldResult(ctx, 0));
+ SmallVector<OpFoldResult> operandHighPads(IndexingMap.getNumResults(),
+ getAsIndexOpFoldResult(ctx, 0));
+ for (auto [idx, expr] : llvm::enumerate(IndexingMap.getResults())) {
+ if (!isa<AffineDimExpr>(expr)) {
+ continue;
+ }
+ AffineDimExpr dimExpr = cast<AffineDimExpr>(expr);
+ if (nonZeroSliceDimMap.contains(dimExpr.getPosition())) {
+ SliceDimInfo sliceDimInfo = nonZeroSliceDimMap[dimExpr.getPosition()];
+ operandLowPads[idx] = sliceDimInfo.offset;
+ operandHighPads[idx] =
+ sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset),
+ sliceDimInfo.sliceSize);
+ }
+ }
+ auto paddingValue = ub::PoisonOp::create(
+ rewriter, loc, getElementTypeOrSelf(operand->get().getType()));
+ auto paddedOperand = tensor::PadOp::create(
+ rewriter, loc, Type(), operand->get(), operandLowPads, operandHighPads,
+ paddingValue, /*nofold=*/false);
+ paddedInputs.push_back(paddedOperand);
+ }
+ AffineMap outputIndexingMap =
+ genericOp.getMatchingIndexingMap(genericOp.getDpsInitOperand(0));
+
+ auto outputShapeType =
+ llvm::cast<ShapedType>(genericOp.getDpsInitOperand(0)->get().getType());
+ SmallVector<OpFoldResult> OutputShape = llvm::map_to_vector(
+ outputShapeType.getShape(),
+ [&](int64_t sz) -> OpFoldResult { return rewriter.getIndexAttr(sz); });
+ SmallVector<OpFoldResult> newSizes = OutputShape;
+ SmallVector<OpFoldResult> outputLowPads(outputIndexingMap.getNumResults(),
+ getAsIndexOpFoldResult(ctx, 0));
+ SmallVector<OpFoldResult> outputHighPads(outputIndexingMap.getNumResults(),
+ getAsIndexOpFoldResult(ctx, 0));
+ SmallVector<OpFoldResult> newStrides(outputIndexingMap.getNumResults(),
+ getAsIndexOpFoldResult(ctx, 1));
+ for (auto [idx, expr] : llvm::enumerate(outputIndexingMap.getResults())) {
+ if (!isa<AffineDimExpr>(expr)) {
+ continue;
+ }
+ AffineDimExpr dimExpr = cast<AffineDimExpr>(expr);
+ if (nonZeroSliceDimMap.contains(dimExpr.getPosition())) {
+ SliceDimInfo sliceDimInfo = nonZeroSliceDimMap[dimExpr.getPosition()];
+ outputLowPads[idx] = sliceDimInfo.offset;
+ outputHighPads[idx] =
+ sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset),
+ sliceDimInfo.sliceSize);
+ OutputShape[idx] = sliceDimInfo.outputSize;
+ newSizes[idx] = sliceDimInfo.sliceSize;
+ }
+ }
+ Value newPadOutput;
+ auto outputElType =
+ getElementTypeOrSelf(genericOp.getDpsInits()[0].getType());
+ if (isGenericOutsNotUsed(genericOp)) {
+ newPadOutput =
+ tensor::EmptyOp::create(rewriter, loc, OutputShape, outputElType);
+
+ } else {
+
+ auto paddingValue = ub::PoisonOp::create(rewriter, loc, outputElType);
+ newPadOutput = tensor::PadOp::create(
+ rewriter, loc, Type(), genericOp.getDpsInits()[0], outputLowPads,
+ outputHighPads, paddingValue, /*nofold=*/false);
+ }
+
+ auto newGenericOp = linalg::GenericOp::create(
+ rewriter, loc, newPadOutput.getType(), paddedInputs, {newPadOutput},
+ genericOp.getIndexingMapsArray(), genericOp.getIteratorTypesArray(),
+ /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
+ rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(),
+ newGenericOp.getRegion().begin());
+
+ auto extractOp = tensor::ExtractSliceOp::create(
+ rewriter, loc,
+ newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0)),
+ outputLowPads, newSizes, newStrides);
+ Value extractRes = extractOp.getResult();
+
+ return std::make_tuple(newGenericOp, extractRes);
+}
+
+class PushDownExtractSliceOpThroughGenericOp final
+ : public OpRewritePattern<GenericOp> {
+public:
+ PushDownExtractSliceOpThroughGenericOp(MLIRContext *context,
+ ControlPropagationFn fun)
+ : OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {}
+
+ LogicalResult matchAndRewrite(GenericOp genericOp,
+ PatternRewriter &rewriter) const override {
+ auto genericAndRepl =
+ pushDownExtractSliceOpThroughGenericOp(rewriter, genericOp, controlFn);
+ if (failed(genericAndRepl))
+ return failure();
+ rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl));
+ return success();
+ }
+
+private:
+ ControlPropagationFn controlFn;
+};
+
} // namespace
void mlir::linalg::populateDataLayoutPropagationPatterns(
@@ -1247,3 +1512,10 @@ void mlir::linalg::populateDataLayoutPropagationPatterns(
PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
patterns.getContext(), controlPackUnPackPropagation);
}
+
+void mlir::linalg::populateExtractSliceSinkingPatterns(
+ RewritePatternSet &patterns,
+ const ControlPropagationFn &controlPackUnPackPropagation) {
+ patterns.insert<PushDownExtractSliceOpThroughGenericOp>(
+ patterns.getContext(), controlPackUnPackPropagation);
+}
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index cc26fa48abf4b..723eecb52351b 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -1447,3 +1447,113 @@ func.func @push_unpack_in_padded_domain_out_used(%arg0: tensor<8x8x4x8xf32>, %ar
// CHECK: %[[UNPACK2:.+]] = linalg.unpack %[[GENERIC]]
// CHECK-SAME: into %[[ARG1]]
// CHECK: return %[[UNPACK2]] : tensor<?x64xf32>
+
+// -----
+
+module {
+ func.func @push_extract_through_generic(%arg0: tensor<128x7x128xf32>, %arg1: tensor<?x5x3x128xf32>, %arg2: tensor<?x5x128xbf16>, %arg3: index) -> tensor<?x5x128xbf16> {
+ %extracted_slice = tensor.extract_slice %arg0[0, 0, %arg3] [128, 7, %arg3] [1, 1, 1] : tensor<128x7x128xf32> to tensor<128x7x?xf32>
+ %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d2 + d3, d0)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%extracted_slice, %arg1 : tensor<128x7x?xf32>, tensor<?x5x3x128xf32>) outs(%arg2 : tensor<?x5x128xbf16>) {
+ ^bb0(%in: f32, %in_0: f32, %out: bf16):
+ %1 = arith.truncf %in : f32 to bf16
+ linalg.yield %1 : bf16
+ } -> tensor<?x5x128xbf16>
+ return %0 : tensor<?x5x128xbf16>
+ }
+}
+
+// CHECK-LABEL: func.func @push_extract_through_generic
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]
+// CHECK: %[[POISON:.+]] = ub.poison : f32
+// CHECK: %[[PADDED:.+]] = tensor.pad %arg1
+// CHECK: tensor.yield %[[POISON]] : f32
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<128x5x128xbf16>
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[ARG0]], %[[PADDED]]
+// CHECK-SAME: outs(%[[EMPTY]]
+// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %3[%[[ARG3]], 0, 0] [%[[ARG3]], 5, 128] [1, 1, 1] : tensor<128x5x128xbf16> to tensor<?x5x128xbf16>
+// CHECK: return %[[EXTRACT]]
+
+// -----
+
+func.func @nopush_extract_through_generic_nodimexpr1(%arg0: tensor<128x7x128xf32>, %arg1: tensor<?x5x3x128xf32>, %arg2: tensor<?x5x128xbf16>, %arg3: index) -> tensor<?x5x128xbf16> {
+ %extracted_slice = tensor.extract_slice %arg0[0, %arg3, %arg3] [128, 7, %arg3] [1, 1, 1] : tensor<128x7x128xf32> to tensor<128x7x?xf32>
+ %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d2 + d3, d0)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%extracted_slice, %arg1 : tensor<128x7x?xf32>, tensor<?x5x3x128xf32>) outs(%arg2 : tensor<?x5x128xbf16>) {
+ ^bb0(%in: f32, %in_0: f32, %out: bf16):
+ %1 = arith.truncf %in : f32 to bf16
+ linalg.yield %1 : bf16
+ } -> tensor<?x5x128xbf16>
+ return %0 : tensor<?x5x128xbf16>
+}
+
+// CHECK-LABEL: func.func @nopush_extract_through_generic_nodimexpr1
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK: return %[[GENERIC]]
+
+// -----
+
+func.func @nopush_extract_through_generic_nodimexpr2(%arg0: tensor<128x?x128xf32>, %arg1: tensor<128x5x3x128xf32>, %arg2: tensor<128x?x128xbf16>, %arg3: index) -> tensor<128x?x128xbf16> {
+ %extracted_slice = tensor.extract_slice %arg1[0, %arg3, 0, 0] [128, %arg3, 3, 128] [1, 1, 1, 1] : tensor<128x5x3x128xf32> to tensor<128x?x3x128xf32>
+ %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d2 + d3, d0)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %extracted_slice : tensor<128x?x128xf32>, tensor<128x?x3x128xf32>) outs(%arg2 : tensor<128x?x128xbf16>) {
+ ^bb0(%in: f32, %in_0: f32, %out: bf16):
+ %1 = arith.truncf %in : f32 to bf16
+ linalg.yield %1 : bf16
+ } -> tensor<128x?x128xbf16>
+ return %0 : tensor<128x?x128xbf16>
+}
+
+// CHECK-LABEL: func.func @nopush_extract_through_generic_nodimexpr2
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK: return %[[GENERIC]]
+
+// -----
+
+func.func @push_redcutionextract_through_generic_withoutsused_2(%arg0: tensor<128x128xf32>, %arg1: tensor<?xbf16>, %arg2: index) -> tensor<?xbf16> {
+ %extracted_slice = tensor.extract_slice %arg0[%arg2, %arg2] [%arg2, %arg2] [1, 1] : tensor<128x128xf32> to tensor<?x?xf32>
+ %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%extracted_slice : tensor<?x?xf32>) outs(%arg1 : tensor<?xbf16>) {
+ ^bb0(%in: f32, %out: bf16):
+ %1 = arith.truncf %in : f32 to bf16
+ %2 = arith.addf %1, %out : bf16
+ linalg.yield %2 : bf16
+ } -> tensor<?xbf16>
+ return %0 : tensor<?xbf16>
+}
+
+// CHECK-LABEL: func.func @push_redcutionextract_through_generic_withoutsused_2
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
+// CHECK: %[[POISON_BF16:.+]] = ub.poison : bf16
+// CHECK: %[[POISON_F32:.+]] = ub.poison : f32
+// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[ARG0]][%[[ARG2]], %[[ARG2]]] [%[[ARG2]], %[[ARG2]]] [1, 1] : tensor<128x128xf32> to tensor<?x?xf32>
+// CHECK: %[[PADDED:.+]] = tensor.pad %[[EXTRACT]]
+// CHECK: tensor.yield %[[POISON_F32]] : f32
+// CHECK: %[[APPLY2:.+]] = affine.apply #map()[%[[ARG2]]]
+// CHECK: %[[PADDED1:.+]] = tensor.pad %[[ARG1]] low[%[[ARG2]]] high[%[[APPLY2]]]
+// CHECK: tensor.yield %[[POISON_BF16]] : bf16
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[PADDED]]
+// CHECK-SAME: outs(%[[PADDED1]]
+// CHECK: %[[EXTRACT1:.+]] = tensor.extract_slice %[[GENERIC]][%[[ARG2]]] [%[[ARG2]]] [1] : tensor<?xbf16> to tensor<?xbf16>
+// CHECK: return %[[EXTRACT1]]
+
+
+// -----
+
+func.func @nopush_rankreducingextract(%arg0: tensor<128x128x128xf32>, %arg1: tensor<?xbf16>, %arg2: index) -> tensor<?xbf16> {
+ %extracted_slice = tensor.extract_slice %arg0[0, %arg2, %arg2] [1, %arg2, %arg2] [1, 1, 1] : tensor<128x128x128xf32> to tensor<?x?xf32>
+ %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%extracted_slice : tensor<?x?xf32>) outs(%arg1 : tensor<?xbf16>) {
+ ^bb0(%in: f32, %out: bf16):
+ %1 = arith.truncf %in : f32 to bf16
+ %2 = arith.addf %1, %out : bf16
+ linalg.yield %2 : bf16
+ } -> tensor<?xbf16>
+ return %0 : tensor<?xbf16>
+}
+
+// CHECK-LABEL: func.func @nopush_rankreducingextract
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK: return %[[GENERIC]]
diff --git a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
index d0700f9a4f1a4..2cf25d8fc8c19 100644
--- a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
@@ -34,6 +34,8 @@ struct TestDataLayoutPropagationPass
RewritePatternSet patterns(context);
linalg::populateDataLayoutPropagationPatterns(
patterns, [](OpOperand *opOperand) { return true; });
+ linalg::populateExtractSliceSinkingPatterns(
+ patterns, [](OpOperand *opOperand) { return true; });
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}
>From 1493d56583ee5f5149a4157561486966f74faeaa Mon Sep 17 00:00:00 2001
From: Nirvedh Meshram <nirvedh at gmail.com>
Date: Mon, 25 Aug 2025 15:25:01 -0500
Subject: [PATCH 2/3] address reviwer comments
Signed-off-by: Nirvedh Meshram <nirvedh at gmail.com>
---
.../Transforms/DataLayoutPropagation.cpp | 127 +++++++++---------
1 file changed, 65 insertions(+), 62 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index d50ab8cf03714..40085a2368009 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -1247,61 +1247,55 @@ struct SliceDimInfo {
/// Return the first input extract slice operand, if present, for the current
/// generic op.
-static FailureOr<std::tuple<OpOperand *, unsigned>>
-getSliceOperandAndIndex(GenericOp genericOp) {
+static FailureOr<OpOperand *> getSliceOperand(GenericOp genericOp) {
OpOperand *sliceOperand = nullptr;
- unsigned operandIndex;
- for (auto [idx, operand] : llvm::enumerate(genericOp.getDpsInputOperands())) {
+ for (auto operand : genericOp.getDpsInputOperands()) {
auto extractOp = operand->get().getDefiningOp<tensor::ExtractSliceOp>();
if (!extractOp)
continue;
sliceOperand = operand;
- operandIndex = idx;
break;
}
if (!sliceOperand) {
return failure();
}
- return std::make_tuple(sliceOperand, operandIndex);
+ return sliceOperand;
}
-// Return a map of dims that have non full slices on them so that other operands
+// Return a map of dims that have partial slices on them so that other operands
// can use this information. Also return a bool mentioning if a reduction dim
// has a non full slice as that can be used to fold the original extract slice.
-static FailureOr<std::tuple<llvm::DenseMap<int64_t, SliceDimInfo>, bool>>
-getNonFullSliceDimInfo(GenericOp genericOp, OpOperand *sliceOperand,
- tensor::ExtractSliceOp producerSliceOp) {
- llvm::DenseMap<int64_t, SliceDimInfo> nonZeroSliceDimMap;
- bool hasNonZeroReductionDimSlice = false;
- SmallVector<utils::IteratorType> iterators =
- genericOp.getIteratorTypesArray();
+static FailureOr<llvm::DenseMap<int64_t, SliceDimInfo>>
+getPartialSliceDimInfo(GenericOp genericOp, OpOperand *sliceOperand) {
+ tensor::ExtractSliceOp producerSliceOp =
+ sliceOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
+ assert(producerSliceOp && "expect a valid ExtractSliceOp");
+ llvm::DenseMap<int64_t, SliceDimInfo> partialSliceDimMap;
SmallVector<OpFoldResult> offsets = producerSliceOp.getMixedOffsets();
SmallVector<OpFoldResult> sizes = producerSliceOp.getMixedSizes();
- SmallVector<OpFoldResult> shape = llvm::map_to_vector(
- producerSliceOp.getSourceType().getShape(),
- [&](int64_t sz) -> OpFoldResult {
- return getAsIndexOpFoldResult(genericOp.getContext(), sz);
- });
+ SmallVector<OpFoldResult> shape = getAsIndexOpFoldResult(
+ genericOp.getContext(), producerSliceOp.getSourceType().getShape());
for (auto [idx, expr] : llvm::enumerate(
genericOp.getMatchingIndexingMap(sliceOperand).getResults())) {
+ // If we have a full slice in a dimension then we dont need to add it to
+ // the partial slice map.
if (isConstantIntValue(offsets[idx], 0) &&
isEqualConstantIntOrValue(sizes[idx], shape[idx])) {
continue;
}
+ // We only support partial slices of AffineDimExprs so bail-out if thats not
+ // the case.
if (!isa<AffineDimExpr>(expr)) {
return failure();
}
SliceDimInfo sliceDimInfo{offsets[idx], sizes[idx], shape[idx]};
int64_t dimPos = cast<AffineDimExpr>(expr).getPosition();
- nonZeroSliceDimMap[dimPos] = sliceDimInfo;
- if (iterators[dimPos] == utils::IteratorType::reduction) {
- hasNonZeroReductionDimSlice = true;
- }
+ partialSliceDimMap[dimPos] = sliceDimInfo;
}
- // Next check if the dims with non zero slice info are used as non
- // AffineDimExpr and if they are then bail-out.
+ // Next check if the dims with partial slice info are used in non
+ // AffineDimExpr in other operands and if they are then bail-out.
for (OpOperand &operand : genericOp->getOpOperands()) {
if (operand == *sliceOperand) {
continue;
@@ -1313,7 +1307,7 @@ getNonFullSliceDimInfo(GenericOp genericOp, OpOperand *sliceOperand,
}
WalkResult status = expr.walk([&](AffineExpr expr) {
if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
- if (nonZeroSliceDimMap.count(dimExpr.getPosition()) != 0) {
+ if (partialSliceDimMap.contains(dimExpr.getPosition())) {
return WalkResult::interrupt();
}
}
@@ -1327,7 +1321,7 @@ getNonFullSliceDimInfo(GenericOp genericOp, OpOperand *sliceOperand,
return failure();
}
}
- return std::make_tuple(nonZeroSliceDimMap, hasNonZeroReductionDimSlice);
+ return partialSliceDimMap;
}
static FailureOr<std::tuple<GenericOp, Value>>
@@ -1335,47 +1329,57 @@ pushDownExtractSliceOpThroughGenericOp(RewriterBase &rewriter,
GenericOp genericOp,
ControlPropagationFn controlFn) {
if (genericOp.getNumResults() != 1)
- return failure();
+ return rewriter.notifyMatchFailure(
+ genericOp, "propagation through multi-result generic is unsupported.");
if (hasGatherSemantics(genericOp))
+ return rewriter.notifyMatchFailure(
+ genericOp,
+ "propagation through generic with gather semantics is unsupported.");
+ // Collect the sliced operand, if present.
+ auto maybeSliceOperand = getSliceOperand(genericOp);
+ if (failed(maybeSliceOperand))
return failure();
- // Collect the unPacked operand, if present.
- auto maybeSliceOperandAndIndex = getSliceOperandAndIndex(genericOp);
- if (failed(maybeSliceOperandAndIndex))
- return failure();
- OpOperand *sliceOperand = std::get<0>(*maybeSliceOperandAndIndex);
- unsigned OperandIndex = std::get<1>(*maybeSliceOperandAndIndex);
+ OpOperand *sliceOperand = *maybeSliceOperand;
+ unsigned OperandIndex = sliceOperand->getOperandNumber();
if (!controlFn(sliceOperand))
return failure();
tensor::ExtractSliceOp producerSliceOp =
sliceOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
- assert(producerSliceOp && "expect a valid UnPackOp");
+ assert(producerSliceOp && "expect a valid ExtractSliceOp");
if (producerSliceOp.getSource().getType().getRank() !=
producerSliceOp.getResult().getType().getRank()) {
- return failure();
+ return rewriter.notifyMatchFailure(
+ genericOp,
+ "propagation of rank-reducing extract slice is unsupported.");
}
SmallVector<OpFoldResult> strides = producerSliceOp.getMixedStrides();
if (!areAllConstantIntValue(strides, 1))
- return failure();
-
- SmallVector<OpFoldResult> offsets = producerSliceOp.getMixedOffsets();
- SmallVector<OpFoldResult> sizes = producerSliceOp.getMixedSizes();
+ return rewriter.notifyMatchFailure(
+ genericOp, "propagation of strided extract slice is unsupported.");
// check if we can support the propagation of this extractSlice
// through the generic op and if so return the dimensions that
- auto maybeNonZeroSliceDimMap =
- getNonFullSliceDimInfo(genericOp, sliceOperand, producerSliceOp);
+ auto maybePartialSliceDimMap =
+ getPartialSliceDimInfo(genericOp, sliceOperand);
- if (failed(maybeNonZeroSliceDimMap)) {
+ if (failed(maybePartialSliceDimMap)) {
return failure();
}
- auto nonZeroSliceDimMap = std::get<0>(*maybeNonZeroSliceDimMap);
- bool hasNonZeroReductionDimSlice = std::get<1>(*maybeNonZeroSliceDimMap);
+ auto partialSliceDimMap = *maybePartialSliceDimMap;
+
+ SmallVector<utils::IteratorType> iterators =
+ genericOp.getIteratorTypesArray();
+ bool hasPartialReductionDimSlice =
+ llvm::any_of(partialSliceDimMap, [&](const auto &slice) {
+ int64_t sliceDim = slice.first;
+ return iterators[sliceDim] == utils::IteratorType::reduction;
+ });
// Store the padding information as (dimPos, lowPad, highPad, PaddedShape).
Location loc = genericOp->getLoc();
@@ -1390,7 +1394,7 @@ pushDownExtractSliceOpThroughGenericOp(RewriterBase &rewriter,
MLIRContext *ctx = genericOp.getContext();
SmallVector<Value> paddedInputs;
for (auto [idx, operand] : llvm::enumerate(genericOp.getDpsInputOperands())) {
- if (idx == OperandIndex && !hasNonZeroReductionDimSlice) {
+ if (idx == OperandIndex && !hasPartialReductionDimSlice) {
paddedInputs.push_back(producerSliceOp.getSource());
continue;
}
@@ -1404,13 +1408,14 @@ pushDownExtractSliceOpThroughGenericOp(RewriterBase &rewriter,
continue;
}
AffineDimExpr dimExpr = cast<AffineDimExpr>(expr);
- if (nonZeroSliceDimMap.contains(dimExpr.getPosition())) {
- SliceDimInfo sliceDimInfo = nonZeroSliceDimMap[dimExpr.getPosition()];
- operandLowPads[idx] = sliceDimInfo.offset;
- operandHighPads[idx] =
- sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset),
- sliceDimInfo.sliceSize);
+ if (!partialSliceDimMap.contains(dimExpr.getPosition())) {
+ continue;
}
+ SliceDimInfo sliceDimInfo = partialSliceDimMap[dimExpr.getPosition()];
+ operandLowPads[idx] = sliceDimInfo.offset;
+ operandHighPads[idx] =
+ sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset),
+ sliceDimInfo.sliceSize);
}
auto paddingValue = ub::PoisonOp::create(
rewriter, loc, getElementTypeOrSelf(operand->get().getType()));
@@ -1439,15 +1444,15 @@ pushDownExtractSliceOpThroughGenericOp(RewriterBase &rewriter,
continue;
}
AffineDimExpr dimExpr = cast<AffineDimExpr>(expr);
- if (nonZeroSliceDimMap.contains(dimExpr.getPosition())) {
- SliceDimInfo sliceDimInfo = nonZeroSliceDimMap[dimExpr.getPosition()];
- outputLowPads[idx] = sliceDimInfo.offset;
- outputHighPads[idx] =
- sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset),
- sliceDimInfo.sliceSize);
- OutputShape[idx] = sliceDimInfo.outputSize;
- newSizes[idx] = sliceDimInfo.sliceSize;
+ if (!partialSliceDimMap.contains(dimExpr.getPosition())) {
+ continue;
}
+ SliceDimInfo sliceDimInfo = partialSliceDimMap[dimExpr.getPosition()];
+ outputLowPads[idx] = sliceDimInfo.offset;
+ outputHighPads[idx] = sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset),
+ sliceDimInfo.sliceSize);
+ OutputShape[idx] = sliceDimInfo.outputSize;
+ newSizes[idx] = sliceDimInfo.sliceSize;
}
Value newPadOutput;
auto outputElType =
@@ -1455,9 +1460,7 @@ pushDownExtractSliceOpThroughGenericOp(RewriterBase &rewriter,
if (isGenericOutsNotUsed(genericOp)) {
newPadOutput =
tensor::EmptyOp::create(rewriter, loc, OutputShape, outputElType);
-
} else {
-
auto paddingValue = ub::PoisonOp::create(rewriter, loc, outputElType);
newPadOutput = tensor::PadOp::create(
rewriter, loc, Type(), genericOp.getDpsInits()[0], outputLowPads,
>From f08b03cc96077b7c6a7e3a3d20dab4d1bf158f91 Mon Sep 17 00:00:00 2001
From: Nirvedh Meshram <nirvedh at gmail.com>
Date: Mon, 25 Aug 2025 16:06:58 -0500
Subject: [PATCH 3/3] add shape types for pads
Signed-off-by: Nirvedh Meshram <nirvedh at gmail.com>
---
mlir/test/Dialect/Linalg/data-layout-propagation.mlir | 3 +++
1 file changed, 3 insertions(+)
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index 723eecb52351b..0e42027644797 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -1470,6 +1470,7 @@ module {
// CHECK: %[[POISON:.+]] = ub.poison : f32
// CHECK: %[[PADDED:.+]] = tensor.pad %arg1
// CHECK: tensor.yield %[[POISON]] : f32
+// CHECK: } : tensor<?x5x3x128xf32> to tensor<?x5x3x128xf32>
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<128x5x128xbf16>
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: ins(%[[ARG0]], %[[PADDED]]
@@ -1531,9 +1532,11 @@ func.func @push_redcutionextract_through_generic_withoutsused_2(%arg0: tensor<12
// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[ARG0]][%[[ARG2]], %[[ARG2]]] [%[[ARG2]], %[[ARG2]]] [1, 1] : tensor<128x128xf32> to tensor<?x?xf32>
// CHECK: %[[PADDED:.+]] = tensor.pad %[[EXTRACT]]
// CHECK: tensor.yield %[[POISON_F32]] : f32
+// CHECK: } : tensor<?x?xf32> to tensor<?x?xf32>
// CHECK: %[[APPLY2:.+]] = affine.apply #map()[%[[ARG2]]]
// CHECK: %[[PADDED1:.+]] = tensor.pad %[[ARG1]] low[%[[ARG2]]] high[%[[APPLY2]]]
// CHECK: tensor.yield %[[POISON_BF16]] : bf16
+// CHECK: } : tensor<?xbf16> to tensor<?xbf16>
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: ins(%[[PADDED]]
// CHECK-SAME: outs(%[[PADDED1]]
More information about the Mlir-commits
mailing list