[Mlir-commits] [mlir] 3f89e33 - [mlir] add pad_tensor(tensor.cast) -> pad_tensor canonicalizer
Alex Zinenko
llvmlistbot at llvm.org
Fri Sep 24 03:03:54 PDT 2021
Author: Alex Zinenko
Date: 2021-09-24T12:03:47+02:00
New Revision: 3f89e339bb185726a2a3eb127ac59c813b52c6fe
URL: https://github.com/llvm/llvm-project/commit/3f89e339bb185726a2a3eb127ac59c813b52c6fe
DIFF: https://github.com/llvm/llvm-project/commit/3f89e339bb185726a2a3eb127ac59c813b52c6fe.diff
LOG: [mlir] add pad_tensor(tensor.cast) -> pad_tensor canonicalizer
This canonicalization pattern complements the tensor.cast(pad_tensor) one in
propagating constant type information when possible. It contributes to the
feasibility of pad hoisting.
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D110343
Added:
Modified:
mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir
mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir
mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
mlir/test/Dialect/Linalg/tile-pad-tensor-op.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
index f374a0613f7a1..e8df979ac9cfb 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
+++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
@@ -53,6 +53,10 @@ SmallVector<Range, 8> getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
namespace mlir {
namespace tensor {
+/// Returns true if `target` is a ranked tensor type that preserves static
+/// information available in the `source` ranked tensor type.
+bool preservesStaticInformation(Type source, Type target);
+
/// Determines whether tensor::CastOp casts to a more dynamic version of the
/// source tensor. This is useful to fold a tensor.cast into a consuming op and
/// implement canonicalization patterns for ops in
diff erent dialects that may
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 75e4a1c91bcda..dfa4df8f58d69 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1482,11 +1482,41 @@ struct FoldSourceTensorCast : public OpRewritePattern<PadTensorOp> {
return success();
}
};
+
+// Fold CastOp using the result of PadTensorOp back into the latter if it adds
+// static information.
+struct FoldTargetTensorCast : public OpRewritePattern<PadTensorOp> {
+ using OpRewritePattern<PadTensorOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(PadTensorOp padTensorOp,
+ PatternRewriter &rewriter) const override {
+ if (!padTensorOp.result().hasOneUse())
+ return failure();
+ auto tensorCastOp =
+ dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
+ if (!tensorCastOp)
+ return failure();
+ if (!tensor::preservesStaticInformation(padTensorOp.result().getType(),
+ tensorCastOp.dest().getType()))
+ return failure();
+
+ auto replacementOp = rewriter.create<PadTensorOp>(
+ padTensorOp.getLoc(), tensorCastOp.dest().getType(),
+ padTensorOp.source(), padTensorOp.low(), padTensorOp.high(),
+ padTensorOp.static_low(), padTensorOp.static_high());
+ replacementOp.region().takeBody(padTensorOp.region());
+
+ rewriter.replaceOp(padTensorOp, replacementOp.result());
+ rewriter.replaceOp(tensorCastOp, replacementOp.result());
+ return success();
+ }
+};
} // namespace
void PadTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<FoldStaticZeroPadding, FoldSourceTensorCast>(context);
+ results.add<FoldTargetTensorCast>(context);
}
/// Return the padding value of the PadTensorOp if it constant. In this context,
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 14ce6c104d44f..2a55223ca9a8b 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -31,6 +31,34 @@ Operation *TensorDialect::materializeConstant(OpBuilder &builder,
// CastOp
//===----------------------------------------------------------------------===//
+/// Returns true if `target` is a ranked tensor type that preserves static
+/// information available in the `source` ranked tensor type.
+bool mlir::tensor::preservesStaticInformation(Type source, Type target) {
+ auto sourceType = source.dyn_cast<RankedTensorType>();
+ auto targetType = target.dyn_cast<RankedTensorType>();
+
+ // Requires RankedTensorType.
+ if (!sourceType || !targetType)
+ return false;
+
+ // Requires same elemental type.
+ if (sourceType.getElementType() != targetType.getElementType())
+ return false;
+
+ // Requires same rank.
+ if (sourceType.getRank() != targetType.getRank())
+ return false;
+
+ // If cast is towards more static sizes along any dimension, don't fold.
+ for (auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
+ if (!ShapedType::isDynamic(std::get<0>(t)) &&
+ ShapedType::isDynamic(std::get<1>(t)))
+ return false;
+ }
+
+ return true;
+}
+
/// Determines whether tensor::CastOp casts to a more dynamic version of the
/// source tensor. This is useful to fold a tensor.cast into a consuming op and
/// implement canonicalization patterns for ops in
diff erent dialects that may
@@ -57,30 +85,10 @@ bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
if (!castOp)
return false;
- RankedTensorType sourceType =
- castOp.source().getType().dyn_cast<RankedTensorType>();
- RankedTensorType resultType = castOp.getType().dyn_cast<RankedTensorType>();
-
- // Requires RankedTensorType.
- if (!sourceType || !resultType)
- return false;
-
- // Requires same elemental type.
- if (sourceType.getElementType() != resultType.getElementType())
- return false;
-
- // Requires same rank.
- if (sourceType.getRank() != resultType.getRank())
- return false;
-
- // If cast is towards more static sizes along any dimension, don't fold.
- for (auto t : llvm::zip(sourceType.getShape(), resultType.getShape())) {
- if (ShapedType::isDynamic(std::get<0>(t)) &&
- !ShapedType::isDynamic(std::get<1>(t)))
- return false;
- }
-
- return true;
+ // Can fold if the source of cast has at least as much static information as
+ // its results.
+ return preservesStaticInformation(castOp.getType(),
+ castOp.source().getType());
}
/// Performs folding of any operand of `op` if it comes from a tensor::CastOp
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index fce08a1e04dca..42d640a60246c 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -696,6 +696,39 @@ func @pad_tensor_of_cast(%t: tensor<8x?xf32>, %s: index) -> tensor<8x32xf32> {
// -----
+// CHECK-LABEL: @cast_of_pad_more_static
+func @cast_of_pad_more_static(%arg0: tensor<?x?xf32>, %padding: index) -> tensor<32x32xf32> {
+ %cst = constant 0.000000e+00 : f32
+ // CHECK: %[[PAD:.*]] = linalg.pad_tensor
+ // CHECK: tensor<?x?xf32> to tensor<32x32xf32>
+ %padded = linalg.pad_tensor %arg0 low[%padding, %padding] high[0, 0] {
+ ^bb0(%arg1: index, %arg2: index):
+ linalg.yield %cst : f32
+ } : tensor<?x?xf32> to tensor<?x?xf32>
+ // CHECK-NOT: tensor.cast
+ %casted = tensor.cast %padded : tensor<?x?xf32> to tensor<32x32xf32>
+ // CHECK: return %[[PAD]]
+ return %casted : tensor<32x32xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @cast_of_pad_less_static
+func @cast_of_pad_less_static(%arg0: tensor<32x?x?xf32>, %padding: index) -> tensor<?x32x32xf32> {
+ %cst = constant 0.000000e+00 : f32
+ // CHECK: linalg.pad_tensor
+ %padded = linalg.pad_tensor %arg0 low[%padding, %padding, %padding] high[0, 0, 0] {
+ ^bb0(%arg1: index, %arg2: index, %arg3: index):
+ linalg.yield %cst : f32
+ } : tensor<32x?x?xf32> to tensor<32x?x?xf32>
+ // CHECK: %[[CAST:.*]] = tensor.cast
+ %casted = tensor.cast %padded : tensor<32x?x?xf32> to tensor<?x32x32xf32>
+ // CHECK: return %[[CAST]]
+ return %casted : tensor<?x32x32xf32>
+}
+
+// -----
+
func @propogate_casts(%arg0 : tensor<?x?xf32>, %arg1 : f32, %arg2 : index,
%arg3 : index) -> tensor<?x?xf32> {
%c0 = constant 0 : index
diff --git a/mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir b/mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir
index 13f12d83133be..b00855581efb7 100644
--- a/mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir
+++ b/mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir
@@ -140,8 +140,7 @@ func @static_mixed_data_low_high_pad(%arg0 : tensor<4x5xf32>, %pad : f32)
// CHECK: } else {
// CHECK: %[[SUBTENSOR:.*]] = tensor.extract_slice %[[ARG0]][%{{.*}}, 4] [%{{.*}}, 1] [1, 1] : tensor<?x5xf32> to tensor<?x1xf32>
// CHECK: %[[PADTENSOR:.*]] = linalg.pad_tensor %[[SUBTENSOR]] low[0, 0] high[%{{.*}}, 3]
-// CHECK: %[[CAST:.*]] = tensor.cast %[[PADTENSOR]] : tensor<?x4xf32> to tensor<3x4xf32>
-// CHECK: scf.yield %[[CAST]]
+// CHECK: scf.yield %[[PADTENSOR]]
// CHECK: }
// CHECK: return %[[RESULT]]
func @dynamic_high_pad(%arg0 : tensor<?x5xf32>, %h1: index, %pad : f32) -> tensor<3x4xf32> {
diff --git a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
index c1a761ce1c425..4aef50e6c96ba 100644
--- a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
+++ b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
@@ -289,7 +289,6 @@ func @conv_tensors_dynamic(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x?x
// CHECK: else
// CHECK: tensor.extract_slice
// CHECK: linalg.pad_tensor
-// CHECK: tensor.cast
// CHECK: tensor.extract_slice
// CHECK: tensor.extract_slice
// CHECK: linalg.generic
diff --git a/mlir/test/Dialect/Linalg/tile-pad-tensor-op.mlir b/mlir/test/Dialect/Linalg/tile-pad-tensor-op.mlir
index 20615f27d1442..5556699cdae66 100644
--- a/mlir/test/Dialect/Linalg/tile-pad-tensor-op.mlir
+++ b/mlir/test/Dialect/Linalg/tile-pad-tensor-op.mlir
@@ -111,8 +111,7 @@ func @static_pad_tensor(%input_tensor: tensor<7x9xf32>,
// TILE1: else
// TILE1: %[[SLICE:.*]] = tensor.extract_slice %arg0[0, %{{.*}}] [7, %{{.*}}] [1, 1] : tensor<7x9xf32> to tensor<7x?xf32>
// TILE1: %[[PAD:.*]] = linalg.pad_tensor %[[SLICE]] low[0, 0] high[7, %{{.*}}]
-// TILE1: %[[CAST:.*]] = tensor.cast %[[PAD]] : tensor<14x?xf32> to tensor<14x3xf32>
-// TILE1: scf.yield %[[CAST]] : tensor<14x3xf32>
+// TILE1: scf.yield %[[PAD]] : tensor<14x3xf32>
// TILE1: %[[R3:.*]] = tensor.insert_slice %[[R2]] into %[[INNER_OUT]][0, %[[IV]]] [14, 3] [1, 1] : tensor<14x3xf32> into tensor<14x15xf32>
// TILE1: scf.yield %[[R3]] : tensor<14x15xf32>
// TILE1: return %[[RESULT]] : tensor<14x15xf32>
More information about the Mlir-commits
mailing list