[Mlir-commits] [mlir] 1a2370b - [MLIR] Fix shape inference in toy tutorial

Mehdi Amini llvmlistbot at llvm.org
Fri Apr 3 21:34:41 PDT 2020


Author: Frederik Gossen
Date: 2020-04-04T04:34:21Z
New Revision: 1a2370bfb8c17f188f197aa97f99a25fd889f8e7

URL: https://github.com/llvm/llvm-project/commit/1a2370bfb8c17f188f197aa97f99a25fd889f8e7
DIFF: https://github.com/llvm/llvm-project/commit/1a2370bfb8c17f188f197aa97f99a25fd889f8e7.diff

LOG: [MLIR] Fix shape inference in toy tutorial

The implementation of shape inference in the toy tutorial did not conform to the correct algorithmic description.
The result was only correct because all operations appear to be processed in sequence.

Differential Revision: https://reviews.llvm.org/D77382

Added: 
    

Modified: 
    mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp
    mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp
    mlir/examples/toy/Ch6/mlir/ShapeInferencePass.cpp
    mlir/examples/toy/Ch7/mlir/ShapeInferencePass.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp
index 107c80460ee6..296bec094624 100644
--- a/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp
+++ b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp
@@ -62,7 +62,7 @@ class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
     while (!opWorklist.empty()) {
       // Find the next operation ready for inference, that is an operation
       // with all operands already resolved (non-generic).
-      auto nextop = llvm::find_if(opWorklist, returnsDynamicShape);
+      auto nextop = llvm::find_if(opWorklist, allOperandsInferred);
       if (nextop == opWorklist.end())
         break;
 
@@ -88,6 +88,14 @@ class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
     }
   }
 
+  /// A utility method that returns if the given operation has all of its
+  /// operands inferred.
+  static bool allOperandsInferred(Operation *op) {
+    return llvm::all_of(op->getOperandTypes(), [](Type operandType) {
+      return operandType.isa<RankedTensorType>();
+    });
+  }
+
   /// A utility method that returns if the given operation has a dynamically
   /// shaped result.
   static bool returnsDynamicShape(Operation *op) {

diff  --git a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp
index 107c80460ee6..296bec094624 100644
--- a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp
+++ b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp
@@ -62,7 +62,7 @@ class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
     while (!opWorklist.empty()) {
       // Find the next operation ready for inference, that is an operation
       // with all operands already resolved (non-generic).
-      auto nextop = llvm::find_if(opWorklist, returnsDynamicShape);
+      auto nextop = llvm::find_if(opWorklist, allOperandsInferred);
       if (nextop == opWorklist.end())
         break;
 
@@ -88,6 +88,14 @@ class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
     }
   }
 
+  /// A utility method that returns if the given operation has all of its
+  /// operands inferred.
+  static bool allOperandsInferred(Operation *op) {
+    return llvm::all_of(op->getOperandTypes(), [](Type operandType) {
+      return operandType.isa<RankedTensorType>();
+    });
+  }
+
   /// A utility method that returns if the given operation has a dynamically
   /// shaped result.
   static bool returnsDynamicShape(Operation *op) {

diff  --git a/mlir/examples/toy/Ch6/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch6/mlir/ShapeInferencePass.cpp
index 107c80460ee6..296bec094624 100644
--- a/mlir/examples/toy/Ch6/mlir/ShapeInferencePass.cpp
+++ b/mlir/examples/toy/Ch6/mlir/ShapeInferencePass.cpp
@@ -62,7 +62,7 @@ class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
     while (!opWorklist.empty()) {
       // Find the next operation ready for inference, that is an operation
       // with all operands already resolved (non-generic).
-      auto nextop = llvm::find_if(opWorklist, returnsDynamicShape);
+      auto nextop = llvm::find_if(opWorklist, allOperandsInferred);
       if (nextop == opWorklist.end())
         break;
 
@@ -88,6 +88,14 @@ class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
     }
   }
 
+  /// A utility method that returns if the given operation has all of its
+  /// operands inferred.
+  static bool allOperandsInferred(Operation *op) {
+    return llvm::all_of(op->getOperandTypes(), [](Type operandType) {
+      return operandType.isa<RankedTensorType>();
+    });
+  }
+
   /// A utility method that returns if the given operation has a dynamically
   /// shaped result.
   static bool returnsDynamicShape(Operation *op) {

diff  --git a/mlir/examples/toy/Ch7/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch7/mlir/ShapeInferencePass.cpp
index 107c80460ee6..296bec094624 100644
--- a/mlir/examples/toy/Ch7/mlir/ShapeInferencePass.cpp
+++ b/mlir/examples/toy/Ch7/mlir/ShapeInferencePass.cpp
@@ -62,7 +62,7 @@ class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
     while (!opWorklist.empty()) {
       // Find the next operation ready for inference, that is an operation
       // with all operands already resolved (non-generic).
-      auto nextop = llvm::find_if(opWorklist, returnsDynamicShape);
+      auto nextop = llvm::find_if(opWorklist, allOperandsInferred);
       if (nextop == opWorklist.end())
         break;
 
@@ -88,6 +88,14 @@ class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
     }
   }
 
+  /// A utility method that returns if the given operation has all of its
+  /// operands inferred.
+  static bool allOperandsInferred(Operation *op) {
+    return llvm::all_of(op->getOperandTypes(), [](Type operandType) {
+      return operandType.isa<RankedTensorType>();
+    });
+  }
+
   /// A utility method that returns if the given operation has a dynamically
   /// shaped result.
   static bool returnsDynamicShape(Operation *op) {


        


More information about the Mlir-commits mailing list