[Mlir-commits] [mlir] 76ea62a - [mlir] Fix folding into tensor.pad op.

Alexander Belyaev llvmlistbot at llvm.org
Thu Mar 30 02:35:16 PDT 2023


Author: Alexander Belyaev
Date: 2023-03-30T11:30:06+02:00
New Revision: 76ea62a2735a760545bfa98524e7a658a15268ac

URL: https://github.com/llvm/llvm-project/commit/76ea62a2735a760545bfa98524e7a658a15268ac
DIFF: https://github.com/llvm/llvm-project/commit/76ea62a2735a760545bfa98524e7a658a15268ac.diff

LOG: [mlir] Fix folding into tensor.pad op.

When low/high padding is folded in padOp, there should be inserted a
tensor.cast back to the original result type. Right now, there is a no-op
tensor.cast from new type to new type...

Differential Revision: https://reviews.llvm.org/D147210

Added: 
    

Modified: 
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/test/Dialect/Tensor/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index e7fb28794567e..7ee9325e5f8eb 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2879,10 +2879,12 @@ struct FoldStaticPadding : public OpRewritePattern<PadOp> {
     auto inputDims = input.getType().cast<RankedTensorType>().getShape();
     auto inputRank = inputDims.size();
 
-    if (!padTensorOp.getResult().getType().isa<RankedTensorType>())
+    auto oldResultType =
+        dyn_cast<RankedTensorType>(padTensorOp.getResult().getType());
+    if (!oldResultType)
       return failure();
-    auto outputDims =
-        padTensorOp.getResult().getType().cast<RankedTensorType>().getShape();
+
+    auto outputDims = oldResultType.getShape();
 
     // Extract the static info from the high and low operands.
     SmallVector<int64_t> constOperandsLow;
@@ -2955,7 +2957,7 @@ struct FoldStaticPadding : public OpRewritePattern<PadOp> {
 
     IRMapping mapper;
     padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
-    rewriter.replaceOpWithNewOp<tensor::CastOp>(padTensorOp, newResultType,
+    rewriter.replaceOpWithNewOp<tensor::CastOp>(padTensorOp, oldResultType,
                                                 newOp);
 
     return success();

diff  --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 8a5e04750e7ce..0a42e2bb3a5c9 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1140,7 +1140,7 @@ func.func @pad_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
 // -----
 
 // CHECK-LABEL:   func @pad_fold_static(
-// CHECK-SAME:      %[[INPUT:.*]]: tensor<?x64x?x?xf32>) -> tensor<?xf32> {
+// CHECK-SAME:      %[[INPUT:.*]]: tensor<?x64x?x?xf32>) -> tensor<?x?x?x?xf32> {
 // CHECK:           %[[CST:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK:           %[[PADDING:.*]] = arith.constant 4 : index
 // CHECK:           %[[PADDED:.*]] = tensor.pad %[[INPUT]]
@@ -1148,16 +1148,16 @@ func.func @pad_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
 // CHECK:           ^bb0(%[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index):
 // CHECK:             tensor.yield %[[CST]] : f32
 // CHECK:           } : tensor<?x64x?x?xf32> to tensor<?x72x?x?xf32>
-func.func @pad_fold_static(%arg0: tensor<?x64x?x?xf32>)
-    -> tensor<?xf32> {
+// CHECK:           tensor.cast
+func.func @pad_fold_static(%arg0: tensor<?x64x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %c0 = arith.constant 0 : index
   %cst = arith.constant 0.000000e+00 : f32
   %padding = arith.constant 4 : index
   %padded = tensor.pad %arg0 low[0, %padding, 1, 1] high[0, %padding, 1, 1]  {
     ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
     tensor.yield %cst: f32
   } : tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32>
-  %result = tensor.collapse_shape %padded [[0, 1, 2, 3]] : tensor<?x?x?x?xf32> into tensor<?xf32>
-  return %result : tensor<?xf32>
+  return %padded : tensor<?x?x?x?xf32>
 }
 
 // -----


        


More information about the Mlir-commits mailing list