[Mlir-commits] [mlir] [mlir] Bug fixes in TOSA shape inference pass (PR #104146)

Rafael Ubal llvmlistbot at llvm.org
Wed Aug 14 11:36:28 PDT 2024


https://github.com/rafaelubalmw created https://github.com/llvm/llvm-project/pull/104146

This PR addresses 2 bugs in the TOSA shape inference pass (`--tosa-infer-shapes`). The included unit test contains a detailed description of the issues.

- Input IR

```
func.func @main(%arg0: tensor<1x2x8xf32>) {
  %0 = tosa.cast %arg0 : (tensor<1x2x8xf32>) -> tensor<?x2x8xf32>

  %c0 = arith.constant 0 : index
  %dim = tensor.dim %0, %c0 : tensor<?x2x8xf32>

  %expanded_0 = tensor.expand_shape %0 [[0], [1, 2], [3]] output_shape [%dim, 1, 4, 8] : tensor<?x2x8xf32> into tensor<?x1x2x8xf32>
  %expanded_1 = tensor.expand_shape %0 [[0], [1, 2], [3]] output_shape [%dim, 1, 4, 8] : tensor<?x2x8xf32> into tensor<?x1x2x8xf32>
  return
}
```

- Output IR

```
module {
  func.func @main(%arg0: tensor<1x2x8xf32>) {
  %0 = tosa.cast %arg0 : (tensor<1x2x8xf32>) -> tensor<1x2x8xf32>

  // This cast was previously inserted between both 'tensor.expand_shape' ops.
  %cast = tensor.cast %0 : tensor<1x2x8xf32> to tensor<?x2x8xf32>

  %c0 = arith.constant 0 : index
  %dim = tensor.dim %0, %c0 : tensor<1x2x8xf32>

  // The operand of the first 'tensor.expand_shape' op was not previously updated
  // from '%0' to '%cast' due to an invalidation of the iterator traversing the
  // use list of the 'tosa.cast' op.
  %expanded_0 = tensor.expand_shape %cast [[0], [1, 2], [3]] output_shape [%dim, 1, 4, 8] : tensor<?x2x8xf32> into tensor<?x1x2x8xf32>
  %expanded_1 = tensor.expand_shape %cast [[0], [1, 2], [3]] output_shape [%dim, 1, 4, 8] : tensor<?x2x8xf32> into tensor<?x1x2x8xf32>
  return
}
```


>From b0bec6e5be64d7a2571bc86ae20ef9361bf7ece2 Mon Sep 17 00:00:00 2001
From: Rafael Ubal Tena <rubal at mathworks.com>
Date: Wed, 14 Aug 2024 14:28:16 -0400
Subject: [PATCH] Fixed 2 bugs in tosa infer shapes

---
 .../Tosa/Transforms/TosaInferShapes.cpp       | 34 ++++++++---
 mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir | 56 ++++++++++++++++++-
 2 files changed, 81 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
index b1d5720541846f..90c40c09d734f4 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
@@ -88,18 +88,36 @@ class TypeModificationState {
     // For each use whose type changed, cast the value with the new type back to
     // the old type.
     for (auto [value, oldType] : oldTypes) {
-      tensor::CastOp castedValue;
-      for (auto &use : value.getUses()) {
-        if (canBeRefined(use.getOwner()))
+      // The call to 'use->set()' in the body of the loop below invalidates the
+      // iterator used to traverse op uses, so it is important to make a copy of
+      // these first.
+      llvm::SmallVector<OpOperand *> uses = llvm::map_to_vector(
+          value.getUses(),
+          [](OpOperand &use) -> OpOperand * {
+            return &use;
+          });
+
+      // A 'tensor.cast' op is emitted only if needed. Once emitted, it is
+      // cached and reused by all consumers.
+      tensor::CastOp castValue;
+
+      // Traverse all uses
+      for (OpOperand* use : uses) {
+        if (canBeRefined(use->getOwner()))
           continue;
 
-        // Cache the cast to avoid generating duplicates
-        if (!castedValue) {
-          ImplicitLocOpBuilder builder{value.getLoc(), use.getOwner()};
-          castedValue = builder.create<tensor::CastOp>(oldType, value);
+        if (!castValue) {
+          // Set the insertion point as far back as possible, since new
+          // consumers of the 'tensor.cast' op generated in future iterations
+          // are likely to be further up in the code due to the order in which
+          // they appear in the use list.
+          OpBuilder builder{value.getContext()};
+          builder.setInsertionPointAfter(value.getDefiningOp());
+          castValue = builder.create<tensor::CastOp>(
+              value.getLoc(), oldType, value);
         }
 
-        use.set(castedValue);
+        use->set(castValue);
       }
     }
 
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 3224f88968f3d2..d46de740800e93 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -1373,4 +1373,58 @@ func.func @test_tosa_use_def_chain(%arg0: tensor<1x32x32x3xf32>, %arg1: tensor<1
   // CHECK: (tensor<1x32x32x16xf32>) -> tensor<1x16x16x16xf32>
   %1 = tosa.max_pool2d %0 {kernel = array<i64: 2, 2>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 2, 2>} : (tensor<?x32x32x16xf32>) -> tensor<?x16x16x16xf32>
   return %1 : tensor<?x16x16x16xf32>
-}
\ No newline at end of file
+}
+
+// -----
+
+// This test locks two bug fixes manifested in the code below.
+//
+// 1. Context
+//
+// When shape propagation hits an operation that does not support shape
+// inference (here 'tensor.expand_shape'), it must revert the currently
+// inferred shape of its consumers back to the originally expected input
+// type to avoid potential op verification errors. This type reversal is
+// done through an additional 'tensor.cast' op.
+//
+//
+// 2. Preserving list of non-inferrable consumers
+//
+// When multiple non-inferrable consumers of a shape-inferred value are found
+// (here, the 2 occurrences of 'tensor.expand_shape' consuming the output of
+// 'tosa.cast'), their input argument ('%0') must be altered to consume the
+// output the new 'tensor.cast' op. While these replacements occur, the use list
+// of the producer ('tosa.cast') is also implicitly altered, invalidating any
+// iterators associated with it. It is therefore necessary to create a copy of
+// this use list ahead of time. Before this bug fix, the second
+// 'tensor.expand_shape' op below was not updated correctly.
+//
+// 3. Guaranteeing def-use order
+//
+// When emitting the 'tensor.cast' op, it is important to guarantee that its
+// output value is defined before all of its consumers (here, both of the
+// 'tensor.expand_shape' ops. In a previous version of the code, this insertion
+// occurred right before the first encountered consumer. Since use lists are
+// saved in reverse order, the 'tensor.cast' op was inserted before the second
+// 'tensor.expand_shape' op, leading to a def-use order violation when the
+// first 'tensor.expand_shape' op was later updated. The current implementation
+// sets the insertion point right after the producer of the last shape-inferred
+// value (here 'tosa.cast'), which guarantees correct def-use order for all
+// future operand updates.
+
+// CHECK-LABEL: test_multiple_non_inferrable_consumers
+// CHECK-SAME: %[[ARG:.*]]: tensor<1x2x8xf32>
+func.func @test_multiple_non_inferrable_consumers(%arg0: tensor<1x2x8xf32>) {
+  // CHECK: %[[TOSA_CAST:.*]] = tosa.cast %[[ARG]] : (tensor<1x2x8xf32>) -> tensor<1x2x8xf32>
+  // CHECK: %[[TENSOR_CAST:.*]] = tensor.cast %[[TOSA_CAST]] : tensor<1x2x8xf32> to tensor<?x2x8xf32>
+  %0 = tosa.cast %arg0 : (tensor<1x2x8xf32>) -> tensor<?x2x8xf32>
+
+  %c0 = arith.constant 0 : index
+  %dim = tensor.dim %0, %c0 : tensor<?x2x8xf32>
+
+  // CHECK: tensor.expand_shape %[[TENSOR_CAST]]
+  // CHECK: tensor.expand_shape %[[TENSOR_CAST]]
+  %expanded_0 = tensor.expand_shape %0 [[0], [1, 2], [3]] output_shape [%dim, 1, 4, 8] : tensor<?x2x8xf32> into tensor<?x1x2x8xf32>
+  %expanded_1 = tensor.expand_shape %0 [[0], [1, 2], [3]] output_shape [%dim, 1, 4, 8] : tensor<?x2x8xf32> into tensor<?x1x2x8xf32>
+  return
+}



More information about the Mlir-commits mailing list