[Mlir-commits] [mlir] [mlir][tensor] add gather decompose pattern (PR #119805)

zhicong zhong llvmlistbot at llvm.org
Thu Dec 12 19:15:45 PST 2024

https://github.com/zhczhong created https://github.com/llvm/llvm-project/pull/119805

Current tensor.gather cannot be bufferized and further lowered. Here add a decompose pattern to help decompose the tensor.gather into a series of bufferized op(tensor.empty, linalg.generic, tensor.extract_slice)

>From 483eb7a37d4825ceb8dcf542fcd21e508a64e2c7 Mon Sep 17 00:00:00 2001
From: "Zhong, Zhicong" <zhicong.zhong at intel.com>
Date: Fri, 13 Dec 2024 03:12:30 +0000
Subject: [PATCH] add gather decompose pattern

 .../Tensor/TransformOps/TensorTransformOps.td |  11 ++
 .../Dialect/Tensor/Transforms/Transforms.h    |   7 +
 .../TransformOps/TensorTransformOps.cpp       |   5 +
 .../Dialect/Tensor/Transforms/CMakeLists.txt  |   1 +
 .../Tensor/Transforms/GatherOpPatterns.cpp    | 166 ++++++++++++++++++
 .../test/Dialect/Tensor/decompose-gather.mlir |  66 +++++++
 6 files changed, 256 insertions(+)
 create mode 100644 mlir/lib/Dialect/Tensor/Transforms/GatherOpPatterns.cpp
 create mode 100644 mlir/test/Dialect/Tensor/decompose-gather.mlir

diff --git a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
index 81bab1b0c82f7a..2be2d019e11228 100644
--- a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
@@ -189,4 +189,15 @@ def TypeConversionCastShapeDynamicDimsOp : Op<Transform_Dialect,
       "(`ignore_dynamic_info` $ignore_dynamic_info^)? attr-dict";
+def ApplyDecomposeTensorGatherPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.tensor.decompose_gather",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Indicates that tensor.gather ops should be decomposed into a chain of
+    tensor.extract_slice and linalg.generic to extract the element from source.
+  }];
+  let assemblyFormat = "attr-dict";
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index ae695e0326ca1a..fa73f74d0be66d 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -102,6 +102,13 @@ using ControlFoldFn = std::function<bool(OpOperand *)>;
 void populateRewriteAsConstantPatterns(RewritePatternSet &patterns,
                                        const ControlFoldFn &controlFn);
+/// Populates `patterns` with patterns that decompose `tensor.gather` into
+/// `tensor.empty` and `linalg.geric`, followed by a chain
+/// of `tensor.extract_slice` operations on the inputs. This is intended to be
+/// used as a tensor -> linalg lowering that decomposes gather such
+/// that it can be bufferized into a sequence of bufferized op.
+void populateDecomposeTensorGatherPatterns(RewritePatternSet &patterns);
 // Transform helpers
diff --git a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
index 99199252710f99..cb2d01df40b8d8 100644
--- a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
+++ b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
@@ -143,6 +143,11 @@ void transform::ApplyRewriteTensorOpsAsConstantPatternsOp::populatePatterns(
     tensor::populateRewriteAsConstantPatterns(patterns, defaultControlFn);
+void transform::ApplyDecomposeTensorGatherPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  tensor::populateDecomposeTensorGatherPatterns(patterns);
 // TypeConversionCastTensorShapeOp
diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
index cc6275fee671aa..f1a23e5e3bfbfc 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRTensorTransforms
+  GatherOpPatterns.cpp
diff --git a/mlir/lib/Dialect/Tensor/Transforms/GatherOpPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/GatherOpPatterns.cpp
new file mode 100644
index 00000000000000..5905ee049228a5
--- /dev/null
+++ b/mlir/lib/Dialect/Tensor/Transforms/GatherOpPatterns.cpp
@@ -0,0 +1,166 @@
+//===- GatherOpPatterns.cpp - Patterns related to tensor.concat lowering --===//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+using namespace mlir;
+using namespace mlir::tensor;
+namespace {
+/// Decompose `tensor.gather` into `linalg.generic`.
+/// %2 = tensor.gather %0[%1] gather_dims([0]) : (tensor<7x128xf16>,
+/// tensor<1x7x1xindex>) -> tensor<1x7x128xf16>
+/// Becomes
+/// %empty = tensor.empty() : tensor<1x7x128xf16>
+/// %14 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1,
+/// 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types =
+/// ["parallel", "parallel", "parallel"]} ins(%expanded : tensor<1x7x1xindex>)
+/// outs(%13 : tensor<1x7x128xf16>) {
+///    ^bb0(%in: index, %out: f16):
+///      %17 = linalg.index 2 : index
+///      %extracted = tensor.extract %0[%in, %17] : tensor<7x128xf16>
+///      linalg.yield %extracted : f16
+///    } -> tensor<1x7x128xf16>
+struct DecomposeTensorGatherOp : public OpRewritePattern<tensor::GatherOp> {
+  using OpRewritePattern<tensor::GatherOp>::OpRewritePattern;
+  SmallVector<OpFoldResult> getDstMixedSizes(PatternRewriter &rewriter,
+                                             Location loc,
+                                             tensor::GatherOp gatherOp) const {
+    SmallVector<OpFoldResult> dstSize =
+        tensor::getMixedSizes(rewriter, loc, gatherOp.getResult());
+    SmallVector<OpFoldResult> indexSize =
+        tensor::getMixedSizes(rewriter, loc, gatherOp.getIndices());
+    SmallVector<OpFoldResult> srcSize =
+        tensor::getMixedSizes(rewriter, loc, gatherOp.getSource());
+    SmallVector<int64_t> gatherDims(gatherOp.getGatherDims());
+    bool isShrinkDst = (indexSize.size() - 1) + srcSize.size() ==
+                       dstSize.size() + gatherDims.size();
+    for (size_t i = 0; i < indexSize.size() - 1; i++) {
+      dstSize[i] = indexSize[i];
+    }
+    auto cnt = 0;
+    for (size_t i = indexSize.size() - 1; i < dstSize.size(); i++) {
+      while (isShrinkDst && llvm::find(gatherDims, cnt) != gatherDims.end()) {
+        cnt++;
+      }
+      dstSize[i] = llvm::find(gatherDims, cnt) == gatherDims.end()
+                       ? srcSize[cnt]
+                       : getAsIndexOpFoldResult(rewriter.getContext(), 1);
+      cnt++;
+    }
+    return dstSize;
+  }
+  LogicalResult matchAndRewrite(tensor::GatherOp gatherOp,
+                                PatternRewriter &rewriter) const override {
+    OpBuilder::InsertionGuard g(rewriter);
+    rewriter.setInsertionPoint(gatherOp);
+    Location loc = gatherOp.getLoc();
+    SmallVector<int64_t> gatherDims(gatherOp.getGatherDims());
+    // create destination tensor for linalg out
+    RankedTensorType dstType = gatherOp.getResultType();
+    Value dstTensor = rewriter.create<tensor::EmptyOp>(
+        loc, getDstMixedSizes(rewriter, loc, gatherOp),
+        dstType.getElementType());
+    // split index tensor to create the linalg input
+    SmallVector<Value> indexTensors;
+    Value originIndexTensor = gatherOp.getIndices();
+    SmallVector<OpFoldResult> indexTensorSize =
+        tensor::getMixedSizes(rewriter, loc, originIndexTensor);
+    SmallVector<OpFoldResult> indexTensorStride(
+        indexTensorSize.size(),
+        getAsIndexOpFoldResult(rewriter.getContext(), 1));
+    SmallVector<OpFoldResult> indexTensorOffset(
+        indexTensorSize.size(),
+        getAsIndexOpFoldResult(rewriter.getContext(), 0));
+    indexTensorSize[indexTensorSize.size() - 1] =
+        getAsIndexOpFoldResult(rewriter.getContext(), 1);
+    for (size_t cnt = 0; cnt < gatherDims.size(); cnt++) {
+      indexTensorOffset[indexTensorSize.size() - 1] =
+          getAsIndexOpFoldResult(rewriter.getContext(), cnt);
+      Value indexTensor = rewriter.create<tensor::ExtractSliceOp>(
+          loc, originIndexTensor, indexTensorOffset, indexTensorSize,
+          indexTensorStride);
+      indexTensors.emplace_back(indexTensor);
+    }
+    // create the affine map
+    SmallVector<AffineMap> affineMaps;
+    SmallVector<AffineExpr> dimExprs;
+    size_t dstRank = dstType.getShape().size();
+    for (unsigned i = 0; i < indexTensorSize.size() - 1; ++i)
+      dimExprs.push_back(rewriter.getAffineDimExpr(i));
+    dimExprs.push_back(getAffineConstantExpr(0, rewriter.getContext()));
+    for (size_t cnt = 0; cnt < gatherDims.size(); cnt++) {
+      AffineMap currentMap =
+          AffineMap::get(/*dimCount=*/dstRank, /*symbolCount=*/0, dimExprs,
+                         rewriter.getContext());
+      affineMaps.emplace_back(currentMap);
+    }
+    affineMaps.emplace_back(rewriter.getMultiDimIdentityMap(dstRank));
+    // create iterater types array
+    SmallVector<utils::IteratorType> iteratorTypesArray(
+        dstRank, utils::IteratorType::parallel);
+    // check whether the gather op is valid
+    size_t srcRank = gatherOp.getSourceType().getShape().size();
+    assert(((indexTensorSize.size() - 1) + srcRank == dstRank ||
+            (indexTensorSize.size() - 1) + srcRank ==
+                dstRank + gatherDims.size()) &&
+           "Expected: index_size - 1 + source_size == dst_size or dst_szie - "
+           "gather_size. \n");
+    rewriter.replaceOpWithNewOp<linalg::GenericOp>(
+        gatherOp, TypeRange(dstType), indexTensors, ValueRange{dstTensor},
+        affineMaps, iteratorTypesArray,
+        [&](OpBuilder &b, Location loc, ValueRange args) {
+          SmallVector<Value> indexValues(srcRank);
+          bool isShrinkDst = (indexTensorSize.size() - 1) + srcRank ==
+                             dstRank + gatherDims.size();
+          int cnt = 0;
+          for (auto i = indexTensorSize.size() - 1; i < dstRank; i++) {
+            while (isShrinkDst &&
+                   llvm::find(gatherDims, cnt) != gatherDims.end()) {
+              cnt++;
+            }
+            indexValues[cnt] = b.create<linalg::IndexOp>(loc, i);
+            cnt++;
+          }
+          for (auto &&[i, dim] : llvm::enumerate(gatherDims)) {
+            indexValues[dim] = args[i];
+          }
+          Value extract = b.create<tensor::ExtractOp>(loc, gatherOp.getSource(),
+                                                      indexValues);
+          b.create<linalg::YieldOp>(loc, extract);
+        });
+    return success();
+  }
+} // namespace
+void mlir::tensor::populateDecomposeTensorGatherPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<DecomposeTensorGatherOp>(patterns.getContext());
diff --git a/mlir/test/Dialect/Tensor/decompose-gather.mlir b/mlir/test/Dialect/Tensor/decompose-gather.mlir
new file mode 100644
index 00000000000000..587dfc8cc7e2fc
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/decompose-gather.mlir
@@ -0,0 +1,66 @@
+// RUN: mlir-opt -split-input-file -transform-interpreter -cse --mlir-print-local-scope %s | FileCheck %s
+/// CHECK-LABEL: @gather_single_gather_dim
+func.func @gather_single_gather_dim(%arg0: tensor<2x2x2x2xf32>, %arg1: tensor<2x3x1xindex>) -> tensor<2x3x2x2x2xf32> {
+    /// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<2x3x2x2x2xf32>
+    /// CHECK: linalg.generic {{.*}} iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG1:.*]] : tensor<2x3x1xindex>) outs(%[[EMPTY:.*]] : tensor<2x3x2x2x2xf32>)
+    %1 = tensor.gather %arg0[%arg1] gather_dims([1]) : (tensor<2x2x2x2xf32>, tensor<2x3x1xindex>) -> tensor<2x3x2x2x2xf32>
+  return %1 : tensor<2x3x2x2x2xf32>
+/// CHECK-LABEL: @gather_single_gather_dim_no_shrink
+func.func @gather_single_gather_dim_no_shrink(%arg0: tensor<2x2x2x2xf32>, %arg1: tensor<2x3x1xindex>) -> tensor<2x3x2x1x2x2xf32> {
+    /// CHECK: %[[EMPTY1:.*]] = tensor.empty() : tensor<2x3x2x1x2x2xf32>
+    /// CHECK: linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, 0)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG1:.*]] : tensor<2x3x1xindex>) outs(%[[EMPTY1:.*]] : tensor<2x3x2x1x2x2xf32>)
+    %1 = tensor.gather %arg0[%arg1] gather_dims([1]) : (tensor<2x2x2x2xf32>, tensor<2x3x1xindex>) -> tensor<2x3x2x1x2x2xf32>
+  return %1 : tensor<2x3x2x1x2x2xf32>
+/// CHECK-LABEL: @gather_multiple_gather_dim
+func.func @gather_multiple_gather_dim(%arg0: tensor<2x2x2x2xf32>, %arg1: tensor<2x3x2xindex>) -> tensor<2x3x2x2xf32> {
+    // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<2x3x2x2xf32>
+    /// CHECK: %[[EXTRACTSLICE1:.*]] = tensor.extract_slice %[[ARG1:.*]][0, 0, 0] [2, 3, 1] [1, 1, 1] : tensor<2x3x2xindex> to tensor<2x3x1xindex>
+    /// CHECK: %[[EXTRACTSLICE2:.*]] = tensor.extract_slice %[[ARG1:.*]][0, 0, 1] [2, 3, 1] [1, 1, 1] : tensor<2x3x2xindex> to tensor<2x3x1xindex>
+      /// CHECK: linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[EXTRACTSLICE1:.*]], %[[EXTRACTSLICE2:.*]] : tensor<2x3x1xindex>, tensor<2x3x1xindex>) outs(%[[EMPTY:.*]] : tensor<2x3x2x2xf32>)
+    %1 = tensor.gather %arg0[%arg1] gather_dims([1, 2]) : (tensor<2x2x2x2xf32>, tensor<2x3x2xindex>) -> tensor<2x3x2x2xf32>
+  return %1 : tensor<2x3x2x2xf32>
+/// CHECK-LABEL: @gather_multiple_gather_dim_no_shrink
+func.func @gather_multiple_gather_dim_no_shrink(%arg0: tensor<2x2x2x2xf32>, %arg1: tensor<2x3x2xindex>) -> tensor<2x3x2x1x1x2xf32> {
+    %1 = tensor.gather %arg0[%arg1] gather_dims([1, 2]) : (tensor<2x2x2x2xf32>, tensor<2x3x2xindex>) -> tensor<2x3x2x1x1x2xf32>
+  return %1 : tensor<2x3x2x1x1x2xf32>
+/// CHECK-LABEL: @gather_single_gather_dim_dynamic
+func.func @gather_single_gather_dim_dynamic(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<2x3x1xindex>) -> tensor<2x3x?x?x?xf32> {
+    /// CHECK: %[[DIM1:.*]] = tensor.dim
+    /// CHECK: %[[DIM2:.*]] = tensor.dim
+    /// CHECK: %[[DIM3:.*]] = tensor.dim
+    /// CHECK: %[[EMPTY:.*]] = tensor.empty(%[[DIM1:.*]], %[[DIM2:.*]], %[[DIM3:.*]]) : tensor<2x3x?x?x?xf32>
+    /// CHECK: linalg.generic {{.*}} iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0:.*]] : tensor<2x3x1xindex>) outs(%[[EMPTY:.*]] : tensor<2x3x?x?x?xf32>)
+    %1 = tensor.gather %arg0[%arg1] gather_dims([1]) : (tensor<?x?x?x?xf32>, tensor<2x3x1xindex>) -> tensor<2x3x?x?x?xf32>
+  return %1 : tensor<2x3x?x?x?xf32>
+/// CHECK-LABEL: @gather_multiple_gather_dim_no_shrink_dynamic
+func.func @gather_multiple_gather_dim_no_shrink_dynamic(%arg0: tensor<2x2x2x2xf32>, %arg1: tensor<?x?x2xindex>) -> tensor<?x?x2x1x1x2xf32> {
+    /// CHECK: %[[DIM1:.*]] = tensor.dim
+    /// CHECK: %[[DIM2:.*]] = tensor.dim
+    /// CHECK: %[[EMPTY:.*]] = tensor.empty(%[[DIM1:.*]], %[[DIM2:.*]]) : tensor<?x?x2x1x1x2xf32>
+    /// CHECK: %[[EXTRACTSLICE1:.*]] = tensor.extract_slice %[[ARG1:.*]][0, 0, 0] [%[[DIM1:.*]], %[[DIM2:.*]], 1] [1, 1, 1] : tensor<?x?x2xindex> to tensor<?x?x1xindex>
+    /// CHECK: %[[EXTRACTSLICE2:.*]] = tensor.extract_slice %[[ARG1:.*]][0, 0, 1] [%[[DIM1:.*]], %[[DIM2:.*]], 1] [1, 1, 1] : tensor<?x?x2xindex> to tensor<?x?x1xindex>
+    /// CHECK: linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, 0)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, 0)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%[[EXTRACTSLICE1:.*]], %[[EXTRACTSLICE2:.*]] : tensor<?x?x1xindex>, tensor<?x?x1xindex>) outs(%[[EMPTY:.*]] : tensor<?x?x2x1x1x2xf32>)
+    %1 = tensor.gather %arg0[%arg1] gather_dims([1, 2]) : (tensor<2x2x2x2xf32>, tensor<?x?x2xindex>) -> tensor<?x?x2x1x1x2xf32>
+  return %1 : tensor<?x?x2x1x1x2xf32>
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) {
+    %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
+    transform.apply_patterns to %func_op {
+      transform.apply_patterns.tensor.decompose_gather
+    } : !transform.op<"func.func">
+    transform.yield
+  }

More information about the Mlir-commits mailing list