[Mlir-commits] [mlir] 057fc8e - [ODS] Use Adaptor Trait for Shaped Type Inference

Amanda Tang llvmlistbot at llvm.org
Thu Jul 20 12:41:14 PDT 2023


Author: Amanda Tang
Date: 2023-07-20T19:41:08Z
New Revision: 057fc8e7d8a3593f98930b8b91f80b9dd9b5fd4a

URL: https://github.com/llvm/llvm-project/commit/057fc8e7d8a3593f98930b8b91f80b9dd9b5fd4a
DIFF: https://github.com/llvm/llvm-project/commit/057fc8e7d8a3593f98930b8b91f80b9dd9b5fd4a.diff

LOG: [ODS] Use Adaptor Trait for Shaped Type Inference

Author inferReturnTypeComponents methods with the Op Adaptor by using the InferShapedTypeOpAdaptor.

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D155243

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
    mlir/include/mlir/Interfaces/InferTypeOpInterface.h
    mlir/include/mlir/Interfaces/InferTypeOpInterface.td
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
    mlir/test/lib/Dialect/Test/TestDialect.cpp
    mlir/test/lib/Dialect/Test/TestOps.td

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 7a16b37f0ca417..421d6b09424a59 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -32,10 +32,7 @@ include "mlir/Dialect/Tosa/IR/TosaOpBase.td"
 //===----------------------------------------------------------------------===//
 // Operator: argmax
 //===----------------------------------------------------------------------===//
-def Tosa_ArgMaxOp : Tosa_Op<"argmax", [
-    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
-                              ["inferReturnTypeComponents"]>,
-    Pure]> {
+def Tosa_ArgMaxOp : Tosa_Op<"argmax", [InferShapedTypeOpAdaptor, Pure]> {
   let summary = "Perform argmax on the input.";
 
   let description = [{
@@ -62,10 +59,7 @@ def Tosa_AccType : AnyTypeOf<[I<32>, SI<32>, F16, F32]>;
 //===----------------------------------------------------------------------===//
 // Operator: avg_pool2d
 //===----------------------------------------------------------------------===//
-def Tosa_AvgPool2dOp : Tosa_Op<"avg_pool2d", [
-    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
-                              ["inferReturnTypeComponents"]>,
-    Pure]> {
+def Tosa_AvgPool2dOp : Tosa_Op<"avg_pool2d", [InferShapedTypeOpAdaptor, Pure]> {
   let summary = "Performs max pooling on the input.";
 
   let description = [{
@@ -95,10 +89,7 @@ def Tosa_AvgPool2dOp : Tosa_Op<"avg_pool2d", [
 //===----------------------------------------------------------------------===//
 // Operator: conv2d
 //===----------------------------------------------------------------------===//
-def Tosa_Conv2DOp : Tosa_Op<"conv2d", [
-    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
-                              ["inferReturnTypeComponents"]>,
-    Pure]> {
+def Tosa_Conv2DOp : Tosa_Op<"conv2d", [InferShapedTypeOpAdaptor, Pure]> {
   let summary = "2D Convolution Operator";
 
   let description = [{
@@ -128,10 +119,7 @@ def Tosa_Conv2DOp : Tosa_Op<"conv2d", [
 //===----------------------------------------------------------------------===//
 // Operator: conv3d
 //===----------------------------------------------------------------------===//
-def Tosa_Conv3DOp : Tosa_Op<"conv3d", [
-    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
-                              ["inferReturnTypeComponents"]>,
-    Pure]> {
+def Tosa_Conv3DOp : Tosa_Op<"conv3d", [InferShapedTypeOpAdaptor, Pure]> {
   let summary = "3D Convolution operator";
 
   let description = [{
@@ -160,10 +148,8 @@ def Tosa_Conv3DOp : Tosa_Op<"conv3d", [
 //===----------------------------------------------------------------------===//
 // Operator: depthwise_conv2d
 //===----------------------------------------------------------------------===//
-def Tosa_DepthwiseConv2DOp : Tosa_Op<"depthwise_conv2d", [
-    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
-                              ["inferReturnTypeComponents"]>,
-    Pure]> {
+def Tosa_DepthwiseConv2DOp : Tosa_Op<"depthwise_conv2d",
+    [InferShapedTypeOpAdaptor, Pure]> {
   let summary = "Depthwise 2D Convolution operator";
 
   let description = [{
@@ -193,10 +179,7 @@ def Tosa_DepthwiseConv2DOp : Tosa_Op<"depthwise_conv2d", [
 //===----------------------------------------------------------------------===//
 // Operator: fft2d
 //===----------------------------------------------------------------------===//
-def Tosa_FFT2dOp : Tosa_Op<"fft2d", [
-    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
-                              ["inferReturnTypeComponents"]>,
-    Pure]> {
+def Tosa_FFT2dOp : Tosa_Op<"fft2d", [InferShapedTypeOpAdaptor, Pure]> {
   let summary = "Performs FFT2D operation on the input.";
 
   let description = [{
@@ -224,9 +207,7 @@ def Tosa_FFT2dOp : Tosa_Op<"fft2d", [
 // Operator: fully_connected
 //===----------------------------------------------------------------------===//
 def Tosa_FullyConnectedOp : Tosa_Op<"fully_connected", [
-    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
-                              ["inferReturnTypeComponents"]>,
-    Pure]> {
+    InferShapedTypeOpAdaptor, Pure]> {
   let summary = "Fully Connected operator";
 
   let description = [{
@@ -251,10 +232,7 @@ def Tosa_FullyConnectedOp : Tosa_Op<"fully_connected", [
 //===----------------------------------------------------------------------===//
 // Operator: matmul
 //===----------------------------------------------------------------------===//
-def Tosa_MatMulOp : Tosa_Op<"matmul", [
-    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
-                              ["inferReturnTypeComponents"]>,
-    Pure]> {
+def Tosa_MatMulOp : Tosa_Op<"matmul", [InferShapedTypeOpAdaptor, Pure]> {
   let summary = "Matrix multiplication with bias";
 
   let description = [{
@@ -279,10 +257,7 @@ def Tosa_MatMulOp : Tosa_Op<"matmul", [
 //===----------------------------------------------------------------------===//
 // Operator: max_pool2d
 //===----------------------------------------------------------------------===//
-def Tosa_MaxPool2dOp : Tosa_Op<"max_pool2d", [
-    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
-                              ["inferReturnTypeComponents"]>,
-    Pure]> {
+def Tosa_MaxPool2dOp : Tosa_Op<"max_pool2d", [InferShapedTypeOpAdaptor, Pure]> {
   let summary = "Performs max pooling on the input.";
 
   let description = [{
@@ -310,10 +285,7 @@ def Tosa_MaxPool2dOp : Tosa_Op<"max_pool2d", [
 //===----------------------------------------------------------------------===//
 // Operator: rfft2d
 //===----------------------------------------------------------------------===//
-def Tosa_RFFT2dOp : Tosa_Op<"rfft2d", [
-    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
-                              ["inferReturnTypeComponents"]>,
-    Pure]> {
+def Tosa_RFFT2dOp : Tosa_Op<"rfft2d", [InferShapedTypeOpAdaptor, Pure]> {
   let summary = "Performs RFFT2D operation on the input.";
 
   let description = [{
@@ -338,10 +310,8 @@ def Tosa_RFFT2dOp : Tosa_Op<"rfft2d", [
 //===----------------------------------------------------------------------===//
 // Operator: transpose_conv2d
 //===----------------------------------------------------------------------===//
-def Tosa_TransposeConv2DOp : Tosa_Op<"transpose_conv2d", [
-    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
-                              ["inferReturnTypeComponents"]>,
-    Pure]> {
+def Tosa_TransposeConv2DOp : Tosa_Op<"transpose_conv2d", 
+    [InferShapedTypeOpAdaptor, Pure]> {
   let summary = "Transpose 2D Convolution operator.";
 
   let description = [{
@@ -828,10 +798,7 @@ def Tosa_SubOp : Tosa_ElemWiseBinaryOp<"sub"> {
 //===----------------------------------------------------------------------===//
 // Operator: table
 //===----------------------------------------------------------------------===//
-def Tosa_TableOp : Tosa_Op<"table", [
-    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
-                              ["inferReturnTypeComponents"]>,
-    Pure]> {
+def Tosa_TableOp : Tosa_Op<"table", [InferShapedTypeOpAdaptor, Pure]> {
   let summary = "Table lookup op";
 
   let description = [{
@@ -1214,7 +1181,7 @@ def Tosa_GreaterEqualOp : Tosa_Op<"greater_equal", [
 // Operator: reduce_all
 //===----------------------------------------------------------------------===//
 def Tosa_ReduceAllOp : Tosa_Op<"reduce_all", [
-    InferTensorType, Pure]> {
+    InferTensorTypeAdaptor, Pure]> {
   let summary = "Reduce All operator";
 
   let description = [{
@@ -1243,7 +1210,7 @@ def Tosa_ReduceAllOp : Tosa_Op<"reduce_all", [
 // Operator: reduce_any
 //===----------------------------------------------------------------------===//
 def Tosa_ReduceAnyOp : Tosa_Op<"reduce_any", [
-    InferTensorType, Pure]> {
+    InferTensorTypeAdaptor, Pure]> {
   let summary = "Reduce Any operator";
 
   let description = [{
@@ -1272,7 +1239,7 @@ def Tosa_ReduceAnyOp : Tosa_Op<"reduce_any", [
 // Operator: reduce_max
 //===----------------------------------------------------------------------===//
 def Tosa_ReduceMaxOp : Tosa_Op<"reduce_max", [
-    InferTensorType, Pure]> {
+    InferTensorTypeAdaptor, Pure]> {
   let summary = "Reduce Max operator";
 
   let description = [{
@@ -1301,7 +1268,7 @@ def Tosa_ReduceMaxOp : Tosa_Op<"reduce_max", [
 // Operator: reduce_min
 //===----------------------------------------------------------------------===//
 def Tosa_ReduceMinOp : Tosa_Op<"reduce_min", [
-    InferTensorType, Pure]> {
+    InferTensorTypeAdaptor, Pure]> {
   let summary = "Reduce Min operator";
 
   let description = [{
@@ -1330,7 +1297,7 @@ def Tosa_ReduceMinOp : Tosa_Op<"reduce_min", [
 // Operator: reduce_prod
 //===----------------------------------------------------------------------===//
 def Tosa_ReduceProdOp : Tosa_Op<"reduce_prod", [
-    InferTensorType, Pure]> {
+    InferTensorTypeAdaptor, Pure]> {
   let summary = "Reduce Prod operator";
 
   let description = [{
@@ -1359,7 +1326,7 @@ def Tosa_ReduceProdOp : Tosa_Op<"reduce_prod", [
 // Operator: reduce_sum
 //===----------------------------------------------------------------------===//
 def Tosa_ReduceSumOp : Tosa_Op<"reduce_sum", [
-    InferTensorType, Pure]> {
+    InferTensorTypeAdaptor, Pure]> {
   let summary = "Reduce Sum operator";
 
   let description = [{
@@ -1393,7 +1360,7 @@ def Tosa_ReduceSumOp : Tosa_Op<"reduce_sum", [
 // Operator: concat
 //===----------------------------------------------------------------------===//
 def Tosa_ConcatOp : Tosa_Op<"concat", [
-    InferTensorType, Pure]> {
+    InferTensorTypeAdaptor, Pure]> {
   let summary = "Concatenates tensors along one dimension.";
 
   let description = [{
@@ -1423,10 +1390,7 @@ def Tosa_ConcatOp : Tosa_Op<"concat", [
 //===----------------------------------------------------------------------===//
 // Operator: pad
 //===----------------------------------------------------------------------===//
-def Tosa_PadOp : Tosa_Op<"pad", [
-    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
-                              ["inferReturnTypeComponents"]>,
-    Pure]> {
+def Tosa_PadOp : Tosa_Op<"pad", [InferShapedTypeOpAdaptor, Pure]> {
   let summary = "Pads a tensor with value specified.";
 
   let description = [{
@@ -1471,7 +1435,7 @@ def Tosa_PadOp : Tosa_Op<"pad", [
 // Operator: reshape
 //===----------------------------------------------------------------------===//
 def Tosa_ReshapeOp: Tosa_Op<"reshape", [
-  InferTensorType, Pure]> {
+  InferTensorTypeAdaptor, Pure]> {
   let summary = "Reshape operator";
 
   let description = [{
@@ -1528,9 +1492,7 @@ def Tosa_ReverseOp: Tosa_Op<"reverse", [
 //===----------------------------------------------------------------------===//
 // Operator: slice
 //===----------------------------------------------------------------------===//
-def Tosa_SliceOp: Tosa_Op<"slice", [
-      DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
-                              ["inferReturnTypeComponents"]>, Pure]> {
+def Tosa_SliceOp: Tosa_Op<"slice", [InferShapedTypeOpAdaptor, Pure]> {
   let summary = "Slice operator";
 
   let description = [{
@@ -1556,10 +1518,7 @@ def Tosa_SliceOp: Tosa_Op<"slice", [
 //===----------------------------------------------------------------------===//
 // Operator: tile
 //===----------------------------------------------------------------------===//
-def Tosa_TileOp: Tosa_Op<"tile", [
-      DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
-                              ["inferReturnTypeComponents"]>,
-      Pure]> {
+def Tosa_TileOp: Tosa_Op<"tile", [InferShapedTypeOpAdaptor, Pure]> {
   let summary = "Tile operator";
 
   let description = [{
@@ -1580,10 +1539,7 @@ def Tosa_TileOp: Tosa_Op<"tile", [
 //===----------------------------------------------------------------------===//
 // Operator: transpose
 //===----------------------------------------------------------------------===//
-def Tosa_TransposeOp : Tosa_Op<"transpose", [
-      DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
-                              ["inferReturnTypeComponents"]>,
-      Pure]> {
+def Tosa_TransposeOp : Tosa_Op<"transpose", [InferShapedTypeOpAdaptor, Pure]> {
   let summary = "Transpose operator";
 
   let description = [{
@@ -1615,10 +1571,7 @@ def Tosa_TransposeOp : Tosa_Op<"transpose", [
 //===----------------------------------------------------------------------===//
 // Operator: gather
 //===----------------------------------------------------------------------===//
-def Tosa_GatherOp : Tosa_Op<"gather", [
-      DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
-                              ["inferReturnTypeComponents"]>,
-      Pure]> {
+def Tosa_GatherOp : Tosa_Op<"gather", [InferShapedTypeOpAdaptor, Pure]> {
   let summary = "Gather operation,";
 
   let description = [{
@@ -1639,10 +1592,7 @@ def Tosa_GatherOp : Tosa_Op<"gather", [
 //===----------------------------------------------------------------------===//
 // Operator: scatter
 //===----------------------------------------------------------------------===//
-def Tosa_ScatterOp : Tosa_Op<"scatter", [
-      DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
-                              ["inferReturnTypeComponents"]>,
-      Pure]> {
+def Tosa_ScatterOp : Tosa_Op<"scatter", [InferShapedTypeOpAdaptor, Pure]> {
   let summary = "Scatter operation,";
 
   let description = [{
@@ -1669,10 +1619,7 @@ def Tosa_ScatterOp : Tosa_Op<"scatter", [
 //===----------------------------------------------------------------------===//
 // Operator: resize
 //===----------------------------------------------------------------------===//
-def Tosa_ResizeOp : Tosa_Op<"resize", [
-      DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
-                              ["inferReturnTypeComponents"]>,
-      Pure]> {
+def Tosa_ResizeOp : Tosa_Op<"resize", [InferShapedTypeOpAdaptor, Pure]> {
 
   let summary = "Resize operation, supports various resize/upsample modes";
 
@@ -1898,9 +1845,8 @@ def Tosa_CustomOp : Tosa_Op<"custom"> {
 //===----------------------------------------------------------------------===//
 // Further described in docs/Rationale/RationaleTOSADialect.md .
 //===----------------------------------------------------------------------===//
-def Tosa_IfOp : Tosa_Op<"cond_if", [
-      DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
-                                ["inferReturnTypeComponents"]>,
+def Tosa_IfOp : Tosa_Op<"cond_if",
+       [InferShapedTypeOpAdaptor,
        SingleBlockImplicitTerminator<"YieldOp">,
        RecursiveMemoryEffects]> {
   let summary = "Conditional if operator";
@@ -1933,8 +1879,7 @@ def Tosa_IfOp : Tosa_Op<"cond_if", [
 //===----------------------------------------------------------------------===//
 def Tosa_WhileOp : Tosa_Op<"while_loop", [
        DeclareOpInterfaceMethods<LoopLikeOpInterface>,
-       DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
-                                 ["inferReturnTypeComponents"]>,
+       InferShapedTypeOpAdaptor,
        SingleBlockImplicitTerminator<"YieldOp">,
        RecursiveMemoryEffects]> {
   let summary = "output = input; While (Cond(output)) {output = Body(output)}";

diff  --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
index b8a664fc6882f0..67de05b0cb4ff3 100644
--- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
+++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
@@ -262,6 +262,10 @@ template <typename ConcreteType>
 class InferTypeOpAdaptor : public TraitBase<ConcreteType, InferTypeOpAdaptor> {
 };
 
+template <typename ConcreteType>
+class InferShapedTypeOpAdaptor
+    : public TraitBase<ConcreteType, InferShapedTypeOpAdaptor> {};
+
 /// Tensor type inference trait that constructs a tensor from the inferred
 /// shape and elemental types.
 /// Requires: Op implements InferShapedTypeOpInterface and InferTypeOpInterface.

diff  --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
index a458887b374543..1ceaf25a994e01 100644
--- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
+++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
@@ -222,6 +222,42 @@ def InferTypeOpAdaptorWithIsCompatible : InferTypeOpAdaptorBase<
   }]
 >;
 
+// Convenient trait to define a wrapper to inferReturnTypeComponents that passes
+// in the Op Adaptor directly
+class InferShapedTypeOpAdaptorBase<list<string> overridenMethods = []> : TraitList<
+  [
+    // Op implements infer type op interface.
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface, overridenMethods>,
+    NativeOpTrait<
+      /*name=*/"InferShapedTypeOpAdaptor",
+      /*traits=*/[],
+      /*extraOpDeclaration=*/[{
+        static ::mlir::LogicalResult
+        inferReturnTypeComponents(::mlir::MLIRContext *context,
+                                std::optional<::mlir::Location> location,
+                                Adaptor adaptor,
+                                ::llvm::SmallVectorImpl<::mlir::ShapedTypeComponents> &inferredReturnShapes);
+      }],
+      /*extraOpDefinition=*/[{
+        ::mlir::LogicalResult
+        $cppClass::inferReturnTypeComponents(::mlir::MLIRContext *context,
+                          std::optional<::mlir::Location> location,
+                          ::mlir::ValueShapeRange operands, ::mlir::DictionaryAttr attributes,
+                          ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions,
+                          ::llvm::SmallVectorImpl<::mlir::ShapedTypeComponents> &inferredReturnShapes) {
+          $cppClass::Adaptor adaptor(operands, attributes, properties, regions);
+          return $cppClass::inferReturnTypeComponents(context,
+            location, adaptor, inferredReturnShapes);
+        }
+      }]
+    >
+  ]>;
+
+def InferShapedTypeOpAdaptor : InferShapedTypeOpAdaptorBase<[
+  "inferReturnTypeComponents"]>;
+def InferShapedTypeOpAdaptorWithReify : InferShapedTypeOpAdaptorBase<[
+  "inferReturnTypeComponents", "reifyReturnTypeShapes"]>;
+
 // Convenience class grouping together type and shaped type op interfaces for
 // ops that have tensor return types.
 class InferTensorTypeBase<list<string> overridenMethods = []> : TraitList<
@@ -260,6 +296,44 @@ def InferTensorType : InferTensorTypeBase<["inferReturnTypeComponents"]>;
 def InferTensorTypeWithReify: InferTensorTypeBase<[
     "inferReturnTypeComponents", "reifyReturnTypeShapes"]>;
 
+// Convenience class grouping together type and shaped type op interfaces for
+// ops that have tensor return types.
+class InferTensorTypeAdaptorBase<list<string> overridenMethods = []> : TraitList<
+  [
+    // Op implements infer type op interface.
+    DeclareOpInterfaceMethods<InferTypeOpInterface>,
+    // The op will have methods implementing the ShapedType type inference
+    // interface.
+    InferShapedTypeOpAdaptorBase<overridenMethods>,
+    // The op produces tensors and will use the ShapedType type infer interface
+    // along with knowledge that it is producing Tensors to infer the type.
+    NativeOpTrait<
+      /*name=*/"InferTensorType",
+      /*traits=*/[],
+      /*extraOpDeclaration=*/[{}],
+      /*extraOpDefinition=*/[{
+        LogicalResult
+        $cppClass::inferReturnTypes(::mlir::MLIRContext *context,
+                          std::optional<::mlir::Location> location,
+                          ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
+                          ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions,
+                          ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
+          SmallVector<ShapedTypeComponents, 2> retComponents;
+          if (failed($cppClass::inferReturnTypeComponents(context, location,
+                                    operands, attributes, properties, regions,
+                                    retComponents)))
+            return failure();
+          return ::mlir::detail::inferReturnTensorTypes(retComponents,
+                                    inferredReturnTypes);
+        }
+      }]
+    >
+  ]>;
+
+def InferTensorTypeAdaptor : InferTensorTypeAdaptorBase<["inferReturnTypeComponents"]>;
+def InferTensorTypeAdaptorWithReify: InferTensorTypeAdaptorBase<[
+    "inferReturnTypeComponents", "reifyReturnTypeShapes"]>;
+
 def ReifyRankedShapedTypeOpInterface :
     OpInterface<"ReifyRankedShapedTypeOpInterface"> {
   let description = [{

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 7577c5893c775b..7b67040b4f6340 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -22,6 +22,7 @@
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Transforms/InliningUtils.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/TypeSwitch.h"
@@ -404,12 +405,10 @@ static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands,
 
 LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
-    ValueShapeRange operands, DictionaryAttr attributes,
-    OpaqueProperties properties, RegionRange regions,
+    ArgMaxOp::Adaptor adaptor,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
-  ShapeAdaptor inputShape = operands.getShape(0);
-  auto *prop = properties.as<Properties *>();
-  IntegerAttr axis = prop->axis;
+  ShapeAdaptor inputShape(adaptor.getInput().getType());
+  IntegerAttr axis = adaptor.getProperties().axis;
   int32_t axisVal = axis.getValue().getSExtValue();
 
   if (!inputShape.hasRank()) {
@@ -431,10 +430,9 @@ LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
 
 LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
-    ValueShapeRange operands, DictionaryAttr attributes,
-    OpaqueProperties properties, RegionRange regions,
+    RFFT2dOp::Adaptor adaptor,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
-  ShapeAdaptor inputShape = operands.getShape(0);
+  ShapeAdaptor inputShape(adaptor.getInput().getType());
 
   if (!inputShape.hasRank())
     return failure();
@@ -458,26 +456,26 @@ LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
 
 LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
-    ValueShapeRange operands, DictionaryAttr attributes,
-    OpaqueProperties properties, RegionRange regions,
+    FFT2dOp::Adaptor adaptor,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
-  inferredReturnShapes.push_back(ShapedTypeComponents(operands.getShape(0)));
-  inferredReturnShapes.push_back(ShapedTypeComponents(operands.getShape(1)));
+  inferredReturnShapes.push_back(
+      ShapedTypeComponents(ShapeAdaptor(adaptor.getInputReal().getType())));
+  inferredReturnShapes.push_back(
+      ShapedTypeComponents(ShapeAdaptor(adaptor.getInputImag().getType())));
   return success();
 }
 
 LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
-    ValueShapeRange operands, DictionaryAttr attributes,
-    OpaqueProperties properties, RegionRange regions,
+    ConcatOp::Adaptor adaptor,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
   // Infer all dimension sizes by reducing based on inputs.
-  auto *prop = properties.as<Properties *>();
-  int32_t axis = prop->axis.getValue().getSExtValue();
+  const Properties &prop = adaptor.getProperties();
+  int32_t axis = prop.axis.getValue().getSExtValue();
   llvm::SmallVector<int64_t> outputShape;
   bool hasRankedInput = false;
-  for (auto operand : operands) {
-    ShapeAdaptor operandShape = operands.getShape(operand);
+  for (auto operand : adaptor.getOperands()) {
+    ShapeAdaptor operandShape(operand.getType());
     if (!operandShape.hasRank())
       continue;
 
@@ -501,7 +499,7 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
     hasRankedInput = true;
   }
   Type inputType =
-      llvm::cast<TensorType>(operands.getType()[0]).getElementType();
+      llvm::cast<TensorType>(adaptor.getInput1().getType()[0]).getElementType();
   if (!hasRankedInput) {
     inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
     return success();
@@ -509,8 +507,8 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
 
   // Determine the dimension size along the concatenation axis.
   int64_t concatDimSize = 0;
-  for (auto operand : operands) {
-    ShapeAdaptor operandShape = operands.getShape(operand);
+  for (auto operand : adaptor.getOperands()) {
+    ShapeAdaptor operandShape(operand.getType());
 
     // We need to know the length of the concatenation axis of all inputs to
     // determine the dimension size of the output shape.
@@ -553,12 +551,11 @@ bool tosa::EqualOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
 
 LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
-    ValueShapeRange operands, DictionaryAttr attributes,
-    OpaqueProperties properties, RegionRange regions,
+    FullyConnectedOp::Adaptor adaptor,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
-  ShapeAdaptor inputShape = operands.getShape(0);
-  ShapeAdaptor weightShape = operands.getShape(1);
-  ShapeAdaptor biasShape = operands.getShape(2);
+  ShapeAdaptor inputShape(adaptor.getInput().getType());
+  ShapeAdaptor weightShape(adaptor.getWeight().getType());
+  ShapeAdaptor biasShape(adaptor.getBias().getType());
 
   // All shapes are dynamic.
   SmallVector<int64_t> outShape;
@@ -585,11 +582,10 @@ LogicalResult FullyConnectedOp::verify() { return verifyConvOp(*this); }
 
 LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
-    ValueShapeRange operands, DictionaryAttr attributes,
-    OpaqueProperties properties, RegionRange regions,
+    MatMulOp::Adaptor adaptor,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
-  ShapeAdaptor lhsShape = operands.getShape(0);
-  ShapeAdaptor rhsShape = operands.getShape(1);
+  ShapeAdaptor lhsShape(adaptor.getA().getType());
+  ShapeAdaptor rhsShape(adaptor.getB().getType());
 
   // All shapes are dynamic.
   SmallVector<int64_t> outShape;
@@ -612,11 +608,10 @@ LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
 
 LogicalResult tosa::PadOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
-    ValueShapeRange operands, DictionaryAttr attributes,
-    OpaqueProperties properties, RegionRange regions,
+    PadOp::Adaptor adaptor,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
-  ShapeAdaptor inputShape = operands.getShape(0);
-  ShapeAdaptor paddingShape = operands.getShape(1);
+  ShapeAdaptor inputShape(adaptor.getInput1().getType());
+  ShapeAdaptor paddingShape(adaptor.getPadding().getType());
   SmallVector<int64_t> outputShape;
 
   // If both inputs have unknown shape, we cannot determine the shape of the
@@ -641,7 +636,7 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
 
   DenseIntElementsAttr paddings;
   // If the paddings value is not a constant, all dimensions must be dynamic.
-  if (!matchPattern(operands[1], m_Constant(&paddings))) {
+  if (!matchPattern(adaptor.getPadding(), m_Constant(&paddings))) {
     outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
     inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
     return success();
@@ -675,22 +670,18 @@ static SmallVector<int64_t> convertToMlirShape(ArrayRef<int64_t> shape) {
 
 LogicalResult tosa::SliceOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
-    ValueShapeRange operands, DictionaryAttr attributes,
-    OpaqueProperties properties, RegionRange regions,
+    SliceOp::Adaptor adaptor,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
-  inferredReturnShapes.push_back(ShapedTypeComponents(
-      convertToMlirShape(SliceOpAdaptor(operands, attributes,
-                                        *properties.as<Properties *>(), regions)
-                             .getSize())));
+  inferredReturnShapes.push_back(
+      ShapedTypeComponents(convertToMlirShape(adaptor.getSize())));
   return success();
 }
 
 LogicalResult tosa::TableOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
-    ValueShapeRange operands, DictionaryAttr attributes,
-    OpaqueProperties properties, RegionRange regions,
+    TableOp::Adaptor adaptor,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
-  ShapeAdaptor inputShape = operands.getShape(0);
+  ShapeAdaptor inputShape(adaptor.getInput().getType());
 
   if (!inputShape.hasRank()) {
     inferredReturnShapes.push_back(ShapedTypeComponents());
@@ -704,13 +695,10 @@ LogicalResult tosa::TableOp::inferReturnTypeComponents(
 
 LogicalResult tosa::TileOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
-    ValueShapeRange operands, DictionaryAttr attributes,
-    OpaqueProperties properties, RegionRange regions,
+    TileOp::Adaptor adaptor,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
-  TileOpAdaptor adaptor(operands, attributes, *properties.as<Properties *>(),
-                        regions);
   ArrayRef<int64_t> multiples = adaptor.getMultiples();
-  ShapeAdaptor inputShape = operands.getShape(0);
+  ShapeAdaptor inputShape(adaptor.getInput1().getType());
   SmallVector<int64_t> outputShape;
   if (!inputShape.hasRank()) {
     outputShape.resize(multiples.size(), ShapedType::kDynamic);
@@ -739,13 +727,10 @@ bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
 
 LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
-    ValueShapeRange operands, DictionaryAttr attributes,
-    OpaqueProperties properties, RegionRange regions,
+    ReshapeOp::Adaptor adaptor,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
-  ReshapeOpAdaptor adaptor(operands, attributes, *properties.as<Properties *>(),
-                           regions);
-  ShapeAdaptor inputShape = operands.getShape(0);
-  Type inputType = getElementTypeOrSelf(operands.getType()[0]);
+  ShapeAdaptor inputShape(adaptor.getInput1().getType());
+  Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
   llvm::SmallVector<int64_t> newShapeValue =
       convertToMlirShape(adaptor.getNewShape());
 
@@ -814,11 +799,10 @@ LogicalResult tosa::TransposeOp::getConstantPerms(SmallVector<int64_t> &perms) {
 
 LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
-    ValueShapeRange operands, DictionaryAttr attributes,
-    OpaqueProperties properties, RegionRange regions,
+    TransposeOp::Adaptor adaptor,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
-  ShapeAdaptor inputShape = operands.getShape(0);
-  ShapeAdaptor permsShape = operands.getShape(1);
+  ShapeAdaptor inputShape(adaptor.getInput1().getType());
+  ShapeAdaptor permsShape(adaptor.getPerms().getType());
 
   // If input rank and permutation length is unknown, the output rank is
   // unknown.
@@ -869,7 +853,10 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
   outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
   // If the permuations are a constant we can directly determine the output
   // shape.
-  if (ShapeAdaptor permShape = operands.getValueAsShape(1)) {
+  DenseIntElementsAttr attr;
+  if (matchPattern(adaptor.getPerms(), m_Constant(&attr)) &&
+      attr.getType().getRank() == 1) {
+    ShapeAdaptor permShape = attr;
     outputShape.reserve(inputShape.getRank());
     for (int i = 0, s = inputShape.getRank(); i < s; i++) {
       outputShape[i] = inputShape.getDimSize(permShape.getDimSize(i));
@@ -882,19 +869,18 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
 
 LogicalResult tosa::GatherOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
-    ValueShapeRange operands, DictionaryAttr attributes,
-    OpaqueProperties properties, RegionRange regions,
+    GatherOp::Adaptor adaptor,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
   llvm::SmallVector<int64_t> outputShape;
   outputShape.resize(3, ShapedType::kDynamic);
 
-  ShapeAdaptor valuesShape = operands.getShape(0);
+  ShapeAdaptor valuesShape(adaptor.getValues().getType());
   if (valuesShape.hasRank()) {
     outputShape[0] = valuesShape.getDimSize(0);
     outputShape[2] = valuesShape.getDimSize(2);
   }
 
-  ShapeAdaptor indicesShape = operands.getShape(1);
+  ShapeAdaptor indicesShape(adaptor.getIndices().getType());
   if (indicesShape.hasRank()) {
     if (outputShape[0] == ShapedType::kDynamic)
       outputShape[0] = indicesShape.getDimSize(0);
@@ -908,15 +894,12 @@ LogicalResult tosa::GatherOp::inferReturnTypeComponents(
 
 LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
-    ValueShapeRange operands, DictionaryAttr attributes,
-    OpaqueProperties properties, RegionRange regions,
+    ResizeOp::Adaptor adaptor,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
-  ResizeOpAdaptor adaptor(operands, attributes, *properties.as<Properties *>(),
-                          regions);
   llvm::SmallVector<int64_t, 4> outputShape;
   outputShape.resize(4, ShapedType::kDynamic);
 
-  ShapeAdaptor inputShape = operands.getShape(adaptor.getInput());
+  ShapeAdaptor inputShape(adaptor.getInput().getType());
   if (!inputShape.hasRank())
     return failure();
 
@@ -950,26 +933,25 @@ LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
 
 LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
-    ValueShapeRange operands, DictionaryAttr attributes,
-    OpaqueProperties properties, RegionRange regions,
+    ScatterOp::Adaptor adaptor,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
   llvm::SmallVector<int64_t> outputShape;
   outputShape.resize(3, ShapedType::kDynamic);
 
-  ShapeAdaptor valuesInShape = operands.getShape(0);
+  ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType());
   if (valuesInShape.hasRank()) {
     outputShape[0] = valuesInShape.getDimSize(0);
     outputShape[1] = valuesInShape.getDimSize(1);
     outputShape[2] = valuesInShape.getDimSize(2);
   }
 
-  ShapeAdaptor indicesShape = operands.getShape(1);
+  ShapeAdaptor indicesShape(adaptor.getIndices().getType());
   if (indicesShape.hasRank()) {
     if (outputShape[0] == ShapedType::kDynamic)
       outputShape[0] = indicesShape.getDimSize(0);
   }
 
-  ShapeAdaptor inputShape = operands.getShape(2);
+  ShapeAdaptor inputShape(adaptor.getInput().getType());
   if (inputShape.hasRank()) {
     if (outputShape[0] == ShapedType::kDynamic)
       outputShape[0] = inputShape.getDimSize(0);
@@ -1009,13 +991,13 @@ static LogicalResult ReduceInferReturnTypes(
 #define REDUCE_SHAPE_INFER(OP)                                                 \
   LogicalResult OP::inferReturnTypeComponents(                                 \
       MLIRContext *context, ::std::optional<Location> location,                \
-      ValueShapeRange operands, DictionaryAttr attributes,                     \
-      OpaqueProperties properties, RegionRange regions,                        \
+      OP::Adaptor adaptor,                                                     \
       SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {           \
     Type inputType =                                                           \
-        llvm::cast<TensorType>(operands.getType()[0]).getElementType();        \
-    return ReduceInferReturnTypes(operands.getShape(0), inputType,             \
-                                  properties.as<Properties *>()->axis,         \
+        llvm::cast<TensorType>(adaptor.getInput().getType()).getElementType(); \
+    ShapeAdaptor inputShape(adaptor.getInput().getType());                     \
+    const Properties &prop = adaptor.getProperties();                          \
+    return ReduceInferReturnTypes(inputShape, inputType, prop.axis,            \
                                   inferredReturnShapes);                       \
   }                                                                            \
   COMPATIBLE_RETURN_TYPES(OP)
@@ -1092,10 +1074,9 @@ NARY_SHAPE_INFER(tosa::SigmoidOp)
 #undef PRED_SHAPE_INFER
 
 static LogicalResult poolingInferReturnTypes(
-    const ValueShapeRange &operands, DictionaryAttr attributes,
-    ArrayRef<int64_t> kernel, ArrayRef<int64_t> stride, ArrayRef<int64_t> pad,
+    ShapeAdaptor inputShape, ArrayRef<int64_t> kernel, ArrayRef<int64_t> stride,
+    ArrayRef<int64_t> pad,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
-  ShapeAdaptor inputShape = operands.getShape(0);
   llvm::SmallVector<int64_t> outputShape;
   outputShape.resize(4, ShapedType::kDynamic);
 
@@ -1128,12 +1109,9 @@ static LogicalResult poolingInferReturnTypes(
 
 LogicalResult Conv2DOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
-    ValueShapeRange operands, DictionaryAttr attributes,
-    OpaqueProperties properties, RegionRange regions,
+    Conv2DOp::Adaptor adaptor,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
   llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
-  Conv2DOp::Adaptor adaptor(operands, attributes,
-                            *properties.as<Properties *>(), regions);
 
   int64_t inputWidth = ShapedType::kDynamic;
   int64_t inputHeight = ShapedType::kDynamic;
@@ -1142,7 +1120,7 @@ LogicalResult Conv2DOp::inferReturnTypeComponents(
 
   // Input shape describes input width/height and batch.
 
-  ShapeAdaptor inputShape = operands.getShape(adaptor.getInput());
+  ShapeAdaptor inputShape(adaptor.getInput().getType());
   if (inputShape.hasRank()) {
     outputShape[0] = inputShape.getDimSize(0);
     inputHeight = inputShape.getDimSize(1);
@@ -1150,7 +1128,7 @@ LogicalResult Conv2DOp::inferReturnTypeComponents(
   }
 
   // Weight shapes describes the filter width/height and the output channels.
-  ShapeAdaptor weightShape = operands.getShape(adaptor.getWeight());
+  ShapeAdaptor weightShape(adaptor.getWeight().getType());
   if (weightShape.hasRank()) {
     outputShape[3] = weightShape.getDimSize(0);
     weightHeight = weightShape.getDimSize(1);
@@ -1158,7 +1136,7 @@ LogicalResult Conv2DOp::inferReturnTypeComponents(
   }
 
   // Bias shape can describe the output channels.
-  ShapeAdaptor biasShape = operands.getShape(adaptor.getBias());
+  ShapeAdaptor biasShape(adaptor.getBias().getType());
   if (biasShape.hasRank()) {
     outputShape[3] = ShapedType::isDynamic(outputShape[3])
                          ? biasShape.getDimSize(0)
@@ -1193,12 +1171,9 @@ LogicalResult Conv2DOp::verify() { return verifyConvOp(*this); }
 
 LogicalResult Conv3DOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
-    ValueShapeRange operands, DictionaryAttr attributes,
-    OpaqueProperties properties, RegionRange regions,
+    Conv3DOp::Adaptor adaptor,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
   llvm::SmallVector<int64_t> outputShape(5, ShapedType::kDynamic);
-  Conv3DOp::Adaptor adaptor(operands, attributes,
-                            *properties.as<Properties *>(), regions);
 
   int64_t inputWidth = ShapedType::kDynamic;
   int64_t inputHeight = ShapedType::kDynamic;
@@ -1209,7 +1184,7 @@ LogicalResult Conv3DOp::inferReturnTypeComponents(
   int64_t weightDepth = ShapedType::kDynamic;
 
   // Input shape describes input width/height and batch.
-  ShapeAdaptor inputShape = operands.getShape(adaptor.getInput());
+  ShapeAdaptor inputShape(adaptor.getInput().getType());
   if (inputShape.hasRank()) {
     outputShape[0] = inputShape.getDimSize(0);
     inputDepth = inputShape.getDimSize(1);
@@ -1218,7 +1193,7 @@ LogicalResult Conv3DOp::inferReturnTypeComponents(
   }
 
   // Weight shapes describes the filter width/height and the output channels.
-  ShapeAdaptor weightShape = operands.getShape(adaptor.getWeight());
+  ShapeAdaptor weightShape(adaptor.getWeight().getType());
   if (weightShape.hasRank()) {
     outputShape[4] = weightShape.getDimSize(0);
     weightDepth = weightShape.getDimSize(1);
@@ -1227,7 +1202,7 @@ LogicalResult Conv3DOp::inferReturnTypeComponents(
   }
 
   // Bias shape can describe the output channels.
-  ShapeAdaptor biasShape = operands.getShape(adaptor.getBias());
+  ShapeAdaptor biasShape(adaptor.getBias().getType());
   if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) {
     outputShape[4] = biasShape.getDimSize(0);
   }
@@ -1268,32 +1243,29 @@ LogicalResult Conv3DOp::verify() { return verifyConvOp(*this); }
 
 LogicalResult AvgPool2dOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
-    ValueShapeRange operands, DictionaryAttr attributes,
-    OpaqueProperties properties, RegionRange regions,
+    AvgPool2dOp::Adaptor adaptor,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
-  Properties &prop = *properties.as<Properties *>();
-  return poolingInferReturnTypes(operands, attributes, prop.kernel, prop.stride,
-                                 prop.pad, inferredReturnShapes);
+  ShapeAdaptor inputShape(adaptor.getInput().getType());
+  const Properties &prop = adaptor.getProperties();
+  return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
+                                 inferredReturnShapes);
 }
 
 LogicalResult MaxPool2dOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
-    ValueShapeRange operands, DictionaryAttr attributes,
-    OpaqueProperties properties, RegionRange regions,
+    MaxPool2dOp::Adaptor adaptor,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
-  Properties &prop = *properties.as<Properties *>();
-  return poolingInferReturnTypes(operands, attributes, prop.kernel, prop.stride,
-                                 prop.pad, inferredReturnShapes);
+  ShapeAdaptor inputShape(adaptor.getInput().getType());
+  const Properties &prop = adaptor.getProperties();
+  return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
+                                 inferredReturnShapes);
 }
 
 LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
-    ValueShapeRange operands, DictionaryAttr attributes,
-    OpaqueProperties properties, RegionRange regions,
+    DepthwiseConv2DOp::Adaptor adaptor,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
   llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
-  DepthwiseConv2DOp::Adaptor adaptor(operands, attributes,
-                                     *properties.as<Properties *>(), regions);
 
   int64_t inputWidth = ShapedType::kDynamic;
   int64_t inputHeight = ShapedType::kDynamic;
@@ -1304,7 +1276,7 @@ LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
   int64_t depthChannels = ShapedType::kDynamic;
 
   // Input shape describes input width/height and batch.
-  ShapeAdaptor inputShape = operands.getShape(adaptor.getInput());
+  ShapeAdaptor inputShape(adaptor.getInput().getType());
   if (inputShape.hasRank()) {
     outputShape[0] = inputShape.getDimSize(0);
     inputHeight = inputShape.getDimSize(1);
@@ -1313,7 +1285,7 @@ LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
   }
 
   // Weight shapes describes the filter width/height and the output channels.
-  ShapeAdaptor weightShape = operands.getShape(adaptor.getWeight());
+  ShapeAdaptor weightShape(adaptor.getWeight().getType());
   if (weightShape.hasRank()) {
     weightHeight = weightShape.getDimSize(0);
     weightWidth = weightShape.getDimSize(1);
@@ -1331,7 +1303,7 @@ LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
   }
 
   // Bias shape can describe the output channels.
-  ShapeAdaptor biasShape = operands.getShape(adaptor.getBias());
+  ShapeAdaptor biasShape(adaptor.getBias().getType());
   if (biasShape.hasRank()) {
     outputShape[3] = ShapedType::isDynamic(outputShape[3])
                          ? biasShape.getDimSize(0)
@@ -1366,11 +1338,8 @@ LogicalResult DepthwiseConv2DOp::verify() { return verifyConvOp(*this); }
 
 LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
-    ValueShapeRange operands, DictionaryAttr attributes,
-    OpaqueProperties properties, RegionRange regions,
+    TransposeConv2DOp::Adaptor adaptor,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
-  TransposeConv2DOp::Adaptor adaptor(operands, attributes,
-                                     *properties.as<Properties *>(), regions);
   // outputShape is mutable.
   llvm::SmallVector<int64_t> outputShape =
       convertToMlirShape(adaptor.getOutShape());
@@ -1381,7 +1350,7 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
   int64_t weightHeight = ShapedType::kDynamic;
 
   // Input shape describes input width/height and batch.
-  ShapeAdaptor inputShape = operands.getShape(adaptor.getInput());
+  ShapeAdaptor inputShape(adaptor.getInput().getType());
   if (inputShape.hasRank()) {
     outputShape[0] = ShapedType::isDynamic(outputShape[0])
                          ? inputShape.getDimSize(0)
@@ -1391,7 +1360,7 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
   }
 
   // Weight shapes describes the filter width/height and the output channels.
-  ShapeAdaptor weightShape = operands.getShape(adaptor.getFilter());
+  ShapeAdaptor weightShape(adaptor.getFilter().getType());
   if (weightShape.hasRank()) {
     outputShape[3] = ShapedType::isDynamic(outputShape[3])
                          ? weightShape.getDimSize(0)
@@ -1401,7 +1370,7 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
   }
 
   // Bias shape can describe the output channels.
-  ShapeAdaptor biasShape = operands.getShape(adaptor.getInput());
+  ShapeAdaptor biasShape(adaptor.getInput().getType());
   if (biasShape.hasRank()) {
     outputShape[3] = ShapedType::isDynamic(outputShape[3])
                          ? biasShape.getDimSize(0)
@@ -1433,11 +1402,10 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
 
 LogicalResult IfOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
-    ValueShapeRange operands, DictionaryAttr attributes,
-    OpaqueProperties properties, RegionRange regions,
+    IfOp::Adaptor adaptor,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
   llvm::SmallVector<tosa::YieldOp> yieldOps;
-  for (Region *region : regions) {
+  for (Region *region : adaptor.getRegions()) {
     for (auto &block : *region)
       if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
         yieldOps.push_back(returnOp);
@@ -1478,11 +1446,10 @@ LogicalResult IfOp::inferReturnTypeComponents(
 
 LogicalResult WhileOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
-    ValueShapeRange operands, DictionaryAttr attributes,
-    OpaqueProperties properties, RegionRange regions,
+    WhileOp::Adaptor adaptor,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
   llvm::SmallVector<tosa::YieldOp> yieldOps;
-  for (auto &block : *regions[1])
+  for (auto &block : adaptor.getBody())
     if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
       yieldOps.push_back(returnOp);
 

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 420d5d3e4c5962..03aeac4c9dff80 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -1437,9 +1437,8 @@ LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
   // Create return type consisting of the last element of the first operand.
   auto operandType = operands.front().getType();
   auto sval = dyn_cast<ShapedType>(operandType);
-  if (!sval) {
+  if (!sval)
     return emitOptionalError(location, "only shaped type operands allowed");
-  }
   int64_t dim = sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamic;
   auto type = IntegerType::get(context, 17);
 
@@ -1458,6 +1457,35 @@ LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
   return success();
 }
 
+LogicalResult
+OpWithShapedTypeInferTypeAdaptorInterfaceOp::inferReturnTypeComponents(
+    MLIRContext *context, std::optional<Location> location,
+    OpWithShapedTypeInferTypeAdaptorInterfaceOp::Adaptor adaptor,
+    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+  // Create return type consisting of the last element of the first operand.
+  auto operandType = adaptor.getOperand1().getType();
+  auto sval = dyn_cast<ShapedType>(operandType);
+  if (!sval)
+    return emitOptionalError(location, "only shaped type operands allowed");
+  int64_t dim = sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamic;
+  auto type = IntegerType::get(context, 17);
+
+  Attribute encoding;
+  if (auto rankedTy = dyn_cast<RankedTensorType>(sval))
+    encoding = rankedTy.getEncoding();
+  inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type, encoding));
+  return success();
+}
+
+LogicalResult
+OpWithShapedTypeInferTypeAdaptorInterfaceOp::reifyReturnTypeShapes(
+    OpBuilder &builder, ValueRange operands,
+    llvm::SmallVectorImpl<Value> &shapes) {
+  shapes = SmallVector<Value, 1>{
+      builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)};
+  return success();
+}
+
 LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes(
     OpBuilder &builder, ValueRange operands,
     llvm::SmallVectorImpl<Value> &shapes) {

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 7a3ae924afa9a3..389fac6f3fed6f 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -780,6 +780,13 @@ def OpWithShapedTypeInferTypeInterfaceOp : TEST_Op<"op_with_shaped_type_infer_ty
   let results = (outs AnyTensor);
 }
 
+def OpWithShapedTypeInferTypeAdaptorInterfaceOp : 
+      TEST_Op<"op_with_shaped_type_infer_type_adaptor_if",
+              [InferTensorTypeAdaptorWithReify]> {
+  let arguments = (ins AnyTensor:$operand1, AnyTensor:$operand2);
+  let results = (outs AnyTensor:$result);
+}
+
 def OpWithResultShapeInterfaceOp : TEST_Op<"op_with_result_shape_interface",
       [DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
           ["reifyReturnTypeShapes"]>]> {


        


More information about the Mlir-commits mailing list