[Mlir-commits] [mlir] 9eb3e56 - [ODS] Make the getType() method on a OneResult instruction return a specific type.

Chris Lattner llvmlistbot at llvm.org
Sat Dec 26 13:52:48 PST 2020


Author: Chris Lattner
Date: 2020-12-26T13:52:40-08:00
New Revision: 9eb3e564d3b1c772a64eef6ecaa3b1705d065218

URL: https://github.com/llvm/llvm-project/commit/9eb3e564d3b1c772a64eef6ecaa3b1705d065218
DIFF: https://github.com/llvm/llvm-project/commit/9eb3e564d3b1c772a64eef6ecaa3b1705d065218.diff

LOG: [ODS] Make the getType() method on a OneResult instruction return a specific type.

Implement Bug 46698, making ODS synthesize a getType() method that returns a
specific C++ class for OneResult methods where we know that class.  This eliminates
a common source of casts in things like:

   myOp.getType().cast<FIRRTLType>().getPassive()

because we know that myOp always returns a FIRRTLType.  This also encourages
op authors to type their results more tightly (which is also good for
verification).

I chose to implement this by splitting the OneResult trait into itself plus a
OneTypedResult trait, given that many things are using `hasTrait<OneResult>`
to conditionalize various logic.

While this changes makes many many ops get more specific getType() results, it
is generally drop-in compatible with the previous behavior because 'x.cast<T>()'
is allowed when x is already known to be a T.  The one exception to this is that
we need declarations of the types used by ops, which is why a couple headers
needed additional #includes.

I updated a few things in tree to remove the now-redundant `.cast<>`'s, but there
are probably many more than can be removed.

Differential Revision: https://reviews.llvm.org/D93790

Added: 
    

Modified: 
    mlir/docs/Tutorials/Toy/Ch-2.md
    mlir/examples/standalone/include/Standalone/StandaloneOps.h
    mlir/include/mlir/Dialect/AVX512/AVX512Dialect.h
    mlir/include/mlir/Dialect/ArmNeon/ArmNeonDialect.h
    mlir/include/mlir/IR/OpBase.td
    mlir/include/mlir/IR/OpDefinition.h
    mlir/include/mlir/TableGen/Type.h
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/lib/TableGen/Type.cpp
    mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Tutorials/Toy/Ch-2.md b/mlir/docs/Tutorials/Toy/Ch-2.md
index be76dc351912..6c9b93ae4066 100644
--- a/mlir/docs/Tutorials/Toy/Ch-2.md
+++ b/mlir/docs/Tutorials/Toy/Ch-2.md
@@ -210,7 +210,9 @@ class ConstantOp : public mlir::Op<ConstantOp,
                      /// The ConstantOp takes no inputs.
                      mlir::OpTrait::ZeroOperands,
                      /// The ConstantOp returns a single result.
-                     mlir::OpTrait::OneResult> {
+                     mlir::OpTrait::OneResult,
+                     /// The result of getType is `Type`.
+                     mlir::OpTraits::OneTypedResult<Type>::Impl> {
 
  public:
   /// Inherit the constructors from the base Op class.

diff  --git a/mlir/examples/standalone/include/Standalone/StandaloneOps.h b/mlir/examples/standalone/include/Standalone/StandaloneOps.h
index 5a8c5d1040e6..a56c2867b1c8 100644
--- a/mlir/examples/standalone/include/Standalone/StandaloneOps.h
+++ b/mlir/examples/standalone/include/Standalone/StandaloneOps.h
@@ -9,6 +9,7 @@
 #ifndef STANDALONE_STANDALONEOPS_H
 #define STANDALONE_STANDALONEOPS_H
 
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"

diff  --git a/mlir/include/mlir/Dialect/AVX512/AVX512Dialect.h b/mlir/include/mlir/Dialect/AVX512/AVX512Dialect.h
index aae3dbdf179f..eddee61d6b19 100644
--- a/mlir/include/mlir/Dialect/AVX512/AVX512Dialect.h
+++ b/mlir/include/mlir/Dialect/AVX512/AVX512Dialect.h
@@ -13,6 +13,7 @@
 #ifndef MLIR_DIALECT_AVX512_AVX512DIALECT_H_
 #define MLIR_DIALECT_AVX512_AVX512DIALECT_H_
 
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"

diff  --git a/mlir/include/mlir/Dialect/ArmNeon/ArmNeonDialect.h b/mlir/include/mlir/Dialect/ArmNeon/ArmNeonDialect.h
index 76153af97689..18535353a104 100644
--- a/mlir/include/mlir/Dialect/ArmNeon/ArmNeonDialect.h
+++ b/mlir/include/mlir/Dialect/ArmNeon/ArmNeonDialect.h
@@ -13,6 +13,7 @@
 #ifndef MLIR_DIALECT_ARMNEON_ARMNEONDIALECT_H_
 #define MLIR_DIALECT_ARMNEON_ARMNEONDIALECT_H_
 
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 0ae572c38f49..857a652f17d9 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -175,8 +175,12 @@ class Constraint<Pred pred, string desc = ""> {
 // are considered as uncategorized constraints.
 
 // Subclass for constraints on a type.
-class TypeConstraint<Pred predicate, string description = ""> :
-    Constraint<predicate, description>;
+class TypeConstraint<Pred predicate, string description = "",
+                     string cppClassNameParam = "::mlir::Type"> :
+    Constraint<predicate, description> {
+  // The name of the C++ Type class if known, or Type if not.
+  string cppClassName = cppClassNameParam;
+}
 
 // Subclass for constraints on an attribute.
 class AttrConstraint<Pred predicate, string description = ""> :
@@ -285,8 +289,9 @@ class Dialect {
 //===----------------------------------------------------------------------===//
 
 // A type, carries type constraints.
-class Type<Pred condition, string descr = ""> :
-    TypeConstraint<condition, descr> {
+class Type<Pred condition, string descr = "",
+           string cppClassName = "::mlir::Type"> :
+    TypeConstraint<condition, descr, cppClassName> {
   string typeDescription = "";
   string builderCall = "";
 }
@@ -299,8 +304,9 @@ class TypeAlias<Type t, string description = t.description> :
 }
 
 // A type of a specific dialect.
-class DialectType<Dialect d, Pred condition, string descr = ""> :
-    Type<condition, descr> {
+class DialectType<Dialect d, Pred condition, string descr = "",
+                  string cppClassName = "::mlir::Type"> :
+    Type<condition, descr, cppClassName> {
   Dialect dialect = d;
 }
 
@@ -331,11 +337,13 @@ class BuildableType<code builder> {
 def AnyType : Type<CPred<"true">, "any type">;
 
 // None type
-def NoneType : Type<CPred<"$_self.isa<::mlir::NoneType>()">, "none type">,
+def NoneType : Type<CPred<"$_self.isa<::mlir::NoneType>()">, "none type",
+                    "::mlir::NoneType">,
       BuildableType<"$_builder.getType<::mlir::NoneType>()">;
 
 // Any type from the given list
-class AnyTypeOf<list<Type> allowedTypes, string description = ""> : Type<
+class AnyTypeOf<list<Type> allowedTypes, string description = "",
+                string cppClassName = "::mlir::Type"> : Type<
     // Satisfy any of the allowed type's condition
     Or<!foreach(allowedtype, allowedTypes, allowedtype.predicate)>,
     !if(!eq(description, ""),
@@ -345,7 +353,8 @@ class AnyTypeOf<list<Type> allowedTypes, string description = ""> : Type<
 // Integer types.
 
 // Any integer type irrespective of its width and signedness semantics.
-def AnyInteger : Type<CPred<"$_self.isa<::mlir::IntegerType>()">, "integer">;
+def AnyInteger : Type<CPred<"$_self.isa<::mlir::IntegerType>()">, "integer",
+                      "::mlir::IntegerType">;
 
 // Any integer type (regardless of signedness semantics) of a specific width.
 class AnyI<int width>
@@ -355,7 +364,8 @@ class AnyI<int width>
 
 class AnyIntOfWidths<list<int> widths> :
     AnyTypeOf<!foreach(w, widths, AnyI<w>),
-              StrJoinInt<widths, "/">.result # "-bit integer">;
+              StrJoinInt<widths, "/">.result # "-bit integer",
+              "::mlir::IntegerType">;
 
 def AnyI1  : AnyI<1>;
 def AnyI8  : AnyI<8>;
@@ -365,12 +375,13 @@ def AnyI64 : AnyI<64>;
 
 // Any signless integer type irrespective of its width.
 def AnySignlessInteger : Type<
-  CPred<"$_self.isSignlessInteger()">, "signless integer">;
+  CPred<"$_self.isSignlessInteger()">, "signless integer",
+        "::mlir::IntegerType">;
 
 // Signless integer type of a specific width.
 class I<int width>
     : Type<CPred<"$_self.isSignlessInteger(" # width # ")">,
-                  width # "-bit signless integer">,
+                  width # "-bit signless integer", "::mlir::IntegerType">,
       BuildableType<"$_builder.getIntegerType(" # width # ")"> {
   int bitwidth = width;
 }
@@ -392,7 +403,7 @@ def AnySignedInteger : Type<
 // Signed integer type of a specific width.
 class SI<int width>
     : Type<CPred<"$_self.isSignedInteger(" # width # ")">,
-                  width # "-bit signed integer">,
+                  width # "-bit signed integer", "::mlir::IntegerType">,
       BuildableType<
         "$_builder.getIntegerType(" # width # ", /*isSigned=*/true)"> {
   int bitwidth = width;
@@ -415,7 +426,7 @@ def AnyUnsignedInteger : Type<
 // Unsigned integer type of a specific width.
 class UI<int width>
     : Type<CPred<"$_self.isUnsignedInteger(" # width # ")">,
-                  width # "-bit unsigned integer">,
+                  width # "-bit unsigned integer", "::mlir::IntegerType">,
       BuildableType<
         "$_builder.getIntegerType(" # width # ", /*isSigned=*/false)"> {
   int bitwidth = width;
@@ -432,18 +443,20 @@ def UI32 : UI<32>;
 def UI64 : UI<64>;
 
 // Index type.
-def Index : Type<CPred<"$_self.isa<::mlir::IndexType>()">, "index">,
+def Index : Type<CPred<"$_self.isa<::mlir::IndexType>()">, "index",
+                 "::mlir::IndexType">,
             BuildableType<"$_builder.getIndexType()">;
 
 // Floating point types.
 
 // Any float type irrespective of its width.
-def AnyFloat : Type<CPred<"$_self.isa<::mlir::FloatType>()">, "floating-point">;
+def AnyFloat : Type<CPred<"$_self.isa<::mlir::FloatType>()">, "floating-point",
+                    "::mlir::FloatType">;
 
 // Float type of a specific width.
 class F<int width>
     : Type<CPred<"$_self.isF" # width # "()">,
-                width # "-bit float">,
+           width # "-bit float", "::mlir::FloatType">,
       BuildableType<"$_builder.getF" # width # "Type()"> {
   int bitwidth = width;
 }
@@ -465,16 +478,17 @@ class Complex<Type type>
           SubstLeaves<"$_self",
                       "$_self.cast<::mlir::ComplexType>().getElementType()",
            type.predicate>]>,
-           "complex type with " # type.description # " elements"> {
+           "complex type with " # type.description # " elements",
+           "::mlir::ComplexType"> {
   Type elementType = type;
 }
 
 def AnyComplex : Type<CPred<"$_self.isa<::mlir::ComplexType>()">,
-                            "complex-type">;
+                      "complex-type", "::mlir::ComplexType">;
 
 class OpaqueType<string dialect, string name, string description>
   : Type<CPred<"isOpaqueTypeWithName($_self, \""#dialect#"\", \""#name#"\")">,
-         description>,
+         description, "::mlir::OpaqueType">,
     BuildableType<"::mlir::OpaqueType::get($_builder.getContext(), "
                   "$_builder.getIdentifier(\"" # dialect # "\"), \""
                   # name # "\")">;
@@ -483,17 +497,17 @@ class OpaqueType<string dialect, string name, string description>
 
 // Any function type.
 def FunctionType : Type<CPred<"$_self.isa<::mlir::FunctionType>()">,
-                              "function type">;
+                              "function type", "::mlir::FunctionType">;
 
 // A container type is a type that has another type embedded within it.
 class ContainerType<Type etype, Pred containerPred, code elementTypeCall,
-                    string descr> :
+                    string descr, string cppClassName = "::mlir::Type"> :
     // First, check the container predicate.  Then, substitute the extracted
     // element into the element type checker.
     Type<And<[containerPred,
                 SubstLeaves<"$_self", !cast<string>(elementTypeCall),
                 etype.predicate>]>,
-         descr # " of " # etype.description # " values"> {
+         descr # " of " # etype.description # " values", cppClassName> {
   // The type of elements in the container.
   Type elementType = etype;
 
@@ -502,9 +516,11 @@ class ContainerType<Type etype, Pred containerPred, code elementTypeCall,
 }
 
 class ShapedContainerType<list<Type> allowedTypes,
-                          Pred containerPred, string descr> :
+                          Pred containerPred, string descr,
+                          string cppClassName = "::mlir::Type"> :
     ContainerType<AnyTypeOf<allowedTypes>, containerPred,
-                  "$_self.cast<::mlir::ShapedType>().getElementType()", descr>;
+                  "$_self.cast<::mlir::ShapedType>().getElementType()", descr,
+                  cppClassName>;
 
 // Whether a shaped type is ranked.
 def HasRankPred : CPred<"$_self.cast<::mlir::ShapedType>().hasRank()">;
@@ -520,7 +536,8 @@ class HasAnyRankOfPred<list<int> ranks> : And<[
 // Vector types.
 
 class VectorOf<list<Type> allowedTypes> :
-  ShapedContainerType<allowedTypes, IsVectorTypePred, "vector">;
+  ShapedContainerType<allowedTypes, IsVectorTypePred, "vector",
+                      "::mlir::VectorType">;
 
 // Whether the number of elements of a vector is from the given
 // `allowedRanks` list
@@ -534,7 +551,7 @@ class IsVectorOfRankPred<list<int> allowedRanks> :
 // Any vector where the rank is from the given `allowedRanks` list
 class VectorOfRank<list<int> allowedRanks> : Type<
   IsVectorOfRankPred<allowedRanks>,
-  " of ranks " # StrJoinInt<allowedRanks, "/">.result>;
+  " of ranks " # StrJoinInt<allowedRanks, "/">.result, "::mlir::VectorType">;
 
 // Any vector where the rank is from the given `allowedRanks` list and the type
 // is from the given `allowedTypes` list
@@ -543,7 +560,8 @@ class VectorOfRankAndType<list<int> allowedRanks,
   And<[VectorOf<allowedTypes>.predicate,
        VectorOfRank<allowedRanks>.predicate]>,
   VectorOf<allowedTypes>.description #
-  VectorOfRank<allowedRanks>.description>;
+  VectorOfRank<allowedRanks>.description,
+  "::mlir::VectorType">;
 
 // Whether the number of elements of a vector is from the given
 // `allowedLengths` list
@@ -558,7 +576,8 @@ class IsVectorOfLengthPred<list<int> allowedLengths> :
 // `allowedLengths` list
 class VectorOfLength<list<int> allowedLengths> : Type<
   IsVectorOfLengthPred<allowedLengths>,
-  " of length " # StrJoinInt<allowedLengths, "/">.result>;
+  " of length " # StrJoinInt<allowedLengths, "/">.result,
+  "::mlir::VectorType">;
 
 
 // Any vector where the number of elements is from the given
@@ -569,30 +588,34 @@ class VectorOfLengthAndType<list<int> allowedLengths,
   And<[VectorOf<allowedTypes>.predicate,
        VectorOfLength<allowedLengths>.predicate]>,
   VectorOf<allowedTypes>.description #
-  VectorOfLength<allowedLengths>.description>;
+  VectorOfLength<allowedLengths>.description,
+  "::mlir::VectorType">;
 
 def AnyVector : VectorOf<[AnyType]>;
 
 // Shaped types.
 
-def AnyShaped: ShapedContainerType<[AnyType], IsShapedTypePred, "shaped">;
+def AnyShaped: ShapedContainerType<[AnyType], IsShapedTypePred, "shaped",
+                                   "::mlir::ShapedType">;
 
 // Tensor types.
 
 // Any tensor type whose element type is from the given `allowedTypes` list
 class TensorOf<list<Type> allowedTypes> :
-  ShapedContainerType<allowedTypes, IsTensorTypePred, "tensor">;
+  ShapedContainerType<allowedTypes, IsTensorTypePred, "tensor",
+                      "::mlir::TensorType">;
 
 def AnyTensor : TensorOf<[AnyType]>;
 
 def AnyRankedTensor :
   ShapedContainerType<[AnyType], And<[IsTensorTypePred, HasRankPred]>,
-  "ranked tensor">;
+  "ranked tensor", "::mlir::TensorType">;
 
 // TODO: Have an easy way to add another constraint to a type.
 class StaticShapeTensorOf<list<Type> allowedTypes>
     : Type<And<[TensorOf<allowedTypes>.predicate, HasStaticShapePred]>,
-           "statically shaped " # TensorOf<allowedTypes>.description>;
+           "statically shaped " # TensorOf<allowedTypes>.description,
+           "::mlir::TensorType">;
 
 def AnyStaticShapeTensor : StaticShapeTensorOf<[AnyType]>;
 
@@ -612,7 +635,7 @@ def F64Tensor  : TensorOf<[F64]>;
 class TensorRankOf<list<Type> allowedTypes, list<int> ranks> :
     Type<And<[TensorOf<allowedTypes>.predicate, HasAnyRankOfPred<ranks>]>,
          StrJoin<!foreach(rank, ranks, rank # "D"), "/">.result # " " #
-         TensorOf<allowedTypes>.description>;
+         TensorOf<allowedTypes>.description, "::mlir::TensorType">;
 
 class 0DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [0]>;
 class 1DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [1]>;
@@ -623,12 +646,14 @@ class 4DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [4]>;
 // Unranked Memref type
 def AnyUnrankedMemRef :
     ShapedContainerType<[AnyType],
-                        IsUnrankedMemRefTypePred, "unranked.memref">;
+                        IsUnrankedMemRefTypePred, "unranked.memref",
+                        "::mlir::MemRefType">;
 // Memref type.
 
 // Memrefs are blocks of data with fixed type and rank.
 class MemRefOf<list<Type> allowedTypes> :
-    ShapedContainerType<allowedTypes, IsMemRefTypePred, "memref">;
+    ShapedContainerType<allowedTypes, IsMemRefTypePred, "memref",
+                        "::mlir::MemRefType">;
 
 def AnyMemRef : MemRefOf<[AnyType]>;
 
@@ -679,7 +704,7 @@ class StridedMemRefRankOf<list<Type> allowedTypes, list<int> ranks> :
          MemRefOf<allowedTypes>.description>;
 
 // This represents a generic tuple without any constraints on element type.
-def AnyTuple : Type<IsTupleTypePred, "tuple">;
+def AnyTuple : Type<IsTupleTypePred, "tuple", "::mlir::TupleType">;
 
 // A container type that has other types embedded in it, but (unlike
 // ContainerType) can hold elements with a mix of types. Requires a call that
@@ -2414,9 +2439,7 @@ def replaceWithValue;
 // the given C++ base class.
 class TypeDef<Dialect dialect, string name,
               string baseCppClass = "::mlir::Type">
-    : DialectType<dialect, CPred<"">> {
-  // The name of the C++ Type class.
-  string cppClassName = name # "Type";
+    : DialectType<dialect, CPred<"">, /*descr*/"", name # "Type"> {
   // The name of the C++ base class to use for this Type.
   string cppBaseClassName = baseCppClass;
 

diff  --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index ddfee8e3a1c7..11dc4b77b677 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -28,10 +28,6 @@ namespace mlir {
 class Builder;
 class OpBuilder;
 
-namespace OpTrait {
-template <typename ConcreteType> class OneResult;
-}
-
 /// This class represents success/failure for operation parsing. It is
 /// essentially a simple wrapper class around LogicalResult that allows for
 /// explicit conversion to bool. This allows for the parser to chain together
@@ -188,7 +184,8 @@ class OpState {
   void setAttrs(DictionaryAttr newAttrs) { state->setAttrs(newAttrs); }
 
   /// Set the dialect attributes for this operation, and preserve all dependent.
-  template <typename DialectAttrs> void setDialectAttrs(DialectAttrs &&attrs) {
+  template <typename DialectAttrs>
+  void setDialectAttrs(DialectAttrs &&attrs) {
     state->setDialectAttrs(std::forward<DialectAttrs>(attrs));
   }
 
@@ -424,7 +421,8 @@ class OneOperand : public TraitBase<ConcreteType, OneOperand> {
 ///
 ///   class FooOp : public Op<FooOp, OpTrait::NOperands<2>::Impl> {
 ///
-template <unsigned N> class NOperands {
+template <unsigned N>
+class NOperands {
 public:
   static_assert(N > 1, "use ZeroOperands/OneOperand for N < 2");
 
@@ -443,7 +441,8 @@ template <unsigned N> class NOperands {
 ///
 ///   class FooOp : public Op<FooOp, OpTrait::AtLeastNOperands<2>::Impl> {
 ///
-template <unsigned N> class AtLeastNOperands {
+template <unsigned N>
+class AtLeastNOperands {
 public:
   template <typename ConcreteType>
   class Impl : public detail::MultiOperandTraitBase<ConcreteType,
@@ -517,7 +516,8 @@ class OneRegion : public TraitBase<ConcreteType, OneRegion> {
 
 /// This class provides the API for ops that are known to have a specified
 /// number of regions.
-template <unsigned N> class NRegions {
+template <unsigned N>
+class NRegions {
 public:
   static_assert(N > 1, "use ZeroRegion/OneRegion for N < 2");
 
@@ -533,7 +533,8 @@ template <unsigned N> class NRegions {
 
 /// This class provides APIs for ops that are known to have at least a specified
 /// number of regions.
-template <unsigned N> class AtLeastNRegions {
+template <unsigned N>
+class AtLeastNRegions {
 public:
   template <typename ConcreteType>
   class Impl : public detail::MultiRegionTraitBase<ConcreteType,
@@ -582,7 +583,8 @@ struct MultiResultTraitBase : public TraitBase<ConcreteType, TraitType> {
 
   /// Replace all uses of results of this operation with the provided 'values'.
   /// 'values' may correspond to an existing operation, or a range of 'Value'.
-  template <typename ValuesT> void replaceAllUsesWith(ValuesT &&values) {
+  template <typename ValuesT>
+  void replaceAllUsesWith(ValuesT &&values) {
     this->getOperation()->replaceAllUsesWith(std::forward<ValuesT>(values));
   }
 
@@ -610,20 +612,19 @@ struct MultiResultTraitBase : public TraitBase<ConcreteType, TraitType> {
 } // end namespace detail
 
 /// This class provides return value APIs for ops that are known to have a
-/// single result.
+/// single result.  ResultType is the concrete type returned by getType().
 template <typename ConcreteType>
 class OneResult : public TraitBase<ConcreteType, OneResult> {
 public:
   Value getResult() { return this->getOperation()->getResult(0); }
-  Type getType() { return getResult().getType(); }
 
   /// If the operation returns a single value, then the Op can be implicitly
   /// converted to an Value. This yields the value of the only result.
   operator Value() { return getResult(); }
 
-  /// Replace all uses of 'this' value with the new value, updating anything in
-  /// the IR that uses 'this' to use the other value instead.  When this returns
-  /// there are zero uses of 'this'.
+  /// Replace all uses of 'this' value with the new value, updating anything
+  /// in the IR that uses 'this' to use the other value instead.  When this
+  /// returns there are zero uses of 'this'.
   void replaceAllUsesWith(Value newValue) {
     getResult().replaceAllUsesWith(newValue);
   }
@@ -638,12 +639,33 @@ class OneResult : public TraitBase<ConcreteType, OneResult> {
   }
 };
 
+/// This trait is used for return value APIs for ops that are known to have a
+/// specific type other than `Type`.  This allows the "getType()" member to be
+/// more specific for an op.  This should be used in conjunction with OneResult,
+/// and occur in the trait list before OneResult.
+template <typename ResultType>
+class OneTypedResult {
+public:
+  /// This class provides return value APIs for ops that are known to have a
+  /// single result.  ResultType is the concrete type returned by getType().
+  template <typename ConcreteType>
+  class Impl
+      : public TraitBase<ConcreteType, OneTypedResult<ResultType>::Impl> {
+  public:
+    ResultType getType() {
+      auto resultTy = this->getOperation()->getResult(0).getType();
+      return resultTy.template cast<ResultType>();
+    }
+  };
+};
+
 /// This class provides the API for ops that are known to have a specified
 /// number of results.  This is used as a trait like this:
 ///
 ///   class FooOp : public Op<FooOp, OpTrait::NResults<2>::Impl> {
 ///
-template <unsigned N> class NResults {
+template <unsigned N>
+class NResults {
 public:
   static_assert(N > 1, "use ZeroResult/OneResult for N < 2");
 
@@ -662,7 +684,8 @@ template <unsigned N> class NResults {
 ///
 ///   class FooOp : public Op<FooOp, OpTrait::AtLeastNResults<2>::Impl> {
 ///
-template <unsigned N> class AtLeastNResults {
+template <unsigned N>
+class AtLeastNResults {
 public:
   template <typename ConcreteType>
   class Impl : public detail::MultiResultTraitBase<ConcreteType,
@@ -1573,7 +1596,8 @@ class Op : public OpState, public Traits<ConcreteType>... {
   using has_fold = decltype(
       std::declval<T>().fold(std::declval<ArrayRef<Attribute>>(),
                              std::declval<SmallVectorImpl<OpFoldResult> &>()));
-  template <typename T> using detect_has_fold = llvm::is_detected<has_fold, T>;
+  template <typename T>
+  using detect_has_fold = llvm::is_detected<has_fold, T>;
   /// Trait to check if T provides a 'print' method.
   template <typename T, typename... Args>
   using has_print =

diff  --git a/mlir/include/mlir/TableGen/Type.h b/mlir/include/mlir/TableGen/Type.h
index 2653b90196f7..37024f51bdc9 100644
--- a/mlir/include/mlir/TableGen/Type.h
+++ b/mlir/include/mlir/TableGen/Type.h
@@ -47,6 +47,9 @@ class TypeConstraint : public Constraint {
   // Returns the builder call for this constraint if this is a buildable type,
   // returns None otherwise.
   Optional<StringRef> getBuilderCall() const;
+
+  // Return the C++ class name for this type (which may just be ::mlir::Type).
+  StringRef getCPPClassName() const;
 };
 
 // Wrapper class with helper methods for accessing Types defined in TableGen.

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index b315417a420b..e023b6da460b 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -612,7 +612,7 @@ class VectorCreateMaskOpConversion
   LogicalResult
   matchAndRewrite(vector::CreateMaskOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    auto dstType = op->getResult(0).getType().cast<VectorType>();
+    auto dstType = op.getType();
     int64_t rank = dstType.getRank();
     if (rank == 1) {
       rewriter.replaceOp(
@@ -1091,8 +1091,7 @@ class VectorTypeCastOpConversion
     auto loc = castOp->getLoc();
     MemRefType sourceMemRefType =
         castOp.getOperand().getType().cast<MemRefType>();
-    MemRefType targetMemRefType =
-        castOp.getResult().getType().cast<MemRefType>();
+    MemRefType targetMemRefType = castOp.getType();
 
     // Only static shape casts supported atm.
     if (!sourceMemRefType.hasStaticShape() ||
@@ -1459,7 +1458,7 @@ class VectorExtractStridedSliceOpConversion
 
   LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
                                 PatternRewriter &rewriter) const override {
-    auto dstType = op.getResult().getType().cast<VectorType>();
+    auto dstType = op.getType();
 
     assert(!op.offsets().getValue().empty() && "Unexpected empty offsets");
 

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 539e00d58dbf..a6a353c97977 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -67,7 +67,7 @@ static MaskFormat get1DMaskFormat(Value mask) {
     ArrayAttr masks = m.mask_dim_sizes();
     assert(masks.size() == 1);
     int64_t i = masks[0].cast<IntegerAttr>().getInt();
-    int64_t u = m.getType().cast<VectorType>().getDimSize(0);
+    int64_t u = m.getType().getDimSize(0);
     if (i >= u)
       return MaskFormat::AllTrue;
     if (i <= 0)
@@ -849,7 +849,7 @@ static Value foldExtractFromShapeCast(ExtractOp extractOp) {
     return Value();
   // Get the nth dimension size starting from lowest dimension.
   auto getDimReverse = [](VectorType type, int64_t n) {
-    return type.getShape().take_back(n+1).front();
+    return type.getShape().take_back(n + 1).front();
   };
   int64_t destinationRank =
       extractOp.getType().isa<VectorType>()
@@ -1870,9 +1870,8 @@ class StridedSliceConstantFolder final
     auto dense = constantOp.value().dyn_cast<SplatElementsAttr>();
     if (!dense)
       return failure();
-    auto newAttr = DenseElementsAttr::get(
-        extractStridedSliceOp.getType().cast<VectorType>(),
-        dense.getSplatValue());
+    auto newAttr = DenseElementsAttr::get(extractStridedSliceOp.getType(),
+                                          dense.getSplatValue());
     rewriter.replaceOpWithNewOp<ConstantOp>(extractStridedSliceOp, newAttr);
     return success();
   }

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 5ba82b39a5a6..f1708db113d4 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -999,8 +999,7 @@ struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
       return failure();
     auto operandSourceVectorType =
         sourceShapeCastOp.source().getType().cast<VectorType>();
-    auto operandResultVectorType =
-        sourceShapeCastOp.result().getType().cast<VectorType>();
+    auto operandResultVectorType = sourceShapeCastOp.getType();
 
     // Check if shape cast operations invert each other.
     if (operandSourceVectorType != resultVectorType ||
@@ -1397,7 +1396,7 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
   LogicalResult matchAndRewrite(vector::ConstantMaskOp op,
                                 PatternRewriter &rewriter) const override {
     auto loc = op.getLoc();
-    auto dstType = op.getResult().getType().cast<VectorType>();
+    auto dstType = op.getType();
     auto eltType = dstType.getElementType();
     auto dimSizes = op.mask_dim_sizes();
     int64_t rank = dimSizes.size();

diff  --git a/mlir/lib/TableGen/Type.cpp b/mlir/lib/TableGen/Type.cpp
index 5fe6bbbd5e83..826821937dca 100644
--- a/mlir/lib/TableGen/Type.cpp
+++ b/mlir/lib/TableGen/Type.cpp
@@ -53,6 +53,11 @@ Optional<StringRef> TypeConstraint::getBuilderCall() const {
       .Default([](auto *) { return llvm::None; });
 }
 
+// Return the C++ class name for this type (which may just be ::mlir::Type).
+StringRef TypeConstraint::getCPPClassName() const {
+  return def->getValueAsString("cppClassName");
+}
+
 Type::Type(const llvm::Record *record) : TypeConstraint(record) {}
 
 StringRef Type::getTypeDescription() const {

diff  --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 40e1c355daf8..995b4cd05cd5 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -2137,11 +2137,18 @@ void OpEmitter::genTraits() {
   unsigned numVariadicRegions = op.getNumVariadicRegions();
   addSizeCountTrait(opClass, "Region", numRegions, numVariadicRegions);
 
-  // Add result size trait.
+  // Add result size traits.
   int numResults = op.getNumResults();
   int numVariadicResults = op.getNumVariableLengthResults();
   addSizeCountTrait(opClass, "Result", numResults, numVariadicResults);
 
+  // For single result ops with a known specific type, generate a OneTypedResult
+  // trait.
+  if (numResults == 1 && numVariadicResults == 0) {
+    auto cppName = op.getResults().begin()->constraint.getCPPClassName();
+    opClass.addTrait("::mlir::OpTrait::OneTypedResult<" + cppName + ">::Impl");
+  }
+
   // Add successor size trait.
   unsigned numSuccessors = op.getNumSuccessors();
   unsigned numVariadicSuccessors = op.getNumVariadicSuccessors();


        


More information about the Mlir-commits mailing list