[Mlir-commits] [mlir] 0b1aee3 - Revert "[mlir][Tensor] Add rewrites to extract slices through `tensor.collape_shape`"
Mehdi Amini
llvmlistbot at llvm.org
Fri Sep 2 16:36:54 PDT 2022
Author: Mehdi Amini
Date: 2022-09-02T23:34:52Z
New Revision: 0b1aee38bd2ffbf9f4f70e95c7cb18031d282544
URL: https://github.com/llvm/llvm-project/commit/0b1aee38bd2ffbf9f4f70e95c7cb18031d282544
DIFF: https://github.com/llvm/llvm-project/commit/0b1aee38bd2ffbf9f4f70e95c7cb18031d282544.diff
LOG: Revert "[mlir][Tensor] Add rewrites to extract slices through `tensor.collape_shape`"
This reverts commit 5711957875738c1318f89afd7bf4be388f85a087.
A circular dependency is introduced here from Dialect/Utils/ to the
ViewLikeInterface, but it already depends on Dialect/Utils.
Also this introduces a dependency from lib/Dialect/Tensor to Linalg,
which isn't obviously correct from a layering point of view.
Added:
Modified:
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
mlir/include/mlir/Interfaces/ViewLikeInterface.h
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
mlir/lib/Dialect/Utils/CMakeLists.txt
mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
mlir/lib/Interfaces/ViewLikeInterface.cpp
mlir/test/lib/Dialect/Tensor/CMakeLists.txt
mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
Removed:
mlir/include/mlir/Dialect/Tensor/Transforms/TransformUtils.h
mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshape.cpp
mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 2a47f9bdbb116..9a0bbf690cf02 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -283,11 +283,7 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice",
// Build an ExtractSliceOp with dynamic entries and inferred result type.
OpBuilder<(ins "RankedTensorType":$resultType, "Value":$source,
"ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides,
- CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
- // Build an ExtractSliceOp with mixed static and dynamic entries packed in
- // a Range vector.
- OpBuilder<(ins "Value":$source, "ArrayRef<Range>":$ranges,
- CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>
];
let extraClassDeclaration = extraBaseClassDeclaration # [{
@@ -605,11 +601,6 @@ def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [
// Build a InsertSliceOp with dynamic entries.
OpBuilder<(ins "Value":$source, "Value":$dest,
"ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides,
- CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
- // Build an InsertSliceOp with mixed static and dynamic entries packed in
- // a Range vector.
- OpBuilder<(ins "Value":$source, "Value":$dest,
- "ArrayRef<Range>":$ranges,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>
];
@@ -1208,11 +1199,7 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
"ArrayRef<OpFoldResult>":$offsets, "ArrayRef<OpFoldResult>":$sizes,
"ArrayRef<OpFoldResult>":$strides,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
- // Build a ParallelInsertSliceOp with mixed static and dynamic entries
- // packed into a Range vector.
- OpBuilder<(ins "Value":$source, "Value":$dest,
- "ArrayRef<Range>":$ranges,
- CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+
// Build a ParallelInsertSliceOp with dynamic entries.
OpBuilder<(ins "Value":$source, "Value":$dest,
"ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides,
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/TransformUtils.h b/mlir/include/mlir/Dialect/Tensor/Transforms/TransformUtils.h
deleted file mode 100644
index 2ca556275af12..0000000000000
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/TransformUtils.h
+++ /dev/null
@@ -1,210 +0,0 @@
-//===- TransformsUtils.h - Tensor Transformation Utilities-------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-#ifndef MLIR_DIALECT_TENSOR_TRANSFORMS_TRANSFORMUTILS_H
-#define MLIR_DIALECT_TENSOR_TRANSFORMS_TRANSFORMUTILS_H
-
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/IR/PatternMatch.h"
-
-namespace mlir {
-namespace tensor {
-
-//===----------------------------------------------------------------------===//
-// Extract slice from `tensor.collapse_shape`
-//===----------------------------------------------------------------------===//
-
-/// This class assists with generating IR required to materialize an
-/// arbitrary-sized slice from the result of a CollapseShapeOp. In order to
-/// accomplish this, a loop nest or similar operation must be created by the
-/// caller. The purpose of the loop nest is to generate a "tiling by 1" of all
-/// sliced dimensions. The "tiling by 1" assembles all elements of the result
-/// tile over dimensions that would have been impossible to directly slice.
-///
-/// The class provides three methods:
-/// 1. `ExtractSliceFromCollapseHelper::create`: emits IR that should
-/// appear before the loop nest and populates the internal state.
-/// 2. `ExtractSliceFromCollapseHelper::getIterationSpaceSizes`: returns
-/// parameters used by the caller to construct the loop nest.
-/// 3. `ExtractSliceFromCollapseHelper::emitLoopNestBody`:
-/// emits IR to construct a "size-1 tile" of the desired result and returns a
-/// set of ranges where the tile should be inserted into the destination
-/// tensor.
-///
-/// ### Intended usage:
-///
-/// The caller should first call `ExtractSliceFromCollapseHelper::create` and
-/// then create a destination tensor that is the same size as the desired slice.
-/// The caller then creates a loop nest that iterates over the multi-dimensional
-/// iteration space defined by `[0, ub[0]) x [0, ub[1]] x ... x [0, ub[N-1]]`
-/// where `ub` is the upper bound given by
-/// `ExtractSliceFromCollapseHelper::getIterationSpaceSizes`. Inside the body of
-/// the loop nest, the caller should call
-/// `ExtractSliceFromCollapseHelper::emitLoopNestBody` and provide the induction
-/// variables. This returns a sub-tile and a set of ranges that describe where
-/// this tile should be inserted into the result by the caller. For a complete
-/// example of usage, see the examples in the TestTensorTransforms pass.
-///
-/// ### Example:
-/// Consider the following IR:
-/// ```
-/// %0 = linalg.generic ... -> tensor<3x?x?x11x?xf32>
-/// %1 = tensor.collapse_shape %0 [[0, 1, 2], [3, 4]]
-/// : tensor<3x?x?x11x?xf32> into tensor<?x?xf32>
-/// %2 = tensor.extract_slice %1 [%offt0, %offt1][%size0, %size1][1, 1]
-/// : tensor<?x?xf32> to tensor<?x?xf32>
-/// ```
-///
-/// We can construct %2 by generating the following, which only uses `%0`:
-///
-/// ```
-/// %dest = linalg.init_tensor [%size0, %size1] : tensor<?x?xf32>
-/// %1 = tensor.dim %0, %c1 : tensor<3x?x?x11x?xf32>
-/// %2 = tensor.dim %0, %c2 : tensor<3x?x?x11x?xf32>
-/// %3 = tensor.dim %0, %c4 : tensor<3x?x?x11x?xf32>
-///
-/// %result = scf.for %iv0 = %c0 to %arg2 step %c1 iter_args(%arg6 = %dest) ->
-/// (tensor<?x?xf32>) {
-/// %5 = scf.for %iv1 = %c0 to %arg4 step %c1 iter_args(%arg8 = %arg6)
-/// -> (tensor<?x?xf32>) {
-/// %lin0 = (affine.apply) %iv0 + %offt0
-/// %lin1 = (affine.apply) %iv1 + %offt1
-///
-/// %mi0:3 = affine.delinearize_index %lin0 into (%c3, %1, %2)
-/// %mi1:2 = affine.delinearize_index %lin1 into (%c11, %3)
-///
-/// %sub_tile = tensor.extract_slice %0
-/// [%mi0#0, %mi0#1, %mi0#2, %mi1#0, %mi1#1]
-/// [1, 1, 1, 1, 1]
-/// [1, 1, 1, 1, 1]
-/// : tensor<3x?x?x11x?xf32> to tensor<1x1x1x1x1xf32>
-/// %sub_tile_collapsed = tensor.collapse_shape %sub_tile
-/// [[0, 1, 2], [3, 4]]
-/// : tensor<1x1x1x1x1xf32> into tensor<1x1xf3
-///
-/// %12 = tensor.insert_slice %sub_tile_collapsed into
-/// %arg8[%iv0, %iv1] [1, 1] [1, 1]
-/// : tensor<1x1xf32> into tensor<?x?xf32>
-/// scf.yield %12 : tensor<?x?xf32>
-/// }
-/// scf.yield %5 : tensor<?x?xf32>
-/// }
-/// ```
-///
-/// ### Explanation of example:
-///
-/// Each step above is explained below.
-///
-/// #### Step 0: Create %dest and materialization of shapes.
-/// This step is self-explanatory and performed by the caller. It can be
-/// done before or after calling `ExtractSliceFromCollapseHelper::create`,
-/// which materializes the source shape (`%0, %1, %2`).
-///
-/// #### Step 1: Create loop nest.
-///
-/// The caller creates the loop nest (depicted here is `scf.for`, but any other
-/// similar op can be used). The iteration should start at zero and proceed with
-/// step size 1 to the upper bounds given by
-/// `ExtractSliceFromCollapseHelper::getIterationSpaceSizes`. This forms the
-/// basis for the "tiling by 1".
-///
-/// #### Step 2: Transform (%iv0, %iv1) from the index space of %3 to the index
-/// space of %0.
-///
-/// This step is performed by
-/// `ExtractSliceFromCollapseHelper::emitLoopNestBody`.
-///
-/// The induction variables `%iv0` and `%iv1` live in the
-/// index space of %2 (for dimensions 0 and 1, respectively). `%lin0` and
-/// `%lin1` are the result of inverting or resolve the index space
-/// transformation represented by the slice operation, accounting for offset and
-/// stride. Subsequently, `%mi0` and `%mi1` are the result of applying the
-/// inverse index space transformation represented by `tensor.collapse_shape`.
-/// This is accomplished using `affine.delinearize_index`. Note that %iv0
-/// and %iv1 now correspond to multi-indices `%mi0:3` and `%mi1:2`.
-///
-/// #### Step 3: Extract a sub-tile slice from the source.
-///
-/// This step is also performed by
-/// `ExtractSliceFromCollapseHelper::emitLoopNestBody`.
-///
-/// The indices `%mi0` and `%mi1` are used to extract a slice from %0. This
-/// slice is then collapsed down to match the result rank.
-///
-/// #### Step 4: Insert sub-tile into the destination
-///
-/// This step is performed by the caller using the results of
-/// `ExtractSliceFromCollapseHelper::emitLoopNestBody`.
-///
-/// In the above example, the slice insertion parameters are straightforward,
-/// but in other possible situations, the slice parameters are more complicated,
-/// which is why this helper calculates them for the caller. These other
-/// situations correspond to:
-/// 1. The presence of linearized dimensions that are not sliced
-/// 2. The presence of non-linearized dimensions that are sliced.
-class ExtractSliceFromCollapseHelper {
-public:
- /// Given a CollapseShapeOp and a set of ranges describing the desired slice
- /// of its result, emits IR to materialize the shapes of the input and output
- /// tensors, and returns an instance of the initialized class. Returns failure
- /// if the slice is rank-reducing.
- static FailureOr<ExtractSliceFromCollapseHelper>
- create(OpBuilder &b, tensor::CollapseShapeOp op, ArrayRef<Range> sliceParams);
-
- /// Given a CollapseShapeOp and an ExtractSliceOp acting on its result, emits
- /// IR to materialize the shapes of the input and output tensors of the
- /// CollapseShapeOp, and returns an instance of the initialized class. Returns
- /// failure if the slice is rank-reducing.
- static FailureOr<ExtractSliceFromCollapseHelper>
- create(OpBuilder &b, tensor::CollapseShapeOp collapseOp,
- tensor::ExtractSliceOp extractOp);
-
- ExtractSliceFromCollapseHelper(
- tensor::CollapseShapeOp collapseShapeOp,
- ArrayRef<OpFoldResult> collapseShapeInputShape,
- ArrayRef<OpFoldResult> collapseShapeOutputShape,
- ArrayRef<Range> extractSliceParams,
- const llvm::SmallBitVector &linearizedDimensions,
- const llvm::SmallBitVector &slicedDimensions, ArrayRef<Value> tiledSizes)
- : collapseShapeOp(collapseShapeOp),
- collapseShapeInputShape(collapseShapeInputShape),
- collapseShapeOutputShape(collapseShapeOutputShape),
- sliceParams(extractSliceParams),
- linearizedDimensions(linearizedDimensions),
- slicedDimensions(slicedDimensions), tiledSizes(tiledSizes) {}
-
- /// Return the upper bounds of the iteration space (with 0 offset and stride
- /// 1) required to create the desired slice. Note that this is not the same
- /// as the `sizes` parameters of the ExtractSliceOp because not all dimensions
- /// of the slice are required to be tiled to form the result.
- const SmallVector<Value> &getIterationSpaceSizes() { return tiledSizes; }
-
- /// Generates the IR inside of the caller's loop nest for 1) inverting the
- /// index mappings of the ExtractSliceOp->CollapseShapeOp chain and 2)
- /// extracting the CollapseShapeOp source tensor tile for this specified
- /// iteration space point `tileInductionVars` and 3) calculating where to
- /// insert the extracted tile. The returned pair consists of the results of
- /// (2) and (3) and should be used by the caller to insert into the
- /// destination tensor.
- std::pair<Value, SmallVector<Range>>
- emitLoopNestBody(OpBuilder &builder, Location loc,
- ValueRange tileInductionVars);
-
-private:
- tensor::CollapseShapeOp collapseShapeOp;
- SmallVector<OpFoldResult> collapseShapeInputShape;
- SmallVector<OpFoldResult> collapseShapeOutputShape;
- SmallVector<Range> sliceParams;
- llvm::SmallBitVector linearizedDimensions;
- llvm::SmallBitVector slicedDimensions;
- SmallVector<Value> tiledSizes;
-};
-
-} // namespace tensor
-} // namespace mlir
-
-#endif // MLIR_DIALECT_TENSOR_TRANSFORMS_TRANSFORMUTILS_H
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 2196a64ac49e4..61f7aea9f5267 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -16,7 +16,6 @@
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
-#include "mlir/Interfaces/ViewLikeInterface.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/StringRef.h"
@@ -374,90 +373,6 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
}
};
-/// The input parameters `offsets`, `sizes`, `strides` specify a rectangular
-/// non rank-reducing slice of the collapse_shape output. Try to find which
-/// dimensions have been sliced and which dimensions are not sliced (offset = 0,
-/// size = dim, size = 1). Note that this conservative as it cannot detect if a
-/// dynamic size corresponds to the full tensor dimension or not.
-llvm::SmallBitVector getSlicedDimensions(ArrayRef<OpFoldResult> sliceInputShape,
- ArrayRef<Range> sliceParams);
-
-/// Determine which dimensions are linearized by a `tensor.collapse_shape` op by
-/// inspecting its reassociation indices.
-llvm::SmallBitVector
-getLinearizedDimensions(ArrayRef<ReassociationIndices> reassociationIndices);
-
-/// Given the parameters for both operations in a `CollapseShape->ExtractSlice`
-/// chain and reified source and result shapes of the CollapseShapeOp, this
-/// class provides two functions that assist with directly forming the result
-/// of the extract slice by "tiling the CollapseShapeOp by 1".
-//// Example:
-// clang-format off
-/// ```
-/// %0 = linalg.generic ... -> tensor<3x7x11x10xf32>
-/// %1 = tensor.collapse_shape %0 [[0, 1, 2], [3]] : ... to tensor<341x10xf32>
-/// %2 = tensor.extract_slice %1 [13, 0] [10, 10] [2, 1] : .... tensor<10x10xf32>
-/// ```
-/// This class helps build the below IR to replace %2:
-/// ```
-/// %dest = linalg.init_tensor() : tensor<10x10xf32>
-/// %2 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%arg0) -> tensor<10x10xf32> {
-/// %linear_index = affine.apply affine_map<(d0)[]->(d0*2 + 11)>(%iv)
-/// %3:3 = arith.delinearize_index %iv into (3, 7, 11)
-///
-/// // This function takes %3 (multiIndices) and the parameters for the slice below.
-/// %4 = tensor.extract_slice %0 [%3#0, %3#1, %3#2, 0] [1, 1, 1, 10] [1, 1, 1, 1] :
-/// tensor<3x7x11x10xf32> to tensor<1x1x1x10xf32>
-///
-/// %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] :
-/// tensor<1x1x1x10xf32> into tensor<1x10xf32>
-/// %6 = tensor.insert_slice %5 into %arg0 [%iv, 0] [1, 10] [1, 1] :
-/// tensor<1x10xf32> into tensor<10x10xf32>
-/// scf.yield %6 : tensor<10x10xf32>
-/// }
-/// ```
-// clang-format on
-class SliceFromCollapseHelper {
-public:
- SliceFromCollapseHelper(ArrayRef<ReassociationIndices> reassociationIndices,
- ArrayRef<OpFoldResult> collapseShapeInputShape,
- ArrayRef<OpFoldResult> collapseShapeOutputShape,
- ArrayRef<Range> extractSliceParams)
- : reassociationIndices(reassociationIndices),
- collapseShapeInputShape(collapseShapeInputShape),
- collapseShapeOutputShape(collapseShapeOutputShape),
- sliceParams(extractSliceParams),
- linearizedDimensions(getLinearizedDimensions(reassociationIndices)),
- slicedDimensions(getSlicedDimensions(collapseShapeOutputShape,
- extractSliceParams)) {}
-
- /// This function takes multi-indices and maps them to ExtractSlice parameters
- /// in the index space of the CollapseShape's source tensor. This function's
- /// signature can be described by `(D_0, D_1,.. D_{n-1}) -> (offsets, sizes,
- /// strides)` where `n` the number of "tiled dimensions", which are the
- /// dimensions of the output that are linearized by the collapse shape op and
- /// are also sliced. Each `D_i` is a tuple that must represent a valid
- /// multi-index for the `i-th` tiled dimension. In the example above, there is
- /// only one tiled dimension (D_0) and `arith.delinearize_index` produces the
- /// multi-index (%3) that would be passed to this function to generate the
- /// parameters for the `tensor.extract_slice` op (%4).
- SmallVector<Range> getExtractSliceParams(ArrayRef<ValueRange> multiIndices);
-
- /// This function takes indices in the index space of the "tiled dimensions"
- /// described above and returns a set of Range variables that describe how the
- /// slice should be inserted into the destination. In the example above, `%iv`
- /// would be passed to this function to generate the parameters for the
- /// `tensor.insert_slice` op producing %6.
- SmallVector<Range> getInsertSliceParams(ValueRange tileIndices);
-
-private:
- SmallVector<ReassociationIndices> reassociationIndices;
- SmallVector<OpFoldResult> collapseShapeInputShape;
- SmallVector<OpFoldResult> collapseShapeOutputShape;
- SmallVector<Range> sliceParams;
- llvm::SmallBitVector linearizedDimensions;
- llvm::SmallBitVector slicedDimensions;
-};
} // namespace mlir
#endif // MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
index ea50092ea5c9f..4bb13ce750607 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
@@ -30,13 +30,6 @@ struct Range {
OpFoldResult stride;
};
-/// Given an array of Range values, return a tuple of (offset vector, sizes
-/// vector, and strides vector) formed by separating out the individual elements
-/// of each range.
-std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
- SmallVector<OpFoldResult>>
-getOffsetsSizesAndStrides(ArrayRef<Range> ranges);
-
/// Return a vector of OpFoldResults given the special value
/// that indicates whether of the value is dynamic or not.
SmallVector<OpFoldResult, 4> getMixedValues(ArrayAttr staticValues,
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index fdbbcdabcd831..cd4c4b9f03a59 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1124,15 +1124,6 @@ void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
}
-/// Build an ExtractSliceOp with mixed static and dynamic entries packed into a
-/// Range vector.
-void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
- ArrayRef<Range> ranges,
- ArrayRef<NamedAttribute> attrs) {
- auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
- build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
-}
-
/// Build an ExtractSliceOp with dynamic entries and custom result type. If the
/// type passed is nullptr, it is inferred.
void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
@@ -1520,15 +1511,6 @@ void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
result.addAttributes(attrs);
}
-/// Build an InsertSliceOp with mixed static and dynamic entries packed into a
-/// Range vector.
-void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
- Value dest, ArrayRef<Range> ranges,
- ArrayRef<NamedAttribute> attrs) {
- auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
- build(b, result, source, dest, offsets, sizes, strides, attrs);
-}
-
// Build a InsertSliceOp with dynamic entries.
void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
Value dest, ValueRange offsets, ValueRange sizes,
@@ -2291,16 +2273,6 @@ void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
result.addAttributes(attrs);
}
-/// Build an ParallelInsertSliceOp with mixed static and dynamic entries packed
-/// into a Range vector.
-void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
- Value source, Value dest,
- ArrayRef<Range> ranges,
- ArrayRef<NamedAttribute> attrs) {
- auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
- build(b, result, source, dest, offsets, sizes, strides, attrs);
-}
-
// Build a ParallelInsertSliceOp with dynamic entries.
void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
Value source, Value dest, ValueRange offsets,
diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
index 0b200e03226da..66e4cc906f238 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
@@ -1,7 +1,6 @@
add_mlir_dialect_library(MLIRTensorTransforms
BufferizableOpInterfaceImpl.cpp
Bufferize.cpp
- ExtractSliceFromReshape.cpp
SplitPadding.cpp
SwapExtractSliceWithProducer.cpp
diff --git a/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshape.cpp b/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshape.cpp
deleted file mode 100644
index 4809cc40c6610..0000000000000
--- a/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshape.cpp
+++ /dev/null
@@ -1,181 +0,0 @@
-//===- ExtractSliceFromReshape.cpp - Slice reshape rewrites-------*- C++-*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file implements rewrites that replace slices of reshape results with
-// aggregated slices of the reshape source.
-//
-//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
-#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
-#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
-#include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/Dialect/Tensor/Transforms/TransformUtils.h"
-#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
-#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
-#include "mlir/Dialect/Utils/StaticValueUtils.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/OpDefinition.h"
-#include "llvm/ADT/STLExtras.h"
-
-using namespace mlir;
-using namespace mlir::tensor;
-
-/// Get the dimension size of a value of RankedTensor type at the
-OpFoldResult getShapeDimSize(OpBuilder &b, Location loc, Value rankedTensor,
- int64_t dimIdx) {
- RankedTensorType tensorType = rankedTensor.getType().cast<RankedTensorType>();
- if (!tensorType.isDynamicDim(dimIdx)) {
- return b.getIndexAttr(tensorType.getDimSize(dimIdx));
- }
- Value idxValue = b.create<arith::ConstantIndexOp>(loc, dimIdx);
- return b.createOrFold<tensor::DimOp>(loc, rankedTensor, idxValue);
-}
-
-/// Get all the dimension sizes of a value of RankedTensor type.
-static SmallVector<OpFoldResult> getShapeDimSizes(OpBuilder &b, Location loc,
- Value rankedTensor) {
- SmallVector<OpFoldResult> dimSizes;
- RankedTensorType tensorType = rankedTensor.getType().cast<RankedTensorType>();
- for (unsigned i = 0; i < tensorType.getRank(); i++)
- dimSizes.push_back(getShapeDimSize(b, loc, rankedTensor, i));
- return dimSizes;
-}
-
-/// A tuple that represents (dimension number, dimension value).
-using DimAndIndex = std::tuple<unsigned, Value>;
-
-/// Transform `dimAndIndex` from the output index space of a (non-rank-reducing)
-/// slice described by `sliceParams` into the input index space.
-static DimAndIndex invertSliceIndexing(OpBuilder &b, Location loc,
- ArrayRef<Range> sliceParams,
- const DimAndIndex &dimAndIndex) {
- AffineExpr d0, s0, s1;
- bindDims(b.getContext(), d0);
- bindSymbols(b.getContext(), s0, s1);
- auto [dim, indexValue] = dimAndIndex;
- assert(dim < sliceParams.size() && "slice should be non rank-reducing");
- return std::make_pair(
- dim,
- makeComposedAffineApply(
- b, loc, s0 + d0 * s1,
- {indexValue,
- getValueOrCreateConstantIndexOp(b, loc, sliceParams[dim].offset),
- getValueOrCreateConstantIndexOp(b, loc, sliceParams[dim].stride)}));
-}
-
-/// Transform `dimAndIndex` from the result tensor index space of a
-/// CollapseShapeOp to the source tensor index space.
-static ValueRange invertCollapseShapeIndexing(
- OpBuilder &b, Location loc, ArrayRef<ReassociationIndices> reassociation,
- ArrayRef<OpFoldResult> reshapeSourceShape, const DimAndIndex &dimAndIndex) {
- const auto &[dim, indexValue] = dimAndIndex;
- SmallVector<OpFoldResult> basis;
- for (int64_t i : reassociation[dim])
- basis.push_back(reshapeSourceShape[i]);
- auto delinearized =
- b.create<AffineDelinearizeIndexOp>(loc, indexValue, basis);
- return delinearized->getResults();
-}
-
-FailureOr<ExtractSliceFromCollapseHelper>
-tensor::ExtractSliceFromCollapseHelper::create(
- OpBuilder &b, tensor::CollapseShapeOp collapseOp,
- tensor::ExtractSliceOp extractOp) {
- if (extractOp.getSource().getDefiningOp<tensor::CollapseShapeOp>() !=
- collapseOp)
- return failure();
- SmallVector<Range> ranges;
- ranges.reserve(extractOp.getSourceType().getRank());
- for (const auto &[o, s, st] :
- llvm::zip(extractOp.getMixedOffsets(), extractOp.getMixedSizes(),
- extractOp.getMixedStrides())) {
- ranges.push_back({o, s, st});
- }
- return ExtractSliceFromCollapseHelper::create(b, collapseOp, ranges);
-}
-
-FailureOr<ExtractSliceFromCollapseHelper>
-tensor::ExtractSliceFromCollapseHelper::create(OpBuilder &b,
- tensor::CollapseShapeOp op,
- ArrayRef<Range> sliceParams) {
-
- // Materialize the output shape of the collapse_shape operation. This will
- // create IR describing the output shape in terms of the input shape.
- ReifiedRankedShapedTypeDims reifiedShapes;
- ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface =
- dyn_cast<ReifyRankedShapedTypeOpInterface>(op.getOperation());
- if (failed(reifyShapedTypeInterface.reifyResultShapes(b, reifiedShapes)))
- return failure();
- SmallVector<OpFoldResult> collapseShapeOutputShape =
- getAsOpFoldResult(reifiedShapes[0]);
- SmallVector<ReassociationIndices> reassociationIndices =
- op.getReassociationIndices();
-
- // Determine which of the CollapseShapeOp's result dimensions are sliced
- // and/or linearized.
- llvm::SmallBitVector linearizedDimensions =
- getLinearizedDimensions(reassociationIndices);
- llvm::SmallBitVector slicedDimensions =
- getSlicedDimensions(collapseShapeOutputShape, sliceParams);
-
- auto collapseShapeInputShape = getShapeDimSizes(b, op.getLoc(), op.getSrc());
-
- SmallVector<OpFoldResult> srcShape =
- getShapeDimSizes(b, op->getLoc(), op.getSrc());
-
- SmallVector<Value> tileSizes;
- for (unsigned i = 0; i < sliceParams.size(); i++) {
- if (slicedDimensions[i] && linearizedDimensions[i])
- tileSizes.push_back(
- getValueOrCreateConstantIndexOp(b, op.getLoc(), sliceParams[i].size));
- }
-
- return ExtractSliceFromCollapseHelper(
- op, collapseShapeInputShape, collapseShapeOutputShape, sliceParams,
- linearizedDimensions, slicedDimensions, tileSizes);
-}
-
-std::pair<Value, SmallVector<Range>>
-tensor::ExtractSliceFromCollapseHelper::emitLoopNestBody(
- OpBuilder &builder, Location loc, ValueRange tileInductionVars) {
- // Create the helper class for forming the slice parameters.
- const SmallVector<ReassociationIndices> reassociationIndices =
- collapseShapeOp.getReassociationIndices();
- SliceFromCollapseHelper helper(reassociationIndices, collapseShapeInputShape,
- collapseShapeOutputShape, sliceParams);
-
- // Get the indices of the tiled dims (linearized by the collapse_shape
- // and sliced by the extract_slice) invert the index spaces
- // transformations.
- SmallVector<ValueRange> multiIndices;
- unsigned loopIdx = 0;
- for (unsigned i = 0, e = linearizedDimensions.size(); i < e; i++) {
- if (linearizedDimensions[i] && slicedDimensions[i]) {
- DimAndIndex tb =
- invertSliceIndexing(builder, loc, sliceParams,
- std::make_tuple(i, tileInductionVars[loopIdx++]));
- multiIndices.push_back(invertCollapseShapeIndexing(
- builder, loc, reassociationIndices, collapseShapeInputShape, tb));
- }
- }
-
- auto extractParams = helper.getExtractSliceParams(multiIndices);
-
- Value subTileResult = builder.create<tensor::ExtractSliceOp>(
- loc, collapseShapeOp.getSrc(), extractParams);
-
- SmallVector<Range> insertParams =
- helper.getInsertSliceParams(tileInductionVars);
-
- // Collapse the dimensions of the source slice back down.
- Value collapsedResult = builder.create<tensor::CollapseShapeOp>(
- loc, subTileResult, reassociationIndices);
- return std::make_pair(collapsedResult, insertParams);
-}
diff --git a/mlir/lib/Dialect/Utils/CMakeLists.txt b/mlir/lib/Dialect/Utils/CMakeLists.txt
index 578c6a67be581..f329afa8fa755 100644
--- a/mlir/lib/Dialect/Utils/CMakeLists.txt
+++ b/mlir/lib/Dialect/Utils/CMakeLists.txt
@@ -6,5 +6,4 @@ add_mlir_library(MLIRDialectUtils
LINK_LIBS PUBLIC
MLIRIR
- MLIRViewLikeInterface
)
diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index adc53dcb9743f..b26b3d93541a2 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -8,11 +8,8 @@
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
-#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
-#include "mlir/IR/OpDefinition.h"
-#include "mlir/Interfaces/ViewLikeInterface.h"
#include <numeric>
@@ -273,88 +270,3 @@ bool mlir::hasNonIdentityLayout(Type type) {
return !memrefType.getLayout().isIdentity();
return false;
}
-
-llvm::SmallBitVector
-mlir::getSlicedDimensions(ArrayRef<OpFoldResult> sliceInputShape,
- ArrayRef<Range> sliceParams) {
- assert(sliceParams.size() == sliceInputShape.size() &&
- "only supports non rank-reducing case");
- llvm::SmallBitVector mask(sliceInputShape.size());
- unsigned idx = 0;
- for (const auto &[offset, size, stride] : sliceParams) {
- Optional<int64_t> offsetConst = getConstantIntValue(offset);
- Optional<int64_t> strideConst = getConstantIntValue(stride);
- mask[idx] = !isEqualConstantIntOrValue(size, sliceInputShape[idx]) ||
- (!strideConst || *strideConst != 1) ||
- (!offsetConst || *offsetConst != 0);
- idx++;
- }
- return mask;
-}
-
-llvm::SmallBitVector mlir::getLinearizedDimensions(
- ArrayRef<ReassociationIndices> reassociationIndices) {
- llvm::SmallBitVector result(reassociationIndices.size());
- for (const auto &it : llvm::enumerate(reassociationIndices))
- result[it.index()] = it.value().size() > 1;
- return result;
-}
-
-SmallVector<Range> SliceFromCollapseHelper::getExtractSliceParams(
- ArrayRef<ValueRange> multiIndices) {
- assert(!multiIndices.empty() && !multiIndices[0].empty() &&
- "multiIndices should not be empty");
- unsigned loopIdx = 0;
- MLIRContext *ctx = multiIndices[0][0].getContext();
- auto oneAttr = IntegerAttr::get(IndexType::get(ctx), 1);
- auto zeroAttr = IntegerAttr::get(IndexType::get(ctx), 0);
- SmallVector<Range> offsetsSizesAndStrides;
- offsetsSizesAndStrides.reserve(collapseShapeInputShape.size());
- for (const auto &it : llvm::enumerate(reassociationIndices)) {
- // Case 1: Linearized dimensions that have also been sliced. These
- // are size of 1 because we are iterating over these dimensions. The
- // offsets are exactly the de-linearized multi-indices.
- if (slicedDimensions[it.index()] && linearizedDimensions[it.index()]) {
- llvm::append_range(
- offsetsSizesAndStrides,
- llvm::map_range(multiIndices[loopIdx++], [&](Value v) -> Range {
- return Range{getAsOpFoldResult(v), oneAttr, oneAttr};
- }));
- continue;
- }
-
- // Case 2: One or possibly multiple combined input dimensions, but we
- // have proven that these are not sliced. In this case we just take
- // the full extent of each dimension in the reassociation list.
- if (linearizedDimensions[it.index()]) {
- llvm::append_range(
- offsetsSizesAndStrides,
- llvm::map_range(it.value(), [&](int64_t idx) -> Range {
- return {zeroAttr, collapseShapeInputShape[idx], oneAttr};
- }));
- continue;
- }
-
- // Case 3: A single index, but it may be sliced.
- offsetsSizesAndStrides.push_back(sliceParams[it.index()]);
- }
- return offsetsSizesAndStrides;
-}
-
-SmallVector<Range>
-SliceFromCollapseHelper::getInsertSliceParams(ValueRange tileIndices) {
- MLIRContext *ctx = tileIndices[0].getContext();
- auto one = IntegerAttr::get(IndexType::get(ctx), 1);
- auto zero = IntegerAttr::get(IndexType::get(ctx), 0);
- SmallVector<Range> insertParams;
- insertParams.reserve(linearizedDimensions.size());
- unsigned loopIdx = 0;
- for (unsigned i = 0; i < linearizedDimensions.size(); i++) {
- if (linearizedDimensions[i] && slicedDimensions[i]) {
- insertParams.push_back(Range{tileIndices[loopIdx++], one, one});
- continue;
- }
- insertParams.push_back(Range{zero, sliceParams[i].size, one});
- }
- return insertParams;
-}
diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp
index 6b5d579a339a1..89ebd81271721 100644
--- a/mlir/lib/Interfaces/ViewLikeInterface.cpp
+++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp
@@ -17,21 +17,6 @@ using namespace mlir;
/// Include the definitions of the loop-like interfaces.
#include "mlir/Interfaces/ViewLikeInterface.cpp.inc"
-std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
- SmallVector<OpFoldResult>>
-mlir::getOffsetsSizesAndStrides(ArrayRef<Range> ranges) {
- SmallVector<OpFoldResult> offsets, sizes, strides;
- offsets.reserve(ranges.size());
- sizes.reserve(ranges.size());
- strides.reserve(ranges.size());
- for (const auto &[offset, size, stride] : ranges) {
- offsets.push_back(offset);
- sizes.push_back(size);
- strides.push_back(stride);
- }
- return std::make_tuple(offsets, sizes, strides);
-}
-
LogicalResult mlir::verifyListOfOperandsOrIntegers(
Operation *op, StringRef name, unsigned numElements, ArrayAttr attr,
ValueRange values, llvm::function_ref<bool(int64_t)> isDynamic) {
diff --git a/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir b/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir
deleted file mode 100644
index d8ca129bf59a8..0000000000000
--- a/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir
+++ /dev/null
@@ -1,164 +0,0 @@
-// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-rewrite-extract-slice-from-collapse-shape %s | FileCheck %s
-// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns="test-rewrite-extract-slice-from-collapse-shape use-foreach" %s | FileCheck %s --check-prefix=FOREACH
-
-func.func @extract_slice_static(%input: tensor<3x5x7x11xf32>) -> tensor<20x11xf32> {
- %collapsed = tensor.collapse_shape %input [[0, 1, 2], [3]] : tensor<3x5x7x11xf32> into tensor<105x11xf32>
- %slice = tensor.extract_slice %collapsed [0, 0] [20, 11] [1, 1] : tensor<105x11xf32> to tensor<20x11xf32>
- return %slice : tensor<20x11xf32>
-}
-
-// CHECK: func.func @extract_slice_static(%[[arg0:.+]]:
-// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index
-// CHECK-DAG: %[[c20:.+]] = arith.constant 20 : index
-// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index
-// CHECK-DAG: %[[c3:.+]] = arith.constant 3 : index
-// CHECK-DAG: %[[c5:.+]] = arith.constant 5 : index
-// CHECK-DAG: %[[c7:.+]] = arith.constant 7 : index
-// CHECK-DAG: %[[init:.+]] = linalg.init_tensor [20, 11] :
-// CHECK-DAG: %[[tile:.+]] = scf.for %[[iv:.+]] = %[[c0]] to %[[c20]] step %[[c1]] iter_args(%[[iterArg:.+]] = %[[init]])
-// CHECK: %[[multiIndex:.+]]:3 = affine.delinearize_index %[[iv]] into (%[[c3]], %[[c5]], %[[c7]]
-// CHECK: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0] [1, 1, 1, 11] [1, 1, 1, 1] :
-// CHECK: %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3]{{\]}} :
-// CHECK: %[[update:.+]] = tensor.insert_slice %[[sliceFlat]] into %[[iterArg]][%[[iv]], 0] [1, 11] [1, 1] :
-// CHECK: scf.yield %[[update]] :
-// CHECK: return %[[tile]]
-
-// FOREACH: func.func @extract_slice_static(%[[arg0:.+]]:
-// FOREACH-DAG: %[[c20:.+]] = arith.constant 20 : index
-// FOREACH-DAG: %[[c3:.+]] = arith.constant 3 : index
-// FOREACH-DAG: %[[c5:.+]] = arith.constant 5 : index
-// FOREACH-DAG: %[[c7:.+]] = arith.constant 7 : index
-// FOREACH-DAG: %[[init:.+]] = linalg.init_tensor [20, 11] :
-// FOREACH: %[[tile:.+]] = scf.foreach_thread (%[[iv:.+]]) in (%[[c20]]) shared_outs(%[[dest:.+]] = %[[init]])
-// FOREACH: %[[multiIndex:.+]]:3 = affine.delinearize_index %[[iv]] into (%[[c3]], %[[c5]], %[[c7]]
-// FOREACH: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0] [1, 1, 1, 11] [1, 1, 1, 1] :
-// FOREACH: %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3]{{\]}} :
-// FOREACH: perform_concurrently
-// FOREACH-NEXT: tensor.parallel_insert_slice %[[sliceFlat]] into %[[dest]][%[[iv]], 0] [1, 11] [1, 1] :
-// FOREACH: return %[[tile]]
-
-// -----
-
-
-func.func @extract_slice_static_strided(%input: tensor<3x5x7x11xf32>) -> tensor<10x5xf32> {
- %collapsed = tensor.collapse_shape %input [[0, 1, 2], [3]] : tensor<3x5x7x11xf32> into tensor<105x11xf32>
- %slice = tensor.extract_slice %collapsed [13, 0] [10, 5] [2, 2] : tensor<105x11xf32> to tensor<10x5xf32>
- return %slice : tensor<10x5xf32>
-}
-
-// CHECK: #[[$map0:.+]] = affine_map<(d0) -> (d0 * 2 + 13)>
-// CHECK: func.func @extract_slice_static_strided(%[[arg0:.+]]:
-// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index
-// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index
-// CHECK-DAG: %[[c10:.+]] = arith.constant 10 : index
-// CHECK-DAG: %[[c3:.+]] = arith.constant 3 : index
-// CHECK-DAG: %[[c5:.+]] = arith.constant 5 : index
-// CHECK-DAG: %[[c7:.+]] = arith.constant 7 : index
-// CHECK: %[[init:.+]] = linalg.init_tensor [10, 5] :
-// CHECK: %[[tile:.+]] = scf.for %[[iv:.+]] = %[[c0]] to %[[c10]] step %[[c1]] iter_args(%[[iterArg:.+]] = %[[init]])
-// CHECK: %[[inputIv:.+]] = affine.apply #[[$map0]](%[[iv]])
-// CHECK: %[[multiIndex:.+]]:3 = affine.delinearize_index %[[inputIv]] into (%[[c3]], %[[c5]], %[[c7]]
-// CHECK: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0] [1, 1, 1, 5] [1, 1, 1, 2] :
-// CHECK: %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3]{{\]}} :
-// CHECK: %[[update:.+]] = tensor.insert_slice %[[sliceFlat]] into %[[iterArg]][%[[iv]], 0] [1, 5] [1, 1] :
-// CHECK: scf.yield %[[update]] :
-// CHECK: return %[[tile]]
-
-
-// -----
-
-
-func.func @extract_slice_dynamic(%input: tensor<3x?x?x11xf32>, %offt: index, %size: index) -> tensor<?x5xf32> {
- %collapsed = tensor.collapse_shape %input [[0, 1, 2], [3]] : tensor<3x?x?x11xf32> into tensor<?x11xf32>
- %slice = tensor.extract_slice %collapsed [%offt, 0] [%size, 5] [2, 2] : tensor<?x11xf32> to tensor<?x5xf32>
- return %slice : tensor<?x5xf32>
-}
-
-// CHECK: #[[map0:.+]] = affine_map<(d0)[s0] -> (d0 * 2 + s0)>
-// CHECK: func.func @extract_slice_dynamic(%[[arg0:.+]]: tensor<{{.*}}>, %[[lb:.+]]: index, %[[sz:.+]]: index)
-// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index
-// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index
-// CHECK-DAG: %[[c2:.+]] = arith.constant 2 : index
-// CHECK-DAG: %[[c3:.+]] = arith.constant 3 : index
-// CHECK: %[[init:.+]] = linalg.init_tensor [%[[sz]], 5] : tensor<?x5xf32>
-// CHECK-DAG: %[[d1:.+]] = tensor.dim %arg0, %[[c1]] : tensor<3x?x?x11xf32>
-// CHECK-DAG: %[[d2:.+]] = tensor.dim %arg0, %[[c2]] : tensor<3x?x?x11xf32>
-// CHECK: %[[tile:.+]] = scf.for %[[iv:.+]] = %[[c0]] to %[[sz]] step %[[c1]] iter_args(%[[iterArg:.+]] = %[[init]])
-// CHECK: %[[inputIv:.+]] = affine.apply #[[map0]](%[[iv]])[%[[lb]]]
-// CHECK: %[[multiIndex:.+]]:3 = affine.delinearize_index %[[inputIv]] into (%[[c3]], %[[d1]], %[[d2]]) :
-// CHECK: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0] [1, 1, 1, 5] [1, 1, 1, 2] :
-// CHECK: %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3]{{\]}} :
-// CHECK: %[[update:.+]] = tensor.insert_slice %[[sliceFlat]] into %[[iterArg]][%[[iv]], 0] [1, 5] [1, 1] :
-// CHECK: scf.yield %[[update]] :
-// CHECK: return %[[tile]] :
-
-// -----
-
-
-func.func @extract_slice_dynamic_multidim(%input: tensor<3x?x?x11x?xf32>, %offt0: index, %size0: index, %offt1: index, %size1: index) -> tensor<?x?xf32> {
- %collapsed = tensor.collapse_shape %input [[0, 1, 2], [3, 4]] : tensor<3x?x?x11x?xf32> into tensor<?x?xf32>
- %slice = tensor.extract_slice %collapsed [%offt0, %offt1] [%size0, %size1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
- return %slice : tensor<?x?xf32>
-}
-
-// CHECK: #[[map0:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
-// CHECK: func.func @extract_slice_dynamic_multidim(%[[arg0:.+]]: tensor<3x?x?x11x?xf32>, %[[lb1:.+]]: index, %[[sz1:.+]]: index, %[[lb2:.+]]: index, %[[sz2:.+]]: index)
-// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index
-// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index
-// CHECK-DAG: %[[c2:.+]] = arith.constant 2 : index
-// CHECK-DAG: %[[c3:.+]] = arith.constant 3 : index
-// CHECK-DAG: %[[c4:.+]] = arith.constant 4 : index
-// CHECK-DAG: %[[c11:.+]] = arith.constant 11 : index
-// CHECK: %[[init:.+]] = linalg.init_tensor [%[[sz1]], %[[sz2]]] : tensor<?x?xf32>
-// CHECK-DAG: %[[d1:.+]] = tensor.dim %[[arg0]], %[[c1]] :
-// CHECK-DAG: %[[d2:.+]] = tensor.dim %[[arg0]], %[[c2]] :
-// CHECK-DAG: %[[d4:.+]] = tensor.dim %[[arg0]], %[[c4]] :
-// CHECK: %[[tile1:.+]] = scf.for %[[iv1:.+]] = %[[c0]] to %[[sz1]] step %[[c1]] iter_args(%[[iterArg1:.+]] = %[[init]])
-// CHECK: %[[tile2:.+]] = scf.for %[[iv2:.+]] = %[[c0]] to %[[sz2]] step %[[c1]] iter_args(%[[iterArg2:.+]] = %[[iterArg1]])
-// CHECK: %[[inputIv1:.+]] = affine.apply #[[map0:.+]](%[[iv1]])[%[[lb1]]]
-// CHECK: %[[multiIndex1:.+]]:3 = affine.delinearize_index %[[inputIv1]] into (%[[c3]], %[[d1]], %[[d2]]) :
-// CHECK: %[[inputIv2:.+]] = affine.apply #[[map0:.+]](%[[iv2]])[%[[lb2]]]
-// CHECK: %[[multiIndex2:.+]]:2 = affine.delinearize_index %[[inputIv2]] into (%[[c11]], %[[d4]]) :
-// CHECK: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex1]]#0, %[[multiIndex1]]#1, %[[multiIndex1]]#2, %[[multiIndex2]]#0, %[[multiIndex2]]#1] [1, 1, 1, 1, 1] [1, 1, 1, 1, 1] :
-// CHECK: %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3, 4]{{\]}} :
-// CHECK: %[[update:.+]] = tensor.insert_slice %[[sliceFlat]] into %[[iterArg2]][%[[iv1]], %[[iv2]]] [1, 1] [1, 1] :
-// CHECK: scf.yield %[[update]] :
-// CHECK: scf.yield %[[tile2]] :
-// CHECK: return %[[tile1]] :
-
-// FOREACH: #[[map1:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
-// FOREACH: func.func @extract_slice_dynamic_multidim(%[[arg0:.+]]: tensor<3x?x?x11x?xf32>, %[[lb1:.+]]: index, %[[sz1:.+]]: index, %[[lb2:.+]]: index, %[[sz2:.+]]: index)
-// FOREACH-DAG: %[[c1:.+]] = arith.constant 1 : index
-// FOREACH-DAG: %[[c2:.+]] = arith.constant 2 : index
-// FOREACH-DAG: %[[c3:.+]] = arith.constant 3 : index
-// FOREACH-DAG: %[[c4:.+]] = arith.constant 4 : index
-// FOREACH-DAG: %[[c11:.+]] = arith.constant 11 : index
-// FOREACH: %[[init:.+]] = linalg.init_tensor [%[[sz1]], %[[sz2]]] : tensor<?x?xf32>
-// FOREACH-DAG: %[[d1:.+]] = tensor.dim %[[arg0]], %[[c1]] :
-// FOREACH-DAG: %[[d2:.+]] = tensor.dim %[[arg0]], %[[c2]] :
-// FOREACH-DAG: %[[d4:.+]] = tensor.dim %[[arg0]], %[[c4]] :
-// FOREACH: %[[tile1:.+]] = scf.foreach_thread (%[[tid1:.+]], %[[tid2:.+]]) in (%[[sz1]], %[[sz2]]) shared_outs(%[[dest:.+]] = %[[init]])
-// FOREACH-DAG: %[[iv1:.+]] = affine.apply #[[map1]](%[[tid1]])[%[[lb1]]]
-// FOREACH: %[[multiIndex1:.+]]:3 = affine.delinearize_index %[[iv1]] into (%[[c3]], %[[d1]], %[[d2]]) :
-// FOREACH-DAG: %[[iv2:.+]] = affine.apply #[[map1]](%[[tid2]])[%[[lb2]]]
-// FOREACH: %[[multiIndex2:.+]]:2 = affine.delinearize_index %[[iv2]] into (%[[c11]], %[[d4]]) :
-// FOREACH: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex1]]#0, %[[multiIndex1]]#1, %[[multiIndex1]]#2, %[[multiIndex2]]#0, %[[multiIndex2]]#1] [1, 1, 1, 1, 1] [1, 1, 1, 1, 1] :
-// FOREACH: %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3, 4]{{\]}} :
-// FOREACH: perform_concurrently
-//FOREACH-NEXT: tensor.parallel_insert_slice %[[sliceFlat]] into %[[dest]][%[[tid1]], %[[tid2]]] [1, 1] [1, 1] :
-
-// -----
-
-// Verifies that a linearized dimension that is not sliced does not generate a loop. Note that this
-// only works for static shapes.
-
-// CHECK: @extract_slice_non_sliced_linearized_dim(%[[arg0:.+]]: tensor<{{.*}}>,
-func.func @extract_slice_non_sliced_linearized_dim(%input: tensor<3x?x?x11x2xf32>, %offt: index, %size: index) -> tensor<?x22xf32> {
- %collapsed = tensor.collapse_shape %input [[0, 1, 2], [3, 4]] : tensor<3x?x?x11x2xf32> into tensor<?x22xf32>
- %slice = tensor.extract_slice %collapsed [%offt, 0] [%size, 22] [1, 1] : tensor<?x22xf32> to tensor<?x22xf32>
- // CHECK: scf.for
- // CHECK-NOT: scf.for
- // CHECK: %[[multiIndex:.+]]:3 = affine.delinearize_index
- // CHECK: tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0, 0] [1, 1, 1, 11, 2] [1, 1, 1, 1, 1]
- return %slice : tensor<?x22xf32>
-}
diff --git a/mlir/test/lib/Dialect/Tensor/CMakeLists.txt b/mlir/test/lib/Dialect/Tensor/CMakeLists.txt
index 56e59820677ea..89c996da7b2a9 100644
--- a/mlir/test/lib/Dialect/Tensor/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Tensor/CMakeLists.txt
@@ -6,7 +6,6 @@ add_mlir_library(MLIRTensorTestPasses
LINK_LIBS PUBLIC
MLIRArithmeticDialect
- MLIRLinalgDialect
MLIRPass
MLIRSCFDialect
MLIRTensorDialect
diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
index f5a7f984ab0a7..4c38ad1d2dda2 100644
--- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
@@ -11,10 +11,8 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
-#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/Dialect/Tensor/Transforms/TransformUtils.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -30,8 +28,7 @@ struct TestTensorTransforms
TestTensorTransforms(const TestTensorTransforms &pass) : PassWrapper(pass) {}
void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<arith::ArithmeticDialect, scf::SCFDialect,
- linalg::LinalgDialect>();
+ registry.insert<arith::ArithmeticDialect, scf::SCFDialect>();
}
StringRef getArgument() const final {
@@ -52,19 +49,6 @@ struct TestTensorTransforms
*this, "test-fold-constant-extract-slice",
llvm::cl::desc("Test folding arith.constant and tensor.extract_slice"),
llvm::cl::init(false)};
-
- Option<bool> testRewriteExtractSliceWithTiledCollapseShape{
- *this, "test-rewrite-extract-slice-from-collapse-shape",
- llvm::cl::desc("Test swapping tensor.extract_slice of a collapse_shape "
- "with loop nest"),
- llvm::cl::init(false)};
-
- Option<bool> useForeach{
- *this, "use-foreach",
- llvm::cl::desc(
- "Use the scf.foreach_thread operation when generating loop nests for "
- "the extract_slice of collapse_shape pattern"),
- llvm::cl::init(false)};
};
} // namespace
@@ -90,142 +74,12 @@ static void applyFoldConstantExtractSlicePatterns(Operation *rootOp) {
(void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
}
-namespace {
-/// Base pattern to rewrite a `tensor.collapse_shape -> tensor.extract_slice`.
-/// The `tensor.extract_slice` is replaced by a loop or gather operation that
-/// stitches together the desired tile from slices of the source of the collapse
-/// shape op.
-struct RewriteExtractSliceFromCollapseShapeBase
- : public OpRewritePattern<tensor::ExtractSliceOp> {
- RewriteExtractSliceFromCollapseShapeBase(MLIRContext *context)
- : mlir::OpRewritePattern<tensor::ExtractSliceOp>(context) {}
-
- /// Emit a loop or gather operation that uses `helper` to take each point in
- /// the parallel iteration space bounds, extract a slice from the source
- /// tensor and insert it into `dest`. For examples, see below for `scf.for`
- /// and `scf.foreach`.
- virtual LogicalResult
- emitReplacement(tensor::ExtractSliceOp op, Value dest,
- tensor::ExtractSliceFromCollapseHelper &helper,
- PatternRewriter &rewriter) const = 0;
-
- LogicalResult matchAndRewrite(tensor::ExtractSliceOp op,
- PatternRewriter &rewriter) const override {
- auto collapseOp = op.getSource().getDefiningOp<tensor::CollapseShapeOp>();
- if (!collapseOp)
- return rewriter.notifyMatchFailure(
- op, "producer is not a tensor.collapse_shape op");
-
- // Materialize the output shape values of the slice operation.a
- ReifiedRankedShapedTypeDims reifiedShapes;
- if (failed(op.reifyResultShapes(rewriter, reifiedShapes)))
- return rewriter.notifyMatchFailure(op, "failed to reify result shapes");
-
- // Create the destination tensor using the above values.
- Type elementType = op.getSourceType().getElementType();
- SmallVector<OpFoldResult> outputShape = getAsOpFoldResult(reifiedShapes[0]);
- Value dest = rewriter.create<linalg::InitTensorOp>(
- op->getLoc(), outputShape, elementType);
-
- // Calculate the parameters for the tile loop nest.
- FailureOr<tensor::ExtractSliceFromCollapseHelper> params =
- tensor::ExtractSliceFromCollapseHelper::create(rewriter, collapseOp,
- op);
- if (failed(params))
- return rewriter.notifyMatchFailure(
- op, "could not calculate tiling parameters");
- return emitReplacement(op, dest, *params, rewriter);
- }
-};
-
-struct RewriteExtractSliceFromCollapseShapeUsingScfFor
- : public RewriteExtractSliceFromCollapseShapeBase {
- RewriteExtractSliceFromCollapseShapeUsingScfFor(MLIRContext *context)
- : RewriteExtractSliceFromCollapseShapeBase(context) {}
- LogicalResult emitReplacement(tensor::ExtractSliceOp op, Value dest,
- tensor::ExtractSliceFromCollapseHelper &helper,
- PatternRewriter &rewriter) const override {
- Location loc = op.getLoc();
- const unsigned numTiledDims = helper.getIterationSpaceSizes().size();
- auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
- SmallVector<Value> lbs(numTiledDims, zero);
- SmallVector<Value> steps(numTiledDims, one);
- scf::LoopNest nest = scf::buildLoopNest(
- rewriter, loc, lbs, helper.getIterationSpaceSizes(), steps, dest,
- [&](OpBuilder &nestedBuilder, Location loc, ValueRange outputIvs,
- ValueRange iterArgs) -> scf::ValueVector {
- auto [tile, insertParams] =
- helper.emitLoopNestBody(nestedBuilder, loc, outputIvs);
-
- // Insert the slice into the destination.
- Value result = nestedBuilder.create<tensor::InsertSliceOp>(
- loc, tile, iterArgs[0], insertParams);
- return {result};
- });
- rewriter.replaceOp(op, nest.getResults()[0]);
- return success();
- }
-};
-
-struct RewriteExtractSliceFromCollapseShapeUsingScfForeach
- : public RewriteExtractSliceFromCollapseShapeBase {
- RewriteExtractSliceFromCollapseShapeUsingScfForeach(MLIRContext *context)
- : RewriteExtractSliceFromCollapseShapeBase(context) {}
- LogicalResult emitReplacement(tensor::ExtractSliceOp op, Value dest,
- tensor::ExtractSliceFromCollapseHelper &helper,
- PatternRewriter &rewriter) const override {
- Location loc = op.getLoc();
- auto foreachOp = rewriter.create<scf::ForeachThreadOp>(
- loc, /*outputs=*/dest, /*numThreads=*/helper.getIterationSpaceSizes(),
- /*threadDimMapping=*/ArrayRef<int64_t>{},
- [&](OpBuilder &nestedBuilder, Location loc, ValueRange regionArgs) {
- unsigned numThreadIdRegionArgs =
- helper.getIterationSpaceSizes().size();
- unsigned numOutputRegionArgs =
- regionArgs.size() - numThreadIdRegionArgs;
- ValueRange outputIvs = regionArgs.take_front(numThreadIdRegionArgs);
- ValueRange outputArgs = regionArgs.take_back(numOutputRegionArgs);
- assert(outputArgs.size() == 1 &&
- "there should only be one output region argument");
- auto [tile, insertParams] =
- helper.emitLoopNestBody(nestedBuilder, loc, outputIvs);
- // Insert the slice into the destination.
- auto term = nestedBuilder.create<scf::PerformConcurrentlyOp>(loc);
- nestedBuilder.setInsertionPointToStart(term.getBody());
- nestedBuilder.create<tensor::ParallelInsertSliceOp>(
- loc, tile, outputArgs[0], insertParams);
- });
- rewriter.replaceOp(op, foreachOp->getResult(0));
- return success();
- }
-};
-} // namespace
-
-static LogicalResult
-applyRewriteExtractFromCollapseShapePatterns(Operation *rootOp,
- bool useForeach) {
- RewritePatternSet patterns(rootOp->getContext());
- if (useForeach)
- patterns.add<RewriteExtractSliceFromCollapseShapeUsingScfForeach>(
- rootOp->getContext());
- else
- patterns.add<RewriteExtractSliceFromCollapseShapeUsingScfFor>(
- rootOp->getContext());
- return applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
-}
-
void TestTensorTransforms::runOnOperation() {
Operation *rootOp = getOperation();
if (testSplitPaddingPatterns)
applySplitPaddingPatterns(rootOp);
if (testFoldConstantExtractSlice)
applyFoldConstantExtractSlicePatterns(rootOp);
- if (testRewriteExtractSliceWithTiledCollapseShape) {
- if (failed(
- applyRewriteExtractFromCollapseShapePatterns(rootOp, useForeach)))
- return signalPassFailure();
- }
}
namespace mlir {
More information about the Mlir-commits
mailing list