[Mlir-commits] [mlir] [mlir][tensor] Add bubble up extract_slice pattern for TilingInterfac… (PR #179967)
Tomer Solomon
llvmlistbot at llvm.org
Thu Feb 5 08:02:29 PST 2026
https://github.com/recursion-man created https://github.com/llvm/llvm-project/pull/179967
Add a new pattern `populateBubbleUpExtractSliceThroughTilingInterfacePatterns`
that bubbles up `tensor.extract_slice` operations through any producer
implementing `TilingInterface`.
This pattern is more general than the existing Linalg-specific bubble up
pattern (`populateBubbleUpExtractSliceOpPatterns`) as it:
1. Works with any operation implementing TilingInterface, not just Linalg ops.
2. Supports multiple non-overlapping extract_slice consumers of the same
producer, creating separate tiled operations for each slice.
3. Prevention of Re-computation by verifying that consumers are strictly non-overlapping.
4. Provides an optional control function to filter which slices should be
transformed.
The transformation reduces computation by creating smaller/tiled operations
that compute only the slices actually needed by consumers. For example:
Before:
%0 = linalg.generic ... -> tensor<16x16xf32>
%1 = tensor.extract_slice %0[0, 0] [4, 8] [1, 1] -> tensor<4x8xf32>
%2 = tensor.extract_slice %0[0, 8] [4, 8] [1, 1] -> tensor<4x8xf32>
After:
%0 = linalg.generic ... -> tensor<4x8xf32> // tiled for slice 1
%1 = linalg.generic ... -> tensor<4x8xf32> // tiled for slice 2
This is useful for scenarios where a large tensor operation is followed by
multiple slice extractions, allowing each slice to be computed independently
with reduced memory and compute requirements.
>From 8e7c8a54eeb0210ce5417ec2e8b8957454d2aeef Mon Sep 17 00:00:00 2001
From: Tomer Solomon <tomer.solomon at mobileye.com>
Date: Thu, 5 Feb 2026 17:49:33 +0200
Subject: [PATCH] [mlir][tensor] Add bubble up extract_slice pattern for
TilingInterface ops
Add a new pattern `populateBubbleUpExtractSliceThroughTilingInterfacePatterns`that bubbles up `tensor.extract_slice` operations through any producerimplementing `TilingInterface`.This pattern is more general than the existing Linalg-specific bubble uppattern (`populateBubbleUpExtractSliceOpPatterns`) as it:1. Works with any operation implementing TilingInterface, not just Linalg ops2. Supports multiple non-overlapping extract_slice consumers of the same producer, creating separate tiled operations for each slice3. Uses `ValueBoundsConstraintSet::areOverlappingSlices` to verify slices don't overlap before transformation4. Provides an optional control function to filter which slices should be transformedThe transformation reduces computation by creating smaller/tiled operationsthat compute only the slices actually needed by consumers. For example:Before: %0 = linalg.generic ... -> tensor<16x16xf32> %1 = tensor.extract_slice %0[0, 0] [4, 8] [1, 1] -> tensor<4x8xf32> %2 = tensor.extract_slice %0[0, 8] [4, 8] [1, 1] -> tensor<4x8xf32>After: %0 = linalg.generic ... -> tensor<4x8xf32> // tiled for slice 1 %1 = linalg.generic ... -> tensor<4x8xf32> // tiled for slice 2This is useful for scenarios where a large tensor operation is followed bymultiple slice extractions, allowing each slice to be computed independentlywith reduced memory and compute requirements.
---
.../Dialect/Tensor/Transforms/Transforms.h | 17 ++
...leUpExtractSliceThroughTilingInterface.cpp | 217 ++++++++++++++++++
.../Dialect/Tensor/Transforms/CMakeLists.txt | 1 +
...xtract-slice-through-tiling-interface.mlir | 149 ++++++++++++
.../Dialect/Tensor/TestTensorTransforms.cpp | 15 ++
5 files changed, 399 insertions(+)
create mode 100644 mlir/lib/Dialect/Tensor/Transforms/BubbleUpExtractSliceThroughTilingInterface.cpp
create mode 100644 mlir/test/Dialect/Tensor/bubble-up-extract-slice-through-tiling-interface.mlir
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index 3e4da94bd714e..767789ac8597c 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -11,6 +11,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
namespace mlir {
@@ -68,6 +69,22 @@ void populateMergeConsecutiveInsertExtractSlicePatterns(
/// tiling interface.
void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns);
+/// Appends patterns to bubble up `tensor.extract_slice` through any operation
+/// that implements the `TilingInterface`. This pattern handles multiple
+/// non-overlapping extract_slice consumers and uses
+/// `TilingInterface::generateResultTileValue` to create tiled implementations.
+///
+/// The optional `controlFn` can be used to filter which slices should be
+/// transformed. It is called for each extract_slice with the producer op and
+/// the slice op. Return success() to allow transformation, failure() to skip.
+///
+/// This is more general than the Linalg-specific bubble up pattern as it
+/// works with any TilingInterface operation, not just Linalg ops.
+void populateBubbleUpExtractSliceThroughTilingInterfacePatterns(
+ RewritePatternSet &patterns,
+ function_ref<LogicalResult(TilingInterface, tensor::ExtractSliceOp)>
+ controlFn = nullptr);
+
/// Populates `patterns` with patterns that drop redundant tensor.insert_slice
/// rank expansions.
void populateDropRedundantInsertSliceRankExpansionPatterns(
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BubbleUpExtractSliceThroughTilingInterface.cpp b/mlir/lib/Dialect/Tensor/Transforms/BubbleUpExtractSliceThroughTilingInterface.cpp
new file mode 100644
index 0000000000000..9643d3542596b
--- /dev/null
+++ b/mlir/lib/Dialect/Tensor/Transforms/BubbleUpExtractSliceThroughTilingInterface.cpp
@@ -0,0 +1,217 @@
+//===- BubbleUpExtractSliceThroughTilingInterface.cpp ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements patterns to bubble up `tensor.extract_slice` operations
+// through producers that implement the `TilingInterface`. Unlike the Linalg-
+// specific bubble up pattern, this works with any operation implementing
+// TilingInterface and supports multiple non-overlapping extract_slice consumers.
+//
+// The transformation reduces computation by creating smaller/tiled operations
+// that compute only the slices actually needed by consumers.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/TilingInterface.h"
+#include "mlir/Interfaces/ValueBoundsOpInterface.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "bubble-up-extract-slice-tiling-interface"
+
+using namespace mlir;
+using namespace mlir::tensor;
+
+namespace {
+
+/// Check if any two slices in the list overlap.
+/// Returns failure() if overlap analysis cannot be determined.
+/// Returns false if slices are guaranteed non-overlapping.
+/// Returns true if slices may overlap.
+static FailureOr<bool>
+hasOverlappingSlices(MLIRContext *ctx,
+ SmallVectorImpl<tensor::ExtractSliceOp> &slices) {
+ for (size_t i = 0; i < slices.size(); ++i) {
+ for (size_t j = i + 1; j < slices.size(); ++j) {
+ HyperrectangularSlice slice1(
+ cast<OffsetSizeAndStrideOpInterface>(slices[i].getOperation()));
+ HyperrectangularSlice slice2(
+ cast<OffsetSizeAndStrideOpInterface>(slices[j].getOperation()));
+
+ FailureOr<bool> overlapping =
+ ValueBoundsConstraintSet::areOverlappingSlices(ctx, slice1, slice2);
+ if (failed(overlapping)) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "Could not determine if slices overlap at indices " << i
+ << " and " << j << "\n");
+ return failure();
+ }
+ if (*overlapping) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "Found overlapping slices at indices " << i << " and "
+ << j << "\n");
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
+/// Pattern to bubble up extract_slice through operations implementing
+/// TilingInterface.
+///
+/// Matches: TilingInterface op whose output is consumed only by non-overlapping
+/// extract_slice ops.
+///
+/// Transforms to: tiled operations (via TilingInterface::generateResultTileValue)
+/// that each compute only a specific output slice, with extract_slice operations
+/// applied to the inputs.
+///
+/// For example:
+///
+/// Before:
+/// ```mlir
+/// %0 = "some.tiling_interface_op"(%input) : (tensor<1x9450x256xf32>)
+/// -> tensor<1x9450x256xf32>
+/// %1 = tensor.extract_slice %0[0, 0, 0] [1, 7200, 256] [1, 1, 1]
+/// : tensor<1x9450x256xf32> to tensor<1x7200x256xf32>
+/// %2 = tensor.extract_slice %0[0, 7200, 0] [1, 2250, 256] [1, 1, 1]
+/// : tensor<1x9450x256xf32> to tensor<1x2250x256xf32>
+/// ```
+///
+/// After:
+/// ```mlir
+/// %input0 = tensor.extract_slice %input[0, 0, 0] [1, 7200, 256] [1, 1, 1]
+/// : tensor<1x9450x256xf32> to tensor<1x7200x256xf32>
+/// %0 = "some.tiling_interface_op"(%input0) : (tensor<1x7200x256xf32>)
+/// -> tensor<1x7200x256xf32>
+/// %input1 = tensor.extract_slice %input[0, 7200, 0] [1, 2250, 256] [1, 1, 1]
+/// : tensor<1x9450x256xf32> to tensor<1x2250x256xf32>
+/// %1 = "some.tiling_interface_op"(%input1) : (tensor<1x2250x256xf32>)
+/// -> tensor<1x2250x256xf32>
+/// ```
+struct BubbleUpExtractSliceThroughTilingInterface
+ : public OpInterfaceRewritePattern<TilingInterface> {
+ BubbleUpExtractSliceThroughTilingInterface(
+ MLIRContext *context,
+ function_ref<LogicalResult(TilingInterface, tensor::ExtractSliceOp)>
+ controlFn = nullptr,
+ PatternBenefit benefit = 1)
+ : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
+ controlFn(controlFn) {}
+
+ LogicalResult matchAndRewrite(TilingInterface producerOp,
+ PatternRewriter &rewriter) const override {
+ // Only support operations with a single result for now.
+ if (producerOp->getNumResults() != 1)
+ return rewriter.notifyMatchFailure(producerOp,
+ "expected single result operation");
+
+ OpResult output = producerOp->getResult(0);
+ auto outputType = dyn_cast<RankedTensorType>(output.getType());
+ if (!outputType)
+ return rewriter.notifyMatchFailure(producerOp,
+ "expected ranked tensor result");
+
+ LLVM_DEBUG(llvm::dbgs() << "Checking TilingInterface op: " << *producerOp
+ << "\n");
+
+ // Collect all extract_slice users.
+ SmallVector<tensor::ExtractSliceOp> extractSlices;
+ for (Operation *user : output.getUsers()) {
+ auto extractSlice = dyn_cast<tensor::ExtractSliceOp>(user);
+ if (!extractSlice)
+ return rewriter.notifyMatchFailure(
+ producerOp, "result has non-extract_slice consumer");
+ extractSlices.push_back(extractSlice);
+ }
+
+ // Sort slices by their position in the block to ensure deterministic
+ // processing order. This maintains the correspondence between original
+ // extract_slice ops and their replacements.
+ llvm::sort(extractSlices, [](tensor::ExtractSliceOp a,
+ tensor::ExtractSliceOp b) {
+ return a->isBeforeInBlock(b);
+ });
+
+ if (extractSlices.empty())
+ return rewriter.notifyMatchFailure(producerOp,
+ "no extract_slice consumers");
+
+ LLVM_DEBUG(llvm::dbgs() << "Found " << extractSlices.size()
+ << " extract_slice consumers\n");
+
+ // Check for overlapping slices when there are multiple consumers.
+ // Overlapping slices would cause redundant computation.
+ if (extractSlices.size() > 1) {
+ FailureOr<bool> hasOverlaps =
+ hasOverlappingSlices(rewriter.getContext(), extractSlices);
+ if (failed(hasOverlaps))
+ return rewriter.notifyMatchFailure(
+ producerOp, "could not determine slice overlaps");
+ if (*hasOverlaps)
+ return rewriter.notifyMatchFailure(
+ producerOp, "extract_slices have overlapping regions");
+
+ LLVM_DEBUG(llvm::dbgs() << "No overlapping slices detected\n");
+ }
+
+ // Apply the control function to verify each slice is suitable for tiling.
+ if (controlFn) {
+ for (tensor::ExtractSliceOp extractSlice : extractSlices) {
+ if (failed(controlFn(producerOp, extractSlice)))
+ return rewriter.notifyMatchFailure(
+ producerOp, "slice rejected by control function");
+ }
+ }
+
+ // For each extract_slice, create a tiled producer.
+ // Collect all results first before replacing to avoid iterator invalidation.
+ SmallVector<std::pair<tensor::ExtractSliceOp, Value>> replacements;
+ for (tensor::ExtractSliceOp extractSlice : extractSlices) {
+ FailureOr<TilingResult> tilingResult =
+ tensor::replaceExtractSliceWithTiledProducer(rewriter, extractSlice,
+ output);
+ if (failed(tilingResult))
+ return rewriter.notifyMatchFailure(
+ producerOp, "failed to generate tiled implementation");
+
+ if (tilingResult->tiledValues.empty())
+ return rewriter.notifyMatchFailure(
+ producerOp, "tiling produced no values");
+
+ replacements.emplace_back(extractSlice, tilingResult->tiledValues[0]);
+ }
+
+ // Replace all extract_slices with tiled values.
+ for (auto &[extractSlice, tiledValue] : replacements) {
+ rewriter.replaceOp(extractSlice, tiledValue);
+ LLVM_DEBUG(llvm::dbgs() << "Replaced extract_slice with tiled value\n");
+ }
+
+ return success();
+ }
+
+private:
+ /// Optional callback to control which slices should be transformed.
+ /// Called for each extract_slice with the producer op and the slice op.
+ /// Return success() to allow the transformation, failure() to skip.
+ function_ref<LogicalResult(TilingInterface, tensor::ExtractSliceOp)>
+ controlFn;
+};
+
+} // namespace
+
+void mlir::tensor::populateBubbleUpExtractSliceThroughTilingInterfacePatterns(
+ RewritePatternSet &patterns,
+ function_ref<LogicalResult(TilingInterface, tensor::ExtractSliceOp)>
+ controlFn) {
+ patterns.add<BubbleUpExtractSliceThroughTilingInterface>(
+ patterns.getContext(), controlFn);
+}
diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
index 99e1c4fec8467..261c0e88de972 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
add_mlir_dialect_library(MLIRTensorTransforms
+ BubbleUpExtractSliceThroughTilingInterface.cpp
BufferizableOpInterfaceImpl.cpp
ConcatOpPatterns.cpp
EmptyOpPatterns.cpp
diff --git a/mlir/test/Dialect/Tensor/bubble-up-extract-slice-through-tiling-interface.mlir b/mlir/test/Dialect/Tensor/bubble-up-extract-slice-through-tiling-interface.mlir
new file mode 100644
index 0000000000000..223166fb7216b
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/bubble-up-extract-slice-through-tiling-interface.mlir
@@ -0,0 +1,149 @@
+// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-bubble-up-extract-slice-through-tiling-interface %s | FileCheck %s
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d1)>
+
+// CHECK-LABEL: func.func @bubble_single_slice(
+// CHECK-DAG: %[[SLICE0:.*]] = tensor.extract_slice %arg0[4, 8] [4, 4] [1, 1]
+// CHECK-DAG: %[[SLICE1:.*]] = tensor.extract_slice %arg1[8] [4] [1]
+// CHECK-DAG: %[[SLICE2:.*]] = tensor.extract_slice %arg0[4, 8] [4, 4] [1, 1]
+// CHECK: %[[GENERIC:.*]] = linalg.generic
+// CHECK-SAME: ins(%[[SLICE0]], %[[SLICE1]] : tensor<4x4xf32>, tensor<4xf32>)
+// CHECK-SAME: outs(%[[SLICE2]] : tensor<4x4xf32>)
+// CHECK: return %[[GENERIC]]
+func.func @bubble_single_slice(%arg0: tensor<16x16xf32>, %arg1: tensor<16xf32>) -> tensor<4x4xf32> {
+ %0 = linalg.generic {
+ indexing_maps = [#map, #map1, #map],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%arg0, %arg1 : tensor<16x16xf32>, tensor<16xf32>)
+ outs(%arg0 : tensor<16x16xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.addf %in, %in_0 : f32
+ linalg.yield %1 : f32
+ } -> tensor<16x16xf32>
+ %1 = tensor.extract_slice %0[4, 8] [4, 4] [1, 1] : tensor<16x16xf32> to tensor<4x4xf32>
+ return %1 : tensor<4x4xf32>
+}
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d1)>
+
+// CHECK-LABEL: func.func @bubble_multiple_non_overlapping_slices(
+// CHECK: %[[S0:.*]] = tensor.extract_slice %arg0[0, 0] [4, 8] [1, 1]
+// CHECK: %[[S1:.*]] = tensor.extract_slice %arg1[0] [8] [1]
+// CHECK: %[[S2:.*]] = tensor.extract_slice %arg0[0, 0] [4, 8] [1, 1]
+// CHECK: %[[GENERIC1:.*]] = linalg.generic
+// CHECK-SAME: ins(%[[S0]], %[[S1]] : tensor<4x8xf32>, tensor<8xf32>)
+// CHECK-SAME: outs(%[[S2]] : tensor<4x8xf32>)
+// CHECK: %[[S3:.*]] = tensor.extract_slice %arg0[0, 8] [4, 8] [1, 1]
+// CHECK: %[[S4:.*]] = tensor.extract_slice %arg1[8] [8] [1]
+// CHECK: %[[S5:.*]] = tensor.extract_slice %arg0[0, 8] [4, 8] [1, 1]
+// CHECK: %[[GENERIC2:.*]] = linalg.generic
+// CHECK-SAME: ins(%[[S3]], %[[S4]] : tensor<4x8xf32>, tensor<8xf32>)
+// CHECK-SAME: outs(%[[S5]] : tensor<4x8xf32>)
+// CHECK: return %[[GENERIC1]], %[[GENERIC2]]
+func.func @bubble_multiple_non_overlapping_slices(%arg0: tensor<16x16xf32>, %arg1: tensor<16xf32>)
+ -> (tensor<4x8xf32>, tensor<4x8xf32>) {
+ %0 = linalg.generic {
+ indexing_maps = [#map, #map1, #map],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%arg0, %arg1 : tensor<16x16xf32>, tensor<16xf32>)
+ outs(%arg0 : tensor<16x16xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.addf %in, %in_0 : f32
+ linalg.yield %1 : f32
+ } -> tensor<16x16xf32>
+ %1 = tensor.extract_slice %0[0, 0] [4, 8] [1, 1] : tensor<16x16xf32> to tensor<4x8xf32>
+ %2 = tensor.extract_slice %0[0, 8] [4, 8] [1, 1] : tensor<16x16xf32> to tensor<4x8xf32>
+ return %1, %2 : tensor<4x8xf32>, tensor<4x8xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @bubble_through_matmul(
+// CHECK-DAG: %[[LHS_SLICE:.*]] = tensor.extract_slice %arg0[2, 0] [4, 8] [1, 1]
+// CHECK-DAG: %[[RHS_SLICE:.*]] = tensor.extract_slice %arg1[0, 2] [8, 4] [1, 1]
+// CHECK-DAG: %[[DST_SLICE:.*]] = tensor.extract_slice %arg2[2, 2] [4, 4] [1, 1]
+// CHECK: %[[MATMUL:.*]] = linalg.matmul ins(%[[LHS_SLICE]], %[[RHS_SLICE]] :
+// CHECK: return %[[MATMUL]]
+func.func @bubble_through_matmul(%lhs: tensor<8x8xf32>, %rhs: tensor<8x8xf32>,
+ %dst: tensor<8x8xf32>) -> tensor<4x4xf32> {
+ %0 = linalg.matmul ins(%lhs, %rhs : tensor<8x8xf32>, tensor<8x8xf32>)
+ outs(%dst : tensor<8x8xf32>) -> tensor<8x8xf32>
+ %1 = tensor.extract_slice %0[2, 2] [4, 4] [1, 1] : tensor<8x8xf32> to tensor<4x4xf32>
+ return %1 : tensor<4x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @bubble_through_fill(
+// CHECK: %[[CST:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK: %[[EMPTY:.*]] = tensor.empty()
+// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[EMPTY]][4, 4] [4, 4] [1, 1]
+// CHECK: %[[FILL:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[SLICE]] : tensor<4x4xf32>)
+// CHECK: return %[[FILL]]
+func.func @bubble_through_fill() -> tensor<4x4xf32> {
+ %cst = arith.constant 1.0 : f32
+ %empty = tensor.empty() : tensor<16x16xf32>
+ %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<16x16xf32>) -> tensor<16x16xf32>
+ %slice = tensor.extract_slice %fill[4, 4] [4, 4] [1, 1] : tensor<16x16xf32> to tensor<4x4xf32>
+ return %slice : tensor<4x4xf32>
+}
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d1)>
+
+// CHECK-LABEL: func.func @bubble_with_dynamic_dims(
+// CHECK: %[[SLICE0:.*]] = tensor.extract_slice %arg0[%arg2, %arg3] [%arg4, %arg5] [1, 1]
+// CHECK: %[[SLICE1:.*]] = tensor.extract_slice %arg1[%arg3] [%arg5] [1]
+// CHECK: %[[SLICE2:.*]] = tensor.extract_slice %arg0[%arg2, %arg3] [%arg4, %arg5] [1, 1]
+// CHECK: %[[GENERIC:.*]] = linalg.generic
+// CHECK-SAME: ins(%[[SLICE0]], %[[SLICE1]] : tensor<?x?xf32>, tensor<?xf32>)
+// CHECK-SAME: outs(%[[SLICE2]] : tensor<?x?xf32>)
+// CHECK: return %[[GENERIC]]
+func.func @bubble_with_dynamic_dims(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf32>,
+ %off0: index, %off1: index,
+ %sz0: index, %sz1: index) -> tensor<?x?xf32> {
+ %0 = linalg.generic {
+ indexing_maps = [#map, #map1, #map],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?xf32>)
+ outs(%arg0 : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.addf %in, %in_0 : f32
+ linalg.yield %1 : f32
+ } -> tensor<?x?xf32>
+ %1 = tensor.extract_slice %0[%off0, %off1] [%sz0, %sz1] [1, 1]
+ : tensor<?x?xf32> to tensor<?x?xf32>
+ return %1 : tensor<?x?xf32>
+}
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d1)>
+
+/// Negative test: result has non-extract_slice consumer.
+
+// CHECK-LABEL: func.func @no_bubble_non_slice_consumer(
+// CHECK: %[[GENERIC:.*]] = linalg.generic
+// CHECK-SAME: ins(%arg0, %arg1 : tensor<16x16xf32>, tensor<16xf32>)
+// CHECK-SAME: outs(%arg0 : tensor<16x16xf32>)
+// CHECK: return %[[GENERIC]] : tensor<16x16xf32>
+func.func @no_bubble_non_slice_consumer(%arg0: tensor<16x16xf32>,
+ %arg1: tensor<16xf32>) -> tensor<16x16xf32> {
+ %0 = linalg.generic {
+ indexing_maps = [#map, #map1, #map],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%arg0, %arg1 : tensor<16x16xf32>, tensor<16xf32>)
+ outs(%arg0 : tensor<16x16xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.addf %in, %in_0 : f32
+ linalg.yield %1 : f32
+ } -> tensor<16x16xf32>
+ return %0 : tensor<16x16xf32>
+}
diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
index 687473ebe6d60..9c2145aa7188c 100644
--- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
@@ -93,6 +93,12 @@ struct TestTensorTransforms
*this, "test-tracking-listener",
llvm::cl::desc("Test tensor TrackingListener for the transform dialect"),
llvm::cl::init(false)};
+
+ Option<bool> testBubbleUpExtractSliceThroughTilingInterface{
+ *this, "test-bubble-up-extract-slice-through-tiling-interface",
+ llvm::cl::desc("Test bubbling up tensor.extract_slice through operations "
+ "implementing TilingInterface"),
+ llvm::cl::init(false)};
};
} // namespace
@@ -143,6 +149,13 @@ static void applyFoldExtractFromCollapseShapePatterns(Operation *rootOp) {
(void)applyPatternsGreedily(rootOp, std::move(patterns));
}
+static void
+applyBubbleUpExtractSliceThroughTilingInterfacePatterns(Operation *rootOp) {
+ RewritePatternSet patterns(rootOp->getContext());
+ tensor::populateBubbleUpExtractSliceThroughTilingInterfacePatterns(patterns);
+ (void)applyPatternsGreedily(rootOp, std::move(patterns));
+}
+
namespace {
/// Base pattern to rewrite a `tensor.collapse_shape -> tensor.extract_slice`.
/// The `tensor.extract_slice` is replaced by a loop or gather operation that
@@ -397,6 +410,8 @@ void TestTensorTransforms::runOnOperation() {
if (testTrackingListener)
if (failed(testTrackingListenerReplacements(rootOp)))
return signalPassFailure();
+ if (testBubbleUpExtractSliceThroughTilingInterface)
+ applyBubbleUpExtractSliceThroughTilingInterfacePatterns(rootOp);
}
namespace mlir {
More information about the Mlir-commits
mailing list