[Mlir-commits] [mlir] 1a2370b - [MLIR] Fix shape inference in toy tutorial
Mehdi Amini
llvmlistbot at llvm.org
Fri Apr 3 21:34:41 PDT 2020
Author: Frederik Gossen
Date: 2020-04-04T04:34:21Z
New Revision: 1a2370bfb8c17f188f197aa97f99a25fd889f8e7
URL: https://github.com/llvm/llvm-project/commit/1a2370bfb8c17f188f197aa97f99a25fd889f8e7
DIFF: https://github.com/llvm/llvm-project/commit/1a2370bfb8c17f188f197aa97f99a25fd889f8e7.diff
LOG: [MLIR] Fix shape inference in toy tutorial
The implementation of shape inference in the toy tutorial did not conform to the correct algorithmic description.
The result was only correct because all operations appear to be processed in sequence.
Differential Revision: https://reviews.llvm.org/D77382
Added:
Modified:
mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp
mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp
mlir/examples/toy/Ch6/mlir/ShapeInferencePass.cpp
mlir/examples/toy/Ch7/mlir/ShapeInferencePass.cpp
Removed:
################################################################################
diff --git a/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp
index 107c80460ee6..296bec094624 100644
--- a/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp
+++ b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp
@@ -62,7 +62,7 @@ class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
while (!opWorklist.empty()) {
// Find the next operation ready for inference, that is an operation
// with all operands already resolved (non-generic).
- auto nextop = llvm::find_if(opWorklist, returnsDynamicShape);
+ auto nextop = llvm::find_if(opWorklist, allOperandsInferred);
if (nextop == opWorklist.end())
break;
@@ -88,6 +88,14 @@ class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
}
}
+ /// A utility method that returns if the given operation has all of its
+ /// operands inferred.
+ static bool allOperandsInferred(Operation *op) {
+ return llvm::all_of(op->getOperandTypes(), [](Type operandType) {
+ return operandType.isa<RankedTensorType>();
+ });
+ }
+
/// A utility method that returns if the given operation has a dynamically
/// shaped result.
static bool returnsDynamicShape(Operation *op) {
diff --git a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp
index 107c80460ee6..296bec094624 100644
--- a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp
+++ b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp
@@ -62,7 +62,7 @@ class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
while (!opWorklist.empty()) {
// Find the next operation ready for inference, that is an operation
// with all operands already resolved (non-generic).
- auto nextop = llvm::find_if(opWorklist, returnsDynamicShape);
+ auto nextop = llvm::find_if(opWorklist, allOperandsInferred);
if (nextop == opWorklist.end())
break;
@@ -88,6 +88,14 @@ class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
}
}
+ /// A utility method that returns if the given operation has all of its
+ /// operands inferred.
+ static bool allOperandsInferred(Operation *op) {
+ return llvm::all_of(op->getOperandTypes(), [](Type operandType) {
+ return operandType.isa<RankedTensorType>();
+ });
+ }
+
/// A utility method that returns if the given operation has a dynamically
/// shaped result.
static bool returnsDynamicShape(Operation *op) {
diff --git a/mlir/examples/toy/Ch6/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch6/mlir/ShapeInferencePass.cpp
index 107c80460ee6..296bec094624 100644
--- a/mlir/examples/toy/Ch6/mlir/ShapeInferencePass.cpp
+++ b/mlir/examples/toy/Ch6/mlir/ShapeInferencePass.cpp
@@ -62,7 +62,7 @@ class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
while (!opWorklist.empty()) {
// Find the next operation ready for inference, that is an operation
// with all operands already resolved (non-generic).
- auto nextop = llvm::find_if(opWorklist, returnsDynamicShape);
+ auto nextop = llvm::find_if(opWorklist, allOperandsInferred);
if (nextop == opWorklist.end())
break;
@@ -88,6 +88,14 @@ class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
}
}
+ /// A utility method that returns if the given operation has all of its
+ /// operands inferred.
+ static bool allOperandsInferred(Operation *op) {
+ return llvm::all_of(op->getOperandTypes(), [](Type operandType) {
+ return operandType.isa<RankedTensorType>();
+ });
+ }
+
/// A utility method that returns if the given operation has a dynamically
/// shaped result.
static bool returnsDynamicShape(Operation *op) {
diff --git a/mlir/examples/toy/Ch7/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch7/mlir/ShapeInferencePass.cpp
index 107c80460ee6..296bec094624 100644
--- a/mlir/examples/toy/Ch7/mlir/ShapeInferencePass.cpp
+++ b/mlir/examples/toy/Ch7/mlir/ShapeInferencePass.cpp
@@ -62,7 +62,7 @@ class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
while (!opWorklist.empty()) {
// Find the next operation ready for inference, that is an operation
// with all operands already resolved (non-generic).
- auto nextop = llvm::find_if(opWorklist, returnsDynamicShape);
+ auto nextop = llvm::find_if(opWorklist, allOperandsInferred);
if (nextop == opWorklist.end())
break;
@@ -88,6 +88,14 @@ class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
}
}
+ /// A utility method that returns if the given operation has all of its
+ /// operands inferred.
+ static bool allOperandsInferred(Operation *op) {
+ return llvm::all_of(op->getOperandTypes(), [](Type operandType) {
+ return operandType.isa<RankedTensorType>();
+ });
+ }
+
/// A utility method that returns if the given operation has a dynamically
/// shaped result.
static bool returnsDynamicShape(Operation *op) {
More information about the Mlir-commits
mailing list