[Mlir-commits] [mlir] 98aad40 - [tosa] Improve inferred shapes of TOSA operations

Eric Kunze llvmlistbot at llvm.org
Thu May 25 16:28:05 PDT 2023


Author: Spenser Bauman
Date: 2023-05-25T16:27:13-07:00
New Revision: 98aad40806f18cd1d8f248d15efc75fe26840bd4

URL: https://github.com/llvm/llvm-project/commit/98aad40806f18cd1d8f248d15efc75fe26840bd4
DIFF: https://github.com/llvm/llvm-project/commit/98aad40806f18cd1d8f248d15efc75fe26840bd4.diff

LOG: [tosa] Improve inferred shapes of TOSA operations

The TosaInferShapes pass avoids updating the shapes of tensor operators
when the consumers are not TOSA operations, limiting the efficacy of
TosaInferShapes when the IR is a mix of TOSA and other operations.
This change attempts to update the result shapes when the consumers
themselves have reasonable type/shape inference methods.

Reviewed By: eric-k256

Differential Revision: https://reviews.llvm.org/D151228

Added: 
    

Modified: 
    mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
    mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
index 65b66d29d6f81..9c49cd5578857 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
@@ -21,6 +21,7 @@
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/Matchers.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -201,6 +202,16 @@ void propagateShapesInRegion(Region &region) {
     return it->second;
   };
 
+  // Check whether this use case is replaceable. We define an op as
+  // being replaceable if it is used by a ReturnOp, a TosaOp, or an op with a
+  // type-inference related interface.
+  auto isReplaceableUser = [](Operation *user) -> bool {
+    return isa<func::ReturnOp>(user) ||
+           user->getDialect()->getNamespace() ==
+               TosaDialect::getDialectNamespace() ||
+           isa<InferTypeOpInterface, InferShapedTypeOpInterface>(user);
+  };
+
   for (auto &block : region) {
     for (Operation &op : block) {
       if (op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
@@ -227,18 +238,8 @@ void propagateShapesInRegion(Region &region) {
           Value result = std::get<0>(it);
           ShapedTypeComponents predictedShape = std::get<1>(it);
 
-          // Check whether this use case is replaceable. We define an op as
-          // being replaceable if it is used by a ReturnOp or a TosaOp.
-          bool replaceable = true;
-          for (auto *user : result.getUsers()) {
-            if (isa<func::ReturnOp>(user))
-              continue;
-            if (user->getDialect()->getNamespace() ==
-                TosaDialect::getDialectNamespace())
-              continue;
-
-            replaceable = false;
-          }
+          if (!llvm::all_of(result.getUsers(), isReplaceableUser))
+            continue;
 
           // Determine the knowledge based on the output type.
           // TODO: should also query WIP type probably
@@ -256,9 +257,6 @@ void propagateShapesInRegion(Region &region) {
             }
           }
 
-          if (!replaceable)
-            continue;
-
           // Compute the new type based on the joined version.
           auto newKnowledge =
               ValueKnowledge::join(currentKnowledge, inferredKnowledge);

diff  --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 5bbb6e1fb4a6d..bf913363039d7 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -1237,3 +1237,33 @@ func.func @test_unranked_equal(%arg0 : tensor<*xf32>, %arg1 : tensor<f32>) -> ()
 
   return
 }
+
+// -----
+
+// CHECK-LABEL: test_non_tosa_consumer_shape
+func.func @test_non_tosa_consumer_shape(%arg0: tensor<4x4xf32>) -> !shape.shape {
+  // CHECK: "tosa.log"(%arg0) : (tensor<4x4xf32>) -> tensor<4x4xf32>
+  %0 = "tosa.log"(%arg0) : (tensor<4x4xf32>) -> tensor<*xf32>
+  %1 = shape.shape_of %0 : tensor<*xf32> -> !shape.shape
+  return %1 : !shape.shape
+}
+
+// -----
+
+// CHECK-LABEL: test_non_tosa_consumer_shape2
+func.func @test_non_tosa_consumer_shape2(%arg0: tensor<4x4xf32>) -> tensor<?xindex> {
+  // CHECK: "tosa.log"(%arg0) : (tensor<4x4xf32>) -> tensor<4x4xf32>
+  %0 = "tosa.log"(%arg0) : (tensor<4x4xf32>) -> tensor<*xf32>
+  %1 = shape.shape_of %0 : tensor<*xf32> -> tensor<?xindex>
+  return %1 : tensor<?xindex>
+}
+
+// -----
+
+// CHECK-LABEL: test_non_tosa_consumer_extract
+func.func @test_non_tosa_consumer_extract(%arg0: tensor<4x4xf32>, %arg1: index) -> f32 {
+  // CHECK: "tosa.log"(%arg0) : (tensor<4x4xf32>) -> tensor<4x4xf32>
+  %0 = "tosa.log"(%arg0) : (tensor<4x4xf32>) -> tensor<?x?xf32>
+  %1 = tensor.extract %0[%arg1, %arg1] : tensor<?x?xf32>
+  return %1 : f32
+}


        


More information about the Mlir-commits mailing list