[Mlir-commits] [mlir] 41e5dbe - Enables inferring return types for Shape op if possible

Chia-hung Duan llvmlistbot at llvm.org
Wed Aug 18 14:38:09 PDT 2021


Author: Chia-hung Duan
Date: 2021-08-18T21:36:55Z
New Revision: 41e5dbe0fa95933c60bd70eda65af0f2d0243e39

URL: https://github.com/llvm/llvm-project/commit/41e5dbe0fa95933c60bd70eda65af0f2d0243e39
DIFF: https://github.com/llvm/llvm-project/commit/41e5dbe0fa95933c60bd70eda65af0f2d0243e39.diff

LOG: Enables inferring return types for Shape op if possible

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D102565

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/include/mlir/Interfaces/InferTypeOpInterface.td
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/test/Dialect/Shape/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index d415bb8b56225..6b39fbffe9d43 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -28,7 +28,9 @@ include "mlir/IR/SymbolInterfaces.td"
 class Shape_Op<string mnemonic, list<OpTrait> traits = []> :
     Op<ShapeDialect, mnemonic, traits>;
 
-def Shape_AddOp : Shape_Op<"add", [Commutative, NoSideEffect]> {
+def Shape_AddOp : Shape_Op<"add",
+    [Commutative, NoSideEffect,
+     DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
   let summary = "Addition of sizes and indices";
   let description = [{
     Adds two sizes or indices. If either operand is an error it will be
@@ -47,6 +49,12 @@ def Shape_AddOp : Shape_Op<"add", [Commutative, NoSideEffect]> {
   }];
 
   let verifier = [{ return verifySizeOrIndexOp(*this); }];
+
+  let extraClassDeclaration = [{
+    // Returns when two result types are compatible for this op; method used by
+    // InferTypeOpInterface
+    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
+  }];
 }
 
 def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative, NoSideEffect]> {
@@ -77,6 +85,8 @@ def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative, NoSideEffect]> {
                        OptionalAttr<StrAttr>:$error);
   let results = (outs Shape_ShapeOrExtentTensorType:$result);
 
+  let builders = [OpBuilder<(ins "Value":$shape)>];
+
   let assemblyFormat = [{
     $shapes attr-dict `:` type($shapes) `->` type($result)
   }];
@@ -145,7 +155,8 @@ def Shape_ConstSizeOp : Shape_Op<"const_size", [
   let hasFolder = 1;
 }
 
-def Shape_DivOp : Shape_Op<"div", [NoSideEffect]> {
+def Shape_DivOp : Shape_Op<"div", [NoSideEffect,
+                           DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
   let summary = "Division of sizes and indices";
   let description = [{
     Divides two sizes or indices. If either operand is an error it will be
@@ -173,10 +184,16 @@ def Shape_DivOp : Shape_Op<"div", [NoSideEffect]> {
 
   let verifier = [{ return ::verifySizeOrIndexOp(*this); }];
   let hasFolder = 1;
+
+  let extraClassDeclaration = [{
+    // Returns when two result types are compatible for this op; method used by
+    // InferTypeOpInterface
+    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
+  }];
 }
 
-def Shape_ShapeEqOp : Shape_Op<"shape_eq", [NoSideEffect, Commutative,
-                                            InferTypeOpInterface]> {
+def Shape_ShapeEqOp : Shape_Op<"shape_eq",
+    [NoSideEffect, Commutative, InferTypeOpInterface]> {
   let summary = "Returns whether the input shapes or extent tensors are equal";
   let description = [{
     Takes one or more shape or extent tensor operands and determines whether
@@ -290,7 +307,8 @@ def Shape_IsBroadcastableOp : Shape_Op<"is_broadcastable",
   let assemblyFormat = "$shapes attr-dict `:` type($shapes)";
 }
 
-def Shape_RankOp : Shape_Op<"rank", [NoSideEffect]> {
+def Shape_RankOp : Shape_Op<"rank",
+    [NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
   let summary = "Gets the rank of a shape";
   let description = [{
     Returns the rank of the shape or extent tensor, i.e. the number of extents.
@@ -304,6 +322,12 @@ def Shape_RankOp : Shape_Op<"rank", [NoSideEffect]> {
   let hasFolder = 1;
   let hasCanonicalizer = 1;
   let verifier = [{ return ::verifySizeOrIndexOp(*this); }];
+
+  let extraClassDeclaration = [{
+    // Returns when two result types are compatible for this op; method used by
+    // InferTypeOpInterface
+    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
+  }];
 }
 
 def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [NoSideEffect]> {
@@ -324,7 +348,8 @@ def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [NoSideEffect]> {
   let hasFolder = 1;
 }
 
-def Shape_GetExtentOp : Shape_Op<"get_extent", [NoSideEffect]> {
+def Shape_GetExtentOp : Shape_Op<"get_extent",
+    [NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
   let summary = "Gets the specified extent from a shape or extent tensor";
   let description = [{
     Gets the extent indexed by `dim` from the `shape` operand. If the shape is
@@ -344,6 +369,9 @@ def Shape_GetExtentOp : Shape_Op<"get_extent", [NoSideEffect]> {
   let extraClassDeclaration = [{
     /// Get the `dim` value as integer if it is constant.
     Optional<int64_t> getConstantDim();
+    /// Returns when two result types are compatible for this op; method used by
+    /// InferTypeOpInterface
+    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
   }];
 
   let hasFolder = 1;
@@ -369,7 +397,8 @@ def Shape_IndexToSizeOp : Shape_Op<"index_to_size", [NoSideEffect]> {
   let hasCanonicalizer = 1;
 }
 
-def Shape_JoinOp : Shape_Op<"join", [Commutative]> {
+def Shape_JoinOp : Shape_Op<"join",
+    [Commutative, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
   let summary = "Returns the least general shape.shape of its operands";
   let description = [{
     An operation that computes the least general shape of input operands.
@@ -405,9 +434,17 @@ def Shape_JoinOp : Shape_Op<"join", [Commutative]> {
     $arg0 `,` $arg1 (`,` `error` `=` $error^)? attr-dict `:`
       type($arg0) `,` type($arg1) `->` type($result)
   }];
+
+  let extraClassDeclaration = [{
+    // Returns when two result types are compatible for this op; method used by
+    // InferTypeOpInterface
+    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
+  }];
 }
 
-def Shape_MaxOp : Shape_Op<"max", [Commutative, NoSideEffect]> {
+def Shape_MaxOp : Shape_Op<"max",
+    [Commutative, NoSideEffect,
+     DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
   let summary = "Elementwise maximum";
   let description = [{
     Computes the elementwise maximum of two sizes or shapes with equal ranks.
@@ -424,9 +461,17 @@ def Shape_MaxOp : Shape_Op<"max", [Commutative, NoSideEffect]> {
   }];
 
   let hasFolder = 1;
+
+  let extraClassDeclaration = [{
+    // Returns when two result types are compatible for this op; method used by
+    // InferTypeOpInterface
+    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
+  }];
 }
 
-def Shape_MinOp : Shape_Op<"min", [Commutative, NoSideEffect]> {
+def Shape_MinOp : Shape_Op<"min",
+    [Commutative, NoSideEffect,
+     DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
   let summary = "Elementwise minimum";
   let description = [{
     Computes the elementwise minimum of two sizes or shapes with equal ranks.
@@ -443,9 +488,17 @@ def Shape_MinOp : Shape_Op<"min", [Commutative, NoSideEffect]> {
   }];
 
   let hasFolder = 1;
+
+  let extraClassDeclaration = [{
+    // Returns when two result types are compatible for this op; method used by
+    // InferTypeOpInterface
+    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
+  }];
 }
 
-def Shape_MulOp : Shape_Op<"mul", [Commutative, NoSideEffect]> {
+def Shape_MulOp : Shape_Op<"mul",
+    [Commutative, NoSideEffect,
+     DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
   let summary = "Multiplication of sizes and indices";
   let description = [{
     Multiplies two sizes or indices. If either operand is an error it will be
@@ -465,9 +518,16 @@ def Shape_MulOp : Shape_Op<"mul", [Commutative, NoSideEffect]> {
 
   let verifier = [{ return ::verifySizeOrIndexOp(*this); }];
   let hasFolder = 1;
+
+  let extraClassDeclaration = [{
+    // Returns when two result types are compatible for this op; method used by
+    // InferTypeOpInterface
+    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
+  }];
 }
 
-def Shape_NumElementsOp : Shape_Op<"num_elements", [NoSideEffect]> {
+def Shape_NumElementsOp : Shape_Op<"num_elements",
+    [NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
   let summary = "Returns the number of elements for a given shape";
   let description = [{
     Returns the number of elements for a given shape which is the product of its
@@ -480,12 +540,15 @@ def Shape_NumElementsOp : Shape_Op<"num_elements", [NoSideEffect]> {
   let arguments = (ins Shape_ShapeOrExtentTensorType:$shape);
   let results = (outs Shape_SizeOrIndexType:$result);
 
-  let builders = [OpBuilder<(ins "Value":$shape)>];
-
   let assemblyFormat = "$shape attr-dict `:` type($shape) `->` type($result)";
 
   let hasFolder = 1;
   let verifier = [{ return ::verifySizeOrIndexOp(*this); }];
+  let extraClassDeclaration = [{
+    // Returns when two result types are compatible for this op; method used by
+    // InferTypeOpInterface
+    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
+  }];
 }
 
 def Shape_ReduceOp : Shape_Op<"reduce",
@@ -535,7 +598,8 @@ def Shape_ReduceOp : Shape_Op<"reduce",
   let parser = [{ return ::parse$cppClass(parser, result); }];
 }
 
-def Shape_ShapeOfOp : Shape_Op<"shape_of", [NoSideEffect]> {
+def Shape_ShapeOfOp : Shape_Op<"shape_of",
+    [NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
   let summary = "Returns shape of a value or shaped type operand";
 
   let description = [{
@@ -548,11 +612,15 @@ def Shape_ShapeOfOp : Shape_Op<"shape_of", [NoSideEffect]> {
 
   let assemblyFormat = "$arg attr-dict `:` type($arg) `->` type($result)";
 
-  let builders = [OpBuilder<(ins "Value":$arg)>];
-
   let verifier = [{ return ::verifyShapeOrExtentTensorOp(*this); }];
   let hasCanonicalizer = 1;
   let hasFolder = 1;
+
+  let extraClassDeclaration = [{
+    // Returns when two result types are compatible for this op; method used by
+    // InferTypeOpInterface
+    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
+  }];
 }
 
 def Shape_SizeToIndexOp : Shape_Op<"size_to_index", [NoSideEffect]> {

diff  --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
index fe7c8eeb2e134..1f604e25bf910 100644
--- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
+++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
@@ -34,7 +34,9 @@ def InferTypeOpInterface : OpInterface<"InferTypeOpInterface"> {
       The method takes an optional location which, if set, will be used to
       report errors on. The operands and attributes correspond to those with
       which an Operation would be created (e.g., as used in Operation::create)
-      and the regions of the op.
+      and the regions of the op. Be aware that this method is supposed to be
+      called with valid arguments, e.g., operands are verified, or it may result
+      in an undefined behavior.
       }],
       /*retTy=*/"::mlir::LogicalResult",
       /*methodName=*/"inferReturnTypes",

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index f75bfc5894b6a..7c17455cb3ae6 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -89,6 +89,16 @@ static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) {
   return success();
 }
 
+template <typename... Ty>
+static bool eachHasOnlyOneOfTypes(TypeRange typeRange) {
+  return typeRange.size() == 1 && typeRange.front().isa<Ty...>();
+}
+
+template <typename... Ty, typename... ranges>
+static bool eachHasOnlyOneOfTypes(TypeRange l, ranges... rs) {
+  return eachHasOnlyOneOfTypes<Ty...>(l) && eachHasOnlyOneOfTypes<Ty...>(rs...);
+}
+
 //===----------------------------------------------------------------------===//
 // InlinerInterface
 //===----------------------------------------------------------------------===//
@@ -404,6 +414,27 @@ void AssumingOp::build(
   result.addTypes(assumingTypes);
 }
 
+//===----------------------------------------------------------------------===//
+// AddOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult mlir::shape::AddOp::inferReturnTypes(
+    MLIRContext *context, Optional<Location> location, ValueRange operands,
+    DictionaryAttr attributes, RegionRange regions,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  if (operands[0].getType().isa<SizeType>() ||
+      operands[1].getType().isa<SizeType>())
+    inferredReturnTypes.assign({SizeType::get(context)});
+  else
+    inferredReturnTypes.assign({IndexType::get(context)});
+  return success();
+}
+
+bool mlir::shape::AddOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
+  // SizeType is compatible with IndexType.
+  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
+}
+
 //===----------------------------------------------------------------------===//
 // AssumingAllOp
 //===----------------------------------------------------------------------===//
@@ -955,6 +986,23 @@ OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) {
   return IntegerAttr::get(indexTy, quotient);
 }
 
+LogicalResult mlir::shape::DivOp::inferReturnTypes(
+    MLIRContext *context, Optional<Location> location, ValueRange operands,
+    DictionaryAttr attributes, RegionRange regions,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  if (operands[0].getType().isa<SizeType>() ||
+      operands[1].getType().isa<SizeType>())
+    inferredReturnTypes.assign({SizeType::get(context)});
+  else
+    inferredReturnTypes.assign({IndexType::get(context)});
+  return success();
+}
+
+bool mlir::shape::DivOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
+  // SizeType is compatible with IndexType.
+  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
+}
+
 //===----------------------------------------------------------------------===//
 // ShapeEqOp
 //===----------------------------------------------------------------------===//
@@ -1096,6 +1144,20 @@ void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
   }
 }
 
+LogicalResult mlir::shape::GetExtentOp::inferReturnTypes(
+    MLIRContext *context, Optional<Location> location, ValueRange operands,
+    DictionaryAttr attributes, RegionRange regions,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  inferredReturnTypes.assign({IndexType::get(context)});
+  return success();
+}
+
+bool mlir::shape::GetExtentOp::isCompatibleReturnTypes(TypeRange l,
+                                                       TypeRange r) {
+  // SizeType is compatible with IndexType.
+  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
+}
+
 //===----------------------------------------------------------------------===//
 // IsBroadcastableOp
 //===----------------------------------------------------------------------===//
@@ -1114,6 +1176,38 @@ OpFoldResult IsBroadcastableOp::fold(ArrayRef<Attribute> operands) {
   return nullptr;
 }
 
+//===----------------------------------------------------------------------===//
+// JoinOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult mlir::shape::JoinOp::inferReturnTypes(
+    MLIRContext *context, Optional<Location> location, ValueRange operands,
+    DictionaryAttr attributes, RegionRange regions,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  inferredReturnTypes.assign({operands[0].getType()});
+  return success();
+}
+
+bool mlir::shape::JoinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
+  if (l.size() != 1 || r.size() != 1)
+    return false;
+  if (l == r)
+    return true;
+
+  Type lhs = l.front();
+  Type rhs = r.front();
+
+  if (lhs != rhs)
+    return false;
+
+  if (lhs.isa<SizeType>() || lhs.isa<ShapeType>())
+    return true;
+
+  if (succeeded(verifyCompatibleShapes({lhs, rhs})))
+    return true;
+  return false;
+}
+
 //===----------------------------------------------------------------------===//
 // RankOp
 //===----------------------------------------------------------------------===//
@@ -1173,6 +1267,22 @@ void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
   patterns.add<RankShapeOfCanonicalizationPattern>(context);
 }
 
+LogicalResult mlir::shape::RankOp::inferReturnTypes(
+    MLIRContext *context, Optional<Location> location, ValueRange operands,
+    DictionaryAttr attributes, RegionRange regions,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  if (operands[0].getType().isa<ShapeType>())
+    inferredReturnTypes.assign({SizeType::get(context)});
+  else
+    inferredReturnTypes.assign({IndexType::get(context)});
+  return success();
+}
+
+bool mlir::shape::RankOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
+  // SizeType is compatible with IndexType.
+  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
+}
+
 //===----------------------------------------------------------------------===//
 // NumElementsOp
 //===----------------------------------------------------------------------===//
@@ -1191,14 +1301,21 @@ OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) {
   return builder.getIndexAttr(product.getLimitedValue());
 }
 
-void NumElementsOp::build(OpBuilder &builder, OperationState &result,
-                          Value shape) {
-  if (shape.getType().isa<ShapedType>()) {
-    auto type = builder.getIndexType();
-    return build(builder, result, type, shape);
-  }
-  auto type = SizeType::get(builder.getContext());
-  return build(builder, result, type, shape);
+LogicalResult mlir::shape::NumElementsOp::inferReturnTypes(
+    MLIRContext *context, Optional<Location> location, ValueRange operands,
+    DictionaryAttr attributes, RegionRange regions,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  if (operands[0].getType().isa<ShapeType>())
+    inferredReturnTypes.assign({SizeType::get(context)});
+  else
+    inferredReturnTypes.assign({IndexType::get(context)});
+  return success();
+}
+
+bool mlir::shape::NumElementsOp::isCompatibleReturnTypes(TypeRange l,
+                                                         TypeRange r) {
+  // SizeType is compatible with IndexType.
+  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
 }
 
 //===----------------------------------------------------------------------===//
@@ -1212,6 +1329,27 @@ OpFoldResult MaxOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
   return nullptr;
 }
 
+LogicalResult mlir::shape::MaxOp::inferReturnTypes(
+    MLIRContext *context, Optional<Location> location, ValueRange operands,
+    DictionaryAttr attributes, RegionRange regions,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  if (operands[0].getType() == operands[1].getType())
+    inferredReturnTypes.assign({operands[0].getType()});
+  else
+    inferredReturnTypes.assign({SizeType::get(context)});
+  return success();
+}
+
+bool mlir::shape::MaxOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
+  if (l.size() != 1 || r.size() != 1)
+    return false;
+  if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>())
+    return true;
+  if (l.front().isa<SizeType>() && r.front().isa<SizeType>())
+    return true;
+  return false;
+}
+
 //===----------------------------------------------------------------------===//
 // MinOp
 //===----------------------------------------------------------------------===//
@@ -1223,6 +1361,27 @@ OpFoldResult MinOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
   return nullptr;
 }
 
+LogicalResult mlir::shape::MinOp::inferReturnTypes(
+    MLIRContext *context, Optional<Location> location, ValueRange operands,
+    DictionaryAttr attributes, RegionRange regions,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  if (operands[0].getType() == operands[1].getType())
+    inferredReturnTypes.assign({operands[0].getType()});
+  else
+    inferredReturnTypes.assign({SizeType::get(context)});
+  return success();
+}
+
+bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
+  if (l.size() != 1 || r.size() != 1)
+    return false;
+  if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>())
+    return true;
+  if (l.front().isa<SizeType>() && r.front().isa<SizeType>())
+    return true;
+  return false;
+}
+
 //===----------------------------------------------------------------------===//
 // MulOp
 //===----------------------------------------------------------------------===//
@@ -1239,6 +1398,22 @@ OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
   return IntegerAttr::get(indexTy, folded);
 }
 
+LogicalResult mlir::shape::MulOp::inferReturnTypes(
+    MLIRContext *context, Optional<Location> location, ValueRange operands,
+    DictionaryAttr attributes, RegionRange regions,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  if (operands[0].getType().isa<SizeType>() ||
+      operands[1].getType().isa<SizeType>())
+    inferredReturnTypes.assign({SizeType::get(context)});
+  else
+    inferredReturnTypes.assign({IndexType::get(context)});
+  return success();
+}
+
+bool mlir::shape::MulOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
+  // SizeType is compatible with IndexType.
+  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
+}
 //===----------------------------------------------------------------------===//
 // ShapeOfOp
 //===----------------------------------------------------------------------===//
@@ -1251,18 +1426,6 @@ OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
   return builder.getIndexTensorAttr(type.getShape());
 }
 
-void ShapeOfOp::build(OpBuilder &builder, OperationState &result, Value arg) {
-  if (auto shapedTy = arg.getType().dyn_cast<ShapedType>()) {
-    int64_t rank =
-        shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamicSize;
-    Type indexTy = builder.getIndexType();
-    Type extentTensorTy = RankedTensorType::get({rank}, indexTy);
-    return ShapeOfOp::build(builder, result, extentTensorTy, arg);
-  }
-  Type shapeTy = builder.getType<ShapeType>();
-  return ShapeOfOp::build(builder, result, shapeTy, arg);
-}
-
 namespace {
 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
   using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
@@ -1317,6 +1480,44 @@ void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
   patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor>(context);
 }
 
+LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes(
+    MLIRContext *context, Optional<Location> location, ValueRange operands,
+    DictionaryAttr attributes, RegionRange regions,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  if (operands[0].getType().isa<ValueShapeType>())
+    inferredReturnTypes.assign({ShapeType::get(context)});
+  else {
+    auto shapedTy = operands[0].getType().cast<ShapedType>();
+    int64_t rank =
+        shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamicSize;
+    Type indexTy = IndexType::get(context);
+    Type extentTensorTy = RankedTensorType::get({rank}, indexTy);
+    inferredReturnTypes.assign({extentTensorTy});
+  }
+  return success();
+}
+
+bool mlir::shape::ShapeOfOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
+  if (l.size() != 1 || r.size() != 1)
+    return false;
+  if (l == r)
+    return true;
+
+  Type lhs = l.front();
+  Type rhs = r.front();
+
+  if (!lhs.isa<ShapeType, ShapedType>() || !rhs.isa<ShapeType, ShapedType>())
+    return false;
+
+  if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>())
+    // Shape type is compatible with all other valid return types.
+    return true;
+
+  if (succeeded(verifyCompatibleShapes({lhs, rhs})))
+    return true;
+  return false;
+}
+
 //===----------------------------------------------------------------------===//
 // SizeToIndexOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Shape/invalid.mlir b/mlir/test/Dialect/Shape/invalid.mlir
index c605e25b3873c..030926a9cce4b 100644
--- a/mlir/test/Dialect/Shape/invalid.mlir
+++ b/mlir/test/Dialect/Shape/invalid.mlir
@@ -97,6 +97,14 @@ func @shape_of(%value_arg : !shape.value_shape,
 
 // -----
 
+func @shape_of_incompatible_return_types(%value_arg : tensor<1x2xindex>) {
+  // expected-error at +1 {{'shape.shape_of' op inferred type(s) 'tensor<2xindex>' are incompatible with return type(s) of operation 'tensor<3xf32>'}}
+  %0 = shape.shape_of %value_arg : tensor<1x2xindex> -> tensor<3xf32>
+  return
+}
+
+// -----
+
 func @rank(%arg : !shape.shape) {
   // expected-error at +1 {{if at least one of the operands can hold error values then the result must be of type `size` to propagate them}}
   %0 = shape.rank %arg : !shape.shape -> index


        


More information about the Mlir-commits mailing list