[Mlir-commits] [mlir] 2f0b4db - [mlir] Add convenience grouping for tensor type inference
Jacques Pienaar
llvmlistbot at llvm.org
Mon Mar 1 05:21:22 PST 2021
Author: Jacques Pienaar
Date: 2021-03-01T05:21:08-08:00
New Revision: 2f0b4db5ea52148a91c57fcb192856bab567de5a
URL: https://github.com/llvm/llvm-project/commit/2f0b4db5ea52148a91c57fcb192856bab567de5a
DIFF: https://github.com/llvm/llvm-project/commit/2f0b4db5ea52148a91c57fcb192856bab567de5a.diff
LOG: [mlir] Add convenience grouping for tensor type inference
For ops that produces tensor types and implement the shaped type component interface, the type inference interface can be used. Create a grouping of these together to make it easier to specify (it cannot be added into a list of traits, but must rather be appended/concated to one as it isn't a trait but a list of traits).
Differential Revision: https://reviews.llvm.org/D97636
Added:
Modified:
mlir/include/mlir/Interfaces/InferTypeOpInterface.td
mlir/test/lib/Dialect/Test/TestOps.td
Removed:
################################################################################
diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
index ca044f019366..f15d9b3e006a 100644
--- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
+++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
@@ -116,4 +116,20 @@ def InferShapedTypeOpInterface : OpInterface<"InferShapedTypeOpInterface"> {
];
}
+// Convenience class grouping together type and shaped type op interfaces for
+// ops that have tensor return types.
+class InferTensorType<list<string> overridenMethods = []> {
+ list<OpTrait> traits = [
+ // Op implements infer type op interface.
+ InferTypeOpInterface,
+ // The op will have methods implementing the ShapedType type inference
+ // interface.
+ DeclareOpInterfaceMethods<InferShapedTypeOpInterface, overridenMethods>,
+ // The op produces tensors and will use the ShapedType type infer interface
+ // along with knowledge that it is producing Tensors to infer the type.
+ NativeOpTrait<"InferTensorType">
+ ];
+}
+defvar InferTensorTypeWithReify = InferTensorType<["reifyReturnTypeShapes"]>;
+
#endif // MLIR_INFERTYPEOPINTERFACE
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 44df4ab8e10b..4893ac3d8492 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -503,24 +503,10 @@ def OpWithInferTypeInterfaceOp : TEST_Op<"op_with_infer_type_if", [
let results = (outs AnyTensor);
}
-def InferTensorType : NativeOpTrait<"InferTensorType">;
def OpWithShapedTypeInferTypeInterfaceOp : TEST_Op<"op_with_shaped_type_infer_type_if",
- [
- // Op implements infer type op interface.
- InferTypeOpInterface,
- // The op will have methods implementing the ShapedType type infer interface.
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface>,
- // The op produces tensors and will use the ShapedType type infer interface
- // along with knowledge that it is producing Tensors to infer shape.
- InferTensorType
- ]> {
+ InferTensorTypeWithReify.traits> {
let arguments = (ins AnyTensor, AnyTensor);
let results = (outs AnyTensor);
-
- let extraClassDeclaration = [{
- LogicalResult reifyReturnTypeShapes(OpBuilder &builder,
- SmallVectorImpl<Value> &shapes);
- }];
}
def IsNotScalar : Constraint<CPred<"$0.getType().getRank() != 0">>;
More information about the Mlir-commits
mailing list