[Mlir-commits] [mlir] cb393f4 - [MLIR][Shape] Canonicalize casted extent tensor operands
Frederik Gossen
llvmlistbot at llvm.org
Wed Apr 28 02:52:17 PDT 2021
Author: Frederik Gossen
Date: 2021-04-28T11:51:58+02:00
New Revision: cb393f4c99c1e35d79a1017e59b00014f01daf3e
URL: https://github.com/llvm/llvm-project/commit/cb393f4c99c1e35d79a1017e59b00014f01daf3e
DIFF: https://github.com/llvm/llvm-project/commit/cb393f4c99c1e35d79a1017e59b00014f01daf3e.diff
LOG: [MLIR][Shape] Canonicalize casted extent tensor operands
Both, `shape.broadcast` and `shape.cstr_broadcastable` accept dynamic and static
extent tensors. If their operands are casted, we can use the original value
instead.
Differential Revision: https://reviews.llvm.org/D101376
Added:
Modified:
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/test/Dialect/Shape/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 47fd322ba47c..fd012aa84d1c 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -628,12 +628,45 @@ struct BroadcastFoldConstantOperandsPattern
return success();
}
};
+
+template <typename OpTy>
+struct CanonicalizeCastExtentTensorOperandsPattern
+ : public OpRewritePattern<OpTy> {
+ using OpRewritePattern<OpTy>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(OpTy op,
+ PatternRewriter &rewriter) const override {
+ // Canonicalize operands.
+ bool anyChange = false;
+ auto canonicalizeOperand = [&](Value operand) {
+ if (auto castOp = operand.getDefiningOp<tensor::CastOp>()) {
+ // Only eliminate the cast if it holds no shape information.
+ bool isInformationLoosingCast =
+ castOp.getType().cast<RankedTensorType>().isDynamicDim(0);
+ if (isInformationLoosingCast) {
+ anyChange = true;
+ return castOp.source();
+ }
+ }
+ return operand;
+ };
+ auto newOperands = llvm::to_vector<8>(
+ llvm::map_range(op.getOperands(), canonicalizeOperand));
+
+ // Rewrite op if any change required.
+ if (!anyChange)
+ return failure();
+ rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands);
+ return success();
+ }
+};
} // namespace
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add<BroadcastFoldConstantOperandsPattern,
BroadcastForwardSingleOperandPattern,
+ CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>,
RemoveDuplicateOperandsPattern<BroadcastOp>,
RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context);
}
@@ -716,7 +749,8 @@ void CstrBroadcastableOp::getCanonicalizationPatterns(
// Canonicalization patterns have overlap with the considerations during
// folding in case additional shape information is inferred at some point that
// does not result in folding.
- patterns.add<CstrBroadcastableEqOps,
+ patterns.add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>,
+ CstrBroadcastableEqOps,
RemoveDuplicateOperandsPattern<CstrBroadcastableOp>,
RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context);
}
@@ -1188,7 +1222,7 @@ struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
// ```
// %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex>
// ```
-struct ShapeOfCastedExtentTensor : public OpRewritePattern<tensor::CastOp> {
+struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {
using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::CastOp op,
@@ -1214,7 +1248,7 @@ struct ShapeOfCastedExtentTensor : public OpRewritePattern<tensor::CastOp> {
void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
- patterns.add<ShapeOfCastedExtentTensor, ShapeOfWithTensor>(context);
+ patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 3876e9ca34fc..367ce7f6ba1a 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -1115,8 +1115,8 @@ func @fold_div_mixed() -> !shape.size {
// CHECK-LABEL: @fold_index_cast_on_index
func @fold_index_cast_on_index(%arg: index) -> index {
// CHECK-NOT: size_to_index
- %casted = shape.size_to_index %arg : index
- return %casted : index
+ %0 = shape.size_to_index %arg : index
+ return %0 : index
}
// -----
@@ -1125,8 +1125,8 @@ func @fold_index_cast_on_index(%arg: index) -> index {
// CHECK-LABEL: @fold_to_extent_tensor_on_tensor
func @fold_to_extent_tensor_on_tensor(%arg: tensor<?xindex>) -> tensor<?xindex> {
// CHECK-NOT: to_extent_tensor
- %casted = shape.to_extent_tensor %arg : tensor<?xindex> -> tensor<?xindex>
- return %casted : tensor<?xindex>
+ %0 = shape.to_extent_tensor %arg : tensor<?xindex> -> tensor<?xindex>
+ return %0 : tensor<?xindex>
}
// -----
@@ -1264,9 +1264,9 @@ func @broadcast_as_from_extent_tensor(%a : tensor<?xindex>) -> !shape.shape {
// -----
-// CHECK-LABEL: @casted_extent_tensor
+// CHECK-LABEL: @cast_extent_tensor
// CHECK-SAME: (%[[ARG:.*]]: tensor<?x?x?xf32>) -> tensor<?xindex>
-func @casted_extent_tensor(%arg : tensor<?x?x?xf32>) -> tensor<?xindex> {
+func @cast_extent_tensor(%arg : tensor<?x?x?xf32>) -> tensor<?xindex> {
// CHECK: %[[RESULT:.*]] = shape.shape_of %[[ARG]] : tensor<?x?x?xf32> -> tensor<?xindex>
// CHECK: return %[[RESULT]] : tensor<?xindex>
%0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex>
@@ -1276,9 +1276,9 @@ func @casted_extent_tensor(%arg : tensor<?x?x?xf32>) -> tensor<?xindex> {
// -----
-// CHECK-LABEL: @casted_extent_tensor
+// CHECK-LABEL: @cast_extent_tensor
// CHECK-SAME: (%[[ARG:.*]]: tensor<?x?x?xf32>) -> tensor<3xindex>
-func @casted_extent_tensor(%arg : tensor<?x?x?xf32>) -> tensor<3xindex> {
+func @cast_extent_tensor(%arg : tensor<?x?x?xf32>) -> tensor<3xindex> {
// CHECK: %[[RESULT:.*]] = shape.shape_of %[[ARG]] : tensor<?x?x?xf32> -> tensor<3xindex>
// CHECK: return %[[RESULT]] : tensor<3xindex>
%0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex>
@@ -1288,8 +1288,8 @@ func @casted_extent_tensor(%arg : tensor<?x?x?xf32>) -> tensor<3xindex> {
// -----
-// CHECK-LABEL: @casted_extent_tensor
-func @casted_extent_tensor(%arg : tensor<?x?x?x?xf32>) -> tensor<3xindex> {
+// CHECK-LABEL: @cast_extent_tensor
+func @cast_extent_tensor(%arg : tensor<?x?x?x?xf32>) -> tensor<3xindex> {
// CHECK: tensor.cast %{{.*}} : tensor<?xindex> to tensor<3xindex>
%0 = shape.shape_of %arg : tensor<?x?x?x?xf32> -> tensor<?xindex>
%1 = tensor.cast %0 : tensor<?xindex> to tensor<3xindex>
@@ -1298,8 +1298,8 @@ func @casted_extent_tensor(%arg : tensor<?x?x?x?xf32>) -> tensor<3xindex> {
// -----
-// CHECK-LABEL: @casted_extent_tensor
-func @casted_extent_tensor(%arg : tensor<*xf32>) -> tensor<3xindex> {
+// CHECK-LABEL: @cast_extent_tensor
+func @cast_extent_tensor(%arg : tensor<*xf32>) -> tensor<3xindex> {
// CHECK: tensor.cast %{{.*}} : tensor<?xindex> to tensor<3xindex>
%0 = shape.shape_of %arg : tensor<*xf32> -> tensor<?xindex>
%1 = tensor.cast %0 : tensor<?xindex> to tensor<3xindex>
@@ -1335,3 +1335,21 @@ func @cstr_broadcastable_folding(%arg : tensor<?x4xf32>) {
%2 = shape.cstr_broadcastable %0, %1: tensor<2xindex>, tensor<1xindex>
"use"(%2) : (!shape.witness) -> ()
}
+
+// -----
+
+// CHECK-LABEL: @cast_extent_tensor_operands
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<?xindex>, %[[ARG1:.*]]: tensor<3xindex>)
+func @cast_extent_tensor_operands(%arg0 : tensor<?xindex>,
+ %arg1 : tensor<3xindex>) -> (!shape.witness, tensor<?xindex>) {
+ // CHECK: %[[CAST_ARG0:.*]] = tensor.cast %[[ARG0]] : tensor<?xindex> to tensor<3xindex>
+ // CHECK: %[[WIT:.*]] = shape.cstr_broadcastable %[[CAST_ARG0]], %[[ARG1]] : tensor<3xindex>, tensor<3xindex>
+ // CHECK: %[[RES:.*]] = shape.broadcast %[[CAST_ARG0]], %[[ARG1]] : tensor<3xindex>, tensor<3xindex>
+ // CHECK: return %[[WIT]], %[[RES]]
+ %0 = tensor.cast %arg0 : tensor<?xindex> to tensor<3xindex>
+ %1 = tensor.cast %arg1 : tensor<3xindex> to tensor<?xindex>
+ %2 = shape.cstr_broadcastable %0, %1 : tensor<3xindex>, tensor<?xindex>
+ %3 = shape.broadcast %0, %1 :tensor<3xindex>, tensor<?xindex>
+ -> tensor<?xindex>
+ return %2, %3 : !shape.witness, tensor<?xindex>
+}
More information about the Mlir-commits
mailing list