[Mlir-commits] [mlir] 9eb3e56 - [ODS] Make the getType() method on a OneResult instruction return a specific type.
Chris Lattner
llvmlistbot at llvm.org
Sat Dec 26 13:52:48 PST 2020
Author: Chris Lattner
Date: 2020-12-26T13:52:40-08:00
New Revision: 9eb3e564d3b1c772a64eef6ecaa3b1705d065218
URL: https://github.com/llvm/llvm-project/commit/9eb3e564d3b1c772a64eef6ecaa3b1705d065218
DIFF: https://github.com/llvm/llvm-project/commit/9eb3e564d3b1c772a64eef6ecaa3b1705d065218.diff
LOG: [ODS] Make the getType() method on a OneResult instruction return a specific type.
Implement Bug 46698, making ODS synthesize a getType() method that returns a
specific C++ class for OneResult methods where we know that class. This eliminates
a common source of casts in things like:
myOp.getType().cast<FIRRTLType>().getPassive()
because we know that myOp always returns a FIRRTLType. This also encourages
op authors to type their results more tightly (which is also good for
verification).
I chose to implement this by splitting the OneResult trait into itself plus a
OneTypedResult trait, given that many things are using `hasTrait<OneResult>`
to conditionalize various logic.
While this changes makes many many ops get more specific getType() results, it
is generally drop-in compatible with the previous behavior because 'x.cast<T>()'
is allowed when x is already known to be a T. The one exception to this is that
we need declarations of the types used by ops, which is why a couple headers
needed additional #includes.
I updated a few things in tree to remove the now-redundant `.cast<>`'s, but there
are probably many more than can be removed.
Differential Revision: https://reviews.llvm.org/D93790
Added:
Modified:
mlir/docs/Tutorials/Toy/Ch-2.md
mlir/examples/standalone/include/Standalone/StandaloneOps.h
mlir/include/mlir/Dialect/AVX512/AVX512Dialect.h
mlir/include/mlir/Dialect/ArmNeon/ArmNeonDialect.h
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/IR/OpDefinition.h
mlir/include/mlir/TableGen/Type.h
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/lib/TableGen/Type.cpp
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
Removed:
################################################################################
diff --git a/mlir/docs/Tutorials/Toy/Ch-2.md b/mlir/docs/Tutorials/Toy/Ch-2.md
index be76dc351912..6c9b93ae4066 100644
--- a/mlir/docs/Tutorials/Toy/Ch-2.md
+++ b/mlir/docs/Tutorials/Toy/Ch-2.md
@@ -210,7 +210,9 @@ class ConstantOp : public mlir::Op<ConstantOp,
/// The ConstantOp takes no inputs.
mlir::OpTrait::ZeroOperands,
/// The ConstantOp returns a single result.
- mlir::OpTrait::OneResult> {
+ mlir::OpTrait::OneResult,
+ /// The result of getType is `Type`.
+ mlir::OpTraits::OneTypedResult<Type>::Impl> {
public:
/// Inherit the constructors from the base Op class.
diff --git a/mlir/examples/standalone/include/Standalone/StandaloneOps.h b/mlir/examples/standalone/include/Standalone/StandaloneOps.h
index 5a8c5d1040e6..a56c2867b1c8 100644
--- a/mlir/examples/standalone/include/Standalone/StandaloneOps.h
+++ b/mlir/examples/standalone/include/Standalone/StandaloneOps.h
@@ -9,6 +9,7 @@
#ifndef STANDALONE_STANDALONEOPS_H
#define STANDALONE_STANDALONEOPS_H
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
diff --git a/mlir/include/mlir/Dialect/AVX512/AVX512Dialect.h b/mlir/include/mlir/Dialect/AVX512/AVX512Dialect.h
index aae3dbdf179f..eddee61d6b19 100644
--- a/mlir/include/mlir/Dialect/AVX512/AVX512Dialect.h
+++ b/mlir/include/mlir/Dialect/AVX512/AVX512Dialect.h
@@ -13,6 +13,7 @@
#ifndef MLIR_DIALECT_AVX512_AVX512DIALECT_H_
#define MLIR_DIALECT_AVX512_AVX512DIALECT_H_
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
diff --git a/mlir/include/mlir/Dialect/ArmNeon/ArmNeonDialect.h b/mlir/include/mlir/Dialect/ArmNeon/ArmNeonDialect.h
index 76153af97689..18535353a104 100644
--- a/mlir/include/mlir/Dialect/ArmNeon/ArmNeonDialect.h
+++ b/mlir/include/mlir/Dialect/ArmNeon/ArmNeonDialect.h
@@ -13,6 +13,7 @@
#ifndef MLIR_DIALECT_ARMNEON_ARMNEONDIALECT_H_
#define MLIR_DIALECT_ARMNEON_ARMNEONDIALECT_H_
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 0ae572c38f49..857a652f17d9 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -175,8 +175,12 @@ class Constraint<Pred pred, string desc = ""> {
// are considered as uncategorized constraints.
// Subclass for constraints on a type.
-class TypeConstraint<Pred predicate, string description = ""> :
- Constraint<predicate, description>;
+class TypeConstraint<Pred predicate, string description = "",
+ string cppClassNameParam = "::mlir::Type"> :
+ Constraint<predicate, description> {
+ // The name of the C++ Type class if known, or Type if not.
+ string cppClassName = cppClassNameParam;
+}
// Subclass for constraints on an attribute.
class AttrConstraint<Pred predicate, string description = ""> :
@@ -285,8 +289,9 @@ class Dialect {
//===----------------------------------------------------------------------===//
// A type, carries type constraints.
-class Type<Pred condition, string descr = ""> :
- TypeConstraint<condition, descr> {
+class Type<Pred condition, string descr = "",
+ string cppClassName = "::mlir::Type"> :
+ TypeConstraint<condition, descr, cppClassName> {
string typeDescription = "";
string builderCall = "";
}
@@ -299,8 +304,9 @@ class TypeAlias<Type t, string description = t.description> :
}
// A type of a specific dialect.
-class DialectType<Dialect d, Pred condition, string descr = ""> :
- Type<condition, descr> {
+class DialectType<Dialect d, Pred condition, string descr = "",
+ string cppClassName = "::mlir::Type"> :
+ Type<condition, descr, cppClassName> {
Dialect dialect = d;
}
@@ -331,11 +337,13 @@ class BuildableType<code builder> {
def AnyType : Type<CPred<"true">, "any type">;
// None type
-def NoneType : Type<CPred<"$_self.isa<::mlir::NoneType>()">, "none type">,
+def NoneType : Type<CPred<"$_self.isa<::mlir::NoneType>()">, "none type",
+ "::mlir::NoneType">,
BuildableType<"$_builder.getType<::mlir::NoneType>()">;
// Any type from the given list
-class AnyTypeOf<list<Type> allowedTypes, string description = ""> : Type<
+class AnyTypeOf<list<Type> allowedTypes, string description = "",
+ string cppClassName = "::mlir::Type"> : Type<
// Satisfy any of the allowed type's condition
Or<!foreach(allowedtype, allowedTypes, allowedtype.predicate)>,
!if(!eq(description, ""),
@@ -345,7 +353,8 @@ class AnyTypeOf<list<Type> allowedTypes, string description = ""> : Type<
// Integer types.
// Any integer type irrespective of its width and signedness semantics.
-def AnyInteger : Type<CPred<"$_self.isa<::mlir::IntegerType>()">, "integer">;
+def AnyInteger : Type<CPred<"$_self.isa<::mlir::IntegerType>()">, "integer",
+ "::mlir::IntegerType">;
// Any integer type (regardless of signedness semantics) of a specific width.
class AnyI<int width>
@@ -355,7 +364,8 @@ class AnyI<int width>
class AnyIntOfWidths<list<int> widths> :
AnyTypeOf<!foreach(w, widths, AnyI<w>),
- StrJoinInt<widths, "/">.result # "-bit integer">;
+ StrJoinInt<widths, "/">.result # "-bit integer",
+ "::mlir::IntegerType">;
def AnyI1 : AnyI<1>;
def AnyI8 : AnyI<8>;
@@ -365,12 +375,13 @@ def AnyI64 : AnyI<64>;
// Any signless integer type irrespective of its width.
def AnySignlessInteger : Type<
- CPred<"$_self.isSignlessInteger()">, "signless integer">;
+ CPred<"$_self.isSignlessInteger()">, "signless integer",
+ "::mlir::IntegerType">;
// Signless integer type of a specific width.
class I<int width>
: Type<CPred<"$_self.isSignlessInteger(" # width # ")">,
- width # "-bit signless integer">,
+ width # "-bit signless integer", "::mlir::IntegerType">,
BuildableType<"$_builder.getIntegerType(" # width # ")"> {
int bitwidth = width;
}
@@ -392,7 +403,7 @@ def AnySignedInteger : Type<
// Signed integer type of a specific width.
class SI<int width>
: Type<CPred<"$_self.isSignedInteger(" # width # ")">,
- width # "-bit signed integer">,
+ width # "-bit signed integer", "::mlir::IntegerType">,
BuildableType<
"$_builder.getIntegerType(" # width # ", /*isSigned=*/true)"> {
int bitwidth = width;
@@ -415,7 +426,7 @@ def AnyUnsignedInteger : Type<
// Unsigned integer type of a specific width.
class UI<int width>
: Type<CPred<"$_self.isUnsignedInteger(" # width # ")">,
- width # "-bit unsigned integer">,
+ width # "-bit unsigned integer", "::mlir::IntegerType">,
BuildableType<
"$_builder.getIntegerType(" # width # ", /*isSigned=*/false)"> {
int bitwidth = width;
@@ -432,18 +443,20 @@ def UI32 : UI<32>;
def UI64 : UI<64>;
// Index type.
-def Index : Type<CPred<"$_self.isa<::mlir::IndexType>()">, "index">,
+def Index : Type<CPred<"$_self.isa<::mlir::IndexType>()">, "index",
+ "::mlir::IndexType">,
BuildableType<"$_builder.getIndexType()">;
// Floating point types.
// Any float type irrespective of its width.
-def AnyFloat : Type<CPred<"$_self.isa<::mlir::FloatType>()">, "floating-point">;
+def AnyFloat : Type<CPred<"$_self.isa<::mlir::FloatType>()">, "floating-point",
+ "::mlir::FloatType">;
// Float type of a specific width.
class F<int width>
: Type<CPred<"$_self.isF" # width # "()">,
- width # "-bit float">,
+ width # "-bit float", "::mlir::FloatType">,
BuildableType<"$_builder.getF" # width # "Type()"> {
int bitwidth = width;
}
@@ -465,16 +478,17 @@ class Complex<Type type>
SubstLeaves<"$_self",
"$_self.cast<::mlir::ComplexType>().getElementType()",
type.predicate>]>,
- "complex type with " # type.description # " elements"> {
+ "complex type with " # type.description # " elements",
+ "::mlir::ComplexType"> {
Type elementType = type;
}
def AnyComplex : Type<CPred<"$_self.isa<::mlir::ComplexType>()">,
- "complex-type">;
+ "complex-type", "::mlir::ComplexType">;
class OpaqueType<string dialect, string name, string description>
: Type<CPred<"isOpaqueTypeWithName($_self, \""#dialect#"\", \""#name#"\")">,
- description>,
+ description, "::mlir::OpaqueType">,
BuildableType<"::mlir::OpaqueType::get($_builder.getContext(), "
"$_builder.getIdentifier(\"" # dialect # "\"), \""
# name # "\")">;
@@ -483,17 +497,17 @@ class OpaqueType<string dialect, string name, string description>
// Any function type.
def FunctionType : Type<CPred<"$_self.isa<::mlir::FunctionType>()">,
- "function type">;
+ "function type", "::mlir::FunctionType">;
// A container type is a type that has another type embedded within it.
class ContainerType<Type etype, Pred containerPred, code elementTypeCall,
- string descr> :
+ string descr, string cppClassName = "::mlir::Type"> :
// First, check the container predicate. Then, substitute the extracted
// element into the element type checker.
Type<And<[containerPred,
SubstLeaves<"$_self", !cast<string>(elementTypeCall),
etype.predicate>]>,
- descr # " of " # etype.description # " values"> {
+ descr # " of " # etype.description # " values", cppClassName> {
// The type of elements in the container.
Type elementType = etype;
@@ -502,9 +516,11 @@ class ContainerType<Type etype, Pred containerPred, code elementTypeCall,
}
class ShapedContainerType<list<Type> allowedTypes,
- Pred containerPred, string descr> :
+ Pred containerPred, string descr,
+ string cppClassName = "::mlir::Type"> :
ContainerType<AnyTypeOf<allowedTypes>, containerPred,
- "$_self.cast<::mlir::ShapedType>().getElementType()", descr>;
+ "$_self.cast<::mlir::ShapedType>().getElementType()", descr,
+ cppClassName>;
// Whether a shaped type is ranked.
def HasRankPred : CPred<"$_self.cast<::mlir::ShapedType>().hasRank()">;
@@ -520,7 +536,8 @@ class HasAnyRankOfPred<list<int> ranks> : And<[
// Vector types.
class VectorOf<list<Type> allowedTypes> :
- ShapedContainerType<allowedTypes, IsVectorTypePred, "vector">;
+ ShapedContainerType<allowedTypes, IsVectorTypePred, "vector",
+ "::mlir::VectorType">;
// Whether the number of elements of a vector is from the given
// `allowedRanks` list
@@ -534,7 +551,7 @@ class IsVectorOfRankPred<list<int> allowedRanks> :
// Any vector where the rank is from the given `allowedRanks` list
class VectorOfRank<list<int> allowedRanks> : Type<
IsVectorOfRankPred<allowedRanks>,
- " of ranks " # StrJoinInt<allowedRanks, "/">.result>;
+ " of ranks " # StrJoinInt<allowedRanks, "/">.result, "::mlir::VectorType">;
// Any vector where the rank is from the given `allowedRanks` list and the type
// is from the given `allowedTypes` list
@@ -543,7 +560,8 @@ class VectorOfRankAndType<list<int> allowedRanks,
And<[VectorOf<allowedTypes>.predicate,
VectorOfRank<allowedRanks>.predicate]>,
VectorOf<allowedTypes>.description #
- VectorOfRank<allowedRanks>.description>;
+ VectorOfRank<allowedRanks>.description,
+ "::mlir::VectorType">;
// Whether the number of elements of a vector is from the given
// `allowedLengths` list
@@ -558,7 +576,8 @@ class IsVectorOfLengthPred<list<int> allowedLengths> :
// `allowedLengths` list
class VectorOfLength<list<int> allowedLengths> : Type<
IsVectorOfLengthPred<allowedLengths>,
- " of length " # StrJoinInt<allowedLengths, "/">.result>;
+ " of length " # StrJoinInt<allowedLengths, "/">.result,
+ "::mlir::VectorType">;
// Any vector where the number of elements is from the given
@@ -569,30 +588,34 @@ class VectorOfLengthAndType<list<int> allowedLengths,
And<[VectorOf<allowedTypes>.predicate,
VectorOfLength<allowedLengths>.predicate]>,
VectorOf<allowedTypes>.description #
- VectorOfLength<allowedLengths>.description>;
+ VectorOfLength<allowedLengths>.description,
+ "::mlir::VectorType">;
def AnyVector : VectorOf<[AnyType]>;
// Shaped types.
-def AnyShaped: ShapedContainerType<[AnyType], IsShapedTypePred, "shaped">;
+def AnyShaped: ShapedContainerType<[AnyType], IsShapedTypePred, "shaped",
+ "::mlir::ShapedType">;
// Tensor types.
// Any tensor type whose element type is from the given `allowedTypes` list
class TensorOf<list<Type> allowedTypes> :
- ShapedContainerType<allowedTypes, IsTensorTypePred, "tensor">;
+ ShapedContainerType<allowedTypes, IsTensorTypePred, "tensor",
+ "::mlir::TensorType">;
def AnyTensor : TensorOf<[AnyType]>;
def AnyRankedTensor :
ShapedContainerType<[AnyType], And<[IsTensorTypePred, HasRankPred]>,
- "ranked tensor">;
+ "ranked tensor", "::mlir::TensorType">;
// TODO: Have an easy way to add another constraint to a type.
class StaticShapeTensorOf<list<Type> allowedTypes>
: Type<And<[TensorOf<allowedTypes>.predicate, HasStaticShapePred]>,
- "statically shaped " # TensorOf<allowedTypes>.description>;
+ "statically shaped " # TensorOf<allowedTypes>.description,
+ "::mlir::TensorType">;
def AnyStaticShapeTensor : StaticShapeTensorOf<[AnyType]>;
@@ -612,7 +635,7 @@ def F64Tensor : TensorOf<[F64]>;
class TensorRankOf<list<Type> allowedTypes, list<int> ranks> :
Type<And<[TensorOf<allowedTypes>.predicate, HasAnyRankOfPred<ranks>]>,
StrJoin<!foreach(rank, ranks, rank # "D"), "/">.result # " " #
- TensorOf<allowedTypes>.description>;
+ TensorOf<allowedTypes>.description, "::mlir::TensorType">;
class 0DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [0]>;
class 1DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [1]>;
@@ -623,12 +646,14 @@ class 4DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [4]>;
// Unranked Memref type
def AnyUnrankedMemRef :
ShapedContainerType<[AnyType],
- IsUnrankedMemRefTypePred, "unranked.memref">;
+ IsUnrankedMemRefTypePred, "unranked.memref",
+ "::mlir::MemRefType">;
// Memref type.
// Memrefs are blocks of data with fixed type and rank.
class MemRefOf<list<Type> allowedTypes> :
- ShapedContainerType<allowedTypes, IsMemRefTypePred, "memref">;
+ ShapedContainerType<allowedTypes, IsMemRefTypePred, "memref",
+ "::mlir::MemRefType">;
def AnyMemRef : MemRefOf<[AnyType]>;
@@ -679,7 +704,7 @@ class StridedMemRefRankOf<list<Type> allowedTypes, list<int> ranks> :
MemRefOf<allowedTypes>.description>;
// This represents a generic tuple without any constraints on element type.
-def AnyTuple : Type<IsTupleTypePred, "tuple">;
+def AnyTuple : Type<IsTupleTypePred, "tuple", "::mlir::TupleType">;
// A container type that has other types embedded in it, but (unlike
// ContainerType) can hold elements with a mix of types. Requires a call that
@@ -2414,9 +2439,7 @@ def replaceWithValue;
// the given C++ base class.
class TypeDef<Dialect dialect, string name,
string baseCppClass = "::mlir::Type">
- : DialectType<dialect, CPred<"">> {
- // The name of the C++ Type class.
- string cppClassName = name # "Type";
+ : DialectType<dialect, CPred<"">, /*descr*/"", name # "Type"> {
// The name of the C++ base class to use for this Type.
string cppBaseClassName = baseCppClass;
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index ddfee8e3a1c7..11dc4b77b677 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -28,10 +28,6 @@ namespace mlir {
class Builder;
class OpBuilder;
-namespace OpTrait {
-template <typename ConcreteType> class OneResult;
-}
-
/// This class represents success/failure for operation parsing. It is
/// essentially a simple wrapper class around LogicalResult that allows for
/// explicit conversion to bool. This allows for the parser to chain together
@@ -188,7 +184,8 @@ class OpState {
void setAttrs(DictionaryAttr newAttrs) { state->setAttrs(newAttrs); }
/// Set the dialect attributes for this operation, and preserve all dependent.
- template <typename DialectAttrs> void setDialectAttrs(DialectAttrs &&attrs) {
+ template <typename DialectAttrs>
+ void setDialectAttrs(DialectAttrs &&attrs) {
state->setDialectAttrs(std::forward<DialectAttrs>(attrs));
}
@@ -424,7 +421,8 @@ class OneOperand : public TraitBase<ConcreteType, OneOperand> {
///
/// class FooOp : public Op<FooOp, OpTrait::NOperands<2>::Impl> {
///
-template <unsigned N> class NOperands {
+template <unsigned N>
+class NOperands {
public:
static_assert(N > 1, "use ZeroOperands/OneOperand for N < 2");
@@ -443,7 +441,8 @@ template <unsigned N> class NOperands {
///
/// class FooOp : public Op<FooOp, OpTrait::AtLeastNOperands<2>::Impl> {
///
-template <unsigned N> class AtLeastNOperands {
+template <unsigned N>
+class AtLeastNOperands {
public:
template <typename ConcreteType>
class Impl : public detail::MultiOperandTraitBase<ConcreteType,
@@ -517,7 +516,8 @@ class OneRegion : public TraitBase<ConcreteType, OneRegion> {
/// This class provides the API for ops that are known to have a specified
/// number of regions.
-template <unsigned N> class NRegions {
+template <unsigned N>
+class NRegions {
public:
static_assert(N > 1, "use ZeroRegion/OneRegion for N < 2");
@@ -533,7 +533,8 @@ template <unsigned N> class NRegions {
/// This class provides APIs for ops that are known to have at least a specified
/// number of regions.
-template <unsigned N> class AtLeastNRegions {
+template <unsigned N>
+class AtLeastNRegions {
public:
template <typename ConcreteType>
class Impl : public detail::MultiRegionTraitBase<ConcreteType,
@@ -582,7 +583,8 @@ struct MultiResultTraitBase : public TraitBase<ConcreteType, TraitType> {
/// Replace all uses of results of this operation with the provided 'values'.
/// 'values' may correspond to an existing operation, or a range of 'Value'.
- template <typename ValuesT> void replaceAllUsesWith(ValuesT &&values) {
+ template <typename ValuesT>
+ void replaceAllUsesWith(ValuesT &&values) {
this->getOperation()->replaceAllUsesWith(std::forward<ValuesT>(values));
}
@@ -610,20 +612,19 @@ struct MultiResultTraitBase : public TraitBase<ConcreteType, TraitType> {
} // end namespace detail
/// This class provides return value APIs for ops that are known to have a
-/// single result.
+/// single result. ResultType is the concrete type returned by getType().
template <typename ConcreteType>
class OneResult : public TraitBase<ConcreteType, OneResult> {
public:
Value getResult() { return this->getOperation()->getResult(0); }
- Type getType() { return getResult().getType(); }
/// If the operation returns a single value, then the Op can be implicitly
/// converted to an Value. This yields the value of the only result.
operator Value() { return getResult(); }
- /// Replace all uses of 'this' value with the new value, updating anything in
- /// the IR that uses 'this' to use the other value instead. When this returns
- /// there are zero uses of 'this'.
+ /// Replace all uses of 'this' value with the new value, updating anything
+ /// in the IR that uses 'this' to use the other value instead. When this
+ /// returns there are zero uses of 'this'.
void replaceAllUsesWith(Value newValue) {
getResult().replaceAllUsesWith(newValue);
}
@@ -638,12 +639,33 @@ class OneResult : public TraitBase<ConcreteType, OneResult> {
}
};
+/// This trait is used for return value APIs for ops that are known to have a
+/// specific type other than `Type`. This allows the "getType()" member to be
+/// more specific for an op. This should be used in conjunction with OneResult,
+/// and occur in the trait list before OneResult.
+template <typename ResultType>
+class OneTypedResult {
+public:
+ /// This class provides return value APIs for ops that are known to have a
+ /// single result. ResultType is the concrete type returned by getType().
+ template <typename ConcreteType>
+ class Impl
+ : public TraitBase<ConcreteType, OneTypedResult<ResultType>::Impl> {
+ public:
+ ResultType getType() {
+ auto resultTy = this->getOperation()->getResult(0).getType();
+ return resultTy.template cast<ResultType>();
+ }
+ };
+};
+
/// This class provides the API for ops that are known to have a specified
/// number of results. This is used as a trait like this:
///
/// class FooOp : public Op<FooOp, OpTrait::NResults<2>::Impl> {
///
-template <unsigned N> class NResults {
+template <unsigned N>
+class NResults {
public:
static_assert(N > 1, "use ZeroResult/OneResult for N < 2");
@@ -662,7 +684,8 @@ template <unsigned N> class NResults {
///
/// class FooOp : public Op<FooOp, OpTrait::AtLeastNResults<2>::Impl> {
///
-template <unsigned N> class AtLeastNResults {
+template <unsigned N>
+class AtLeastNResults {
public:
template <typename ConcreteType>
class Impl : public detail::MultiResultTraitBase<ConcreteType,
@@ -1573,7 +1596,8 @@ class Op : public OpState, public Traits<ConcreteType>... {
using has_fold = decltype(
std::declval<T>().fold(std::declval<ArrayRef<Attribute>>(),
std::declval<SmallVectorImpl<OpFoldResult> &>()));
- template <typename T> using detect_has_fold = llvm::is_detected<has_fold, T>;
+ template <typename T>
+ using detect_has_fold = llvm::is_detected<has_fold, T>;
/// Trait to check if T provides a 'print' method.
template <typename T, typename... Args>
using has_print =
diff --git a/mlir/include/mlir/TableGen/Type.h b/mlir/include/mlir/TableGen/Type.h
index 2653b90196f7..37024f51bdc9 100644
--- a/mlir/include/mlir/TableGen/Type.h
+++ b/mlir/include/mlir/TableGen/Type.h
@@ -47,6 +47,9 @@ class TypeConstraint : public Constraint {
// Returns the builder call for this constraint if this is a buildable type,
// returns None otherwise.
Optional<StringRef> getBuilderCall() const;
+
+ // Return the C++ class name for this type (which may just be ::mlir::Type).
+ StringRef getCPPClassName() const;
};
// Wrapper class with helper methods for accessing Types defined in TableGen.
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index b315417a420b..e023b6da460b 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -612,7 +612,7 @@ class VectorCreateMaskOpConversion
LogicalResult
matchAndRewrite(vector::CreateMaskOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto dstType = op->getResult(0).getType().cast<VectorType>();
+ auto dstType = op.getType();
int64_t rank = dstType.getRank();
if (rank == 1) {
rewriter.replaceOp(
@@ -1091,8 +1091,7 @@ class VectorTypeCastOpConversion
auto loc = castOp->getLoc();
MemRefType sourceMemRefType =
castOp.getOperand().getType().cast<MemRefType>();
- MemRefType targetMemRefType =
- castOp.getResult().getType().cast<MemRefType>();
+ MemRefType targetMemRefType = castOp.getType();
// Only static shape casts supported atm.
if (!sourceMemRefType.hasStaticShape() ||
@@ -1459,7 +1458,7 @@ class VectorExtractStridedSliceOpConversion
LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
PatternRewriter &rewriter) const override {
- auto dstType = op.getResult().getType().cast<VectorType>();
+ auto dstType = op.getType();
assert(!op.offsets().getValue().empty() && "Unexpected empty offsets");
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 539e00d58dbf..a6a353c97977 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -67,7 +67,7 @@ static MaskFormat get1DMaskFormat(Value mask) {
ArrayAttr masks = m.mask_dim_sizes();
assert(masks.size() == 1);
int64_t i = masks[0].cast<IntegerAttr>().getInt();
- int64_t u = m.getType().cast<VectorType>().getDimSize(0);
+ int64_t u = m.getType().getDimSize(0);
if (i >= u)
return MaskFormat::AllTrue;
if (i <= 0)
@@ -849,7 +849,7 @@ static Value foldExtractFromShapeCast(ExtractOp extractOp) {
return Value();
// Get the nth dimension size starting from lowest dimension.
auto getDimReverse = [](VectorType type, int64_t n) {
- return type.getShape().take_back(n+1).front();
+ return type.getShape().take_back(n + 1).front();
};
int64_t destinationRank =
extractOp.getType().isa<VectorType>()
@@ -1870,9 +1870,8 @@ class StridedSliceConstantFolder final
auto dense = constantOp.value().dyn_cast<SplatElementsAttr>();
if (!dense)
return failure();
- auto newAttr = DenseElementsAttr::get(
- extractStridedSliceOp.getType().cast<VectorType>(),
- dense.getSplatValue());
+ auto newAttr = DenseElementsAttr::get(extractStridedSliceOp.getType(),
+ dense.getSplatValue());
rewriter.replaceOpWithNewOp<ConstantOp>(extractStridedSliceOp, newAttr);
return success();
}
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 5ba82b39a5a6..f1708db113d4 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -999,8 +999,7 @@ struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
return failure();
auto operandSourceVectorType =
sourceShapeCastOp.source().getType().cast<VectorType>();
- auto operandResultVectorType =
- sourceShapeCastOp.result().getType().cast<VectorType>();
+ auto operandResultVectorType = sourceShapeCastOp.getType();
// Check if shape cast operations invert each other.
if (operandSourceVectorType != resultVectorType ||
@@ -1397,7 +1396,7 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
LogicalResult matchAndRewrite(vector::ConstantMaskOp op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
- auto dstType = op.getResult().getType().cast<VectorType>();
+ auto dstType = op.getType();
auto eltType = dstType.getElementType();
auto dimSizes = op.mask_dim_sizes();
int64_t rank = dimSizes.size();
diff --git a/mlir/lib/TableGen/Type.cpp b/mlir/lib/TableGen/Type.cpp
index 5fe6bbbd5e83..826821937dca 100644
--- a/mlir/lib/TableGen/Type.cpp
+++ b/mlir/lib/TableGen/Type.cpp
@@ -53,6 +53,11 @@ Optional<StringRef> TypeConstraint::getBuilderCall() const {
.Default([](auto *) { return llvm::None; });
}
+// Return the C++ class name for this type (which may just be ::mlir::Type).
+StringRef TypeConstraint::getCPPClassName() const {
+ return def->getValueAsString("cppClassName");
+}
+
Type::Type(const llvm::Record *record) : TypeConstraint(record) {}
StringRef Type::getTypeDescription() const {
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 40e1c355daf8..995b4cd05cd5 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -2137,11 +2137,18 @@ void OpEmitter::genTraits() {
unsigned numVariadicRegions = op.getNumVariadicRegions();
addSizeCountTrait(opClass, "Region", numRegions, numVariadicRegions);
- // Add result size trait.
+ // Add result size traits.
int numResults = op.getNumResults();
int numVariadicResults = op.getNumVariableLengthResults();
addSizeCountTrait(opClass, "Result", numResults, numVariadicResults);
+ // For single result ops with a known specific type, generate a OneTypedResult
+ // trait.
+ if (numResults == 1 && numVariadicResults == 0) {
+ auto cppName = op.getResults().begin()->constraint.getCPPClassName();
+ opClass.addTrait("::mlir::OpTrait::OneTypedResult<" + cppName + ">::Impl");
+ }
+
// Add successor size trait.
unsigned numSuccessors = op.getNumSuccessors();
unsigned numVariadicSuccessors = op.getNumVariadicSuccessors();
More information about the Mlir-commits
mailing list