[Mlir-commits] [mlir] 5711957 - [mlir][Tensor] Add rewrites to extract slices through `tensor.collape_shape`

Christopher Bate llvmlistbot at llvm.org
Fri Sep 2 10:29:14 PDT 2022


Author: Christopher Bate
Date: 2022-09-02T11:29:04-06:00
New Revision: 5711957875738c1318f89afd7bf4be388f85a087

URL: https://github.com/llvm/llvm-project/commit/5711957875738c1318f89afd7bf4be388f85a087
DIFF: https://github.com/llvm/llvm-project/commit/5711957875738c1318f89afd7bf4be388f85a087.diff

LOG: [mlir][Tensor] Add rewrites to extract slices through `tensor.collape_shape`

This change adds a set of utilities to replace the result of a
`tensor.collapse_shape -> tensor.extract_slice` chain with the
equivalent result formed by aggregating slices of the
`tensor.collapse_shape` source. In general, it is not possible to
commute `extract_slice` and `collapse_shape` if linearized dimensions
are sliced. The i-th dimension of the `tensor.collapse_shape`
result is a "linearized sliced dimension" if:

1) Reassociation indices of tensor.collapse_shape in the i'th position
   is greater than size 1 (multiple dimensions of the input are collapsed)
2) The i-th dimension is sliced by `tensor.extract_slice`.

We can work around this by stitching together the result of
`tensor.extract_slice` by iterating over any linearized sliced dimensions.
This is equivalent to "tiling" the linearized-and-sliced dimensions of
the `tensor.collapse_shape` operation in order to manifest the result
tile (the result of the `tensor.extract_slice`). The user of the
utilities must provide the mechanism to create the tiling (e.g. a loop).
In the tests, it is demonstrated how to apply the utilities using either
`scf.for` or `scf.foreach_thread`.

The below example illustrates the pattern using `scf.for`:

```
%0 = linalg.generic ... -> tensor<3x7x11x10xf32>
%1 = tensor.collapse_shape %0 [[0, 1, 2], [3]] : ... to tensor<341x10xf32>
%2 = tensor.extract_slice %1 [13, 0] [10, 10] [2, 1] : .... tensor<10x10xf32>
```

We can construct %2 by generating the following IR:

```
%dest = linalg.init_tensor() : tensor<10x10xf32>
%2 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%arg0) -> tensor<10x10xf32> {
   // Step 1: Map this output idx (%iv) to a multi-index for the input (%3):
   %linear_index = affine.apply affine_map<(d0)[]->(d0*2 + 11)>(%iv)
   %3:3 = arith.delinearize_index %iv into (3, 7, 11)
   // Step 2: Extract the slice from the input
   %4 = tensor.extract_slice %0 [%3#0, %3#1, %3#2, 0] [1, 1, 1, 10] [1, 1, 1, 1] :
         tensor<3x7x11x10xf32> to tensor<1x1x1x10xf32>
   %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] :
         tensor<1x1x1x10xf32> into tensor<1x10xf32>
   // Step 3: Insert the slice into the destination
   %6 = tensor.insert_slice %5 into %arg0 [%iv, 0] [1, 10] [1, 1] :
         tensor<1x10xf32> into tensor<10x10xf32>
   scf.yield %6 : tensor<10x10xf32>
}
```

The pattern was discussed in the RFC here: https://discourse.llvm.org/t/rfc-tensor-extracting-slices-from-tensor-collapse-shape/64034

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D129699

Added: 
    mlir/include/mlir/Dialect/Tensor/Transforms/TransformUtils.h
    mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshape.cpp
    mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir

Modified: 
    mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
    mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
    mlir/include/mlir/Interfaces/ViewLikeInterface.h
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
    mlir/lib/Dialect/Utils/CMakeLists.txt
    mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
    mlir/lib/Interfaces/ViewLikeInterface.cpp
    mlir/test/lib/Dialect/Tensor/CMakeLists.txt
    mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 9a0bbf690cf0..2a47f9bdbb11 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -283,7 +283,11 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice",
     // Build an ExtractSliceOp with dynamic entries and inferred result type.
     OpBuilder<(ins "RankedTensorType":$resultType, "Value":$source,
       "ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides,
-      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+    // Build an ExtractSliceOp with mixed static and dynamic entries packed in
+    // a Range vector.
+    OpBuilder<(ins "Value":$source, "ArrayRef<Range>":$ranges,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,    
   ];
 
   let extraClassDeclaration = extraBaseClassDeclaration # [{
@@ -601,6 +605,11 @@ def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [
     // Build a InsertSliceOp with dynamic entries.
     OpBuilder<(ins "Value":$source, "Value":$dest,
       "ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+    // Build an InsertSliceOp with mixed static and dynamic entries packed in
+    // a Range vector.
+    OpBuilder<(ins "Value":$source, "Value":$dest,
+      "ArrayRef<Range>":$ranges,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>
   ];
 
@@ -1199,7 +1208,11 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
       "ArrayRef<OpFoldResult>":$offsets, "ArrayRef<OpFoldResult>":$sizes,
       "ArrayRef<OpFoldResult>":$strides,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
-
+    // Build a ParallelInsertSliceOp with mixed static and dynamic entries
+    // packed into a Range vector.
+    OpBuilder<(ins "Value":$source, "Value":$dest,
+      "ArrayRef<Range>":$ranges,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
     // Build a ParallelInsertSliceOp with dynamic entries.
     OpBuilder<(ins "Value":$source, "Value":$dest,
       "ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides,

diff  --git a/mlir/include/mlir/Dialect/Tensor/Transforms/TransformUtils.h b/mlir/include/mlir/Dialect/Tensor/Transforms/TransformUtils.h
new file mode 100644
index 000000000000..2ca556275af1
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/TransformUtils.h
@@ -0,0 +1,210 @@
+//===- TransformsUtils.h - Tensor Transformation Utilities-------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_DIALECT_TENSOR_TRANSFORMS_TRANSFORMUTILS_H
+#define MLIR_DIALECT_TENSOR_TRANSFORMS_TRANSFORMUTILS_H
+
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+namespace tensor {
+
+//===----------------------------------------------------------------------===//
+// Extract slice from `tensor.collapse_shape`
+//===----------------------------------------------------------------------===//
+
+/// This class assists with generating IR required to materialize an
+/// arbitrary-sized slice from the result of a CollapseShapeOp. In order to
+/// accomplish this, a loop nest or similar operation must be created by the
+/// caller. The purpose of the loop nest is to generate a "tiling by 1" of all
+/// sliced dimensions. The "tiling by 1" assembles all elements of the result
+/// tile over dimensions that would have been impossible to directly slice.
+///
+/// The class provides three methods:
+/// 1. `ExtractSliceFromCollapseHelper::create`: emits IR that should
+/// appear before the loop nest and populates the internal state.
+/// 2. `ExtractSliceFromCollapseHelper::getIterationSpaceSizes`: returns
+/// parameters used by the caller to construct the loop nest.
+/// 3. `ExtractSliceFromCollapseHelper::emitLoopNestBody`:
+/// emits IR to construct a "size-1 tile" of the desired result and returns a
+/// set of ranges where the tile should be inserted into the destination
+/// tensor.
+///
+/// ### Intended usage:
+///
+/// The caller should first call `ExtractSliceFromCollapseHelper::create` and
+/// then create a destination tensor that is the same size as the desired slice.
+/// The caller then creates a loop nest that iterates over the multi-dimensional
+/// iteration space defined by `[0, ub[0]) x [0, ub[1]] x ... x [0, ub[N-1]]`
+/// where `ub` is the upper bound given by
+/// `ExtractSliceFromCollapseHelper::getIterationSpaceSizes`. Inside the body of
+/// the loop nest, the caller should call
+/// `ExtractSliceFromCollapseHelper::emitLoopNestBody` and provide the induction
+/// variables. This returns a sub-tile and a set of ranges that describe where
+/// this tile should be inserted into the result by the caller. For a complete
+/// example of usage, see the examples in the TestTensorTransforms pass.
+///
+/// ### Example:
+/// Consider the following IR:
+/// ```
+/// %0 = linalg.generic ... -> tensor<3x?x?x11x?xf32>
+/// %1 = tensor.collapse_shape %0 [[0, 1, 2], [3, 4]]
+///        : tensor<3x?x?x11x?xf32> into tensor<?x?xf32>
+/// %2 = tensor.extract_slice %1 [%offt0, %offt1][%size0, %size1][1, 1]
+///        : tensor<?x?xf32> to tensor<?x?xf32>
+/// ```
+///
+/// We can construct %2 by generating the following, which only uses `%0`:
+///
+/// ```
+/// %dest = linalg.init_tensor [%size0, %size1] : tensor<?x?xf32>
+/// %1 = tensor.dim %0, %c1 : tensor<3x?x?x11x?xf32>
+/// %2 = tensor.dim %0, %c2 : tensor<3x?x?x11x?xf32>
+/// %3 = tensor.dim %0, %c4 : tensor<3x?x?x11x?xf32>
+///
+/// %result = scf.for %iv0 = %c0 to %arg2 step %c1 iter_args(%arg6 = %dest) ->
+///                                                  (tensor<?x?xf32>) {
+///   %5 = scf.for %iv1 = %c0 to %arg4 step %c1 iter_args(%arg8 = %arg6)
+///                                                  -> (tensor<?x?xf32>) {
+///     %lin0 = (affine.apply) %iv0 + %offt0
+///     %lin1 = (affine.apply) %iv1 + %offt1
+///
+///     %mi0:3 = affine.delinearize_index %lin0 into (%c3, %1, %2)
+///     %mi1:2 = affine.delinearize_index %lin1 into (%c11, %3)
+///
+///     %sub_tile = tensor.extract_slice %0
+///                    [%mi0#0, %mi0#1, %mi0#2, %mi1#0, %mi1#1]
+///                    [1, 1, 1, 1, 1]
+///                    [1, 1, 1, 1, 1]
+///            : tensor<3x?x?x11x?xf32> to tensor<1x1x1x1x1xf32>
+///     %sub_tile_collapsed = tensor.collapse_shape %sub_tile
+///             [[0, 1, 2], [3, 4]]
+///            : tensor<1x1x1x1x1xf32> into tensor<1x1xf3
+///
+///     %12 = tensor.insert_slice %sub_tile_collapsed into
+///             %arg8[%iv0, %iv1] [1, 1] [1, 1]
+///             : tensor<1x1xf32> into tensor<?x?xf32>
+///     scf.yield %12 : tensor<?x?xf32>
+///   }
+///   scf.yield %5 : tensor<?x?xf32>
+/// }
+/// ```
+///
+/// ### Explanation of example:
+///
+/// Each step above is explained below.
+///
+/// #### Step 0: Create %dest and materialization of shapes.
+/// This step is self-explanatory and performed by the caller. It can be
+/// done before or after calling `ExtractSliceFromCollapseHelper::create`,
+/// which materializes the source shape (`%0, %1, %2`).
+///
+/// #### Step 1: Create loop nest.
+///
+/// The caller creates the loop nest (depicted here is `scf.for`, but any other
+/// similar op can be used). The iteration should start at zero and proceed with
+/// step size 1 to the upper bounds given by
+/// `ExtractSliceFromCollapseHelper::getIterationSpaceSizes`. This forms the
+/// basis for the "tiling by 1".
+///
+/// #### Step 2: Transform (%iv0, %iv1) from the index space of %3 to the index
+/// space of %0.
+///
+/// This step is performed by
+/// `ExtractSliceFromCollapseHelper::emitLoopNestBody`.
+///
+/// The induction variables `%iv0` and `%iv1` live in the
+/// index space of %2 (for dimensions 0 and 1, respectively). `%lin0` and
+/// `%lin1` are the result of inverting or resolve the index space
+/// transformation represented by the slice operation, accounting for offset and
+/// stride. Subsequently, `%mi0` and `%mi1` are the result of applying the
+/// inverse index space transformation represented by `tensor.collapse_shape`.
+/// This is accomplished using `affine.delinearize_index`. Note that %iv0
+/// and %iv1 now correspond to multi-indices `%mi0:3` and `%mi1:2`.
+///
+/// #### Step 3: Extract a sub-tile slice from the source.
+///
+/// This step is also performed by
+/// `ExtractSliceFromCollapseHelper::emitLoopNestBody`.
+///
+/// The indices `%mi0` and `%mi1` are used to extract a slice from %0.  This
+/// slice is then collapsed down to match the result rank.
+///
+/// #### Step 4: Insert sub-tile into the destination
+///
+/// This step is performed by the caller using the results of
+/// `ExtractSliceFromCollapseHelper::emitLoopNestBody`.
+///
+/// In the above example, the slice insertion parameters are straightforward,
+/// but in other possible situations, the slice parameters are more complicated,
+/// which is why this helper calculates them for the caller. These other
+/// situations correspond to:
+/// 1. The presence of linearized dimensions that are not sliced
+/// 2. The presence of non-linearized dimensions that are sliced.
+class ExtractSliceFromCollapseHelper {
+public:
+  /// Given a CollapseShapeOp and a set of ranges describing the desired slice
+  /// of its result, emits IR to materialize the shapes of the input and output
+  /// tensors, and returns an instance of the initialized class. Returns failure
+  /// if the slice is rank-reducing.
+  static FailureOr<ExtractSliceFromCollapseHelper>
+  create(OpBuilder &b, tensor::CollapseShapeOp op, ArrayRef<Range> sliceParams);
+
+  /// Given a CollapseShapeOp and an ExtractSliceOp acting on its result, emits
+  /// IR to materialize the shapes of the input and output tensors of the
+  /// CollapseShapeOp, and returns an instance of the initialized class. Returns
+  /// failure if the slice is rank-reducing.
+  static FailureOr<ExtractSliceFromCollapseHelper>
+  create(OpBuilder &b, tensor::CollapseShapeOp collapseOp,
+         tensor::ExtractSliceOp extractOp);
+
+  ExtractSliceFromCollapseHelper(
+      tensor::CollapseShapeOp collapseShapeOp,
+      ArrayRef<OpFoldResult> collapseShapeInputShape,
+      ArrayRef<OpFoldResult> collapseShapeOutputShape,
+      ArrayRef<Range> extractSliceParams,
+      const llvm::SmallBitVector &linearizedDimensions,
+      const llvm::SmallBitVector &slicedDimensions, ArrayRef<Value> tiledSizes)
+      : collapseShapeOp(collapseShapeOp),
+        collapseShapeInputShape(collapseShapeInputShape),
+        collapseShapeOutputShape(collapseShapeOutputShape),
+        sliceParams(extractSliceParams),
+        linearizedDimensions(linearizedDimensions),
+        slicedDimensions(slicedDimensions), tiledSizes(tiledSizes) {}
+
+  /// Return the upper bounds of the iteration space (with 0 offset and stride
+  /// 1) required to create the desired slice. Note that this is not the same
+  /// as the `sizes` parameters of the ExtractSliceOp because not all dimensions
+  /// of the slice are required to be tiled to form the result.
+  const SmallVector<Value> &getIterationSpaceSizes() { return tiledSizes; }
+
+  /// Generates the IR inside of the caller's loop nest for 1) inverting the
+  /// index mappings of the ExtractSliceOp->CollapseShapeOp chain and 2)
+  /// extracting the CollapseShapeOp source tensor tile for this specified
+  /// iteration space point `tileInductionVars` and 3) calculating where to
+  /// insert the extracted tile. The returned pair consists of the results of
+  /// (2) and (3) and should be used by the caller to insert into the
+  /// destination tensor.
+  std::pair<Value, SmallVector<Range>>
+  emitLoopNestBody(OpBuilder &builder, Location loc,
+                   ValueRange tileInductionVars);
+
+private:
+  tensor::CollapseShapeOp collapseShapeOp;
+  SmallVector<OpFoldResult> collapseShapeInputShape;
+  SmallVector<OpFoldResult> collapseShapeOutputShape;
+  SmallVector<Range> sliceParams;
+  llvm::SmallBitVector linearizedDimensions;
+  llvm::SmallBitVector slicedDimensions;
+  SmallVector<Value> tiledSizes;
+};
+
+} // namespace tensor
+} // namespace mlir
+
+#endif // MLIR_DIALECT_TENSOR_TRANSFORMS_TRANSFORMUTILS_H

diff  --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 61f7aea9f526..2196a64ac49e 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -16,6 +16,7 @@
 
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/ViewLikeInterface.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/StringRef.h"
 
@@ -373,6 +374,90 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
   }
 };
 
+/// The input parameters `offsets`, `sizes`, `strides` specify a rectangular
+/// non rank-reducing slice of the collapse_shape output. Try to find which
+/// dimensions have been sliced and which dimensions are not sliced (offset = 0,
+/// size = dim, size = 1). Note that this conservative as it cannot detect if a
+/// dynamic size corresponds to the full tensor dimension or not.
+llvm::SmallBitVector getSlicedDimensions(ArrayRef<OpFoldResult> sliceInputShape,
+                                         ArrayRef<Range> sliceParams);
+
+/// Determine which dimensions are linearized by a `tensor.collapse_shape` op by
+/// inspecting its reassociation indices.
+llvm::SmallBitVector
+getLinearizedDimensions(ArrayRef<ReassociationIndices> reassociationIndices);
+
+/// Given the parameters for both operations in a `CollapseShape->ExtractSlice`
+/// chain and reified source and result shapes of the CollapseShapeOp, this
+/// class provides two functions that assist with directly forming the result
+/// of the extract slice by "tiling the CollapseShapeOp by 1".
+//// Example:
+// clang-format off
+/// ```
+/// %0 = linalg.generic ... -> tensor<3x7x11x10xf32>
+/// %1 = tensor.collapse_shape %0 [[0, 1, 2], [3]] : ... to tensor<341x10xf32>
+/// %2 = tensor.extract_slice %1 [13, 0] [10, 10] [2, 1] : .... tensor<10x10xf32>
+/// ```
+/// This class helps build the below IR to replace %2:
+/// ```
+/// %dest = linalg.init_tensor() : tensor<10x10xf32>
+/// %2 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%arg0) -> tensor<10x10xf32> {
+///    %linear_index = affine.apply affine_map<(d0)[]->(d0*2 + 11)>(%iv)
+///    %3:3 = arith.delinearize_index %iv into (3, 7, 11)
+///
+///    // This function takes %3 (multiIndices) and the parameters for the slice below.
+///    %4 = tensor.extract_slice %0 [%3#0, %3#1, %3#2, 0] [1, 1, 1, 10] [1, 1, 1, 1] :
+///          tensor<3x7x11x10xf32> to tensor<1x1x1x10xf32>
+///
+///    %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] : 
+///          tensor<1x1x1x10xf32> into tensor<1x10xf32>
+///    %6 = tensor.insert_slice %5 into %arg0 [%iv, 0] [1, 10] [1, 1] :
+///          tensor<1x10xf32> into tensor<10x10xf32>
+///    scf.yield %6 : tensor<10x10xf32>
+/// }
+/// ```
+// clang-format on
+class SliceFromCollapseHelper {
+public:
+  SliceFromCollapseHelper(ArrayRef<ReassociationIndices> reassociationIndices,
+                          ArrayRef<OpFoldResult> collapseShapeInputShape,
+                          ArrayRef<OpFoldResult> collapseShapeOutputShape,
+                          ArrayRef<Range> extractSliceParams)
+      : reassociationIndices(reassociationIndices),
+        collapseShapeInputShape(collapseShapeInputShape),
+        collapseShapeOutputShape(collapseShapeOutputShape),
+        sliceParams(extractSliceParams),
+        linearizedDimensions(getLinearizedDimensions(reassociationIndices)),
+        slicedDimensions(getSlicedDimensions(collapseShapeOutputShape,
+                                             extractSliceParams)) {}
+
+  /// This function takes multi-indices and maps them to ExtractSlice parameters
+  /// in the index space of the CollapseShape's source tensor. This function's
+  /// signature can be described by `(D_0, D_1,.. D_{n-1}) -> (offsets, sizes,
+  /// strides)` where `n` the number of "tiled dimensions", which are the
+  /// dimensions of the output that are linearized by the collapse shape op and
+  /// are also sliced. Each `D_i` is a tuple that must represent a valid
+  /// multi-index for the `i-th` tiled dimension. In the example above, there is
+  /// only one tiled dimension (D_0) and `arith.delinearize_index` produces the
+  /// multi-index (%3) that would be passed to this function to generate the
+  /// parameters for the `tensor.extract_slice` op (%4).
+  SmallVector<Range> getExtractSliceParams(ArrayRef<ValueRange> multiIndices);
+
+  /// This function takes indices in the index space of the "tiled dimensions"
+  /// described above and returns a set of Range variables that describe how the
+  /// slice should be inserted into the destination. In the example above, `%iv`
+  /// would be passed to this function to generate the parameters for the
+  /// `tensor.insert_slice` op producing %6.
+  SmallVector<Range> getInsertSliceParams(ValueRange tileIndices);
+
+private:
+  SmallVector<ReassociationIndices> reassociationIndices;
+  SmallVector<OpFoldResult> collapseShapeInputShape;
+  SmallVector<OpFoldResult> collapseShapeOutputShape;
+  SmallVector<Range> sliceParams;
+  llvm::SmallBitVector linearizedDimensions;
+  llvm::SmallBitVector slicedDimensions;
+};
 } // namespace mlir
 
 #endif // MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H

diff  --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
index 4bb13ce75060..ea50092ea5c9 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
@@ -30,6 +30,13 @@ struct Range {
   OpFoldResult stride;
 };
 
+/// Given an array of Range values, return a tuple of (offset vector, sizes
+/// vector, and strides vector) formed by separating out the individual elements
+/// of each range.
+std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
+           SmallVector<OpFoldResult>>
+getOffsetsSizesAndStrides(ArrayRef<Range> ranges);
+
 /// Return a vector of OpFoldResults given the special value
 /// that indicates whether of the value is dynamic or not.
 SmallVector<OpFoldResult, 4> getMixedValues(ArrayAttr staticValues,

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index cd4c4b9f03a5..fdbbcdabcd83 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1124,6 +1124,15 @@ void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
   build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
 }
 
+/// Build an ExtractSliceOp with mixed static and dynamic entries packed into a
+/// Range vector.
+void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
+                           ArrayRef<Range> ranges,
+                           ArrayRef<NamedAttribute> attrs) {
+  auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
+  build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
+}
+
 /// Build an ExtractSliceOp with dynamic entries and custom result type. If the
 /// type passed is nullptr, it is inferred.
 void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
@@ -1511,6 +1520,15 @@ void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
   result.addAttributes(attrs);
 }
 
+/// Build an InsertSliceOp with mixed static and dynamic entries packed into a
+/// Range vector.
+void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
+                          Value dest, ArrayRef<Range> ranges,
+                          ArrayRef<NamedAttribute> attrs) {
+  auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
+  build(b, result, source, dest, offsets, sizes, strides, attrs);
+}
+
 // Build a InsertSliceOp with dynamic entries.
 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
                           Value dest, ValueRange offsets, ValueRange sizes,
@@ -2273,6 +2291,16 @@ void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
   result.addAttributes(attrs);
 }
 
+/// Build an ParallelInsertSliceOp with mixed static and dynamic entries packed
+/// into a Range vector.
+void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
+                                  Value source, Value dest,
+                                  ArrayRef<Range> ranges,
+                                  ArrayRef<NamedAttribute> attrs) {
+  auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
+  build(b, result, source, dest, offsets, sizes, strides, attrs);
+}
+
 // Build a ParallelInsertSliceOp with dynamic entries.
 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
                                   Value source, Value dest, ValueRange offsets,

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
index 66e4cc906f23..0b200e03226d 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_mlir_dialect_library(MLIRTensorTransforms
   BufferizableOpInterfaceImpl.cpp
   Bufferize.cpp
+  ExtractSliceFromReshape.cpp
   SplitPadding.cpp
   SwapExtractSliceWithProducer.cpp
 

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshape.cpp b/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshape.cpp
new file mode 100644
index 000000000000..4809cc40c661
--- /dev/null
+++ b/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshape.cpp
@@ -0,0 +1,181 @@
+//===- ExtractSliceFromReshape.cpp - Slice reshape rewrites-------*- C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements rewrites that replace slices of reshape results with
+// aggregated slices of the reshape source.
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Transforms/TransformUtils.h"
+#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
+#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
+#include "llvm/ADT/STLExtras.h"
+
+using namespace mlir;
+using namespace mlir::tensor;
+
+/// Get the dimension size of a value of RankedTensor type at the
+OpFoldResult getShapeDimSize(OpBuilder &b, Location loc, Value rankedTensor,
+                             int64_t dimIdx) {
+  RankedTensorType tensorType = rankedTensor.getType().cast<RankedTensorType>();
+  if (!tensorType.isDynamicDim(dimIdx)) {
+    return b.getIndexAttr(tensorType.getDimSize(dimIdx));
+  }
+  Value idxValue = b.create<arith::ConstantIndexOp>(loc, dimIdx);
+  return b.createOrFold<tensor::DimOp>(loc, rankedTensor, idxValue);
+}
+
+/// Get all the dimension sizes of a value of RankedTensor type.
+static SmallVector<OpFoldResult> getShapeDimSizes(OpBuilder &b, Location loc,
+                                                  Value rankedTensor) {
+  SmallVector<OpFoldResult> dimSizes;
+  RankedTensorType tensorType = rankedTensor.getType().cast<RankedTensorType>();
+  for (unsigned i = 0; i < tensorType.getRank(); i++)
+    dimSizes.push_back(getShapeDimSize(b, loc, rankedTensor, i));
+  return dimSizes;
+}
+
+/// A tuple that represents (dimension number, dimension value).
+using DimAndIndex = std::tuple<unsigned, Value>;
+
+/// Transform `dimAndIndex` from the output index space of a (non-rank-reducing)
+/// slice described by `sliceParams` into the input index space.
+static DimAndIndex invertSliceIndexing(OpBuilder &b, Location loc,
+                                       ArrayRef<Range> sliceParams,
+                                       const DimAndIndex &dimAndIndex) {
+  AffineExpr d0, s0, s1;
+  bindDims(b.getContext(), d0);
+  bindSymbols(b.getContext(), s0, s1);
+  auto [dim, indexValue] = dimAndIndex;
+  assert(dim < sliceParams.size() && "slice should be non rank-reducing");
+  return std::make_pair(
+      dim,
+      makeComposedAffineApply(
+          b, loc, s0 + d0 * s1,
+          {indexValue,
+           getValueOrCreateConstantIndexOp(b, loc, sliceParams[dim].offset),
+           getValueOrCreateConstantIndexOp(b, loc, sliceParams[dim].stride)}));
+}
+
+/// Transform `dimAndIndex` from the result tensor index space of a
+/// CollapseShapeOp to the source tensor index space.
+static ValueRange invertCollapseShapeIndexing(
+    OpBuilder &b, Location loc, ArrayRef<ReassociationIndices> reassociation,
+    ArrayRef<OpFoldResult> reshapeSourceShape, const DimAndIndex &dimAndIndex) {
+  const auto &[dim, indexValue] = dimAndIndex;
+  SmallVector<OpFoldResult> basis;
+  for (int64_t i : reassociation[dim])
+    basis.push_back(reshapeSourceShape[i]);
+  auto delinearized =
+      b.create<AffineDelinearizeIndexOp>(loc, indexValue, basis);
+  return delinearized->getResults();
+}
+
+FailureOr<ExtractSliceFromCollapseHelper>
+tensor::ExtractSliceFromCollapseHelper::create(
+    OpBuilder &b, tensor::CollapseShapeOp collapseOp,
+    tensor::ExtractSliceOp extractOp) {
+  if (extractOp.getSource().getDefiningOp<tensor::CollapseShapeOp>() !=
+      collapseOp)
+    return failure();
+  SmallVector<Range> ranges;
+  ranges.reserve(extractOp.getSourceType().getRank());
+  for (const auto &[o, s, st] :
+       llvm::zip(extractOp.getMixedOffsets(), extractOp.getMixedSizes(),
+                 extractOp.getMixedStrides())) {
+    ranges.push_back({o, s, st});
+  }
+  return ExtractSliceFromCollapseHelper::create(b, collapseOp, ranges);
+}
+
+FailureOr<ExtractSliceFromCollapseHelper>
+tensor::ExtractSliceFromCollapseHelper::create(OpBuilder &b,
+                                               tensor::CollapseShapeOp op,
+                                               ArrayRef<Range> sliceParams) {
+
+  // Materialize the output shape of the collapse_shape operation. This will
+  // create IR describing the output shape in terms of the input shape.
+  ReifiedRankedShapedTypeDims reifiedShapes;
+  ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface =
+      dyn_cast<ReifyRankedShapedTypeOpInterface>(op.getOperation());
+  if (failed(reifyShapedTypeInterface.reifyResultShapes(b, reifiedShapes)))
+    return failure();
+  SmallVector<OpFoldResult> collapseShapeOutputShape =
+      getAsOpFoldResult(reifiedShapes[0]);
+  SmallVector<ReassociationIndices> reassociationIndices =
+      op.getReassociationIndices();
+
+  // Determine which of the CollapseShapeOp's result dimensions are sliced
+  // and/or linearized.
+  llvm::SmallBitVector linearizedDimensions =
+      getLinearizedDimensions(reassociationIndices);
+  llvm::SmallBitVector slicedDimensions =
+      getSlicedDimensions(collapseShapeOutputShape, sliceParams);
+
+  auto collapseShapeInputShape = getShapeDimSizes(b, op.getLoc(), op.getSrc());
+
+  SmallVector<OpFoldResult> srcShape =
+      getShapeDimSizes(b, op->getLoc(), op.getSrc());
+
+  SmallVector<Value> tileSizes;
+  for (unsigned i = 0; i < sliceParams.size(); i++) {
+    if (slicedDimensions[i] && linearizedDimensions[i])
+      tileSizes.push_back(
+          getValueOrCreateConstantIndexOp(b, op.getLoc(), sliceParams[i].size));
+  }
+
+  return ExtractSliceFromCollapseHelper(
+      op, collapseShapeInputShape, collapseShapeOutputShape, sliceParams,
+      linearizedDimensions, slicedDimensions, tileSizes);
+}
+
+std::pair<Value, SmallVector<Range>>
+tensor::ExtractSliceFromCollapseHelper::emitLoopNestBody(
+    OpBuilder &builder, Location loc, ValueRange tileInductionVars) {
+  // Create the helper class for forming the slice parameters.
+  const SmallVector<ReassociationIndices> reassociationIndices =
+      collapseShapeOp.getReassociationIndices();
+  SliceFromCollapseHelper helper(reassociationIndices, collapseShapeInputShape,
+                                 collapseShapeOutputShape, sliceParams);
+
+  // Get the indices of the tiled dims (linearized by the collapse_shape
+  // and sliced by the extract_slice) invert the index spaces
+  // transformations.
+  SmallVector<ValueRange> multiIndices;
+  unsigned loopIdx = 0;
+  for (unsigned i = 0, e = linearizedDimensions.size(); i < e; i++) {
+    if (linearizedDimensions[i] && slicedDimensions[i]) {
+      DimAndIndex tb =
+          invertSliceIndexing(builder, loc, sliceParams,
+                              std::make_tuple(i, tileInductionVars[loopIdx++]));
+      multiIndices.push_back(invertCollapseShapeIndexing(
+          builder, loc, reassociationIndices, collapseShapeInputShape, tb));
+    }
+  }
+
+  auto extractParams = helper.getExtractSliceParams(multiIndices);
+
+  Value subTileResult = builder.create<tensor::ExtractSliceOp>(
+      loc, collapseShapeOp.getSrc(), extractParams);
+
+  SmallVector<Range> insertParams =
+      helper.getInsertSliceParams(tileInductionVars);
+
+  // Collapse the dimensions of the source slice back down.
+  Value collapsedResult = builder.create<tensor::CollapseShapeOp>(
+      loc, subTileResult, reassociationIndices);
+  return std::make_pair(collapsedResult, insertParams);
+}

diff  --git a/mlir/lib/Dialect/Utils/CMakeLists.txt b/mlir/lib/Dialect/Utils/CMakeLists.txt
index f329afa8fa75..578c6a67be58 100644
--- a/mlir/lib/Dialect/Utils/CMakeLists.txt
+++ b/mlir/lib/Dialect/Utils/CMakeLists.txt
@@ -6,4 +6,5 @@ add_mlir_library(MLIRDialectUtils
 
   LINK_LIBS PUBLIC
   MLIRIR
+  MLIRViewLikeInterface
 )

diff  --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index b26b3d93541a..adc53dcb9743 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -8,8 +8,11 @@
 
 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
 
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/Interfaces/ViewLikeInterface.h"
 
 #include <numeric>
 
@@ -270,3 +273,88 @@ bool mlir::hasNonIdentityLayout(Type type) {
     return !memrefType.getLayout().isIdentity();
   return false;
 }
+
+llvm::SmallBitVector
+mlir::getSlicedDimensions(ArrayRef<OpFoldResult> sliceInputShape,
+                          ArrayRef<Range> sliceParams) {
+  assert(sliceParams.size() == sliceInputShape.size() &&
+         "only supports non rank-reducing case");
+  llvm::SmallBitVector mask(sliceInputShape.size());
+  unsigned idx = 0;
+  for (const auto &[offset, size, stride] : sliceParams) {
+    Optional<int64_t> offsetConst = getConstantIntValue(offset);
+    Optional<int64_t> strideConst = getConstantIntValue(stride);
+    mask[idx] = !isEqualConstantIntOrValue(size, sliceInputShape[idx]) ||
+                (!strideConst || *strideConst != 1) ||
+                (!offsetConst || *offsetConst != 0);
+    idx++;
+  }
+  return mask;
+}
+
+llvm::SmallBitVector mlir::getLinearizedDimensions(
+    ArrayRef<ReassociationIndices> reassociationIndices) {
+  llvm::SmallBitVector result(reassociationIndices.size());
+  for (const auto &it : llvm::enumerate(reassociationIndices))
+    result[it.index()] = it.value().size() > 1;
+  return result;
+}
+
+SmallVector<Range> SliceFromCollapseHelper::getExtractSliceParams(
+    ArrayRef<ValueRange> multiIndices) {
+  assert(!multiIndices.empty() && !multiIndices[0].empty() &&
+         "multiIndices should not be empty");
+  unsigned loopIdx = 0;
+  MLIRContext *ctx = multiIndices[0][0].getContext();
+  auto oneAttr = IntegerAttr::get(IndexType::get(ctx), 1);
+  auto zeroAttr = IntegerAttr::get(IndexType::get(ctx), 0);
+  SmallVector<Range> offsetsSizesAndStrides;
+  offsetsSizesAndStrides.reserve(collapseShapeInputShape.size());
+  for (const auto &it : llvm::enumerate(reassociationIndices)) {
+    // Case 1: Linearized dimensions that have also been sliced. These
+    // are size of 1 because we are iterating over these dimensions. The
+    // offsets are exactly the de-linearized multi-indices.
+    if (slicedDimensions[it.index()] && linearizedDimensions[it.index()]) {
+      llvm::append_range(
+          offsetsSizesAndStrides,
+          llvm::map_range(multiIndices[loopIdx++], [&](Value v) -> Range {
+            return Range{getAsOpFoldResult(v), oneAttr, oneAttr};
+          }));
+      continue;
+    }
+
+    // Case 2: One or possibly multiple combined input dimensions, but we
+    // have proven that these are not sliced. In this case we just take
+    // the full extent of each dimension in the reassociation list.
+    if (linearizedDimensions[it.index()]) {
+      llvm::append_range(
+          offsetsSizesAndStrides,
+          llvm::map_range(it.value(), [&](int64_t idx) -> Range {
+            return {zeroAttr, collapseShapeInputShape[idx], oneAttr};
+          }));
+      continue;
+    }
+
+    // Case 3: A single index, but it may be sliced.
+    offsetsSizesAndStrides.push_back(sliceParams[it.index()]);
+  }
+  return offsetsSizesAndStrides;
+}
+
+SmallVector<Range>
+SliceFromCollapseHelper::getInsertSliceParams(ValueRange tileIndices) {
+  MLIRContext *ctx = tileIndices[0].getContext();
+  auto one = IntegerAttr::get(IndexType::get(ctx), 1);
+  auto zero = IntegerAttr::get(IndexType::get(ctx), 0);
+  SmallVector<Range> insertParams;
+  insertParams.reserve(linearizedDimensions.size());
+  unsigned loopIdx = 0;
+  for (unsigned i = 0; i < linearizedDimensions.size(); i++) {
+    if (linearizedDimensions[i] && slicedDimensions[i]) {
+      insertParams.push_back(Range{tileIndices[loopIdx++], one, one});
+      continue;
+    }
+    insertParams.push_back(Range{zero, sliceParams[i].size, one});
+  }
+  return insertParams;
+}

diff  --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp
index 89ebd8127172..6b5d579a339a 100644
--- a/mlir/lib/Interfaces/ViewLikeInterface.cpp
+++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp
@@ -17,6 +17,21 @@ using namespace mlir;
 /// Include the definitions of the loop-like interfaces.
 #include "mlir/Interfaces/ViewLikeInterface.cpp.inc"
 
+std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
+           SmallVector<OpFoldResult>>
+mlir::getOffsetsSizesAndStrides(ArrayRef<Range> ranges) {
+  SmallVector<OpFoldResult> offsets, sizes, strides;
+  offsets.reserve(ranges.size());
+  sizes.reserve(ranges.size());
+  strides.reserve(ranges.size());
+  for (const auto &[offset, size, stride] : ranges) {
+    offsets.push_back(offset);
+    sizes.push_back(size);
+    strides.push_back(stride);
+  }
+  return std::make_tuple(offsets, sizes, strides);
+}
+
 LogicalResult mlir::verifyListOfOperandsOrIntegers(
     Operation *op, StringRef name, unsigned numElements, ArrayAttr attr,
     ValueRange values, llvm::function_ref<bool(int64_t)> isDynamic) {

diff  --git a/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir b/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir
new file mode 100644
index 000000000000..d8ca129bf59a
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir
@@ -0,0 +1,164 @@
+// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-rewrite-extract-slice-from-collapse-shape %s | FileCheck %s
+// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns="test-rewrite-extract-slice-from-collapse-shape use-foreach" %s | FileCheck %s --check-prefix=FOREACH
+
+func.func @extract_slice_static(%input: tensor<3x5x7x11xf32>) -> tensor<20x11xf32> {
+  %collapsed = tensor.collapse_shape %input [[0, 1, 2], [3]] : tensor<3x5x7x11xf32> into tensor<105x11xf32>
+  %slice = tensor.extract_slice %collapsed [0, 0] [20, 11] [1, 1] : tensor<105x11xf32> to tensor<20x11xf32>
+  return %slice : tensor<20x11xf32>
+}
+
+//     CHECK: func.func @extract_slice_static(%[[arg0:.+]]:
+// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[c20:.+]] = arith.constant 20 : index
+// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[c3:.+]] = arith.constant 3 : index
+// CHECK-DAG: %[[c5:.+]] = arith.constant 5 : index
+// CHECK-DAG: %[[c7:.+]] = arith.constant 7 : index
+// CHECK-DAG: %[[init:.+]] = linalg.init_tensor [20, 11] :
+// CHECK-DAG: %[[tile:.+]] = scf.for %[[iv:.+]] = %[[c0]] to %[[c20]] step %[[c1]] iter_args(%[[iterArg:.+]] = %[[init]])
+//     CHECK:   %[[multiIndex:.+]]:3 = affine.delinearize_index %[[iv]] into (%[[c3]], %[[c5]], %[[c7]]
+//     CHECK:   %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0] [1, 1, 1, 11] [1, 1, 1, 1] : 
+//     CHECK:   %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3]{{\]}} : 
+//     CHECK:   %[[update:.+]] = tensor.insert_slice %[[sliceFlat]] into %[[iterArg]][%[[iv]], 0] [1, 11] [1, 1] : 
+//     CHECK:   scf.yield %[[update]] :
+//     CHECK: return %[[tile]]
+
+//     FOREACH: func.func @extract_slice_static(%[[arg0:.+]]:
+// FOREACH-DAG: %[[c20:.+]] = arith.constant 20 : index
+// FOREACH-DAG: %[[c3:.+]] = arith.constant 3 : index
+// FOREACH-DAG: %[[c5:.+]] = arith.constant 5 : index
+// FOREACH-DAG: %[[c7:.+]] = arith.constant 7 : index
+// FOREACH-DAG: %[[init:.+]] = linalg.init_tensor [20, 11] :
+//     FOREACH: %[[tile:.+]] = scf.foreach_thread (%[[iv:.+]]) in (%[[c20]]) shared_outs(%[[dest:.+]] = %[[init]])
+//     FOREACH:   %[[multiIndex:.+]]:3 = affine.delinearize_index %[[iv]] into (%[[c3]], %[[c5]], %[[c7]]
+//     FOREACH:   %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0] [1, 1, 1, 11] [1, 1, 1, 1] : 
+//     FOREACH:   %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3]{{\]}} : 
+//     FOREACH:   perform_concurrently
+// FOREACH-NEXT:   tensor.parallel_insert_slice %[[sliceFlat]] into %[[dest]][%[[iv]], 0] [1, 11] [1, 1] :
+//     FOREACH: return %[[tile]]
+
+// -----
+
+
+func.func @extract_slice_static_strided(%input: tensor<3x5x7x11xf32>) -> tensor<10x5xf32> {
+  %collapsed = tensor.collapse_shape %input [[0, 1, 2], [3]] : tensor<3x5x7x11xf32> into tensor<105x11xf32>
+  %slice = tensor.extract_slice %collapsed [13, 0] [10, 5] [2, 2] : tensor<105x11xf32> to tensor<10x5xf32>
+  return %slice : tensor<10x5xf32>
+}
+
+//     CHECK: #[[$map0:.+]] = affine_map<(d0) -> (d0 * 2 + 13)>
+//     CHECK: func.func @extract_slice_static_strided(%[[arg0:.+]]:
+// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[c10:.+]] = arith.constant 10 : index
+// CHECK-DAG: %[[c3:.+]] = arith.constant 3 : index
+// CHECK-DAG: %[[c5:.+]] = arith.constant 5 : index
+// CHECK-DAG: %[[c7:.+]] = arith.constant 7 : index
+//     CHECK: %[[init:.+]] = linalg.init_tensor [10, 5] :
+//     CHECK: %[[tile:.+]] = scf.for %[[iv:.+]] = %[[c0]] to %[[c10]] step %[[c1]] iter_args(%[[iterArg:.+]] = %[[init]])
+//     CHECK:   %[[inputIv:.+]] = affine.apply #[[$map0]](%[[iv]])
+//     CHECK:   %[[multiIndex:.+]]:3 = affine.delinearize_index %[[inputIv]] into (%[[c3]], %[[c5]], %[[c7]]
+//     CHECK:   %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0] [1, 1, 1, 5] [1, 1, 1, 2] : 
+//     CHECK:   %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3]{{\]}} : 
+//     CHECK:   %[[update:.+]] = tensor.insert_slice %[[sliceFlat]] into %[[iterArg]][%[[iv]], 0] [1, 5] [1, 1] : 
+//     CHECK:   scf.yield %[[update]] :
+//     CHECK: return %[[tile]]
+
+
+// -----
+
+
+func.func @extract_slice_dynamic(%input: tensor<3x?x?x11xf32>, %offt: index, %size: index) -> tensor<?x5xf32> {
+  %collapsed = tensor.collapse_shape %input [[0, 1, 2], [3]] : tensor<3x?x?x11xf32> into tensor<?x11xf32>
+  %slice = tensor.extract_slice %collapsed [%offt, 0] [%size, 5] [2, 2] : tensor<?x11xf32> to tensor<?x5xf32>
+  return %slice : tensor<?x5xf32>
+}
+
+//     CHECK: #[[map0:.+]] = affine_map<(d0)[s0] -> (d0 * 2 + s0)>
+//     CHECK: func.func @extract_slice_dynamic(%[[arg0:.+]]: tensor<{{.*}}>, %[[lb:.+]]: index, %[[sz:.+]]: index)
+// CHECK-DAG:   %[[c0:.+]] = arith.constant 0 : index
+// CHECK-DAG:   %[[c1:.+]] = arith.constant 1 : index
+// CHECK-DAG:   %[[c2:.+]] = arith.constant 2 : index
+// CHECK-DAG:   %[[c3:.+]] = arith.constant 3 : index
+//     CHECK:   %[[init:.+]] = linalg.init_tensor [%[[sz]], 5] : tensor<?x5xf32>
+// CHECK-DAG:   %[[d1:.+]] = tensor.dim %arg0, %[[c1]] : tensor<3x?x?x11xf32>
+// CHECK-DAG:   %[[d2:.+]] = tensor.dim %arg0, %[[c2]] : tensor<3x?x?x11xf32>
+//     CHECK:   %[[tile:.+]] = scf.for %[[iv:.+]] = %[[c0]] to %[[sz]] step %[[c1]] iter_args(%[[iterArg:.+]] = %[[init]])
+//     CHECK:     %[[inputIv:.+]] = affine.apply #[[map0]](%[[iv]])[%[[lb]]]
+//     CHECK:     %[[multiIndex:.+]]:3 = affine.delinearize_index %[[inputIv]] into (%[[c3]], %[[d1]], %[[d2]]) :
+//     CHECK:     %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0] [1, 1, 1, 5] [1, 1, 1, 2] :
+//     CHECK:     %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3]{{\]}} :
+//     CHECK:     %[[update:.+]] = tensor.insert_slice %[[sliceFlat]] into %[[iterArg]][%[[iv]], 0] [1, 5] [1, 1] :
+//     CHECK:     scf.yield %[[update]] :
+//     CHECK:   return %[[tile]] :
+
+// -----
+
+
+func.func @extract_slice_dynamic_multidim(%input: tensor<3x?x?x11x?xf32>, %offt0: index, %size0: index, %offt1: index, %size1: index) -> tensor<?x?xf32> {
+  %collapsed = tensor.collapse_shape %input [[0, 1, 2], [3, 4]] : tensor<3x?x?x11x?xf32> into tensor<?x?xf32>
+  %slice = tensor.extract_slice %collapsed [%offt0, %offt1] [%size0, %size1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+  return %slice : tensor<?x?xf32>
+}
+
+//     CHECK: #[[map0:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
+//     CHECK: func.func @extract_slice_dynamic_multidim(%[[arg0:.+]]: tensor<3x?x?x11x?xf32>, %[[lb1:.+]]: index, %[[sz1:.+]]: index, %[[lb2:.+]]: index, %[[sz2:.+]]: index)
+// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[c2:.+]] = arith.constant 2 : index
+// CHECK-DAG: %[[c3:.+]] = arith.constant 3 : index
+// CHECK-DAG: %[[c4:.+]] = arith.constant 4 : index
+// CHECK-DAG: %[[c11:.+]] = arith.constant 11 : index
+//     CHECK: %[[init:.+]] = linalg.init_tensor [%[[sz1]], %[[sz2]]] : tensor<?x?xf32>
+// CHECK-DAG: %[[d1:.+]] = tensor.dim %[[arg0]], %[[c1]] : 
+// CHECK-DAG: %[[d2:.+]] = tensor.dim %[[arg0]], %[[c2]] : 
+// CHECK-DAG: %[[d4:.+]] = tensor.dim %[[arg0]], %[[c4]] :
+//     CHECK: %[[tile1:.+]] = scf.for %[[iv1:.+]] = %[[c0]] to %[[sz1]] step %[[c1]] iter_args(%[[iterArg1:.+]] = %[[init]])
+//     CHECK:   %[[tile2:.+]] = scf.for %[[iv2:.+]] = %[[c0]] to %[[sz2]] step %[[c1]] iter_args(%[[iterArg2:.+]] = %[[iterArg1]])
+//     CHECK:       %[[inputIv1:.+]] = affine.apply #[[map0:.+]](%[[iv1]])[%[[lb1]]]
+//     CHECK:       %[[multiIndex1:.+]]:3 = affine.delinearize_index %[[inputIv1]] into (%[[c3]], %[[d1]], %[[d2]]) :
+//     CHECK:       %[[inputIv2:.+]] = affine.apply #[[map0:.+]](%[[iv2]])[%[[lb2]]]
+//     CHECK:       %[[multiIndex2:.+]]:2 = affine.delinearize_index %[[inputIv2]] into (%[[c11]], %[[d4]]) :
+//     CHECK:       %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex1]]#0, %[[multiIndex1]]#1, %[[multiIndex1]]#2, %[[multiIndex2]]#0, %[[multiIndex2]]#1] [1, 1, 1, 1, 1] [1, 1, 1, 1, 1] : 
+//     CHECK:       %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3, 4]{{\]}} : 
+//     CHECK:       %[[update:.+]] = tensor.insert_slice %[[sliceFlat]] into %[[iterArg2]][%[[iv1]], %[[iv2]]] [1, 1] [1, 1] : 
+//     CHECK:       scf.yield %[[update]] :
+//     CHECK:     scf.yield %[[tile2]] :
+//     CHECK:   return %[[tile1]] : 
+
+//     FOREACH: #[[map1:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
+//     FOREACH: func.func @extract_slice_dynamic_multidim(%[[arg0:.+]]: tensor<3x?x?x11x?xf32>, %[[lb1:.+]]: index, %[[sz1:.+]]: index, %[[lb2:.+]]: index, %[[sz2:.+]]: index)
+// FOREACH-DAG: %[[c1:.+]] = arith.constant 1 : index
+// FOREACH-DAG: %[[c2:.+]] = arith.constant 2 : index
+// FOREACH-DAG: %[[c3:.+]] = arith.constant 3 : index
+// FOREACH-DAG: %[[c4:.+]] = arith.constant 4 : index
+// FOREACH-DAG: %[[c11:.+]] = arith.constant 11 : index
+//     FOREACH:     %[[init:.+]] = linalg.init_tensor [%[[sz1]], %[[sz2]]] : tensor<?x?xf32>
+// FOREACH-DAG:     %[[d1:.+]] = tensor.dim %[[arg0]], %[[c1]] : 
+// FOREACH-DAG:     %[[d2:.+]] = tensor.dim %[[arg0]], %[[c2]] : 
+// FOREACH-DAG:     %[[d4:.+]] = tensor.dim %[[arg0]], %[[c4]] :
+//     FOREACH:     %[[tile1:.+]] = scf.foreach_thread (%[[tid1:.+]], %[[tid2:.+]]) in (%[[sz1]], %[[sz2]]) shared_outs(%[[dest:.+]] = %[[init]])
+// FOREACH-DAG:       %[[iv1:.+]] = affine.apply #[[map1]](%[[tid1]])[%[[lb1]]]
+//     FOREACH:       %[[multiIndex1:.+]]:3 = affine.delinearize_index %[[iv1]] into (%[[c3]], %[[d1]], %[[d2]]) :
+// FOREACH-DAG:       %[[iv2:.+]] = affine.apply #[[map1]](%[[tid2]])[%[[lb2]]]
+//     FOREACH:       %[[multiIndex2:.+]]:2 = affine.delinearize_index %[[iv2]] into (%[[c11]], %[[d4]]) :
+//     FOREACH:       %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex1]]#0, %[[multiIndex1]]#1, %[[multiIndex1]]#2, %[[multiIndex2]]#0, %[[multiIndex2]]#1] [1, 1, 1, 1, 1] [1, 1, 1, 1, 1] : 
+//     FOREACH:       %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3, 4]{{\]}} : 
+//     FOREACH:       perform_concurrently
+//FOREACH-NEXT:         tensor.parallel_insert_slice %[[sliceFlat]] into %[[dest]][%[[tid1]], %[[tid2]]] [1, 1] [1, 1] :
+
+// -----
+
+// Verifies that a linearized dimension that is not sliced does not generate a loop. Note that this
+// only works for static shapes.
+
+// CHECK: @extract_slice_non_sliced_linearized_dim(%[[arg0:.+]]: tensor<{{.*}}>,
+func.func @extract_slice_non_sliced_linearized_dim(%input: tensor<3x?x?x11x2xf32>, %offt: index, %size: index) -> tensor<?x22xf32> {
+  %collapsed = tensor.collapse_shape %input [[0, 1, 2], [3, 4]] : tensor<3x?x?x11x2xf32> into tensor<?x22xf32>  
+  %slice = tensor.extract_slice %collapsed [%offt, 0] [%size, 22] [1, 1] : tensor<?x22xf32> to tensor<?x22xf32>
+  // CHECK: scf.for
+  // CHECK-NOT: scf.for
+  // CHECK: %[[multiIndex:.+]]:3 = affine.delinearize_index
+  // CHECK: tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0, 0] [1, 1, 1, 11, 2] [1, 1, 1, 1, 1]
+  return %slice : tensor<?x22xf32>
+}

diff  --git a/mlir/test/lib/Dialect/Tensor/CMakeLists.txt b/mlir/test/lib/Dialect/Tensor/CMakeLists.txt
index 89c996da7b2a..56e59820677e 100644
--- a/mlir/test/lib/Dialect/Tensor/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Tensor/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_library(MLIRTensorTestPasses
 
   LINK_LIBS PUBLIC
   MLIRArithmeticDialect
+  MLIRLinalgDialect
   MLIRPass
   MLIRSCFDialect
   MLIRTensorDialect

diff  --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
index 4c38ad1d2dda..f5a7f984ab0a 100644
--- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
@@ -11,8 +11,10 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Transforms/TransformUtils.h"
 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -28,7 +30,8 @@ struct TestTensorTransforms
   TestTensorTransforms(const TestTensorTransforms &pass) : PassWrapper(pass) {}
 
   void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<arith::ArithmeticDialect, scf::SCFDialect>();
+    registry.insert<arith::ArithmeticDialect, scf::SCFDialect,
+                    linalg::LinalgDialect>();
   }
 
   StringRef getArgument() const final {
@@ -49,6 +52,19 @@ struct TestTensorTransforms
       *this, "test-fold-constant-extract-slice",
       llvm::cl::desc("Test folding arith.constant and tensor.extract_slice"),
       llvm::cl::init(false)};
+
+  Option<bool> testRewriteExtractSliceWithTiledCollapseShape{
+      *this, "test-rewrite-extract-slice-from-collapse-shape",
+      llvm::cl::desc("Test swapping tensor.extract_slice of a collapse_shape "
+                     "with loop nest"),
+      llvm::cl::init(false)};
+
+  Option<bool> useForeach{
+      *this, "use-foreach",
+      llvm::cl::desc(
+          "Use the scf.foreach_thread operation when generating loop nests for "
+          "the extract_slice of collapse_shape pattern"),
+      llvm::cl::init(false)};
 };
 } // namespace
 
@@ -74,12 +90,142 @@ static void applyFoldConstantExtractSlicePatterns(Operation *rootOp) {
   (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
 }
 
+namespace {
+/// Base pattern to rewrite  a `tensor.collapse_shape -> tensor.extract_slice`.
+/// The `tensor.extract_slice` is replaced by a loop or gather operation that
+/// stitches together the desired tile from slices of the source of the collapse
+/// shape op.
+struct RewriteExtractSliceFromCollapseShapeBase
+    : public OpRewritePattern<tensor::ExtractSliceOp> {
+  RewriteExtractSliceFromCollapseShapeBase(MLIRContext *context)
+      : mlir::OpRewritePattern<tensor::ExtractSliceOp>(context) {}
+
+  /// Emit a loop or gather operation that uses `helper` to take each point in
+  /// the parallel iteration space bounds, extract a slice from the source
+  /// tensor and insert it into `dest`. For examples, see below for `scf.for`
+  /// and `scf.foreach`.
+  virtual LogicalResult
+  emitReplacement(tensor::ExtractSliceOp op, Value dest,
+                  tensor::ExtractSliceFromCollapseHelper &helper,
+                  PatternRewriter &rewriter) const = 0;
+
+  LogicalResult matchAndRewrite(tensor::ExtractSliceOp op,
+                                PatternRewriter &rewriter) const override {
+    auto collapseOp = op.getSource().getDefiningOp<tensor::CollapseShapeOp>();
+    if (!collapseOp)
+      return rewriter.notifyMatchFailure(
+          op, "producer is not a tensor.collapse_shape op");
+
+    // Materialize the output shape values of the slice operation.a
+    ReifiedRankedShapedTypeDims reifiedShapes;
+    if (failed(op.reifyResultShapes(rewriter, reifiedShapes)))
+      return rewriter.notifyMatchFailure(op, "failed to reify result shapes");
+
+    // Create the destination tensor using the above values.
+    Type elementType = op.getSourceType().getElementType();
+    SmallVector<OpFoldResult> outputShape = getAsOpFoldResult(reifiedShapes[0]);
+    Value dest = rewriter.create<linalg::InitTensorOp>(
+        op->getLoc(), outputShape, elementType);
+
+    // Calculate the parameters for the tile loop nest.
+    FailureOr<tensor::ExtractSliceFromCollapseHelper> params =
+        tensor::ExtractSliceFromCollapseHelper::create(rewriter, collapseOp,
+                                                       op);
+    if (failed(params))
+      return rewriter.notifyMatchFailure(
+          op, "could not calculate tiling parameters");
+    return emitReplacement(op, dest, *params, rewriter);
+  }
+};
+
+struct RewriteExtractSliceFromCollapseShapeUsingScfFor
+    : public RewriteExtractSliceFromCollapseShapeBase {
+  RewriteExtractSliceFromCollapseShapeUsingScfFor(MLIRContext *context)
+      : RewriteExtractSliceFromCollapseShapeBase(context) {}
+  LogicalResult emitReplacement(tensor::ExtractSliceOp op, Value dest,
+                                tensor::ExtractSliceFromCollapseHelper &helper,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    const unsigned numTiledDims = helper.getIterationSpaceSizes().size();
+    auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+    SmallVector<Value> lbs(numTiledDims, zero);
+    SmallVector<Value> steps(numTiledDims, one);
+    scf::LoopNest nest = scf::buildLoopNest(
+        rewriter, loc, lbs, helper.getIterationSpaceSizes(), steps, dest,
+        [&](OpBuilder &nestedBuilder, Location loc, ValueRange outputIvs,
+            ValueRange iterArgs) -> scf::ValueVector {
+          auto [tile, insertParams] =
+              helper.emitLoopNestBody(nestedBuilder, loc, outputIvs);
+
+          // Insert the slice into the destination.
+          Value result = nestedBuilder.create<tensor::InsertSliceOp>(
+              loc, tile, iterArgs[0], insertParams);
+          return {result};
+        });
+    rewriter.replaceOp(op, nest.getResults()[0]);
+    return success();
+  }
+};
+
+struct RewriteExtractSliceFromCollapseShapeUsingScfForeach
+    : public RewriteExtractSliceFromCollapseShapeBase {
+  RewriteExtractSliceFromCollapseShapeUsingScfForeach(MLIRContext *context)
+      : RewriteExtractSliceFromCollapseShapeBase(context) {}
+  LogicalResult emitReplacement(tensor::ExtractSliceOp op, Value dest,
+                                tensor::ExtractSliceFromCollapseHelper &helper,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    auto foreachOp = rewriter.create<scf::ForeachThreadOp>(
+        loc, /*outputs=*/dest, /*numThreads=*/helper.getIterationSpaceSizes(),
+        /*threadDimMapping=*/ArrayRef<int64_t>{},
+        [&](OpBuilder &nestedBuilder, Location loc, ValueRange regionArgs) {
+          unsigned numThreadIdRegionArgs =
+              helper.getIterationSpaceSizes().size();
+          unsigned numOutputRegionArgs =
+              regionArgs.size() - numThreadIdRegionArgs;
+          ValueRange outputIvs = regionArgs.take_front(numThreadIdRegionArgs);
+          ValueRange outputArgs = regionArgs.take_back(numOutputRegionArgs);
+          assert(outputArgs.size() == 1 &&
+                 "there should only be one output region argument");
+          auto [tile, insertParams] =
+              helper.emitLoopNestBody(nestedBuilder, loc, outputIvs);
+          // Insert the slice into the destination.
+          auto term = nestedBuilder.create<scf::PerformConcurrentlyOp>(loc);
+          nestedBuilder.setInsertionPointToStart(term.getBody());
+          nestedBuilder.create<tensor::ParallelInsertSliceOp>(
+              loc, tile, outputArgs[0], insertParams);
+        });
+    rewriter.replaceOp(op, foreachOp->getResult(0));
+    return success();
+  }
+};
+} // namespace
+
+static LogicalResult
+applyRewriteExtractFromCollapseShapePatterns(Operation *rootOp,
+                                             bool useForeach) {
+  RewritePatternSet patterns(rootOp->getContext());
+  if (useForeach)
+    patterns.add<RewriteExtractSliceFromCollapseShapeUsingScfForeach>(
+        rootOp->getContext());
+  else
+    patterns.add<RewriteExtractSliceFromCollapseShapeUsingScfFor>(
+        rootOp->getContext());
+  return applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
+}
+
 void TestTensorTransforms::runOnOperation() {
   Operation *rootOp = getOperation();
   if (testSplitPaddingPatterns)
     applySplitPaddingPatterns(rootOp);
   if (testFoldConstantExtractSlice)
     applyFoldConstantExtractSlicePatterns(rootOp);
+  if (testRewriteExtractSliceWithTiledCollapseShape) {
+    if (failed(
+            applyRewriteExtractFromCollapseShapePatterns(rootOp, useForeach)))
+      return signalPassFailure();
+  }
 }
 
 namespace mlir {


        


More information about the Mlir-commits mailing list