[Mlir-commits] [mlir] 0a94d35 - [mlir][tosa] Fix tosa-infer-shapes crash (#87234)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Apr 2 16:45:31 PDT 2024


Author: Spenser Bauman
Date: 2024-04-02T19:45:27-04:00
New Revision: 0a94d35bfb81cb0bef60ebe60513d191661da0bd

URL: https://github.com/llvm/llvm-project/commit/0a94d35bfb81cb0bef60ebe60513d191661da0bd
DIFF: https://github.com/llvm/llvm-project/commit/0a94d35bfb81cb0bef60ebe60513d191661da0bd.diff

LOG: [mlir][tosa] Fix tosa-infer-shapes crash (#87234)

The tosa-infer-shapes pass inserts tensor.cast operations to mediate
refined result types with consumers whose types cannot be refined. This
process interferes with how types are refined in tosa.while_loop body
regions, where types are propagated speculatively (to determine the
types of the tosa.yield terminator) and then reverted.

The new tosa.cast operations result in a crash due to not having types
associated to them for the reversion process.

This change modifies the shape propagation behavior so that the
introduction to tensor.cast operations behaves better with this type
reversion process. The new behavior is to only introduce tensor.cast
operations once we wish to commit the newly computed types to the IR.

This is an example causing the crash:

```mlir
func.func @while_dont_crash(%arg0 : tensor<i32>) -> (tensor<*xi32>) {
  %0 = tosa.add %arg0, %arg0 : (tensor<i32>, tensor<i32>) -> tensor<*xi32>

  %1 = tosa.while_loop (%arg1 = %0) : (tensor<*xi32>) -> tensor<*xi32> {
    %2 = "tosa.const"() <{value = dense<3> : tensor<i32>}> : () -> tensor<i32>
    %3 = tosa.greater_equal %2, %arg1 : (tensor<i32>, tensor<*xi32>) -> tensor<*xi1>
    tosa.yield %3 : tensor<*xi1>
  } do {
  ^bb0(%arg1: tensor<*xi32>):
    // Inferrable operation whose type will refine to tensor<i32>
    %3 = tosa.add %arg1, %arg1 : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>

    // Non-inferrable use site, will require the cast:
    //     tensor.cast %3 : tensor<i32> to tensor<*xi32>
    // 
    // The new cast operation will result in accessing undefined memory through
    // originalTypeMap in the C++ code.
    "use"(%3) : (tensor<*xi32>) -> ()
    tosa.yield %3 : tensor<*xi32>
  }

  return %1 : tensor<*xi32>
}
```

The `tensor.cast` operation inserted in the loop body causes a failure
in the code which resets the types after propagation through the loop
body:

```c++
// The types inferred in the block assume the operand types specified for
// this iteration. We need to restore the original types to ensure that
// future iterations only use the already specified types, not possible
// types from previous iterations.
for (auto &block : bodyRegion) {
  for (auto arg : block.getArguments())
    arg.setType(originalTypeMap[arg]);
  for (auto &op : block)
    for (auto result : op.getResults())
      result.setType(originalTypeMap[result]);  // problematic access
}
```

---------

Co-authored-by: Spenser Bauman <sabauma at fastmail>

Added: 
    

Modified: 
    mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
    mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
index ad28c564f7dbdd..8614559e2a6f13 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
@@ -18,14 +18,9 @@
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
 #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
 #include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/IRMapping.h"
-#include "mlir/IR/Matchers.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "llvm/Support/FormatVariadic.h"
 
 namespace mlir {
 namespace tosa {
@@ -39,9 +34,87 @@ using namespace mlir::tosa;
 
 namespace {
 
-void propagateShapesInRegion(Region &region);
+// Check whether this use case is replaceable. We define an op as
+// being replaceable if it is used by a TosaOp, or an op with a
+// type-inference related interface.
+// When a non-replaceable use is encountered, the value is wrapped in a
+// cast back to the original type after inference.
+bool isReplaceableUser(Operation *user) {
+  // Handle unregistered dialects.
+  if (!user->getDialect())
+    return false;
+
+  return user->getDialect()->getNamespace() ==
+             TosaDialect::getDialectNamespace() ||
+         isa<InferTypeOpInterface, InferShapedTypeOpInterface>(user);
+}
+
+// During type propagation, the types of values in the operator graph are
+// updated. For the tosa.while_loop operation, types are speculatively updated
+// within the body region to determine the output type of the while_loop. This
+// process is performed until a fixed point is reached, then the types are
+// reverted.
+//
+// This class encapsulates the state information needed to perform the reversion
+// process or to commit to the final changes.
+class TypeModificationState {
+public:
+  TypeModificationState() = default;
+
+  ~TypeModificationState() {
+    // Ensure the recorded modifications are either committed or reverted.
+    assert(oldTypes.empty() && "unhandled type modifications");
+  }
+
+  // Update the state of the value and record the old type.
+  void setType(Value value, Type type) {
+    if (value.getType() != type) {
+      oldTypes.emplace_back(value, value.getType());
+      value.setType(type);
+    }
+  }
 
-void propagateShapesToTosaIf(Operation &op) {
+  // Revert changes made to the types in the IR by setting all the affected
+  // values to their old types.
+  void revert() {
+    // Otherwise revert the changes.
+    for (auto [value, type] : oldTypes)
+      value.setType(type);
+
+    oldTypes.clear();
+  }
+
+  // Commit the changes to the types in the IR.
+  // This requires inserting tensor.cast operations to mediate the newly
+  // inferred result types with users that do not support type inference.
+  void commit() {
+    // For each use whose type changed, cast the value with the new type back to
+    // the old type.
+    for (auto [value, oldType] : oldTypes) {
+      for (auto &use : value.getUses()) {
+        if (isReplaceableUser(use.getOwner()))
+          continue;
+
+        OpBuilder builder(value.getContext());
+        builder.setInsertionPoint(use.getOwner());
+
+        Location loc = value.getLoc();
+        use.set(builder.create<tensor::CastOp>(loc, oldType, value));
+      }
+    }
+
+    oldTypes.clear();
+  }
+
+private:
+  // A record of each value whose type was updated along with that value's
+  // previous type.
+  llvm::SmallVector<std::pair<Value, Type>> oldTypes;
+};
+
+void propagateShapesInRegion(Region &region, TypeModificationState &state);
+
+void propagateShapesToTosaIf(Operation &op, TypeModificationState &state) {
   IfOp ifOp = dyn_cast<IfOp>(op);
   if (!ifOp)
     return;
@@ -58,7 +131,7 @@ void propagateShapesToTosaIf(Operation &op) {
 
       if (inferredTy.hasRank()) {
         Type newType = oldType.clone(inferredTy.getShape());
-        blockArg.setType(newType);
+        state.setType(blockArg, newType);
       }
     }
 
@@ -71,14 +144,14 @@ void propagateShapesToTosaIf(Operation &op) {
           ValueKnowledge::join(operandKnowledge, blockKnowledge);
       if (!joinedKnowledge)
         continue;
-      frontBlock.getArgument(i).setType(joinedKnowledge.getType());
+      state.setType(frontBlock.getArgument(i), joinedKnowledge.getType());
     }
 
-    propagateShapesInRegion(region);
+    propagateShapesInRegion(region, state);
   }
 }
 
-void propagateShapesToTosaWhile(Operation &op) {
+void propagateShapesToTosaWhile(Operation &op, TypeModificationState &state) {
   WhileOp whileOp = dyn_cast<WhileOp>(op);
   if (!whileOp)
     return;
@@ -86,49 +159,29 @@ void propagateShapesToTosaWhile(Operation &op) {
   // Determine what the expected argument types are to the cond/body blocks.
   // The expected arguments should be compatible with ever iteration of the
   // loop body / condition for tosa.while.
-  llvm::SmallVector<Type> argTypes;
-  for (auto operand : op.getOperands()) {
-    auto operandTy = cast<ShapedType>(operand.getType());
-    if (operandTy.hasRank()) {
-      auto newTy = operandTy.clone(operandTy.getShape());
-      argTypes.push_back(newTy);
-    } else {
-      argTypes.push_back(operand.getType());
-    }
-  }
-
-  // Save out the type information so we can restore at the end.
-  llvm::DenseMap<Value, Type> originalTypeMap;
-  for (auto &block : op.getRegion(1)) {
-    for (auto arg : block.getArguments())
-      originalTypeMap[arg] = arg.getType();
-    for (auto &op : block)
-      for (auto result : op.getResults())
-        originalTypeMap[result] = result.getType();
-  }
+  SmallVector<Type> argTypes = llvm::to_vector(op.getOperandTypes());
 
   bool hasNewTypes = true;
   while (hasNewTypes) {
+    TypeModificationState localState;
 
     // Set types on the block args.
     Region &bodyRegion = op.getRegion(1);
     Block &block = bodyRegion.front();
     for (int i = 0, s = argTypes.size(); i < s; i++) {
-      block.getArgument(i).setType(argTypes[i]);
+      localState.setType(block.getArgument(i), argTypes[i]);
     }
 
     // Propagate to the end.
-    propagateShapesInRegion(bodyRegion);
+    propagateShapesInRegion(bodyRegion, localState);
 
-    // Find all the tosa yield types and verify there is atleast one.
+    // Find all the tosa yield types and verify there is a single one.
     llvm::SmallVector<YieldOp> yieldOps;
     for (auto &block : bodyRegion)
       if (auto yieldOp = dyn_cast<YieldOp>(block.getTerminator()))
         yieldOps.push_back(yieldOp);
 
-    if (yieldOps.empty())
-      return;
-
+    assert(yieldOps.size() == 1 && "missing or non-unique yield op");
     // Using the new tosa.yield operand types, infer the new subtypes.
     llvm::SmallVector<ValueKnowledge> yieldTypeInfo;
     for (auto ty : argTypes) {
@@ -158,17 +211,8 @@ void propagateShapesToTosaWhile(Operation &op) {
       argTypes[i] = newType;
     }
 
-    // The types inferred in the block assume the operand types specified for
-    // this iteration. We need to restore the original types to ensure that
-    // future iterations only use the already specified types, not possible
-    // types from previous iterations.
-    for (auto &block : bodyRegion) {
-      for (auto arg : block.getArguments())
-        arg.setType(originalTypeMap[arg]);
-      for (auto &op : block)
-        for (auto result : op.getResults())
-          result.setType(originalTypeMap[result]);
-    }
+    // Revert all changes made during the speculative part of the algorithm.
+    localState.revert();
   }
 
   // We now set the block arguments according to the most recent shape
@@ -176,41 +220,22 @@ void propagateShapesToTosaWhile(Operation &op) {
   // iteration.
   for (auto &region : op.getRegions()) {
     for (unsigned int i = 0, s = argTypes.size(); i < s; i++) {
-      region.front().getArgument(i).setType(argTypes[i]);
+      state.setType(region.front().getArgument(i), argTypes[i]);
     }
 
-    propagateShapesInRegion(region);
+    propagateShapesInRegion(region, state);
   }
 }
 
-// Track the old type for each operand whose type was updated
-// during inference. This information is used to introduce casts
-// back to the type expected by the operand after inference.
-struct TypeRewriteInfo {
-  OpOperand *operand;
-  Type oldType;
-};
-
-void propagateShapesInRegion(Region &region) {
-  // Check whether this use case is replaceable. We define an op as
-  // being replaceable if it is used by a TosaOp, or an op with a
-  // type-inference related interface.
-  // When a non-replaceable use is encountered, the value is wrapped in a
-  // cast back to the original type after inference.
-  auto isReplaceableUser = [](Operation *user) -> bool {
-    return user->getDialect()->getNamespace() ==
-               TosaDialect::getDialectNamespace() ||
-           isa<InferTypeOpInterface, InferShapedTypeOpInterface>(user);
-  };
-
-  llvm::SmallVector<TypeRewriteInfo> requiresUpdate;
+void propagateShapesInRegion(Region &region, TypeModificationState &state) {
   for (auto &block : region) {
     for (Operation &op : block) {
-      if (op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
+      if (!op.getDialect() ||
+          op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
         continue;
 
-      propagateShapesToTosaIf(op);
-      propagateShapesToTosaWhile(op);
+      propagateShapesToTosaIf(op, state);
+      propagateShapesToTosaWhile(op, state);
 
       InferShapedTypeOpInterface shapeInterface =
           dyn_cast<InferShapedTypeOpInterface>(op);
@@ -252,30 +277,11 @@ void propagateShapesInRegion(Region &region) {
             continue;
 
           // Set new type
-          result.setType(newKnowledge.getType());
-
-          // Collect all uses of the operation which require update.
-          for (auto &user : result.getUses()) {
-            if (!isReplaceableUser(user.getOwner()))
-              requiresUpdate.push_back({&user, resultTy});
-          }
+          state.setType(result, newKnowledge.getType());
         }
       }
     }
   }
-
-  // For each use whose type changed, cast the value with the new type back to
-  // the old type.
-  IRRewriter rewriter(region.getContext());
-  for (auto [operand, oldType] : requiresUpdate) {
-    rewriter.setInsertionPoint(operand->getOwner());
-
-    auto oldValue = operand->get();
-
-    auto loc = oldValue.getLoc();
-    auto castOp = rewriter.create<tensor::CastOp>(loc, oldType, oldValue);
-    operand->set(castOp);
-  }
 }
 
 /// Pass that performs shape propagation across TOSA operations. This includes
@@ -285,7 +291,9 @@ struct TosaInferShapes
 public:
   void runOnOperation() override {
     func::FuncOp func = getOperation();
-    propagateShapesInRegion(func.getBody());
+    TypeModificationState state;
+    propagateShapesInRegion(func.getBody(), state);
+    state.commit();
   }
 };
 } // namespace

diff  --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 1f0cfaf92c5c74..2be120439ed68e 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --split-input-file --tosa-infer-shapes %s | FileCheck %s
+// RUN: mlir-opt --split-input-file --tosa-infer-shapes --allow-unregistered-dialect %s | FileCheck %s
 
 // CHECK-LABEL: @test_return
 func.func @test_return(%arg0 : tensor<4xf32>) -> tensor<*xf32> {
@@ -1177,6 +1177,97 @@ func.func @while_test(%arg0 : tensor<i32>, %arg1 : tensor<1xi32>) -> () {
 
 // -----
 
+// This test locks down a fix for a crash in the type inference process.
+// The relevant pattern is a while loop whose body contains a TOSA operation which is
+// consumed by a non-inferrable user in the same body.
+// Previously, this would trigger a crash due to how types are cached and then
+// reapplied to the operations in the loops body.
+
+// CHECK-LABEL: @while_dont_crash
+func.func @while_dont_crash(%arg0 : tensor<i32>) -> (tensor<*xi32>) {
+  %0 = tosa.add %arg0, %arg0 : (tensor<i32>, tensor<i32>) -> tensor<*xi32>
+  // CHECK:      tosa.while_loop
+  // CHECK-SAME: (tensor<i32>) -> tensor<i32>
+  %1 = tosa.while_loop (%arg1 = %0) : (tensor<*xi32>) -> tensor<*xi32> {
+    %2 = "tosa.const"() <{value = dense<3> : tensor<i32>}> : () -> tensor<i32>
+    // CHECK:       tosa.greater_equal
+    // CHECK-SAME: (tensor<i32>, tensor<i32>) -> tensor<i1>
+    %3 = tosa.greater_equal %2, %arg1 : (tensor<i32>, tensor<*xi32>) -> tensor<*xi1>
+    tosa.yield %3 : tensor<*xi1>
+  } do {
+  // CHECK:      ^bb0
+  // CHECK-SAME: tensor<i32>
+  ^bb0(%arg1: tensor<*xi32>):
+    // CHECK:     tosa.add
+    // CHECK-SAME: (tensor<i32>, tensor<i32>) -> tensor<i32>
+    %3 = tosa.add %arg1, %arg1 : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>
+    // CHECK: %[[CAST:.+]] = tensor.cast %{{.*}} : tensor<i32> to tensor<*xi32>
+    // CHECK: "use"(%[[CAST]]) : (tensor<*xi32>) -> ()
+    "use"(%3) : (tensor<*xi32>) -> ()
+    tosa.yield %3 : tensor<*xi32>
+  }
+  // CHECK: tensor.cast
+  return %1 : tensor<*xi32>
+}
+
+// -----
+
+// This test locks down a fix for a crash in the type inference process.
+// The relevant pattern is a while loop whose body contains a TOSA operation which is
+// consumed by a non-inferrable user in the same body.
+
+// CHECK-LABEL: @while_dont_crash_nested
+func.func @while_dont_crash_nested(%arg0 : tensor<i32>) -> (tensor<*xi32>) {
+  %0 = tosa.add %arg0, %arg0 : (tensor<i32>, tensor<i32>) -> tensor<*xi32>
+  // CHECK:      tosa.while_loop
+  // CHECK-SAME: (tensor<i32>) -> tensor<i32>
+  %1 = tosa.while_loop (%arg1 = %0) : (tensor<*xi32>) -> tensor<*xi32> {
+    %2 = "tosa.const"() <{value = dense<3> : tensor<i32>}> : () -> tensor<i32>
+    // CHECK:       tosa.greater_equal
+    // CHECK-SAME: (tensor<i32>, tensor<i32>) -> tensor<i1>
+    %3 = tosa.greater_equal %2, %arg1 : (tensor<i32>, tensor<*xi32>) -> tensor<*xi1>
+    // CHECK:      tosa.yield
+    // CHECK-SAME: tensor<i1>
+    tosa.yield %3 : tensor<*xi1>
+  } do {
+  // CHECK:      ^bb0
+  // CHECK-SAME: tensor<i32>
+  ^bb0(%arg1: tensor<*xi32>):
+    // CHECK:      tosa.while_loop
+    // CHECK-SAME: (tensor<i32>) -> tensor<i32>
+    %1 = tosa.while_loop (%arg2 = %arg1) : (tensor<*xi32>) -> tensor<*xi32> {
+      %2 = "tosa.const"() <{value = dense<3> : tensor<i32>}> : () -> tensor<i32>
+      // CHECK:       tosa.greater_equal
+      // CHECK-SAME: (tensor<i32>, tensor<i32>) -> tensor<i1>
+      %4 = tosa.greater_equal %2, %arg2 : (tensor<i32>, tensor<*xi32>) -> tensor<*xi1>
+      // CHECK:      tosa.yield
+      // CHECK-SAME: tensor<i1>
+      tosa.yield %4 : tensor<*xi1>
+    } do {
+    // CHECK:      ^bb0
+    // CHECK-SAME: tensor<i32>
+    ^bb0(%arg2: tensor<*xi32>):
+      // CHECK:     tosa.add
+      // CHECK-SAME: (tensor<i32>, tensor<i32>) -> tensor<i32>
+      %4 = tosa.add %arg2, %arg2 : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>
+      // CHECK: %[[CAST:.+]] = tensor.cast %{{.*}} : tensor<i32> to tensor<*xi32>
+      // CHECK: "use"(%[[CAST]]) : (tensor<*xi32>) -> ()
+      "use"(%4) : (tensor<*xi32>) -> ()
+      // CHECK:      tosa.yield
+      // CHECK-SAME: tensor<i32>
+      tosa.yield %4 : tensor<*xi32>
+    }
+    // CHECK:      tosa.yield
+    // CHECK-SAME: tensor<i32>
+    tosa.yield %1 : tensor<*xi32>
+  }
+
+  // CHECK: tensor.cast
+  return %1 : tensor<*xi32>
+}
+
+// -----
+
 // CHECK-LABEL: @test_static_rfft2d
 func.func @test_static_rfft2d(%arg0: tensor<5x2x8xf32>) -> () {
   // CHECK: -> (tensor<5x2x5xf32>, tensor<5x2x5xf32>)


        


More information about the Mlir-commits mailing list