[Mlir-commits] [mlir] 5267ed0 - [ODS] Use Adaptor Traits for Type Inference

Amanda Tang llvmlistbot at llvm.org
Tue Jul 18 10:58:57 PDT 2023


Author: Amanda Tang
Date: 2023-07-18T17:58:31Z
New Revision: 5267ed05bc4612e91409d63b4dbc4e01751acb75

URL: https://github.com/llvm/llvm-project/commit/5267ed05bc4612e91409d63b4dbc4e01751acb75
DIFF: https://github.com/llvm/llvm-project/commit/5267ed05bc4612e91409d63b4dbc4e01751acb75.diff

LOG: [ODS] Use Adaptor Traits for Type Inference

Author inferReturnTypes methods with the Op Adaptor by using the InferTypeOpAdaptor.

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D155115

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
    mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/include/mlir/Interfaces/InferTypeOpInterface.h
    mlir/include/mlir/Interfaces/InferTypeOpInterface.td
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/lib/Dialect/SCF/IR/SCF.cpp
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/lib/Dialect/Test/TestDialect.cpp
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/lib/Dialect/Test/TestPatterns.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 0c3f96ff70b9c4..ea6e363a6c3257 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -903,7 +903,7 @@ def MemRef_ExtractStridedMetadataOp : MemRef_Op<"extract_strided_metadata", [
     Pure,
     SameVariadicResultSize,
     ViewLikeOpInterface,
-    InferTypeOpInterfaceAdaptor]> {
+    InferTypeOpAdaptor]> {
   let summary = "Extracts a buffer base with offset and strides";
   let description = [{
     Extracts a base buffer, offset and strides. This op allows additional layers

diff  --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 58b720c6e39637..db7b41afb7ed37 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -713,9 +713,8 @@ def InParallelOp : SCF_Op<"forall.in_parallel", [
 
 def IfOp : SCF_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterface, [
     "getNumRegionInvocations", "getRegionInvocationBounds"]>,
-    DeclareOpInterfaceMethods<InferTypeOpInterface>,
-    SingleBlockImplicitTerminator<"scf::YieldOp">, RecursiveMemoryEffects,
-    NoRegionArguments]> {
+    InferTypeOpAdaptor, SingleBlockImplicitTerminator<"scf::YieldOp">,
+    RecursiveMemoryEffects, NoRegionArguments]> {
   let summary = "if-then-else operation";
   let description = [{
     The `scf.if` operation represents an if-then-else construct for

diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index dce5be8207c82b..3109c4bdb36182 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -32,8 +32,7 @@ class Shape_Op<string mnemonic, list<Trait> traits = []> :
     Op<ShapeDialect, mnemonic, traits>;
 
 def Shape_AddOp : Shape_Op<"add",
-    [Commutative, Pure,
-     DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+    [Commutative, Pure, InferTypeOpAdaptorWithIsCompatible]> {
   let summary = "Addition of sizes and indices";
   let description = [{
     Adds two sizes or indices. If either operand is an error it will be
@@ -51,12 +50,6 @@ def Shape_AddOp : Shape_Op<"add",
     $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)
   }];
 
-  let extraClassDeclaration = [{
-    // Returns when two result types are compatible for this op; method used by
-    // InferTypeOpInterface
-    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
-  }];
-
   let hasFolder = 1;
   let hasVerifier = 1;
 }
@@ -109,7 +102,7 @@ def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative, Pure]> {
 }
 
 def Shape_ConstShapeOp : Shape_Op<"const_shape",
-    [ConstantLike, Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+    [ConstantLike, Pure, InferTypeOpAdaptorWithIsCompatible]> {
   let summary = "Creates a constant shape or extent tensor";
   let description = [{
     Creates a constant shape or extent tensor. The individual extents are given
@@ -128,11 +121,6 @@ def Shape_ConstShapeOp : Shape_Op<"const_shape",
   let hasCustomAssemblyFormat = 1;
   let hasFolder = 1;
   let hasCanonicalizer = 1;
-
-  let extraClassDeclaration = [{
-    // InferTypeOpInterface:
-    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
-  }];
 }
 
 def Shape_ConstSizeOp : Shape_Op<"const_size", [
@@ -158,8 +146,7 @@ def Shape_ConstSizeOp : Shape_Op<"const_size", [
   let hasFolder = 1;
 }
 
-def Shape_DivOp : Shape_Op<"div", [Pure,
-                           DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+def Shape_DivOp : Shape_Op<"div", [Pure, InferTypeOpAdaptorWithIsCompatible]> {
   let summary = "Division of sizes and indices";
   let description = [{
     Divides two sizes or indices. If either operand is an error it will be
@@ -187,12 +174,6 @@ def Shape_DivOp : Shape_Op<"div", [Pure,
 
   let hasFolder = 1;
   let hasVerifier = 1;
-
-  let extraClassDeclaration = [{
-    // Returns when two result types are compatible for this op; method used by
-    // InferTypeOpInterface
-    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
-  }];
 }
 
 def Shape_ShapeEqOp : Shape_Op<"shape_eq", [Pure, Commutative]> {
@@ -287,7 +268,7 @@ def Shape_IsBroadcastableOp : Shape_Op<"is_broadcastable", [Commutative]> {
 }
 
 def Shape_RankOp : Shape_Op<"rank",
-    [Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+    [Pure, InferTypeOpAdaptorWithIsCompatible]> {
   let summary = "Gets the rank of a shape";
   let description = [{
     Returns the rank of the shape or extent tensor, i.e. the number of extents.
@@ -301,12 +282,6 @@ def Shape_RankOp : Shape_Op<"rank",
   let hasFolder = 1;
   let hasCanonicalizer = 1;
   let hasVerifier = 1;
-
-  let extraClassDeclaration = [{
-    // Returns when two result types are compatible for this op; method used by
-    // InferTypeOpInterface
-    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
-  }];
 }
 
 def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [
@@ -330,7 +305,7 @@ def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [
 }
 
 def Shape_DimOp : Shape_Op<"dim",
-    [Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+    [Pure, InferTypeOpAdaptorWithIsCompatible]> {
   let summary = "Gets the specified extent from the shape of a shaped input";
   let description = [{
     Gets the extent indexed by `dim` from the shape of the `value` operand. If
@@ -354,17 +329,13 @@ def Shape_DimOp : Shape_Op<"dim",
   let extraClassDeclaration = [{
     /// Get the `index` value as integer if it is constant.
     std::optional<int64_t> getConstantIndex();
-
-    /// Returns when two result types are compatible for this op; method used
-    /// by InferTypeOpInterface
-    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
   }];
 
   let hasFolder = 1;
 }
 
 def Shape_GetExtentOp : Shape_Op<"get_extent",
-    [Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+    [Pure, InferTypeOpAdaptorWithIsCompatible]> {
   let summary = "Gets the specified extent from a shape or extent tensor";
   let description = [{
     Gets the extent indexed by `dim` from the `shape` operand. If the shape is
@@ -384,9 +355,6 @@ def Shape_GetExtentOp : Shape_Op<"get_extent",
   let extraClassDeclaration = [{
     /// Get the `dim` value as integer if it is constant.
     std::optional<int64_t> getConstantDim();
-    /// Returns when two result types are compatible for this op; method used
-    /// by InferTypeOpInterface
-    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
   }];
 
   let hasFolder = 1;
@@ -413,8 +381,7 @@ def Shape_IndexToSizeOp : Shape_Op<"index_to_size", [Pure]> {
 }
 
 def Shape_MaxOp : Shape_Op<"max",
-    [Commutative, Pure,
-     DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+    [Commutative, Pure, InferTypeOpAdaptorWithIsCompatible]> {
   let summary = "Elementwise maximum";
   let description = [{
     Computes the elementwise maximum of two sizes or shapes with equal ranks.
@@ -431,16 +398,10 @@ def Shape_MaxOp : Shape_Op<"max",
   }];
 
   let hasFolder = 1;
-
-  let extraClassDeclaration = [{
-    // Returns when two result types are compatible for this op; method used by
-    // InferTypeOpInterface
-    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
-  }];
 }
 
 def Shape_MeetOp : Shape_Op<"meet",
-    [Commutative, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+    [Commutative, InferTypeOpAdaptorWithIsCompatible]> {
   let summary = "Returns the least general shape or size of its operands";
   let description = [{
     An operation that computes the least general shape or dim of input operands.
@@ -478,17 +439,10 @@ def Shape_MeetOp : Shape_Op<"meet",
     $arg0 `,` $arg1 (`,` `error` `=` $error^)? attr-dict `:`
       type($arg0) `,` type($arg1) `->` type($result)
   }];
-
-  let extraClassDeclaration = [{
-    // Returns when two result types are compatible for this op; method used by
-    // InferTypeOpInterface
-    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
-  }];
 }
 
 def Shape_MinOp : Shape_Op<"min",
-    [Commutative, Pure,
-     DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+    [Commutative, Pure, InferTypeOpAdaptorWithIsCompatible]> {
   let summary = "Elementwise minimum";
   let description = [{
     Computes the elementwise minimum of two sizes or shapes with equal ranks.
@@ -505,17 +459,10 @@ def Shape_MinOp : Shape_Op<"min",
   }];
 
   let hasFolder = 1;
-
-  let extraClassDeclaration = [{
-    // Returns when two result types are compatible for this op; method used by
-    // InferTypeOpInterface
-    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
-  }];
 }
 
 def Shape_MulOp : Shape_Op<"mul",
-    [Commutative, Pure,
-     DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+    [Commutative, Pure, InferTypeOpAdaptorWithIsCompatible]> {
   let summary = "Multiplication of sizes and indices";
   let description = [{
     Multiplies two sizes or indices. If either operand is an error it will be
@@ -535,16 +482,10 @@ def Shape_MulOp : Shape_Op<"mul",
 
   let hasFolder = 1;
   let hasVerifier = 1;
-
-  let extraClassDeclaration = [{
-    // Returns when two result types are compatible for this op; method used by
-    // InferTypeOpInterface
-    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
-  }];
 }
 
 def Shape_NumElementsOp : Shape_Op<"num_elements",
-    [Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+    [Pure, InferTypeOpAdaptorWithIsCompatible]> {
   let summary = "Returns the number of elements for a given shape";
   let description = [{
     Returns the number of elements for a given shape which is the product of
@@ -561,11 +502,6 @@ def Shape_NumElementsOp : Shape_Op<"num_elements",
 
   let hasFolder = 1;
   let hasVerifier = 1;
-  let extraClassDeclaration = [{
-    // Returns when two result types are compatible for this op; method used by
-    // InferTypeOpInterface
-    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
-  }];
 }
 
 def Shape_ReduceOp : Shape_Op<"reduce",
@@ -616,7 +552,7 @@ def Shape_ReduceOp : Shape_Op<"reduce",
 }
 
 def Shape_ShapeOfOp : Shape_Op<"shape_of",
-    [Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+    [Pure, InferTypeOpAdaptorWithIsCompatible]> {
   let summary = "Returns shape of a value or shaped type operand";
 
   let description = [{
@@ -632,12 +568,6 @@ def Shape_ShapeOfOp : Shape_Op<"shape_of",
   let hasCanonicalizer = 1;
   let hasFolder = 1;
   let hasVerifier = 1;
-
-  let extraClassDeclaration = [{
-    // Returns when two result types are compatible for this op; method used by
-    // InferTypeOpInterface
-    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
-  }];
 }
 
 def Shape_ValueOfOp : Shape_Op<"value_of", [Pure]> {

diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 7346315013b1c3..588998853e6995 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -463,7 +463,7 @@ def Vector_ShuffleOp :
                  TCresVTEtIsSameAsOpBase<0, 0>>,
      PredOpTrait<"second operand v2 and result have same element type",
                  TCresVTEtIsSameAsOpBase<0, 1>>,
-     DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
+     InferTypeOpAdaptor]>,
      Arguments<(ins AnyVectorOfAnyRank:$v1, AnyVectorOfAnyRank:$v2,
                     I64ArrayAttr:$mask)>,
      Results<(outs AnyVector:$vector)> {
@@ -572,7 +572,7 @@ def Vector_ExtractOp :
   Vector_Op<"extract", [Pure,
      PredOpTrait<"operand and result have same element type",
                  TCresVTEtIsSameAsOpBase<0, 0>>,
-     DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
+     InferTypeOpAdaptorWithIsCompatible]>,
     Arguments<(ins AnyVectorOfAnyRank:$vector, I64ArrayAttr:$position)>,
     Results<(outs AnyType)> {
   let summary = "extract operation";
@@ -598,7 +598,6 @@ def Vector_ExtractOp :
     VectorType getSourceVectorType() {
       return ::llvm::cast<VectorType>(getVector().getType());
     }
-    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
   }];
   let assemblyFormat = "$vector `` $position attr-dict `:` type($vector)";
   let hasCanonicalizer = 1;

diff  --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
index 747ec0a76f3f1d..b8a664fc6882f0 100644
--- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
+++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
@@ -259,8 +259,8 @@ namespace mlir {
 namespace OpTrait {
 
 template <typename ConcreteType>
-class InferTypeOpInterfaceAdaptor
-    : public TraitBase<ConcreteType, InferTypeOpInterfaceAdaptor> {};
+class InferTypeOpAdaptor : public TraitBase<ConcreteType, InferTypeOpAdaptor> {
+};
 
 /// Tensor type inference trait that constructs a tensor from the inferred
 /// shape and elemental types.

diff  --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
index c9c1c6cc9ab01c..a458887b374543 100644
--- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
+++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
@@ -186,35 +186,42 @@ def InferShapedTypeOpInterface : OpInterface<"InferShapedTypeOpInterface"> {
 
 // Convenient trait to define a wrapper to inferReturnTypes that passes in the
 // Op Adaptor directly
-def InferTypeOpInterfaceAdaptor : TraitList<
+class InferTypeOpAdaptorBase<code additionalDecls = [{}]> : TraitList<
   [
     // Op implements infer type op interface.
     DeclareOpInterfaceMethods<InferTypeOpInterface>,
     NativeOpTrait<
-      /*name=*/"InferTypeOpInterfaceAdaptor",
+      /*name=*/"InferTypeOpAdaptor",
       /*traits=*/[],
       /*extraOpDeclaration=*/[{
-        static LogicalResult
-        inferReturnTypesAdaptor(MLIRContext *context,
-                                std::optional<Location> location,
+        static ::mlir::LogicalResult
+        inferReturnTypes(::mlir::MLIRContext *context,
+                                std::optional<::mlir::Location> location,
                                 Adaptor adaptor,
-                                SmallVectorImpl<Type> &inferredReturnTypes);
-      }],
+                                ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes);
+      }] # additionalDecls,
       /*extraOpDefinition=*/[{
-        LogicalResult
-        $cppClass::inferReturnTypes(MLIRContext *context,
-                          std::optional<Location> location,
-                          ValueRange operands, DictionaryAttr attributes,
-                          OpaqueProperties properties, RegionRange regions,
-                          SmallVectorImpl<Type> &inferredReturnTypes) {
+        ::mlir::LogicalResult
+        $cppClass::inferReturnTypes(::mlir::MLIRContext *context,
+                          std::optional<::mlir::Location> location,
+                          ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
+                          ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions,
+                          ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
           $cppClass::Adaptor adaptor(operands, attributes, properties, regions);
-          return $cppClass::inferReturnTypesAdaptor(context,
+          return $cppClass::inferReturnTypes(context,
             location, adaptor, inferredReturnTypes);
         }
       }]
     >
   ]>;
 
+def InferTypeOpAdaptor : InferTypeOpAdaptorBase;
+def InferTypeOpAdaptorWithIsCompatible : InferTypeOpAdaptorBase<
+  [{
+    static bool isCompatibleReturnTypes(::mlir::TypeRange l, ::mlir::TypeRange r);
+  }]
+>;
+
 // Convenience class grouping together type and shaped type op interfaces for
 // ops that have tensor return types.
 class InferTensorTypeBase<list<string> overridenMethods = []> : TraitList<
@@ -231,13 +238,13 @@ class InferTensorTypeBase<list<string> overridenMethods = []> : TraitList<
       /*traits=*/[],
       /*extraOpDeclaration=*/[{}],
       /*extraOpDefinition=*/[{
-        LogicalResult
-        $cppClass::inferReturnTypes(MLIRContext *context,
-                          std::optional<Location> location,
-                          ValueRange operands, DictionaryAttr attributes,
-                          OpaqueProperties properties, RegionRange regions,
-                          SmallVectorImpl<Type> &inferredReturnTypes) {
-          SmallVector<ShapedTypeComponents, 2> retComponents;
+        ::mlir::LogicalResult
+        $cppClass::inferReturnTypes(::mlir::MLIRContext *context,
+                          std::optional<::mlir::Location> location,
+                          ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
+                          ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions,
+                          ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
+          ::llvm::SmallVector<::mlir::ShapedTypeComponents, 2> retComponents;
           if (failed($cppClass::inferReturnTypeComponents(context, location,
                                     operands, attributes, properties, regions,
                                     retComponents)))

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 427efa0cd76822..5f35adf0ddaab1 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1354,7 +1354,7 @@ void ExtractAlignedPointerAsIndexOp::getAsmResultNames(
 
 /// The number and type of the results are inferred from the
 /// shape of the source.
-LogicalResult ExtractStridedMetadataOp::inferReturnTypesAdaptor(
+LogicalResult ExtractStridedMetadataOp::inferReturnTypes(
     MLIRContext *context, std::optional<Location> location,
     ExtractStridedMetadataOp::Adaptor adaptor,
     SmallVectorImpl<Type> &inferredReturnTypes) {

diff  --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index ddcdffebed1392..d8e0ba1c8cd6de 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1841,12 +1841,11 @@ bool mlir::scf::insideMutuallyExclusiveBranches(Operation *a, Operation *b) {
 
 LogicalResult
 IfOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
-                       ValueRange operands, DictionaryAttr attrs,
-                       OpaqueProperties properties, RegionRange regions,
+                       IfOp::Adaptor adaptor,
                        SmallVectorImpl<Type> &inferredReturnTypes) {
-  if (regions.empty())
+  if (adaptor.getRegions().empty())
     return failure();
-  Region *r = regions.front();
+  Region *r = &adaptor.getThenRegion();
   if (r->empty())
     return failure();
   Block &b = r->front();

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index c5c69dc98b82a0..a06dd446e258a4 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -394,11 +394,10 @@ void AssumingOp::build(
 //===----------------------------------------------------------------------===//
 
 LogicalResult mlir::shape::AddOp::inferReturnTypes(
-    MLIRContext *context, std::optional<Location> location, ValueRange operands,
-    DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
-    SmallVectorImpl<Type> &inferredReturnTypes) {
-  if (llvm::isa<SizeType>(operands[0].getType()) ||
-      llvm::isa<SizeType>(operands[1].getType()))
+    MLIRContext *context, std::optional<Location> location,
+    AddOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
+  if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
+      llvm::isa<SizeType>(adaptor.getRhs().getType()))
     inferredReturnTypes.assign({SizeType::get(context)});
   else
     inferredReturnTypes.assign({IndexType::get(context)});
@@ -916,18 +915,17 @@ void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
 }
 
 LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes(
-    MLIRContext *context, std::optional<Location> location, ValueRange operands,
-    DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
-    SmallVectorImpl<Type> &inferredReturnTypes) {
+    MLIRContext *context, std::optional<Location> location,
+    ConstShapeOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
   Builder b(context);
-  Properties *prop = properties.as<Properties *>();
+  const Properties *prop = &adaptor.getProperties();
   DenseIntElementsAttr shape;
   // TODO: this is only exercised by the Python bindings codepath which does not
   // support properties
   if (prop)
     shape = prop->shape;
   else
-    shape = attributes.getAs<DenseIntElementsAttr>("shape");
+    shape = adaptor.getAttributes().getAs<DenseIntElementsAttr>("shape");
   if (!shape)
     return emitOptionalError(location, "missing shape attribute");
   inferredReturnTypes.assign({RankedTensorType::get(
@@ -1104,11 +1102,9 @@ OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
 }
 
 LogicalResult mlir::shape::DimOp::inferReturnTypes(
-    MLIRContext *context, std::optional<Location> location, ValueRange operands,
-    DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
-    SmallVectorImpl<Type> &inferredReturnTypes) {
-  DimOpAdaptor dimOp(operands);
-  inferredReturnTypes.assign({dimOp.getIndex().getType()});
+    MLIRContext *context, std::optional<Location> location,
+    DimOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
+  inferredReturnTypes.assign({adaptor.getIndex().getType()});
   return success();
 }
 
@@ -1141,11 +1137,10 @@ OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
 }
 
 LogicalResult mlir::shape::DivOp::inferReturnTypes(
-    MLIRContext *context, std::optional<Location> location, ValueRange operands,
-    DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
-    SmallVectorImpl<Type> &inferredReturnTypes) {
-  if (llvm::isa<SizeType>(operands[0].getType()) ||
-      llvm::isa<SizeType>(operands[1].getType()))
+    MLIRContext *context, std::optional<Location> location,
+    DivOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
+  if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
+      llvm::isa<SizeType>(adaptor.getRhs().getType()))
     inferredReturnTypes.assign({SizeType::get(context)});
   else
     inferredReturnTypes.assign({IndexType::get(context)});
@@ -1361,9 +1356,8 @@ void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
 }
 
 LogicalResult mlir::shape::GetExtentOp::inferReturnTypes(
-    MLIRContext *context, std::optional<Location> location, ValueRange operands,
-    DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
-    SmallVectorImpl<Type> &inferredReturnTypes) {
+    MLIRContext *context, std::optional<Location> location,
+    GetExtentOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
   inferredReturnTypes.assign({IndexType::get(context)});
   return success();
 }
@@ -1399,10 +1393,9 @@ OpFoldResult IsBroadcastableOp::fold(FoldAdaptor adaptor) {
 //===----------------------------------------------------------------------===//
 
 LogicalResult mlir::shape::MeetOp::inferReturnTypes(
-    MLIRContext *context, std::optional<Location> location, ValueRange operands,
-    DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
-    SmallVectorImpl<Type> &inferredReturnTypes) {
-  if (operands.empty())
+    MLIRContext *context, std::optional<Location> location,
+    MeetOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
+  if (adaptor.getOperands().empty())
     return failure();
 
   auto isShapeType = [](Type arg) {
@@ -1411,7 +1404,7 @@ LogicalResult mlir::shape::MeetOp::inferReturnTypes(
     return isExtentTensorType(arg);
   };
 
-  ValueRange::type_range types = operands.getTypes();
+  ValueRange::type_range types = adaptor.getOperands().getTypes();
   Type acc = types.front();
   for (auto t : drop_begin(types)) {
     Type l = acc, r = t;
@@ -1535,10 +1528,9 @@ void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
 }
 
 LogicalResult mlir::shape::RankOp::inferReturnTypes(
-    MLIRContext *context, std::optional<Location> location, ValueRange operands,
-    DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
-    SmallVectorImpl<Type> &inferredReturnTypes) {
-  if (llvm::isa<ShapeType>(operands[0].getType()))
+    MLIRContext *context, std::optional<Location> location,
+    RankOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
+  if (llvm::isa<ShapeType>(adaptor.getShape().getType()))
     inferredReturnTypes.assign({SizeType::get(context)});
   else
     inferredReturnTypes.assign({IndexType::get(context)});
@@ -1571,10 +1563,10 @@ OpFoldResult NumElementsOp::fold(FoldAdaptor adaptor) {
 }
 
 LogicalResult mlir::shape::NumElementsOp::inferReturnTypes(
-    MLIRContext *context, std::optional<Location> location, ValueRange operands,
-    DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
+    MLIRContext *context, std::optional<Location> location,
+    NumElementsOp::Adaptor adaptor,
     SmallVectorImpl<Type> &inferredReturnTypes) {
-  if (llvm::isa<ShapeType>(operands[0].getType()))
+  if (llvm::isa<ShapeType>(adaptor.getShape().getType()))
     inferredReturnTypes.assign({SizeType::get(context)});
   else
     inferredReturnTypes.assign({IndexType::get(context)});
@@ -1603,11 +1595,10 @@ OpFoldResult MaxOp::fold(FoldAdaptor adaptor) {
 }
 
 LogicalResult mlir::shape::MaxOp::inferReturnTypes(
-    MLIRContext *context, std::optional<Location> location, ValueRange operands,
-    DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
-    SmallVectorImpl<Type> &inferredReturnTypes) {
-  if (operands[0].getType() == operands[1].getType())
-    inferredReturnTypes.assign({operands[0].getType()});
+    MLIRContext *context, std::optional<Location> location,
+    MaxOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
+  if (adaptor.getLhs().getType() == adaptor.getRhs().getType())
+    inferredReturnTypes.assign({adaptor.getLhs().getType()});
   else
     inferredReturnTypes.assign({SizeType::get(context)});
   return success();
@@ -1635,11 +1626,10 @@ OpFoldResult MinOp::fold(FoldAdaptor adaptor) {
 }
 
 LogicalResult mlir::shape::MinOp::inferReturnTypes(
-    MLIRContext *context, std::optional<Location> location, ValueRange operands,
-    DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
-    SmallVectorImpl<Type> &inferredReturnTypes) {
-  if (operands[0].getType() == operands[1].getType())
-    inferredReturnTypes.assign({operands[0].getType()});
+    MLIRContext *context, std::optional<Location> location,
+    MinOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
+  if (adaptor.getLhs().getType() == adaptor.getRhs().getType())
+    inferredReturnTypes.assign({adaptor.getLhs().getType()});
   else
     inferredReturnTypes.assign({SizeType::get(context)});
   return success();
@@ -1672,11 +1662,10 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
 }
 
 LogicalResult mlir::shape::MulOp::inferReturnTypes(
-    MLIRContext *context, std::optional<Location> location, ValueRange operands,
-    DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
-    SmallVectorImpl<Type> &inferredReturnTypes) {
-  if (llvm::isa<SizeType>(operands[0].getType()) ||
-      llvm::isa<SizeType>(operands[1].getType()))
+    MLIRContext *context, std::optional<Location> location,
+    MulOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
+  if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
+      llvm::isa<SizeType>(adaptor.getRhs().getType()))
     inferredReturnTypes.assign({SizeType::get(context)});
   else
     inferredReturnTypes.assign({IndexType::get(context)});
@@ -1759,13 +1748,12 @@ void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
 }
 
 LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes(
-    MLIRContext *context, std::optional<Location> location, ValueRange operands,
-    DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
-    SmallVectorImpl<Type> &inferredReturnTypes) {
-  if (llvm::isa<ValueShapeType>(operands[0].getType()))
+    MLIRContext *context, std::optional<Location> location,
+    ShapeOfOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
+  if (llvm::isa<ValueShapeType>(adaptor.getArg().getType()))
     inferredReturnTypes.assign({ShapeType::get(context)});
   else {
-    auto shapedTy = llvm::cast<ShapedType>(operands[0].getType());
+    auto shapedTy = llvm::cast<ShapedType>(adaptor.getArg().getType());
     int64_t rank =
         shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamic;
     Type indexTy = IndexType::get(context);

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index f5dbeb250ed60d..075b139e2f3b1c 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1146,15 +1146,15 @@ void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
 
 LogicalResult
 ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
-                            ValueRange operands, DictionaryAttr attributes,
-                            OpaqueProperties properties, RegionRange,
+                            ExtractOp::Adaptor adaptor,
                             SmallVectorImpl<Type> &inferredReturnTypes) {
-  ExtractOp::Adaptor op(operands, attributes, properties);
-  auto vectorType = llvm::cast<VectorType>(op.getVector().getType());
-  if (static_cast<int64_t>(op.getPosition().size()) == vectorType.getRank()) {
+  auto vectorType = llvm::cast<VectorType>(adaptor.getVector().getType());
+  if (static_cast<int64_t>(adaptor.getPosition().size()) ==
+      vectorType.getRank()) {
     inferredReturnTypes.push_back(vectorType.getElementType());
   } else {
-    auto n = std::min<size_t>(op.getPosition().size(), vectorType.getRank());
+    auto n =
+        std::min<size_t>(adaptor.getPosition().size(), vectorType.getRank());
     inferredReturnTypes.push_back(VectorType::get(
         vectorType.getShape().drop_front(n), vectorType.getElementType()));
   }
@@ -2114,17 +2114,15 @@ LogicalResult ShuffleOp::verify() {
 
 LogicalResult
 ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
-                            ValueRange operands, DictionaryAttr attributes,
-                            OpaqueProperties properties, RegionRange,
+                            ShuffleOp::Adaptor adaptor,
                             SmallVectorImpl<Type> &inferredReturnTypes) {
-  ShuffleOp::Adaptor op(operands, attributes, properties);
-  auto v1Type = llvm::cast<VectorType>(op.getV1().getType());
+  auto v1Type = llvm::cast<VectorType>(adaptor.getV1().getType());
   auto v1Rank = v1Type.getRank();
   // Construct resulting type: leading dimension matches mask
   // length, all trailing dimensions match the operands.
   SmallVector<int64_t, 4> shape;
   shape.reserve(v1Rank);
-  shape.push_back(std::max<size_t>(1, op.getMask().size()));
+  shape.push_back(std::max<size_t>(1, adaptor.getMask().size()));
   // In the 0-D case there is no trailing shape to append.
   if (v1Rank > 0)
     llvm::append_range(shape, v1Type.getShape().drop_front());

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index fba882d3ba1a3e..420d5d3e4c5962 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -1385,6 +1385,19 @@ LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
   return success();
 }
 
+LogicalResult OpWithInferTypeAdaptorInterfaceOp::inferReturnTypes(
+    MLIRContext *, std::optional<Location> location,
+    OpWithInferTypeAdaptorInterfaceOp::Adaptor adaptor,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  if (adaptor.getX().getType() != adaptor.getY().getType()) {
+    return emitOptionalError(location, "operand type mismatch ",
+                             adaptor.getX().getType(), " vs ",
+                             adaptor.getY().getType());
+  }
+  inferredReturnTypes.assign({adaptor.getX().getType()});
+  return success();
+}
+
 // TODO: We should be able to only define either inferReturnType or
 // refineReturnType, currently only refineReturnType can be omitted.
 LogicalResult OpWithRefineTypeInterfaceOp::inferReturnTypes(

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 30f334c4404bea..7a3ae924afa9a3 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -761,6 +761,12 @@ def OpWithInferTypeInterfaceOp : TEST_Op<"op_with_infer_type_if", [
   let results = (outs AnyTensor);
 }
 
+def OpWithInferTypeAdaptorInterfaceOp : TEST_Op<"op_with_infer_type_adaptor_if", [
+    InferTypeOpAdaptor]> {
+  let arguments = (ins AnyTensor:$x, AnyTensor:$y);
+  let results = (outs AnyTensor);
+}
+
 def OpWithRefineTypeInterfaceOp : TEST_Op<"op_with_refine_type_if", [
     DeclareOpInterfaceMethods<InferTypeOpInterface,
         ["refineReturnTypes"]>]> {

diff  --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 16c85818605d78..46788edcb4df58 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -485,6 +485,8 @@ struct TestReturnTypeDriver
         // output would be in reverse order underneath `op` from which
         // the attributes and regions are used.
         invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op);
+        invokeCreateWithInferredReturnType<OpWithInferTypeAdaptorInterfaceOp>(
+            op);
         invokeCreateWithInferredReturnType<
             OpWithShapedTypeInferTypeInterfaceOp>(op);
       };


        


More information about the Mlir-commits mailing list