[Mlir-commits] [mlir] f4a478c - [mlir][Tensor] Add rewrites to extract slices through `tensor.collape_shape`
Christopher Bate
llvmlistbot at llvm.org
Thu Sep 8 20:58:31 PDT 2022
Author: Christopher Bate
Date: 2022-09-08T21:58:21-06:00
New Revision: f4a478cd017818ad6381a8aa1a7e3d29fd263ef9
URL: https://github.com/llvm/llvm-project/commit/f4a478cd017818ad6381a8aa1a7e3d29fd263ef9
DIFF: https://github.com/llvm/llvm-project/commit/f4a478cd017818ad6381a8aa1a7e3d29fd263ef9.diff
LOG: [mlir][Tensor] Add rewrites to extract slices through `tensor.collape_shape`
This change adds a set of utilities to replace the result of a
`tensor.collapse_shape -> tensor.extract_slice` chain with the
equivalent result formed by aggregating slices of the
`tensor.collapse_shape` source. In general, it is not possible to
commute `extract_slice` and `collapse_shape` if linearized dimensions
are sliced. The i-th dimension of the `tensor.collapse_shape`
result is a "linearized sliced dimension" if:
1) Reassociation indices of tensor.collapse_shape in the i'th position
is greater than size 1 (multiple dimensions of the input are collapsed)
2) The i-th dimension is sliced by `tensor.extract_slice`.
We can work around this by stitching together the result of
`tensor.extract_slice` by iterating over any linearized sliced dimensions.
This is equivalent to "tiling" the linearized-and-sliced dimensions of
the `tensor.collapse_shape` operation in order to manifest the result
tile (the result of the `tensor.extract_slice`). The user of the
utilities must provide the mechanism to create the tiling (e.g. a loop).
In the tests, it is demonstrated how to apply the utilities using either
`scf.for` or `scf.foreach_thread`.
The below example illustrates the pattern using `scf.for`:
```
%0 = linalg.generic ... -> tensor<3x7x11x10xf32>
%1 = tensor.collapse_shape %0 [[0, 1, 2], [3]] : ... to tensor<341x10xf32>
%2 = tensor.extract_slice %1 [13, 0] [10, 10] [2, 1] : .... tensor<10x10xf32>
```
We can construct %2 by generating the following IR:
```
%dest = linalg.init_tensor() : tensor<10x10xf32>
%2 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%arg0) -> tensor<10x10xf32> {
// Step 1: Map this output idx (%iv) to a multi-index for the input (%3):
%linear_index = affine.apply affine_map<(d0)[]->(d0*2 + 11)>(%iv)
%3:3 = arith.delinearize_index %iv into (3, 7, 11)
// Step 2: Extract the slice from the input
%4 = tensor.extract_slice %0 [%3#0, %3#1, %3#2, 0] [1, 1, 1, 10] [1, 1, 1, 1] :
tensor<3x7x11x10xf32> to tensor<1x1x1x10xf32>
%5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] :
tensor<1x1x1x10xf32> into tensor<1x10xf32>
// Step 3: Insert the slice into the destination
%6 = tensor.insert_slice %5 into %arg0 [%iv, 0] [1, 10] [1, 1] :
tensor<1x10xf32> into tensor<10x10xf32>
scf.yield %6 : tensor<10x10xf32>
}
```
The pattern was discussed in the RFC here: https://discourse.llvm.org/t/rfc-tensor-extracting-slices-from-tensor-collapse-shape/64034
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D129699
Added:
mlir/include/mlir/Dialect/Tensor/Transforms/TransformUtils.h
mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshape.cpp
mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir
Modified:
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
mlir/lib/Dialect/Utils/StaticValueUtils.cpp
mlir/test/lib/Dialect/Tensor/CMakeLists.txt
mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 51ab419b5802..a55e6a45769b 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -283,7 +283,11 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice",
// Build an ExtractSliceOp with dynamic entries and inferred result type.
OpBuilder<(ins "RankedTensorType":$resultType, "Value":$source,
"ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides,
- CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+ // Build an ExtractSliceOp with mixed static and dynamic entries packed in
+ // a Range vector.
+ OpBuilder<(ins "Value":$source, "ArrayRef<Range>":$ranges,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
];
let extraClassDeclaration = extraBaseClassDeclaration # [{
@@ -739,6 +743,11 @@ def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [
// Build a InsertSliceOp with dynamic entries.
OpBuilder<(ins "Value":$source, "Value":$dest,
"ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+ // Build an InsertSliceOp with mixed static and dynamic entries packed in
+ // a Range vector.
+ OpBuilder<(ins "Value":$source, "Value":$dest,
+ "ArrayRef<Range>":$ranges,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>
];
@@ -1337,7 +1346,11 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
"ArrayRef<OpFoldResult>":$offsets, "ArrayRef<OpFoldResult>":$sizes,
"ArrayRef<OpFoldResult>":$strides,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
-
+ // Build a ParallelInsertSliceOp with mixed static and dynamic entries
+ // packed into a Range vector.
+ OpBuilder<(ins "Value":$source, "Value":$dest,
+ "ArrayRef<Range>":$ranges,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
// Build a ParallelInsertSliceOp with dynamic entries.
OpBuilder<(ins "Value":$source, "Value":$dest,
"ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides,
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/TransformUtils.h b/mlir/include/mlir/Dialect/Tensor/Transforms/TransformUtils.h
new file mode 100644
index 000000000000..2ca556275af1
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/TransformUtils.h
@@ -0,0 +1,210 @@
+//===- TransformsUtils.h - Tensor Transformation Utilities-------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_DIALECT_TENSOR_TRANSFORMS_TRANSFORMUTILS_H
+#define MLIR_DIALECT_TENSOR_TRANSFORMS_TRANSFORMUTILS_H
+
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+namespace tensor {
+
+//===----------------------------------------------------------------------===//
+// Extract slice from `tensor.collapse_shape`
+//===----------------------------------------------------------------------===//
+
+/// This class assists with generating IR required to materialize an
+/// arbitrary-sized slice from the result of a CollapseShapeOp. In order to
+/// accomplish this, a loop nest or similar operation must be created by the
+/// caller. The purpose of the loop nest is to generate a "tiling by 1" of all
+/// sliced dimensions. The "tiling by 1" assembles all elements of the result
+/// tile over dimensions that would have been impossible to directly slice.
+///
+/// The class provides three methods:
+/// 1. `ExtractSliceFromCollapseHelper::create`: emits IR that should
+/// appear before the loop nest and populates the internal state.
+/// 2. `ExtractSliceFromCollapseHelper::getIterationSpaceSizes`: returns
+/// parameters used by the caller to construct the loop nest.
+/// 3. `ExtractSliceFromCollapseHelper::emitLoopNestBody`:
+/// emits IR to construct a "size-1 tile" of the desired result and returns a
+/// set of ranges where the tile should be inserted into the destination
+/// tensor.
+///
+/// ### Intended usage:
+///
+/// The caller should first call `ExtractSliceFromCollapseHelper::create` and
+/// then create a destination tensor that is the same size as the desired slice.
+/// The caller then creates a loop nest that iterates over the multi-dimensional
+/// iteration space defined by `[0, ub[0]) x [0, ub[1]] x ... x [0, ub[N-1]]`
+/// where `ub` is the upper bound given by
+/// `ExtractSliceFromCollapseHelper::getIterationSpaceSizes`. Inside the body of
+/// the loop nest, the caller should call
+/// `ExtractSliceFromCollapseHelper::emitLoopNestBody` and provide the induction
+/// variables. This returns a sub-tile and a set of ranges that describe where
+/// this tile should be inserted into the result by the caller. For a complete
+/// example of usage, see the examples in the TestTensorTransforms pass.
+///
+/// ### Example:
+/// Consider the following IR:
+/// ```
+/// %0 = linalg.generic ... -> tensor<3x?x?x11x?xf32>
+/// %1 = tensor.collapse_shape %0 [[0, 1, 2], [3, 4]]
+/// : tensor<3x?x?x11x?xf32> into tensor<?x?xf32>
+/// %2 = tensor.extract_slice %1 [%offt0, %offt1][%size0, %size1][1, 1]
+/// : tensor<?x?xf32> to tensor<?x?xf32>
+/// ```
+///
+/// We can construct %2 by generating the following, which only uses `%0`:
+///
+/// ```
+/// %dest = linalg.init_tensor [%size0, %size1] : tensor<?x?xf32>
+/// %1 = tensor.dim %0, %c1 : tensor<3x?x?x11x?xf32>
+/// %2 = tensor.dim %0, %c2 : tensor<3x?x?x11x?xf32>
+/// %3 = tensor.dim %0, %c4 : tensor<3x?x?x11x?xf32>
+///
+/// %result = scf.for %iv0 = %c0 to %arg2 step %c1 iter_args(%arg6 = %dest) ->
+/// (tensor<?x?xf32>) {
+/// %5 = scf.for %iv1 = %c0 to %arg4 step %c1 iter_args(%arg8 = %arg6)
+/// -> (tensor<?x?xf32>) {
+/// %lin0 = (affine.apply) %iv0 + %offt0
+/// %lin1 = (affine.apply) %iv1 + %offt1
+///
+/// %mi0:3 = affine.delinearize_index %lin0 into (%c3, %1, %2)
+/// %mi1:2 = affine.delinearize_index %lin1 into (%c11, %3)
+///
+/// %sub_tile = tensor.extract_slice %0
+/// [%mi0#0, %mi0#1, %mi0#2, %mi1#0, %mi1#1]
+/// [1, 1, 1, 1, 1]
+/// [1, 1, 1, 1, 1]
+/// : tensor<3x?x?x11x?xf32> to tensor<1x1x1x1x1xf32>
+/// %sub_tile_collapsed = tensor.collapse_shape %sub_tile
+/// [[0, 1, 2], [3, 4]]
+/// : tensor<1x1x1x1x1xf32> into tensor<1x1xf3
+///
+/// %12 = tensor.insert_slice %sub_tile_collapsed into
+/// %arg8[%iv0, %iv1] [1, 1] [1, 1]
+/// : tensor<1x1xf32> into tensor<?x?xf32>
+/// scf.yield %12 : tensor<?x?xf32>
+/// }
+/// scf.yield %5 : tensor<?x?xf32>
+/// }
+/// ```
+///
+/// ### Explanation of example:
+///
+/// Each step above is explained below.
+///
+/// #### Step 0: Create %dest and materialization of shapes.
+/// This step is self-explanatory and performed by the caller. It can be
+/// done before or after calling `ExtractSliceFromCollapseHelper::create`,
+/// which materializes the source shape (`%0, %1, %2`).
+///
+/// #### Step 1: Create loop nest.
+///
+/// The caller creates the loop nest (depicted here is `scf.for`, but any other
+/// similar op can be used). The iteration should start at zero and proceed with
+/// step size 1 to the upper bounds given by
+/// `ExtractSliceFromCollapseHelper::getIterationSpaceSizes`. This forms the
+/// basis for the "tiling by 1".
+///
+/// #### Step 2: Transform (%iv0, %iv1) from the index space of %3 to the index
+/// space of %0.
+///
+/// This step is performed by
+/// `ExtractSliceFromCollapseHelper::emitLoopNestBody`.
+///
+/// The induction variables `%iv0` and `%iv1` live in the
+/// index space of %2 (for dimensions 0 and 1, respectively). `%lin0` and
+/// `%lin1` are the result of inverting or resolve the index space
+/// transformation represented by the slice operation, accounting for offset and
+/// stride. Subsequently, `%mi0` and `%mi1` are the result of applying the
+/// inverse index space transformation represented by `tensor.collapse_shape`.
+/// This is accomplished using `affine.delinearize_index`. Note that %iv0
+/// and %iv1 now correspond to multi-indices `%mi0:3` and `%mi1:2`.
+///
+/// #### Step 3: Extract a sub-tile slice from the source.
+///
+/// This step is also performed by
+/// `ExtractSliceFromCollapseHelper::emitLoopNestBody`.
+///
+/// The indices `%mi0` and `%mi1` are used to extract a slice from %0. This
+/// slice is then collapsed down to match the result rank.
+///
+/// #### Step 4: Insert sub-tile into the destination
+///
+/// This step is performed by the caller using the results of
+/// `ExtractSliceFromCollapseHelper::emitLoopNestBody`.
+///
+/// In the above example, the slice insertion parameters are straightforward,
+/// but in other possible situations, the slice parameters are more complicated,
+/// which is why this helper calculates them for the caller. These other
+/// situations correspond to:
+/// 1. The presence of linearized dimensions that are not sliced
+/// 2. The presence of non-linearized dimensions that are sliced.
+class ExtractSliceFromCollapseHelper {
+public:
+ /// Given a CollapseShapeOp and a set of ranges describing the desired slice
+ /// of its result, emits IR to materialize the shapes of the input and output
+ /// tensors, and returns an instance of the initialized class. Returns failure
+ /// if the slice is rank-reducing.
+ static FailureOr<ExtractSliceFromCollapseHelper>
+ create(OpBuilder &b, tensor::CollapseShapeOp op, ArrayRef<Range> sliceParams);
+
+ /// Given a CollapseShapeOp and an ExtractSliceOp acting on its result, emits
+ /// IR to materialize the shapes of the input and output tensors of the
+ /// CollapseShapeOp, and returns an instance of the initialized class. Returns
+ /// failure if the slice is rank-reducing.
+ static FailureOr<ExtractSliceFromCollapseHelper>
+ create(OpBuilder &b, tensor::CollapseShapeOp collapseOp,
+ tensor::ExtractSliceOp extractOp);
+
+ ExtractSliceFromCollapseHelper(
+ tensor::CollapseShapeOp collapseShapeOp,
+ ArrayRef<OpFoldResult> collapseShapeInputShape,
+ ArrayRef<OpFoldResult> collapseShapeOutputShape,
+ ArrayRef<Range> extractSliceParams,
+ const llvm::SmallBitVector &linearizedDimensions,
+ const llvm::SmallBitVector &slicedDimensions, ArrayRef<Value> tiledSizes)
+ : collapseShapeOp(collapseShapeOp),
+ collapseShapeInputShape(collapseShapeInputShape),
+ collapseShapeOutputShape(collapseShapeOutputShape),
+ sliceParams(extractSliceParams),
+ linearizedDimensions(linearizedDimensions),
+ slicedDimensions(slicedDimensions), tiledSizes(tiledSizes) {}
+
+ /// Return the upper bounds of the iteration space (with 0 offset and stride
+ /// 1) required to create the desired slice. Note that this is not the same
+ /// as the `sizes` parameters of the ExtractSliceOp because not all dimensions
+ /// of the slice are required to be tiled to form the result.
+ const SmallVector<Value> &getIterationSpaceSizes() { return tiledSizes; }
+
+ /// Generates the IR inside of the caller's loop nest for 1) inverting the
+ /// index mappings of the ExtractSliceOp->CollapseShapeOp chain and 2)
+ /// extracting the CollapseShapeOp source tensor tile for this specified
+ /// iteration space point `tileInductionVars` and 3) calculating where to
+ /// insert the extracted tile. The returned pair consists of the results of
+ /// (2) and (3) and should be used by the caller to insert into the
+ /// destination tensor.
+ std::pair<Value, SmallVector<Range>>
+ emitLoopNestBody(OpBuilder &builder, Location loc,
+ ValueRange tileInductionVars);
+
+private:
+ tensor::CollapseShapeOp collapseShapeOp;
+ SmallVector<OpFoldResult> collapseShapeInputShape;
+ SmallVector<OpFoldResult> collapseShapeOutputShape;
+ SmallVector<Range> sliceParams;
+ llvm::SmallBitVector linearizedDimensions;
+ llvm::SmallBitVector slicedDimensions;
+ SmallVector<Value> tiledSizes;
+};
+
+} // namespace tensor
+} // namespace mlir
+
+#endif // MLIR_DIALECT_TENSOR_TRANSFORMS_TRANSFORMUTILS_H
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 61f7aea9f526..e6b6048f8180 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -14,6 +14,7 @@
#ifndef MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
#define MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
@@ -373,6 +374,90 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
}
};
+/// The input parameters `offsets`, `sizes`, `strides` specify a rectangular
+/// non rank-reducing slice of the collapse_shape output. Try to find which
+/// dimensions have been sliced and which dimensions are not sliced (offset = 0,
+/// size = dim, size = 1). Note that this conservative as it cannot detect if a
+/// dynamic size corresponds to the full tensor dimension or not.
+llvm::SmallBitVector getSlicedDimensions(ArrayRef<OpFoldResult> sliceInputShape,
+ ArrayRef<Range> sliceParams);
+
+/// Determine which dimensions are linearized by a `tensor.collapse_shape` op by
+/// inspecting its reassociation indices.
+llvm::SmallBitVector
+getLinearizedDimensions(ArrayRef<ReassociationIndices> reassociationIndices);
+
+/// Given the parameters for both operations in a `CollapseShape->ExtractSlice`
+/// chain and reified source and result shapes of the CollapseShapeOp, this
+/// class provides two functions that assist with directly forming the result
+/// of the extract slice by "tiling the CollapseShapeOp by 1".
+//// Example:
+// clang-format off
+/// ```
+/// %0 = linalg.generic ... -> tensor<3x7x11x10xf32>
+/// %1 = tensor.collapse_shape %0 [[0, 1, 2], [3]] : ... to tensor<341x10xf32>
+/// %2 = tensor.extract_slice %1 [13, 0] [10, 10] [2, 1] : .... tensor<10x10xf32>
+/// ```
+/// This class helps build the below IR to replace %2:
+/// ```
+/// %dest = linalg.init_tensor() : tensor<10x10xf32>
+/// %2 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%arg0) -> tensor<10x10xf32> {
+/// %linear_index = affine.apply affine_map<(d0)[]->(d0*2 + 11)>(%iv)
+/// %3:3 = arith.delinearize_index %iv into (3, 7, 11)
+///
+/// // This function takes %3 (multiIndices) and the parameters for the slice below.
+/// %4 = tensor.extract_slice %0 [%3#0, %3#1, %3#2, 0] [1, 1, 1, 10] [1, 1, 1, 1] :
+/// tensor<3x7x11x10xf32> to tensor<1x1x1x10xf32>
+///
+/// %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] :
+/// tensor<1x1x1x10xf32> into tensor<1x10xf32>
+/// %6 = tensor.insert_slice %5 into %arg0 [%iv, 0] [1, 10] [1, 1] :
+/// tensor<1x10xf32> into tensor<10x10xf32>
+/// scf.yield %6 : tensor<10x10xf32>
+/// }
+/// ```
+// clang-format on
+class SliceFromCollapseHelper {
+public:
+ SliceFromCollapseHelper(ArrayRef<ReassociationIndices> reassociationIndices,
+ ArrayRef<OpFoldResult> collapseShapeInputShape,
+ ArrayRef<OpFoldResult> collapseShapeOutputShape,
+ ArrayRef<Range> extractSliceParams)
+ : reassociationIndices(reassociationIndices),
+ collapseShapeInputShape(collapseShapeInputShape),
+ collapseShapeOutputShape(collapseShapeOutputShape),
+ sliceParams(extractSliceParams),
+ linearizedDimensions(getLinearizedDimensions(reassociationIndices)),
+ slicedDimensions(getSlicedDimensions(collapseShapeOutputShape,
+ extractSliceParams)) {}
+
+ /// This function takes multi-indices and maps them to ExtractSlice parameters
+ /// in the index space of the CollapseShape's source tensor. This function's
+ /// signature can be described by `(D_0, D_1,.. D_{n-1}) -> (offsets, sizes,
+ /// strides)` where `n` the number of "tiled dimensions", which are the
+ /// dimensions of the output that are linearized by the collapse shape op and
+ /// are also sliced. Each `D_i` is a tuple that must represent a valid
+ /// multi-index for the `i-th` tiled dimension. In the example above, there is
+ /// only one tiled dimension (D_0) and `arith.delinearize_index` produces the
+ /// multi-index (%3) that would be passed to this function to generate the
+ /// parameters for the `tensor.extract_slice` op (%4).
+ SmallVector<Range> getExtractSliceParams(ArrayRef<ValueRange> multiIndices);
+
+ /// This function takes indices in the index space of the "tiled dimensions"
+ /// described above and returns a set of Range variables that describe how the
+ /// slice should be inserted into the destination. In the example above, `%iv`
+ /// would be passed to this function to generate the parameters for the
+ /// `tensor.insert_slice` op producing %6.
+ SmallVector<Range> getInsertSliceParams(ValueRange tileIndices);
+
+private:
+ SmallVector<ReassociationIndices> reassociationIndices;
+ SmallVector<OpFoldResult> collapseShapeInputShape;
+ SmallVector<OpFoldResult> collapseShapeOutputShape;
+ SmallVector<Range> sliceParams;
+ llvm::SmallBitVector linearizedDimensions;
+ llvm::SmallBitVector slicedDimensions;
+};
} // namespace mlir
#endif // MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 096ee94c0ba4..f290b1e8e8b3 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -29,6 +29,13 @@ struct Range {
OpFoldResult stride;
};
+/// Given an array of Range values, return a tuple of (offset vector, sizes
+/// vector, and strides vector) formed by separating out the individual elements
+/// of each range.
+std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
+ SmallVector<OpFoldResult>>
+getOffsetsSizesAndStrides(ArrayRef<Range> ranges);
+
/// Helper function to dispatch an OpFoldResult into `staticVec` if:
/// a) it is an IntegerAttr
/// In other cases, the OpFoldResult is dispached to the `dynamicVec`.
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 7b3bc79451a4..1960232b5f4e 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1210,6 +1210,15 @@ void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
}
+/// Build an ExtractSliceOp with mixed static and dynamic entries packed into a
+/// Range vector.
+void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
+ ArrayRef<Range> ranges,
+ ArrayRef<NamedAttribute> attrs) {
+ auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
+ build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
+}
+
/// Build an ExtractSliceOp with dynamic entries and custom result type. If the
/// type passed is nullptr, it is inferred.
void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
@@ -1597,6 +1606,15 @@ void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
result.addAttributes(attrs);
}
+/// Build an InsertSliceOp with mixed static and dynamic entries packed into a
+/// Range vector.
+void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
+ Value dest, ArrayRef<Range> ranges,
+ ArrayRef<NamedAttribute> attrs) {
+ auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
+ build(b, result, source, dest, offsets, sizes, strides, attrs);
+}
+
// Build a InsertSliceOp with dynamic entries.
void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
Value dest, ValueRange offsets, ValueRange sizes,
@@ -2359,6 +2377,16 @@ void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
result.addAttributes(attrs);
}
+/// Build an ParallelInsertSliceOp with mixed static and dynamic entries packed
+/// into a Range vector.
+void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
+ Value source, Value dest,
+ ArrayRef<Range> ranges,
+ ArrayRef<NamedAttribute> attrs) {
+ auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
+ build(b, result, source, dest, offsets, sizes, strides, attrs);
+}
+
// Build a ParallelInsertSliceOp with dynamic entries.
void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
Value source, Value dest, ValueRange offsets,
diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
index 66e4cc906f23..0b200e03226d 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRTensorTransforms
BufferizableOpInterfaceImpl.cpp
Bufferize.cpp
+ ExtractSliceFromReshape.cpp
SplitPadding.cpp
SwapExtractSliceWithProducer.cpp
diff --git a/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshape.cpp b/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshape.cpp
new file mode 100644
index 000000000000..4acd5482e823
--- /dev/null
+++ b/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshape.cpp
@@ -0,0 +1,179 @@
+//===- ExtractSliceFromReshape.cpp - Slice reshape rewrites-------*- C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements rewrites that replace slices of reshape results with
+// aggregated slices of the reshape source.
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Transforms/TransformUtils.h"
+#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
+#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
+#include "llvm/ADT/STLExtras.h"
+
+using namespace mlir;
+using namespace mlir::tensor;
+
+/// Get the dimension size of a value of RankedTensor type at the
+OpFoldResult getShapeDimSize(OpBuilder &b, Location loc, Value rankedTensor,
+ int64_t dimIdx) {
+ RankedTensorType tensorType = rankedTensor.getType().cast<RankedTensorType>();
+ if (!tensorType.isDynamicDim(dimIdx)) {
+ return b.getIndexAttr(tensorType.getDimSize(dimIdx));
+ }
+ Value idxValue = b.create<arith::ConstantIndexOp>(loc, dimIdx);
+ return b.createOrFold<tensor::DimOp>(loc, rankedTensor, idxValue);
+}
+
+/// Get all the dimension sizes of a value of RankedTensor type.
+static SmallVector<OpFoldResult> getShapeDimSizes(OpBuilder &b, Location loc,
+ Value rankedTensor) {
+ SmallVector<OpFoldResult> dimSizes;
+ RankedTensorType tensorType = rankedTensor.getType().cast<RankedTensorType>();
+ for (unsigned i = 0; i < tensorType.getRank(); i++)
+ dimSizes.push_back(getShapeDimSize(b, loc, rankedTensor, i));
+ return dimSizes;
+}
+
+/// A tuple that represents (dimension number, dimension value).
+using DimAndIndex = std::tuple<unsigned, Value>;
+
+/// Transform `dimAndIndex` from the output index space of a (non-rank-reducing)
+/// slice described by `sliceParams` into the input index space.
+static DimAndIndex invertSliceIndexing(OpBuilder &b, Location loc,
+ ArrayRef<Range> sliceParams,
+ const DimAndIndex &dimAndIndex) {
+ AffineExpr d0, s0, s1;
+ bindDims(b.getContext(), d0);
+ bindSymbols(b.getContext(), s0, s1);
+ auto [dim, indexValue] = dimAndIndex;
+ assert(dim < sliceParams.size() && "slice should be non rank-reducing");
+ return std::make_pair(
+ dim,
+ makeComposedAffineApply(
+ b, loc, s0 + d0 * s1,
+ {indexValue,
+ getValueOrCreateConstantIndexOp(b, loc, sliceParams[dim].offset),
+ getValueOrCreateConstantIndexOp(b, loc, sliceParams[dim].stride)}));
+}
+
+/// Transform `dimAndIndex` from the result tensor index space of a
+/// CollapseShapeOp to the source tensor index space.
+static ValueRange invertCollapseShapeIndexing(
+ OpBuilder &b, Location loc, ArrayRef<ReassociationIndices> reassociation,
+ ArrayRef<OpFoldResult> reshapeSourceShape, const DimAndIndex &dimAndIndex) {
+ const auto &[dim, indexValue] = dimAndIndex;
+ SmallVector<OpFoldResult> basis;
+ for (int64_t i : reassociation[dim])
+ basis.push_back(reshapeSourceShape[i]);
+ auto delinearized =
+ b.create<AffineDelinearizeIndexOp>(loc, indexValue, basis);
+ return delinearized->getResults();
+}
+
+FailureOr<ExtractSliceFromCollapseHelper>
+tensor::ExtractSliceFromCollapseHelper::create(
+ OpBuilder &b, tensor::CollapseShapeOp collapseOp,
+ tensor::ExtractSliceOp extractOp) {
+ if (extractOp.getSource().getDefiningOp<tensor::CollapseShapeOp>() !=
+ collapseOp)
+ return failure();
+ SmallVector<Range> ranges;
+ ranges.reserve(extractOp.getSourceType().getRank());
+ for (const auto &[o, s, st] :
+ llvm::zip(extractOp.getMixedOffsets(), extractOp.getMixedSizes(),
+ extractOp.getMixedStrides())) {
+ ranges.push_back({o, s, st});
+ }
+ return ExtractSliceFromCollapseHelper::create(b, collapseOp, ranges);
+}
+
+FailureOr<ExtractSliceFromCollapseHelper>
+tensor::ExtractSliceFromCollapseHelper::create(OpBuilder &b,
+ tensor::CollapseShapeOp op,
+ ArrayRef<Range> sliceParams) {
+
+ // Materialize the output shape of the collapse_shape operation. This will
+ // create IR describing the output shape in terms of the input shape.
+ ReifiedRankedShapedTypeDims reifiedShapes;
+ ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface =
+ dyn_cast<ReifyRankedShapedTypeOpInterface>(op.getOperation());
+ if (failed(reifyShapedTypeInterface.reifyResultShapes(b, reifiedShapes)))
+ return failure();
+ SmallVector<OpFoldResult> collapseShapeOutputShape =
+ getAsOpFoldResult(reifiedShapes[0]);
+ SmallVector<ReassociationIndices> reassociationIndices =
+ op.getReassociationIndices();
+
+ // Determine which of the CollapseShapeOp's result dimensions are sliced
+ // and/or linearized.
+ llvm::SmallBitVector linearizedDimensions =
+ getLinearizedDimensions(reassociationIndices);
+ llvm::SmallBitVector slicedDimensions =
+ getSlicedDimensions(collapseShapeOutputShape, sliceParams);
+
+ auto collapseShapeInputShape = getShapeDimSizes(b, op.getLoc(), op.getSrc());
+
+ SmallVector<OpFoldResult> srcShape =
+ getShapeDimSizes(b, op->getLoc(), op.getSrc());
+
+ SmallVector<Value> tileSizes;
+ for (unsigned i = 0; i < sliceParams.size(); i++) {
+ if (slicedDimensions[i] && linearizedDimensions[i])
+ tileSizes.push_back(
+ getValueOrCreateConstantIndexOp(b, op.getLoc(), sliceParams[i].size));
+ }
+
+ return ExtractSliceFromCollapseHelper(
+ op, collapseShapeInputShape, collapseShapeOutputShape, sliceParams,
+ linearizedDimensions, slicedDimensions, tileSizes);
+}
+
+std::pair<Value, SmallVector<Range>>
+tensor::ExtractSliceFromCollapseHelper::emitLoopNestBody(
+ OpBuilder &builder, Location loc, ValueRange tileInductionVars) {
+ // Create the helper class for forming the slice parameters.
+ const SmallVector<ReassociationIndices> reassociationIndices =
+ collapseShapeOp.getReassociationIndices();
+ SliceFromCollapseHelper helper(reassociationIndices, collapseShapeInputShape,
+ collapseShapeOutputShape, sliceParams);
+
+ // Get the indices of the tiled dims (linearized by the collapse_shape
+ // and sliced by the extract_slice) invert the index spaces
+ // transformations.
+ SmallVector<ValueRange> multiIndices;
+ unsigned loopIdx = 0;
+ for (unsigned i = 0, e = linearizedDimensions.size(); i < e; i++) {
+ if (linearizedDimensions[i] && slicedDimensions[i]) {
+ DimAndIndex tb =
+ invertSliceIndexing(builder, loc, sliceParams,
+ std::make_tuple(i, tileInductionVars[loopIdx++]));
+ multiIndices.push_back(invertCollapseShapeIndexing(
+ builder, loc, reassociationIndices, collapseShapeInputShape, tb));
+ }
+ }
+
+ auto extractParams = helper.getExtractSliceParams(multiIndices);
+
+ Value subTileResult = builder.create<tensor::ExtractSliceOp>(
+ loc, collapseShapeOp.getSrc(), extractParams);
+
+ SmallVector<Range> insertParams =
+ helper.getInsertSliceParams(tileInductionVars);
+
+ // Collapse the dimensions of the source slice back down.
+ Value collapsedResult = builder.create<tensor::CollapseShapeOp>(
+ loc, subTileResult, reassociationIndices);
+ return std::make_pair(collapsedResult, insertParams);
+}
diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index b26b3d93541a..7f5b63814e69 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -270,3 +270,88 @@ bool mlir::hasNonIdentityLayout(Type type) {
return !memrefType.getLayout().isIdentity();
return false;
}
+
+llvm::SmallBitVector
+mlir::getSlicedDimensions(ArrayRef<OpFoldResult> sliceInputShape,
+ ArrayRef<Range> sliceParams) {
+ assert(sliceParams.size() == sliceInputShape.size() &&
+ "only supports non rank-reducing case");
+ llvm::SmallBitVector mask(sliceInputShape.size());
+ unsigned idx = 0;
+ for (const auto &[offset, size, stride] : sliceParams) {
+ Optional<int64_t> offsetConst = getConstantIntValue(offset);
+ Optional<int64_t> strideConst = getConstantIntValue(stride);
+ mask[idx] = !isEqualConstantIntOrValue(size, sliceInputShape[idx]) ||
+ (!strideConst || *strideConst != 1) ||
+ (!offsetConst || *offsetConst != 0);
+ idx++;
+ }
+ return mask;
+}
+
+llvm::SmallBitVector mlir::getLinearizedDimensions(
+ ArrayRef<ReassociationIndices> reassociationIndices) {
+ llvm::SmallBitVector result(reassociationIndices.size());
+ for (const auto &it : llvm::enumerate(reassociationIndices))
+ result[it.index()] = it.value().size() > 1;
+ return result;
+}
+
+SmallVector<Range> SliceFromCollapseHelper::getExtractSliceParams(
+ ArrayRef<ValueRange> multiIndices) {
+ assert(!multiIndices.empty() && !multiIndices[0].empty() &&
+ "multiIndices should not be empty");
+ unsigned loopIdx = 0;
+ MLIRContext *ctx = multiIndices[0][0].getContext();
+ auto oneAttr = IntegerAttr::get(IndexType::get(ctx), 1);
+ auto zeroAttr = IntegerAttr::get(IndexType::get(ctx), 0);
+ SmallVector<Range> offsetsSizesAndStrides;
+ offsetsSizesAndStrides.reserve(collapseShapeInputShape.size());
+ for (const auto &it : llvm::enumerate(reassociationIndices)) {
+ // Case 1: Linearized dimensions that have also been sliced. These
+ // are size of 1 because we are iterating over these dimensions. The
+ // offsets are exactly the de-linearized multi-indices.
+ if (slicedDimensions[it.index()] && linearizedDimensions[it.index()]) {
+ llvm::append_range(
+ offsetsSizesAndStrides,
+ llvm::map_range(multiIndices[loopIdx++], [&](Value v) -> Range {
+ return Range{getAsOpFoldResult(v), oneAttr, oneAttr};
+ }));
+ continue;
+ }
+
+ // Case 2: One or possibly multiple combined input dimensions, but we
+ // have proven that these are not sliced. In this case we just take
+ // the full extent of each dimension in the reassociation list.
+ if (linearizedDimensions[it.index()]) {
+ llvm::append_range(
+ offsetsSizesAndStrides,
+ llvm::map_range(it.value(), [&](int64_t idx) -> Range {
+ return {zeroAttr, collapseShapeInputShape[idx], oneAttr};
+ }));
+ continue;
+ }
+
+ // Case 3: A single index, but it may be sliced.
+ offsetsSizesAndStrides.push_back(sliceParams[it.index()]);
+ }
+ return offsetsSizesAndStrides;
+}
+
+SmallVector<Range>
+SliceFromCollapseHelper::getInsertSliceParams(ValueRange tileIndices) {
+ MLIRContext *ctx = tileIndices[0].getContext();
+ auto one = IntegerAttr::get(IndexType::get(ctx), 1);
+ auto zero = IntegerAttr::get(IndexType::get(ctx), 0);
+ SmallVector<Range> insertParams;
+ insertParams.reserve(linearizedDimensions.size());
+ unsigned loopIdx = 0;
+ for (unsigned i = 0; i < linearizedDimensions.size(); i++) {
+ if (linearizedDimensions[i] && slicedDimensions[i]) {
+ insertParams.push_back(Range{tileIndices[loopIdx++], one, one});
+ continue;
+ }
+ insertParams.push_back(Range{zero, sliceParams[i].size, one});
+ }
+ return insertParams;
+}
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 80e93553858f..6212df931144 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -13,6 +13,21 @@
namespace mlir {
+std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
+ SmallVector<OpFoldResult>>
+getOffsetsSizesAndStrides(ArrayRef<Range> ranges) {
+ SmallVector<OpFoldResult> offsets, sizes, strides;
+ offsets.reserve(ranges.size());
+ sizes.reserve(ranges.size());
+ strides.reserve(ranges.size());
+ for (const auto &[offset, size, stride] : ranges) {
+ offsets.push_back(offset);
+ sizes.push_back(size);
+ strides.push_back(stride);
+ }
+ return std::make_tuple(offsets, sizes, strides);
+}
+
/// Helper function to dispatch an OpFoldResult into `staticVec` if:
/// a) it is an IntegerAttr
/// In other cases, the OpFoldResult is dispached to the `dynamicVec`.
diff --git a/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir b/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir
new file mode 100644
index 000000000000..d8ca129bf59a
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir
@@ -0,0 +1,164 @@
+// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-rewrite-extract-slice-from-collapse-shape %s | FileCheck %s
+// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns="test-rewrite-extract-slice-from-collapse-shape use-foreach" %s | FileCheck %s --check-prefix=FOREACH
+
+func.func @extract_slice_static(%input: tensor<3x5x7x11xf32>) -> tensor<20x11xf32> {
+ %collapsed = tensor.collapse_shape %input [[0, 1, 2], [3]] : tensor<3x5x7x11xf32> into tensor<105x11xf32>
+ %slice = tensor.extract_slice %collapsed [0, 0] [20, 11] [1, 1] : tensor<105x11xf32> to tensor<20x11xf32>
+ return %slice : tensor<20x11xf32>
+}
+
+// CHECK: func.func @extract_slice_static(%[[arg0:.+]]:
+// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[c20:.+]] = arith.constant 20 : index
+// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[c3:.+]] = arith.constant 3 : index
+// CHECK-DAG: %[[c5:.+]] = arith.constant 5 : index
+// CHECK-DAG: %[[c7:.+]] = arith.constant 7 : index
+// CHECK-DAG: %[[init:.+]] = linalg.init_tensor [20, 11] :
+// CHECK-DAG: %[[tile:.+]] = scf.for %[[iv:.+]] = %[[c0]] to %[[c20]] step %[[c1]] iter_args(%[[iterArg:.+]] = %[[init]])
+// CHECK: %[[multiIndex:.+]]:3 = affine.delinearize_index %[[iv]] into (%[[c3]], %[[c5]], %[[c7]]
+// CHECK: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0] [1, 1, 1, 11] [1, 1, 1, 1] :
+// CHECK: %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3]{{\]}} :
+// CHECK: %[[update:.+]] = tensor.insert_slice %[[sliceFlat]] into %[[iterArg]][%[[iv]], 0] [1, 11] [1, 1] :
+// CHECK: scf.yield %[[update]] :
+// CHECK: return %[[tile]]
+
+// FOREACH: func.func @extract_slice_static(%[[arg0:.+]]:
+// FOREACH-DAG: %[[c20:.+]] = arith.constant 20 : index
+// FOREACH-DAG: %[[c3:.+]] = arith.constant 3 : index
+// FOREACH-DAG: %[[c5:.+]] = arith.constant 5 : index
+// FOREACH-DAG: %[[c7:.+]] = arith.constant 7 : index
+// FOREACH-DAG: %[[init:.+]] = linalg.init_tensor [20, 11] :
+// FOREACH: %[[tile:.+]] = scf.foreach_thread (%[[iv:.+]]) in (%[[c20]]) shared_outs(%[[dest:.+]] = %[[init]])
+// FOREACH: %[[multiIndex:.+]]:3 = affine.delinearize_index %[[iv]] into (%[[c3]], %[[c5]], %[[c7]]
+// FOREACH: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0] [1, 1, 1, 11] [1, 1, 1, 1] :
+// FOREACH: %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3]{{\]}} :
+// FOREACH: perform_concurrently
+// FOREACH-NEXT: tensor.parallel_insert_slice %[[sliceFlat]] into %[[dest]][%[[iv]], 0] [1, 11] [1, 1] :
+// FOREACH: return %[[tile]]
+
+// -----
+
+
+func.func @extract_slice_static_strided(%input: tensor<3x5x7x11xf32>) -> tensor<10x5xf32> {
+ %collapsed = tensor.collapse_shape %input [[0, 1, 2], [3]] : tensor<3x5x7x11xf32> into tensor<105x11xf32>
+ %slice = tensor.extract_slice %collapsed [13, 0] [10, 5] [2, 2] : tensor<105x11xf32> to tensor<10x5xf32>
+ return %slice : tensor<10x5xf32>
+}
+
+// CHECK: #[[$map0:.+]] = affine_map<(d0) -> (d0 * 2 + 13)>
+// CHECK: func.func @extract_slice_static_strided(%[[arg0:.+]]:
+// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[c10:.+]] = arith.constant 10 : index
+// CHECK-DAG: %[[c3:.+]] = arith.constant 3 : index
+// CHECK-DAG: %[[c5:.+]] = arith.constant 5 : index
+// CHECK-DAG: %[[c7:.+]] = arith.constant 7 : index
+// CHECK: %[[init:.+]] = linalg.init_tensor [10, 5] :
+// CHECK: %[[tile:.+]] = scf.for %[[iv:.+]] = %[[c0]] to %[[c10]] step %[[c1]] iter_args(%[[iterArg:.+]] = %[[init]])
+// CHECK: %[[inputIv:.+]] = affine.apply #[[$map0]](%[[iv]])
+// CHECK: %[[multiIndex:.+]]:3 = affine.delinearize_index %[[inputIv]] into (%[[c3]], %[[c5]], %[[c7]]
+// CHECK: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0] [1, 1, 1, 5] [1, 1, 1, 2] :
+// CHECK: %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3]{{\]}} :
+// CHECK: %[[update:.+]] = tensor.insert_slice %[[sliceFlat]] into %[[iterArg]][%[[iv]], 0] [1, 5] [1, 1] :
+// CHECK: scf.yield %[[update]] :
+// CHECK: return %[[tile]]
+
+
+// -----
+
+
+func.func @extract_slice_dynamic(%input: tensor<3x?x?x11xf32>, %offt: index, %size: index) -> tensor<?x5xf32> {
+ %collapsed = tensor.collapse_shape %input [[0, 1, 2], [3]] : tensor<3x?x?x11xf32> into tensor<?x11xf32>
+ %slice = tensor.extract_slice %collapsed [%offt, 0] [%size, 5] [2, 2] : tensor<?x11xf32> to tensor<?x5xf32>
+ return %slice : tensor<?x5xf32>
+}
+
+// CHECK: #[[map0:.+]] = affine_map<(d0)[s0] -> (d0 * 2 + s0)>
+// CHECK: func.func @extract_slice_dynamic(%[[arg0:.+]]: tensor<{{.*}}>, %[[lb:.+]]: index, %[[sz:.+]]: index)
+// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[c2:.+]] = arith.constant 2 : index
+// CHECK-DAG: %[[c3:.+]] = arith.constant 3 : index
+// CHECK: %[[init:.+]] = linalg.init_tensor [%[[sz]], 5] : tensor<?x5xf32>
+// CHECK-DAG: %[[d1:.+]] = tensor.dim %arg0, %[[c1]] : tensor<3x?x?x11xf32>
+// CHECK-DAG: %[[d2:.+]] = tensor.dim %arg0, %[[c2]] : tensor<3x?x?x11xf32>
+// CHECK: %[[tile:.+]] = scf.for %[[iv:.+]] = %[[c0]] to %[[sz]] step %[[c1]] iter_args(%[[iterArg:.+]] = %[[init]])
+// CHECK: %[[inputIv:.+]] = affine.apply #[[map0]](%[[iv]])[%[[lb]]]
+// CHECK: %[[multiIndex:.+]]:3 = affine.delinearize_index %[[inputIv]] into (%[[c3]], %[[d1]], %[[d2]]) :
+// CHECK: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0] [1, 1, 1, 5] [1, 1, 1, 2] :
+// CHECK: %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3]{{\]}} :
+// CHECK: %[[update:.+]] = tensor.insert_slice %[[sliceFlat]] into %[[iterArg]][%[[iv]], 0] [1, 5] [1, 1] :
+// CHECK: scf.yield %[[update]] :
+// CHECK: return %[[tile]] :
+
+// -----
+
+
+func.func @extract_slice_dynamic_multidim(%input: tensor<3x?x?x11x?xf32>, %offt0: index, %size0: index, %offt1: index, %size1: index) -> tensor<?x?xf32> {
+ %collapsed = tensor.collapse_shape %input [[0, 1, 2], [3, 4]] : tensor<3x?x?x11x?xf32> into tensor<?x?xf32>
+ %slice = tensor.extract_slice %collapsed [%offt0, %offt1] [%size0, %size1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+ return %slice : tensor<?x?xf32>
+}
+
+// CHECK: #[[map0:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
+// CHECK: func.func @extract_slice_dynamic_multidim(%[[arg0:.+]]: tensor<3x?x?x11x?xf32>, %[[lb1:.+]]: index, %[[sz1:.+]]: index, %[[lb2:.+]]: index, %[[sz2:.+]]: index)
+// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[c2:.+]] = arith.constant 2 : index
+// CHECK-DAG: %[[c3:.+]] = arith.constant 3 : index
+// CHECK-DAG: %[[c4:.+]] = arith.constant 4 : index
+// CHECK-DAG: %[[c11:.+]] = arith.constant 11 : index
+// CHECK: %[[init:.+]] = linalg.init_tensor [%[[sz1]], %[[sz2]]] : tensor<?x?xf32>
+// CHECK-DAG: %[[d1:.+]] = tensor.dim %[[arg0]], %[[c1]] :
+// CHECK-DAG: %[[d2:.+]] = tensor.dim %[[arg0]], %[[c2]] :
+// CHECK-DAG: %[[d4:.+]] = tensor.dim %[[arg0]], %[[c4]] :
+// CHECK: %[[tile1:.+]] = scf.for %[[iv1:.+]] = %[[c0]] to %[[sz1]] step %[[c1]] iter_args(%[[iterArg1:.+]] = %[[init]])
+// CHECK: %[[tile2:.+]] = scf.for %[[iv2:.+]] = %[[c0]] to %[[sz2]] step %[[c1]] iter_args(%[[iterArg2:.+]] = %[[iterArg1]])
+// CHECK: %[[inputIv1:.+]] = affine.apply #[[map0:.+]](%[[iv1]])[%[[lb1]]]
+// CHECK: %[[multiIndex1:.+]]:3 = affine.delinearize_index %[[inputIv1]] into (%[[c3]], %[[d1]], %[[d2]]) :
+// CHECK: %[[inputIv2:.+]] = affine.apply #[[map0:.+]](%[[iv2]])[%[[lb2]]]
+// CHECK: %[[multiIndex2:.+]]:2 = affine.delinearize_index %[[inputIv2]] into (%[[c11]], %[[d4]]) :
+// CHECK: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex1]]#0, %[[multiIndex1]]#1, %[[multiIndex1]]#2, %[[multiIndex2]]#0, %[[multiIndex2]]#1] [1, 1, 1, 1, 1] [1, 1, 1, 1, 1] :
+// CHECK: %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3, 4]{{\]}} :
+// CHECK: %[[update:.+]] = tensor.insert_slice %[[sliceFlat]] into %[[iterArg2]][%[[iv1]], %[[iv2]]] [1, 1] [1, 1] :
+// CHECK: scf.yield %[[update]] :
+// CHECK: scf.yield %[[tile2]] :
+// CHECK: return %[[tile1]] :
+
+// FOREACH: #[[map1:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
+// FOREACH: func.func @extract_slice_dynamic_multidim(%[[arg0:.+]]: tensor<3x?x?x11x?xf32>, %[[lb1:.+]]: index, %[[sz1:.+]]: index, %[[lb2:.+]]: index, %[[sz2:.+]]: index)
+// FOREACH-DAG: %[[c1:.+]] = arith.constant 1 : index
+// FOREACH-DAG: %[[c2:.+]] = arith.constant 2 : index
+// FOREACH-DAG: %[[c3:.+]] = arith.constant 3 : index
+// FOREACH-DAG: %[[c4:.+]] = arith.constant 4 : index
+// FOREACH-DAG: %[[c11:.+]] = arith.constant 11 : index
+// FOREACH: %[[init:.+]] = linalg.init_tensor [%[[sz1]], %[[sz2]]] : tensor<?x?xf32>
+// FOREACH-DAG: %[[d1:.+]] = tensor.dim %[[arg0]], %[[c1]] :
+// FOREACH-DAG: %[[d2:.+]] = tensor.dim %[[arg0]], %[[c2]] :
+// FOREACH-DAG: %[[d4:.+]] = tensor.dim %[[arg0]], %[[c4]] :
+// FOREACH: %[[tile1:.+]] = scf.foreach_thread (%[[tid1:.+]], %[[tid2:.+]]) in (%[[sz1]], %[[sz2]]) shared_outs(%[[dest:.+]] = %[[init]])
+// FOREACH-DAG: %[[iv1:.+]] = affine.apply #[[map1]](%[[tid1]])[%[[lb1]]]
+// FOREACH: %[[multiIndex1:.+]]:3 = affine.delinearize_index %[[iv1]] into (%[[c3]], %[[d1]], %[[d2]]) :
+// FOREACH-DAG: %[[iv2:.+]] = affine.apply #[[map1]](%[[tid2]])[%[[lb2]]]
+// FOREACH: %[[multiIndex2:.+]]:2 = affine.delinearize_index %[[iv2]] into (%[[c11]], %[[d4]]) :
+// FOREACH: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex1]]#0, %[[multiIndex1]]#1, %[[multiIndex1]]#2, %[[multiIndex2]]#0, %[[multiIndex2]]#1] [1, 1, 1, 1, 1] [1, 1, 1, 1, 1] :
+// FOREACH: %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3, 4]{{\]}} :
+// FOREACH: perform_concurrently
+//FOREACH-NEXT: tensor.parallel_insert_slice %[[sliceFlat]] into %[[dest]][%[[tid1]], %[[tid2]]] [1, 1] [1, 1] :
+
+// -----
+
+// Verifies that a linearized dimension that is not sliced does not generate a loop. Note that this
+// only works for static shapes.
+
+// CHECK: @extract_slice_non_sliced_linearized_dim(%[[arg0:.+]]: tensor<{{.*}}>,
+func.func @extract_slice_non_sliced_linearized_dim(%input: tensor<3x?x?x11x2xf32>, %offt: index, %size: index) -> tensor<?x22xf32> {
+ %collapsed = tensor.collapse_shape %input [[0, 1, 2], [3, 4]] : tensor<3x?x?x11x2xf32> into tensor<?x22xf32>
+ %slice = tensor.extract_slice %collapsed [%offt, 0] [%size, 22] [1, 1] : tensor<?x22xf32> to tensor<?x22xf32>
+ // CHECK: scf.for
+ // CHECK-NOT: scf.for
+ // CHECK: %[[multiIndex:.+]]:3 = affine.delinearize_index
+ // CHECK: tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0, 0] [1, 1, 1, 11, 2] [1, 1, 1, 1, 1]
+ return %slice : tensor<?x22xf32>
+}
diff --git a/mlir/test/lib/Dialect/Tensor/CMakeLists.txt b/mlir/test/lib/Dialect/Tensor/CMakeLists.txt
index 89c996da7b2a..56e59820677e 100644
--- a/mlir/test/lib/Dialect/Tensor/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Tensor/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_library(MLIRTensorTestPasses
LINK_LIBS PUBLIC
MLIRArithmeticDialect
+ MLIRLinalgDialect
MLIRPass
MLIRSCFDialect
MLIRTensorDialect
diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
index 4c38ad1d2dda..f5a7f984ab0a 100644
--- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
@@ -11,8 +11,10 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Transforms/TransformUtils.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -28,7 +30,8 @@ struct TestTensorTransforms
TestTensorTransforms(const TestTensorTransforms &pass) : PassWrapper(pass) {}
void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<arith::ArithmeticDialect, scf::SCFDialect>();
+ registry.insert<arith::ArithmeticDialect, scf::SCFDialect,
+ linalg::LinalgDialect>();
}
StringRef getArgument() const final {
@@ -49,6 +52,19 @@ struct TestTensorTransforms
*this, "test-fold-constant-extract-slice",
llvm::cl::desc("Test folding arith.constant and tensor.extract_slice"),
llvm::cl::init(false)};
+
+ Option<bool> testRewriteExtractSliceWithTiledCollapseShape{
+ *this, "test-rewrite-extract-slice-from-collapse-shape",
+ llvm::cl::desc("Test swapping tensor.extract_slice of a collapse_shape "
+ "with loop nest"),
+ llvm::cl::init(false)};
+
+ Option<bool> useForeach{
+ *this, "use-foreach",
+ llvm::cl::desc(
+ "Use the scf.foreach_thread operation when generating loop nests for "
+ "the extract_slice of collapse_shape pattern"),
+ llvm::cl::init(false)};
};
} // namespace
@@ -74,12 +90,142 @@ static void applyFoldConstantExtractSlicePatterns(Operation *rootOp) {
(void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
}
+namespace {
+/// Base pattern to rewrite a `tensor.collapse_shape -> tensor.extract_slice`.
+/// The `tensor.extract_slice` is replaced by a loop or gather operation that
+/// stitches together the desired tile from slices of the source of the collapse
+/// shape op.
+struct RewriteExtractSliceFromCollapseShapeBase
+ : public OpRewritePattern<tensor::ExtractSliceOp> {
+ RewriteExtractSliceFromCollapseShapeBase(MLIRContext *context)
+ : mlir::OpRewritePattern<tensor::ExtractSliceOp>(context) {}
+
+ /// Emit a loop or gather operation that uses `helper` to take each point in
+ /// the parallel iteration space bounds, extract a slice from the source
+ /// tensor and insert it into `dest`. For examples, see below for `scf.for`
+ /// and `scf.foreach`.
+ virtual LogicalResult
+ emitReplacement(tensor::ExtractSliceOp op, Value dest,
+ tensor::ExtractSliceFromCollapseHelper &helper,
+ PatternRewriter &rewriter) const = 0;
+
+ LogicalResult matchAndRewrite(tensor::ExtractSliceOp op,
+ PatternRewriter &rewriter) const override {
+ auto collapseOp = op.getSource().getDefiningOp<tensor::CollapseShapeOp>();
+ if (!collapseOp)
+ return rewriter.notifyMatchFailure(
+ op, "producer is not a tensor.collapse_shape op");
+
+ // Materialize the output shape values of the slice operation.a
+ ReifiedRankedShapedTypeDims reifiedShapes;
+ if (failed(op.reifyResultShapes(rewriter, reifiedShapes)))
+ return rewriter.notifyMatchFailure(op, "failed to reify result shapes");
+
+ // Create the destination tensor using the above values.
+ Type elementType = op.getSourceType().getElementType();
+ SmallVector<OpFoldResult> outputShape = getAsOpFoldResult(reifiedShapes[0]);
+ Value dest = rewriter.create<linalg::InitTensorOp>(
+ op->getLoc(), outputShape, elementType);
+
+ // Calculate the parameters for the tile loop nest.
+ FailureOr<tensor::ExtractSliceFromCollapseHelper> params =
+ tensor::ExtractSliceFromCollapseHelper::create(rewriter, collapseOp,
+ op);
+ if (failed(params))
+ return rewriter.notifyMatchFailure(
+ op, "could not calculate tiling parameters");
+ return emitReplacement(op, dest, *params, rewriter);
+ }
+};
+
+struct RewriteExtractSliceFromCollapseShapeUsingScfFor
+ : public RewriteExtractSliceFromCollapseShapeBase {
+ RewriteExtractSliceFromCollapseShapeUsingScfFor(MLIRContext *context)
+ : RewriteExtractSliceFromCollapseShapeBase(context) {}
+ LogicalResult emitReplacement(tensor::ExtractSliceOp op, Value dest,
+ tensor::ExtractSliceFromCollapseHelper &helper,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ const unsigned numTiledDims = helper.getIterationSpaceSizes().size();
+ auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ SmallVector<Value> lbs(numTiledDims, zero);
+ SmallVector<Value> steps(numTiledDims, one);
+ scf::LoopNest nest = scf::buildLoopNest(
+ rewriter, loc, lbs, helper.getIterationSpaceSizes(), steps, dest,
+ [&](OpBuilder &nestedBuilder, Location loc, ValueRange outputIvs,
+ ValueRange iterArgs) -> scf::ValueVector {
+ auto [tile, insertParams] =
+ helper.emitLoopNestBody(nestedBuilder, loc, outputIvs);
+
+ // Insert the slice into the destination.
+ Value result = nestedBuilder.create<tensor::InsertSliceOp>(
+ loc, tile, iterArgs[0], insertParams);
+ return {result};
+ });
+ rewriter.replaceOp(op, nest.getResults()[0]);
+ return success();
+ }
+};
+
+struct RewriteExtractSliceFromCollapseShapeUsingScfForeach
+ : public RewriteExtractSliceFromCollapseShapeBase {
+ RewriteExtractSliceFromCollapseShapeUsingScfForeach(MLIRContext *context)
+ : RewriteExtractSliceFromCollapseShapeBase(context) {}
+ LogicalResult emitReplacement(tensor::ExtractSliceOp op, Value dest,
+ tensor::ExtractSliceFromCollapseHelper &helper,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ auto foreachOp = rewriter.create<scf::ForeachThreadOp>(
+ loc, /*outputs=*/dest, /*numThreads=*/helper.getIterationSpaceSizes(),
+ /*threadDimMapping=*/ArrayRef<int64_t>{},
+ [&](OpBuilder &nestedBuilder, Location loc, ValueRange regionArgs) {
+ unsigned numThreadIdRegionArgs =
+ helper.getIterationSpaceSizes().size();
+ unsigned numOutputRegionArgs =
+ regionArgs.size() - numThreadIdRegionArgs;
+ ValueRange outputIvs = regionArgs.take_front(numThreadIdRegionArgs);
+ ValueRange outputArgs = regionArgs.take_back(numOutputRegionArgs);
+ assert(outputArgs.size() == 1 &&
+ "there should only be one output region argument");
+ auto [tile, insertParams] =
+ helper.emitLoopNestBody(nestedBuilder, loc, outputIvs);
+ // Insert the slice into the destination.
+ auto term = nestedBuilder.create<scf::PerformConcurrentlyOp>(loc);
+ nestedBuilder.setInsertionPointToStart(term.getBody());
+ nestedBuilder.create<tensor::ParallelInsertSliceOp>(
+ loc, tile, outputArgs[0], insertParams);
+ });
+ rewriter.replaceOp(op, foreachOp->getResult(0));
+ return success();
+ }
+};
+} // namespace
+
+static LogicalResult
+applyRewriteExtractFromCollapseShapePatterns(Operation *rootOp,
+ bool useForeach) {
+ RewritePatternSet patterns(rootOp->getContext());
+ if (useForeach)
+ patterns.add<RewriteExtractSliceFromCollapseShapeUsingScfForeach>(
+ rootOp->getContext());
+ else
+ patterns.add<RewriteExtractSliceFromCollapseShapeUsingScfFor>(
+ rootOp->getContext());
+ return applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
+}
+
void TestTensorTransforms::runOnOperation() {
Operation *rootOp = getOperation();
if (testSplitPaddingPatterns)
applySplitPaddingPatterns(rootOp);
if (testFoldConstantExtractSlice)
applyFoldConstantExtractSlicePatterns(rootOp);
+ if (testRewriteExtractSliceWithTiledCollapseShape) {
+ if (failed(
+ applyRewriteExtractFromCollapseShapePatterns(rootOp, useForeach)))
+ return signalPassFailure();
+ }
}
namespace mlir {
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 12e0650d0a1b..9bc56617e18e 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -5061,11 +5061,13 @@ cc_library(
"include/mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h",
"include/mlir/Dialect/Tensor/Transforms/Passes.h",
"include/mlir/Dialect/Tensor/Transforms/Transforms.h",
+ "include/mlir/Dialect/Tensor/Transforms/TransformUtils.h"
],
includes = ["include"],
deps = [
":AffineDialect",
":ArithmeticDialect",
+ ":ArithmeticUtils",
":BufferizationDialect",
":BufferizationTransforms",
":DialectUtils",
diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
index e3429a493c6f..11ae74d3d224 100644
--- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
@@ -620,6 +620,7 @@ cc_library(
includes = ["lib/Dialect/Test"],
deps = [
"//mlir:ArithmeticDialect",
+ "//mlir:LinalgDialect",
"//mlir:Pass",
"//mlir:SCFDialect",
"//mlir:TensorDialect",
More information about the Mlir-commits
mailing list