[Mlir-commits] [mlir] c3dcf39 - [mlir] Restrict to requiring traits when using InferTensorType trait.

Jacques Pienaar llvmlistbot at llvm.org
Mon Oct 11 14:56:36 PDT 2021


Author: Jacques Pienaar
Date: 2021-10-11T14:56:28-07:00
New Revision: c3dcf39554dbea780d6cb7e12239451ba47a2668

URL: https://github.com/llvm/llvm-project/commit/c3dcf39554dbea780d6cb7e12239451ba47a2668
DIFF: https://github.com/llvm/llvm-project/commit/c3dcf39554dbea780d6cb7e12239451ba47a2668.diff

LOG: [mlir] Restrict to requiring traits when using InferTensorType trait.

Avoids running into segfaults accidentally.

Differential Revision: https://reviews.llvm.org/D110297

Added: 
    

Modified: 
    mlir/include/mlir/Interfaces/InferTypeOpInterface.h

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
index e153e66bd9d97..c4f8f2d905e56 100644
--- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
+++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
@@ -245,11 +245,25 @@ LogicalResult inferReturnTensorTypes(
 LogicalResult verifyInferredResultTypes(Operation *op);
 } // namespace detail
 
+namespace OpTrait {
+template <typename ConcreteType>
+class InferTensorType;
+} // namespace OpTrait
+} // namespace mlir
+
+/// Include the generated interface declarations.
+#include "mlir/Interfaces/InferTypeOpInterface.h.inc"
+
+namespace mlir {
 namespace OpTrait {
 
 /// Tensor type inference trait that constructs a tensor from the inferred
 /// shape and elemental types.
-/// Requires: Op implements functions of InferShapedTypeOpInterface.
+/// Requires: Op implements InferShapedTypeOpInterface and InferTypeOpInterface.
+///   Less strict is possible (e.g., implements inferReturnTypeComponents and
+///   these always populates all element types and shapes or fails, but this\
+///   trait is currently only used where the interfaces are, so keep it
+///   restricted for now).
 template <typename ConcreteType>
 class InferTensorType : public TraitBase<ConcreteType, InferTensorType> {
 public:
@@ -258,6 +272,12 @@ class InferTensorType : public TraitBase<ConcreteType, InferTensorType> {
                    ValueRange operands, DictionaryAttr attributes,
                    RegionRange regions,
                    SmallVectorImpl<Type> &inferredReturnTypes) {
+    static_assert(
+        ConcreteType::template hasTrait<InferShapedTypeOpInterface::Trait>(),
+        "requires InferShapedTypeOpInterface to ensure succesful invocation");
+    static_assert(
+        ConcreteType::template hasTrait<InferTypeOpInterface::Trait>(),
+        "requires InferTypeOpInterface to ensure succesful invocation");
     return ::mlir::detail::inferReturnTensorTypes(
         ConcreteType::inferReturnTypeComponents, context, location, operands,
         attributes, regions, inferredReturnTypes);
@@ -267,7 +287,4 @@ class InferTensorType : public TraitBase<ConcreteType, InferTensorType> {
 } // namespace OpTrait
 } // namespace mlir
 
-/// Include the generated interface declarations.
-#include "mlir/Interfaces/InferTypeOpInterface.h.inc"
-
 #endif // MLIR_INTERFACES_INFERTYPEOPINTERFACE_H_


        


More information about the Mlir-commits mailing list