[Mlir-commits] [mlir] 3ce2ee2 - [mlir][ODS] Infer return types if the operands are variadic but the results are not

Benjamin Kramer llvmlistbot at llvm.org
Fri Feb 18 06:34:23 PST 2022


Author: Benjamin Kramer
Date: 2022-02-18T15:29:06+01:00
New Revision: 3ce2ee28f042c2a00d09c228c76f2692778bd607

URL: https://github.com/llvm/llvm-project/commit/3ce2ee28f042c2a00d09c228c76f2692778bd607
DIFF: https://github.com/llvm/llvm-project/commit/3ce2ee28f042c2a00d09c228c76f2692778bd607.diff

LOG: [mlir][ODS] Infer return types if the operands are variadic but the results are not

Clean up code that worked around this limitation.

Differential Revision: https://reviews.llvm.org/D120119

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/lib/TableGen/Operator.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 19b08642c8a66..31a69522c86c3 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -192,8 +192,7 @@ def Shape_DivOp : Shape_Op<"div", [NoSideEffect,
   }];
 }
 
-def Shape_ShapeEqOp : Shape_Op<"shape_eq",
-    [NoSideEffect, Commutative, InferTypeOpInterface]> {
+def Shape_ShapeEqOp : Shape_Op<"shape_eq", [NoSideEffect, Commutative]> {
   let summary = "Returns whether the input shapes or extent tensors are equal";
   let description = [{
     Takes one or more shape or extent tensor operands and determines whether
@@ -211,17 +210,6 @@ def Shape_ShapeEqOp : Shape_Op<"shape_eq",
   OpBuilder<(ins "::mlir::Value":$lhs, "::mlir::Value":$rhs),
     [{ build($_builder, $_state, ::llvm::makeArrayRef({lhs, rhs})); }]>,
   ];
-  let extraClassDeclaration = [{
-    // TODO: This should really be automatic. Figure out how to not need this defined.
-    static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context,
-    ::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands,
-    ::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions,
-    ::llvm::SmallVectorImpl<::mlir::Type>&inferredReturnTypes) {
-      inferredReturnTypes.push_back(::mlir::IntegerType::get(context,
-                                                             /*width=*/1));
-      return success();
-    };
-  }];
 
   let assemblyFormat = "$shapes attr-dict `:` type($shapes)";
   let hasFolder = 1;
@@ -262,8 +250,7 @@ def Shape_FromExtentTensorOp : Shape_Op<"from_extent_tensor", [NoSideEffect]> {
   let assemblyFormat = "$input attr-dict `:` type($input)";
 }
 
-def Shape_IsBroadcastableOp : Shape_Op<"is_broadcastable",
-                                       [Commutative, InferTypeOpInterface]> {
+def Shape_IsBroadcastableOp : Shape_Op<"is_broadcastable", [Commutative]> {
   let summary = "Determines if 2+ shapes can be successfully broadcasted";
   let description = [{
     Given multiple input shapes or extent tensors, return a predicate specifying
@@ -289,17 +276,6 @@ def Shape_IsBroadcastableOp : Shape_Op<"is_broadcastable",
   OpBuilder<(ins "::mlir::Value":$lhs, "::mlir::Value":$rhs),
     [{ build($_builder, $_state, ::llvm::makeArrayRef({lhs, rhs})); }]>,
   ];
-  let extraClassDeclaration = [{
-    // TODO: This should really be automatic. Figure out how to not need this defined.
-    static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context,
-    ::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands,
-    ::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions,
-    ::llvm::SmallVectorImpl<::mlir::Type>&inferredReturnTypes) {
-      inferredReturnTypes.push_back(::mlir::IntegerType::get(context,
-                                                             /*width=*/1));
-      return success();
-    };
-  }];
 
   let hasFolder = 1;
   let hasCanonicalizer = 1;
@@ -850,12 +826,6 @@ def Shape_AssumingAllOp : Shape_Op<"assuming_all", [Commutative, NoSideEffect]>
   let arguments = (ins Variadic<Shape_WitnessType>:$inputs);
   let results = (outs Shape_WitnessType:$result);
 
-  // Only needed while tablegen is unable to generate this for ops with variadic
-  // arguments.
-  let builders = [
-    OpBuilder<(ins "ValueRange":$inputs)>,
-  ];
-
   let assemblyFormat = "$inputs attr-dict";
 
   let hasFolder = 1;
@@ -917,8 +887,7 @@ def Shape_AssumingYieldOp : Shape_Op<"assuming_yield",
   let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
 }
 
-def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable",
-                                         [Commutative, InferTypeOpInterface]> {
+def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable", [Commutative]> {
   let summary = "Determines if 2+ shapes can be successfully broadcasted";
   let description = [{
     Given input shapes or extent tensors, return a witness specifying if they
@@ -944,23 +913,12 @@ def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable",
     [{ build($_builder, $_state, ::llvm::makeArrayRef({lhs, rhs})); }]>,
   ];
 
-  let extraClassDeclaration = [{
-    // TODO: This should really be automatic. Figure out how to not need this defined.
-    static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context,
-    ::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands,
-    ::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions,
-    ::llvm::SmallVectorImpl<::mlir::Type>&inferredReturnTypes) {
-      inferredReturnTypes.push_back(::mlir::shape::WitnessType::get(context));
-      return success();
-    };
-  }];
-
   let hasCanonicalizer = 1;
   let hasFolder = 1;
   let hasVerifier = 1;
 }
 
-def Shape_CstrEqOp : Shape_Op<"cstr_eq", [Commutative, InferTypeOpInterface]> {
+def Shape_CstrEqOp : Shape_Op<"cstr_eq", [Commutative]> {
   let summary = "Determines if all input shapes are equal";
   let description = [{
     Given 1 or more input shapes, determine if all shapes are the exact same.
@@ -978,17 +936,6 @@ def Shape_CstrEqOp : Shape_Op<"cstr_eq", [Commutative, InferTypeOpInterface]> {
 
   let assemblyFormat = "$shapes attr-dict `:` type($shapes)";
 
-  let extraClassDeclaration = [{
-    // TODO: This should really be automatic. Figure out how to not need this defined.
-    static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context,
-    ::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands,
-    ::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions,
-    ::llvm::SmallVectorImpl<::mlir::Type>&inferredReturnTypes) {
-      inferredReturnTypes.push_back(::mlir::shape::WitnessType::get(context));
-      return success();
-    };
-  }];
-
   let hasCanonicalizer = 1;
   let hasFolder = 1;
 }

diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index aec8dd5b68823..2e7f06903824f 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -803,8 +803,6 @@ def Vector_InsertMapOp :
       into vector<64x4x32xf32>
     ```
   }];
-  let builders = [OpBuilder<(ins "Value":$vector, "Value":$dest,
-                                "ValueRange":$ids)>];
   let extraClassDeclaration = [{
     VectorType getSourceVectorType() {
       return vector().getType().cast<VectorType>();

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 5c851f579ef85..0f633312eaddc 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -663,11 +663,6 @@ LogicalResult AssumingAllOp::verify() {
   return success();
 }
 
-void AssumingAllOp::build(OpBuilder &b, OperationState &state,
-                          ValueRange inputs) {
-  build(b, state, b.getType<WitnessType>(), inputs);
-}
-
 //===----------------------------------------------------------------------===//
 // BroadcastOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index ddfe0d8442280..f6547e46d5418 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1900,11 +1900,6 @@ OpFoldResult vector::InsertOp::fold(ArrayRef<Attribute> operands) {
 // InsertMapOp
 //===----------------------------------------------------------------------===//
 
-void InsertMapOp::build(OpBuilder &builder, OperationState &result,
-                        Value vector, Value dest, ValueRange ids) {
-  InsertMapOp::build(builder, result, dest.getType(), vector, dest, ids);
-}
-
 LogicalResult InsertMapOp::verify() {
   if (getSourceVectorType().getRank() != getResultType().getRank())
     return emitOpError("expected source and destination vectors of same rank");

diff  --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp
index a71ae4d642b13..2a0d49fcfccf6 100644
--- a/mlir/lib/TableGen/Operator.cpp
+++ b/mlir/lib/TableGen/Operator.cpp
@@ -327,9 +327,8 @@ void Operator::populateTypeInferenceInfo(
   if (getNumResults() == 0)
     return;
 
-  // Skip for ops with variadic operands/results.
-  // TODO: This can be relaxed.
-  if (isVariadic())
+  // Skip ops with variadic or optional results.
+  if (getNumVariableLengthResults() > 0)
     return;
 
   // Skip cases currently being custom generated.


        


More information about the Mlir-commits mailing list