[Mlir-commits] [mlir] [mlir] [linalg] Add pattern to swap transpose with broadcast (PR #97063)

donald chen llvmlistbot at llvm.org
Tue Jul 16 20:51:24 PDT 2024


https://github.com/cxy-1993 updated https://github.com/llvm/llvm-project/pull/97063

>From aa2fd98fc97ee1f931288d82a7fa5911dd1d6a96 Mon Sep 17 00:00:00 2001
From: cxy <chenxunyu1993 at gmail.com>
Date: Thu, 27 Jun 2024 00:00:03 +0800
Subject: [PATCH 1/2] [mlir] [linalg] Add canonicalize pattern to swap
 transpose with broadcast

Add canonicalize pattern that implement canonicalize:

  transpose(broadcast(input)) -> broadcast(transpose(input))

Reduce the cost of transpose.
---
 .../mlir/Dialect/Utils/IndexingUtils.h        |  8 ++
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 60 ++++++++++++++-
 mlir/lib/Dialect/Utils/IndexingUtils.cpp      | 25 +++++++
 mlir/test/Dialect/Linalg/canonicalize.mlir    | 75 ++++++++++++++++++-
 4 files changed, 166 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
index b774359552aa5..7849782e5442b 100644
--- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
@@ -243,6 +243,14 @@ SmallVector<int64_t>
 computePermutationVector(int64_t permSize, ArrayRef<int64_t> positions,
                          ArrayRef<int64_t> desiredPositions);
 
+/// Returns a permutation vector that drop the input dims in
+/// dropPositions from inputPerm.
+///
+/// For example, inputPerm = {2, 4, 0, 1, 3} and dropPositions= {1, 2} would
+/// result in a {2, 0, 1} permutation vector.
+SmallVector<int64_t> dropDims(ArrayRef<int64_t> inputPerm,
+                              ArrayRef<int64_t> dropPositions);
+
 /// Helper to return a subset of `arrayAttr` as a vector of int64_t.
 // TODO: Port everything relevant to DenseArrayAttr and drop this util.
 SmallVector<int64_t> getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront = 0,
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 57d126603ebd7..2ccb6e98a34a7 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1890,9 +1890,67 @@ struct FoldTransposeWithTranspose : OpRewritePattern<linalg::TransposeOp> {
   }
 };
 
+/// This pattern canonicalize transpose by swapping the order of
+/// broadcast and transpose:
+///   transpose(broadcast(input)) -> broadcast(transpose(input))
+struct SwapTransposeWithBroadcast : OpRewritePattern<linalg::TransposeOp> {
+  using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
+                                PatternRewriter &rewriter) const override {
+    Value input = transposeOp.getInput();
+    BroadcastOp broadcastOp = input.getDefiningOp<BroadcastOp>();
+    if (!input.hasOneUse() || !broadcastOp)
+      return failure();
+
+    ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
+    ArrayRef<int64_t> perms = transposeOp.getPermutation();
+
+    // Get new perms and new dimensions.
+    SmallVector<int64_t> resultPerms = dropDims(perms, dimensions);
+    SmallVector<int64_t> invertPerm = invertPermutationVector(perms);
+    SmallVector<int64_t> resultDimensions;
+    for (unsigned i = 0; i < dimensions.size(); i++) {
+      resultDimensions.push_back(invertPerm[dimensions[i]]);
+    }
+
+    // Create transpose result.
+    Value broadcastInput = broadcastOp.getInput();
+    Location loc = transposeOp.getLoc();
+    MLIRContext *ctx = transposeOp.getContext();
+    SmallVector<OpFoldResult> dims;
+    auto broadcastInputTy =
+        mlir::cast<RankedTensorType>(broadcastInput.getType());
+    for (unsigned i = 0; i < broadcastInputTy.getRank(); i++) {
+      if (broadcastInputTy.isDynamicDim(i)) {
+        dims.push_back(rewriter.create<tensor::DimOp>(loc, broadcastInput, i)
+                           ->getResult(0));
+      } else {
+        dims.push_back(IntegerAttr::get(IndexType::get(ctx),
+                                        broadcastInputTy.getDimSize(i)));
+      }
+    }
+    SmallVector<OpFoldResult> transposeResultShapes =
+        applyPermutation(dims, resultPerms);
+    Value transposeInit = rewriter.create<tensor::EmptyOp>(
+        transposeOp.getLoc(), transposeResultShapes,
+        broadcastInputTy.getElementType());
+
+    // Create broadcast(transpose(input)).
+    Value transposeResult =
+        rewriter
+            .create<TransposeOp>(loc, broadcastOp.getInput(), transposeInit,
+                                 resultPerms)
+            ->getResult(0);
+    rewriter.replaceOpWithNewOp<BroadcastOp>(
+        transposeOp, transposeResult, transposeOp.getInit(), resultDimensions);
+    return success();
+  }
+};
+
 void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                               MLIRContext *context) {
-  results.add<FoldTransposeWithTranspose>(context);
+  results.add<FoldTransposeWithTranspose, SwapTransposeWithBroadcast>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
index aba225be720c3..ddc1129a5a75f 100644
--- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp
+++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
@@ -252,6 +252,31 @@ mlir::computePermutationVector(int64_t permSize, ArrayRef<int64_t> positions,
   return res;
 }
 
+SmallVector<int64_t> mlir::dropDims(ArrayRef<int64_t> inputPerm,
+                                    ArrayRef<int64_t> dropPositions) {
+  assert(inputPerm.size() >= dropPositions.size() &&
+         "expect inputPerm size large than position to drop");
+  SmallVector<int64_t> res;
+  for (unsigned inputIndex = 0; inputIndex < inputPerm.size(); ++inputIndex) {
+    int64_t targetIndex = inputPerm[inputIndex];
+    bool shouldDrop = false;
+    for (unsigned dropIndex = 0; dropIndex < dropPositions.size();
+         dropIndex++) {
+      if (dropPositions[dropIndex] == inputPerm[inputIndex]) {
+        shouldDrop = true;
+        break;
+      }
+      if (dropPositions[dropIndex] < inputPerm[inputIndex]) {
+        targetIndex--;
+      }
+    }
+    if (!shouldDrop) {
+      res.push_back(targetIndex);
+    }
+  }
+  return res;
+}
+
 SmallVector<int64_t> mlir::getI64SubArray(ArrayAttr arrayAttr,
                                           unsigned dropFront,
                                           unsigned dropBack) {
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 928030a81dc02..d34bc8c1c54f6 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1017,7 +1017,7 @@ func.func @broadcast_same_shape(%input: tensor<2x3xf32>, %init: tensor<2x3xf32>)
   return %0 : tensor<2x3xf32>
 }
 
-// ----
+// -----
 
 func.func @transpose_1d(%input: tensor<16xf32>,
                         %init: tensor<16xf32>) -> tensor<16xf32> {
@@ -1096,3 +1096,76 @@ func.func @transpose_transpose_fold(%input: tensor<5x4x3xf32>,
   func.return %transpose2 : tensor<3x4x5xf32>
 }
 
+// -----
+
+func.func @broadcast_transpose_fold(%input: tensor<2x4x5xf32>,
+                                    %init1: tensor<1x2x3x4x5x6xf32>,
+                                    %init2: tensor<1x6x2x3x5x4xf32>) -> tensor<1x6x2x3x5x4xf32> {
+  // CHECK-LABEL: @broadcast_transpose_fold
+  //  CHECK-SAME:     %[[INPUT:[a-zA-Z0-9]+]]: tensor<2x4x5xf32>
+  //  CHECK-SAME:     %[[INIT1:[a-zA-Z0-9]+]]: tensor<1x2x3x4x5x6xf32>
+  //  CHECK-SAME:     %[[INIT2:[a-zA-Z0-9]+]]: tensor<1x6x2x3x5x4xf32>
+  //       CHECK:   %[[TMP_INIT:.+]] = tensor.empty() : tensor<2x5x4xf32>
+  //       CHECK:   %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[INPUT]] : tensor<2x4x5xf32>) outs(%[[TMP_INIT]] : tensor<2x5x4xf32>) permutation = [0, 2, 1]
+  //       CHECK:   %[[BROADCAST:.+]] = linalg.broadcast ins(%[[TRANSPOSE]] : tensor<2x5x4xf32>) outs(%[[INIT2]] : tensor<1x6x2x3x5x4xf32>) dimensions = [0, 3, 1]
+  //       CHECK:   return %[[BROADCAST]] : tensor<1x6x2x3x5x4xf32>
+  %broadcast = linalg.broadcast
+      ins(%input : tensor<2x4x5xf32>)
+      outs(%init1 : tensor<1x2x3x4x5x6xf32>)
+      dimensions = [0, 2, 5]
+  %transpose = linalg.transpose
+      ins(%broadcast : tensor<1x2x3x4x5x6xf32>)
+      outs(%init2 : tensor<1x6x2x3x5x4xf32>)
+      permutation = [0, 5, 1, 2, 4, 3]
+  func.return %transpose : tensor<1x6x2x3x5x4xf32>
+}
+
+// -----
+
+func.func @broadcast_transpose_fold_dynamic(%input: tensor<?x?x5xf32>,
+                                            %init1: tensor<1x?x3x?x5x6xf32>,
+                                            %init2: tensor<1x3x?x6x5x?xf32>) -> tensor<1x3x?x6x5x?xf32> {
+  // CHECK-LABEL: @broadcast_transpose_fold_dynamic
+  //  CHECK-SAME:     %[[INPUT:[a-zA-Z0-9]+]]: tensor<?x?x5xf32>
+  //  CHECK-SAME:     %[[INIT1:[a-zA-Z0-9]+]]: tensor<1x?x3x?x5x6xf32>
+  //  CHECK-SAME:     %[[INIT2:[a-zA-Z0-9]+]]: tensor<1x3x?x6x5x?xf32>
+  //   CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+  //   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+  //       CHECK:   %[[DIM0:.+]] = tensor.dim %[[INPUT]], %[[C0]] : tensor<?x?x5xf32>
+  //       CHECK:   %[[DIM1:.+]] = tensor.dim %[[INPUT]], %[[C1]] : tensor<?x?x5xf32>
+  //       CHECK:   %[[TMP_INIT:.+]] = tensor.empty(%[[DIM1]], %[[DIM0]]) : tensor<?x5x?xf32>
+  //       CHECK:   %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[INPUT]] : tensor<?x?x5xf32>) outs(%[[TMP_INIT]] : tensor<?x5x?xf32>) permutation = [1, 2, 0]
+  //       CHECK:   %[[BROADCAST:.+]] = linalg.broadcast ins(%[[TRANSPOSE]] : tensor<?x5x?xf32>) outs(%[[INIT2]] : tensor<1x3x?x6x5x?xf32>) dimensions = [0, 1, 3]
+  //       CHECK:   return %[[BROADCAST]] : tensor<1x3x?x6x5x?xf32>
+  %broadcast = linalg.broadcast
+      ins(%input : tensor<?x?x5xf32>)
+      outs(%init1 : tensor<1x?x3x?x5x6xf32>)
+      dimensions = [0, 2, 5]
+  %transpose = linalg.transpose
+      ins(%broadcast : tensor<1x?x3x?x5x6xf32>)
+      outs(%init2 : tensor<1x3x?x6x5x?xf32>)
+      permutation = [0, 2, 3, 5, 4, 1]
+  func.return %transpose : tensor<1x3x?x6x5x?xf32>
+}
+
+// -----
+
+func.func @broadcast_transpose_fold_2dim(%input: tensor<2xf32>,
+                                         %init1: tensor<2x4xf32>,
+                                         %init2: tensor<4x2xf32>) -> tensor<4x2xf32> {
+  // CHECK-LABEL: @broadcast_transpose_fold_2dim
+  //  CHECK-SAME:     %[[INPUT:[a-zA-Z0-9]+]]: tensor<2xf32>
+  //  CHECK-SAME:     %[[INIT1:[a-zA-Z0-9]+]]: tensor<2x4xf32>
+  //  CHECK-SAME:     %[[INIT2:[a-zA-Z0-9]+]]: tensor<4x2xf32>
+  //       CHECK:   %[[BROADCAST:.+]] = linalg.broadcast ins(%[[INPUT]] : tensor<2xf32>) outs(%[[INIT2]] : tensor<4x2xf32>) dimensions = [0]
+  //       CHECK:   return %[[BROADCAST]] : tensor<4x2xf32>
+  %broadcast = linalg.broadcast
+      ins(%input : tensor<2xf32>)
+      outs(%init1 : tensor<2x4xf32>)
+      dimensions = [1]
+  %transpose = linalg.transpose
+      ins(%broadcast : tensor<2x4xf32>)
+      outs(%init2 : tensor<4x2xf32>)
+      permutation = [1, 0]
+  func.return %transpose : tensor<4x2xf32>
+}

>From 79a7960d9e87311e609455e75a39a65f9cd20875 Mon Sep 17 00:00:00 2001
From: cxy <chenxunyu1993 at gmail.com>
Date: Wed, 17 Jul 2024 11:51:00 +0800
Subject: [PATCH 2/2] update docs

---
 mlir/docs/Canonicalization.md | 61 ++++++++++++++++++++++++++++++++++-
 1 file changed, 60 insertions(+), 1 deletion(-)

diff --git a/mlir/docs/Canonicalization.md b/mlir/docs/Canonicalization.md
index d1cba572af212..dba58ad22c9e9 100644
--- a/mlir/docs/Canonicalization.md
+++ b/mlir/docs/Canonicalization.md
@@ -33,6 +33,11 @@ together.
 
 Some important things to think about w.r.t. canonicalization patterns:
 
+*   The goal of canonicalization is to make subsequent optimizations more
+    effective. Therefore, performance improvements are not necessary for
+    canonicalization. But it is generally better to define a canonicalize
+    pattern that do not harm the performance.
+
 *   Pass pipelines should not rely on the canonicalizer pass for correctness.
     They should work correctly with all instances of the canonicalization pass
     removed.
@@ -51,6 +56,60 @@ Some important things to think about w.r.t. canonicalization patterns:
 *   It is always good to eliminate operations entirely when possible, e.g. by
     folding known identities (like "x + 0 = x").
 
+*   Canonicalize isn't a great place to put pattens with expensive compile time
+    (i.e. have O(n) complexity) or complicated cost models.
+
+*   Canonicalize shouldn't drop the semantic of original operation.
+
+For example, a pattern that transform
+
+```
+  %res = vector.transpose %0, [1, 0] : vector<nx1x<eltty>> to vector<1xnx<elty>>
+```
+
+to
+
+```
+  %res = vector.shape_cast %0 : vector<nx1x<eltty>> to vector<1xnx<elty>>
+```
+
+is not a good canonicalize pattern because it drops the transpose semantic.
+
+
+A pattern that transform (linalg.transpose is only use of %broadcast)
+
+```
+  %broadcast = linalg.broadcast
+      ins(%input : tensor<2x4x5xf32>)
+      outs(%init1 : tensor<1x2x3x4x5x6xf32>)
+      dimensions = [0, 2, 5]
+  %transpose = linalg.transpose
+      ins(%broadcast : tensor<1x2x3x4x5x6xf32>)
+      outs(%init2 : tensor<1x6x2x3x5x4xf32>)
+      permutation = [0, 5, 1, 2, 4, 3]
+```
+
+to
+
+```
+  %tranpose = linalg.transpose
+      ins(%input : tensor<2x4x5xf32>)
+      outs(%tmp_init : tensor<2x5x4xf32>)
+      permutation = [0, 2, 1]
+  %broadcast = linalg.broadcast
+      ins(%transpose : tensor<2x5x4xf32>)
+      outs(%init2 : tensor<1x6x2x3x5x4xf32>)
+      dimensions = [0, 3, 1]
+```
+
+is a good canonicalize pattern because:
+
+1. This pattern is converge.
+2. This pattern always transforms the program towards reducing the amount of
+   computational data, which is a clear lattice.
+3. This is not a one-off pattern, new matches may be generated during the
+   application process.
+
 ## Globally Applied Rules
 
 These transformations are applied to all levels of IR:
@@ -189,7 +248,7 @@ each of the operands, returning the corresponding constant attribute. These
 operands are those that implement the `ConstantLike` trait. If any of the
 operands are non-constant, a null `Attribute` value is provided instead. For
 example, if MyOp provides three operands [`a`, `b`, `c`], but only `b` is
-constant then `adaptor` will return Attribute() for `getA()` and `getC()`, 
+constant then `adaptor` will return Attribute() for `getA()` and `getC()`,
 and b-value for `getB()`.
 
 Also above, is the use of `OpFoldResult`. This class represents the possible



More information about the Mlir-commits mailing list