[Mlir-commits] [mlir] de6f750 - [Linalg] Add pattern to push down extract slice through linalg generic op (#154162)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Aug 27 12:47:51 PDT 2025


Author: Nirvedh Meshram
Date: 2025-08-27T14:47:47-05:00
New Revision: de6f7504c83d8a61e78372ea4da92414f95efa8f

URL: https://github.com/llvm/llvm-project/commit/de6f7504c83d8a61e78372ea4da92414f95efa8f
DIFF: https://github.com/llvm/llvm-project/commit/de6f7504c83d8a61e78372ea4da92414f95efa8f.diff

LOG: [Linalg] Add pattern to push down extract slice through linalg generic op (#154162)

This PR adds a datalayout propagation pattern to push down extract slice
through generic op. It adds a different populate function since there
may be conditions where a user doesn't want this pattern but wants the
other patterns e.g. extract slice is used as a special op when it comes
to tiling.

---------

Signed-off-by: Nirvedh Meshram <nirvedh at gmail.com>

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
    mlir/test/Dialect/Linalg/data-layout-propagation.mlir
    mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 8d7b892cb69ef..64d3a2448b409 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1918,6 +1918,11 @@ void populateDataLayoutPropagationPatterns(
     RewritePatternSet &patterns,
     const ControlPropagationFn &controlPackUnPackPropagation);
 
+/// Patterns to sink extract slice across other operations.
+void populateExtractSliceSinkingPatterns(
+    RewritePatternSet &patterns,
+    const ControlPropagationFn &controlPackUnPackPropagation);
+
 /// Pattern to remove dead operands and results of `linalg.generic` operations.
 /// This is a pattern wrapper for `deduplicateOperandsAndRemoveDeadResults`.
 void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns);

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 0a9c1766425bd..40085a2368009 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -6,10 +6,12 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/IR/Dominance.h"
 #include "llvm/ADT/SetOperations.h"
@@ -1236,6 +1238,272 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
   ControlPropagationFn controlFn;
 };
 
+// This struct contains infomation about extract_slice dims.
+struct SliceDimInfo {
+  OpFoldResult offset;
+  OpFoldResult sliceSize;
+  OpFoldResult outputSize;
+};
+
+/// Return the first input extract slice operand, if present, for the current
+/// generic op.
+static FailureOr<OpOperand *> getSliceOperand(GenericOp genericOp) {
+  OpOperand *sliceOperand = nullptr;
+  for (auto operand : genericOp.getDpsInputOperands()) {
+    auto extractOp = operand->get().getDefiningOp<tensor::ExtractSliceOp>();
+    if (!extractOp)
+      continue;
+    sliceOperand = operand;
+    break;
+  }
+  if (!sliceOperand) {
+    return failure();
+  }
+  return sliceOperand;
+}
+
+// Return a map of dims that have partial slices on them so that other operands
+// can use this information. Also return a bool mentioning if a reduction dim
+// has a non full slice as that can be used to fold the original extract slice.
+static FailureOr<llvm::DenseMap<int64_t, SliceDimInfo>>
+getPartialSliceDimInfo(GenericOp genericOp, OpOperand *sliceOperand) {
+  tensor::ExtractSliceOp producerSliceOp =
+      sliceOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
+  assert(producerSliceOp && "expect a valid ExtractSliceOp");
+  llvm::DenseMap<int64_t, SliceDimInfo> partialSliceDimMap;
+  SmallVector<OpFoldResult> offsets = producerSliceOp.getMixedOffsets();
+  SmallVector<OpFoldResult> sizes = producerSliceOp.getMixedSizes();
+
+  SmallVector<OpFoldResult> shape = getAsIndexOpFoldResult(
+      genericOp.getContext(), producerSliceOp.getSourceType().getShape());
+
+  for (auto [idx, expr] : llvm::enumerate(
+           genericOp.getMatchingIndexingMap(sliceOperand).getResults())) {
+    // If we have a full slice in a dimension then we dont need to add it to
+    // the partial slice map.
+    if (isConstantIntValue(offsets[idx], 0) &&
+        isEqualConstantIntOrValue(sizes[idx], shape[idx])) {
+      continue;
+    }
+    // We only support partial slices of AffineDimExprs so bail-out if thats not
+    // the case.
+    if (!isa<AffineDimExpr>(expr)) {
+      return failure();
+    }
+    SliceDimInfo sliceDimInfo{offsets[idx], sizes[idx], shape[idx]};
+    int64_t dimPos = cast<AffineDimExpr>(expr).getPosition();
+    partialSliceDimMap[dimPos] = sliceDimInfo;
+  }
+  // Next check if the dims with partial slice info are used in non
+  // AffineDimExpr in other operands and if they are then bail-out.
+  for (OpOperand &operand : genericOp->getOpOperands()) {
+    if (operand == *sliceOperand) {
+      continue;
+    }
+    AffineMap IndexingMap = genericOp.getMatchingIndexingMap(&operand);
+    if (llvm::any_of(IndexingMap.getResults(), [&](AffineExpr expr) {
+          if (isa<AffineDimExpr>(expr)) {
+            return false;
+          }
+          WalkResult status = expr.walk([&](AffineExpr expr) {
+            if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
+              if (partialSliceDimMap.contains(dimExpr.getPosition())) {
+                return WalkResult::interrupt();
+              }
+            }
+            return WalkResult::advance();
+          });
+          if (status.wasInterrupted()) {
+            return true;
+          }
+          return false;
+        })) {
+      return failure();
+    }
+  }
+  return partialSliceDimMap;
+}
+
+static FailureOr<std::tuple<GenericOp, Value>>
+pushDownExtractSliceOpThroughGenericOp(RewriterBase &rewriter,
+                                       GenericOp genericOp,
+                                       ControlPropagationFn controlFn) {
+  if (genericOp.getNumResults() != 1)
+    return rewriter.notifyMatchFailure(
+        genericOp, "propagation through multi-result generic is unsupported.");
+  if (hasGatherSemantics(genericOp))
+    return rewriter.notifyMatchFailure(
+        genericOp,
+        "propagation through generic with gather semantics is unsupported.");
+  // Collect the sliced operand, if present.
+  auto maybeSliceOperand = getSliceOperand(genericOp);
+  if (failed(maybeSliceOperand))
+    return failure();
+  OpOperand *sliceOperand = *maybeSliceOperand;
+  unsigned OperandIndex = sliceOperand->getOperandNumber();
+
+  if (!controlFn(sliceOperand))
+    return failure();
+
+  tensor::ExtractSliceOp producerSliceOp =
+      sliceOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
+  assert(producerSliceOp && "expect a valid ExtractSliceOp");
+
+  if (producerSliceOp.getSource().getType().getRank() !=
+      producerSliceOp.getResult().getType().getRank()) {
+    return rewriter.notifyMatchFailure(
+        genericOp,
+        "propagation of rank-reducing extract slice is unsupported.");
+  }
+
+  SmallVector<OpFoldResult> strides = producerSliceOp.getMixedStrides();
+  if (!areAllConstantIntValue(strides, 1))
+    return rewriter.notifyMatchFailure(
+        genericOp, "propagation of strided extract slice is unsupported.");
+
+  // check if we can support the propagation of this extractSlice
+  // through the generic op and if so return the dimensions that
+
+  auto maybePartialSliceDimMap =
+      getPartialSliceDimInfo(genericOp, sliceOperand);
+
+  if (failed(maybePartialSliceDimMap)) {
+    return failure();
+  }
+
+  auto partialSliceDimMap = *maybePartialSliceDimMap;
+
+  SmallVector<utils::IteratorType> iterators =
+      genericOp.getIteratorTypesArray();
+  bool hasPartialReductionDimSlice =
+      llvm::any_of(partialSliceDimMap, [&](const auto &slice) {
+        int64_t sliceDim = slice.first;
+        return iterators[sliceDim] == utils::IteratorType::reduction;
+      });
+
+  // Store the padding information as (dimPos, lowPad, highPad, PaddedShape).
+  Location loc = genericOp->getLoc();
+  AffineExpr dim0, dim1;
+  bindDims(rewriter.getContext(), dim0, dim1);
+  auto subMap = AffineMap::get(2, 0, {dim0 - dim1});
+  auto sub = [&](OpFoldResult v1, OpFoldResult v2) {
+    return affine::makeComposedFoldedAffineApply(rewriter, loc, subMap,
+                                                 {v1, v2});
+  };
+
+  MLIRContext *ctx = genericOp.getContext();
+  SmallVector<Value> paddedInputs;
+  for (auto [idx, operand] : llvm::enumerate(genericOp.getDpsInputOperands())) {
+    if (idx == OperandIndex && !hasPartialReductionDimSlice) {
+      paddedInputs.push_back(producerSliceOp.getSource());
+      continue;
+    }
+    AffineMap IndexingMap = genericOp.getMatchingIndexingMap(operand);
+    SmallVector<OpFoldResult> operandLowPads(IndexingMap.getNumResults(),
+                                             getAsIndexOpFoldResult(ctx, 0));
+    SmallVector<OpFoldResult> operandHighPads(IndexingMap.getNumResults(),
+                                              getAsIndexOpFoldResult(ctx, 0));
+    for (auto [idx, expr] : llvm::enumerate(IndexingMap.getResults())) {
+      if (!isa<AffineDimExpr>(expr)) {
+        continue;
+      }
+      AffineDimExpr dimExpr = cast<AffineDimExpr>(expr);
+      if (!partialSliceDimMap.contains(dimExpr.getPosition())) {
+        continue;
+      }
+      SliceDimInfo sliceDimInfo = partialSliceDimMap[dimExpr.getPosition()];
+      operandLowPads[idx] = sliceDimInfo.offset;
+      operandHighPads[idx] =
+          sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset),
+              sliceDimInfo.sliceSize);
+    }
+    auto paddingValue = ub::PoisonOp::create(
+        rewriter, loc, getElementTypeOrSelf(operand->get().getType()));
+    auto paddedOperand = tensor::PadOp::create(
+        rewriter, loc, Type(), operand->get(), operandLowPads, operandHighPads,
+        paddingValue, /*nofold=*/false);
+    paddedInputs.push_back(paddedOperand);
+  }
+  AffineMap outputIndexingMap =
+      genericOp.getMatchingIndexingMap(genericOp.getDpsInitOperand(0));
+
+  auto outputShapeType =
+      llvm::cast<ShapedType>(genericOp.getDpsInitOperand(0)->get().getType());
+  SmallVector<OpFoldResult> OutputShape = llvm::map_to_vector(
+      outputShapeType.getShape(),
+      [&](int64_t sz) -> OpFoldResult { return rewriter.getIndexAttr(sz); });
+  SmallVector<OpFoldResult> newSizes = OutputShape;
+  SmallVector<OpFoldResult> outputLowPads(outputIndexingMap.getNumResults(),
+                                          getAsIndexOpFoldResult(ctx, 0));
+  SmallVector<OpFoldResult> outputHighPads(outputIndexingMap.getNumResults(),
+                                           getAsIndexOpFoldResult(ctx, 0));
+  SmallVector<OpFoldResult> newStrides(outputIndexingMap.getNumResults(),
+                                       getAsIndexOpFoldResult(ctx, 1));
+  for (auto [idx, expr] : llvm::enumerate(outputIndexingMap.getResults())) {
+    if (!isa<AffineDimExpr>(expr)) {
+      continue;
+    }
+    AffineDimExpr dimExpr = cast<AffineDimExpr>(expr);
+    if (!partialSliceDimMap.contains(dimExpr.getPosition())) {
+      continue;
+    }
+    SliceDimInfo sliceDimInfo = partialSliceDimMap[dimExpr.getPosition()];
+    outputLowPads[idx] = sliceDimInfo.offset;
+    outputHighPads[idx] = sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset),
+                              sliceDimInfo.sliceSize);
+    OutputShape[idx] = sliceDimInfo.outputSize;
+    newSizes[idx] = sliceDimInfo.sliceSize;
+  }
+  Value newPadOutput;
+  auto outputElType =
+      getElementTypeOrSelf(genericOp.getDpsInits()[0].getType());
+  if (isGenericOutsNotUsed(genericOp)) {
+    newPadOutput =
+        tensor::EmptyOp::create(rewriter, loc, OutputShape, outputElType);
+  } else {
+    auto paddingValue = ub::PoisonOp::create(rewriter, loc, outputElType);
+    newPadOutput = tensor::PadOp::create(
+        rewriter, loc, Type(), genericOp.getDpsInits()[0], outputLowPads,
+        outputHighPads, paddingValue, /*nofold=*/false);
+  }
+
+  auto newGenericOp = linalg::GenericOp::create(
+      rewriter, loc, newPadOutput.getType(), paddedInputs, {newPadOutput},
+      genericOp.getIndexingMapsArray(), genericOp.getIteratorTypesArray(),
+      /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
+  rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(),
+                             newGenericOp.getRegion().begin());
+
+  auto extractOp = tensor::ExtractSliceOp::create(
+      rewriter, loc,
+      newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0)),
+      outputLowPads, newSizes, newStrides);
+  Value extractRes = extractOp.getResult();
+
+  return std::make_tuple(newGenericOp, extractRes);
+}
+
+class PushDownExtractSliceOpThroughGenericOp final
+    : public OpRewritePattern<GenericOp> {
+public:
+  PushDownExtractSliceOpThroughGenericOp(MLIRContext *context,
+                                         ControlPropagationFn fun)
+      : OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {}
+
+  LogicalResult matchAndRewrite(GenericOp genericOp,
+                                PatternRewriter &rewriter) const override {
+    auto genericAndRepl =
+        pushDownExtractSliceOpThroughGenericOp(rewriter, genericOp, controlFn);
+    if (failed(genericAndRepl))
+      return failure();
+    rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl));
+    return success();
+  }
+
+private:
+  ControlPropagationFn controlFn;
+};
+
 } // namespace
 
 void mlir::linalg::populateDataLayoutPropagationPatterns(
@@ -1247,3 +1515,10 @@ void mlir::linalg::populateDataLayoutPropagationPatterns(
               PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
           patterns.getContext(), controlPackUnPackPropagation);
 }
+
+void mlir::linalg::populateExtractSliceSinkingPatterns(
+    RewritePatternSet &patterns,
+    const ControlPropagationFn &controlPackUnPackPropagation) {
+  patterns.insert<PushDownExtractSliceOpThroughGenericOp>(
+      patterns.getContext(), controlPackUnPackPropagation);
+}

diff  --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index cc26fa48abf4b..0e42027644797 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -1447,3 +1447,116 @@ func.func @push_unpack_in_padded_domain_out_used(%arg0: tensor<8x8x4x8xf32>, %ar
 // CHECK:         %[[UNPACK2:.+]] = linalg.unpack %[[GENERIC]]
 // CHECK-SAME:    into %[[ARG1]]
 // CHECK:         return %[[UNPACK2]] : tensor<?x64xf32>
+
+// -----
+
+module {
+  func.func @push_extract_through_generic(%arg0: tensor<128x7x128xf32>, %arg1: tensor<?x5x3x128xf32>, %arg2: tensor<?x5x128xbf16>, %arg3: index) -> tensor<?x5x128xbf16> {
+    %extracted_slice = tensor.extract_slice %arg0[0, 0, %arg3] [128, 7, %arg3] [1, 1, 1] : tensor<128x7x128xf32> to tensor<128x7x?xf32>
+    %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d2 + d3, d0)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%extracted_slice, %arg1 : tensor<128x7x?xf32>, tensor<?x5x3x128xf32>) outs(%arg2 : tensor<?x5x128xbf16>) {
+    ^bb0(%in: f32, %in_0: f32, %out: bf16):
+      %1 = arith.truncf %in : f32 to bf16
+      linalg.yield %1 : bf16
+    } -> tensor<?x5x128xbf16>
+    return %0 : tensor<?x5x128xbf16>
+  }
+}
+
+// CHECK-LABEL: func.func @push_extract_through_generic
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[ARG2:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[ARG3:[a-zA-Z0-9]+]]
+// CHECK:         %[[POISON:.+]] = ub.poison : f32
+// CHECK:         %[[PADDED:.+]] = tensor.pad %arg1
+// CHECK:           tensor.yield %[[POISON]] : f32
+// CHECK:         } : tensor<?x5x3x128xf32> to tensor<?x5x3x128xf32>
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<128x5x128xbf16>
+// CHECK:         %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME:    ins(%[[ARG0]], %[[PADDED]]   
+// CHECK-SAME:    outs(%[[EMPTY]]
+// CHECK:         %[[EXTRACT:.+]] = tensor.extract_slice %3[%[[ARG3]], 0, 0] [%[[ARG3]], 5, 128] [1, 1, 1] : tensor<128x5x128xbf16> to tensor<?x5x128xbf16>
+// CHECK:         return %[[EXTRACT]]
+
+// -----
+
+func.func @nopush_extract_through_generic_nodimexpr1(%arg0: tensor<128x7x128xf32>, %arg1: tensor<?x5x3x128xf32>, %arg2: tensor<?x5x128xbf16>, %arg3: index) -> tensor<?x5x128xbf16> {
+  %extracted_slice = tensor.extract_slice %arg0[0, %arg3, %arg3] [128, 7, %arg3] [1, 1, 1] : tensor<128x7x128xf32> to tensor<128x7x?xf32>
+  %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d2 + d3, d0)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%extracted_slice, %arg1 : tensor<128x7x?xf32>, tensor<?x5x3x128xf32>) outs(%arg2 : tensor<?x5x128xbf16>) {
+  ^bb0(%in: f32, %in_0: f32, %out: bf16):
+    %1 = arith.truncf %in : f32 to bf16
+    linalg.yield %1 : bf16
+  } -> tensor<?x5x128xbf16>
+  return %0 : tensor<?x5x128xbf16>
+}
+
+// CHECK-LABEL: func.func @nopush_extract_through_generic_nodimexpr1
+// CHECK:         %[[GENERIC:.+]] = linalg.generic
+// CHECK:         return %[[GENERIC]]          
+
+// -----
+
+func.func @nopush_extract_through_generic_nodimexpr2(%arg0: tensor<128x?x128xf32>, %arg1: tensor<128x5x3x128xf32>, %arg2: tensor<128x?x128xbf16>, %arg3: index) -> tensor<128x?x128xbf16> {
+  %extracted_slice = tensor.extract_slice %arg1[0, %arg3, 0, 0] [128, %arg3, 3, 128] [1, 1, 1, 1] : tensor<128x5x3x128xf32> to tensor<128x?x3x128xf32>
+  %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d2 + d3, d0)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %extracted_slice : tensor<128x?x128xf32>, tensor<128x?x3x128xf32>) outs(%arg2 : tensor<128x?x128xbf16>) {
+  ^bb0(%in: f32, %in_0: f32, %out: bf16):
+    %1 = arith.truncf %in : f32 to bf16
+    linalg.yield %1 : bf16
+  } -> tensor<128x?x128xbf16>
+  return %0 : tensor<128x?x128xbf16>
+}
+
+// CHECK-LABEL: func.func @nopush_extract_through_generic_nodimexpr2
+// CHECK:         %[[GENERIC:.+]] = linalg.generic
+// CHECK:         return %[[GENERIC]]   
+
+// -----
+
+func.func @push_redcutionextract_through_generic_withoutsused_2(%arg0: tensor<128x128xf32>, %arg1: tensor<?xbf16>, %arg2: index) -> tensor<?xbf16> {
+  %extracted_slice = tensor.extract_slice %arg0[%arg2, %arg2] [%arg2, %arg2] [1, 1] : tensor<128x128xf32> to tensor<?x?xf32>
+  %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%extracted_slice : tensor<?x?xf32>) outs(%arg1 : tensor<?xbf16>) {
+  ^bb0(%in: f32, %out: bf16):
+    %1 = arith.truncf %in : f32 to bf16
+    %2 = arith.addf %1, %out : bf16
+    linalg.yield %2 : bf16
+  } -> tensor<?xbf16>
+  return %0 : tensor<?xbf16>
+}
+
+// CHECK-LABEL: func.func @push_redcutionextract_through_generic_withoutsused_2
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[ARG2:[a-zA-Z0-9]+]]
+// CHECK:         %[[POISON_BF16:.+]] = ub.poison : bf16
+// CHECK:         %[[POISON_F32:.+]] = ub.poison : f32
+// CHECK:         %[[EXTRACT:.+]] = tensor.extract_slice %[[ARG0]][%[[ARG2]], %[[ARG2]]] [%[[ARG2]], %[[ARG2]]] [1, 1] : tensor<128x128xf32> to tensor<?x?xf32>
+// CHECK:         %[[PADDED:.+]] = tensor.pad %[[EXTRACT]]
+// CHECK:           tensor.yield %[[POISON_F32]] : f32
+// CHECK:         } : tensor<?x?xf32> to tensor<?x?xf32>
+// CHECK:         %[[APPLY2:.+]] = affine.apply #map()[%[[ARG2]]]
+// CHECK:         %[[PADDED1:.+]] = tensor.pad %[[ARG1]] low[%[[ARG2]]] high[%[[APPLY2]]]
+// CHECK:           tensor.yield %[[POISON_BF16]] : bf16
+// CHECK:         } : tensor<?xbf16> to tensor<?xbf16>
+// CHECK:         %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME:    ins(%[[PADDED]]
+// CHECK-SAME:    outs(%[[PADDED1]]
+// CHECK:         %[[EXTRACT1:.+]] = tensor.extract_slice %[[GENERIC]][%[[ARG2]]] [%[[ARG2]]] [1] : tensor<?xbf16> to tensor<?xbf16>
+// CHECK:         return %[[EXTRACT1]]
+
+
+// -----
+
+func.func @nopush_rankreducingextract(%arg0: tensor<128x128x128xf32>, %arg1: tensor<?xbf16>, %arg2: index) -> tensor<?xbf16> {
+  %extracted_slice = tensor.extract_slice %arg0[0, %arg2, %arg2] [1, %arg2, %arg2] [1, 1, 1] : tensor<128x128x128xf32> to tensor<?x?xf32>
+  %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%extracted_slice : tensor<?x?xf32>) outs(%arg1 : tensor<?xbf16>) {
+  ^bb0(%in: f32, %out: bf16):
+    %1 = arith.truncf %in : f32 to bf16
+    %2 = arith.addf %1, %out : bf16
+    linalg.yield %2 : bf16
+  } -> tensor<?xbf16>
+  return %0 : tensor<?xbf16>
+}
+
+// CHECK-LABEL: func.func @nopush_rankreducingextract
+// CHECK:         %[[GENERIC:.+]] = linalg.generic
+// CHECK:         return %[[GENERIC]]   

diff  --git a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
index d0700f9a4f1a4..2cf25d8fc8c19 100644
--- a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
@@ -34,6 +34,8 @@ struct TestDataLayoutPropagationPass
     RewritePatternSet patterns(context);
     linalg::populateDataLayoutPropagationPatterns(
         patterns, [](OpOperand *opOperand) { return true; });
+    linalg::populateExtractSliceSinkingPatterns(
+        patterns, [](OpOperand *opOperand) { return true; });
     if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
       return signalPassFailure();
   }


        


More information about the Mlir-commits mailing list