[Mlir-commits] [mlir] 57a7cd7 - [shape] Add inferReturnTypes to a couple ops.
Sean Silva
llvmlistbot at llvm.org
Fri Apr 24 16:10:41 PDT 2020
Author: Sean Silva
Date: 2020-04-24T16:10:20-07:00
New Revision: 57a7cd7a138fed24e109a02dbd8f7d464bf7e177
URL: https://github.com/llvm/llvm-project/commit/57a7cd7a138fed24e109a02dbd8f7d464bf7e177
DIFF: https://github.com/llvm/llvm-project/commit/57a7cd7a138fed24e109a02dbd8f7d464bf7e177.diff
LOG: [shape] Add inferReturnTypes to a couple ops.
- ShapeOfOp
- BroadcastOp
Differential Revision: https://reviews.llvm.org/D78822
Added:
Modified:
mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/lib/Dialect/Shape/IR/Shape.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index f54456b862fa..fa277f4f89de 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -130,7 +130,8 @@ def Shape_AddOp : Shape_Op<"add", [SameOperandsAndResultType]> {
let results = (outs Shape_SizeType:$result);
}
-def Shape_BroadcastOp : Shape_Op<"broadcast", []> {
+def Shape_BroadcastOp : Shape_Op<"broadcast",
+ [DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "Returns the broadcasted output shape of two inputs";
let description = [{
Computes the broadcasted output shape following:
@@ -317,7 +318,8 @@ def Shape_ReduceOp : Shape_Op<"reduce", []> {
let regions = (region SizedRegion<1>:$body);
}
-def Shape_ShapeOfOp : Shape_Op<"shape_of", []> {
+def Shape_ShapeOfOp : Shape_Op<"shape_of",
+ [DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "Returns shape of a value or shaped type operand";
let arguments = (ins AnyTypeOf<[AnyShaped, Shape_ValueShapeType]>:$arg);
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 4a1c0f1d5128..10e766f3cc61 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -92,6 +92,14 @@ void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const {
// BroadcastOp
//===----------------------------------------------------------------------===//
+LogicalResult BroadcastOp::inferReturnTypes(
+ MLIRContext *context, Optional<Location> location, ValueRange operands,
+ ArrayRef<NamedAttribute> attributes, RegionRange regions,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ inferredReturnTypes.push_back(ShapeType::get(context));
+ return success();
+}
+
OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
if (!operands[0] || !operands[1])
return nullptr;
@@ -175,6 +183,14 @@ LogicalResult ConstSizeOp::inferReturnTypes(
// ShapeOfOp
//===----------------------------------------------------------------------===//
+LogicalResult ShapeOfOp::inferReturnTypes(
+ MLIRContext *context, Optional<Location> location, ValueRange operands,
+ ArrayRef<NamedAttribute> attributes, RegionRange regions,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ inferredReturnTypes.push_back(ShapeType::get(context));
+ return success();
+}
+
OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
auto type = getOperand().getType().dyn_cast<ShapedType>();
if (!type || !type.hasStaticShape())
More information about the Mlir-commits
mailing list