[Mlir-commits] [mlir] 9a82482 - [mlir][linalg] Fix pad tensor cast folding with changed type

Yi Zhang llvmlistbot at llvm.org
Thu Jul 29 14:51:56 PDT 2021


Author: Yi Zhang
Date: 2021-07-29T17:47:01-04:00
New Revision: 9a824823131600bca71406f533c2ba051c23c7d7

URL: https://github.com/llvm/llvm-project/commit/9a824823131600bca71406f533c2ba051c23c7d7
DIFF: https://github.com/llvm/llvm-project/commit/9a824823131600bca71406f533c2ba051c23c7d7.diff

LOG: [mlir][linalg] Fix pad tensor cast folding with changed type

`PadTensorOp` has verification logic to make sure
result dim must be static if all the padding values are static.
Cast folding might add more static information for the src operand
of `PadTensorOp` which might change a valid operation to be invalid.
Change the canonicalizing pattern to fix this.

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 5ce4b3d27842f..8b54b6de8f092 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1229,9 +1229,26 @@ struct FoldSourceTensorCast : public OpRewritePattern<PadTensorOp> {
     if (!tensor::canFoldIntoConsumerOp(castOp))
       return failure();
 
-    rewriter.updateRootInPlace(padTensorOp, [&]() {
-      padTensorOp.sourceMutable().assign(castOp.source());
-    });
+    auto newResultType = PadTensorOp::inferResultType(
+        castOp.source().getType().cast<RankedTensorType>(),
+        extractFromI64ArrayAttr(padTensorOp.static_low()),
+        extractFromI64ArrayAttr(padTensorOp.static_high()));
+
+    if (newResultType == padTensorOp.getResultType()) {
+      rewriter.updateRootInPlace(padTensorOp, [&]() {
+        padTensorOp.sourceMutable().assign(castOp.source());
+      });
+    } else {
+      auto newOp = rewriter.create<PadTensorOp>(
+          padTensorOp->getLoc(), newResultType, padTensorOp.source(),
+          padTensorOp.low(), padTensorOp.high(), padTensorOp.static_low(),
+          padTensorOp.static_high(), /*output=*/nullptr);
+      BlockAndValueMapping mapper;
+      padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
+
+      rewriter.replaceOpWithNewOp<tensor::CastOp>(
+          padTensorOp, padTensorOp.getResultType(), newOp);
+    }
     return success();
   }
 };

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index ee7edbe9010fd..34dacd58a51e5 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -627,6 +627,55 @@ func @pad_tensor_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
   } : tensor<5x6xf32> to tensor<5x6xf32>
   return %0 : tensor<5x6xf32>
 }
+
+// -----
+// CHECK-LABEL:   func @pad_tensor_after_cast_
diff ernt_shape(
+// CHECK-SAME:      %[[INPUT:.*]]: tensor<?x64x?x?xf32>) -> tensor<?x?x?x?xf32> {
+// CHECK:           %[[CST:.*]] = constant 0.000000e+00 : f32
+// CHECK:           %[[PADDED:.*]] = linalg.pad_tensor %[[INPUT]]
+// CHECK-SAME:        low[0, 0, 1, 1] high[0, 0, 1, 1]  {
+// CHECK:           ^bb0(%[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index):
+// CHECK:             linalg.yield %[[CST]] : f32
+// CHECK:           } : tensor<?x64x?x?xf32> to tensor<?x64x?x?xf32>
+// CHECK:           %[[DYNAMIC:.*]] = tensor.cast %[[PADDED:.*]] :
+// CHECK-SAME:         tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32>
+// CHECK:           return %[[DYNAMIC]] : tensor<?x?x?x?xf32>
+// CHECK:         }
+func @pad_tensor_after_cast_
diff ernt_shape(%arg0: tensor<?x64x?x?xf32>)
+    -> tensor<?x?x?x?xf32> {
+  %cst = constant 0.000000e+00 : f32
+  %dynamic = tensor.cast %arg0 : tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32>
+  %padded = linalg.pad_tensor %dynamic low[0, 0, 1, 1] high[0, 0, 1, 1]  {
+    ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):  // no predecessors
+    linalg.yield %cst: f32
+  } : tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32>
+  return %padded: tensor<?x?x?x?xf32>
+}
+
+// -----
+// CHECK-LABEL:   func @pad_tensor_after_cast_same_shape(
+// CHECK-SAME:      %[[INPUT:.*]]: tensor<?x64x?x?xf32>,
+// CHECK-SAME:      %[[PADDING:.*]]: index) -> tensor<?x?x?x?xf32> {
+// CHECK:           %[[CST:.*]] = constant 0.000000e+00 : f32
+// CHECK:           %[[PADDED:.*]] = linalg.pad_tensor %[[INPUT]]
+// CHECK-SAME:        low[0, %[[PADDING]], 1, 1] high[0, %[[PADDING]], 1, 1]  {
+// CHECK:           ^bb0(%[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index):
+// CHECK:             linalg.yield %[[CST]] : f32
+// CHECK:           } : tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32>
+// CHECK:           return %[[PADDED:.*]] : tensor<?x?x?x?xf32>
+// CHECK:         }
+func @pad_tensor_after_cast_same_shape(%arg0: tensor<?x64x?x?xf32>, %padding : index)
+    -> tensor<?x?x?x?xf32> {
+  %cst = constant 0.000000e+00 : f32
+  %dynamic = tensor.cast %arg0 : tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32>
+  %padded = linalg.pad_tensor %dynamic low[0, %padding, 1, 1] high[0, %padding, 1, 1]  {
+    ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):  // no predecessors
+    linalg.yield %cst: f32
+  } : tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32>
+  return %padded: tensor<?x?x?x?xf32>
+}
+
+// -----
 func @propogate_casts(%arg0 : tensor<?x?xf32>, %arg1 : f32, %arg2 : index,
     %arg3 : index) -> tensor<?x?xf32> {
   %c0 = constant 0 : index


        


More information about the Mlir-commits mailing list