[Mlir-commits] [mlir] [mlir] [linalg] Add canonicalize pattern to swap transpose with broadcast (PR #97063)
donald chen
llvmlistbot at llvm.org
Tue Jul 9 05:57:08 PDT 2024
https://github.com/cxy-1993 updated https://github.com/llvm/llvm-project/pull/97063
>From 9c8bbf9c8a3becfdcdb490b09c2ee926871e887a Mon Sep 17 00:00:00 2001
From: cxy <chenxunyu1993 at gmail.com>
Date: Thu, 27 Jun 2024 00:00:03 +0800
Subject: [PATCH] [mlir] [linalg] Add pattern to swap transpose with broadcast
Add a pattern that implement:
transpose(broadcast(input)) -> broadcast(transpose(input))
---
.../Dialect/Linalg/Transforms/Transforms.h | 14 +++-
.../mlir/Dialect/Utils/IndexingUtils.h | 8 ++
.../Dialect/Linalg/Transforms/CMakeLists.txt | 1 +
.../Transforms/SwapTransposeWithBroadcast.cpp | 84 +++++++++++++++++++
mlir/lib/Dialect/Utils/IndexingUtils.cpp | 25 ++++++
mlir/test/Dialect/Linalg/canonicalize.mlir | 2 +-
.../Linalg/swap-transpose-with-broadcast.mlir | 73 ++++++++++++++++
.../Dialect/Linalg/TestLinalgTransforms.cpp | 12 +++
8 files changed, 214 insertions(+), 5 deletions(-)
create mode 100644 mlir/lib/Dialect/Linalg/Transforms/SwapTransposeWithBroadcast.cpp
create mode 100644 mlir/test/Dialect/Linalg/swap-transpose-with-broadcast.mlir
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index b0871a5dff5da..5b6876acc804c 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1709,15 +1709,21 @@ void populateSplitReductionPattern(
void populateTransposeMatmulPatterns(RewritePatternSet &patterns,
bool transposeLHS = true);
+/// Patterns to convert transpose(broadcast(input)) to
+/// broadcast(transpose(input)).
+void populateSwapTransposeWithBroadcastPatterns(RewritePatternSet &patterns);
+
/// Patterns to block pack Linalg matmul ops.
void populateBlockPackMatmulPatterns(RewritePatternSet &patterns,
const ControlBlockPackMatmulFn &controlFn);
/// Adds patterns that reduce the rank of named contraction ops that have
-/// unit dimensions in the operand(s) by converting to a sequence of `collapse_shape`,
-/// `<corresponding linalg named op>`, `expand_shape` (if on tensors). For example a
-/// `linalg.batch_matmul` with unit batch size will convert to `linalg.matmul`
-/// and a `linalg.matvec` with with unit spatial dim in lhs will convert to a `linalg.dot`.
+/// unit dimensions in the operand(s) by converting to a sequence of
+/// `collapse_shape`,
+/// `<corresponding linalg named op>`, `expand_shape` (if on tensors). For
+/// example a `linalg.batch_matmul` with unit batch size will convert to
+/// `linalg.matmul` and a `linalg.matvec` with with unit spatial dim in lhs will
+/// convert to a `linalg.dot`.
void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns);
} // namespace linalg
diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
index b774359552aa5..7849782e5442b 100644
--- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
@@ -243,6 +243,14 @@ SmallVector<int64_t>
computePermutationVector(int64_t permSize, ArrayRef<int64_t> positions,
ArrayRef<int64_t> desiredPositions);
+/// Returns a permutation vector that drop the input dims in
+/// dropPositions from inputPerm.
+///
+/// For example, inputPerm = {2, 4, 0, 1, 3} and dropPositions= {1, 2} would
+/// result in a {2, 0, 1} permutation vector.
+SmallVector<int64_t> dropDims(ArrayRef<int64_t> inputPerm,
+ ArrayRef<int64_t> dropPositions);
+
/// Helper to return a subset of `arrayAttr` as a vector of int64_t.
// TODO: Port everything relevant to DenseArrayAttr and drop this util.
SmallVector<int64_t> getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront = 0,
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 7e3dc56e0acdc..f4324d7701bfd 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -33,6 +33,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
SplitReduction.cpp
SubsetInsertionOpInterfaceImpl.cpp
SwapExtractSliceWithFillPatterns.cpp
+ SwapTransposeWithBroadcast.cpp
Tiling.cpp
TilingInterfaceImpl.cpp
Transforms.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/SwapTransposeWithBroadcast.cpp b/mlir/lib/Dialect/Linalg/Transforms/SwapTransposeWithBroadcast.cpp
new file mode 100644
index 0000000000000..7d52c78b2d4df
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/SwapTransposeWithBroadcast.cpp
@@ -0,0 +1,84 @@
+//===- SwapTransposeWithBroadcast.cpp - Swap transpose with broadcast op --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// This is a pattern swap broadcast with transpose.
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "linalg-swap-transpose-with-broadcast"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+namespace {
+/// This pattern canonicalize transpose by swapping the order of
+/// broadcast and transpose:
+/// transpose(broadcast(input)) -> broadcast(transpose(input))
+struct SwapTransposeWithBroadcast : OpRewritePattern<linalg::TransposeOp> {
+ using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
+ PatternRewriter &rewriter) const override {
+ Value input = transposeOp.getInput();
+ BroadcastOp broadcastOp = input.getDefiningOp<BroadcastOp>();
+ if (!input.hasOneUse() || !broadcastOp)
+ return failure();
+
+ ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
+ ArrayRef<int64_t> perms = transposeOp.getPermutation();
+
+ // Get new perms and new dimensions.
+ SmallVector<int64_t> resultPerms = dropDims(perms, dimensions);
+ SmallVector<int64_t> invertPerm = invertPermutationVector(perms);
+ SmallVector<int64_t> resultDimensions;
+ for (unsigned i = 0; i < dimensions.size(); i++) {
+ resultDimensions.push_back(invertPerm[dimensions[i]]);
+ }
+
+ // Create transpose result.
+ Value broadcastInput = broadcastOp.getInput();
+ Location loc = transposeOp.getLoc();
+ MLIRContext *ctx = transposeOp.getContext();
+ SmallVector<OpFoldResult> dims;
+ auto broadcastInputTy =
+ mlir::cast<RankedTensorType>(broadcastInput.getType());
+ for (unsigned i = 0; i < broadcastInputTy.getRank(); i++) {
+ if (broadcastInputTy.isDynamicDim(i)) {
+ dims.push_back(rewriter.create<tensor::DimOp>(loc, broadcastInput, i)
+ ->getResult(0));
+ } else {
+ dims.push_back(IntegerAttr::get(IndexType::get(ctx),
+ broadcastInputTy.getDimSize(i)));
+ }
+ }
+ SmallVector<OpFoldResult> transposeResultShapes =
+ applyPermutation(dims, resultPerms);
+ Value transposeInit = rewriter.create<tensor::EmptyOp>(
+ transposeOp.getLoc(), transposeResultShapes,
+ broadcastInputTy.getElementType());
+
+ // Create broadcast(transpose(input)).
+ Value transposeResult =
+ rewriter
+ .create<TransposeOp>(loc, broadcastOp.getInput(), transposeInit,
+ resultPerms)
+ ->getResult(0);
+ rewriter.replaceOpWithNewOp<BroadcastOp>(
+ transposeOp, transposeResult, transposeOp.getInit(), resultDimensions);
+ return success();
+ }
+};
+} // namespace
+
+void mlir::linalg::populateSwapTransposeWithBroadcastPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<SwapTransposeWithBroadcast>(patterns.getContext());
+}
diff --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
index aba225be720c3..ddc1129a5a75f 100644
--- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp
+++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
@@ -252,6 +252,31 @@ mlir::computePermutationVector(int64_t permSize, ArrayRef<int64_t> positions,
return res;
}
+SmallVector<int64_t> mlir::dropDims(ArrayRef<int64_t> inputPerm,
+ ArrayRef<int64_t> dropPositions) {
+ assert(inputPerm.size() >= dropPositions.size() &&
+ "expect inputPerm size large than position to drop");
+ SmallVector<int64_t> res;
+ for (unsigned inputIndex = 0; inputIndex < inputPerm.size(); ++inputIndex) {
+ int64_t targetIndex = inputPerm[inputIndex];
+ bool shouldDrop = false;
+ for (unsigned dropIndex = 0; dropIndex < dropPositions.size();
+ dropIndex++) {
+ if (dropPositions[dropIndex] == inputPerm[inputIndex]) {
+ shouldDrop = true;
+ break;
+ }
+ if (dropPositions[dropIndex] < inputPerm[inputIndex]) {
+ targetIndex--;
+ }
+ }
+ if (!shouldDrop) {
+ res.push_back(targetIndex);
+ }
+ }
+ return res;
+}
+
SmallVector<int64_t> mlir::getI64SubArray(ArrayAttr arrayAttr,
unsigned dropFront,
unsigned dropBack) {
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 928030a81dc02..41305cb5ef60e 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1017,7 +1017,7 @@ func.func @broadcast_same_shape(%input: tensor<2x3xf32>, %init: tensor<2x3xf32>)
return %0 : tensor<2x3xf32>
}
-// ----
+// -----
func.func @transpose_1d(%input: tensor<16xf32>,
%init: tensor<16xf32>) -> tensor<16xf32> {
diff --git a/mlir/test/Dialect/Linalg/swap-transpose-with-broadcast.mlir b/mlir/test/Dialect/Linalg/swap-transpose-with-broadcast.mlir
new file mode 100644
index 0000000000000..2235f9c189c04
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/swap-transpose-with-broadcast.mlir
@@ -0,0 +1,73 @@
+//RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-swap-transpose-with-broadcast %s | FileCheck %s
+
+func.func @broadcast_transpose_fold(%input: tensor<2x4x5xf32>,
+ %init1: tensor<1x2x3x4x5x6xf32>,
+ %init2: tensor<1x6x2x3x5x4xf32>) -> tensor<1x6x2x3x5x4xf32> {
+ // CHECK-LABEL: @broadcast_transpose_fold
+ // CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<2x4x5xf32>
+ // CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<1x2x3x4x5x6xf32>
+ // CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<1x6x2x3x5x4xf32>
+ // CHECK: %[[TMP_INIT:.+]] = tensor.empty() : tensor<2x5x4xf32>
+ // CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[INPUT]] : tensor<2x4x5xf32>) outs(%[[TMP_INIT]] : tensor<2x5x4xf32>) permutation = [0, 2, 1]
+ // CHECK: %[[BROADCAST:.+]] = linalg.broadcast ins(%[[TRANSPOSE]] : tensor<2x5x4xf32>) outs(%[[INIT2]] : tensor<1x6x2x3x5x4xf32>) dimensions = [0, 3, 1]
+ // CHECK: return %[[BROADCAST]] : tensor<1x6x2x3x5x4xf32>
+ %broadcast = linalg.broadcast
+ ins(%input : tensor<2x4x5xf32>)
+ outs(%init1 : tensor<1x2x3x4x5x6xf32>)
+ dimensions = [0, 2, 5]
+ %transpose = linalg.transpose
+ ins(%broadcast : tensor<1x2x3x4x5x6xf32>)
+ outs(%init2 : tensor<1x6x2x3x5x4xf32>)
+ permutation = [0, 5, 1, 2, 4, 3]
+ func.return %transpose : tensor<1x6x2x3x5x4xf32>
+}
+
+// -----
+
+func.func @broadcast_transpose_fold_dynamic(%input: tensor<?x?x5xf32>,
+ %init1: tensor<1x?x3x?x5x6xf32>,
+ %init2: tensor<1x3x?x6x5x?xf32>) -> tensor<1x3x?x6x5x?xf32> {
+ // CHECK-LABEL: @broadcast_transpose_fold_dynamic
+ // CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<?x?x5xf32>
+ // CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<1x?x3x?x5x6xf32>
+ // CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<1x3x?x6x5x?xf32>
+ // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+ // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+ // CHECK: %[[DIM0:.+]] = tensor.dim %[[INPUT]], %[[C0]] : tensor<?x?x5xf32>
+ // CHECK: %[[DIM1:.+]] = tensor.dim %[[INPUT]], %[[C1]] : tensor<?x?x5xf32>
+ // CHECK: %[[TMP_INIT:.+]] = tensor.empty(%[[DIM1]], %[[DIM0]]) : tensor<?x5x?xf32>
+ // CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[INPUT]] : tensor<?x?x5xf32>) outs(%[[TMP_INIT]] : tensor<?x5x?xf32>) permutation = [1, 2, 0]
+ // CHECK: %[[BROADCAST:.+]] = linalg.broadcast ins(%[[TRANSPOSE]] : tensor<?x5x?xf32>) outs(%[[INIT2]] : tensor<1x3x?x6x5x?xf32>) dimensions = [0, 1, 3]
+ // CHECK: return %[[BROADCAST]] : tensor<1x3x?x6x5x?xf32>
+ %broadcast = linalg.broadcast
+ ins(%input : tensor<?x?x5xf32>)
+ outs(%init1 : tensor<1x?x3x?x5x6xf32>)
+ dimensions = [0, 2, 5]
+ %transpose = linalg.transpose
+ ins(%broadcast : tensor<1x?x3x?x5x6xf32>)
+ outs(%init2 : tensor<1x3x?x6x5x?xf32>)
+ permutation = [0, 2, 3, 5, 4, 1]
+ func.return %transpose : tensor<1x3x?x6x5x?xf32>
+}
+
+// -----
+
+func.func @broadcast_transpose_fold_2dim(%input: tensor<2xf32>,
+ %init1: tensor<2x4xf32>,
+ %init2: tensor<4x2xf32>) -> tensor<4x2xf32> {
+ // CHECK-LABEL: @broadcast_transpose_fold_2dim
+ // CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<2xf32>
+ // CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<2x4xf32>
+ // CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<4x2xf32>
+ // CHECK: %[[BROADCAST:.+]] = linalg.broadcast ins(%[[INPUT]] : tensor<2xf32>) outs(%[[INIT2]] : tensor<4x2xf32>) dimensions = [0]
+ // CHECK: return %[[BROADCAST]] : tensor<4x2xf32>
+ %broadcast = linalg.broadcast
+ ins(%input : tensor<2xf32>)
+ outs(%init1 : tensor<2x4xf32>)
+ dimensions = [1]
+ %transpose = linalg.transpose
+ ins(%broadcast : tensor<2x4xf32>)
+ outs(%init2 : tensor<4x2xf32>)
+ permutation = [1, 0]
+ func.return %transpose : tensor<4x2xf32>
+}
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 4892fa2f99a7c..7f30c24807d56 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -115,6 +115,10 @@ struct TestLinalgTransforms
llvm::cl::desc(
"Test patterns to swap tensor.extract_slice(linalg.fill())"),
llvm::cl::init(false)};
+ Option<bool> testSwapTransposeWithBroadcast{
+ *this, "test-swap-transpose-with-broadcast",
+ llvm::cl::desc("Test patterns to swap transpose(broadcast(input))"),
+ llvm::cl::init(false)};
Option<bool> testEraseUnusedOperandsAndResults{
*this, "test-erase-unused-operands-and-results",
llvm::cl::desc("Test patterns to erase unused operands and results"),
@@ -195,6 +199,12 @@ static void applySwapExtractSliceWithFillPattern(func::FuncOp funcOp) {
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
+static void applySwapTransposeWithBroadcast(func::FuncOp funcOp) {
+ RewritePatternSet patterns(funcOp.getContext());
+ populateSwapTransposeWithBroadcastPatterns(patterns);
+ (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+}
+
static void applyEraseUnusedOperandsAndResultsPatterns(func::FuncOp funcOp) {
RewritePatternSet patterns(funcOp.getContext());
populateEraseUnusedOperandsAndResultsPatterns(patterns);
@@ -227,6 +237,8 @@ void TestLinalgTransforms::runOnOperation() {
return applyBubbleUpExtractSliceOpPattern(getOperation());
if (testSwapExtractSliceWithFill)
return applySwapExtractSliceWithFillPattern(getOperation());
+ if (testSwapTransposeWithBroadcast)
+ return applySwapTransposeWithBroadcast(getOperation());
if (testEraseUnusedOperandsAndResults)
return applyEraseUnusedOperandsAndResultsPatterns(getOperation());
if (testEraseUnnecessaryInputs)
More information about the Mlir-commits
mailing list