[Mlir-commits] [mlir] b6ab4f1 - [mlir][linalg] Fold linalg.pad_tensor if src type == result type
Matthias Springer
llvmlistbot at llvm.org
Tue Jun 15 01:25:25 PDT 2021
Author: Matthias Springer
Date: 2021-06-15T17:25:12+09:00
New Revision: b6ab4f1a8b6546b67dbcc3612f33c26d9b72a5cc
URL: https://github.com/llvm/llvm-project/commit/b6ab4f1a8b6546b67dbcc3612f33c26d9b72a5cc
DIFF: https://github.com/llvm/llvm-project/commit/b6ab4f1a8b6546b67dbcc3612f33c26d9b72a5cc.diff
LOG: [mlir][linalg] Fold linalg.pad_tensor if src type == result type
Fold PadTensorOp to source if source type and result type have static shape and are equal.
Differential Revision: https://reviews.llvm.org/D103778
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 9b6120e61eed..1b0177642ee9 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -296,6 +296,7 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
];
let hasCanonicalizer = 1;
+ let hasFolder = 1;
}
def Linalg_RangeOp :
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 2b3ae8909541..985a9f7a09a2 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1164,6 +1164,12 @@ Value PadTensorOp::getConstantPaddingValue() {
return padValue;
}
+OpFoldResult PadTensorOp::fold(ArrayRef<Attribute>) {
+ if (getResultType().hasStaticShape() && getResultType() == getSourceType())
+ return source();
+ return {};
+}
+
//===----------------------------------------------------------------------===//
// ReshapeOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 6bd2895fcdc9..029ac621ca4b 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -893,6 +893,22 @@ func @dead_linalg_tensor(%arg0 : tensor<7x7xi32>, %arg1 : tensor<7x7xf32>,
// -----
+// CHECK-LABEL: func @pad_tensor_same_static_shape(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32>
+// CHECK-NOT: linalg.pad_tensor
+// CHECK: return %[[ARG0]]
+func @pad_tensor_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
+ -> tensor<5x6xf32> {
+ %cst = constant 0.000000e+00 : f32
+ %0 = linalg.pad_tensor %arg0 low[%a, 0] high[0, %a] {
+ ^bb0(%arg1: index, %arg2: index):
+ linalg.yield %cst : f32
+ } : tensor<5x6xf32> to tensor<5x6xf32>
+ return %0 : tensor<5x6xf32>
+}
+
+// -----
+
func @dim_reshape_expansion(%arg0 : tensor<6x5x?xf32>) -> (index, index, index)
{
%c1 = constant 1 : index
More information about the Mlir-commits
mailing list