[Mlir-commits] [mlir] 76ea62a - [mlir] Fix folding into tensor.pad op.
Alexander Belyaev
llvmlistbot at llvm.org
Thu Mar 30 02:35:16 PDT 2023
Author: Alexander Belyaev
Date: 2023-03-30T11:30:06+02:00
New Revision: 76ea62a2735a760545bfa98524e7a658a15268ac
URL: https://github.com/llvm/llvm-project/commit/76ea62a2735a760545bfa98524e7a658a15268ac
DIFF: https://github.com/llvm/llvm-project/commit/76ea62a2735a760545bfa98524e7a658a15268ac.diff
LOG: [mlir] Fix folding into tensor.pad op.
When low/high padding is folded in padOp, there should be inserted a
tensor.cast back to the original result type. Right now, there is a no-op
tensor.cast from new type to new type...
Differential Revision: https://reviews.llvm.org/D147210
Added:
Modified:
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Tensor/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index e7fb28794567e..7ee9325e5f8eb 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2879,10 +2879,12 @@ struct FoldStaticPadding : public OpRewritePattern<PadOp> {
auto inputDims = input.getType().cast<RankedTensorType>().getShape();
auto inputRank = inputDims.size();
- if (!padTensorOp.getResult().getType().isa<RankedTensorType>())
+ auto oldResultType =
+ dyn_cast<RankedTensorType>(padTensorOp.getResult().getType());
+ if (!oldResultType)
return failure();
- auto outputDims =
- padTensorOp.getResult().getType().cast<RankedTensorType>().getShape();
+
+ auto outputDims = oldResultType.getShape();
// Extract the static info from the high and low operands.
SmallVector<int64_t> constOperandsLow;
@@ -2955,7 +2957,7 @@ struct FoldStaticPadding : public OpRewritePattern<PadOp> {
IRMapping mapper;
padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
- rewriter.replaceOpWithNewOp<tensor::CastOp>(padTensorOp, newResultType,
+ rewriter.replaceOpWithNewOp<tensor::CastOp>(padTensorOp, oldResultType,
newOp);
return success();
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 8a5e04750e7ce..0a42e2bb3a5c9 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1140,7 +1140,7 @@ func.func @pad_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
// -----
// CHECK-LABEL: func @pad_fold_static(
-// CHECK-SAME: %[[INPUT:.*]]: tensor<?x64x?x?xf32>) -> tensor<?xf32> {
+// CHECK-SAME: %[[INPUT:.*]]: tensor<?x64x?x?xf32>) -> tensor<?x?x?x?xf32> {
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[PADDING:.*]] = arith.constant 4 : index
// CHECK: %[[PADDED:.*]] = tensor.pad %[[INPUT]]
@@ -1148,16 +1148,16 @@ func.func @pad_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
// CHECK: ^bb0(%[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index):
// CHECK: tensor.yield %[[CST]] : f32
// CHECK: } : tensor<?x64x?x?xf32> to tensor<?x72x?x?xf32>
-func.func @pad_fold_static(%arg0: tensor<?x64x?x?xf32>)
- -> tensor<?xf32> {
+// CHECK: tensor.cast
+func.func @pad_fold_static(%arg0: tensor<?x64x?x?xf32>) -> tensor<?x?x?x?xf32> {
+ %c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%padding = arith.constant 4 : index
%padded = tensor.pad %arg0 low[0, %padding, 1, 1] high[0, %padding, 1, 1] {
^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
tensor.yield %cst: f32
} : tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32>
- %result = tensor.collapse_shape %padded [[0, 1, 2, 3]] : tensor<?x?x?x?xf32> into tensor<?xf32>
- return %result : tensor<?xf32>
+ return %padded : tensor<?x?x?x?xf32>
}
// -----
More information about the Mlir-commits
mailing list