[Mlir-commits] [mlir] [mlir][linalg] Add a folder for transpose(fill) -> fill (PR #83623)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Mar 1 13:26:21 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg
@llvm/pr-subscribers-mlir
Author: Quinn Dawkins (qedawkins)
<details>
<summary>Changes</summary>
This is similar to the existing folder for a linalg.copy. Transposing a filled tensor is the same as filling the destination of the transpose.
---
Full diff: https://github.com/llvm/llvm-project/pull/83623.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+17-1)
- (modified) mlir/test/Dialect/Linalg/canonicalize.mlir (+14)
``````````diff
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 919f5130e1760f..6954eee93efd14 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -815,6 +815,22 @@ struct FoldFillWithCopy : OpRewritePattern<linalg::CopyOp> {
}
};
+/// Fold fill with transpose.
+struct FoldFillWithTranspose : OpRewritePattern<linalg::TransposeOp> {
+ using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
+ PatternRewriter &rewriter) const override {
+ if (auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) {
+ rewriter.replaceOpWithNewOp<FillOp>(
+ transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),
+ transposeOp.getDpsInitOperand(0)->get());
+ return success();
+ }
+ return failure();
+ }
+};
+
} // namespace
void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
@@ -823,7 +839,7 @@ void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
.add<FoldFillWithCopy, FoldFillWithTensorExtract, FoldFillWithPack,
FoldFillWithPad, FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
- FoldInsertPadIntoFill>(context);
+ FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 206d7e9f1ce8df..19cea6c2066c92 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -993,6 +993,20 @@ func.func @canonicalize_fill_to_copy_dest(%arg0 : tensor<?x?xf32>, %arg1 : tenso
// -----
+// CHECK-LABEL: func @canonicalize_fill_to_transpose_input(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
+// CHECK: %[[ZERO:.+]] = arith.constant 0.0
+// CHECK: linalg.fill ins(%[[ZERO]] : f32) outs(%[[ARG1]] : tensor<?x?xf32>)
+func.func @canonicalize_fill_to_transpose_input(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %c0 = arith.constant 0.0 : f32
+ %fill = linalg.fill ins(%c0 : f32) outs(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %transpose = linalg.transpose ins(%fill : tensor<?x?xf32>) outs(%arg1 : tensor<?x?xf32>) permutation = [1, 0]
+ return %transpose : tensor<?x?xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @broadcast_same_shape(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<2x3xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<2x3xf32>)
``````````
</details>
https://github.com/llvm/llvm-project/pull/83623
More information about the Mlir-commits
mailing list