[Mlir-commits] [mlir] [TOSA] Fix shape inference when operand was inferred (PR #66906)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Sep 20 06:44:29 PDT 2023


https://github.com/maxbartel created https://github.com/llvm/llvm-project/pull/66906

https://github.com/llvm/llvm-project/commit/057fc8e7d8a3593f98930b8b91f80b9dd9b5fd4a Introduces a bug in the `TosaInferShapesPass` when an operand type was already inferred. https://github.com/llvm/llvm-project/blob/f7bfa583b7a5ff0e9954d2810006b7a71123be88/mlir/include/mlir/Interfaces/InferTypeOpInterface.td#L248 interprets the `ValueShapeRange` as a normal `ValueRange` and looses the information of the inference. 

This PR changes the logic of the shape inference a bit. Instead of saving the type information in a `DenseMap` and updating the types after the whole analysis for a region, it now updates the types directly in each iteration. That way the operands always have the inferred type.

>From 13b2dca39131aff22edabd6797344e256c58f336 Mon Sep 17 00:00:00 2001
From: Maximilian Bartel <max.bartel97 at gmail.com>
Date: Wed, 20 Sep 2023 15:22:12 +0200
Subject: [PATCH] fix: fix tosa-infer-shapes pass for use def chains

---
 .../mlir/Interfaces/InferTypeOpInterface.td   |  2 +-
 .../Tosa/Transforms/TosaInferShapes.cpp       | 56 +++++--------------
 mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir | 13 +++++
 3 files changed, 29 insertions(+), 42 deletions(-)

diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
index 54c1c13fd029dbc..c5eeeaf58a7b4f8 100644
--- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
+++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
@@ -223,7 +223,7 @@ def InferTypeOpAdaptorWithIsCompatible : InferTypeOpAdaptorBase<
 >;
 
 // Convenient trait to define a wrapper to inferReturnTypeComponents that passes
-// in the Op Adaptor directly
+// in the Op Adaptor directly. Only uses the current types of the operands.
 class InferShapedTypeOpAdaptorBase<list<string> overridenMethods = []> : TraitList<
   [
     // Op implements infer type op interface.
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
index 9c49cd55788571b..3cc16a91edce747 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
@@ -41,8 +41,7 @@ namespace {
 
 void propagateShapesInRegion(Region &region);
 
-void propagateShapesToTosaIf(
-    Operation &op, DenseMap<Value, ShapedTypeComponents> &shapesStorage) {
+void propagateShapesToTosaIf(Operation &op) {
   IfOp ifOp = dyn_cast<IfOp>(op);
   if (!ifOp)
     return;
@@ -53,12 +52,12 @@ void propagateShapesToTosaIf(
       return;
 
     for (unsigned int i = 1, s = op.getNumOperands(); i < s; i++) {
-      auto inferredTy = shapesStorage[op.getOperand(i)];
+      auto inferredTy = cast<ShapedType>(op.getOperand(i).getType());
       auto blockArg = frontBlock.getArgument(i - 1);
       auto oldType = cast<ShapedType>(blockArg.getType());
 
       if (inferredTy.hasRank()) {
-        Type newType = oldType.clone(inferredTy.getDims());
+        Type newType = oldType.clone(inferredTy.getShape());
         blockArg.setType(newType);
       }
     }
@@ -79,8 +78,7 @@ void propagateShapesToTosaIf(
   }
 }
 
-void propagateShapesToTosaWhile(
-    Operation &op, DenseMap<Value, ShapedTypeComponents> &shapesStorage) {
+void propagateShapesToTosaWhile(Operation &op) {
   WhileOp whileOp = dyn_cast<WhileOp>(op);
   if (!whileOp)
     return;
@@ -91,9 +89,8 @@ void propagateShapesToTosaWhile(
   llvm::SmallVector<Type> argTypes;
   for (auto operand : op.getOperands()) {
     auto operandTy = cast<ShapedType>(operand.getType());
-    auto shapedTypeComponent = shapesStorage[operand];
-    if (shapedTypeComponent.hasRank()) {
-      auto newTy = operandTy.clone(shapedTypeComponent.getDims());
+    if (operandTy.hasRank()) {
+      auto newTy = operandTy.clone(operandTy.getShape());
       argTypes.push_back(newTy);
     } else {
       argTypes.push_back(operand.getType());
@@ -187,21 +184,6 @@ void propagateShapesToTosaWhile(
 }
 
 void propagateShapesInRegion(Region &region) {
-  DenseMap<Value, ShapedTypeComponents> shapesStorage;
-  auto setShapes = [&](Value val, Type t) {
-    if (auto st = dyn_cast<ShapedType>(t))
-      shapesStorage[val] = st;
-    else
-      shapesStorage[val] = t;
-  };
-  auto operandShape = [&](Value val) -> ShapeAdaptor {
-    // Query the WIP mapping rather than the type if set.
-    auto it = shapesStorage.find(val);
-    if (it == shapesStorage.end())
-      return nullptr;
-    return it->second;
-  };
-
   // Check whether this use case is replaceable. We define an op as
   // being replaceable if it is used by a ReturnOp, a TosaOp, or an op with a
   // type-inference related interface.
@@ -217,8 +199,8 @@ void propagateShapesInRegion(Region &region) {
       if (op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
         continue;
 
-      propagateShapesToTosaIf(op, shapesStorage);
-      propagateShapesToTosaWhile(op, shapesStorage);
+      propagateShapesToTosaIf(op);
+      propagateShapesToTosaWhile(op);
 
       InferShapedTypeOpInterface shapeInterface =
           dyn_cast<InferShapedTypeOpInterface>(op);
@@ -227,12 +209,11 @@ void propagateShapesInRegion(Region &region) {
 
       SmallVector<ShapedTypeComponents> returnedShapes;
 
-      ValueShapeRange range(op.getOperands(), operandShape);
       if (shapeInterface
-              .inferReturnTypeComponents(op.getContext(), op.getLoc(), range,
-                                         op.getDiscardableAttrDictionary(),
-                                         op.getPropertiesStorage(),
-                                         op.getRegions(), returnedShapes)
+              .inferReturnTypeComponents(
+                  op.getContext(), op.getLoc(), op.getOperands(),
+                  op.getDiscardableAttrDictionary(), op.getPropertiesStorage(),
+                  op.getRegions(), returnedShapes)
               .succeeded()) {
         for (auto it : llvm::zip(op.getResults(), returnedShapes)) {
           Value result = std::get<0>(it);
@@ -262,20 +243,13 @@ void propagateShapesInRegion(Region &region) {
               ValueKnowledge::join(currentKnowledge, inferredKnowledge);
           if (!newKnowledge)
             continue;
-          setShapes(result, newKnowledge.getType());
+
+          // Set new type
+          result.setType(newKnowledge.getType());
         }
       }
     }
   }
-
-  // Actually update types with updated shape knowledge.
-  for (auto it : shapesStorage) {
-    auto result = it.second;
-    if (result.hasRank()) {
-      Type t = cast<ShapedType>(it.first.getType()).clone(result.getDims());
-      it.first.setType(t);
-    }
-  }
 }
 
 /// Pass that performs shape propagation across TOSA operations. This includes
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index cb96b2a8a0d193b..d468ba582483cbe 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -1259,3 +1259,16 @@ func.func @test_non_tosa_consumer_extract(%arg0: tensor<4x4xf32>, %arg1: index)
   %1 = tensor.extract %0[%arg1, %arg1] : tensor<?x?xf32>
   return %1 : f32
 }
+
+// -----
+
+// CHECK-LABEL: test_tosa_use_def_chain
+func.func @test_tosa_use_def_chain(%arg0: tensor<1x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>, %arg2: tensor<16xf32>) -> tensor<?x16x16x16xf32> {
+  // CHECK: [[CONV:%.+]] = tosa.conv2d %arg0, %arg1, %arg2
+  // CHECK: (tensor<1x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
+  %0 = tosa.conv2d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1>, pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>} : (tensor<1x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<?x32x32x16xf32>
+  // CHECK: tosa.max_pool2d [[CONV]]
+  // CHECK: (tensor<1x32x32x16xf32>) -> tensor<1x16x16x16xf32>
+  %1 = tosa.max_pool2d %0 {kernel = array<i64: 2, 2>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 2, 2>} : (tensor<?x32x32x16xf32>) -> tensor<?x16x16x16xf32>
+  return %1 : tensor<?x16x16x16xf32>
+}



More information about the Mlir-commits mailing list