[Mlir-commits] [mlir] 6e59282 - [MLIR] Add pattern to bubble up tensor.extract_slice (#126898)
    llvmlistbot at llvm.org 
    llvmlistbot at llvm.org
       
    Mon Mar  3 10:20:53 PST 2025
    
    
  
Author: ofri frishman
Date: 2025-03-03T18:20:50Z
New Revision: 6e59282235b2ba7b5bbae968cafb15bab9656cff
URL: https://github.com/llvm/llvm-project/commit/6e59282235b2ba7b5bbae968cafb15bab9656cff
DIFF: https://github.com/llvm/llvm-project/commit/6e59282235b2ba7b5bbae968cafb15bab9656cff.diff
LOG: [MLIR] Add pattern to bubble up tensor.extract_slice (#126898)
Add a pattern that bubbles up tensor.extract_slice through
tensor.expand_shape, and add a transform op to tensor dialect
to directly use this pattern.
This pattern enables tiling and fusing op chains which contain
tensor.expand_shape if added as a cleanup pattern of tile and fuse
utility.
Without this pattern that would not be possible, as
tensor.expand_shape does not implement the tiling interface.
In addition, registering this pattern as a cleanup pattern for
transform.structured.fuse.
The pattern was first implement in IREE project by
Quinn Dawkins and is being upstreamed.
---------
Co-authored-by: Quinn Dawkins <quinn.dawkins at gmail.com>
Added: 
    mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir
Modified: 
    mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
    mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
    mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
    mlir/test/Dialect/Linalg/transform-op-fuse.mlir
Removed: 
    
################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
index fcb10f55d556d..9f6387d985f77 100644
--- a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
@@ -101,6 +101,17 @@ def ApplyReassociativeReshapeFoldingPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
+def ApplyBubbleUpExtractSlicePatternsOp : Op<Transform_Dialect,
+    "apply_patterns.tensor.bubble_up_extract_slice",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Indicates that producers of tensor.extract_slice should swap and operate on 
+    the result of the slice.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
 def ApplyRewriteTensorOpsAsConstantPatternsOp : Op<Transform_Dialect,
     "apply_patterns.tensor.rewrite_as_constant",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff  --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index 905ab0577ccc1..18981337742eb 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -58,6 +58,12 @@ void populateFoldTensorSubsetIntoVectorTransferPatterns(
 void populateMergeConsecutiveInsertExtractSlicePatterns(
     RewritePatternSet &patterns);
 
+/// Appends patterns that are used to bubble up tensor.extract slice op above
+/// its producer. When used as cleanup patterns of tile and fuse, enables fusing
+/// the producer with the consumer even if the producer does not implement the
+/// tiling interface.
+void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns);
+
 /// Populates `patterns` with patterns that drop redundant tensor.insert_slice
 /// rank expansions.
 void populateDropRedundantInsertSliceRankExpansionPatterns(
diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 2f54e780093a2..ef7b4757a04b4 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -592,6 +592,7 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
     RewritePatternSet patterns(context);
     tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, context);
     tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns);
+    tensor::populateBubbleUpExtractSliceOpPatterns(patterns);
     tileAndFuseOptions.cleanupPatterns = std::move(patterns);
   }
 
diff  --git a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
index f3560d08ff769..723731b8bed61 100644
--- a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
+++ b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
@@ -120,6 +120,11 @@ void transform::ApplyReassociativeReshapeFoldingPatternsOp::populatePatterns(
   tensor::populateReassociativeReshapeFoldingPatterns(patterns);
 }
 
+void transform::ApplyBubbleUpExtractSlicePatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  tensor::populateBubbleUpExtractSliceOpPatterns(patterns);
+}
+
 void transform::ApplyRewriteTensorOpsAsConstantPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
   ControlFoldFn defaultControlFn = [](OpOperand *fusedOperand) {
diff  --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
index 5edd7a02bc42b..ae8e3528b02e0 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
@@ -6,10 +6,14 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/ValueBoundsOpInterface.h"
 #include "llvm/Support/Debug.h"
+#include "llvm/Support/LogicalResult.h"
 
 using namespace mlir;
 using namespace mlir::tensor;
@@ -210,6 +214,214 @@ struct BubbleUpExpandThroughParallelCollapse
   }
 };
 
+/// Converts `tensor.extract_slice(tensor.expand_shape)` to
+/// `tensor.expand_shape(tensor.extract_slice)`.
+///
+/// For this transformation to be possible, the slice must be fully contiguous
+/// within each reassociation group of the expand_shape. A slice is defined as
+/// fully contiguous within a reassociation group if after flattening the
+/// reassociation group to a single 1D range, then the slice taken out of the
+/// group could be defined as a single contiguous subrange within that range.
+///
+/// Rank reducing slices are not supported.
+///
+/// Example:
+/// The transformation is possible because each reassociation group has a
+/// contiguous slice (i.e., [2x4->2x4], [2x8->1x5], [4x2x4->1x1x4]).
+/// ```
+/// BEFORE:
+/// %reshape = tensor.expand_shape %in [[0, 1], [2, 3], [4, 5, 6]]
+///     tensor<8x16x32xf32> to tensor<2x4x2x8x4x2x4xf32>
+/// %slice = tensor.extract_slice %reshape ...
+///     tensor<2x4x2x8x4x2x4xf32> to tensor<2x4x1x5x1x1x4xf32>
+///
+/// AFTER:
+/// %slice = tensor.extract_slice %in ...
+///     tensor<8x16x32xf32> to tensor<8x5x4xf32>
+/// %reshape = tensor.expand_shape %slice [[0, 1], [2, 3], [4, 5, 6]]
+///     tensor<8x5x4xf32> to tensor<2x4x1x5x1x1x4xf32>
+/// ```
+///
+/// Note - this pattern could be extended to be a swap pattern between
+/// `tensor.expand_shape` and `tensor.extract_slice`, but is currently
+/// implemented only as a bubble up pattern for `tensor.extract_slice`.
+struct BubbleUpExpandShapeThroughExtractSlice
+    : public OpRewritePattern<tensor::ExtractSliceOp> {
+  using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
+                                PatternRewriter &rewriter) const override {
+    auto expandShapeOp =
+        sliceOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
+
+    if (checkPreconditionForBubbleUpExtractSlice(sliceOp, expandShapeOp,
+                                                 rewriter)
+            .failed())
+      return failure();
+
+    // The tensor.extract_slice before applying the pattern works on the result
+    // of the tensor.expand_shape, so variables (i.e. inputs for ExtractSliceOp)
+    // referring to the state before applying the pattern are named with the
+    // prefix "expanded", and ones referring to the state after applying the
+    // pattern are named with the prefix "collapsed".
+    SmallVector<OpFoldResult> expandedOffsets = sliceOp.getMixedOffsets();
+    SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes();
+    SmallVector<OpFoldResult> expandedShape =
+        getMixedValues(expandShapeOp.getStaticOutputShape(),
+                       expandShapeOp.getOutputShape(), rewriter);
+
+    // Helper variables and function for accumulating the size values.
+    Location loc = expandShapeOp->getLoc();
+    AffineExpr d0, d1, d2;
+    bindDims(rewriter.getContext(), d0, d1, d2);
+    // Multiply two integers.
+    auto mul = [&](OpFoldResult v1, OpFoldResult v2) {
+      auto mulMap = AffineMap::get(2, 0, {d0 * d1});
+      return affine::makeComposedFoldedAffineApply(rewriter, loc, mulMap,
+                                                   {v1, v2});
+    };
+
+    // Compute new offsets, sizes, and strides for tensor.extract_slice.
+    // The new tensor.extract_slice will work on a tensor that has has a rank of
+    // ReassociationIndices.size(). In the loop a single offset, size, and
+    // stride value is computed per reassociation group.
+    SmallVector<OpFoldResult> collapsedOffsets, collapsedSizes,
+        collapsedStrides;
+    for (const ReassociationIndices &indices :
+         expandShapeOp.getReassociationIndices()) {
+      // collapsedSize will hold the size of the single dim that represents the
+      // reassociation group in the non expanded tensor.
+      OpFoldResult collapsedSize = rewriter.getIndexAttr(1);
+      // The reassocGroupSizes and reassocGroupOffsets are used to create an
+      // affine.linearize_index op to linearize the single offset value required
+      // for this reassociation group.
+      SmallVector<OpFoldResult> reassocGroupSizes, reassocGroupOffsets;
+
+      for (long expandedDim : indices) {
+        // reassocGroupSizes and reassocGroupOffsets can be obtained directly
+        // from the expanded state, but the collapsed size requires calculation
+        // as it did not previously exist.
+        reassocGroupSizes.push_back(expandedShape[expandedDim]);
+        reassocGroupOffsets.push_back(expandedOffsets[expandedDim]);
+        collapsedSize = mul(collapsedSize, expandedSizes[expandedDim]);
+      }
+
+      SmallVector<Value> offsetVals =
+          llvm::map_to_vector(reassocGroupOffsets, [&](OpFoldResult ofr) {
+            return getValueOrCreateConstantIndexOp(rewriter, loc, ofr);
+          });
+      OpFoldResult collapsedOffset =
+          rewriter
+              .create<affine::AffineLinearizeIndexOp>(loc, offsetVals,
+                                                      reassocGroupSizes,
+                                                      /*disjoint=*/true)
+              .getResult();
+      collapsedOffsets.push_back(collapsedOffset);
+      collapsedSizes.push_back(collapsedSize);
+
+      // Only unit stride is supported.
+      collapsedStrides.push_back(rewriter.getIndexAttr(1));
+    }
+
+    // The shape of the result can be obtained from the sizes passed in.
+    SmallVector<Value> dynDims;
+    SmallVector<int64_t> shape;
+    dispatchIndexOpFoldResults(expandedSizes, dynDims, shape);
+    RankedTensorType resultType = RankedTensorType::get(
+        shape, expandShapeOp.getResultType().getElementType());
+
+    // Create a new ExtractSliceOp and ExpandShapeOp.
+    Value newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
+        loc, expandShapeOp.getSrc(), collapsedOffsets, collapsedSizes,
+        collapsedStrides);
+    rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
+        sliceOp, resultType, newSliceOp,
+        expandShapeOp.getReassociationIndices(), expandedSizes);
+    return success();
+  }
+
+  // Helper function to check if all the required conditions for the
+  // tensor.extract_slice to be bubbled up through the tensor.expand_shape are
+  // met.
+  LogicalResult
+  checkPreconditionForBubbleUpExtractSlice(tensor::ExtractSliceOp sliceOp,
+                                           tensor::ExpandShapeOp expandShapeOp,
+                                           PatternRewriter &rewriter) const {
+
+    if (!expandShapeOp) {
+      return rewriter.notifyMatchFailure(
+          sliceOp, "tensor.extract_slice source not produced by expand_shape");
+    }
+
+    if (!sliceOp.hasUnitStride()) {
+      return rewriter.notifyMatchFailure(
+          sliceOp, "unsupported: non-unit stride. Only contiguous slices can "
+                   "be supported in this transformation.");
+    }
+
+    SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
+    SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
+
+    if (static_cast<size_t>(sliceOp.getResultType().getRank()) !=
+        sizes.size()) {
+      return rewriter.notifyMatchFailure(sliceOp,
+                                         "unimplemented: rank reducing slice");
+    }
+
+    SmallVector<OpFoldResult> outputShape =
+        getMixedValues(expandShapeOp.getStaticOutputShape(),
+                       expandShapeOp.getOutputShape(), rewriter);
+
+    std::function<bool(OpFoldResult, OpFoldResult, OpFoldResult)>
+        isZeroOffsetAndFullSize =
+            [](OpFoldResult offset, OpFoldResult sliceSize, OpFoldResult size) {
+              if (!isConstantIntValue(offset, 0))
+                return false;
+              FailureOr<bool> maybeEqual =
+                  ValueBoundsConstraintSet::areEqual(sliceSize, size);
+              return llvm::succeeded(maybeEqual) && maybeEqual.value();
+            };
+
+    // Check that the slice is contiguous within each reassociation group.
+    // The slice is contiguous only if after the first dimension where a non
+    // unit slice is taken, the slice size on all subsequent dimensions of the
+    // group is equal to the entire size of the dimension.
+    // Examples of contiguous slices:
+    //   full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 1, 10]
+    //   full sizes: [5, 10] slice offsets: [3, 0] slice sizes: [2, 10]
+    // Examples of non contiguous slices:
+    //   full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 2, 5]
+    //   full sizes: [5, 10] slice offsets: [0, 4] slice sizes: [2, 5]
+    for (const ReassociationIndices &indices :
+         expandShapeOp.getReassociationIndices()) {
+      int64_t i = 0;
+      int64_t e = indices.size();
+      // Find the first expanded dim after the first dim with non-unit extracted
+      // size.
+      for (; i < e; ++i) {
+        if (!isConstantIntValue(sizes[indices[i]], 1)) {
+          // +1 to skip the first non-unit size dim.
+          i++;
+          break;
+        }
+      }
+
+      // Verify that all subsequent dimensions extract the full size of the
+      // source tensor.
+      for (; i < e; ++i) {
+        int64_t expandedDim = indices[i];
+        if (!isZeroOffsetAndFullSize(offsets[expandedDim], sizes[expandedDim],
+                                     outputShape[expandedDim])) {
+          return rewriter.notifyMatchFailure(
+              sliceOp, "Not a contiguous slice of the expanded tensor.");
+        }
+      }
+    }
+
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::tensor::populateReassociativeReshapeFoldingPatterns(
@@ -227,3 +439,8 @@ void mlir::tensor::populateBubbleUpExpandShapePatterns(
     RewritePatternSet &patterns) {
   patterns.add<BubbleUpExpandThroughParallelCollapse>(patterns.getContext());
 }
+
+void mlir::tensor::populateBubbleUpExtractSliceOpPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<BubbleUpExpandShapeThroughExtractSlice>(patterns.getContext());
+}
diff  --git a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir
index 20019424e8d3c..9bcc125ce1ba9 100644
--- a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir
@@ -278,3 +278,163 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+// CHECK-LABEL: func.func @bubble_up_extract_slice_through_expand_shape
+//     CHECK: scf.for %[[X:[A-Za-z0-9]+]] = {{.*}}
+//     CHECK:   scf.for %[[Y:[A-Za-z0-9]+]] = {{.*}}
+//     CHECK:     scf.for %[[Z:[A-Za-z0-9]+]] = {{.*}}
+//     CHECK:       %[[LINEAR_IDX:.+]] = affine.linearize_index disjoint [%[[X]], %[[Y]], %[[Z]]] by (2, 3, 10)
+//     CHECK:       %[[SLICE:.+]] = tensor.extract_slice %{{.*}}[%[[LINEAR_IDX]]] [5] [1] : tensor<60xf32> to tensor<5xf32>
+//     CHECK:       %[[EXPAND:.+]] = tensor.expand_shape %[[SLICE]] {{\[\[}}0, 1, 2]] output_shape [1, 1, 5]
+//     CHECK:       linalg.exp ins(%[[EXPAND]]
+func.func @bubble_up_extract_slice_through_expand_shape(%0: tensor<60xf32>) -> tensor<2x3x10xf32> {
+  %expand = tensor.expand_shape %0 [[0, 1, 2]] output_shape [2, 3, 10] : tensor<60xf32> into tensor<2x3x10xf32>
+  %empty = tensor.empty() : tensor<2x3x10xf32>
+  %exp = linalg.exp ins(%expand : tensor<2x3x10xf32>) outs(%empty : tensor<2x3x10xf32>) -> tensor<2x3x10xf32>
+  return %exp : tensor<2x3x10xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %transformed, %loops:3 = transform.structured.fuse %0 [1, 1, 5] interchange [0, 1, 2] apply_cleanup = true : 
+      (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op, !transform.any_op)
+    transform.yield 
+  }
+}
+
+// -----
+
+// CHECK-LABEL: func.func @bubble_up_extract_slice_through_expand_shape_full_inner_dim
+//     CHECK: scf.for %[[X:[A-Za-z0-9]+]] = {{.*}}
+//     CHECK:   scf.for %[[Y:[A-Za-z0-9]+]] = {{.*}}
+//     CHECK:       %[[LINEAR_IDX:.+]] = affine.linearize_index disjoint [%[[X]], %[[Y]]{{.*}} by (3, 4, 10)
+//     CHECK:       %[[SLICE:.+]] = tensor.extract_slice %{{.*}}[%[[LINEAR_IDX]]] [20] [1] : tensor<120xf32> to tensor<20xf32>
+//     CHECK:       %[[EXPAND:.+]] = tensor.expand_shape %[[SLICE]] {{\[\[}}0, 1, 2]] output_shape [1, 2, 10]
+//     CHECK:       linalg.exp ins(%[[EXPAND]]
+func.func @bubble_up_extract_slice_through_expand_shape_full_inner_dim(%0: tensor<120xf32>) -> tensor<3x4x10xf32> {
+  %expand = tensor.expand_shape %0 [[0, 1, 2]] output_shape [3, 4, 10] : tensor<120xf32> into tensor<3x4x10xf32>
+  %empty = tensor.empty() : tensor<3x4x10xf32>
+  %exp = linalg.exp ins(%expand : tensor<3x4x10xf32>) outs(%empty : tensor<3x4x10xf32>) -> tensor<3x4x10xf32>
+  return %exp : tensor<3x4x10xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %transformed, %loops:2 = transform.structured.fuse %0 [1, 2, 0] interchange [0, 1, 2] apply_cleanup = true : 
+      (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op)
+    transform.yield 
+  }
+}
+
+// -----
+
+// CHECK-LABEL: func.func @no_bubble_up_extract_slice_through_expand_shape_non_contiguous
+//     CHECK: tensor.expand_shape
+//     CHECK: scf.for
+//     CHECK:   scf.for
+//     CHECK:     scf.for
+//     CHECK:       linalg.exp
+func.func @no_bubble_up_extract_slice_through_expand_shape_non_contiguous(%0: tensor<120xf32>) -> tensor<3x4x10xf32> {
+  %expand = tensor.expand_shape %0 [[0, 1, 2]] output_shape [3, 4, 10] : tensor<120xf32> into tensor<3x4x10xf32>
+  %empty = tensor.empty() : tensor<3x4x10xf32>
+  %exp = linalg.exp ins(%expand : tensor<3x4x10xf32>) outs(%empty : tensor<3x4x10xf32>) -> tensor<3x4x10xf32>
+  return %exp : tensor<3x4x10xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %transformed, %loops:3 = transform.structured.fuse %0 [1, 2, 5] interchange [0, 1, 2] apply_cleanup = true : 
+      (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op, !transform.any_op)
+    transform.yield 
+  }
+}
+
+// -----
+
+// CHECK-LABEL: func.func @bubble_up_extract_slice_through_expand_shape_multiple_expanded_dims
+//     CHECK: %[[C0:.+]] = arith.constant 0 : index
+//     CHECK: scf.for %[[X:[A-Za-z0-9]+]] = {{.*}}
+//     CHECK:   scf.for %[[Y:[A-Za-z0-9]+]] = {{.*}}
+//     CHECK:     scf.for %[[Z:[A-Za-z0-9]+]] = {{.*}}
+//     CHECK:       scf.for %[[W:[A-Za-z0-9]+]] = {{.*}}
+//     CHECK:       %[[LINEAR_IDX0:.+]] = affine.linearize_index disjoint [%[[X]], %[[Y]], %[[C0]]] by (3, 4, 10)
+//     CHECK:       %[[LINEAR_IDX1:.+]] = affine.linearize_index disjoint [%[[Z]], %[[W]]] by (7, 8)
+//     CHECK:       %[[SLICE:.+]] = tensor.extract_slice %{{.*}}[%[[LINEAR_IDX0]], %[[LINEAR_IDX1]]] [20, 4] [1, 1] : tensor<120x56xf32> to tensor<20x4xf32>
+//     CHECK:       %[[EXPAND:.+]] = tensor.expand_shape %[[SLICE]] {{\[\[}}0, 1, 2], [3, 4]] output_shape [1, 2, 10, 1, 4]
+//     CHECK:       linalg.exp ins(%[[EXPAND]]
+module {
+  func.func @bubble_up_extract_slice_through_expand_shape_multiple_expanded_dims(%0: tensor<120x56xf32>) -> tensor<3x4x10x7x8xf32> {
+    %expand = tensor.expand_shape %0 [[0, 1, 2], [3, 4]] output_shape [3, 4, 10, 7, 8] : tensor<120x56xf32> into tensor<3x4x10x7x8xf32>
+    %empty = tensor.empty() : tensor<3x4x10x7x8xf32>
+    %exp = linalg.exp ins(%expand : tensor<3x4x10x7x8xf32>) outs(%empty : tensor<3x4x10x7x8xf32>) -> tensor<3x4x10x7x8xf32>
+    return %exp : tensor<3x4x10x7x8xf32>
+  }
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %transformed, %loops:4 = transform.structured.fuse %0 [1, 2, 0, 1, 4] interchange [0, 1, 2, 3, 4] apply_cleanup = true : 
+      (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op, !transform.any_op, !transform.any_op)
+    transform.yield 
+  }
+}
+
+// -----
+
+// CHECK-LABEL: func.func @bubble_up_extract_slice_through_expand_shape_and_fuse_with_expand_producer
+//     CHECK: scf.for %[[X:[A-Za-z0-9]+]] = {{.*}}
+//     CHECK:    %[[LINEAR_IDX:.+]] = affine.linearize_index disjoint [%[[X]], {{.*}} by (8, 32)
+//     CHECK:    %[[SLICE:.+]] = tensor.extract_slice %{{.*}}[0, 0, %[[LINEAR_IDX]]] [1, 1800, 32] [1, 1, 1] : tensor<1x1800x256xf32> to tensor<1x1800x32xf32>
+//     CHECK:    %[[ABS:.+]] = linalg.abs ins(%[[SLICE]]
+//     CHECK:    %[[EXPAND:.+]] = tensor.expand_shape %[[ABS]] {{\[\[}}0], [1], [2, 3]] output_shape [1, 1800, 1, 32]
+//     CHECK:    linalg.exp ins(%[[EXPAND]]
+module {
+  func.func @bubble_up_extract_slice_through_expand_shape_and_fuse_with_expand_producer(%0: tensor<1x1800x256xf32>) -> tensor<1x1800x8x32xf32> {
+    %empty1 = tensor.empty() : tensor<1x1800x256xf32>
+    %exp1 = linalg.abs ins(%0 : tensor<1x1800x256xf32>) outs(%empty1 : tensor<1x1800x256xf32>) -> tensor<1x1800x256xf32>
+    %expand = tensor.expand_shape %exp1 [[0], [1], [2, 3]] output_shape [1, 1800, 8, 32] : tensor<1x1800x256xf32> into tensor<1x1800x8x32xf32>
+    %empty2 = tensor.empty() : tensor<1x1800x8x32xf32>
+    %exp2 = linalg.exp ins(%expand : tensor<1x1800x8x32xf32>) outs(%empty2 : tensor<1x1800x8x32xf32>) -> tensor<1x1800x8x32xf32>
+    return %exp2 : tensor<1x1800x8x32xf32>
+  }
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %transformed, %loops:1 = transform.structured.fuse %0 [0, 0, 1, 0] interchange [0, 1, 2, 3] apply_cleanup = true : 
+      (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">)
+    transform.yield 
+  }
+}
+
+// -----
+
+// CHECK-LABEL: func.func @no_bubble_up_extract_slice_through_expand_shape_on_cleanup_false
+//     CHECK: %[[EXPAND:.+]] = tensor.expand_shape {{.*}} {{\[\[}}0, 1, 2]] output_shape [2, 3, 10]
+//     CHECK: scf.for %[[X:[A-Za-z0-9]+]] = {{.*}}
+//     CHECK:   scf.for %[[Y:[A-Za-z0-9]+]] = {{.*}}
+//     CHECK:     scf.for %[[Z:[A-Za-z0-9]+]] = {{.*}}
+//     CHECK:       %[[SLICE:.+]] = tensor.extract_slice %[[EXPAND]]{{.*}} [1, 1, 5] [1, 1, 1] : tensor<2x3x10xf32> to tensor<1x1x5xf32>
+//     CHECK:       linalg.exp ins(%[[SLICE]]
+func.func @no_bubble_up_extract_slice_through_expand_shape_on_cleanup_false(%0: tensor<60xf32>) -> tensor<2x3x10xf32> {
+  %expand = tensor.expand_shape %0 [[0, 1, 2]] output_shape [2, 3, 10] : tensor<60xf32> into tensor<2x3x10xf32>
+  %empty = tensor.empty() : tensor<2x3x10xf32>
+  %exp = linalg.exp ins(%expand : tensor<2x3x10xf32>) outs(%empty : tensor<2x3x10xf32>) -> tensor<2x3x10xf32>
+  return %exp : tensor<2x3x10xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %transformed, %loops:3 = transform.structured.fuse %0 [1, 1, 5] interchange [0, 1, 2] apply_cleanup = false : 
+      (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op, !transform.any_op)
+    transform.yield 
+  }
+}
diff  --git a/mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir b/mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir
new file mode 100644
index 0000000000000..252e7494bff79
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir
@@ -0,0 +1,124 @@
+// RUN: mlir-opt -split-input-file -transform-interpreter  %s | FileCheck %s
+
+// CHECK-LABEL:   func.func @bubble_up_extract_slice_through_expand_shape(
+// CHECK-SAME:                %[[SRC:.*]]: tensor<60xf32>) -> tensor<1x1x5xf32> {
+// CHECK:           %[[C1:.+]] = arith.constant 5 : index
+// CHECK:           %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]][%[[C1]]] [5] [1] : tensor<60xf32> to tensor<5xf32>
+// CHECK:           %[[EXPAND:.*]] = tensor.expand_shape %[[EXTRACT]] {{\[\[}}0, 1, 2]] output_shape [1, 1, 5] : tensor<5xf32> into tensor<1x1x5xf32>
+// CHECK:           return %[[EXPAND]] : tensor<1x1x5xf32>
+
+func.func @bubble_up_extract_slice_through_expand_shape(%src: tensor<60xf32>) -> tensor<1x1x5xf32> {
+  %expand = tensor.expand_shape %src [[0, 1, 2]] output_shape [2, 3, 10] : tensor<60xf32> into tensor<2x3x10xf32>
+  %extract = tensor.extract_slice %expand[0, 0, 5][1, 1, 5][1, 1, 1] : tensor<2x3x10xf32> to tensor<1x1x5xf32>
+  return %extract : tensor<1x1x5xf32>
+}
+
+// CHECK-LABEL:   func.func @no_bubble_up_extract_slice_on_non_contiguous(
+// CHECK:           %[[EXPAND:.*]] = tensor.expand_shape 
+// CHECK:           %[[EXTRACT:.*]] = tensor.extract_slice 
+// CHECK:           return %[[EXTRACT]]
+
+func.func @no_bubble_up_extract_slice_on_non_contiguous(%src: tensor<60xf32>) -> tensor<1x2x5xf32> {
+  %expand = tensor.expand_shape %src [[0, 1, 2]] output_shape [2, 3, 10] : tensor<60xf32> into tensor<2x3x10xf32>
+  %extract = tensor.extract_slice %expand[0, 0, 5][1, 2, 5][1, 1, 1] : tensor<2x3x10xf32> to tensor<1x2x5xf32>
+  return %extract : tensor<1x2x5xf32>
+}
+
+// CHECK-LABEL:   func.func @no_bubble_up_extract_slice_on_stride(
+// CHECK:           %[[EXPAND:.*]] = tensor.expand_shape 
+// CHECK:           %[[EXTRACT:.*]] = tensor.extract_slice 
+// CHECK:           return %[[EXTRACT]]
+
+func.func @no_bubble_up_extract_slice_on_stride(%src: tensor<60xf32>) -> tensor<1x1x5xf32> {
+  %expand = tensor.expand_shape %src [[0, 1, 2]] output_shape [2, 3, 10] : tensor<60xf32> into tensor<2x3x10xf32>
+  %extract = tensor.extract_slice %expand[0, 0, 5][1, 1, 5][1, 1, 2] : tensor<2x3x10xf32> to tensor<1x1x5xf32>
+  return %extract : tensor<1x1x5xf32>
+}
+
+// CHECK-LABEL:   func.func @no_bubble_up_extract_slice_on_rank_reducing(
+// CHECK:           %[[EXPAND:.*]] = tensor.expand_shape 
+// CHECK:           %[[EXTRACT:.*]] = tensor.extract_slice 
+// CHECK:           return %[[EXTRACT]]
+
+func.func @no_bubble_up_extract_slice_on_rank_reducing(%src: tensor<60xf32>) -> tensor<1x5xf32> {
+  %expand = tensor.expand_shape %src [[0, 1, 2]] output_shape [2, 3, 10] : tensor<60xf32> into tensor<2x3x10xf32>
+  %extract = tensor.extract_slice %expand[0, 0, 5][1, 1, 5][1, 1, 1] : tensor<2x3x10xf32> to tensor<1x5xf32>
+  return %extract : tensor<1x5xf32>
+}
+
+// CHECK-LABEL:   func.func @bubble_up_extract_slice_through_expand_shape_multiple_expanded_dims(
+// CHECK-SAME:                  %[[SRC:.*]]: tensor<120x56xf32>) -> tensor<1x2x10x1x4xf32> {
+// CHECK:           %[[C0:.+]] = arith.constant 0 : index
+// CHECK:           %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]][%[[C0]], %[[C0]]] [20, 4] [1, 1] : tensor<120x56xf32> to tensor<20x4xf32>
+// CHECK:           %[[EXPAND:.*]] = tensor.expand_shape %[[EXTRACT]] {{\[\[}}0, 1, 2], [3, 4]] output_shape [1, 2, 10, 1, 4] : tensor<20x4xf32> into tensor<1x2x10x1x4xf32>
+// CHECK:           return %[[EXPAND]] : tensor<1x2x10x1x4xf32>
+
+func.func @bubble_up_extract_slice_through_expand_shape_multiple_expanded_dims(%src: tensor<120x56xf32>) -> tensor<1x2x10x1x4xf32> {
+  %expand = tensor.expand_shape %src [[0, 1, 2], [3, 4]] output_shape [3, 4, 10, 7, 8] : tensor<120x56xf32> into tensor<3x4x10x7x8xf32>
+  %extract = tensor.extract_slice %expand[0, 0, 0, 0, 0][1, 2, 10, 1, 4][1, 1, 1, 1, 1] : tensor<3x4x10x7x8xf32> to tensor<1x2x10x1x4xf32>
+  return %extract : tensor<1x2x10x1x4xf32>
+}
+
+// CHECK-LABEL:   func.func @bubble_up_extract_slice_with_trailing_full_dims(
+// CHECK-SAME:                %[[SRC:.*]]: tensor<60xf32>) -> tensor<2x5x2xf32> {
+// CHECK:           %[[C0:.+]] = arith.constant 0 : index
+// CHECK:           %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]][%[[C0]]] [20] [1] : tensor<60xf32> to tensor<20xf32>
+// CHECK:           %[[EXPAND:.*]] = tensor.expand_shape %[[EXTRACT]] {{\[\[}}0, 1, 2]] output_shape [2, 5, 2] : tensor<20xf32> into tensor<2x5x2xf32>
+// CHECK:           return %[[EXPAND]] : tensor<2x5x2xf32>
+func.func @bubble_up_extract_slice_with_trailing_full_dims(%src: tensor<60xf32>) -> tensor<2x5x2xf32> {
+  %expand = tensor.expand_shape %src [[0, 1, 2]] output_shape [6, 5, 2] : tensor<60xf32> into tensor<6x5x2xf32>
+  %extract = tensor.extract_slice %expand[0, 0, 0][2, 5, 2][1, 1, 1] : tensor<6x5x2xf32> to tensor<2x5x2xf32>
+  return %extract : tensor<2x5x2xf32>
+}
+
+// CHECK-LABEL:   func.func @bubble_up_extract_slice_dont_fold_linearize_index(
+// CHECK-SAME:                 %[[SRC:.*]]: tensor<60xf32>,
+// CHECK-SAME:                 %[[OFFSET_0:.*]]: index,
+// CHECK-SAME:                 %[[OFFSET_1:.*]]: index) -> tensor<1x1x5xf32> {
+// CHECK:           %[[C1:.+]] = arith.constant 5 : index
+// CHECK:           %[[LINEARIZE:.*]] = affine.linearize_index disjoint {{\[}}%[[OFFSET_0]], %[[OFFSET_1]], %[[C1]]] by (2, 3, 10) : index
+// CHECK:           %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]][%[[LINEARIZE]]] [5] [1] : tensor<60xf32> to tensor<5xf32>
+// CHECK:           %[[EXPAND:.*]] = tensor.expand_shape %[[EXTRACT]] {{\[\[}}0, 1, 2]] output_shape [1, 1, 5] : tensor<5xf32> into tensor<1x1x5xf32>
+// CHECK:           return %[[EXPAND]] : tensor<1x1x5xf32>
+func.func @bubble_up_extract_slice_dont_fold_linearize_index(%src: tensor<60xf32>, %offset_0 : index, %offset_1 : index) -> tensor<1x1x5xf32> {
+  %expand = tensor.expand_shape %src [[0, 1, 2]] output_shape [2, 3, 10] : tensor<60xf32> into tensor<2x3x10xf32>
+  %extract = tensor.extract_slice %expand[%offset_0, %offset_1, 5][1, 1, 5][1, 1, 1] : tensor<2x3x10xf32> to tensor<1x1x5xf32>
+  return %extract : tensor<1x1x5xf32>
+}
+
+// CHECK-LABEL:   func.func @bubble_up_extract_slice_not_all_dims_expanded(
+// CHECK-SAME:                %[[SRC:.*]]: tensor<60x12xf32>) -> tensor<1x1x5x12xf32> {
+// CHECK-DAG:       %[[C5:.+]] = arith.constant 5 : index
+// CHECK-DAG:       %[[C0:.+]] = arith.constant 0 : index
+// CHECK:           %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]]{{\[}}%[[C5]], %[[C0]]] [5, 12] [1, 1] : tensor<60x12xf32> to tensor<5x12xf32>
+// CHECK:           %[[EXPAND:.*]] = tensor.expand_shape %[[EXTRACT]] {{\[\[}}0, 1, 2], [3]] output_shape [1, 1, 5, 12] : tensor<5x12xf32> into tensor<1x1x5x12xf32>
+// CHECK:           return %[[EXPAND]] : tensor<1x1x5x12xf32>
+func.func @bubble_up_extract_slice_not_all_dims_expanded(%src: tensor<60x12xf32>) -> tensor<1x1x5x12xf32> {
+  %expand = tensor.expand_shape %src [[0, 1, 2], [3]] output_shape [2, 3, 10, 12] : tensor<60x12xf32> into tensor<2x3x10x12xf32>
+  %extract = tensor.extract_slice %expand[0, 0, 5, 0][1, 1, 5, 12][1, 1, 1, 1] : tensor<2x3x10x12xf32> to tensor<1x1x5x12xf32>
+  return %extract : tensor<1x1x5x12xf32>
+}
+
+// CHECK-LABEL:   func.func @bubble_up_extract_slice_affine_apply_not_folded(
+// CHECK-SAME:                   %[[SRC:.*]]: tensor<60xf32>,
+// CHECK-SAME:                   %[[SLICE_SIZE:.*]]: index) -> tensor<?x5x2xf32> {
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[AFFINE_APPLY:.*]] = affine.apply #map(){{\[}}%[[SLICE_SIZE]]]
+// CHECK:           %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]]{{\[}}%[[C0]]] {{\[}}%[[AFFINE_APPLY]]] [1] : tensor<60xf32> to tensor<?xf32>
+// CHECK:           %[[EXPAND:.*]] = tensor.expand_shape %[[EXTRACT]] {{\[\[}}0, 1, 2]] output_shape {{\[}}%[[SLICE_SIZE]], 5, 2] : tensor<?xf32> into tensor<?x5x2xf32>
+// CHECK:           return %[[EXPAND]] : tensor<?x5x2xf32>
+func.func @bubble_up_extract_slice_affine_apply_not_folded(%src: tensor<60xf32>, %slice_size : index) -> tensor<?x5x2xf32> {
+  %expand = tensor.expand_shape %src [[0, 1, 2]] output_shape [6, 5, 2] : tensor<60xf32> into tensor<6x5x2xf32>
+  %extract = tensor.extract_slice %expand[0, 0, 0][%slice_size, 5, 2][1, 1, 1] : tensor<6x5x2xf32> to tensor<?x5x2xf32>
+  return %extract : tensor<?x5x2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) {
+    %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
+    transform.apply_patterns to %func_op {
+      transform.apply_patterns.tensor.bubble_up_extract_slice
+    } : !transform.op<"func.func">
+    transform.yield
+  }
+}
        
    
    
More information about the Mlir-commits
mailing list