[Mlir-commits] [mlir] e706533 - [mlir] Add reifyReturnShape to shaped type OpInterface
Jacques Pienaar
llvmlistbot at llvm.org
Fri Feb 28 08:41:45 PST 2020
Author: Jacques Pienaar
Date: 2020-02-28T08:41:18-08:00
New Revision: e706533f0a52b2eb8929d4004d1daccd53feda29
URL: https://github.com/llvm/llvm-project/commit/e706533f0a52b2eb8929d4004d1daccd53feda29
DIFF: https://github.com/llvm/llvm-project/commit/e706533f0a52b2eb8929d4004d1daccd53feda29.diff
LOG: [mlir] Add reifyReturnShape to shaped type OpInterface
This call results in inserting operations that compute the return shape
dynamically for the operation.
Added:
Modified:
mlir/include/mlir/Analysis/InferTypeOpInterface.h
mlir/include/mlir/Analysis/InferTypeOpInterface.td
mlir/test/lib/TestDialect/TestDialect.cpp
mlir/test/lib/TestDialect/TestPatterns.cpp
mlir/test/mlir-tblgen/return-types.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Analysis/InferTypeOpInterface.h b/mlir/include/mlir/Analysis/InferTypeOpInterface.h
index 2a64983ad8c2..4c2628512cf4 100644
--- a/mlir/include/mlir/Analysis/InferTypeOpInterface.h
+++ b/mlir/include/mlir/Analysis/InferTypeOpInterface.h
@@ -15,6 +15,7 @@
#define MLIR_ANALYSIS_INFERTYPEOPINTERFACE_H_
#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Support/LLVM.h"
diff --git a/mlir/include/mlir/Analysis/InferTypeOpInterface.td b/mlir/include/mlir/Analysis/InferTypeOpInterface.td
index 621d586bd3e0..548cd09a14c3 100644
--- a/mlir/include/mlir/Analysis/InferTypeOpInterface.td
+++ b/mlir/include/mlir/Analysis/InferTypeOpInterface.td
@@ -97,6 +97,18 @@ def InferShapedTypeOpInterface : OpInterface<"InferShapedTypeOpInterface"> {
"SmallVectorImpl<ShapedTypeComponents>&":
$inferedReturnShapes)
>,
+ InterfaceMethod<
+ /*desc=*/[{Reify the shape computation for the operation.
+
+ Insert operations using the given OpBulder that computes the result shape.
+ }],
+ /*retTy=*/"LogicalResult",
+ /*methodName=*/"reifyReturnTypeShapes",
+ /*args=*/(ins "OpBuilder&":$builder,
+ "SmallVectorImpl<Value>&":$reifiedReturnShapes),
+ /*methodBody=*/[{}],
+ /*defaultImplementation=*/[{ return failure(); }]
+ >,
];
}
diff --git a/mlir/test/lib/TestDialect/TestDialect.cpp b/mlir/test/lib/TestDialect/TestDialect.cpp
index 330b8041afdc..12ec279c1d67 100644
--- a/mlir/test/lib/TestDialect/TestDialect.cpp
+++ b/mlir/test/lib/TestDialect/TestDialect.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/PatternMatch.h"
@@ -312,24 +313,24 @@ LogicalResult mlir::OpWithInferTypeInterfaceOp::inferReturnTypes(
LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
MLIRContext *context, Optional<Location> location, ValueRange operands,
ArrayRef<NamedAttribute> attributes, RegionRange regions,
- SmallVectorImpl<ShapedTypeComponents> &inferedComponents) {
- // Create return type consisting of the first element of each shape of the
- // input operands or unknown for unranked operand.
- std::vector<int64_t> shape;
- shape.reserve(operands.size());
- for (auto operandType : operands.getTypes()) {
- if (auto sval = operandType.dyn_cast<ShapedType>()) {
- if (sval.hasRank())
- shape.push_back(sval.getShape().front());
- else
- shape.push_back(ShapedType::kDynamicSize);
- } else {
- return emitOptionalError(location, "only shaped type operands allowed");
- }
+ SmallVectorImpl<ShapedTypeComponents> &inferedReturnShapes) {
+ // Create return type consisting of the last element of the first operand.
+ auto operandType = *operands.getTypes().begin();
+ auto sval = operandType.dyn_cast<ShapedType>();
+ if (!sval) {
+ return emitOptionalError(location, "only shaped type operands allowed");
}
- inferedComponents.reserve(1);
+ int64_t dim =
+ sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize;
auto type = IntegerType::get(17, context);
- inferedComponents.emplace_back(shape, type);
+ inferedReturnShapes.push_back(ShapedTypeComponents({dim}, type));
+ return success();
+}
+
+LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
+ OpBuilder &builder, llvm::SmallVectorImpl<Value> &shapes) {
+ shapes = SmallVector<Value, 1>{
+ builder.createOrFold<mlir::DimOp>(getLoc(), getOperand(0), 0)};
return success();
}
diff --git a/mlir/test/lib/TestDialect/TestPatterns.cpp b/mlir/test/lib/TestDialect/TestPatterns.cpp
index f89987610c99..decb5e246a81 100644
--- a/mlir/test/lib/TestDialect/TestPatterns.cpp
+++ b/mlir/test/lib/TestDialect/TestPatterns.cpp
@@ -82,6 +82,19 @@ static void invokeCreateWithInferedReturnType(Operation *op) {
}
}
+static void reifyReturnShape(Operation *op) {
+ OpBuilder b(op);
+
+ // Use permutations of 2 args as operands.
+ auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op);
+ SmallVector<Value, 2> shapes;
+ if (failed(shapedOp.reifyReturnTypeShapes(b, shapes)))
+ return;
+ for (auto it : llvm::enumerate(shapes))
+ op->emitRemark() << "value " << it.index() << ": "
+ << it.value().getDefiningOp();
+}
+
struct TestReturnTypeDriver : public FunctionPass<TestReturnTypeDriver> {
void runOnFunction() override {
if (getFunction().getName() == "testCreateFunctions") {
@@ -100,6 +113,16 @@ struct TestReturnTypeDriver : public FunctionPass<TestReturnTypeDriver> {
};
return;
}
+ if (getFunction().getName() == "testReifyFunctions") {
+ std::vector<Operation *> ops;
+ // Collect ops to avoid triggering on inserted ops.
+ for (auto &op : getFunction().getBody().front())
+ if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op))
+ ops.push_back(&op);
+ // Generate test patterns for each, but skip terminator.
+ for (auto *op : ops)
+ reifyReturnShape(op);
+ }
}
};
} // end anonymous namespace
diff --git a/mlir/test/mlir-tblgen/return-types.mlir b/mlir/test/mlir-tblgen/return-types.mlir
index 3fcb22331fa1..d0eb364a6a9d 100644
--- a/mlir/test/mlir-tblgen/return-types.mlir
+++ b/mlir/test/mlir-tblgen/return-types.mlir
@@ -7,13 +7,13 @@ func @testCreateFunctions(%arg0 : tensor<10xf32>, %arg1 : tensor<20xi32>) {
// CHECK: "test.no_attributes"
%good = "test.no_attributes"(%arg0, %arg0) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
// CHECK: "test.op_with_shaped_type_infer_type_if"
-// CHECK-SAME: (tensor<10xf32>, tensor<10xf32>) -> tensor<10x10xi17>
+// CHECK-SAME: (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi17>
// CHECK: "test.op_with_shaped_type_infer_type_if"
-// CHECK-SAME: (tensor<10xf32>, tensor<20xi32>) -> tensor<10x20xi17>
+// CHECK-SAME: (tensor<10xf32>, tensor<20xi32>) -> tensor<10xi17>
// CHECK: "test.op_with_shaped_type_infer_type_if"
-// CHECK-SAME: (tensor<20xi32>, tensor<10xf32>) -> tensor<20x10xi17>
+// CHECK-SAME: (tensor<20xi32>, tensor<10xf32>) -> tensor<20xi17>
// CHECK: "test.op_with_shaped_type_infer_type_if"
-// CHECK-SAME: (tensor<20xi32>, tensor<20xi32>) -> tensor<20x20xi17>
+// CHECK-SAME: (tensor<20xi32>, tensor<20xi32>) -> tensor<20xi17>
// CHECK: "test.op_with_infer_type_if"
// CHECK-SAME: (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
// CHECK: "test.op_with_infer_type_if"
@@ -36,3 +36,14 @@ func @testReturnTypeOpInterfaceMismatch(%arg0 : tensor<10xf32>, %arg1 : tensor<2
%bad = "test.op_with_infer_type_if"(%arg0, %arg1) : (tensor<10xf32>, tensor<20xf32>) -> tensor<*xf32>
return
}
+
+// -----
+
+// CHECK-LABEL: testReifyFunctions
+func @testReifyFunctions(%arg0 : tensor<10xf32>, %arg1 : tensor<20xf32>) {
+ // expected-remark at +1 {{constant 10}}
+ %0 = "test.op_with_shaped_type_infer_type_if"(%arg0, %arg1) : (tensor<10xf32>, tensor<20xf32>) -> tensor<10xi17>
+ // expected-remark at +1 {{constant 20}}
+ %1 = "test.op_with_shaped_type_infer_type_if"(%arg1, %arg0) : (tensor<20xf32>, tensor<10xf32>) -> tensor<20xi17>
+ return
+}
More information about the Mlir-commits
mailing list