[Mlir-commits] [mlir] [mlir][tensor] Add bubble up extract_slice pattern for TilingInterfac… (PR #179967)

Tomer Solomon llvmlistbot at llvm.org
Thu Feb 5 08:02:29 PST 2026


https://github.com/recursion-man created https://github.com/llvm/llvm-project/pull/179967

Add a new pattern `populateBubbleUpExtractSliceThroughTilingInterfacePatterns`
that bubbles up `tensor.extract_slice` operations through any producer
implementing `TilingInterface`.

This pattern is more general than the existing Linalg-specific bubble up
pattern (`populateBubbleUpExtractSliceOpPatterns`) as it:

1. Works with any operation implementing TilingInterface, not just Linalg ops.
2. Supports multiple non-overlapping extract_slice consumers of the same
   producer, creating separate tiled operations for each slice.
3. Prevention of Re-computation by verifying that consumers are strictly non-overlapping. 
4. Provides an optional control function to filter which slices should be
   transformed.

The transformation reduces computation by creating smaller/tiled operations
that compute only the slices actually needed by consumers. For example:

Before:
  %0 = linalg.generic ... -> tensor<16x16xf32>
  %1 = tensor.extract_slice %0[0, 0] [4, 8] [1, 1] -> tensor<4x8xf32>
  %2 = tensor.extract_slice %0[0, 8] [4, 8] [1, 1] -> tensor<4x8xf32>

After:
  %0 = linalg.generic ... -> tensor<4x8xf32>  // tiled for slice 1
  %1 = linalg.generic ... -> tensor<4x8xf32>  // tiled for slice 2

This is useful for scenarios where a large tensor operation is followed by
multiple slice extractions, allowing each slice to be computed independently
with reduced memory and compute requirements.

>From 8e7c8a54eeb0210ce5417ec2e8b8957454d2aeef Mon Sep 17 00:00:00 2001
From: Tomer Solomon <tomer.solomon at mobileye.com>
Date: Thu, 5 Feb 2026 17:49:33 +0200
Subject: [PATCH] [mlir][tensor] Add bubble up extract_slice pattern for
 TilingInterface ops

Add a new pattern `populateBubbleUpExtractSliceThroughTilingInterfacePatterns`that bubbles up `tensor.extract_slice` operations through any producerimplementing `TilingInterface`.This pattern is more general than the existing Linalg-specific bubble uppattern (`populateBubbleUpExtractSliceOpPatterns`) as it:1. Works with any operation implementing TilingInterface, not just Linalg ops2. Supports multiple non-overlapping extract_slice consumers of the same   producer, creating separate tiled operations for each slice3. Uses `ValueBoundsConstraintSet::areOverlappingSlices` to verify slices   don't overlap before transformation4. Provides an optional control function to filter which slices should be   transformedThe transformation reduces computation by creating smaller/tiled operationsthat compute only the slices actually needed by consumers. For example:Before:  %0 = linalg.generic ... -> tensor<16x16xf32>  %1 = tensor.extract_slice %0[0, 0] [4, 8] [1, 1] -> tensor<4x8xf32>  %2 = tensor.extract_slice %0[0, 8] [4, 8] [1, 1] -> tensor<4x8xf32>After:  %0 = linalg.generic ... -> tensor<4x8xf32>  // tiled for slice 1  %1 = linalg.generic ... -> tensor<4x8xf32>  // tiled for slice 2This is useful for scenarios where a large tensor operation is followed bymultiple slice extractions, allowing each slice to be computed independentlywith reduced memory and compute requirements.
---
 .../Dialect/Tensor/Transforms/Transforms.h    |  17 ++
 ...leUpExtractSliceThroughTilingInterface.cpp | 217 ++++++++++++++++++
 .../Dialect/Tensor/Transforms/CMakeLists.txt  |   1 +
 ...xtract-slice-through-tiling-interface.mlir | 149 ++++++++++++
 .../Dialect/Tensor/TestTensorTransforms.cpp   |  15 ++
 5 files changed, 399 insertions(+)
 create mode 100644 mlir/lib/Dialect/Tensor/Transforms/BubbleUpExtractSliceThroughTilingInterface.cpp
 create mode 100644 mlir/test/Dialect/Tensor/bubble-up-extract-slice-through-tiling-interface.mlir

diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index 3e4da94bd714e..767789ac8597c 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -11,6 +11,7 @@
 
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/TilingInterface.h"
 #include "mlir/Interfaces/ViewLikeInterface.h"
 
 namespace mlir {
@@ -68,6 +69,22 @@ void populateMergeConsecutiveInsertExtractSlicePatterns(
 /// tiling interface.
 void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns);
 
+/// Appends patterns to bubble up `tensor.extract_slice` through any operation
+/// that implements the `TilingInterface`. This pattern handles multiple
+/// non-overlapping extract_slice consumers and uses
+/// `TilingInterface::generateResultTileValue` to create tiled implementations.
+///
+/// The optional `controlFn` can be used to filter which slices should be
+/// transformed. It is called for each extract_slice with the producer op and
+/// the slice op. Return success() to allow transformation, failure() to skip.
+///
+/// This is more general than the Linalg-specific bubble up pattern as it
+/// works with any TilingInterface operation, not just Linalg ops.
+void populateBubbleUpExtractSliceThroughTilingInterfacePatterns(
+    RewritePatternSet &patterns,
+    function_ref<LogicalResult(TilingInterface, tensor::ExtractSliceOp)>
+        controlFn = nullptr);
+
 /// Populates `patterns` with patterns that drop redundant tensor.insert_slice
 /// rank expansions.
 void populateDropRedundantInsertSliceRankExpansionPatterns(
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BubbleUpExtractSliceThroughTilingInterface.cpp b/mlir/lib/Dialect/Tensor/Transforms/BubbleUpExtractSliceThroughTilingInterface.cpp
new file mode 100644
index 0000000000000..9643d3542596b
--- /dev/null
+++ b/mlir/lib/Dialect/Tensor/Transforms/BubbleUpExtractSliceThroughTilingInterface.cpp
@@ -0,0 +1,217 @@
+//===- BubbleUpExtractSliceThroughTilingInterface.cpp ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements patterns to bubble up `tensor.extract_slice` operations
+// through producers that implement the `TilingInterface`. Unlike the Linalg-
+// specific bubble up pattern, this works with any operation implementing
+// TilingInterface and supports multiple non-overlapping extract_slice consumers.
+//
+// The transformation reduces computation by creating smaller/tiled operations
+// that compute only the slices actually needed by consumers.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/TilingInterface.h"
+#include "mlir/Interfaces/ValueBoundsOpInterface.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "bubble-up-extract-slice-tiling-interface"
+
+using namespace mlir;
+using namespace mlir::tensor;
+
+namespace {
+
+/// Check if any two slices in the list overlap.
+/// Returns failure() if overlap analysis cannot be determined.
+/// Returns false if slices are guaranteed non-overlapping.
+/// Returns true if slices may overlap.
+static FailureOr<bool>
+hasOverlappingSlices(MLIRContext *ctx,
+                     SmallVectorImpl<tensor::ExtractSliceOp> &slices) {
+  for (size_t i = 0; i < slices.size(); ++i) {
+    for (size_t j = i + 1; j < slices.size(); ++j) {
+      HyperrectangularSlice slice1(
+          cast<OffsetSizeAndStrideOpInterface>(slices[i].getOperation()));
+      HyperrectangularSlice slice2(
+          cast<OffsetSizeAndStrideOpInterface>(slices[j].getOperation()));
+
+      FailureOr<bool> overlapping =
+          ValueBoundsConstraintSet::areOverlappingSlices(ctx, slice1, slice2);
+      if (failed(overlapping)) {
+        LLVM_DEBUG(llvm::dbgs()
+                   << "Could not determine if slices overlap at indices " << i
+                   << " and " << j << "\n");
+        return failure();
+      }
+      if (*overlapping) {
+        LLVM_DEBUG(llvm::dbgs()
+                   << "Found overlapping slices at indices " << i << " and "
+                   << j << "\n");
+        return true;
+      }
+    }
+  }
+  return false;
+}
+
+/// Pattern to bubble up extract_slice through operations implementing
+/// TilingInterface.
+///
+/// Matches: TilingInterface op whose output is consumed only by non-overlapping
+/// extract_slice ops.
+///
+/// Transforms to: tiled operations (via TilingInterface::generateResultTileValue)
+/// that each compute only a specific output slice, with extract_slice operations
+/// applied to the inputs.
+///
+/// For example:
+///
+/// Before:
+/// ```mlir
+/// %0 = "some.tiling_interface_op"(%input) : (tensor<1x9450x256xf32>)
+///                                           -> tensor<1x9450x256xf32>
+/// %1 = tensor.extract_slice %0[0, 0, 0] [1, 7200, 256] [1, 1, 1]
+///      : tensor<1x9450x256xf32> to tensor<1x7200x256xf32>
+/// %2 = tensor.extract_slice %0[0, 7200, 0] [1, 2250, 256] [1, 1, 1]
+///      : tensor<1x9450x256xf32> to tensor<1x2250x256xf32>
+/// ```
+///
+/// After:
+/// ```mlir
+/// %input0 = tensor.extract_slice %input[0, 0, 0] [1, 7200, 256] [1, 1, 1]
+///           : tensor<1x9450x256xf32> to tensor<1x7200x256xf32>
+/// %0 = "some.tiling_interface_op"(%input0) : (tensor<1x7200x256xf32>)
+///                                            -> tensor<1x7200x256xf32>
+/// %input1 = tensor.extract_slice %input[0, 7200, 0] [1, 2250, 256] [1, 1, 1]
+///           : tensor<1x9450x256xf32> to tensor<1x2250x256xf32>
+/// %1 = "some.tiling_interface_op"(%input1) : (tensor<1x2250x256xf32>)
+///                                            -> tensor<1x2250x256xf32>
+/// ```
+struct BubbleUpExtractSliceThroughTilingInterface
+    : public OpInterfaceRewritePattern<TilingInterface> {
+  BubbleUpExtractSliceThroughTilingInterface(
+      MLIRContext *context,
+      function_ref<LogicalResult(TilingInterface, tensor::ExtractSliceOp)>
+          controlFn = nullptr,
+      PatternBenefit benefit = 1)
+      : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
+        controlFn(controlFn) {}
+
+  LogicalResult matchAndRewrite(TilingInterface producerOp,
+                                PatternRewriter &rewriter) const override {
+    // Only support operations with a single result for now.
+    if (producerOp->getNumResults() != 1)
+      return rewriter.notifyMatchFailure(producerOp,
+                                         "expected single result operation");
+
+    OpResult output = producerOp->getResult(0);
+    auto outputType = dyn_cast<RankedTensorType>(output.getType());
+    if (!outputType)
+      return rewriter.notifyMatchFailure(producerOp,
+                                         "expected ranked tensor result");
+
+    LLVM_DEBUG(llvm::dbgs() << "Checking TilingInterface op: " << *producerOp
+                            << "\n");
+
+    // Collect all extract_slice users.
+    SmallVector<tensor::ExtractSliceOp> extractSlices;
+    for (Operation *user : output.getUsers()) {
+      auto extractSlice = dyn_cast<tensor::ExtractSliceOp>(user);
+      if (!extractSlice)
+        return rewriter.notifyMatchFailure(
+            producerOp, "result has non-extract_slice consumer");
+      extractSlices.push_back(extractSlice);
+    }
+
+    // Sort slices by their position in the block to ensure deterministic
+    // processing order. This maintains the correspondence between original
+    // extract_slice ops and their replacements.
+    llvm::sort(extractSlices, [](tensor::ExtractSliceOp a,
+                                 tensor::ExtractSliceOp b) {
+      return a->isBeforeInBlock(b);
+    });
+
+    if (extractSlices.empty())
+      return rewriter.notifyMatchFailure(producerOp,
+                                         "no extract_slice consumers");
+
+    LLVM_DEBUG(llvm::dbgs() << "Found " << extractSlices.size()
+                            << " extract_slice consumers\n");
+
+    // Check for overlapping slices when there are multiple consumers.
+    // Overlapping slices would cause redundant computation.
+    if (extractSlices.size() > 1) {
+      FailureOr<bool> hasOverlaps =
+          hasOverlappingSlices(rewriter.getContext(), extractSlices);
+      if (failed(hasOverlaps))
+        return rewriter.notifyMatchFailure(
+            producerOp, "could not determine slice overlaps");
+      if (*hasOverlaps)
+        return rewriter.notifyMatchFailure(
+            producerOp, "extract_slices have overlapping regions");
+
+      LLVM_DEBUG(llvm::dbgs() << "No overlapping slices detected\n");
+    }
+
+    // Apply the control function to verify each slice is suitable for tiling.
+    if (controlFn) {
+      for (tensor::ExtractSliceOp extractSlice : extractSlices) {
+        if (failed(controlFn(producerOp, extractSlice)))
+          return rewriter.notifyMatchFailure(
+              producerOp, "slice rejected by control function");
+      }
+    }
+
+    // For each extract_slice, create a tiled producer.
+    // Collect all results first before replacing to avoid iterator invalidation.
+    SmallVector<std::pair<tensor::ExtractSliceOp, Value>> replacements;
+    for (tensor::ExtractSliceOp extractSlice : extractSlices) {
+      FailureOr<TilingResult> tilingResult =
+          tensor::replaceExtractSliceWithTiledProducer(rewriter, extractSlice,
+                                                       output);
+      if (failed(tilingResult))
+        return rewriter.notifyMatchFailure(
+            producerOp, "failed to generate tiled implementation");
+
+      if (tilingResult->tiledValues.empty())
+        return rewriter.notifyMatchFailure(
+            producerOp, "tiling produced no values");
+
+      replacements.emplace_back(extractSlice, tilingResult->tiledValues[0]);
+    }
+
+    // Replace all extract_slices with tiled values.
+    for (auto &[extractSlice, tiledValue] : replacements) {
+      rewriter.replaceOp(extractSlice, tiledValue);
+      LLVM_DEBUG(llvm::dbgs() << "Replaced extract_slice with tiled value\n");
+    }
+
+    return success();
+  }
+
+private:
+  /// Optional callback to control which slices should be transformed.
+  /// Called for each extract_slice with the producer op and the slice op.
+  /// Return success() to allow the transformation, failure() to skip.
+  function_ref<LogicalResult(TilingInterface, tensor::ExtractSliceOp)>
+      controlFn;
+};
+
+} // namespace
+
+void mlir::tensor::populateBubbleUpExtractSliceThroughTilingInterfacePatterns(
+    RewritePatternSet &patterns,
+    function_ref<LogicalResult(TilingInterface, tensor::ExtractSliceOp)>
+        controlFn) {
+  patterns.add<BubbleUpExtractSliceThroughTilingInterface>(
+      patterns.getContext(), controlFn);
+}
diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
index 99e1c4fec8467..261c0e88de972 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_mlir_dialect_library(MLIRTensorTransforms
+  BubbleUpExtractSliceThroughTilingInterface.cpp
   BufferizableOpInterfaceImpl.cpp
   ConcatOpPatterns.cpp
   EmptyOpPatterns.cpp
diff --git a/mlir/test/Dialect/Tensor/bubble-up-extract-slice-through-tiling-interface.mlir b/mlir/test/Dialect/Tensor/bubble-up-extract-slice-through-tiling-interface.mlir
new file mode 100644
index 0000000000000..223166fb7216b
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/bubble-up-extract-slice-through-tiling-interface.mlir
@@ -0,0 +1,149 @@
+// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-bubble-up-extract-slice-through-tiling-interface %s | FileCheck %s
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d1)>
+
+// CHECK-LABEL: func.func @bubble_single_slice(
+//   CHECK-DAG:   %[[SLICE0:.*]] = tensor.extract_slice %arg0[4, 8] [4, 4] [1, 1]
+//   CHECK-DAG:   %[[SLICE1:.*]] = tensor.extract_slice %arg1[8] [4] [1]
+//   CHECK-DAG:   %[[SLICE2:.*]] = tensor.extract_slice %arg0[4, 8] [4, 4] [1, 1]
+//       CHECK:   %[[GENERIC:.*]] = linalg.generic
+//  CHECK-SAME:       ins(%[[SLICE0]], %[[SLICE1]] : tensor<4x4xf32>, tensor<4xf32>)
+//  CHECK-SAME:       outs(%[[SLICE2]] : tensor<4x4xf32>)
+//       CHECK:   return %[[GENERIC]]
+func.func @bubble_single_slice(%arg0: tensor<16x16xf32>, %arg1: tensor<16xf32>) -> tensor<4x4xf32> {
+  %0 = linalg.generic {
+      indexing_maps = [#map, #map1, #map],
+      iterator_types = ["parallel", "parallel"]}
+      ins(%arg0, %arg1 : tensor<16x16xf32>, tensor<16xf32>)
+      outs(%arg0 : tensor<16x16xf32>) {
+    ^bb0(%in: f32, %in_0: f32, %out: f32):
+      %1 = arith.addf %in, %in_0 : f32
+      linalg.yield %1 : f32
+  } -> tensor<16x16xf32>
+  %1 = tensor.extract_slice %0[4, 8] [4, 4] [1, 1] : tensor<16x16xf32> to tensor<4x4xf32>
+  return %1 : tensor<4x4xf32>
+}
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d1)>
+
+// CHECK-LABEL: func.func @bubble_multiple_non_overlapping_slices(
+//       CHECK:   %[[S0:.*]] = tensor.extract_slice %arg0[0, 0] [4, 8] [1, 1]
+//       CHECK:   %[[S1:.*]] = tensor.extract_slice %arg1[0] [8] [1]
+//       CHECK:   %[[S2:.*]] = tensor.extract_slice %arg0[0, 0] [4, 8] [1, 1]
+//       CHECK:   %[[GENERIC1:.*]] = linalg.generic
+//  CHECK-SAME:       ins(%[[S0]], %[[S1]] : tensor<4x8xf32>, tensor<8xf32>)
+//  CHECK-SAME:       outs(%[[S2]] : tensor<4x8xf32>)
+//       CHECK:   %[[S3:.*]] = tensor.extract_slice %arg0[0, 8] [4, 8] [1, 1]
+//       CHECK:   %[[S4:.*]] = tensor.extract_slice %arg1[8] [8] [1]
+//       CHECK:   %[[S5:.*]] = tensor.extract_slice %arg0[0, 8] [4, 8] [1, 1]
+//       CHECK:   %[[GENERIC2:.*]] = linalg.generic
+//  CHECK-SAME:       ins(%[[S3]], %[[S4]] : tensor<4x8xf32>, tensor<8xf32>)
+//  CHECK-SAME:       outs(%[[S5]] : tensor<4x8xf32>)
+//       CHECK:   return %[[GENERIC1]], %[[GENERIC2]]
+func.func @bubble_multiple_non_overlapping_slices(%arg0: tensor<16x16xf32>, %arg1: tensor<16xf32>)
+    -> (tensor<4x8xf32>, tensor<4x8xf32>) {
+  %0 = linalg.generic {
+      indexing_maps = [#map, #map1, #map],
+      iterator_types = ["parallel", "parallel"]}
+      ins(%arg0, %arg1 : tensor<16x16xf32>, tensor<16xf32>)
+      outs(%arg0 : tensor<16x16xf32>) {
+    ^bb0(%in: f32, %in_0: f32, %out: f32):
+      %1 = arith.addf %in, %in_0 : f32
+      linalg.yield %1 : f32
+  } -> tensor<16x16xf32>
+  %1 = tensor.extract_slice %0[0, 0] [4, 8] [1, 1] : tensor<16x16xf32> to tensor<4x8xf32>
+  %2 = tensor.extract_slice %0[0, 8] [4, 8] [1, 1] : tensor<16x16xf32> to tensor<4x8xf32>
+  return %1, %2 : tensor<4x8xf32>, tensor<4x8xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @bubble_through_matmul(
+//   CHECK-DAG:   %[[LHS_SLICE:.*]] = tensor.extract_slice %arg0[2, 0] [4, 8] [1, 1]
+//   CHECK-DAG:   %[[RHS_SLICE:.*]] = tensor.extract_slice %arg1[0, 2] [8, 4] [1, 1]
+//   CHECK-DAG:   %[[DST_SLICE:.*]] = tensor.extract_slice %arg2[2, 2] [4, 4] [1, 1]
+//       CHECK:   %[[MATMUL:.*]] = linalg.matmul ins(%[[LHS_SLICE]], %[[RHS_SLICE]] :
+//       CHECK:   return %[[MATMUL]]
+func.func @bubble_through_matmul(%lhs: tensor<8x8xf32>, %rhs: tensor<8x8xf32>,
+                                 %dst: tensor<8x8xf32>) -> tensor<4x4xf32> {
+  %0 = linalg.matmul ins(%lhs, %rhs : tensor<8x8xf32>, tensor<8x8xf32>)
+                     outs(%dst : tensor<8x8xf32>) -> tensor<8x8xf32>
+  %1 = tensor.extract_slice %0[2, 2] [4, 4] [1, 1] : tensor<8x8xf32> to tensor<4x4xf32>
+  return %1 : tensor<4x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @bubble_through_fill(
+//       CHECK:   %[[CST:.*]] = arith.constant 1.000000e+00 : f32
+//       CHECK:   %[[EMPTY:.*]] = tensor.empty()
+//       CHECK:   %[[SLICE:.*]] = tensor.extract_slice %[[EMPTY]][4, 4] [4, 4] [1, 1]
+//       CHECK:   %[[FILL:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[SLICE]] : tensor<4x4xf32>)
+//       CHECK:   return %[[FILL]]
+func.func @bubble_through_fill() -> tensor<4x4xf32> {
+  %cst = arith.constant 1.0 : f32
+  %empty = tensor.empty() : tensor<16x16xf32>
+  %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<16x16xf32>) -> tensor<16x16xf32>
+  %slice = tensor.extract_slice %fill[4, 4] [4, 4] [1, 1] : tensor<16x16xf32> to tensor<4x4xf32>
+  return %slice : tensor<4x4xf32>
+}
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d1)>
+
+// CHECK-LABEL: func.func @bubble_with_dynamic_dims(
+//       CHECK:   %[[SLICE0:.*]] = tensor.extract_slice %arg0[%arg2, %arg3] [%arg4, %arg5] [1, 1]
+//       CHECK:   %[[SLICE1:.*]] = tensor.extract_slice %arg1[%arg3] [%arg5] [1]
+//       CHECK:   %[[SLICE2:.*]] = tensor.extract_slice %arg0[%arg2, %arg3] [%arg4, %arg5] [1, 1]
+//       CHECK:   %[[GENERIC:.*]] = linalg.generic
+//  CHECK-SAME:       ins(%[[SLICE0]], %[[SLICE1]] : tensor<?x?xf32>, tensor<?xf32>)
+//  CHECK-SAME:       outs(%[[SLICE2]] : tensor<?x?xf32>)
+//       CHECK:   return %[[GENERIC]]
+func.func @bubble_with_dynamic_dims(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf32>,
+                                    %off0: index, %off1: index,
+                                    %sz0: index, %sz1: index) -> tensor<?x?xf32> {
+  %0 = linalg.generic {
+      indexing_maps = [#map, #map1, #map],
+      iterator_types = ["parallel", "parallel"]}
+      ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?xf32>)
+      outs(%arg0 : tensor<?x?xf32>) {
+    ^bb0(%in: f32, %in_0: f32, %out: f32):
+      %1 = arith.addf %in, %in_0 : f32
+      linalg.yield %1 : f32
+  } -> tensor<?x?xf32>
+  %1 = tensor.extract_slice %0[%off0, %off1] [%sz0, %sz1] [1, 1]
+      : tensor<?x?xf32> to tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d1)>
+
+/// Negative test: result has non-extract_slice consumer.
+
+// CHECK-LABEL: func.func @no_bubble_non_slice_consumer(
+//       CHECK:   %[[GENERIC:.*]] = linalg.generic
+//  CHECK-SAME:       ins(%arg0, %arg1 : tensor<16x16xf32>, tensor<16xf32>)
+//  CHECK-SAME:       outs(%arg0 : tensor<16x16xf32>)
+//       CHECK:   return %[[GENERIC]] : tensor<16x16xf32>
+func.func @no_bubble_non_slice_consumer(%arg0: tensor<16x16xf32>,
+                                        %arg1: tensor<16xf32>) -> tensor<16x16xf32> {
+  %0 = linalg.generic {
+      indexing_maps = [#map, #map1, #map],
+      iterator_types = ["parallel", "parallel"]}
+      ins(%arg0, %arg1 : tensor<16x16xf32>, tensor<16xf32>)
+      outs(%arg0 : tensor<16x16xf32>) {
+    ^bb0(%in: f32, %in_0: f32, %out: f32):
+      %1 = arith.addf %in, %in_0 : f32
+      linalg.yield %1 : f32
+  } -> tensor<16x16xf32>
+  return %0 : tensor<16x16xf32>
+}
diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
index 687473ebe6d60..9c2145aa7188c 100644
--- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
@@ -93,6 +93,12 @@ struct TestTensorTransforms
       *this, "test-tracking-listener",
       llvm::cl::desc("Test tensor TrackingListener for the transform dialect"),
       llvm::cl::init(false)};
+
+  Option<bool> testBubbleUpExtractSliceThroughTilingInterface{
+      *this, "test-bubble-up-extract-slice-through-tiling-interface",
+      llvm::cl::desc("Test bubbling up tensor.extract_slice through operations "
+                     "implementing TilingInterface"),
+      llvm::cl::init(false)};
 };
 } // namespace
 
@@ -143,6 +149,13 @@ static void applyFoldExtractFromCollapseShapePatterns(Operation *rootOp) {
   (void)applyPatternsGreedily(rootOp, std::move(patterns));
 }
 
+static void
+applyBubbleUpExtractSliceThroughTilingInterfacePatterns(Operation *rootOp) {
+  RewritePatternSet patterns(rootOp->getContext());
+  tensor::populateBubbleUpExtractSliceThroughTilingInterfacePatterns(patterns);
+  (void)applyPatternsGreedily(rootOp, std::move(patterns));
+}
+
 namespace {
 /// Base pattern to rewrite  a `tensor.collapse_shape -> tensor.extract_slice`.
 /// The `tensor.extract_slice` is replaced by a loop or gather operation that
@@ -397,6 +410,8 @@ void TestTensorTransforms::runOnOperation() {
   if (testTrackingListener)
     if (failed(testTrackingListenerReplacements(rootOp)))
       return signalPassFailure();
+  if (testBubbleUpExtractSliceThroughTilingInterface)
+    applyBubbleUpExtractSliceThroughTilingInterfacePatterns(rootOp);
 }
 
 namespace mlir {



More information about the Mlir-commits mailing list