[Mlir-commits] [mlir] c3dcf39 - [mlir] Restrict to requiring traits when using InferTensorType trait.
Jacques Pienaar
llvmlistbot at llvm.org
Mon Oct 11 14:56:36 PDT 2021
Author: Jacques Pienaar
Date: 2021-10-11T14:56:28-07:00
New Revision: c3dcf39554dbea780d6cb7e12239451ba47a2668
URL: https://github.com/llvm/llvm-project/commit/c3dcf39554dbea780d6cb7e12239451ba47a2668
DIFF: https://github.com/llvm/llvm-project/commit/c3dcf39554dbea780d6cb7e12239451ba47a2668.diff
LOG: [mlir] Restrict to requiring traits when using InferTensorType trait.
Avoids running into segfaults accidentally.
Differential Revision: https://reviews.llvm.org/D110297
Added:
Modified:
mlir/include/mlir/Interfaces/InferTypeOpInterface.h
Removed:
################################################################################
diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
index e153e66bd9d97..c4f8f2d905e56 100644
--- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
+++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
@@ -245,11 +245,25 @@ LogicalResult inferReturnTensorTypes(
LogicalResult verifyInferredResultTypes(Operation *op);
} // namespace detail
+namespace OpTrait {
+template <typename ConcreteType>
+class InferTensorType;
+} // namespace OpTrait
+} // namespace mlir
+
+/// Include the generated interface declarations.
+#include "mlir/Interfaces/InferTypeOpInterface.h.inc"
+
+namespace mlir {
namespace OpTrait {
/// Tensor type inference trait that constructs a tensor from the inferred
/// shape and elemental types.
-/// Requires: Op implements functions of InferShapedTypeOpInterface.
+/// Requires: Op implements InferShapedTypeOpInterface and InferTypeOpInterface.
+/// Less strict is possible (e.g., implements inferReturnTypeComponents and
+/// these always populates all element types and shapes or fails, but this\
+/// trait is currently only used where the interfaces are, so keep it
+/// restricted for now).
template <typename ConcreteType>
class InferTensorType : public TraitBase<ConcreteType, InferTensorType> {
public:
@@ -258,6 +272,12 @@ class InferTensorType : public TraitBase<ConcreteType, InferTensorType> {
ValueRange operands, DictionaryAttr attributes,
RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
+ static_assert(
+ ConcreteType::template hasTrait<InferShapedTypeOpInterface::Trait>(),
+ "requires InferShapedTypeOpInterface to ensure succesful invocation");
+ static_assert(
+ ConcreteType::template hasTrait<InferTypeOpInterface::Trait>(),
+ "requires InferTypeOpInterface to ensure succesful invocation");
return ::mlir::detail::inferReturnTensorTypes(
ConcreteType::inferReturnTypeComponents, context, location, operands,
attributes, regions, inferredReturnTypes);
@@ -267,7 +287,4 @@ class InferTensorType : public TraitBase<ConcreteType, InferTensorType> {
} // namespace OpTrait
} // namespace mlir
-/// Include the generated interface declarations.
-#include "mlir/Interfaces/InferTypeOpInterface.h.inc"
-
#endif // MLIR_INTERFACES_INFERTYPEOPINTERFACE_H_
More information about the Mlir-commits
mailing list