[Mlir-commits] [mlir] 87656a3 - [mlir][linalg] Fold TensorCast into PadTensorOp.

Tobias Gysi llvmlistbot at llvm.org
Mon Jul 19 08:58:10 PDT 2021


Author: Tobias Gysi
Date: 2021-07-19T15:57:38Z
New Revision: 87656a3134c7c03565efca85352a58541ce68789

URL: https://github.com/llvm/llvm-project/commit/87656a3134c7c03565efca85352a58541ce68789
DIFF: https://github.com/llvm/llvm-project/commit/87656a3134c7c03565efca85352a58541ce68789.diff

LOG: [mlir][linalg] Fold TensorCast into PadTensorOp.

Add pattern to fold a TensorCast into a PadTensorOp if the cast removes static size information.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D106278

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 93d330d8d846a..16ff0711f4744 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1063,11 +1063,29 @@ struct FoldToDimOfOutputOperand : public OpRewritePattern<tensor::DimOp> {
     return success();
   }
 };
+
+// Fold CastOp into PadTensorOp when adding static information.
+struct FoldSourceTensorCast : public OpRewritePattern<PadTensorOp> {
+  using OpRewritePattern<PadTensorOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(PadTensorOp padTensorOp,
+                                PatternRewriter &rewriter) const override {
+    auto castOp = padTensorOp.source().getDefiningOp<tensor::CastOp>();
+    if (!tensor::canFoldIntoConsumerOp(castOp))
+      return failure();
+
+    rewriter.updateRootInPlace(padTensorOp, [&]() {
+      padTensorOp.sourceMutable().assign(castOp.source());
+    });
+    return success();
+  }
+};
 } // namespace
 
 void PadTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                               MLIRContext *context) {
-  results.add<FoldStaticZeroPadding, FoldToDimOfOutputOperand>(context);
+  results.add<FoldStaticZeroPadding, FoldToDimOfOutputOperand,
+              FoldSourceTensorCast>(context);
 }
 
 /// Return the padding value of the PadTensorOp if it constant. In this context,

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index c453255d39485..b24876cfe255d 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -772,6 +772,22 @@ func @tensor_pad_cast_fold(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
 
 // -----
 
+// CHECK-LABEL: func @fold_pad_tensor_source_cast(
+//  CHECK-SAME:                  %[[ARG0:.*]]: tensor<4x?xf32>
+//   CHECK-NOT:   tensor.cast
+//       CHECK:   %[[RESULT:.*]] = linalg.pad_tensor %[[ARG0]]
+func @fold_pad_tensor_source_cast(%arg0: tensor<4x?xf32>) -> tensor<4x4xf32> {
+  %cst = constant 0.0 : f32
+  %0 = tensor.cast %arg0 : tensor<4x?xf32> to tensor<?x?xf32>
+  %1 = linalg.pad_tensor %0 low[0, 0] high[0, 1]  {
+    ^bb0(%arg1: index, %arg2: index):  // no predecessors
+      linalg.yield %cst : f32
+  } : tensor<?x?xf32> to tensor<4x4xf32>
+  return %1 : tensor<4x4xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @pad_static_zero_cast(
 //  CHECK-SAME:                  %[[ARG0:.*]]: tensor<?x?x?xf32>
 //   CHECK-NOT:   linalg.pad_tensor


        


More information about the Mlir-commits mailing list