[Mlir-commits] [mlir] dca5361 - [MLIR][Shape] Concretize broadcast result type if possible
Frederik Gossen
llvmlistbot at llvm.org
Wed Apr 28 02:58:54 PDT 2021
Author: Frederik Gossen
Date: 2021-04-28T11:58:32+02:00
New Revision: dca536103592cf1e92aa8316ed23f33d75da25bc
URL: https://github.com/llvm/llvm-project/commit/dca536103592cf1e92aa8316ed23f33d75da25bc
DIFF: https://github.com/llvm/llvm-project/commit/dca536103592cf1e92aa8316ed23f33d75da25bc.diff
LOG: [MLIR][Shape] Concretize broadcast result type if possible
As a canonicalization, infer the resulting shape rank if possible.
Differential Revision: https://reviews.llvm.org/D101377
Added:
Modified:
mlir/include/mlir/Dialect/Shape/IR/Shape.h
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/test/Dialect/Shape/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Shape/IR/Shape.h b/mlir/include/mlir/Dialect/Shape/IR/Shape.h
index 570719eff64d..08c5d5ddbc82 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/Shape.h
+++ b/mlir/include/mlir/Dialect/Shape/IR/Shape.h
@@ -29,7 +29,8 @@ class PatternRewriter;
namespace shape {
/// Alias type for extent tensors.
-RankedTensorType getExtentTensorType(MLIRContext *ctx);
+RankedTensorType getExtentTensorType(MLIRContext *ctx,
+ int64_t rank = ShapedType::kDynamicSize);
// Check if a type is an extent tensor, e.g., tensor<?xindex>.
bool isExtentTensorType(Type);
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index fd012aa84d1c..ac67a62a0aef 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -27,8 +27,8 @@ namespace {
#include "ShapeCanonicalization.inc"
}
-RankedTensorType shape::getExtentTensorType(MLIRContext *ctx) {
- return RankedTensorType::get({ShapedType::kDynamicSize}, IndexType::get(ctx));
+RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) {
+ return RankedTensorType::get({rank}, IndexType::get(ctx));
}
bool shape::isExtentTensorType(Type type) {
@@ -660,11 +660,42 @@ struct CanonicalizeCastExtentTensorOperandsPattern
return success();
}
};
+
+struct BroadcastConcretizeResultTypePattern
+ : public OpRewritePattern<BroadcastOp> {
+ using OpRewritePattern<BroadcastOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(BroadcastOp op,
+ PatternRewriter &rewriter) const override {
+ // Only concretize dynamic extent tensor result types.
+ auto resultTy = op.getType().dyn_cast<RankedTensorType>();
+ if (!resultTy || !resultTy.isDynamicDim(0))
+ return failure();
+
+ // Infer resulting shape rank if possible.
+ int64_t maxRank = 0;
+ for (Value shape : op.shapes()) {
+ if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) {
+ // Cannot infer resulting shape rank if any operand is dynamically
+ // ranked.
+ if (extentTensorTy.isDynamicDim(0))
+ return failure();
+ maxRank = std::max(maxRank, extentTensorTy.getDimSize(0));
+ }
+ }
+
+ auto newOp = rewriter.create<BroadcastOp>(
+ op.getLoc(), getExtentTensorType(getContext(), maxRank), op.shapes());
+ rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
+ return success();
+ }
+};
} // namespace
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
- patterns.add<BroadcastFoldConstantOperandsPattern,
+ patterns.add<BroadcastConcretizeResultTypePattern,
+ BroadcastFoldConstantOperandsPattern,
BroadcastForwardSingleOperandPattern,
CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>,
RemoveDuplicateOperandsPattern<BroadcastOp>,
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 367ce7f6ba1a..6e0243839132 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -1344,7 +1344,8 @@ func @cast_extent_tensor_operands(%arg0 : tensor<?xindex>,
%arg1 : tensor<3xindex>) -> (!shape.witness, tensor<?xindex>) {
// CHECK: %[[CAST_ARG0:.*]] = tensor.cast %[[ARG0]] : tensor<?xindex> to tensor<3xindex>
// CHECK: %[[WIT:.*]] = shape.cstr_broadcastable %[[CAST_ARG0]], %[[ARG1]] : tensor<3xindex>, tensor<3xindex>
- // CHECK: %[[RES:.*]] = shape.broadcast %[[CAST_ARG0]], %[[ARG1]] : tensor<3xindex>, tensor<3xindex>
+ // CHECK: %[[UNCAST_RES:.*]] = shape.broadcast %[[CAST_ARG0]], %[[ARG1]] : tensor<3xindex>, tensor<3xindex> -> tensor<3xindex>
+ // CHECK: %[[RES:.*]] = tensor.cast %[[UNCAST_RES]] : tensor<3xindex> to tensor<?xindex>
// CHECK: return %[[WIT]], %[[RES]]
%0 = tensor.cast %arg0 : tensor<?xindex> to tensor<3xindex>
%1 = tensor.cast %arg1 : tensor<3xindex> to tensor<?xindex>
@@ -1353,3 +1354,17 @@ func @cast_extent_tensor_operands(%arg0 : tensor<?xindex>,
-> tensor<?xindex>
return %2, %3 : !shape.witness, tensor<?xindex>
}
+
+// -----
+
+// CHECK-LABEL: @concretize_broadcast_result_type
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<2xindex>, %[[ARG1:.*]]: tensor<3xindex>)
+func @concretize_broadcast_result_type(%arg0 : tensor<2xindex>,
+ %arg1 : tensor<3xindex>) -> tensor<?xindex> {
+ // CHECK: %[[CONCR:.*]] = shape.broadcast %[[ARG0]], %[[ARG1]] : tensor<2xindex>, tensor<3xindex> -> tensor<3xindex>
+ // CHECK: %[[RES:.*]] = tensor.cast %[[CONCR]] : tensor<3xindex> to tensor<?xindex>
+ // CHECK: return %[[RES]]
+ %0 = shape.broadcast %arg0, %arg1 : tensor<2xindex>, tensor<3xindex>
+ -> tensor<?xindex>
+ return %0 : tensor<?xindex>
+}
More information about the Mlir-commits
mailing list