[Mlir-commits] [mlir] 057fc8e - [ODS] Use Adaptor Trait for Shaped Type Inference
Amanda Tang
llvmlistbot at llvm.org
Thu Jul 20 12:41:14 PDT 2023
Author: Amanda Tang
Date: 2023-07-20T19:41:08Z
New Revision: 057fc8e7d8a3593f98930b8b91f80b9dd9b5fd4a
URL: https://github.com/llvm/llvm-project/commit/057fc8e7d8a3593f98930b8b91f80b9dd9b5fd4a
DIFF: https://github.com/llvm/llvm-project/commit/057fc8e7d8a3593f98930b8b91f80b9dd9b5fd4a.diff
LOG: [ODS] Use Adaptor Trait for Shaped Type Inference
Author inferReturnTypeComponents methods with the Op Adaptor by using the InferShapedTypeOpAdaptor.
Reviewed By: jpienaar
Differential Revision: https://reviews.llvm.org/D155243
Added:
Modified:
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
mlir/include/mlir/Interfaces/InferTypeOpInterface.h
mlir/include/mlir/Interfaces/InferTypeOpInterface.td
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/lib/Dialect/Test/TestOps.td
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 7a16b37f0ca417..421d6b09424a59 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -32,10 +32,7 @@ include "mlir/Dialect/Tosa/IR/TosaOpBase.td"
//===----------------------------------------------------------------------===//
// Operator: argmax
//===----------------------------------------------------------------------===//
-def Tosa_ArgMaxOp : Tosa_Op<"argmax", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- Pure]> {
+def Tosa_ArgMaxOp : Tosa_Op<"argmax", [InferShapedTypeOpAdaptor, Pure]> {
let summary = "Perform argmax on the input.";
let description = [{
@@ -62,10 +59,7 @@ def Tosa_AccType : AnyTypeOf<[I<32>, SI<32>, F16, F32]>;
//===----------------------------------------------------------------------===//
// Operator: avg_pool2d
//===----------------------------------------------------------------------===//
-def Tosa_AvgPool2dOp : Tosa_Op<"avg_pool2d", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- Pure]> {
+def Tosa_AvgPool2dOp : Tosa_Op<"avg_pool2d", [InferShapedTypeOpAdaptor, Pure]> {
let summary = "Performs max pooling on the input.";
let description = [{
@@ -95,10 +89,7 @@ def Tosa_AvgPool2dOp : Tosa_Op<"avg_pool2d", [
//===----------------------------------------------------------------------===//
// Operator: conv2d
//===----------------------------------------------------------------------===//
-def Tosa_Conv2DOp : Tosa_Op<"conv2d", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- Pure]> {
+def Tosa_Conv2DOp : Tosa_Op<"conv2d", [InferShapedTypeOpAdaptor, Pure]> {
let summary = "2D Convolution Operator";
let description = [{
@@ -128,10 +119,7 @@ def Tosa_Conv2DOp : Tosa_Op<"conv2d", [
//===----------------------------------------------------------------------===//
// Operator: conv3d
//===----------------------------------------------------------------------===//
-def Tosa_Conv3DOp : Tosa_Op<"conv3d", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- Pure]> {
+def Tosa_Conv3DOp : Tosa_Op<"conv3d", [InferShapedTypeOpAdaptor, Pure]> {
let summary = "3D Convolution operator";
let description = [{
@@ -160,10 +148,8 @@ def Tosa_Conv3DOp : Tosa_Op<"conv3d", [
//===----------------------------------------------------------------------===//
// Operator: depthwise_conv2d
//===----------------------------------------------------------------------===//
-def Tosa_DepthwiseConv2DOp : Tosa_Op<"depthwise_conv2d", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- Pure]> {
+def Tosa_DepthwiseConv2DOp : Tosa_Op<"depthwise_conv2d",
+ [InferShapedTypeOpAdaptor, Pure]> {
let summary = "Depthwise 2D Convolution operator";
let description = [{
@@ -193,10 +179,7 @@ def Tosa_DepthwiseConv2DOp : Tosa_Op<"depthwise_conv2d", [
//===----------------------------------------------------------------------===//
// Operator: fft2d
//===----------------------------------------------------------------------===//
-def Tosa_FFT2dOp : Tosa_Op<"fft2d", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- Pure]> {
+def Tosa_FFT2dOp : Tosa_Op<"fft2d", [InferShapedTypeOpAdaptor, Pure]> {
let summary = "Performs FFT2D operation on the input.";
let description = [{
@@ -224,9 +207,7 @@ def Tosa_FFT2dOp : Tosa_Op<"fft2d", [
// Operator: fully_connected
//===----------------------------------------------------------------------===//
def Tosa_FullyConnectedOp : Tosa_Op<"fully_connected", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- Pure]> {
+ InferShapedTypeOpAdaptor, Pure]> {
let summary = "Fully Connected operator";
let description = [{
@@ -251,10 +232,7 @@ def Tosa_FullyConnectedOp : Tosa_Op<"fully_connected", [
//===----------------------------------------------------------------------===//
// Operator: matmul
//===----------------------------------------------------------------------===//
-def Tosa_MatMulOp : Tosa_Op<"matmul", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- Pure]> {
+def Tosa_MatMulOp : Tosa_Op<"matmul", [InferShapedTypeOpAdaptor, Pure]> {
let summary = "Matrix multiplication with bias";
let description = [{
@@ -279,10 +257,7 @@ def Tosa_MatMulOp : Tosa_Op<"matmul", [
//===----------------------------------------------------------------------===//
// Operator: max_pool2d
//===----------------------------------------------------------------------===//
-def Tosa_MaxPool2dOp : Tosa_Op<"max_pool2d", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- Pure]> {
+def Tosa_MaxPool2dOp : Tosa_Op<"max_pool2d", [InferShapedTypeOpAdaptor, Pure]> {
let summary = "Performs max pooling on the input.";
let description = [{
@@ -310,10 +285,7 @@ def Tosa_MaxPool2dOp : Tosa_Op<"max_pool2d", [
//===----------------------------------------------------------------------===//
// Operator: rfft2d
//===----------------------------------------------------------------------===//
-def Tosa_RFFT2dOp : Tosa_Op<"rfft2d", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- Pure]> {
+def Tosa_RFFT2dOp : Tosa_Op<"rfft2d", [InferShapedTypeOpAdaptor, Pure]> {
let summary = "Performs RFFT2D operation on the input.";
let description = [{
@@ -338,10 +310,8 @@ def Tosa_RFFT2dOp : Tosa_Op<"rfft2d", [
//===----------------------------------------------------------------------===//
// Operator: transpose_conv2d
//===----------------------------------------------------------------------===//
-def Tosa_TransposeConv2DOp : Tosa_Op<"transpose_conv2d", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- Pure]> {
+def Tosa_TransposeConv2DOp : Tosa_Op<"transpose_conv2d",
+ [InferShapedTypeOpAdaptor, Pure]> {
let summary = "Transpose 2D Convolution operator.";
let description = [{
@@ -828,10 +798,7 @@ def Tosa_SubOp : Tosa_ElemWiseBinaryOp<"sub"> {
//===----------------------------------------------------------------------===//
// Operator: table
//===----------------------------------------------------------------------===//
-def Tosa_TableOp : Tosa_Op<"table", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- Pure]> {
+def Tosa_TableOp : Tosa_Op<"table", [InferShapedTypeOpAdaptor, Pure]> {
let summary = "Table lookup op";
let description = [{
@@ -1214,7 +1181,7 @@ def Tosa_GreaterEqualOp : Tosa_Op<"greater_equal", [
// Operator: reduce_all
//===----------------------------------------------------------------------===//
def Tosa_ReduceAllOp : Tosa_Op<"reduce_all", [
- InferTensorType, Pure]> {
+ InferTensorTypeAdaptor, Pure]> {
let summary = "Reduce All operator";
let description = [{
@@ -1243,7 +1210,7 @@ def Tosa_ReduceAllOp : Tosa_Op<"reduce_all", [
// Operator: reduce_any
//===----------------------------------------------------------------------===//
def Tosa_ReduceAnyOp : Tosa_Op<"reduce_any", [
- InferTensorType, Pure]> {
+ InferTensorTypeAdaptor, Pure]> {
let summary = "Reduce Any operator";
let description = [{
@@ -1272,7 +1239,7 @@ def Tosa_ReduceAnyOp : Tosa_Op<"reduce_any", [
// Operator: reduce_max
//===----------------------------------------------------------------------===//
def Tosa_ReduceMaxOp : Tosa_Op<"reduce_max", [
- InferTensorType, Pure]> {
+ InferTensorTypeAdaptor, Pure]> {
let summary = "Reduce Max operator";
let description = [{
@@ -1301,7 +1268,7 @@ def Tosa_ReduceMaxOp : Tosa_Op<"reduce_max", [
// Operator: reduce_min
//===----------------------------------------------------------------------===//
def Tosa_ReduceMinOp : Tosa_Op<"reduce_min", [
- InferTensorType, Pure]> {
+ InferTensorTypeAdaptor, Pure]> {
let summary = "Reduce Min operator";
let description = [{
@@ -1330,7 +1297,7 @@ def Tosa_ReduceMinOp : Tosa_Op<"reduce_min", [
// Operator: reduce_prod
//===----------------------------------------------------------------------===//
def Tosa_ReduceProdOp : Tosa_Op<"reduce_prod", [
- InferTensorType, Pure]> {
+ InferTensorTypeAdaptor, Pure]> {
let summary = "Reduce Prod operator";
let description = [{
@@ -1359,7 +1326,7 @@ def Tosa_ReduceProdOp : Tosa_Op<"reduce_prod", [
// Operator: reduce_sum
//===----------------------------------------------------------------------===//
def Tosa_ReduceSumOp : Tosa_Op<"reduce_sum", [
- InferTensorType, Pure]> {
+ InferTensorTypeAdaptor, Pure]> {
let summary = "Reduce Sum operator";
let description = [{
@@ -1393,7 +1360,7 @@ def Tosa_ReduceSumOp : Tosa_Op<"reduce_sum", [
// Operator: concat
//===----------------------------------------------------------------------===//
def Tosa_ConcatOp : Tosa_Op<"concat", [
- InferTensorType, Pure]> {
+ InferTensorTypeAdaptor, Pure]> {
let summary = "Concatenates tensors along one dimension.";
let description = [{
@@ -1423,10 +1390,7 @@ def Tosa_ConcatOp : Tosa_Op<"concat", [
//===----------------------------------------------------------------------===//
// Operator: pad
//===----------------------------------------------------------------------===//
-def Tosa_PadOp : Tosa_Op<"pad", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- Pure]> {
+def Tosa_PadOp : Tosa_Op<"pad", [InferShapedTypeOpAdaptor, Pure]> {
let summary = "Pads a tensor with value specified.";
let description = [{
@@ -1471,7 +1435,7 @@ def Tosa_PadOp : Tosa_Op<"pad", [
// Operator: reshape
//===----------------------------------------------------------------------===//
def Tosa_ReshapeOp: Tosa_Op<"reshape", [
- InferTensorType, Pure]> {
+ InferTensorTypeAdaptor, Pure]> {
let summary = "Reshape operator";
let description = [{
@@ -1528,9 +1492,7 @@ def Tosa_ReverseOp: Tosa_Op<"reverse", [
//===----------------------------------------------------------------------===//
// Operator: slice
//===----------------------------------------------------------------------===//
-def Tosa_SliceOp: Tosa_Op<"slice", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>, Pure]> {
+def Tosa_SliceOp: Tosa_Op<"slice", [InferShapedTypeOpAdaptor, Pure]> {
let summary = "Slice operator";
let description = [{
@@ -1556,10 +1518,7 @@ def Tosa_SliceOp: Tosa_Op<"slice", [
//===----------------------------------------------------------------------===//
// Operator: tile
//===----------------------------------------------------------------------===//
-def Tosa_TileOp: Tosa_Op<"tile", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- Pure]> {
+def Tosa_TileOp: Tosa_Op<"tile", [InferShapedTypeOpAdaptor, Pure]> {
let summary = "Tile operator";
let description = [{
@@ -1580,10 +1539,7 @@ def Tosa_TileOp: Tosa_Op<"tile", [
//===----------------------------------------------------------------------===//
// Operator: transpose
//===----------------------------------------------------------------------===//
-def Tosa_TransposeOp : Tosa_Op<"transpose", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- Pure]> {
+def Tosa_TransposeOp : Tosa_Op<"transpose", [InferShapedTypeOpAdaptor, Pure]> {
let summary = "Transpose operator";
let description = [{
@@ -1615,10 +1571,7 @@ def Tosa_TransposeOp : Tosa_Op<"transpose", [
//===----------------------------------------------------------------------===//
// Operator: gather
//===----------------------------------------------------------------------===//
-def Tosa_GatherOp : Tosa_Op<"gather", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- Pure]> {
+def Tosa_GatherOp : Tosa_Op<"gather", [InferShapedTypeOpAdaptor, Pure]> {
let summary = "Gather operation,";
let description = [{
@@ -1639,10 +1592,7 @@ def Tosa_GatherOp : Tosa_Op<"gather", [
//===----------------------------------------------------------------------===//
// Operator: scatter
//===----------------------------------------------------------------------===//
-def Tosa_ScatterOp : Tosa_Op<"scatter", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- Pure]> {
+def Tosa_ScatterOp : Tosa_Op<"scatter", [InferShapedTypeOpAdaptor, Pure]> {
let summary = "Scatter operation,";
let description = [{
@@ -1669,10 +1619,7 @@ def Tosa_ScatterOp : Tosa_Op<"scatter", [
//===----------------------------------------------------------------------===//
// Operator: resize
//===----------------------------------------------------------------------===//
-def Tosa_ResizeOp : Tosa_Op<"resize", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- Pure]> {
+def Tosa_ResizeOp : Tosa_Op<"resize", [InferShapedTypeOpAdaptor, Pure]> {
let summary = "Resize operation, supports various resize/upsample modes";
@@ -1898,9 +1845,8 @@ def Tosa_CustomOp : Tosa_Op<"custom"> {
//===----------------------------------------------------------------------===//
// Further described in docs/Rationale/RationaleTOSADialect.md .
//===----------------------------------------------------------------------===//
-def Tosa_IfOp : Tosa_Op<"cond_if", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
+def Tosa_IfOp : Tosa_Op<"cond_if",
+ [InferShapedTypeOpAdaptor,
SingleBlockImplicitTerminator<"YieldOp">,
RecursiveMemoryEffects]> {
let summary = "Conditional if operator";
@@ -1933,8 +1879,7 @@ def Tosa_IfOp : Tosa_Op<"cond_if", [
//===----------------------------------------------------------------------===//
def Tosa_WhileOp : Tosa_Op<"while_loop", [
DeclareOpInterfaceMethods<LoopLikeOpInterface>,
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
+ InferShapedTypeOpAdaptor,
SingleBlockImplicitTerminator<"YieldOp">,
RecursiveMemoryEffects]> {
let summary = "output = input; While (Cond(output)) {output = Body(output)}";
diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
index b8a664fc6882f0..67de05b0cb4ff3 100644
--- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
+++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
@@ -262,6 +262,10 @@ template <typename ConcreteType>
class InferTypeOpAdaptor : public TraitBase<ConcreteType, InferTypeOpAdaptor> {
};
+template <typename ConcreteType>
+class InferShapedTypeOpAdaptor
+ : public TraitBase<ConcreteType, InferShapedTypeOpAdaptor> {};
+
/// Tensor type inference trait that constructs a tensor from the inferred
/// shape and elemental types.
/// Requires: Op implements InferShapedTypeOpInterface and InferTypeOpInterface.
diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
index a458887b374543..1ceaf25a994e01 100644
--- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
+++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
@@ -222,6 +222,42 @@ def InferTypeOpAdaptorWithIsCompatible : InferTypeOpAdaptorBase<
}]
>;
+// Convenient trait to define a wrapper to inferReturnTypeComponents that passes
+// in the Op Adaptor directly
+class InferShapedTypeOpAdaptorBase<list<string> overridenMethods = []> : TraitList<
+ [
+ // Op implements infer type op interface.
+ DeclareOpInterfaceMethods<InferShapedTypeOpInterface, overridenMethods>,
+ NativeOpTrait<
+ /*name=*/"InferShapedTypeOpAdaptor",
+ /*traits=*/[],
+ /*extraOpDeclaration=*/[{
+ static ::mlir::LogicalResult
+ inferReturnTypeComponents(::mlir::MLIRContext *context,
+ std::optional<::mlir::Location> location,
+ Adaptor adaptor,
+ ::llvm::SmallVectorImpl<::mlir::ShapedTypeComponents> &inferredReturnShapes);
+ }],
+ /*extraOpDefinition=*/[{
+ ::mlir::LogicalResult
+ $cppClass::inferReturnTypeComponents(::mlir::MLIRContext *context,
+ std::optional<::mlir::Location> location,
+ ::mlir::ValueShapeRange operands, ::mlir::DictionaryAttr attributes,
+ ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions,
+ ::llvm::SmallVectorImpl<::mlir::ShapedTypeComponents> &inferredReturnShapes) {
+ $cppClass::Adaptor adaptor(operands, attributes, properties, regions);
+ return $cppClass::inferReturnTypeComponents(context,
+ location, adaptor, inferredReturnShapes);
+ }
+ }]
+ >
+ ]>;
+
+def InferShapedTypeOpAdaptor : InferShapedTypeOpAdaptorBase<[
+ "inferReturnTypeComponents"]>;
+def InferShapedTypeOpAdaptorWithReify : InferShapedTypeOpAdaptorBase<[
+ "inferReturnTypeComponents", "reifyReturnTypeShapes"]>;
+
// Convenience class grouping together type and shaped type op interfaces for
// ops that have tensor return types.
class InferTensorTypeBase<list<string> overridenMethods = []> : TraitList<
@@ -260,6 +296,44 @@ def InferTensorType : InferTensorTypeBase<["inferReturnTypeComponents"]>;
def InferTensorTypeWithReify: InferTensorTypeBase<[
"inferReturnTypeComponents", "reifyReturnTypeShapes"]>;
+// Convenience class grouping together type and shaped type op interfaces for
+// ops that have tensor return types.
+class InferTensorTypeAdaptorBase<list<string> overridenMethods = []> : TraitList<
+ [
+ // Op implements infer type op interface.
+ DeclareOpInterfaceMethods<InferTypeOpInterface>,
+ // The op will have methods implementing the ShapedType type inference
+ // interface.
+ InferShapedTypeOpAdaptorBase<overridenMethods>,
+ // The op produces tensors and will use the ShapedType type infer interface
+ // along with knowledge that it is producing Tensors to infer the type.
+ NativeOpTrait<
+ /*name=*/"InferTensorType",
+ /*traits=*/[],
+ /*extraOpDeclaration=*/[{}],
+ /*extraOpDefinition=*/[{
+ LogicalResult
+ $cppClass::inferReturnTypes(::mlir::MLIRContext *context,
+ std::optional<::mlir::Location> location,
+ ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
+ ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions,
+ ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
+ SmallVector<ShapedTypeComponents, 2> retComponents;
+ if (failed($cppClass::inferReturnTypeComponents(context, location,
+ operands, attributes, properties, regions,
+ retComponents)))
+ return failure();
+ return ::mlir::detail::inferReturnTensorTypes(retComponents,
+ inferredReturnTypes);
+ }
+ }]
+ >
+ ]>;
+
+def InferTensorTypeAdaptor : InferTensorTypeAdaptorBase<["inferReturnTypeComponents"]>;
+def InferTensorTypeAdaptorWithReify: InferTensorTypeAdaptorBase<[
+ "inferReturnTypeComponents", "reifyReturnTypeShapes"]>;
+
def ReifyRankedShapedTypeOpInterface :
OpInterface<"ReifyRankedShapedTypeOpInterface"> {
let description = [{
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 7577c5893c775b..7b67040b4f6340 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -22,6 +22,7 @@
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/TypeSwitch.h"
@@ -404,12 +405,10 @@ static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands,
LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
- ValueShapeRange operands, DictionaryAttr attributes,
- OpaqueProperties properties, RegionRange regions,
+ ArgMaxOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
- ShapeAdaptor inputShape = operands.getShape(0);
- auto *prop = properties.as<Properties *>();
- IntegerAttr axis = prop->axis;
+ ShapeAdaptor inputShape(adaptor.getInput().getType());
+ IntegerAttr axis = adaptor.getProperties().axis;
int32_t axisVal = axis.getValue().getSExtValue();
if (!inputShape.hasRank()) {
@@ -431,10 +430,9 @@ LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
- ValueShapeRange operands, DictionaryAttr attributes,
- OpaqueProperties properties, RegionRange regions,
+ RFFT2dOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
- ShapeAdaptor inputShape = operands.getShape(0);
+ ShapeAdaptor inputShape(adaptor.getInput().getType());
if (!inputShape.hasRank())
return failure();
@@ -458,26 +456,26 @@ LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
- ValueShapeRange operands, DictionaryAttr attributes,
- OpaqueProperties properties, RegionRange regions,
+ FFT2dOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
- inferredReturnShapes.push_back(ShapedTypeComponents(operands.getShape(0)));
- inferredReturnShapes.push_back(ShapedTypeComponents(operands.getShape(1)));
+ inferredReturnShapes.push_back(
+ ShapedTypeComponents(ShapeAdaptor(adaptor.getInputReal().getType())));
+ inferredReturnShapes.push_back(
+ ShapedTypeComponents(ShapeAdaptor(adaptor.getInputImag().getType())));
return success();
}
LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
- ValueShapeRange operands, DictionaryAttr attributes,
- OpaqueProperties properties, RegionRange regions,
+ ConcatOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
// Infer all dimension sizes by reducing based on inputs.
- auto *prop = properties.as<Properties *>();
- int32_t axis = prop->axis.getValue().getSExtValue();
+ const Properties &prop = adaptor.getProperties();
+ int32_t axis = prop.axis.getValue().getSExtValue();
llvm::SmallVector<int64_t> outputShape;
bool hasRankedInput = false;
- for (auto operand : operands) {
- ShapeAdaptor operandShape = operands.getShape(operand);
+ for (auto operand : adaptor.getOperands()) {
+ ShapeAdaptor operandShape(operand.getType());
if (!operandShape.hasRank())
continue;
@@ -501,7 +499,7 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
hasRankedInput = true;
}
Type inputType =
- llvm::cast<TensorType>(operands.getType()[0]).getElementType();
+ llvm::cast<TensorType>(adaptor.getInput1().getType()[0]).getElementType();
if (!hasRankedInput) {
inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
return success();
@@ -509,8 +507,8 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
// Determine the dimension size along the concatenation axis.
int64_t concatDimSize = 0;
- for (auto operand : operands) {
- ShapeAdaptor operandShape = operands.getShape(operand);
+ for (auto operand : adaptor.getOperands()) {
+ ShapeAdaptor operandShape(operand.getType());
// We need to know the length of the concatenation axis of all inputs to
// determine the dimension size of the output shape.
@@ -553,12 +551,11 @@ bool tosa::EqualOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
- ValueShapeRange operands, DictionaryAttr attributes,
- OpaqueProperties properties, RegionRange regions,
+ FullyConnectedOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
- ShapeAdaptor inputShape = operands.getShape(0);
- ShapeAdaptor weightShape = operands.getShape(1);
- ShapeAdaptor biasShape = operands.getShape(2);
+ ShapeAdaptor inputShape(adaptor.getInput().getType());
+ ShapeAdaptor weightShape(adaptor.getWeight().getType());
+ ShapeAdaptor biasShape(adaptor.getBias().getType());
// All shapes are dynamic.
SmallVector<int64_t> outShape;
@@ -585,11 +582,10 @@ LogicalResult FullyConnectedOp::verify() { return verifyConvOp(*this); }
LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
- ValueShapeRange operands, DictionaryAttr attributes,
- OpaqueProperties properties, RegionRange regions,
+ MatMulOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
- ShapeAdaptor lhsShape = operands.getShape(0);
- ShapeAdaptor rhsShape = operands.getShape(1);
+ ShapeAdaptor lhsShape(adaptor.getA().getType());
+ ShapeAdaptor rhsShape(adaptor.getB().getType());
// All shapes are dynamic.
SmallVector<int64_t> outShape;
@@ -612,11 +608,10 @@ LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
LogicalResult tosa::PadOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
- ValueShapeRange operands, DictionaryAttr attributes,
- OpaqueProperties properties, RegionRange regions,
+ PadOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
- ShapeAdaptor inputShape = operands.getShape(0);
- ShapeAdaptor paddingShape = operands.getShape(1);
+ ShapeAdaptor inputShape(adaptor.getInput1().getType());
+ ShapeAdaptor paddingShape(adaptor.getPadding().getType());
SmallVector<int64_t> outputShape;
// If both inputs have unknown shape, we cannot determine the shape of the
@@ -641,7 +636,7 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
DenseIntElementsAttr paddings;
// If the paddings value is not a constant, all dimensions must be dynamic.
- if (!matchPattern(operands[1], m_Constant(&paddings))) {
+ if (!matchPattern(adaptor.getPadding(), m_Constant(&paddings))) {
outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
return success();
@@ -675,22 +670,18 @@ static SmallVector<int64_t> convertToMlirShape(ArrayRef<int64_t> shape) {
LogicalResult tosa::SliceOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
- ValueShapeRange operands, DictionaryAttr attributes,
- OpaqueProperties properties, RegionRange regions,
+ SliceOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
- inferredReturnShapes.push_back(ShapedTypeComponents(
- convertToMlirShape(SliceOpAdaptor(operands, attributes,
- *properties.as<Properties *>(), regions)
- .getSize())));
+ inferredReturnShapes.push_back(
+ ShapedTypeComponents(convertToMlirShape(adaptor.getSize())));
return success();
}
LogicalResult tosa::TableOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
- ValueShapeRange operands, DictionaryAttr attributes,
- OpaqueProperties properties, RegionRange regions,
+ TableOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
- ShapeAdaptor inputShape = operands.getShape(0);
+ ShapeAdaptor inputShape(adaptor.getInput().getType());
if (!inputShape.hasRank()) {
inferredReturnShapes.push_back(ShapedTypeComponents());
@@ -704,13 +695,10 @@ LogicalResult tosa::TableOp::inferReturnTypeComponents(
LogicalResult tosa::TileOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
- ValueShapeRange operands, DictionaryAttr attributes,
- OpaqueProperties properties, RegionRange regions,
+ TileOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
- TileOpAdaptor adaptor(operands, attributes, *properties.as<Properties *>(),
- regions);
ArrayRef<int64_t> multiples = adaptor.getMultiples();
- ShapeAdaptor inputShape = operands.getShape(0);
+ ShapeAdaptor inputShape(adaptor.getInput1().getType());
SmallVector<int64_t> outputShape;
if (!inputShape.hasRank()) {
outputShape.resize(multiples.size(), ShapedType::kDynamic);
@@ -739,13 +727,10 @@ bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
- ValueShapeRange operands, DictionaryAttr attributes,
- OpaqueProperties properties, RegionRange regions,
+ ReshapeOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
- ReshapeOpAdaptor adaptor(operands, attributes, *properties.as<Properties *>(),
- regions);
- ShapeAdaptor inputShape = operands.getShape(0);
- Type inputType = getElementTypeOrSelf(operands.getType()[0]);
+ ShapeAdaptor inputShape(adaptor.getInput1().getType());
+ Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
llvm::SmallVector<int64_t> newShapeValue =
convertToMlirShape(adaptor.getNewShape());
@@ -814,11 +799,10 @@ LogicalResult tosa::TransposeOp::getConstantPerms(SmallVector<int64_t> &perms) {
LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
- ValueShapeRange operands, DictionaryAttr attributes,
- OpaqueProperties properties, RegionRange regions,
+ TransposeOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
- ShapeAdaptor inputShape = operands.getShape(0);
- ShapeAdaptor permsShape = operands.getShape(1);
+ ShapeAdaptor inputShape(adaptor.getInput1().getType());
+ ShapeAdaptor permsShape(adaptor.getPerms().getType());
// If input rank and permutation length is unknown, the output rank is
// unknown.
@@ -869,7 +853,10 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
// If the permuations are a constant we can directly determine the output
// shape.
- if (ShapeAdaptor permShape = operands.getValueAsShape(1)) {
+ DenseIntElementsAttr attr;
+ if (matchPattern(adaptor.getPerms(), m_Constant(&attr)) &&
+ attr.getType().getRank() == 1) {
+ ShapeAdaptor permShape = attr;
outputShape.reserve(inputShape.getRank());
for (int i = 0, s = inputShape.getRank(); i < s; i++) {
outputShape[i] = inputShape.getDimSize(permShape.getDimSize(i));
@@ -882,19 +869,18 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
LogicalResult tosa::GatherOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
- ValueShapeRange operands, DictionaryAttr attributes,
- OpaqueProperties properties, RegionRange regions,
+ GatherOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
llvm::SmallVector<int64_t> outputShape;
outputShape.resize(3, ShapedType::kDynamic);
- ShapeAdaptor valuesShape = operands.getShape(0);
+ ShapeAdaptor valuesShape(adaptor.getValues().getType());
if (valuesShape.hasRank()) {
outputShape[0] = valuesShape.getDimSize(0);
outputShape[2] = valuesShape.getDimSize(2);
}
- ShapeAdaptor indicesShape = operands.getShape(1);
+ ShapeAdaptor indicesShape(adaptor.getIndices().getType());
if (indicesShape.hasRank()) {
if (outputShape[0] == ShapedType::kDynamic)
outputShape[0] = indicesShape.getDimSize(0);
@@ -908,15 +894,12 @@ LogicalResult tosa::GatherOp::inferReturnTypeComponents(
LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
- ValueShapeRange operands, DictionaryAttr attributes,
- OpaqueProperties properties, RegionRange regions,
+ ResizeOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
- ResizeOpAdaptor adaptor(operands, attributes, *properties.as<Properties *>(),
- regions);
llvm::SmallVector<int64_t, 4> outputShape;
outputShape.resize(4, ShapedType::kDynamic);
- ShapeAdaptor inputShape = operands.getShape(adaptor.getInput());
+ ShapeAdaptor inputShape(adaptor.getInput().getType());
if (!inputShape.hasRank())
return failure();
@@ -950,26 +933,25 @@ LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
- ValueShapeRange operands, DictionaryAttr attributes,
- OpaqueProperties properties, RegionRange regions,
+ ScatterOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
llvm::SmallVector<int64_t> outputShape;
outputShape.resize(3, ShapedType::kDynamic);
- ShapeAdaptor valuesInShape = operands.getShape(0);
+ ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType());
if (valuesInShape.hasRank()) {
outputShape[0] = valuesInShape.getDimSize(0);
outputShape[1] = valuesInShape.getDimSize(1);
outputShape[2] = valuesInShape.getDimSize(2);
}
- ShapeAdaptor indicesShape = operands.getShape(1);
+ ShapeAdaptor indicesShape(adaptor.getIndices().getType());
if (indicesShape.hasRank()) {
if (outputShape[0] == ShapedType::kDynamic)
outputShape[0] = indicesShape.getDimSize(0);
}
- ShapeAdaptor inputShape = operands.getShape(2);
+ ShapeAdaptor inputShape(adaptor.getInput().getType());
if (inputShape.hasRank()) {
if (outputShape[0] == ShapedType::kDynamic)
outputShape[0] = inputShape.getDimSize(0);
@@ -1009,13 +991,13 @@ static LogicalResult ReduceInferReturnTypes(
#define REDUCE_SHAPE_INFER(OP) \
LogicalResult OP::inferReturnTypeComponents( \
MLIRContext *context, ::std::optional<Location> location, \
- ValueShapeRange operands, DictionaryAttr attributes, \
- OpaqueProperties properties, RegionRange regions, \
+ OP::Adaptor adaptor, \
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
Type inputType = \
- llvm::cast<TensorType>(operands.getType()[0]).getElementType(); \
- return ReduceInferReturnTypes(operands.getShape(0), inputType, \
- properties.as<Properties *>()->axis, \
+ llvm::cast<TensorType>(adaptor.getInput().getType()).getElementType(); \
+ ShapeAdaptor inputShape(adaptor.getInput().getType()); \
+ const Properties &prop = adaptor.getProperties(); \
+ return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \
inferredReturnShapes); \
} \
COMPATIBLE_RETURN_TYPES(OP)
@@ -1092,10 +1074,9 @@ NARY_SHAPE_INFER(tosa::SigmoidOp)
#undef PRED_SHAPE_INFER
static LogicalResult poolingInferReturnTypes(
- const ValueShapeRange &operands, DictionaryAttr attributes,
- ArrayRef<int64_t> kernel, ArrayRef<int64_t> stride, ArrayRef<int64_t> pad,
+ ShapeAdaptor inputShape, ArrayRef<int64_t> kernel, ArrayRef<int64_t> stride,
+ ArrayRef<int64_t> pad,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
- ShapeAdaptor inputShape = operands.getShape(0);
llvm::SmallVector<int64_t> outputShape;
outputShape.resize(4, ShapedType::kDynamic);
@@ -1128,12 +1109,9 @@ static LogicalResult poolingInferReturnTypes(
LogicalResult Conv2DOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
- ValueShapeRange operands, DictionaryAttr attributes,
- OpaqueProperties properties, RegionRange regions,
+ Conv2DOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
- Conv2DOp::Adaptor adaptor(operands, attributes,
- *properties.as<Properties *>(), regions);
int64_t inputWidth = ShapedType::kDynamic;
int64_t inputHeight = ShapedType::kDynamic;
@@ -1142,7 +1120,7 @@ LogicalResult Conv2DOp::inferReturnTypeComponents(
// Input shape describes input width/height and batch.
- ShapeAdaptor inputShape = operands.getShape(adaptor.getInput());
+ ShapeAdaptor inputShape(adaptor.getInput().getType());
if (inputShape.hasRank()) {
outputShape[0] = inputShape.getDimSize(0);
inputHeight = inputShape.getDimSize(1);
@@ -1150,7 +1128,7 @@ LogicalResult Conv2DOp::inferReturnTypeComponents(
}
// Weight shapes describes the filter width/height and the output channels.
- ShapeAdaptor weightShape = operands.getShape(adaptor.getWeight());
+ ShapeAdaptor weightShape(adaptor.getWeight().getType());
if (weightShape.hasRank()) {
outputShape[3] = weightShape.getDimSize(0);
weightHeight = weightShape.getDimSize(1);
@@ -1158,7 +1136,7 @@ LogicalResult Conv2DOp::inferReturnTypeComponents(
}
// Bias shape can describe the output channels.
- ShapeAdaptor biasShape = operands.getShape(adaptor.getBias());
+ ShapeAdaptor biasShape(adaptor.getBias().getType());
if (biasShape.hasRank()) {
outputShape[3] = ShapedType::isDynamic(outputShape[3])
? biasShape.getDimSize(0)
@@ -1193,12 +1171,9 @@ LogicalResult Conv2DOp::verify() { return verifyConvOp(*this); }
LogicalResult Conv3DOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
- ValueShapeRange operands, DictionaryAttr attributes,
- OpaqueProperties properties, RegionRange regions,
+ Conv3DOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
llvm::SmallVector<int64_t> outputShape(5, ShapedType::kDynamic);
- Conv3DOp::Adaptor adaptor(operands, attributes,
- *properties.as<Properties *>(), regions);
int64_t inputWidth = ShapedType::kDynamic;
int64_t inputHeight = ShapedType::kDynamic;
@@ -1209,7 +1184,7 @@ LogicalResult Conv3DOp::inferReturnTypeComponents(
int64_t weightDepth = ShapedType::kDynamic;
// Input shape describes input width/height and batch.
- ShapeAdaptor inputShape = operands.getShape(adaptor.getInput());
+ ShapeAdaptor inputShape(adaptor.getInput().getType());
if (inputShape.hasRank()) {
outputShape[0] = inputShape.getDimSize(0);
inputDepth = inputShape.getDimSize(1);
@@ -1218,7 +1193,7 @@ LogicalResult Conv3DOp::inferReturnTypeComponents(
}
// Weight shapes describes the filter width/height and the output channels.
- ShapeAdaptor weightShape = operands.getShape(adaptor.getWeight());
+ ShapeAdaptor weightShape(adaptor.getWeight().getType());
if (weightShape.hasRank()) {
outputShape[4] = weightShape.getDimSize(0);
weightDepth = weightShape.getDimSize(1);
@@ -1227,7 +1202,7 @@ LogicalResult Conv3DOp::inferReturnTypeComponents(
}
// Bias shape can describe the output channels.
- ShapeAdaptor biasShape = operands.getShape(adaptor.getBias());
+ ShapeAdaptor biasShape(adaptor.getBias().getType());
if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) {
outputShape[4] = biasShape.getDimSize(0);
}
@@ -1268,32 +1243,29 @@ LogicalResult Conv3DOp::verify() { return verifyConvOp(*this); }
LogicalResult AvgPool2dOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
- ValueShapeRange operands, DictionaryAttr attributes,
- OpaqueProperties properties, RegionRange regions,
+ AvgPool2dOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
- Properties &prop = *properties.as<Properties *>();
- return poolingInferReturnTypes(operands, attributes, prop.kernel, prop.stride,
- prop.pad, inferredReturnShapes);
+ ShapeAdaptor inputShape(adaptor.getInput().getType());
+ const Properties &prop = adaptor.getProperties();
+ return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
+ inferredReturnShapes);
}
LogicalResult MaxPool2dOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
- ValueShapeRange operands, DictionaryAttr attributes,
- OpaqueProperties properties, RegionRange regions,
+ MaxPool2dOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
- Properties &prop = *properties.as<Properties *>();
- return poolingInferReturnTypes(operands, attributes, prop.kernel, prop.stride,
- prop.pad, inferredReturnShapes);
+ ShapeAdaptor inputShape(adaptor.getInput().getType());
+ const Properties &prop = adaptor.getProperties();
+ return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
+ inferredReturnShapes);
}
LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
- ValueShapeRange operands, DictionaryAttr attributes,
- OpaqueProperties properties, RegionRange regions,
+ DepthwiseConv2DOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
- DepthwiseConv2DOp::Adaptor adaptor(operands, attributes,
- *properties.as<Properties *>(), regions);
int64_t inputWidth = ShapedType::kDynamic;
int64_t inputHeight = ShapedType::kDynamic;
@@ -1304,7 +1276,7 @@ LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
int64_t depthChannels = ShapedType::kDynamic;
// Input shape describes input width/height and batch.
- ShapeAdaptor inputShape = operands.getShape(adaptor.getInput());
+ ShapeAdaptor inputShape(adaptor.getInput().getType());
if (inputShape.hasRank()) {
outputShape[0] = inputShape.getDimSize(0);
inputHeight = inputShape.getDimSize(1);
@@ -1313,7 +1285,7 @@ LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
}
// Weight shapes describes the filter width/height and the output channels.
- ShapeAdaptor weightShape = operands.getShape(adaptor.getWeight());
+ ShapeAdaptor weightShape(adaptor.getWeight().getType());
if (weightShape.hasRank()) {
weightHeight = weightShape.getDimSize(0);
weightWidth = weightShape.getDimSize(1);
@@ -1331,7 +1303,7 @@ LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
}
// Bias shape can describe the output channels.
- ShapeAdaptor biasShape = operands.getShape(adaptor.getBias());
+ ShapeAdaptor biasShape(adaptor.getBias().getType());
if (biasShape.hasRank()) {
outputShape[3] = ShapedType::isDynamic(outputShape[3])
? biasShape.getDimSize(0)
@@ -1366,11 +1338,8 @@ LogicalResult DepthwiseConv2DOp::verify() { return verifyConvOp(*this); }
LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
- ValueShapeRange operands, DictionaryAttr attributes,
- OpaqueProperties properties, RegionRange regions,
+ TransposeConv2DOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
- TransposeConv2DOp::Adaptor adaptor(operands, attributes,
- *properties.as<Properties *>(), regions);
// outputShape is mutable.
llvm::SmallVector<int64_t> outputShape =
convertToMlirShape(adaptor.getOutShape());
@@ -1381,7 +1350,7 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
int64_t weightHeight = ShapedType::kDynamic;
// Input shape describes input width/height and batch.
- ShapeAdaptor inputShape = operands.getShape(adaptor.getInput());
+ ShapeAdaptor inputShape(adaptor.getInput().getType());
if (inputShape.hasRank()) {
outputShape[0] = ShapedType::isDynamic(outputShape[0])
? inputShape.getDimSize(0)
@@ -1391,7 +1360,7 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
}
// Weight shapes describes the filter width/height and the output channels.
- ShapeAdaptor weightShape = operands.getShape(adaptor.getFilter());
+ ShapeAdaptor weightShape(adaptor.getFilter().getType());
if (weightShape.hasRank()) {
outputShape[3] = ShapedType::isDynamic(outputShape[3])
? weightShape.getDimSize(0)
@@ -1401,7 +1370,7 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
}
// Bias shape can describe the output channels.
- ShapeAdaptor biasShape = operands.getShape(adaptor.getInput());
+ ShapeAdaptor biasShape(adaptor.getInput().getType());
if (biasShape.hasRank()) {
outputShape[3] = ShapedType::isDynamic(outputShape[3])
? biasShape.getDimSize(0)
@@ -1433,11 +1402,10 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
LogicalResult IfOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
- ValueShapeRange operands, DictionaryAttr attributes,
- OpaqueProperties properties, RegionRange regions,
+ IfOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
llvm::SmallVector<tosa::YieldOp> yieldOps;
- for (Region *region : regions) {
+ for (Region *region : adaptor.getRegions()) {
for (auto &block : *region)
if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
yieldOps.push_back(returnOp);
@@ -1478,11 +1446,10 @@ LogicalResult IfOp::inferReturnTypeComponents(
LogicalResult WhileOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
- ValueShapeRange operands, DictionaryAttr attributes,
- OpaqueProperties properties, RegionRange regions,
+ WhileOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
llvm::SmallVector<tosa::YieldOp> yieldOps;
- for (auto &block : *regions[1])
+ for (auto &block : adaptor.getBody())
if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
yieldOps.push_back(returnOp);
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 420d5d3e4c5962..03aeac4c9dff80 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -1437,9 +1437,8 @@ LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
// Create return type consisting of the last element of the first operand.
auto operandType = operands.front().getType();
auto sval = dyn_cast<ShapedType>(operandType);
- if (!sval) {
+ if (!sval)
return emitOptionalError(location, "only shaped type operands allowed");
- }
int64_t dim = sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamic;
auto type = IntegerType::get(context, 17);
@@ -1458,6 +1457,35 @@ LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
return success();
}
+LogicalResult
+OpWithShapedTypeInferTypeAdaptorInterfaceOp::inferReturnTypeComponents(
+ MLIRContext *context, std::optional<Location> location,
+ OpWithShapedTypeInferTypeAdaptorInterfaceOp::Adaptor adaptor,
+ SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+ // Create return type consisting of the last element of the first operand.
+ auto operandType = adaptor.getOperand1().getType();
+ auto sval = dyn_cast<ShapedType>(operandType);
+ if (!sval)
+ return emitOptionalError(location, "only shaped type operands allowed");
+ int64_t dim = sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamic;
+ auto type = IntegerType::get(context, 17);
+
+ Attribute encoding;
+ if (auto rankedTy = dyn_cast<RankedTensorType>(sval))
+ encoding = rankedTy.getEncoding();
+ inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type, encoding));
+ return success();
+}
+
+LogicalResult
+OpWithShapedTypeInferTypeAdaptorInterfaceOp::reifyReturnTypeShapes(
+ OpBuilder &builder, ValueRange operands,
+ llvm::SmallVectorImpl<Value> &shapes) {
+ shapes = SmallVector<Value, 1>{
+ builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)};
+ return success();
+}
+
LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes(
OpBuilder &builder, ValueRange operands,
llvm::SmallVectorImpl<Value> &shapes) {
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 7a3ae924afa9a3..389fac6f3fed6f 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -780,6 +780,13 @@ def OpWithShapedTypeInferTypeInterfaceOp : TEST_Op<"op_with_shaped_type_infer_ty
let results = (outs AnyTensor);
}
+def OpWithShapedTypeInferTypeAdaptorInterfaceOp :
+ TEST_Op<"op_with_shaped_type_infer_type_adaptor_if",
+ [InferTensorTypeAdaptorWithReify]> {
+ let arguments = (ins AnyTensor:$operand1, AnyTensor:$operand2);
+ let results = (outs AnyTensor:$result);
+}
+
def OpWithResultShapeInterfaceOp : TEST_Op<"op_with_result_shape_interface",
[DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["reifyReturnTypeShapes"]>]> {
More information about the Mlir-commits
mailing list