[Mlir-commits] [mlir] [Linalg] Add pattern to push down extract slice through linalg generic op (PR #154162)

Nirvedh Meshram llvmlistbot at llvm.org
Mon Aug 25 14:07:10 PDT 2025


https://github.com/nirvedhmeshram updated https://github.com/llvm/llvm-project/pull/154162

>From 4eebe2174cc773b213a2f512b7405e14174c4714 Mon Sep 17 00:00:00 2001
From: Nirvedh Meshram <nirvedh at gmail.com>
Date: Fri, 8 Aug 2025 14:44:54 -0700
Subject: [PATCH 1/3] [Linalg] Add pattern to push down extract slice through
 generic

Signed-off-by: Nirvedh Meshram <nirvedh at gmail.com>
---
 .../Dialect/Linalg/Transforms/Transforms.h    |   5 +
 .../Transforms/DataLayoutPropagation.cpp      | 272 ++++++++++++++++++
 .../Linalg/data-layout-propagation.mlir       | 110 +++++++
 .../Linalg/TestDataLayoutPropagation.cpp      |   2 +
 4 files changed, 389 insertions(+)

diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 8d5306dca43e3..680fdffa9e587 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1918,6 +1918,11 @@ void populateDataLayoutPropagationPatterns(
     RewritePatternSet &patterns,
     const ControlPropagationFn &controlPackUnPackPropagation);
 
+/// Patterns to sink extract slice across other operations.
+void populateExtractSliceSinkingPatterns(
+    RewritePatternSet &patterns,
+    const ControlPropagationFn &controlPackUnPackPropagation);
+
 /// Pattern to remove dead operands and results of `linalg.generic` operations.
 /// This is a pattern wrapper for `deduplicateOperandsAndRemoveDeadResults`.
 void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 0a9c1766425bd..d50ab8cf03714 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -6,10 +6,12 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/IR/Dominance.h"
 #include "llvm/ADT/SetOperations.h"
@@ -1236,6 +1238,269 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
   ControlPropagationFn controlFn;
 };
 
+// This struct contains infomation about extract_slice dims.
+struct SliceDimInfo {
+  OpFoldResult offset;
+  OpFoldResult sliceSize;
+  OpFoldResult outputSize;
+};
+
+/// Return the first input extract slice operand, if present, for the current
+/// generic op.
+static FailureOr<std::tuple<OpOperand *, unsigned>>
+getSliceOperandAndIndex(GenericOp genericOp) {
+  OpOperand *sliceOperand = nullptr;
+  unsigned operandIndex;
+  for (auto [idx, operand] : llvm::enumerate(genericOp.getDpsInputOperands())) {
+    auto extractOp = operand->get().getDefiningOp<tensor::ExtractSliceOp>();
+    if (!extractOp)
+      continue;
+    sliceOperand = operand;
+    operandIndex = idx;
+    break;
+  }
+  if (!sliceOperand) {
+    return failure();
+  }
+  return std::make_tuple(sliceOperand, operandIndex);
+}
+
+// Return a map of dims that have non full slices on them so that other operands
+// can use this information. Also return a bool mentioning if a reduction dim
+// has a non full slice as that can be used to fold the original extract slice.
+static FailureOr<std::tuple<llvm::DenseMap<int64_t, SliceDimInfo>, bool>>
+getNonFullSliceDimInfo(GenericOp genericOp, OpOperand *sliceOperand,
+                       tensor::ExtractSliceOp producerSliceOp) {
+  llvm::DenseMap<int64_t, SliceDimInfo> nonZeroSliceDimMap;
+  bool hasNonZeroReductionDimSlice = false;
+  SmallVector<utils::IteratorType> iterators =
+      genericOp.getIteratorTypesArray();
+  SmallVector<OpFoldResult> offsets = producerSliceOp.getMixedOffsets();
+  SmallVector<OpFoldResult> sizes = producerSliceOp.getMixedSizes();
+
+  SmallVector<OpFoldResult> shape = llvm::map_to_vector(
+      producerSliceOp.getSourceType().getShape(),
+      [&](int64_t sz) -> OpFoldResult {
+        return getAsIndexOpFoldResult(genericOp.getContext(), sz);
+      });
+
+  for (auto [idx, expr] : llvm::enumerate(
+           genericOp.getMatchingIndexingMap(sliceOperand).getResults())) {
+    if (isConstantIntValue(offsets[idx], 0) &&
+        isEqualConstantIntOrValue(sizes[idx], shape[idx])) {
+      continue;
+    }
+    if (!isa<AffineDimExpr>(expr)) {
+      return failure();
+    }
+    SliceDimInfo sliceDimInfo{offsets[idx], sizes[idx], shape[idx]};
+    int64_t dimPos = cast<AffineDimExpr>(expr).getPosition();
+    nonZeroSliceDimMap[dimPos] = sliceDimInfo;
+    if (iterators[dimPos] == utils::IteratorType::reduction) {
+      hasNonZeroReductionDimSlice = true;
+    }
+  }
+  // Next check if the dims with non zero slice info are used as non
+  // AffineDimExpr and if they are then bail-out.
+  for (OpOperand &operand : genericOp->getOpOperands()) {
+    if (operand == *sliceOperand) {
+      continue;
+    }
+    AffineMap IndexingMap = genericOp.getMatchingIndexingMap(&operand);
+    if (llvm::any_of(IndexingMap.getResults(), [&](AffineExpr expr) {
+          if (isa<AffineDimExpr>(expr)) {
+            return false;
+          }
+          WalkResult status = expr.walk([&](AffineExpr expr) {
+            if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
+              if (nonZeroSliceDimMap.count(dimExpr.getPosition()) != 0) {
+                return WalkResult::interrupt();
+              }
+            }
+            return WalkResult::advance();
+          });
+          if (status.wasInterrupted()) {
+            return true;
+          }
+          return false;
+        })) {
+      return failure();
+    }
+  }
+  return std::make_tuple(nonZeroSliceDimMap, hasNonZeroReductionDimSlice);
+}
+
+static FailureOr<std::tuple<GenericOp, Value>>
+pushDownExtractSliceOpThroughGenericOp(RewriterBase &rewriter,
+                                       GenericOp genericOp,
+                                       ControlPropagationFn controlFn) {
+  if (genericOp.getNumResults() != 1)
+    return failure();
+  if (hasGatherSemantics(genericOp))
+    return failure();
+  // Collect the unPacked operand, if present.
+  auto maybeSliceOperandAndIndex = getSliceOperandAndIndex(genericOp);
+  if (failed(maybeSliceOperandAndIndex))
+    return failure();
+  OpOperand *sliceOperand = std::get<0>(*maybeSliceOperandAndIndex);
+  unsigned OperandIndex = std::get<1>(*maybeSliceOperandAndIndex);
+
+  if (!controlFn(sliceOperand))
+    return failure();
+
+  tensor::ExtractSliceOp producerSliceOp =
+      sliceOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
+  assert(producerSliceOp && "expect a valid UnPackOp");
+
+  if (producerSliceOp.getSource().getType().getRank() !=
+      producerSliceOp.getResult().getType().getRank()) {
+    return failure();
+  }
+
+  SmallVector<OpFoldResult> strides = producerSliceOp.getMixedStrides();
+  if (!areAllConstantIntValue(strides, 1))
+    return failure();
+
+  SmallVector<OpFoldResult> offsets = producerSliceOp.getMixedOffsets();
+  SmallVector<OpFoldResult> sizes = producerSliceOp.getMixedSizes();
+
+  // check if we can support the propagation of this extractSlice
+  // through the generic op and if so return the dimensions that
+
+  auto maybeNonZeroSliceDimMap =
+      getNonFullSliceDimInfo(genericOp, sliceOperand, producerSliceOp);
+
+  if (failed(maybeNonZeroSliceDimMap)) {
+    return failure();
+  }
+
+  auto nonZeroSliceDimMap = std::get<0>(*maybeNonZeroSliceDimMap);
+  bool hasNonZeroReductionDimSlice = std::get<1>(*maybeNonZeroSliceDimMap);
+
+  // Store the padding information as (dimPos, lowPad, highPad, PaddedShape).
+  Location loc = genericOp->getLoc();
+  AffineExpr dim0, dim1;
+  bindDims(rewriter.getContext(), dim0, dim1);
+  auto subMap = AffineMap::get(2, 0, {dim0 - dim1});
+  auto sub = [&](OpFoldResult v1, OpFoldResult v2) {
+    return affine::makeComposedFoldedAffineApply(rewriter, loc, subMap,
+                                                 {v1, v2});
+  };
+
+  MLIRContext *ctx = genericOp.getContext();
+  SmallVector<Value> paddedInputs;
+  for (auto [idx, operand] : llvm::enumerate(genericOp.getDpsInputOperands())) {
+    if (idx == OperandIndex && !hasNonZeroReductionDimSlice) {
+      paddedInputs.push_back(producerSliceOp.getSource());
+      continue;
+    }
+    AffineMap IndexingMap = genericOp.getMatchingIndexingMap(operand);
+    SmallVector<OpFoldResult> operandLowPads(IndexingMap.getNumResults(),
+                                             getAsIndexOpFoldResult(ctx, 0));
+    SmallVector<OpFoldResult> operandHighPads(IndexingMap.getNumResults(),
+                                              getAsIndexOpFoldResult(ctx, 0));
+    for (auto [idx, expr] : llvm::enumerate(IndexingMap.getResults())) {
+      if (!isa<AffineDimExpr>(expr)) {
+        continue;
+      }
+      AffineDimExpr dimExpr = cast<AffineDimExpr>(expr);
+      if (nonZeroSliceDimMap.contains(dimExpr.getPosition())) {
+        SliceDimInfo sliceDimInfo = nonZeroSliceDimMap[dimExpr.getPosition()];
+        operandLowPads[idx] = sliceDimInfo.offset;
+        operandHighPads[idx] =
+            sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset),
+                sliceDimInfo.sliceSize);
+      }
+    }
+    auto paddingValue = ub::PoisonOp::create(
+        rewriter, loc, getElementTypeOrSelf(operand->get().getType()));
+    auto paddedOperand = tensor::PadOp::create(
+        rewriter, loc, Type(), operand->get(), operandLowPads, operandHighPads,
+        paddingValue, /*nofold=*/false);
+    paddedInputs.push_back(paddedOperand);
+  }
+  AffineMap outputIndexingMap =
+      genericOp.getMatchingIndexingMap(genericOp.getDpsInitOperand(0));
+
+  auto outputShapeType =
+      llvm::cast<ShapedType>(genericOp.getDpsInitOperand(0)->get().getType());
+  SmallVector<OpFoldResult> OutputShape = llvm::map_to_vector(
+      outputShapeType.getShape(),
+      [&](int64_t sz) -> OpFoldResult { return rewriter.getIndexAttr(sz); });
+  SmallVector<OpFoldResult> newSizes = OutputShape;
+  SmallVector<OpFoldResult> outputLowPads(outputIndexingMap.getNumResults(),
+                                          getAsIndexOpFoldResult(ctx, 0));
+  SmallVector<OpFoldResult> outputHighPads(outputIndexingMap.getNumResults(),
+                                           getAsIndexOpFoldResult(ctx, 0));
+  SmallVector<OpFoldResult> newStrides(outputIndexingMap.getNumResults(),
+                                       getAsIndexOpFoldResult(ctx, 1));
+  for (auto [idx, expr] : llvm::enumerate(outputIndexingMap.getResults())) {
+    if (!isa<AffineDimExpr>(expr)) {
+      continue;
+    }
+    AffineDimExpr dimExpr = cast<AffineDimExpr>(expr);
+    if (nonZeroSliceDimMap.contains(dimExpr.getPosition())) {
+      SliceDimInfo sliceDimInfo = nonZeroSliceDimMap[dimExpr.getPosition()];
+      outputLowPads[idx] = sliceDimInfo.offset;
+      outputHighPads[idx] =
+          sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset),
+              sliceDimInfo.sliceSize);
+      OutputShape[idx] = sliceDimInfo.outputSize;
+      newSizes[idx] = sliceDimInfo.sliceSize;
+    }
+  }
+  Value newPadOutput;
+  auto outputElType =
+      getElementTypeOrSelf(genericOp.getDpsInits()[0].getType());
+  if (isGenericOutsNotUsed(genericOp)) {
+    newPadOutput =
+        tensor::EmptyOp::create(rewriter, loc, OutputShape, outputElType);
+
+  } else {
+
+    auto paddingValue = ub::PoisonOp::create(rewriter, loc, outputElType);
+    newPadOutput = tensor::PadOp::create(
+        rewriter, loc, Type(), genericOp.getDpsInits()[0], outputLowPads,
+        outputHighPads, paddingValue, /*nofold=*/false);
+  }
+
+  auto newGenericOp = linalg::GenericOp::create(
+      rewriter, loc, newPadOutput.getType(), paddedInputs, {newPadOutput},
+      genericOp.getIndexingMapsArray(), genericOp.getIteratorTypesArray(),
+      /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
+  rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(),
+                             newGenericOp.getRegion().begin());
+
+  auto extractOp = tensor::ExtractSliceOp::create(
+      rewriter, loc,
+      newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0)),
+      outputLowPads, newSizes, newStrides);
+  Value extractRes = extractOp.getResult();
+
+  return std::make_tuple(newGenericOp, extractRes);
+}
+
+class PushDownExtractSliceOpThroughGenericOp final
+    : public OpRewritePattern<GenericOp> {
+public:
+  PushDownExtractSliceOpThroughGenericOp(MLIRContext *context,
+                                         ControlPropagationFn fun)
+      : OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {}
+
+  LogicalResult matchAndRewrite(GenericOp genericOp,
+                                PatternRewriter &rewriter) const override {
+    auto genericAndRepl =
+        pushDownExtractSliceOpThroughGenericOp(rewriter, genericOp, controlFn);
+    if (failed(genericAndRepl))
+      return failure();
+    rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl));
+    return success();
+  }
+
+private:
+  ControlPropagationFn controlFn;
+};
+
 } // namespace
 
 void mlir::linalg::populateDataLayoutPropagationPatterns(
@@ -1247,3 +1512,10 @@ void mlir::linalg::populateDataLayoutPropagationPatterns(
               PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
           patterns.getContext(), controlPackUnPackPropagation);
 }
+
+void mlir::linalg::populateExtractSliceSinkingPatterns(
+    RewritePatternSet &patterns,
+    const ControlPropagationFn &controlPackUnPackPropagation) {
+  patterns.insert<PushDownExtractSliceOpThroughGenericOp>(
+      patterns.getContext(), controlPackUnPackPropagation);
+}
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index cc26fa48abf4b..723eecb52351b 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -1447,3 +1447,113 @@ func.func @push_unpack_in_padded_domain_out_used(%arg0: tensor<8x8x4x8xf32>, %ar
 // CHECK:         %[[UNPACK2:.+]] = linalg.unpack %[[GENERIC]]
 // CHECK-SAME:    into %[[ARG1]]
 // CHECK:         return %[[UNPACK2]] : tensor<?x64xf32>
+
+// -----
+
+module {
+  func.func @push_extract_through_generic(%arg0: tensor<128x7x128xf32>, %arg1: tensor<?x5x3x128xf32>, %arg2: tensor<?x5x128xbf16>, %arg3: index) -> tensor<?x5x128xbf16> {
+    %extracted_slice = tensor.extract_slice %arg0[0, 0, %arg3] [128, 7, %arg3] [1, 1, 1] : tensor<128x7x128xf32> to tensor<128x7x?xf32>
+    %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d2 + d3, d0)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%extracted_slice, %arg1 : tensor<128x7x?xf32>, tensor<?x5x3x128xf32>) outs(%arg2 : tensor<?x5x128xbf16>) {
+    ^bb0(%in: f32, %in_0: f32, %out: bf16):
+      %1 = arith.truncf %in : f32 to bf16
+      linalg.yield %1 : bf16
+    } -> tensor<?x5x128xbf16>
+    return %0 : tensor<?x5x128xbf16>
+  }
+}
+
+// CHECK-LABEL: func.func @push_extract_through_generic
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[ARG2:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[ARG3:[a-zA-Z0-9]+]]
+// CHECK:         %[[POISON:.+]] = ub.poison : f32
+// CHECK:         %[[PADDED:.+]] = tensor.pad %arg1
+// CHECK:           tensor.yield %[[POISON]] : f32
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<128x5x128xbf16>
+// CHECK:         %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME:    ins(%[[ARG0]], %[[PADDED]]   
+// CHECK-SAME:    outs(%[[EMPTY]]
+// CHECK:         %[[EXTRACT:.+]] = tensor.extract_slice %3[%[[ARG3]], 0, 0] [%[[ARG3]], 5, 128] [1, 1, 1] : tensor<128x5x128xbf16> to tensor<?x5x128xbf16>
+// CHECK:         return %[[EXTRACT]]
+
+// -----
+
+func.func @nopush_extract_through_generic_nodimexpr1(%arg0: tensor<128x7x128xf32>, %arg1: tensor<?x5x3x128xf32>, %arg2: tensor<?x5x128xbf16>, %arg3: index) -> tensor<?x5x128xbf16> {
+  %extracted_slice = tensor.extract_slice %arg0[0, %arg3, %arg3] [128, 7, %arg3] [1, 1, 1] : tensor<128x7x128xf32> to tensor<128x7x?xf32>
+  %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d2 + d3, d0)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%extracted_slice, %arg1 : tensor<128x7x?xf32>, tensor<?x5x3x128xf32>) outs(%arg2 : tensor<?x5x128xbf16>) {
+  ^bb0(%in: f32, %in_0: f32, %out: bf16):
+    %1 = arith.truncf %in : f32 to bf16
+    linalg.yield %1 : bf16
+  } -> tensor<?x5x128xbf16>
+  return %0 : tensor<?x5x128xbf16>
+}
+
+// CHECK-LABEL: func.func @nopush_extract_through_generic_nodimexpr1
+// CHECK:         %[[GENERIC:.+]] = linalg.generic
+// CHECK:         return %[[GENERIC]]          
+
+// -----
+
+func.func @nopush_extract_through_generic_nodimexpr2(%arg0: tensor<128x?x128xf32>, %arg1: tensor<128x5x3x128xf32>, %arg2: tensor<128x?x128xbf16>, %arg3: index) -> tensor<128x?x128xbf16> {
+  %extracted_slice = tensor.extract_slice %arg1[0, %arg3, 0, 0] [128, %arg3, 3, 128] [1, 1, 1, 1] : tensor<128x5x3x128xf32> to tensor<128x?x3x128xf32>
+  %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d2 + d3, d0)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %extracted_slice : tensor<128x?x128xf32>, tensor<128x?x3x128xf32>) outs(%arg2 : tensor<128x?x128xbf16>) {
+  ^bb0(%in: f32, %in_0: f32, %out: bf16):
+    %1 = arith.truncf %in : f32 to bf16
+    linalg.yield %1 : bf16
+  } -> tensor<128x?x128xbf16>
+  return %0 : tensor<128x?x128xbf16>
+}
+
+// CHECK-LABEL: func.func @nopush_extract_through_generic_nodimexpr2
+// CHECK:         %[[GENERIC:.+]] = linalg.generic
+// CHECK:         return %[[GENERIC]]   
+
+// -----
+
+func.func @push_redcutionextract_through_generic_withoutsused_2(%arg0: tensor<128x128xf32>, %arg1: tensor<?xbf16>, %arg2: index) -> tensor<?xbf16> {
+  %extracted_slice = tensor.extract_slice %arg0[%arg2, %arg2] [%arg2, %arg2] [1, 1] : tensor<128x128xf32> to tensor<?x?xf32>
+  %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%extracted_slice : tensor<?x?xf32>) outs(%arg1 : tensor<?xbf16>) {
+  ^bb0(%in: f32, %out: bf16):
+    %1 = arith.truncf %in : f32 to bf16
+    %2 = arith.addf %1, %out : bf16
+    linalg.yield %2 : bf16
+  } -> tensor<?xbf16>
+  return %0 : tensor<?xbf16>
+}
+
+// CHECK-LABEL: func.func @push_redcutionextract_through_generic_withoutsused_2
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[ARG2:[a-zA-Z0-9]+]]
+// CHECK:         %[[POISON_BF16:.+]] = ub.poison : bf16
+// CHECK:         %[[POISON_F32:.+]] = ub.poison : f32
+// CHECK:         %[[EXTRACT:.+]] = tensor.extract_slice %[[ARG0]][%[[ARG2]], %[[ARG2]]] [%[[ARG2]], %[[ARG2]]] [1, 1] : tensor<128x128xf32> to tensor<?x?xf32>
+// CHECK:         %[[PADDED:.+]] = tensor.pad %[[EXTRACT]]
+// CHECK:           tensor.yield %[[POISON_F32]] : f32
+// CHECK:         %[[APPLY2:.+]] = affine.apply #map()[%[[ARG2]]]
+// CHECK:         %[[PADDED1:.+]] = tensor.pad %[[ARG1]] low[%[[ARG2]]] high[%[[APPLY2]]]
+// CHECK:           tensor.yield %[[POISON_BF16]] : bf16
+// CHECK:         %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME:    ins(%[[PADDED]]
+// CHECK-SAME:    outs(%[[PADDED1]]
+// CHECK:         %[[EXTRACT1:.+]] = tensor.extract_slice %[[GENERIC]][%[[ARG2]]] [%[[ARG2]]] [1] : tensor<?xbf16> to tensor<?xbf16>
+// CHECK:         return %[[EXTRACT1]]
+
+
+// -----
+
+func.func @nopush_rankreducingextract(%arg0: tensor<128x128x128xf32>, %arg1: tensor<?xbf16>, %arg2: index) -> tensor<?xbf16> {
+  %extracted_slice = tensor.extract_slice %arg0[0, %arg2, %arg2] [1, %arg2, %arg2] [1, 1, 1] : tensor<128x128x128xf32> to tensor<?x?xf32>
+  %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%extracted_slice : tensor<?x?xf32>) outs(%arg1 : tensor<?xbf16>) {
+  ^bb0(%in: f32, %out: bf16):
+    %1 = arith.truncf %in : f32 to bf16
+    %2 = arith.addf %1, %out : bf16
+    linalg.yield %2 : bf16
+  } -> tensor<?xbf16>
+  return %0 : tensor<?xbf16>
+}
+
+// CHECK-LABEL: func.func @nopush_rankreducingextract
+// CHECK:         %[[GENERIC:.+]] = linalg.generic
+// CHECK:         return %[[GENERIC]]   
diff --git a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
index d0700f9a4f1a4..2cf25d8fc8c19 100644
--- a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
@@ -34,6 +34,8 @@ struct TestDataLayoutPropagationPass
     RewritePatternSet patterns(context);
     linalg::populateDataLayoutPropagationPatterns(
         patterns, [](OpOperand *opOperand) { return true; });
+    linalg::populateExtractSliceSinkingPatterns(
+        patterns, [](OpOperand *opOperand) { return true; });
     if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
       return signalPassFailure();
   }

>From 1493d56583ee5f5149a4157561486966f74faeaa Mon Sep 17 00:00:00 2001
From: Nirvedh Meshram <nirvedh at gmail.com>
Date: Mon, 25 Aug 2025 15:25:01 -0500
Subject: [PATCH 2/3] address reviwer comments

Signed-off-by: Nirvedh Meshram <nirvedh at gmail.com>
---
 .../Transforms/DataLayoutPropagation.cpp      | 127 +++++++++---------
 1 file changed, 65 insertions(+), 62 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index d50ab8cf03714..40085a2368009 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -1247,61 +1247,55 @@ struct SliceDimInfo {
 
 /// Return the first input extract slice operand, if present, for the current
 /// generic op.
-static FailureOr<std::tuple<OpOperand *, unsigned>>
-getSliceOperandAndIndex(GenericOp genericOp) {
+static FailureOr<OpOperand *> getSliceOperand(GenericOp genericOp) {
   OpOperand *sliceOperand = nullptr;
-  unsigned operandIndex;
-  for (auto [idx, operand] : llvm::enumerate(genericOp.getDpsInputOperands())) {
+  for (auto operand : genericOp.getDpsInputOperands()) {
     auto extractOp = operand->get().getDefiningOp<tensor::ExtractSliceOp>();
     if (!extractOp)
       continue;
     sliceOperand = operand;
-    operandIndex = idx;
     break;
   }
   if (!sliceOperand) {
     return failure();
   }
-  return std::make_tuple(sliceOperand, operandIndex);
+  return sliceOperand;
 }
 
-// Return a map of dims that have non full slices on them so that other operands
+// Return a map of dims that have partial slices on them so that other operands
 // can use this information. Also return a bool mentioning if a reduction dim
 // has a non full slice as that can be used to fold the original extract slice.
-static FailureOr<std::tuple<llvm::DenseMap<int64_t, SliceDimInfo>, bool>>
-getNonFullSliceDimInfo(GenericOp genericOp, OpOperand *sliceOperand,
-                       tensor::ExtractSliceOp producerSliceOp) {
-  llvm::DenseMap<int64_t, SliceDimInfo> nonZeroSliceDimMap;
-  bool hasNonZeroReductionDimSlice = false;
-  SmallVector<utils::IteratorType> iterators =
-      genericOp.getIteratorTypesArray();
+static FailureOr<llvm::DenseMap<int64_t, SliceDimInfo>>
+getPartialSliceDimInfo(GenericOp genericOp, OpOperand *sliceOperand) {
+  tensor::ExtractSliceOp producerSliceOp =
+      sliceOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
+  assert(producerSliceOp && "expect a valid ExtractSliceOp");
+  llvm::DenseMap<int64_t, SliceDimInfo> partialSliceDimMap;
   SmallVector<OpFoldResult> offsets = producerSliceOp.getMixedOffsets();
   SmallVector<OpFoldResult> sizes = producerSliceOp.getMixedSizes();
 
-  SmallVector<OpFoldResult> shape = llvm::map_to_vector(
-      producerSliceOp.getSourceType().getShape(),
-      [&](int64_t sz) -> OpFoldResult {
-        return getAsIndexOpFoldResult(genericOp.getContext(), sz);
-      });
+  SmallVector<OpFoldResult> shape = getAsIndexOpFoldResult(
+      genericOp.getContext(), producerSliceOp.getSourceType().getShape());
 
   for (auto [idx, expr] : llvm::enumerate(
            genericOp.getMatchingIndexingMap(sliceOperand).getResults())) {
+    // If we have a full slice in a dimension then we dont need to add it to
+    // the partial slice map.
     if (isConstantIntValue(offsets[idx], 0) &&
         isEqualConstantIntOrValue(sizes[idx], shape[idx])) {
       continue;
     }
+    // We only support partial slices of AffineDimExprs so bail-out if thats not
+    // the case.
     if (!isa<AffineDimExpr>(expr)) {
       return failure();
     }
     SliceDimInfo sliceDimInfo{offsets[idx], sizes[idx], shape[idx]};
     int64_t dimPos = cast<AffineDimExpr>(expr).getPosition();
-    nonZeroSliceDimMap[dimPos] = sliceDimInfo;
-    if (iterators[dimPos] == utils::IteratorType::reduction) {
-      hasNonZeroReductionDimSlice = true;
-    }
+    partialSliceDimMap[dimPos] = sliceDimInfo;
   }
-  // Next check if the dims with non zero slice info are used as non
-  // AffineDimExpr and if they are then bail-out.
+  // Next check if the dims with partial slice info are used in non
+  // AffineDimExpr in other operands and if they are then bail-out.
   for (OpOperand &operand : genericOp->getOpOperands()) {
     if (operand == *sliceOperand) {
       continue;
@@ -1313,7 +1307,7 @@ getNonFullSliceDimInfo(GenericOp genericOp, OpOperand *sliceOperand,
           }
           WalkResult status = expr.walk([&](AffineExpr expr) {
             if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
-              if (nonZeroSliceDimMap.count(dimExpr.getPosition()) != 0) {
+              if (partialSliceDimMap.contains(dimExpr.getPosition())) {
                 return WalkResult::interrupt();
               }
             }
@@ -1327,7 +1321,7 @@ getNonFullSliceDimInfo(GenericOp genericOp, OpOperand *sliceOperand,
       return failure();
     }
   }
-  return std::make_tuple(nonZeroSliceDimMap, hasNonZeroReductionDimSlice);
+  return partialSliceDimMap;
 }
 
 static FailureOr<std::tuple<GenericOp, Value>>
@@ -1335,47 +1329,57 @@ pushDownExtractSliceOpThroughGenericOp(RewriterBase &rewriter,
                                        GenericOp genericOp,
                                        ControlPropagationFn controlFn) {
   if (genericOp.getNumResults() != 1)
-    return failure();
+    return rewriter.notifyMatchFailure(
+        genericOp, "propagation through multi-result generic is unsupported.");
   if (hasGatherSemantics(genericOp))
+    return rewriter.notifyMatchFailure(
+        genericOp,
+        "propagation through generic with gather semantics is unsupported.");
+  // Collect the sliced operand, if present.
+  auto maybeSliceOperand = getSliceOperand(genericOp);
+  if (failed(maybeSliceOperand))
     return failure();
-  // Collect the unPacked operand, if present.
-  auto maybeSliceOperandAndIndex = getSliceOperandAndIndex(genericOp);
-  if (failed(maybeSliceOperandAndIndex))
-    return failure();
-  OpOperand *sliceOperand = std::get<0>(*maybeSliceOperandAndIndex);
-  unsigned OperandIndex = std::get<1>(*maybeSliceOperandAndIndex);
+  OpOperand *sliceOperand = *maybeSliceOperand;
+  unsigned OperandIndex = sliceOperand->getOperandNumber();
 
   if (!controlFn(sliceOperand))
     return failure();
 
   tensor::ExtractSliceOp producerSliceOp =
       sliceOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
-  assert(producerSliceOp && "expect a valid UnPackOp");
+  assert(producerSliceOp && "expect a valid ExtractSliceOp");
 
   if (producerSliceOp.getSource().getType().getRank() !=
       producerSliceOp.getResult().getType().getRank()) {
-    return failure();
+    return rewriter.notifyMatchFailure(
+        genericOp,
+        "propagation of rank-reducing extract slice is unsupported.");
   }
 
   SmallVector<OpFoldResult> strides = producerSliceOp.getMixedStrides();
   if (!areAllConstantIntValue(strides, 1))
-    return failure();
-
-  SmallVector<OpFoldResult> offsets = producerSliceOp.getMixedOffsets();
-  SmallVector<OpFoldResult> sizes = producerSliceOp.getMixedSizes();
+    return rewriter.notifyMatchFailure(
+        genericOp, "propagation of strided extract slice is unsupported.");
 
   // check if we can support the propagation of this extractSlice
   // through the generic op and if so return the dimensions that
 
-  auto maybeNonZeroSliceDimMap =
-      getNonFullSliceDimInfo(genericOp, sliceOperand, producerSliceOp);
+  auto maybePartialSliceDimMap =
+      getPartialSliceDimInfo(genericOp, sliceOperand);
 
-  if (failed(maybeNonZeroSliceDimMap)) {
+  if (failed(maybePartialSliceDimMap)) {
     return failure();
   }
 
-  auto nonZeroSliceDimMap = std::get<0>(*maybeNonZeroSliceDimMap);
-  bool hasNonZeroReductionDimSlice = std::get<1>(*maybeNonZeroSliceDimMap);
+  auto partialSliceDimMap = *maybePartialSliceDimMap;
+
+  SmallVector<utils::IteratorType> iterators =
+      genericOp.getIteratorTypesArray();
+  bool hasPartialReductionDimSlice =
+      llvm::any_of(partialSliceDimMap, [&](const auto &slice) {
+        int64_t sliceDim = slice.first;
+        return iterators[sliceDim] == utils::IteratorType::reduction;
+      });
 
   // Store the padding information as (dimPos, lowPad, highPad, PaddedShape).
   Location loc = genericOp->getLoc();
@@ -1390,7 +1394,7 @@ pushDownExtractSliceOpThroughGenericOp(RewriterBase &rewriter,
   MLIRContext *ctx = genericOp.getContext();
   SmallVector<Value> paddedInputs;
   for (auto [idx, operand] : llvm::enumerate(genericOp.getDpsInputOperands())) {
-    if (idx == OperandIndex && !hasNonZeroReductionDimSlice) {
+    if (idx == OperandIndex && !hasPartialReductionDimSlice) {
       paddedInputs.push_back(producerSliceOp.getSource());
       continue;
     }
@@ -1404,13 +1408,14 @@ pushDownExtractSliceOpThroughGenericOp(RewriterBase &rewriter,
         continue;
       }
       AffineDimExpr dimExpr = cast<AffineDimExpr>(expr);
-      if (nonZeroSliceDimMap.contains(dimExpr.getPosition())) {
-        SliceDimInfo sliceDimInfo = nonZeroSliceDimMap[dimExpr.getPosition()];
-        operandLowPads[idx] = sliceDimInfo.offset;
-        operandHighPads[idx] =
-            sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset),
-                sliceDimInfo.sliceSize);
+      if (!partialSliceDimMap.contains(dimExpr.getPosition())) {
+        continue;
       }
+      SliceDimInfo sliceDimInfo = partialSliceDimMap[dimExpr.getPosition()];
+      operandLowPads[idx] = sliceDimInfo.offset;
+      operandHighPads[idx] =
+          sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset),
+              sliceDimInfo.sliceSize);
     }
     auto paddingValue = ub::PoisonOp::create(
         rewriter, loc, getElementTypeOrSelf(operand->get().getType()));
@@ -1439,15 +1444,15 @@ pushDownExtractSliceOpThroughGenericOp(RewriterBase &rewriter,
       continue;
     }
     AffineDimExpr dimExpr = cast<AffineDimExpr>(expr);
-    if (nonZeroSliceDimMap.contains(dimExpr.getPosition())) {
-      SliceDimInfo sliceDimInfo = nonZeroSliceDimMap[dimExpr.getPosition()];
-      outputLowPads[idx] = sliceDimInfo.offset;
-      outputHighPads[idx] =
-          sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset),
-              sliceDimInfo.sliceSize);
-      OutputShape[idx] = sliceDimInfo.outputSize;
-      newSizes[idx] = sliceDimInfo.sliceSize;
+    if (!partialSliceDimMap.contains(dimExpr.getPosition())) {
+      continue;
     }
+    SliceDimInfo sliceDimInfo = partialSliceDimMap[dimExpr.getPosition()];
+    outputLowPads[idx] = sliceDimInfo.offset;
+    outputHighPads[idx] = sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset),
+                              sliceDimInfo.sliceSize);
+    OutputShape[idx] = sliceDimInfo.outputSize;
+    newSizes[idx] = sliceDimInfo.sliceSize;
   }
   Value newPadOutput;
   auto outputElType =
@@ -1455,9 +1460,7 @@ pushDownExtractSliceOpThroughGenericOp(RewriterBase &rewriter,
   if (isGenericOutsNotUsed(genericOp)) {
     newPadOutput =
         tensor::EmptyOp::create(rewriter, loc, OutputShape, outputElType);
-
   } else {
-
     auto paddingValue = ub::PoisonOp::create(rewriter, loc, outputElType);
     newPadOutput = tensor::PadOp::create(
         rewriter, loc, Type(), genericOp.getDpsInits()[0], outputLowPads,

>From f08b03cc96077b7c6a7e3a3d20dab4d1bf158f91 Mon Sep 17 00:00:00 2001
From: Nirvedh Meshram <nirvedh at gmail.com>
Date: Mon, 25 Aug 2025 16:06:58 -0500
Subject: [PATCH 3/3] add shape types for pads

Signed-off-by: Nirvedh Meshram <nirvedh at gmail.com>
---
 mlir/test/Dialect/Linalg/data-layout-propagation.mlir | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index 723eecb52351b..0e42027644797 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -1470,6 +1470,7 @@ module {
 // CHECK:         %[[POISON:.+]] = ub.poison : f32
 // CHECK:         %[[PADDED:.+]] = tensor.pad %arg1
 // CHECK:           tensor.yield %[[POISON]] : f32
+// CHECK:         } : tensor<?x5x3x128xf32> to tensor<?x5x3x128xf32>
 // CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<128x5x128xbf16>
 // CHECK:         %[[GENERIC:.+]] = linalg.generic
 // CHECK-SAME:    ins(%[[ARG0]], %[[PADDED]]   
@@ -1531,9 +1532,11 @@ func.func @push_redcutionextract_through_generic_withoutsused_2(%arg0: tensor<12
 // CHECK:         %[[EXTRACT:.+]] = tensor.extract_slice %[[ARG0]][%[[ARG2]], %[[ARG2]]] [%[[ARG2]], %[[ARG2]]] [1, 1] : tensor<128x128xf32> to tensor<?x?xf32>
 // CHECK:         %[[PADDED:.+]] = tensor.pad %[[EXTRACT]]
 // CHECK:           tensor.yield %[[POISON_F32]] : f32
+// CHECK:         } : tensor<?x?xf32> to tensor<?x?xf32>
 // CHECK:         %[[APPLY2:.+]] = affine.apply #map()[%[[ARG2]]]
 // CHECK:         %[[PADDED1:.+]] = tensor.pad %[[ARG1]] low[%[[ARG2]]] high[%[[APPLY2]]]
 // CHECK:           tensor.yield %[[POISON_BF16]] : bf16
+// CHECK:         } : tensor<?xbf16> to tensor<?xbf16>
 // CHECK:         %[[GENERIC:.+]] = linalg.generic
 // CHECK-SAME:    ins(%[[PADDED]]
 // CHECK-SAME:    outs(%[[PADDED1]]



More information about the Mlir-commits mailing list