[Mlir-commits] [mlir] 5267ed0 - [ODS] Use Adaptor Traits for Type Inference
Amanda Tang
llvmlistbot at llvm.org
Tue Jul 18 10:58:57 PDT 2023
Author: Amanda Tang
Date: 2023-07-18T17:58:31Z
New Revision: 5267ed05bc4612e91409d63b4dbc4e01751acb75
URL: https://github.com/llvm/llvm-project/commit/5267ed05bc4612e91409d63b4dbc4e01751acb75
DIFF: https://github.com/llvm/llvm-project/commit/5267ed05bc4612e91409d63b4dbc4e01751acb75.diff
LOG: [ODS] Use Adaptor Traits for Type Inference
Author inferReturnTypes methods with the Op Adaptor by using the InferTypeOpAdaptor.
Reviewed By: jpienaar
Differential Revision: https://reviews.llvm.org/D155115
Added:
Modified:
mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/include/mlir/Interfaces/InferTypeOpInterface.h
mlir/include/mlir/Interfaces/InferTypeOpInterface.td
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/SCF/IR/SCF.cpp
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/lib/Dialect/Test/TestPatterns.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 0c3f96ff70b9c4..ea6e363a6c3257 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -903,7 +903,7 @@ def MemRef_ExtractStridedMetadataOp : MemRef_Op<"extract_strided_metadata", [
Pure,
SameVariadicResultSize,
ViewLikeOpInterface,
- InferTypeOpInterfaceAdaptor]> {
+ InferTypeOpAdaptor]> {
let summary = "Extracts a buffer base with offset and strides";
let description = [{
Extracts a base buffer, offset and strides. This op allows additional layers
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 58b720c6e39637..db7b41afb7ed37 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -713,9 +713,8 @@ def InParallelOp : SCF_Op<"forall.in_parallel", [
def IfOp : SCF_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterface, [
"getNumRegionInvocations", "getRegionInvocationBounds"]>,
- DeclareOpInterfaceMethods<InferTypeOpInterface>,
- SingleBlockImplicitTerminator<"scf::YieldOp">, RecursiveMemoryEffects,
- NoRegionArguments]> {
+ InferTypeOpAdaptor, SingleBlockImplicitTerminator<"scf::YieldOp">,
+ RecursiveMemoryEffects, NoRegionArguments]> {
let summary = "if-then-else operation";
let description = [{
The `scf.if` operation represents an if-then-else construct for
diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index dce5be8207c82b..3109c4bdb36182 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -32,8 +32,7 @@ class Shape_Op<string mnemonic, list<Trait> traits = []> :
Op<ShapeDialect, mnemonic, traits>;
def Shape_AddOp : Shape_Op<"add",
- [Commutative, Pure,
- DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+ [Commutative, Pure, InferTypeOpAdaptorWithIsCompatible]> {
let summary = "Addition of sizes and indices";
let description = [{
Adds two sizes or indices. If either operand is an error it will be
@@ -51,12 +50,6 @@ def Shape_AddOp : Shape_Op<"add",
$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)
}];
- let extraClassDeclaration = [{
- // Returns when two result types are compatible for this op; method used by
- // InferTypeOpInterface
- static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
- }];
-
let hasFolder = 1;
let hasVerifier = 1;
}
@@ -109,7 +102,7 @@ def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative, Pure]> {
}
def Shape_ConstShapeOp : Shape_Op<"const_shape",
- [ConstantLike, Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+ [ConstantLike, Pure, InferTypeOpAdaptorWithIsCompatible]> {
let summary = "Creates a constant shape or extent tensor";
let description = [{
Creates a constant shape or extent tensor. The individual extents are given
@@ -128,11 +121,6 @@ def Shape_ConstShapeOp : Shape_Op<"const_shape",
let hasCustomAssemblyFormat = 1;
let hasFolder = 1;
let hasCanonicalizer = 1;
-
- let extraClassDeclaration = [{
- // InferTypeOpInterface:
- static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
- }];
}
def Shape_ConstSizeOp : Shape_Op<"const_size", [
@@ -158,8 +146,7 @@ def Shape_ConstSizeOp : Shape_Op<"const_size", [
let hasFolder = 1;
}
-def Shape_DivOp : Shape_Op<"div", [Pure,
- DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+def Shape_DivOp : Shape_Op<"div", [Pure, InferTypeOpAdaptorWithIsCompatible]> {
let summary = "Division of sizes and indices";
let description = [{
Divides two sizes or indices. If either operand is an error it will be
@@ -187,12 +174,6 @@ def Shape_DivOp : Shape_Op<"div", [Pure,
let hasFolder = 1;
let hasVerifier = 1;
-
- let extraClassDeclaration = [{
- // Returns when two result types are compatible for this op; method used by
- // InferTypeOpInterface
- static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
- }];
}
def Shape_ShapeEqOp : Shape_Op<"shape_eq", [Pure, Commutative]> {
@@ -287,7 +268,7 @@ def Shape_IsBroadcastableOp : Shape_Op<"is_broadcastable", [Commutative]> {
}
def Shape_RankOp : Shape_Op<"rank",
- [Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+ [Pure, InferTypeOpAdaptorWithIsCompatible]> {
let summary = "Gets the rank of a shape";
let description = [{
Returns the rank of the shape or extent tensor, i.e. the number of extents.
@@ -301,12 +282,6 @@ def Shape_RankOp : Shape_Op<"rank",
let hasFolder = 1;
let hasCanonicalizer = 1;
let hasVerifier = 1;
-
- let extraClassDeclaration = [{
- // Returns when two result types are compatible for this op; method used by
- // InferTypeOpInterface
- static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
- }];
}
def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [
@@ -330,7 +305,7 @@ def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [
}
def Shape_DimOp : Shape_Op<"dim",
- [Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+ [Pure, InferTypeOpAdaptorWithIsCompatible]> {
let summary = "Gets the specified extent from the shape of a shaped input";
let description = [{
Gets the extent indexed by `dim` from the shape of the `value` operand. If
@@ -354,17 +329,13 @@ def Shape_DimOp : Shape_Op<"dim",
let extraClassDeclaration = [{
/// Get the `index` value as integer if it is constant.
std::optional<int64_t> getConstantIndex();
-
- /// Returns when two result types are compatible for this op; method used
- /// by InferTypeOpInterface
- static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
}];
let hasFolder = 1;
}
def Shape_GetExtentOp : Shape_Op<"get_extent",
- [Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+ [Pure, InferTypeOpAdaptorWithIsCompatible]> {
let summary = "Gets the specified extent from a shape or extent tensor";
let description = [{
Gets the extent indexed by `dim` from the `shape` operand. If the shape is
@@ -384,9 +355,6 @@ def Shape_GetExtentOp : Shape_Op<"get_extent",
let extraClassDeclaration = [{
/// Get the `dim` value as integer if it is constant.
std::optional<int64_t> getConstantDim();
- /// Returns when two result types are compatible for this op; method used
- /// by InferTypeOpInterface
- static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
}];
let hasFolder = 1;
@@ -413,8 +381,7 @@ def Shape_IndexToSizeOp : Shape_Op<"index_to_size", [Pure]> {
}
def Shape_MaxOp : Shape_Op<"max",
- [Commutative, Pure,
- DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+ [Commutative, Pure, InferTypeOpAdaptorWithIsCompatible]> {
let summary = "Elementwise maximum";
let description = [{
Computes the elementwise maximum of two sizes or shapes with equal ranks.
@@ -431,16 +398,10 @@ def Shape_MaxOp : Shape_Op<"max",
}];
let hasFolder = 1;
-
- let extraClassDeclaration = [{
- // Returns when two result types are compatible for this op; method used by
- // InferTypeOpInterface
- static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
- }];
}
def Shape_MeetOp : Shape_Op<"meet",
- [Commutative, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+ [Commutative, InferTypeOpAdaptorWithIsCompatible]> {
let summary = "Returns the least general shape or size of its operands";
let description = [{
An operation that computes the least general shape or dim of input operands.
@@ -478,17 +439,10 @@ def Shape_MeetOp : Shape_Op<"meet",
$arg0 `,` $arg1 (`,` `error` `=` $error^)? attr-dict `:`
type($arg0) `,` type($arg1) `->` type($result)
}];
-
- let extraClassDeclaration = [{
- // Returns when two result types are compatible for this op; method used by
- // InferTypeOpInterface
- static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
- }];
}
def Shape_MinOp : Shape_Op<"min",
- [Commutative, Pure,
- DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+ [Commutative, Pure, InferTypeOpAdaptorWithIsCompatible]> {
let summary = "Elementwise minimum";
let description = [{
Computes the elementwise minimum of two sizes or shapes with equal ranks.
@@ -505,17 +459,10 @@ def Shape_MinOp : Shape_Op<"min",
}];
let hasFolder = 1;
-
- let extraClassDeclaration = [{
- // Returns when two result types are compatible for this op; method used by
- // InferTypeOpInterface
- static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
- }];
}
def Shape_MulOp : Shape_Op<"mul",
- [Commutative, Pure,
- DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+ [Commutative, Pure, InferTypeOpAdaptorWithIsCompatible]> {
let summary = "Multiplication of sizes and indices";
let description = [{
Multiplies two sizes or indices. If either operand is an error it will be
@@ -535,16 +482,10 @@ def Shape_MulOp : Shape_Op<"mul",
let hasFolder = 1;
let hasVerifier = 1;
-
- let extraClassDeclaration = [{
- // Returns when two result types are compatible for this op; method used by
- // InferTypeOpInterface
- static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
- }];
}
def Shape_NumElementsOp : Shape_Op<"num_elements",
- [Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+ [Pure, InferTypeOpAdaptorWithIsCompatible]> {
let summary = "Returns the number of elements for a given shape";
let description = [{
Returns the number of elements for a given shape which is the product of
@@ -561,11 +502,6 @@ def Shape_NumElementsOp : Shape_Op<"num_elements",
let hasFolder = 1;
let hasVerifier = 1;
- let extraClassDeclaration = [{
- // Returns when two result types are compatible for this op; method used by
- // InferTypeOpInterface
- static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
- }];
}
def Shape_ReduceOp : Shape_Op<"reduce",
@@ -616,7 +552,7 @@ def Shape_ReduceOp : Shape_Op<"reduce",
}
def Shape_ShapeOfOp : Shape_Op<"shape_of",
- [Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+ [Pure, InferTypeOpAdaptorWithIsCompatible]> {
let summary = "Returns shape of a value or shaped type operand";
let description = [{
@@ -632,12 +568,6 @@ def Shape_ShapeOfOp : Shape_Op<"shape_of",
let hasCanonicalizer = 1;
let hasFolder = 1;
let hasVerifier = 1;
-
- let extraClassDeclaration = [{
- // Returns when two result types are compatible for this op; method used by
- // InferTypeOpInterface
- static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
- }];
}
def Shape_ValueOfOp : Shape_Op<"value_of", [Pure]> {
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 7346315013b1c3..588998853e6995 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -463,7 +463,7 @@ def Vector_ShuffleOp :
TCresVTEtIsSameAsOpBase<0, 0>>,
PredOpTrait<"second operand v2 and result have same element type",
TCresVTEtIsSameAsOpBase<0, 1>>,
- DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
+ InferTypeOpAdaptor]>,
Arguments<(ins AnyVectorOfAnyRank:$v1, AnyVectorOfAnyRank:$v2,
I64ArrayAttr:$mask)>,
Results<(outs AnyVector:$vector)> {
@@ -572,7 +572,7 @@ def Vector_ExtractOp :
Vector_Op<"extract", [Pure,
PredOpTrait<"operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>,
- DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
+ InferTypeOpAdaptorWithIsCompatible]>,
Arguments<(ins AnyVectorOfAnyRank:$vector, I64ArrayAttr:$position)>,
Results<(outs AnyType)> {
let summary = "extract operation";
@@ -598,7 +598,6 @@ def Vector_ExtractOp :
VectorType getSourceVectorType() {
return ::llvm::cast<VectorType>(getVector().getType());
}
- static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
}];
let assemblyFormat = "$vector `` $position attr-dict `:` type($vector)";
let hasCanonicalizer = 1;
diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
index 747ec0a76f3f1d..b8a664fc6882f0 100644
--- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
+++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
@@ -259,8 +259,8 @@ namespace mlir {
namespace OpTrait {
template <typename ConcreteType>
-class InferTypeOpInterfaceAdaptor
- : public TraitBase<ConcreteType, InferTypeOpInterfaceAdaptor> {};
+class InferTypeOpAdaptor : public TraitBase<ConcreteType, InferTypeOpAdaptor> {
+};
/// Tensor type inference trait that constructs a tensor from the inferred
/// shape and elemental types.
diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
index c9c1c6cc9ab01c..a458887b374543 100644
--- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
+++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
@@ -186,35 +186,42 @@ def InferShapedTypeOpInterface : OpInterface<"InferShapedTypeOpInterface"> {
// Convenient trait to define a wrapper to inferReturnTypes that passes in the
// Op Adaptor directly
-def InferTypeOpInterfaceAdaptor : TraitList<
+class InferTypeOpAdaptorBase<code additionalDecls = [{}]> : TraitList<
[
// Op implements infer type op interface.
DeclareOpInterfaceMethods<InferTypeOpInterface>,
NativeOpTrait<
- /*name=*/"InferTypeOpInterfaceAdaptor",
+ /*name=*/"InferTypeOpAdaptor",
/*traits=*/[],
/*extraOpDeclaration=*/[{
- static LogicalResult
- inferReturnTypesAdaptor(MLIRContext *context,
- std::optional<Location> location,
+ static ::mlir::LogicalResult
+ inferReturnTypes(::mlir::MLIRContext *context,
+ std::optional<::mlir::Location> location,
Adaptor adaptor,
- SmallVectorImpl<Type> &inferredReturnTypes);
- }],
+ ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes);
+ }] # additionalDecls,
/*extraOpDefinition=*/[{
- LogicalResult
- $cppClass::inferReturnTypes(MLIRContext *context,
- std::optional<Location> location,
- ValueRange operands, DictionaryAttr attributes,
- OpaqueProperties properties, RegionRange regions,
- SmallVectorImpl<Type> &inferredReturnTypes) {
+ ::mlir::LogicalResult
+ $cppClass::inferReturnTypes(::mlir::MLIRContext *context,
+ std::optional<::mlir::Location> location,
+ ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
+ ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions,
+ ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
$cppClass::Adaptor adaptor(operands, attributes, properties, regions);
- return $cppClass::inferReturnTypesAdaptor(context,
+ return $cppClass::inferReturnTypes(context,
location, adaptor, inferredReturnTypes);
}
}]
>
]>;
+def InferTypeOpAdaptor : InferTypeOpAdaptorBase;
+def InferTypeOpAdaptorWithIsCompatible : InferTypeOpAdaptorBase<
+ [{
+ static bool isCompatibleReturnTypes(::mlir::TypeRange l, ::mlir::TypeRange r);
+ }]
+>;
+
// Convenience class grouping together type and shaped type op interfaces for
// ops that have tensor return types.
class InferTensorTypeBase<list<string> overridenMethods = []> : TraitList<
@@ -231,13 +238,13 @@ class InferTensorTypeBase<list<string> overridenMethods = []> : TraitList<
/*traits=*/[],
/*extraOpDeclaration=*/[{}],
/*extraOpDefinition=*/[{
- LogicalResult
- $cppClass::inferReturnTypes(MLIRContext *context,
- std::optional<Location> location,
- ValueRange operands, DictionaryAttr attributes,
- OpaqueProperties properties, RegionRange regions,
- SmallVectorImpl<Type> &inferredReturnTypes) {
- SmallVector<ShapedTypeComponents, 2> retComponents;
+ ::mlir::LogicalResult
+ $cppClass::inferReturnTypes(::mlir::MLIRContext *context,
+ std::optional<::mlir::Location> location,
+ ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
+ ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions,
+ ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
+ ::llvm::SmallVector<::mlir::ShapedTypeComponents, 2> retComponents;
if (failed($cppClass::inferReturnTypeComponents(context, location,
operands, attributes, properties, regions,
retComponents)))
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 427efa0cd76822..5f35adf0ddaab1 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1354,7 +1354,7 @@ void ExtractAlignedPointerAsIndexOp::getAsmResultNames(
/// The number and type of the results are inferred from the
/// shape of the source.
-LogicalResult ExtractStridedMetadataOp::inferReturnTypesAdaptor(
+LogicalResult ExtractStridedMetadataOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location,
ExtractStridedMetadataOp::Adaptor adaptor,
SmallVectorImpl<Type> &inferredReturnTypes) {
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index ddcdffebed1392..d8e0ba1c8cd6de 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1841,12 +1841,11 @@ bool mlir::scf::insideMutuallyExclusiveBranches(Operation *a, Operation *b) {
LogicalResult
IfOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
- ValueRange operands, DictionaryAttr attrs,
- OpaqueProperties properties, RegionRange regions,
+ IfOp::Adaptor adaptor,
SmallVectorImpl<Type> &inferredReturnTypes) {
- if (regions.empty())
+ if (adaptor.getRegions().empty())
return failure();
- Region *r = regions.front();
+ Region *r = &adaptor.getThenRegion();
if (r->empty())
return failure();
Block &b = r->front();
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index c5c69dc98b82a0..a06dd446e258a4 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -394,11 +394,10 @@ void AssumingOp::build(
//===----------------------------------------------------------------------===//
LogicalResult mlir::shape::AddOp::inferReturnTypes(
- MLIRContext *context, std::optional<Location> location, ValueRange operands,
- DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
- SmallVectorImpl<Type> &inferredReturnTypes) {
- if (llvm::isa<SizeType>(operands[0].getType()) ||
- llvm::isa<SizeType>(operands[1].getType()))
+ MLIRContext *context, std::optional<Location> location,
+ AddOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
+ if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
+ llvm::isa<SizeType>(adaptor.getRhs().getType()))
inferredReturnTypes.assign({SizeType::get(context)});
else
inferredReturnTypes.assign({IndexType::get(context)});
@@ -916,18 +915,17 @@ void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
}
LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes(
- MLIRContext *context, std::optional<Location> location, ValueRange operands,
- DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
- SmallVectorImpl<Type> &inferredReturnTypes) {
+ MLIRContext *context, std::optional<Location> location,
+ ConstShapeOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
Builder b(context);
- Properties *prop = properties.as<Properties *>();
+ const Properties *prop = &adaptor.getProperties();
DenseIntElementsAttr shape;
// TODO: this is only exercised by the Python bindings codepath which does not
// support properties
if (prop)
shape = prop->shape;
else
- shape = attributes.getAs<DenseIntElementsAttr>("shape");
+ shape = adaptor.getAttributes().getAs<DenseIntElementsAttr>("shape");
if (!shape)
return emitOptionalError(location, "missing shape attribute");
inferredReturnTypes.assign({RankedTensorType::get(
@@ -1104,11 +1102,9 @@ OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
}
LogicalResult mlir::shape::DimOp::inferReturnTypes(
- MLIRContext *context, std::optional<Location> location, ValueRange operands,
- DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
- SmallVectorImpl<Type> &inferredReturnTypes) {
- DimOpAdaptor dimOp(operands);
- inferredReturnTypes.assign({dimOp.getIndex().getType()});
+ MLIRContext *context, std::optional<Location> location,
+ DimOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
+ inferredReturnTypes.assign({adaptor.getIndex().getType()});
return success();
}
@@ -1141,11 +1137,10 @@ OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
}
LogicalResult mlir::shape::DivOp::inferReturnTypes(
- MLIRContext *context, std::optional<Location> location, ValueRange operands,
- DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
- SmallVectorImpl<Type> &inferredReturnTypes) {
- if (llvm::isa<SizeType>(operands[0].getType()) ||
- llvm::isa<SizeType>(operands[1].getType()))
+ MLIRContext *context, std::optional<Location> location,
+ DivOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
+ if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
+ llvm::isa<SizeType>(adaptor.getRhs().getType()))
inferredReturnTypes.assign({SizeType::get(context)});
else
inferredReturnTypes.assign({IndexType::get(context)});
@@ -1361,9 +1356,8 @@ void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
}
LogicalResult mlir::shape::GetExtentOp::inferReturnTypes(
- MLIRContext *context, std::optional<Location> location, ValueRange operands,
- DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
- SmallVectorImpl<Type> &inferredReturnTypes) {
+ MLIRContext *context, std::optional<Location> location,
+ GetExtentOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
inferredReturnTypes.assign({IndexType::get(context)});
return success();
}
@@ -1399,10 +1393,9 @@ OpFoldResult IsBroadcastableOp::fold(FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===//
LogicalResult mlir::shape::MeetOp::inferReturnTypes(
- MLIRContext *context, std::optional<Location> location, ValueRange operands,
- DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
- SmallVectorImpl<Type> &inferredReturnTypes) {
- if (operands.empty())
+ MLIRContext *context, std::optional<Location> location,
+ MeetOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
+ if (adaptor.getOperands().empty())
return failure();
auto isShapeType = [](Type arg) {
@@ -1411,7 +1404,7 @@ LogicalResult mlir::shape::MeetOp::inferReturnTypes(
return isExtentTensorType(arg);
};
- ValueRange::type_range types = operands.getTypes();
+ ValueRange::type_range types = adaptor.getOperands().getTypes();
Type acc = types.front();
for (auto t : drop_begin(types)) {
Type l = acc, r = t;
@@ -1535,10 +1528,9 @@ void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
}
LogicalResult mlir::shape::RankOp::inferReturnTypes(
- MLIRContext *context, std::optional<Location> location, ValueRange operands,
- DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
- SmallVectorImpl<Type> &inferredReturnTypes) {
- if (llvm::isa<ShapeType>(operands[0].getType()))
+ MLIRContext *context, std::optional<Location> location,
+ RankOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
+ if (llvm::isa<ShapeType>(adaptor.getShape().getType()))
inferredReturnTypes.assign({SizeType::get(context)});
else
inferredReturnTypes.assign({IndexType::get(context)});
@@ -1571,10 +1563,10 @@ OpFoldResult NumElementsOp::fold(FoldAdaptor adaptor) {
}
LogicalResult mlir::shape::NumElementsOp::inferReturnTypes(
- MLIRContext *context, std::optional<Location> location, ValueRange operands,
- DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
+ MLIRContext *context, std::optional<Location> location,
+ NumElementsOp::Adaptor adaptor,
SmallVectorImpl<Type> &inferredReturnTypes) {
- if (llvm::isa<ShapeType>(operands[0].getType()))
+ if (llvm::isa<ShapeType>(adaptor.getShape().getType()))
inferredReturnTypes.assign({SizeType::get(context)});
else
inferredReturnTypes.assign({IndexType::get(context)});
@@ -1603,11 +1595,10 @@ OpFoldResult MaxOp::fold(FoldAdaptor adaptor) {
}
LogicalResult mlir::shape::MaxOp::inferReturnTypes(
- MLIRContext *context, std::optional<Location> location, ValueRange operands,
- DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
- SmallVectorImpl<Type> &inferredReturnTypes) {
- if (operands[0].getType() == operands[1].getType())
- inferredReturnTypes.assign({operands[0].getType()});
+ MLIRContext *context, std::optional<Location> location,
+ MaxOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
+ if (adaptor.getLhs().getType() == adaptor.getRhs().getType())
+ inferredReturnTypes.assign({adaptor.getLhs().getType()});
else
inferredReturnTypes.assign({SizeType::get(context)});
return success();
@@ -1635,11 +1626,10 @@ OpFoldResult MinOp::fold(FoldAdaptor adaptor) {
}
LogicalResult mlir::shape::MinOp::inferReturnTypes(
- MLIRContext *context, std::optional<Location> location, ValueRange operands,
- DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
- SmallVectorImpl<Type> &inferredReturnTypes) {
- if (operands[0].getType() == operands[1].getType())
- inferredReturnTypes.assign({operands[0].getType()});
+ MLIRContext *context, std::optional<Location> location,
+ MinOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
+ if (adaptor.getLhs().getType() == adaptor.getRhs().getType())
+ inferredReturnTypes.assign({adaptor.getLhs().getType()});
else
inferredReturnTypes.assign({SizeType::get(context)});
return success();
@@ -1672,11 +1662,10 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
}
LogicalResult mlir::shape::MulOp::inferReturnTypes(
- MLIRContext *context, std::optional<Location> location, ValueRange operands,
- DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
- SmallVectorImpl<Type> &inferredReturnTypes) {
- if (llvm::isa<SizeType>(operands[0].getType()) ||
- llvm::isa<SizeType>(operands[1].getType()))
+ MLIRContext *context, std::optional<Location> location,
+ MulOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
+ if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
+ llvm::isa<SizeType>(adaptor.getRhs().getType()))
inferredReturnTypes.assign({SizeType::get(context)});
else
inferredReturnTypes.assign({IndexType::get(context)});
@@ -1759,13 +1748,12 @@ void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
}
LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes(
- MLIRContext *context, std::optional<Location> location, ValueRange operands,
- DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
- SmallVectorImpl<Type> &inferredReturnTypes) {
- if (llvm::isa<ValueShapeType>(operands[0].getType()))
+ MLIRContext *context, std::optional<Location> location,
+ ShapeOfOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
+ if (llvm::isa<ValueShapeType>(adaptor.getArg().getType()))
inferredReturnTypes.assign({ShapeType::get(context)});
else {
- auto shapedTy = llvm::cast<ShapedType>(operands[0].getType());
+ auto shapedTy = llvm::cast<ShapedType>(adaptor.getArg().getType());
int64_t rank =
shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamic;
Type indexTy = IndexType::get(context);
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index f5dbeb250ed60d..075b139e2f3b1c 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1146,15 +1146,15 @@ void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
LogicalResult
ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
- ValueRange operands, DictionaryAttr attributes,
- OpaqueProperties properties, RegionRange,
+ ExtractOp::Adaptor adaptor,
SmallVectorImpl<Type> &inferredReturnTypes) {
- ExtractOp::Adaptor op(operands, attributes, properties);
- auto vectorType = llvm::cast<VectorType>(op.getVector().getType());
- if (static_cast<int64_t>(op.getPosition().size()) == vectorType.getRank()) {
+ auto vectorType = llvm::cast<VectorType>(adaptor.getVector().getType());
+ if (static_cast<int64_t>(adaptor.getPosition().size()) ==
+ vectorType.getRank()) {
inferredReturnTypes.push_back(vectorType.getElementType());
} else {
- auto n = std::min<size_t>(op.getPosition().size(), vectorType.getRank());
+ auto n =
+ std::min<size_t>(adaptor.getPosition().size(), vectorType.getRank());
inferredReturnTypes.push_back(VectorType::get(
vectorType.getShape().drop_front(n), vectorType.getElementType()));
}
@@ -2114,17 +2114,15 @@ LogicalResult ShuffleOp::verify() {
LogicalResult
ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
- ValueRange operands, DictionaryAttr attributes,
- OpaqueProperties properties, RegionRange,
+ ShuffleOp::Adaptor adaptor,
SmallVectorImpl<Type> &inferredReturnTypes) {
- ShuffleOp::Adaptor op(operands, attributes, properties);
- auto v1Type = llvm::cast<VectorType>(op.getV1().getType());
+ auto v1Type = llvm::cast<VectorType>(adaptor.getV1().getType());
auto v1Rank = v1Type.getRank();
// Construct resulting type: leading dimension matches mask
// length, all trailing dimensions match the operands.
SmallVector<int64_t, 4> shape;
shape.reserve(v1Rank);
- shape.push_back(std::max<size_t>(1, op.getMask().size()));
+ shape.push_back(std::max<size_t>(1, adaptor.getMask().size()));
// In the 0-D case there is no trailing shape to append.
if (v1Rank > 0)
llvm::append_range(shape, v1Type.getShape().drop_front());
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index fba882d3ba1a3e..420d5d3e4c5962 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -1385,6 +1385,19 @@ LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
return success();
}
+LogicalResult OpWithInferTypeAdaptorInterfaceOp::inferReturnTypes(
+ MLIRContext *, std::optional<Location> location,
+ OpWithInferTypeAdaptorInterfaceOp::Adaptor adaptor,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ if (adaptor.getX().getType() != adaptor.getY().getType()) {
+ return emitOptionalError(location, "operand type mismatch ",
+ adaptor.getX().getType(), " vs ",
+ adaptor.getY().getType());
+ }
+ inferredReturnTypes.assign({adaptor.getX().getType()});
+ return success();
+}
+
// TODO: We should be able to only define either inferReturnType or
// refineReturnType, currently only refineReturnType can be omitted.
LogicalResult OpWithRefineTypeInterfaceOp::inferReturnTypes(
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 30f334c4404bea..7a3ae924afa9a3 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -761,6 +761,12 @@ def OpWithInferTypeInterfaceOp : TEST_Op<"op_with_infer_type_if", [
let results = (outs AnyTensor);
}
+def OpWithInferTypeAdaptorInterfaceOp : TEST_Op<"op_with_infer_type_adaptor_if", [
+ InferTypeOpAdaptor]> {
+ let arguments = (ins AnyTensor:$x, AnyTensor:$y);
+ let results = (outs AnyTensor);
+}
+
def OpWithRefineTypeInterfaceOp : TEST_Op<"op_with_refine_type_if", [
DeclareOpInterfaceMethods<InferTypeOpInterface,
["refineReturnTypes"]>]> {
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 16c85818605d78..46788edcb4df58 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -485,6 +485,8 @@ struct TestReturnTypeDriver
// output would be in reverse order underneath `op` from which
// the attributes and regions are used.
invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op);
+ invokeCreateWithInferredReturnType<OpWithInferTypeAdaptorInterfaceOp>(
+ op);
invokeCreateWithInferredReturnType<
OpWithShapedTypeInferTypeInterfaceOp>(op);
};
More information about the Mlir-commits
mailing list