[Mlir-commits] [mlir] [mlir][Tensor] NFC: Move concat operation decomposition as a method of the concat operation. (PR #116004)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Nov 13 12:13:10 PST 2024

https://github.com/MaheshRavishankar updated https://github.com/llvm/llvm-project/pull/116004

>From e33c018f48fddbe8187120c15f4feb7fcb81d128 Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Tue, 12 Nov 2024 22:47:11 -0800
Subject: [PATCH] [mlir][Tensor] Move concat operation decomposition as a
 method of the concat operation.

Currently the implementation is within a pattern that cannot be used
without a pattern rewriter. Move the decomposition as a method of the
operation to make it usable outside of pattern rewrites.

Signed-off-by: MaheshRavishankar <mahesh.ravishankar at gmail.com>
 .../mlir/Dialect/Tensor/IR/TensorOps.td       |  3 ++
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp      | 48 +++++++++++++++++
 .../Tensor/Transforms/ConcatOpPatterns.cpp    | 53 +++----------------
 .../test/Dialect/Tensor/decompose-concat.mlir | 49 +++++++++--------
 4 files changed, 84 insertions(+), 69 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 3170115883e2be..b73da8bb6af59c 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -178,6 +178,9 @@ def Tensor_ConcatOp : Tensor_Op<"concat",
     int64_t getRank() {
       return ::llvm::cast<RankedTensorType>(getResult().getType()).getRank();
+    // Method to decompose the operation into a sequence of insert_slices.
+    FailureOr<SmallVector<Value>> decomposeOperation(OpBuilder &builder);
   let hasCanonicalizer = 1;
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 147120e0e34203..616d4a7d0a0ab5 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -615,6 +615,54 @@ LogicalResult ConcatOp::verify() {
   return success();
+FailureOr<SmallVector<Value>> ConcatOp::decomposeOperation(OpBuilder &builder) {
+  size_t numInputs = getInputs().size();
+  uint64_t concatDim = getDim();
+  SmallVector<SmallVector<OpFoldResult>> inputShapes;
+  inputShapes.reserve(numInputs);
+  SmallVector<OpFoldResult> concatOffsets;
+  concatOffsets.reserve(numInputs);
+  SmallVector<OpFoldResult> outputShape;
+  AffineExpr addExpr =
+      builder.getAffineSymbolExpr(0) + builder.getAffineSymbolExpr(1);
+  OpFoldResult zero = builder.getIndexAttr(0);
+  Location loc = getLoc();
+  for (auto [index, input] : llvm::enumerate(getInputs())) {
+    SmallVector<OpFoldResult> inputShape =
+        tensor::getMixedSizes(builder, input.getLoc(), input);
+    if (index == 0) {
+      outputShape = inputShape;
+      concatOffsets.push_back(zero);
+    } else {
+      concatOffsets.push_back(outputShape[concatDim]);
+      outputShape[concatDim] = affine::makeComposedFoldedAffineApply(
+          builder, loc, addExpr,
+          {outputShape[concatDim], inputShape[concatDim]});
+    }
+    inputShapes.emplace_back(std::move(inputShape));
+  }
+  Value replacement = builder.create<tensor::EmptyOp>(
+      loc, outputShape, getType().getElementType());
+  int64_t rank = getType().getRank();
+  OpFoldResult one = builder.getIndexAttr(1);
+  SmallVector<OpFoldResult> strides(rank, one);
+  SmallVector<OpFoldResult> offsets(rank, zero);
+  for (auto [index, input] : llvm::enumerate(getInputs())) {
+    offsets[concatDim] = concatOffsets[index];
+    auto insertSlice = builder.create<tensor::InsertSliceOp>(
+        loc, input, replacement, offsets, inputShapes[index], strides);
+    replacement = insertSlice.getResult();
+  }
+  if (replacement.getType() != getType()) {
+    replacement = builder.create<tensor::CastOp>(loc, getType(), replacement);
+  }
+  return SmallVector<Value>{replacement};
 ConcatOp::reifyResultShapes(OpBuilder &builder,
                             ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
diff --git a/mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp
index 7c8403c9609d84..a2a860fcb38abb 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp
@@ -33,54 +33,13 @@ struct DecomposeTensorConcatOp : public OpRewritePattern<ConcatOp> {
   LogicalResult matchAndRewrite(ConcatOp concatOp,
                                 PatternRewriter &rewriter) const override {
-    Location loc = concatOp.getLoc();
-    FailureOr<Value> dest =
-        tensor::getOrCreateDestination(rewriter, loc, concatOp->getResult(0));
-    if (failed(dest))
-      return failure();
-    auto empty = dest->getDefiningOp<tensor::EmptyOp>();
-    if (!empty)
-      return failure();
-    int64_t dim = concatOp.getDim();
-    Value dimValue =
-        rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(dim));
-    int64_t rank = concatOp.getResultType().getRank();
-    SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
-    SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
-    // Compute the partial sums for the slice offsets.
-    AffineExpr sum = rewriter.getAffineDimExpr(0);
-    SmallVector<AffineExpr> partialSums = {sum};
-    SmallVector<OpFoldResult> offsetStrides = {rewriter.getIndexAttr(0)};
-    for (auto [idx, input] :
-         llvm::enumerate(concatOp.getInputs().drop_back())) {
-      sum = sum + rewriter.getAffineDimExpr(idx + 1);
-      partialSums.push_back(sum);
-      offsetStrides.push_back(
-          rewriter.createOrFold<tensor::DimOp>(loc, input, dimValue));
+    FailureOr<SmallVector<Value>> decomposed =
+        concatOp.decomposeOperation(rewriter);
+    if (failed(decomposed)) {
+      return rewriter.notifyMatchFailure(
+          concatOp, "failed to get the decomposed insert slices");
-    auto partialSumMap = AffineMap::get(concatOp.getInputs().size(), 0,
-                                        partialSums, rewriter.getContext());
-    SmallVector<OpFoldResult> dimOffsets =
-        affine::makeComposedFoldedMultiResultAffineApply(
-            rewriter, loc, partialSumMap, offsetStrides);
-    // Construct the chain of insert_slice ops into the destination.
-    Value result = *dest;
-    for (auto [input, offset] :
-         llvm::zip_equal(concatOp.getInputs(), dimOffsets)) {
-      SmallVector<OpFoldResult> sizes =
-          tensor::getMixedSizes(rewriter, loc, input);
-      offsets[dim] = offset;
-      result = rewriter.createOrFold<tensor::InsertSliceOp>(
-          loc, input, result, offsets, sizes, strides);
-    }
-    rewriter.replaceOpWithNewOp<tensor::CastOp>(
-        concatOp, concatOp.getResultType(), result);
+    rewriter.replaceOp(concatOp, decomposed.value()[0]);
     return success();
diff --git a/mlir/test/Dialect/Tensor/decompose-concat.mlir b/mlir/test/Dialect/Tensor/decompose-concat.mlir
index c0f23b8eddbd52..2b1cb138ecda5b 100644
--- a/mlir/test/Dialect/Tensor/decompose-concat.mlir
+++ b/mlir/test/Dialect/Tensor/decompose-concat.mlir
@@ -1,24 +1,23 @@
-// RUN: mlir-opt -split-input-file -transform-interpreter -cse  %s | FileCheck %s
+// RUN: mlir-opt -split-input-file -transform-interpreter -cse --mlir-print-local-scope %s | FileCheck %s
 func.func @decompose_dynamic_concat(%arg0 : tensor<8x4xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
   %0 = tensor.concat dim(1) %arg0, %arg1 : (tensor<8x4xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
   return %0 : tensor<?x?xf32>
-//   CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 + 4)>
 // CHECK-LABEL: func @decompose_dynamic_concat(
 //  CHECK-SAME:     %[[ARG0:.+]]: tensor<8x4xf32>
 //  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
-//   CHECK-DAG:     %[[C8:.+]] = arith.constant 8 : index
 //   CHECK-DAG:     %[[C1:.+]] = arith.constant 1 : index
 //   CHECK-DAG:     %[[C0:.+]] = arith.constant 0 : index
-//       CHECK:     %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
-//       CHECK:     %[[CONCAT_SIZE:.+]] = affine.apply #[[$MAP]]()[%[[DIM]]]
-//       CHECK:     %[[EMPTY:.+]] = tensor.empty(%[[C8]], %[[CONCAT_SIZE]]) : tensor<?x?xf32>
-//       CHECK:     %[[SLICE0:.+]] = tensor.insert_slice %[[ARG0]] into %[[EMPTY]][0, 0] [8, 4] [1, 1] : tensor<8x4xf32> into tensor<?x?xf32>
-//       CHECK:     %[[OFFSET:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
-//       CHECK:     %[[CONCAT:.+]] = tensor.insert_slice %[[ARG1]] into %[[SLICE0]][0, 4] [%[[OFFSET]], %[[DIM]]] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
-//       CHECK:     return %[[CONCAT]] : tensor<?x?xf32>
+//       CHECK:     %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
+//       CHECK:     %[[DIM0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
+//       CHECK:     %[[CONCAT_SIZE:.+]] = affine.apply affine_map<()[s0] -> (s0 + 4)>()[%[[DIM0]]]
+//       CHECK:     %[[EMPTY:.+]] = tensor.empty(%[[CONCAT_SIZE]]) : tensor<8x?xf32>
+//       CHECK:     %[[SLICE0:.+]] = tensor.insert_slice %[[ARG0]] into %[[EMPTY]][0, 0] [8, 4] [1, 1] : tensor<8x4xf32> into tensor<8x?xf32>
+//       CHECK:     %[[CONCAT:.+]] = tensor.insert_slice %[[ARG1]] into %[[SLICE0]][0, 4] [%[[DIM]], %[[DIM0]]] [1, 1] : tensor<?x?xf32> into tensor<8x?xf32>
+//       CHECK:     %[[CAST:.+]] = tensor.cast %[[CONCAT]] : tensor<8x?xf32> to tensor<?x?xf32>
+//       CHECK:     return %[[CAST]] : tensor<?x?xf32>
 func.func @decompose_1d_concat(%arg0 : tensor<1xf32>,
                             %arg1 : tensor<2xf32>,
@@ -42,12 +41,14 @@ func.func @decompose_static_concat_dim(%arg0 : tensor<1x?x64xf32>,
              : (tensor<1x?x64xf32>, tensor<1x?x64xf32>) -> tensor<1x?x128xf32>
   return %0 : tensor<1x?x128xf32>
-// CHECK-LABEL: func @decompose_static_concat_dim
+// CHECK-LABEL: func @decompose_static_concat_dim(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<1x?x64xf32>,
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<1x?x64xf32>)
 //   CHECK-DAG:     %[[C1:.+]] = arith.constant 1 : index
-//       CHECK:     %[[DIM:.+]] = tensor.dim %{{.*}}, %[[C1]] : tensor<1x?x64xf32>
+//       CHECK:     %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<1x?x64xf32>
+//       CHECK:     %[[DIM1:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<1x?x64xf32>
 //       CHECK:    tensor.empty(%[[DIM]]) : tensor<1x?x128xf32>
 //       CHECK:    tensor.insert_slice %{{.*}}[0, 0, 0] [1, %[[DIM]], 64] [1, 1, 1] : tensor<1x?x64xf32> into tensor<1x?x128xf32>
-//       CHECK:     %[[DIM1:.+]] = tensor.dim %{{.*}}, %[[C1]] : tensor<1x?x64xf32>
 //       CHECK:    %[[CONCAT:.+]] = tensor.insert_slice %{{.*}}[0, 0, 64] [1, %[[DIM1]], 64] [1, 1, 1] : tensor<1x?x64xf32> into tensor<1x?x128xf32>
 //       CHECK:    return %[[CONCAT]] : tensor<1x?x128xf32>
@@ -58,19 +59,23 @@ func.func @decompose_dynamic_into_static_concat_dim(%arg0 : tensor<1x?x?xf32>,
              : (tensor<1x?x?xf32>, tensor<1x?x?xf32>) -> tensor<1x?x128xf32>
   return %0 : tensor<1x?x128xf32>
-// CHECK-LABEL: func @decompose_dynamic_into_static_concat_dim
+// CHECK-LABEL: func @decompose_dynamic_into_static_concat_dim(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<1x?x?xf32>,
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<1x?x?xf32>)
 //   CHECK-DAG:     %[[C1:.+]] = arith.constant 1 : index
 //   CHECK-DAG:     %[[C2:.+]] = arith.constant 2 : index
-//       CHECK:     %[[T0_DIM1:.+]] = tensor.dim %{{.*}}, %[[C1]] : tensor<1x?x?xf32>
-//       CHECK:     tensor.empty(%[[T0_DIM1]]) : tensor<1x?x128xf32>
-//       CHECK:     %[[T0_DIM2:.+]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x?x?xf32>
+//       CHECK:     %[[T0_DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<1x?x?xf32>
+//       CHECK:     %[[T0_DIM2:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<1x?x?xf32>
+//       CHECK:     %[[T1_DIM1:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<1x?x?xf32>
+//       CHECK:     %[[T1_DIM2:.+]] = tensor.dim %[[ARG1]], %[[C2]] : tensor<1x?x?xf32>
+//       CHECK:     %[[CONCAT_DIM:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%[[T0_DIM2]], %[[T1_DIM2]]]
+//       CHECK:     tensor.empty(%[[T0_DIM1]], %[[CONCAT_DIM]]) : tensor<1x?x?xf32>
 //       CHECK:     tensor.insert_slice %{{.*}}[0, 0, 0] [1, %[[T0_DIM1]], %[[T0_DIM2]]] [1, 1, 1]
-//  CHECK-SAME:       tensor<1x?x?xf32> into tensor<1x?x128xf32>
-//       CHECK:     %[[T1_DIM1:.+]] = tensor.dim %{{.*}}, %[[C1]] : tensor<1x?x?xf32>
-//       CHECK:     %[[T1_DIM2:.+]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x?x?xf32>
+//  CHECK-SAME:       tensor<1x?x?xf32> into tensor<1x?x?xf32>
 //       CHECK:     %[[CONCAT:.+]] = tensor.insert_slice %{{.*}}[0, 0, %[[T0_DIM2]]] [1, %[[T1_DIM1]], %[[T1_DIM2]]] [1, 1, 1]
-//  CHECK-SAME:        tensor<1x?x?xf32> into tensor<1x?x128xf32>
-//       CHECK:     return %[[CONCAT]] : tensor<1x?x128xf32>
+//  CHECK-SAME:        tensor<1x?x?xf32> into tensor<1x?x?xf32>
+//       CHECK:     %[[CAST:.+]] = tensor.cast %[[CONCAT]] : tensor<1x?x?xf32> to tensor<1x?x128xf32>
+//       CHECK:     return %[[CAST]] : tensor<1x?x128xf32>
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) {

More information about the Mlir-commits mailing list