[Mlir-commits] [mlir] [mlir][tosa] Fix tosa-infer-shapes crash (PR #87234)
Spenser Bauman
llvmlistbot at llvm.org
Mon Apr 1 05:27:04 PDT 2024
https://github.com/sabauma created https://github.com/llvm/llvm-project/pull/87234
The tosa-infer-shapes pass inserts tensor.cast operations to mediate refined result types with consumers whose types cannot be refined. This process interferes with how types are refined in tosa.while_loop body regions, where types are propagated speculatively (to determine the types of the tosa.yield terminator) and then reverted.
The new tosa.cast operations result in a crash due to not having types associated to them for the reversion process.
This change modifies the shape propagation behavior so that the introduction to tensor.cast operations behaves better with this type reversion process. The new behavior is to only introduce tensor.cast operations once we wish to commit the newly computed types to the IR.
>From 6c27cab4cc8dbc5139337e2ca5a9027683ad204c Mon Sep 17 00:00:00 2001
From: Spenser Bauman <sabauma at fastmail>
Date: Sat, 30 Mar 2024 15:36:28 -0400
Subject: [PATCH] [mlir][tosa] Fix tosa-infer-shapes crash
The tosa-infer-shapes pass inserts tensor.cast operations to mediate
refined result types with consumers whose types cannot be refined.
This process interferes with how types are refined in tosa.while_loop
body regions, where types are propagated speculatively (to determine the
types of the tosa.yield terminator) and then reverted.
The new tosa.cast operations result in a crash due to not having types
associated to them for the reversion process.
This change modifies the shape propagation behavior so that the
introduction to tensor.cast operations behaves better with this type
reversion process. The new behavior is to only introduce tensor.cast
operations once we wish to commit the newly computed types to the IR.
---
.../Tosa/Transforms/TosaInferShapes.cpp | 198 +++++++++---------
mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir | 116 +++++++++-
2 files changed, 218 insertions(+), 96 deletions(-)
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
index ad28c564f7dbdd..8614559e2a6f13 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
@@ -18,14 +18,9 @@
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/IRMapping.h"
-#include "mlir/IR/Matchers.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "llvm/Support/FormatVariadic.h"
namespace mlir {
namespace tosa {
@@ -39,9 +34,87 @@ using namespace mlir::tosa;
namespace {
-void propagateShapesInRegion(Region ®ion);
+// Check whether this use case is replaceable. We define an op as
+// being replaceable if it is used by a TosaOp, or an op with a
+// type-inference related interface.
+// When a non-replaceable use is encountered, the value is wrapped in a
+// cast back to the original type after inference.
+bool isReplaceableUser(Operation *user) {
+ // Handle unregistered dialects.
+ if (!user->getDialect())
+ return false;
+
+ return user->getDialect()->getNamespace() ==
+ TosaDialect::getDialectNamespace() ||
+ isa<InferTypeOpInterface, InferShapedTypeOpInterface>(user);
+}
+
+// During type propagation, the types of values in the operator graph are
+// updated. For the tosa.while_loop operation, types are speculatively updated
+// within the body region to determine the output type of the while_loop. This
+// process is performed until a fixed point is reached, then the types are
+// reverted.
+//
+// This class encapsulates the state information needed to perform the reversion
+// process or to commit to the final changes.
+class TypeModificationState {
+public:
+ TypeModificationState() = default;
+
+ ~TypeModificationState() {
+ // Ensure the recorded modifications are either committed or reverted.
+ assert(oldTypes.empty() && "unhandled type modifications");
+ }
+
+ // Update the state of the value and record the old type.
+ void setType(Value value, Type type) {
+ if (value.getType() != type) {
+ oldTypes.emplace_back(value, value.getType());
+ value.setType(type);
+ }
+ }
-void propagateShapesToTosaIf(Operation &op) {
+ // Revert changes made to the types in the IR by setting all the affected
+ // values to their old types.
+ void revert() {
+ // Otherwise revert the changes.
+ for (auto [value, type] : oldTypes)
+ value.setType(type);
+
+ oldTypes.clear();
+ }
+
+ // Commit the changes to the types in the IR.
+ // This requires inserting tensor.cast operations to mediate the newly
+ // inferred result types with users that do not support type inference.
+ void commit() {
+ // For each use whose type changed, cast the value with the new type back to
+ // the old type.
+ for (auto [value, oldType] : oldTypes) {
+ for (auto &use : value.getUses()) {
+ if (isReplaceableUser(use.getOwner()))
+ continue;
+
+ OpBuilder builder(value.getContext());
+ builder.setInsertionPoint(use.getOwner());
+
+ Location loc = value.getLoc();
+ use.set(builder.create<tensor::CastOp>(loc, oldType, value));
+ }
+ }
+
+ oldTypes.clear();
+ }
+
+private:
+ // A record of each value whose type was updated along with that value's
+ // previous type.
+ llvm::SmallVector<std::pair<Value, Type>> oldTypes;
+};
+
+void propagateShapesInRegion(Region ®ion, TypeModificationState &state);
+
+void propagateShapesToTosaIf(Operation &op, TypeModificationState &state) {
IfOp ifOp = dyn_cast<IfOp>(op);
if (!ifOp)
return;
@@ -58,7 +131,7 @@ void propagateShapesToTosaIf(Operation &op) {
if (inferredTy.hasRank()) {
Type newType = oldType.clone(inferredTy.getShape());
- blockArg.setType(newType);
+ state.setType(blockArg, newType);
}
}
@@ -71,14 +144,14 @@ void propagateShapesToTosaIf(Operation &op) {
ValueKnowledge::join(operandKnowledge, blockKnowledge);
if (!joinedKnowledge)
continue;
- frontBlock.getArgument(i).setType(joinedKnowledge.getType());
+ state.setType(frontBlock.getArgument(i), joinedKnowledge.getType());
}
- propagateShapesInRegion(region);
+ propagateShapesInRegion(region, state);
}
}
-void propagateShapesToTosaWhile(Operation &op) {
+void propagateShapesToTosaWhile(Operation &op, TypeModificationState &state) {
WhileOp whileOp = dyn_cast<WhileOp>(op);
if (!whileOp)
return;
@@ -86,49 +159,29 @@ void propagateShapesToTosaWhile(Operation &op) {
// Determine what the expected argument types are to the cond/body blocks.
// The expected arguments should be compatible with ever iteration of the
// loop body / condition for tosa.while.
- llvm::SmallVector<Type> argTypes;
- for (auto operand : op.getOperands()) {
- auto operandTy = cast<ShapedType>(operand.getType());
- if (operandTy.hasRank()) {
- auto newTy = operandTy.clone(operandTy.getShape());
- argTypes.push_back(newTy);
- } else {
- argTypes.push_back(operand.getType());
- }
- }
-
- // Save out the type information so we can restore at the end.
- llvm::DenseMap<Value, Type> originalTypeMap;
- for (auto &block : op.getRegion(1)) {
- for (auto arg : block.getArguments())
- originalTypeMap[arg] = arg.getType();
- for (auto &op : block)
- for (auto result : op.getResults())
- originalTypeMap[result] = result.getType();
- }
+ SmallVector<Type> argTypes = llvm::to_vector(op.getOperandTypes());
bool hasNewTypes = true;
while (hasNewTypes) {
+ TypeModificationState localState;
// Set types on the block args.
Region &bodyRegion = op.getRegion(1);
Block &block = bodyRegion.front();
for (int i = 0, s = argTypes.size(); i < s; i++) {
- block.getArgument(i).setType(argTypes[i]);
+ localState.setType(block.getArgument(i), argTypes[i]);
}
// Propagate to the end.
- propagateShapesInRegion(bodyRegion);
+ propagateShapesInRegion(bodyRegion, localState);
- // Find all the tosa yield types and verify there is atleast one.
+ // Find all the tosa yield types and verify there is a single one.
llvm::SmallVector<YieldOp> yieldOps;
for (auto &block : bodyRegion)
if (auto yieldOp = dyn_cast<YieldOp>(block.getTerminator()))
yieldOps.push_back(yieldOp);
- if (yieldOps.empty())
- return;
-
+ assert(yieldOps.size() == 1 && "missing or non-unique yield op");
// Using the new tosa.yield operand types, infer the new subtypes.
llvm::SmallVector<ValueKnowledge> yieldTypeInfo;
for (auto ty : argTypes) {
@@ -158,17 +211,8 @@ void propagateShapesToTosaWhile(Operation &op) {
argTypes[i] = newType;
}
- // The types inferred in the block assume the operand types specified for
- // this iteration. We need to restore the original types to ensure that
- // future iterations only use the already specified types, not possible
- // types from previous iterations.
- for (auto &block : bodyRegion) {
- for (auto arg : block.getArguments())
- arg.setType(originalTypeMap[arg]);
- for (auto &op : block)
- for (auto result : op.getResults())
- result.setType(originalTypeMap[result]);
- }
+ // Revert all changes made during the speculative part of the algorithm.
+ localState.revert();
}
// We now set the block arguments according to the most recent shape
@@ -176,41 +220,22 @@ void propagateShapesToTosaWhile(Operation &op) {
// iteration.
for (auto ®ion : op.getRegions()) {
for (unsigned int i = 0, s = argTypes.size(); i < s; i++) {
- region.front().getArgument(i).setType(argTypes[i]);
+ state.setType(region.front().getArgument(i), argTypes[i]);
}
- propagateShapesInRegion(region);
+ propagateShapesInRegion(region, state);
}
}
-// Track the old type for each operand whose type was updated
-// during inference. This information is used to introduce casts
-// back to the type expected by the operand after inference.
-struct TypeRewriteInfo {
- OpOperand *operand;
- Type oldType;
-};
-
-void propagateShapesInRegion(Region ®ion) {
- // Check whether this use case is replaceable. We define an op as
- // being replaceable if it is used by a TosaOp, or an op with a
- // type-inference related interface.
- // When a non-replaceable use is encountered, the value is wrapped in a
- // cast back to the original type after inference.
- auto isReplaceableUser = [](Operation *user) -> bool {
- return user->getDialect()->getNamespace() ==
- TosaDialect::getDialectNamespace() ||
- isa<InferTypeOpInterface, InferShapedTypeOpInterface>(user);
- };
-
- llvm::SmallVector<TypeRewriteInfo> requiresUpdate;
+void propagateShapesInRegion(Region ®ion, TypeModificationState &state) {
for (auto &block : region) {
for (Operation &op : block) {
- if (op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
+ if (!op.getDialect() ||
+ op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
continue;
- propagateShapesToTosaIf(op);
- propagateShapesToTosaWhile(op);
+ propagateShapesToTosaIf(op, state);
+ propagateShapesToTosaWhile(op, state);
InferShapedTypeOpInterface shapeInterface =
dyn_cast<InferShapedTypeOpInterface>(op);
@@ -252,30 +277,11 @@ void propagateShapesInRegion(Region ®ion) {
continue;
// Set new type
- result.setType(newKnowledge.getType());
-
- // Collect all uses of the operation which require update.
- for (auto &user : result.getUses()) {
- if (!isReplaceableUser(user.getOwner()))
- requiresUpdate.push_back({&user, resultTy});
- }
+ state.setType(result, newKnowledge.getType());
}
}
}
}
-
- // For each use whose type changed, cast the value with the new type back to
- // the old type.
- IRRewriter rewriter(region.getContext());
- for (auto [operand, oldType] : requiresUpdate) {
- rewriter.setInsertionPoint(operand->getOwner());
-
- auto oldValue = operand->get();
-
- auto loc = oldValue.getLoc();
- auto castOp = rewriter.create<tensor::CastOp>(loc, oldType, oldValue);
- operand->set(castOp);
- }
}
/// Pass that performs shape propagation across TOSA operations. This includes
@@ -285,7 +291,9 @@ struct TosaInferShapes
public:
void runOnOperation() override {
func::FuncOp func = getOperation();
- propagateShapesInRegion(func.getBody());
+ TypeModificationState state;
+ propagateShapesInRegion(func.getBody(), state);
+ state.commit();
}
};
} // namespace
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 1f0cfaf92c5c74..781e2ddf76fbd3 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --split-input-file --tosa-infer-shapes %s | FileCheck %s
+// RUN: mlir-opt --split-input-file --tosa-infer-shapes --allow-unregistered-dialect %s | FileCheck %s
// CHECK-LABEL: @test_return
func.func @test_return(%arg0 : tensor<4xf32>) -> tensor<*xf32> {
@@ -1177,6 +1177,120 @@ func.func @while_test(%arg0 : tensor<i32>, %arg1 : tensor<1xi32>) -> () {
// -----
+// This test locks down a fix for a crash in the type inference process.
+// The relevant pattern is a while loop whose body contains a TOSA operation which is
+// consumed by a non-inferrable user in the same body.
+// Previously, this would trigger a crash due to how types are cached and then
+// reapplied to the operations in the loops body.
+
+// CHECK-LABEL: @while_dont_crash
+func.func @while_dont_crash(%arg0 : tensor<i32>) -> (tensor<*xi32>) {
+ %0 = tosa.add %arg0, %arg0 : (tensor<i32>, tensor<i32>) -> tensor<*xi32>
+
+ // CHECK: tosa.while_loop
+ // CHECK-SAME: (tensor<i32>) -> tensor<i32>
+ %1 = tosa.while_loop (%arg1 = %0) : (tensor<*xi32>) -> tensor<*xi32> {
+ %2 = "tosa.const"() <{value = dense<3> : tensor<i32>}> : () -> tensor<i32>
+
+ // CHECK: tosa.greater_equal
+ // CHECK-SAME: (tensor<i32>, tensor<i32>) -> tensor<i1>
+ %3 = tosa.greater_equal %2, %arg1 : (tensor<i32>, tensor<*xi32>) -> tensor<*xi1>
+
+ tosa.yield %3 : tensor<*xi1>
+
+ } do {
+
+ // CHECK: ^bb0
+ // CHECK-SAME: tensor<i32>
+ ^bb0(%arg1: tensor<*xi32>):
+
+ // CHECK: tosa.add
+ // CHECK-SAME: (tensor<i32>, tensor<i32>) -> tensor<i32>
+ %3 = tosa.add %arg1, %arg1 : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>
+
+ // CHECK: %[[CAST:.+]] = tensor.cast %{{.*}} : tensor<i32> to tensor<*xi32>
+ // CHECK: "use"(%[[CAST]]) : (tensor<*xi32>) -> ()
+ "use"(%3) : (tensor<*xi32>) -> ()
+
+ tosa.yield %3 : tensor<*xi32>
+ }
+
+ // CHECK: tensor.cast
+ return %1 : tensor<*xi32>
+}
+
+// -----
+
+// This test locks down a fix for a crash in the type inference process.
+// The relevant pattern is a while loop whose body contains a TOSA operation which is
+// consumed by a non-inferrable user in the same body.
+
+// CHECK-LABEL: @while_dont_crash_nested
+func.func @while_dont_crash_nested(%arg0 : tensor<i32>) -> (tensor<*xi32>) {
+ %0 = tosa.add %arg0, %arg0 : (tensor<i32>, tensor<i32>) -> tensor<*xi32>
+
+ // CHECK: tosa.while_loop
+ // CHECK-SAME: (tensor<i32>) -> tensor<i32>
+ %1 = tosa.while_loop (%arg1 = %0) : (tensor<*xi32>) -> tensor<*xi32> {
+ %2 = "tosa.const"() <{value = dense<3> : tensor<i32>}> : () -> tensor<i32>
+
+ // CHECK: tosa.greater_equal
+ // CHECK-SAME: (tensor<i32>, tensor<i32>) -> tensor<i1>
+ %3 = tosa.greater_equal %2, %arg1 : (tensor<i32>, tensor<*xi32>) -> tensor<*xi1>
+
+ // CHECK: tosa.yield
+ // CHECK-SAME: tensor<i1>
+ tosa.yield %3 : tensor<*xi1>
+
+ } do {
+
+ // CHECK: ^bb0
+ // CHECK-SAME: tensor<i32>
+ ^bb0(%arg1: tensor<*xi32>):
+
+ // CHECK: tosa.while_loop
+ // CHECK-SAME: (tensor<i32>) -> tensor<i32>
+ %1 = tosa.while_loop (%arg2 = %arg1) : (tensor<*xi32>) -> tensor<*xi32> {
+ %2 = "tosa.const"() <{value = dense<3> : tensor<i32>}> : () -> tensor<i32>
+
+ // CHECK: tosa.greater_equal
+ // CHECK-SAME: (tensor<i32>, tensor<i32>) -> tensor<i1>
+ %4 = tosa.greater_equal %2, %arg2 : (tensor<i32>, tensor<*xi32>) -> tensor<*xi1>
+
+ // CHECK: tosa.yield
+ // CHECK-SAME: tensor<i1>
+ tosa.yield %4 : tensor<*xi1>
+
+ } do {
+
+ // CHECK: ^bb0
+ // CHECK-SAME: tensor<i32>
+ ^bb0(%arg2: tensor<*xi32>):
+
+ // CHECK: tosa.add
+ // CHECK-SAME: (tensor<i32>, tensor<i32>) -> tensor<i32>
+ %4 = tosa.add %arg2, %arg2 : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>
+
+ // CHECK: %[[CAST:.+]] = tensor.cast %{{.*}} : tensor<i32> to tensor<*xi32>
+ // CHECK: "use"(%[[CAST]]) : (tensor<*xi32>) -> ()
+ "use"(%4) : (tensor<*xi32>) -> ()
+
+ // CHECK: tosa.yield
+ // CHECK-SAME: tensor<i32>
+ tosa.yield %4 : tensor<*xi32>
+ }
+
+ // CHECK: tosa.yield
+ // CHECK-SAME: tensor<i32>
+ tosa.yield %1 : tensor<*xi32>
+ }
+
+ // CHECK: tensor.cast
+ return %1 : tensor<*xi32>
+}
+
+// -----
+
// CHECK-LABEL: @test_static_rfft2d
func.func @test_static_rfft2d(%arg0: tensor<5x2x8xf32>) -> () {
// CHECK: -> (tensor<5x2x5xf32>, tensor<5x2x5xf32>)
More information about the Mlir-commits
mailing list