[Mlir-commits] [mlir] [mlir][tensor] add gather decompose pattern (PR #119805)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Dec 12 19:16:17 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: zhicong zhong (zhczhong)
<details>
<summary>Changes</summary>
Current tensor.gather cannot be bufferized and further lowered. Here add a decompose pattern to help decompose the tensor.gather into a series of bufferized op(tensor.empty, linalg.generic, tensor.extract_slice)
---
Full diff: https://github.com/llvm/llvm-project/pull/119805.diff
6 Files Affected:
- (modified) mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td (+11)
- (modified) mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h (+7)
- (modified) mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp (+5)
- (modified) mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt (+1)
- (added) mlir/lib/Dialect/Tensor/Transforms/GatherOpPatterns.cpp (+166)
- (added) mlir/test/Dialect/Tensor/decompose-gather.mlir (+66)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
index 81bab1b0c82f7a..2be2d019e11228 100644
--- a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
@@ -189,4 +189,15 @@ def TypeConversionCastShapeDynamicDimsOp : Op<Transform_Dialect,
"(`ignore_dynamic_info` $ignore_dynamic_info^)? attr-dict";
}
+def ApplyDecomposeTensorGatherPatternsOp : Op<Transform_Dialect,
+ "apply_patterns.tensor.decompose_gather",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Indicates that tensor.gather ops should be decomposed into a chain of
+ tensor.extract_slice and linalg.generic to extract the element from source.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
#endif // TENSOR_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index ae695e0326ca1a..fa73f74d0be66d 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -102,6 +102,13 @@ using ControlFoldFn = std::function<bool(OpOperand *)>;
void populateRewriteAsConstantPatterns(RewritePatternSet &patterns,
const ControlFoldFn &controlFn);
+/// Populates `patterns` with patterns that decompose `tensor.gather` into
+/// `tensor.empty` and `linalg.geric`, followed by a chain
+/// of `tensor.extract_slice` operations on the inputs. This is intended to be
+/// used as a tensor -> linalg lowering that decomposes gather such
+/// that it can be bufferized into a sequence of bufferized op.
+void populateDecomposeTensorGatherPatterns(RewritePatternSet &patterns);
+
//===----------------------------------------------------------------------===//
// Transform helpers
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
index 99199252710f99..cb2d01df40b8d8 100644
--- a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
+++ b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
@@ -143,6 +143,11 @@ void transform::ApplyRewriteTensorOpsAsConstantPatternsOp::populatePatterns(
tensor::populateRewriteAsConstantPatterns(patterns, defaultControlFn);
}
+void transform::ApplyDecomposeTensorGatherPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ tensor::populateDecomposeTensorGatherPatterns(patterns);
+}
+
//===----------------------------------------------------------------------===//
// TypeConversionCastTensorShapeOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
index cc6275fee671aa..f1a23e5e3bfbfc 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRTensorTransforms
EmptyOpPatterns.cpp
ExtractSliceFromReshapeUtils.cpp
FoldTensorSubsetOps.cpp
+ GatherOpPatterns.cpp
IndependenceTransforms.cpp
MergeConsecutiveInsertExtractSlicePatterns.cpp
PackAndUnpackPatterns.cpp
diff --git a/mlir/lib/Dialect/Tensor/Transforms/GatherOpPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/GatherOpPatterns.cpp
new file mode 100644
index 00000000000000..5905ee049228a5
--- /dev/null
+++ b/mlir/lib/Dialect/Tensor/Transforms/GatherOpPatterns.cpp
@@ -0,0 +1,166 @@
+//===- GatherOpPatterns.cpp - Patterns related to tensor.concat lowering --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+
+using namespace mlir;
+using namespace mlir::tensor;
+
+namespace {
+
+/// Decompose `tensor.gather` into `linalg.generic`.
+///
+/// %2 = tensor.gather %0[%1] gather_dims([0]) : (tensor<7x128xf16>,
+/// tensor<1x7x1xindex>) -> tensor<1x7x128xf16>
+///
+/// Becomes
+///
+/// %empty = tensor.empty() : tensor<1x7x128xf16>
+/// %14 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1,
+/// 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types =
+/// ["parallel", "parallel", "parallel"]} ins(%expanded : tensor<1x7x1xindex>)
+/// outs(%13 : tensor<1x7x128xf16>) {
+/// ^bb0(%in: index, %out: f16):
+/// %17 = linalg.index 2 : index
+/// %extracted = tensor.extract %0[%in, %17] : tensor<7x128xf16>
+/// linalg.yield %extracted : f16
+/// } -> tensor<1x7x128xf16>
+struct DecomposeTensorGatherOp : public OpRewritePattern<tensor::GatherOp> {
+ using OpRewritePattern<tensor::GatherOp>::OpRewritePattern;
+
+ SmallVector<OpFoldResult> getDstMixedSizes(PatternRewriter &rewriter,
+ Location loc,
+ tensor::GatherOp gatherOp) const {
+ SmallVector<OpFoldResult> dstSize =
+ tensor::getMixedSizes(rewriter, loc, gatherOp.getResult());
+ SmallVector<OpFoldResult> indexSize =
+ tensor::getMixedSizes(rewriter, loc, gatherOp.getIndices());
+ SmallVector<OpFoldResult> srcSize =
+ tensor::getMixedSizes(rewriter, loc, gatherOp.getSource());
+ SmallVector<int64_t> gatherDims(gatherOp.getGatherDims());
+ bool isShrinkDst = (indexSize.size() - 1) + srcSize.size() ==
+ dstSize.size() + gatherDims.size();
+ for (size_t i = 0; i < indexSize.size() - 1; i++) {
+ dstSize[i] = indexSize[i];
+ }
+ auto cnt = 0;
+ for (size_t i = indexSize.size() - 1; i < dstSize.size(); i++) {
+ while (isShrinkDst && llvm::find(gatherDims, cnt) != gatherDims.end()) {
+ cnt++;
+ }
+ dstSize[i] = llvm::find(gatherDims, cnt) == gatherDims.end()
+ ? srcSize[cnt]
+ : getAsIndexOpFoldResult(rewriter.getContext(), 1);
+ cnt++;
+ }
+ return dstSize;
+ }
+
+ LogicalResult matchAndRewrite(tensor::GatherOp gatherOp,
+ PatternRewriter &rewriter) const override {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(gatherOp);
+ Location loc = gatherOp.getLoc();
+ SmallVector<int64_t> gatherDims(gatherOp.getGatherDims());
+
+ // create destination tensor for linalg out
+ RankedTensorType dstType = gatherOp.getResultType();
+ Value dstTensor = rewriter.create<tensor::EmptyOp>(
+ loc, getDstMixedSizes(rewriter, loc, gatherOp),
+ dstType.getElementType());
+
+ // split index tensor to create the linalg input
+ SmallVector<Value> indexTensors;
+ Value originIndexTensor = gatherOp.getIndices();
+ SmallVector<OpFoldResult> indexTensorSize =
+ tensor::getMixedSizes(rewriter, loc, originIndexTensor);
+ SmallVector<OpFoldResult> indexTensorStride(
+ indexTensorSize.size(),
+ getAsIndexOpFoldResult(rewriter.getContext(), 1));
+ SmallVector<OpFoldResult> indexTensorOffset(
+ indexTensorSize.size(),
+ getAsIndexOpFoldResult(rewriter.getContext(), 0));
+ indexTensorSize[indexTensorSize.size() - 1] =
+ getAsIndexOpFoldResult(rewriter.getContext(), 1);
+
+ for (size_t cnt = 0; cnt < gatherDims.size(); cnt++) {
+ indexTensorOffset[indexTensorSize.size() - 1] =
+ getAsIndexOpFoldResult(rewriter.getContext(), cnt);
+ Value indexTensor = rewriter.create<tensor::ExtractSliceOp>(
+ loc, originIndexTensor, indexTensorOffset, indexTensorSize,
+ indexTensorStride);
+ indexTensors.emplace_back(indexTensor);
+ }
+
+ // create the affine map
+ SmallVector<AffineMap> affineMaps;
+ SmallVector<AffineExpr> dimExprs;
+ size_t dstRank = dstType.getShape().size();
+ for (unsigned i = 0; i < indexTensorSize.size() - 1; ++i)
+ dimExprs.push_back(rewriter.getAffineDimExpr(i));
+ dimExprs.push_back(getAffineConstantExpr(0, rewriter.getContext()));
+
+ for (size_t cnt = 0; cnt < gatherDims.size(); cnt++) {
+ AffineMap currentMap =
+ AffineMap::get(/*dimCount=*/dstRank, /*symbolCount=*/0, dimExprs,
+ rewriter.getContext());
+ affineMaps.emplace_back(currentMap);
+ }
+ affineMaps.emplace_back(rewriter.getMultiDimIdentityMap(dstRank));
+
+ // create iterater types array
+ SmallVector<utils::IteratorType> iteratorTypesArray(
+ dstRank, utils::IteratorType::parallel);
+
+ // check whether the gather op is valid
+ size_t srcRank = gatherOp.getSourceType().getShape().size();
+ assert(((indexTensorSize.size() - 1) + srcRank == dstRank ||
+ (indexTensorSize.size() - 1) + srcRank ==
+ dstRank + gatherDims.size()) &&
+ "Expected: index_size - 1 + source_size == dst_size or dst_szie - "
+ "gather_size. \n");
+ rewriter.replaceOpWithNewOp<linalg::GenericOp>(
+ gatherOp, TypeRange(dstType), indexTensors, ValueRange{dstTensor},
+ affineMaps, iteratorTypesArray,
+ [&](OpBuilder &b, Location loc, ValueRange args) {
+ SmallVector<Value> indexValues(srcRank);
+ bool isShrinkDst = (indexTensorSize.size() - 1) + srcRank ==
+ dstRank + gatherDims.size();
+ int cnt = 0;
+ for (auto i = indexTensorSize.size() - 1; i < dstRank; i++) {
+ while (isShrinkDst &&
+ llvm::find(gatherDims, cnt) != gatherDims.end()) {
+ cnt++;
+ }
+ indexValues[cnt] = b.create<linalg::IndexOp>(loc, i);
+ cnt++;
+ }
+ for (auto &&[i, dim] : llvm::enumerate(gatherDims)) {
+ indexValues[dim] = args[i];
+ }
+
+ Value extract = b.create<tensor::ExtractOp>(loc, gatherOp.getSource(),
+ indexValues);
+ b.create<linalg::YieldOp>(loc, extract);
+ });
+ return success();
+ }
+};
+
+} // namespace
+
+void mlir::tensor::populateDecomposeTensorGatherPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<DecomposeTensorGatherOp>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/Tensor/decompose-gather.mlir b/mlir/test/Dialect/Tensor/decompose-gather.mlir
new file mode 100644
index 00000000000000..587dfc8cc7e2fc
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/decompose-gather.mlir
@@ -0,0 +1,66 @@
+// RUN: mlir-opt -split-input-file -transform-interpreter -cse --mlir-print-local-scope %s | FileCheck %s
+
+/// CHECK-LABEL: @gather_single_gather_dim
+func.func @gather_single_gather_dim(%arg0: tensor<2x2x2x2xf32>, %arg1: tensor<2x3x1xindex>) -> tensor<2x3x2x2x2xf32> {
+ /// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<2x3x2x2x2xf32>
+ /// CHECK: linalg.generic {{.*}} iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG1:.*]] : tensor<2x3x1xindex>) outs(%[[EMPTY:.*]] : tensor<2x3x2x2x2xf32>)
+ %1 = tensor.gather %arg0[%arg1] gather_dims([1]) : (tensor<2x2x2x2xf32>, tensor<2x3x1xindex>) -> tensor<2x3x2x2x2xf32>
+ return %1 : tensor<2x3x2x2x2xf32>
+}
+
+/// CHECK-LABEL: @gather_single_gather_dim_no_shrink
+func.func @gather_single_gather_dim_no_shrink(%arg0: tensor<2x2x2x2xf32>, %arg1: tensor<2x3x1xindex>) -> tensor<2x3x2x1x2x2xf32> {
+ /// CHECK: %[[EMPTY1:.*]] = tensor.empty() : tensor<2x3x2x1x2x2xf32>
+ /// CHECK: linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, 0)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG1:.*]] : tensor<2x3x1xindex>) outs(%[[EMPTY1:.*]] : tensor<2x3x2x1x2x2xf32>)
+ %1 = tensor.gather %arg0[%arg1] gather_dims([1]) : (tensor<2x2x2x2xf32>, tensor<2x3x1xindex>) -> tensor<2x3x2x1x2x2xf32>
+ return %1 : tensor<2x3x2x1x2x2xf32>
+}
+
+/// CHECK-LABEL: @gather_multiple_gather_dim
+func.func @gather_multiple_gather_dim(%arg0: tensor<2x2x2x2xf32>, %arg1: tensor<2x3x2xindex>) -> tensor<2x3x2x2xf32> {
+ // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<2x3x2x2xf32>
+ /// CHECK: %[[EXTRACTSLICE1:.*]] = tensor.extract_slice %[[ARG1:.*]][0, 0, 0] [2, 3, 1] [1, 1, 1] : tensor<2x3x2xindex> to tensor<2x3x1xindex>
+ /// CHECK: %[[EXTRACTSLICE2:.*]] = tensor.extract_slice %[[ARG1:.*]][0, 0, 1] [2, 3, 1] [1, 1, 1] : tensor<2x3x2xindex> to tensor<2x3x1xindex>
+ /// CHECK: linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[EXTRACTSLICE1:.*]], %[[EXTRACTSLICE2:.*]] : tensor<2x3x1xindex>, tensor<2x3x1xindex>) outs(%[[EMPTY:.*]] : tensor<2x3x2x2xf32>)
+ %1 = tensor.gather %arg0[%arg1] gather_dims([1, 2]) : (tensor<2x2x2x2xf32>, tensor<2x3x2xindex>) -> tensor<2x3x2x2xf32>
+ return %1 : tensor<2x3x2x2xf32>
+}
+
+/// CHECK-LABEL: @gather_multiple_gather_dim_no_shrink
+func.func @gather_multiple_gather_dim_no_shrink(%arg0: tensor<2x2x2x2xf32>, %arg1: tensor<2x3x2xindex>) -> tensor<2x3x2x1x1x2xf32> {
+ %1 = tensor.gather %arg0[%arg1] gather_dims([1, 2]) : (tensor<2x2x2x2xf32>, tensor<2x3x2xindex>) -> tensor<2x3x2x1x1x2xf32>
+ return %1 : tensor<2x3x2x1x1x2xf32>
+}
+
+/// CHECK-LABEL: @gather_single_gather_dim_dynamic
+func.func @gather_single_gather_dim_dynamic(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<2x3x1xindex>) -> tensor<2x3x?x?x?xf32> {
+ /// CHECK: %[[DIM1:.*]] = tensor.dim
+ /// CHECK: %[[DIM2:.*]] = tensor.dim
+ /// CHECK: %[[DIM3:.*]] = tensor.dim
+ /// CHECK: %[[EMPTY:.*]] = tensor.empty(%[[DIM1:.*]], %[[DIM2:.*]], %[[DIM3:.*]]) : tensor<2x3x?x?x?xf32>
+ /// CHECK: linalg.generic {{.*}} iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0:.*]] : tensor<2x3x1xindex>) outs(%[[EMPTY:.*]] : tensor<2x3x?x?x?xf32>)
+ %1 = tensor.gather %arg0[%arg1] gather_dims([1]) : (tensor<?x?x?x?xf32>, tensor<2x3x1xindex>) -> tensor<2x3x?x?x?xf32>
+ return %1 : tensor<2x3x?x?x?xf32>
+}
+
+/// CHECK-LABEL: @gather_multiple_gather_dim_no_shrink_dynamic
+func.func @gather_multiple_gather_dim_no_shrink_dynamic(%arg0: tensor<2x2x2x2xf32>, %arg1: tensor<?x?x2xindex>) -> tensor<?x?x2x1x1x2xf32> {
+ /// CHECK: %[[DIM1:.*]] = tensor.dim
+ /// CHECK: %[[DIM2:.*]] = tensor.dim
+ /// CHECK: %[[EMPTY:.*]] = tensor.empty(%[[DIM1:.*]], %[[DIM2:.*]]) : tensor<?x?x2x1x1x2xf32>
+ /// CHECK: %[[EXTRACTSLICE1:.*]] = tensor.extract_slice %[[ARG1:.*]][0, 0, 0] [%[[DIM1:.*]], %[[DIM2:.*]], 1] [1, 1, 1] : tensor<?x?x2xindex> to tensor<?x?x1xindex>
+ /// CHECK: %[[EXTRACTSLICE2:.*]] = tensor.extract_slice %[[ARG1:.*]][0, 0, 1] [%[[DIM1:.*]], %[[DIM2:.*]], 1] [1, 1, 1] : tensor<?x?x2xindex> to tensor<?x?x1xindex>
+ /// CHECK: linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, 0)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, 0)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%[[EXTRACTSLICE1:.*]], %[[EXTRACTSLICE2:.*]] : tensor<?x?x1xindex>, tensor<?x?x1xindex>) outs(%[[EMPTY:.*]] : tensor<?x?x2x1x1x2xf32>)
+ %1 = tensor.gather %arg0[%arg1] gather_dims([1, 2]) : (tensor<2x2x2x2xf32>, tensor<?x?x2xindex>) -> tensor<?x?x2x1x1x2xf32>
+ return %1 : tensor<?x?x2x1x1x2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) {
+ %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
+ transform.apply_patterns to %func_op {
+ transform.apply_patterns.tensor.decompose_gather
+ } : !transform.op<"func.func">
+ transform.yield
+ }
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/119805
More information about the Mlir-commits
mailing list