[Mlir-commits] [mlir] [mlir] Bug fixes in TOSA shape inference pass (PR #104146)
Rafael Ubal
llvmlistbot at llvm.org
Wed Aug 14 11:36:28 PDT 2024
https://github.com/rafaelubalmw created https://github.com/llvm/llvm-project/pull/104146
This PR addresses 2 bugs in the TOSA shape inference pass (`--tosa-infer-shapes`). The included unit test contains a detailed description of the issues.
- Input IR
```
func.func @main(%arg0: tensor<1x2x8xf32>) {
%0 = tosa.cast %arg0 : (tensor<1x2x8xf32>) -> tensor<?x2x8xf32>
%c0 = arith.constant 0 : index
%dim = tensor.dim %0, %c0 : tensor<?x2x8xf32>
%expanded_0 = tensor.expand_shape %0 [[0], [1, 2], [3]] output_shape [%dim, 1, 4, 8] : tensor<?x2x8xf32> into tensor<?x1x2x8xf32>
%expanded_1 = tensor.expand_shape %0 [[0], [1, 2], [3]] output_shape [%dim, 1, 4, 8] : tensor<?x2x8xf32> into tensor<?x1x2x8xf32>
return
}
```
- Output IR
```
module {
func.func @main(%arg0: tensor<1x2x8xf32>) {
%0 = tosa.cast %arg0 : (tensor<1x2x8xf32>) -> tensor<1x2x8xf32>
// This cast was previously inserted between both 'tensor.expand_shape' ops.
%cast = tensor.cast %0 : tensor<1x2x8xf32> to tensor<?x2x8xf32>
%c0 = arith.constant 0 : index
%dim = tensor.dim %0, %c0 : tensor<1x2x8xf32>
// The operand of the first 'tensor.expand_shape' op was not previously updated
// from '%0' to '%cast' due to an invalidation of the iterator traversing the
// use list of the 'tosa.cast' op.
%expanded_0 = tensor.expand_shape %cast [[0], [1, 2], [3]] output_shape [%dim, 1, 4, 8] : tensor<?x2x8xf32> into tensor<?x1x2x8xf32>
%expanded_1 = tensor.expand_shape %cast [[0], [1, 2], [3]] output_shape [%dim, 1, 4, 8] : tensor<?x2x8xf32> into tensor<?x1x2x8xf32>
return
}
```
>From b0bec6e5be64d7a2571bc86ae20ef9361bf7ece2 Mon Sep 17 00:00:00 2001
From: Rafael Ubal Tena <rubal at mathworks.com>
Date: Wed, 14 Aug 2024 14:28:16 -0400
Subject: [PATCH] Fixed 2 bugs in tosa infer shapes
---
.../Tosa/Transforms/TosaInferShapes.cpp | 34 ++++++++---
mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir | 56 ++++++++++++++++++-
2 files changed, 81 insertions(+), 9 deletions(-)
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
index b1d5720541846f..90c40c09d734f4 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
@@ -88,18 +88,36 @@ class TypeModificationState {
// For each use whose type changed, cast the value with the new type back to
// the old type.
for (auto [value, oldType] : oldTypes) {
- tensor::CastOp castedValue;
- for (auto &use : value.getUses()) {
- if (canBeRefined(use.getOwner()))
+ // The call to 'use->set()' in the body of the loop below invalidates the
+ // iterator used to traverse op uses, so it is important to make a copy of
+ // these first.
+ llvm::SmallVector<OpOperand *> uses = llvm::map_to_vector(
+ value.getUses(),
+ [](OpOperand &use) -> OpOperand * {
+ return &use;
+ });
+
+ // A 'tensor.cast' op is emitted only if needed. Once emitted, it is
+ // cached and reused by all consumers.
+ tensor::CastOp castValue;
+
+ // Traverse all uses
+ for (OpOperand* use : uses) {
+ if (canBeRefined(use->getOwner()))
continue;
- // Cache the cast to avoid generating duplicates
- if (!castedValue) {
- ImplicitLocOpBuilder builder{value.getLoc(), use.getOwner()};
- castedValue = builder.create<tensor::CastOp>(oldType, value);
+ if (!castValue) {
+ // Set the insertion point as far back as possible, since new
+ // consumers of the 'tensor.cast' op generated in future iterations
+ // are likely to be further up in the code due to the order in which
+ // they appear in the use list.
+ OpBuilder builder{value.getContext()};
+ builder.setInsertionPointAfter(value.getDefiningOp());
+ castValue = builder.create<tensor::CastOp>(
+ value.getLoc(), oldType, value);
}
- use.set(castedValue);
+ use->set(castValue);
}
}
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 3224f88968f3d2..d46de740800e93 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -1373,4 +1373,58 @@ func.func @test_tosa_use_def_chain(%arg0: tensor<1x32x32x3xf32>, %arg1: tensor<1
// CHECK: (tensor<1x32x32x16xf32>) -> tensor<1x16x16x16xf32>
%1 = tosa.max_pool2d %0 {kernel = array<i64: 2, 2>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 2, 2>} : (tensor<?x32x32x16xf32>) -> tensor<?x16x16x16xf32>
return %1 : tensor<?x16x16x16xf32>
-}
\ No newline at end of file
+}
+
+// -----
+
+// This test locks two bug fixes manifested in the code below.
+//
+// 1. Context
+//
+// When shape propagation hits an operation that does not support shape
+// inference (here 'tensor.expand_shape'), it must revert the currently
+// inferred shape of its consumers back to the originally expected input
+// type to avoid potential op verification errors. This type reversal is
+// done through an additional 'tensor.cast' op.
+//
+//
+// 2. Preserving list of non-inferrable consumers
+//
+// When multiple non-inferrable consumers of a shape-inferred value are found
+// (here, the 2 occurrences of 'tensor.expand_shape' consuming the output of
+// 'tosa.cast'), their input argument ('%0') must be altered to consume the
+// output the new 'tensor.cast' op. While these replacements occur, the use list
+// of the producer ('tosa.cast') is also implicitly altered, invalidating any
+// iterators associated with it. It is therefore necessary to create a copy of
+// this use list ahead of time. Before this bug fix, the second
+// 'tensor.expand_shape' op below was not updated correctly.
+//
+// 3. Guaranteeing def-use order
+//
+// When emitting the 'tensor.cast' op, it is important to guarantee that its
+// output value is defined before all of its consumers (here, both of the
+// 'tensor.expand_shape' ops. In a previous version of the code, this insertion
+// occurred right before the first encountered consumer. Since use lists are
+// saved in reverse order, the 'tensor.cast' op was inserted before the second
+// 'tensor.expand_shape' op, leading to a def-use order violation when the
+// first 'tensor.expand_shape' op was later updated. The current implementation
+// sets the insertion point right after the producer of the last shape-inferred
+// value (here 'tosa.cast'), which guarantees correct def-use order for all
+// future operand updates.
+
+// CHECK-LABEL: test_multiple_non_inferrable_consumers
+// CHECK-SAME: %[[ARG:.*]]: tensor<1x2x8xf32>
+func.func @test_multiple_non_inferrable_consumers(%arg0: tensor<1x2x8xf32>) {
+ // CHECK: %[[TOSA_CAST:.*]] = tosa.cast %[[ARG]] : (tensor<1x2x8xf32>) -> tensor<1x2x8xf32>
+ // CHECK: %[[TENSOR_CAST:.*]] = tensor.cast %[[TOSA_CAST]] : tensor<1x2x8xf32> to tensor<?x2x8xf32>
+ %0 = tosa.cast %arg0 : (tensor<1x2x8xf32>) -> tensor<?x2x8xf32>
+
+ %c0 = arith.constant 0 : index
+ %dim = tensor.dim %0, %c0 : tensor<?x2x8xf32>
+
+ // CHECK: tensor.expand_shape %[[TENSOR_CAST]]
+ // CHECK: tensor.expand_shape %[[TENSOR_CAST]]
+ %expanded_0 = tensor.expand_shape %0 [[0], [1, 2], [3]] output_shape [%dim, 1, 4, 8] : tensor<?x2x8xf32> into tensor<?x1x2x8xf32>
+ %expanded_1 = tensor.expand_shape %0 [[0], [1, 2], [3]] output_shape [%dim, 1, 4, 8] : tensor<?x2x8xf32> into tensor<?x1x2x8xf32>
+ return
+}
More information about the Mlir-commits
mailing list