[Mlir-commits] [mlir] [mlir][tosa] Improve tosa-infer-shapes for ops consumed by non-TOSA operators (PR #72715)

Spenser Bauman llvmlistbot at llvm.org
Fri Nov 17 15:40:24 PST 2023


https://github.com/sabauma updated https://github.com/llvm/llvm-project/pull/72715

>From 1f5e0634e91525a464590f403c2d2904805d5f10 Mon Sep 17 00:00:00 2001
From: Spenser Bauman <sbauman at mathworks.com>
Date: Fri, 17 Nov 2023 17:10:23 -0500
Subject: [PATCH] [mlir][tosa] Improve tosa-infer-shapes for ops consumed by
 non-TOSA operators

TOSA operators consumed by non-TOSA ops generally do not have their
types inferred, as that would alter the types expected by their
consumers. This prevents type refinement on many TOSA operators when the
IR contains a mix of dialects.

This change modifies tosa-infer-shapes to update the types of all TOSA
operators during inference. When a consumer of that TOSA op is not safe
to update, a tensor.cast is inserted back to the original type. This
behavior is similar to how TOSA ops consumed by func.return are handled.

This allows for more type refinement of TOSA ops, and the additional
tensor.cast operators may be removed by later canonicalizations.
---
 .../Tosa/Transforms/TosaInferShapes.cpp       | 75 ++++++++-----------
 mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir | 12 +++
 2 files changed, 44 insertions(+), 43 deletions(-)

diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
index 3cc16a91edce747..ad28c564f7dbddb 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
@@ -183,17 +183,27 @@ void propagateShapesToTosaWhile(Operation &op) {
   }
 }
 
+// Track the old type for each operand whose type was updated
+// during inference. This information is used to introduce casts
+// back to the type expected by the operand after inference.
+struct TypeRewriteInfo {
+  OpOperand *operand;
+  Type oldType;
+};
+
 void propagateShapesInRegion(Region &region) {
   // Check whether this use case is replaceable. We define an op as
-  // being replaceable if it is used by a ReturnOp, a TosaOp, or an op with a
+  // being replaceable if it is used by a TosaOp, or an op with a
   // type-inference related interface.
+  // When a non-replaceable use is encountered, the value is wrapped in a
+  // cast back to the original type after inference.
   auto isReplaceableUser = [](Operation *user) -> bool {
-    return isa<func::ReturnOp>(user) ||
-           user->getDialect()->getNamespace() ==
+    return user->getDialect()->getNamespace() ==
                TosaDialect::getDialectNamespace() ||
            isa<InferTypeOpInterface, InferShapedTypeOpInterface>(user);
   };
 
+  llvm::SmallVector<TypeRewriteInfo> requiresUpdate;
   for (auto &block : region) {
     for (Operation &op : block) {
       if (op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
@@ -219,9 +229,6 @@ void propagateShapesInRegion(Region &region) {
           Value result = std::get<0>(it);
           ShapedTypeComponents predictedShape = std::get<1>(it);
 
-          if (!llvm::all_of(result.getUsers(), isReplaceableUser))
-            continue;
-
           // Determine the knowledge based on the output type.
           // TODO: should also query WIP type probably
           Type resultTy = result.getType();
@@ -246,10 +253,29 @@ void propagateShapesInRegion(Region &region) {
 
           // Set new type
           result.setType(newKnowledge.getType());
+
+          // Collect all uses of the operation which require update.
+          for (auto &user : result.getUses()) {
+            if (!isReplaceableUser(user.getOwner()))
+              requiresUpdate.push_back({&user, resultTy});
+          }
         }
       }
     }
   }
+
+  // For each use whose type changed, cast the value with the new type back to
+  // the old type.
+  IRRewriter rewriter(region.getContext());
+  for (auto [operand, oldType] : requiresUpdate) {
+    rewriter.setInsertionPoint(operand->getOwner());
+
+    auto oldValue = operand->get();
+
+    auto loc = oldValue.getLoc();
+    auto castOp = rewriter.create<tensor::CastOp>(loc, oldType, oldValue);
+    operand->set(castOp);
+  }
 }
 
 /// Pass that performs shape propagation across TOSA operations. This includes
@@ -259,44 +285,7 @@ struct TosaInferShapes
 public:
   void runOnOperation() override {
     func::FuncOp func = getOperation();
-
-    IRRewriter rewriter(func.getContext());
-
     propagateShapesInRegion(func.getBody());
-
-    // Insert UnrealizedConversionCasts to guarantee ReturnOp agress with
-    // the FuncOp type.
-    func.walk([&](func::ReturnOp op) {
-      func::FuncOp parent = dyn_cast<func::FuncOp>(op->getParentOp());
-      if (!parent)
-        return;
-
-      rewriter.setInsertionPoint(op);
-      FunctionType funcTy = func.getFunctionType();
-      auto resultTys = funcTy.getResults();
-
-      bool castAdded = false;
-      SmallVector<Value> castedValues;
-      for (auto it : llvm::zip(op->getOperands(), resultTys)) {
-        auto operand = std::get<0>(it);
-        auto currentTy = operand.getType();
-        auto castTy = std::get<1>(it);
-        if (currentTy == castTy) {
-          castedValues.push_back(operand);
-          continue;
-        }
-
-        castedValues.push_back(
-            rewriter.create<tensor::CastOp>(op.getLoc(), castTy, operand)
-                .getResult());
-
-        castAdded = true;
-      }
-
-      if (castAdded) {
-        rewriter.replaceOpWithNewOp<func::ReturnOp>(op, castedValues);
-      }
-    });
   }
 };
 } // namespace
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 7af66ae1dbc90f0..f057431a841b591 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -1262,6 +1262,17 @@ func.func @test_non_tosa_consumer_extract(%arg0: tensor<4x4xf32>, %arg1: index)
 
 // -----
 
+// CHECK-LABEL: test_non_tosa_consumer_still_propagates
+func.func @test_non_tosa_consumer_still_propagates(%arg0: tensor<1x1x8xf32>, %arg1: tensor<1x8x1xf32>) -> tensor<?x?xf32> {
+  // CHECK: tosa.matmul %arg0, %arg1 : (tensor<1x1x8xf32>, tensor<1x8x1xf32>) -> tensor<1x1x1xf32>
+  %0 = tosa.matmul %arg0, %arg1 : (tensor<1x1x8xf32>, tensor<1x8x1xf32>) -> tensor<?x1x1xf32>
+  %1 = arith.constant dense<[1, 1]> : tensor<2xindex>
+  %2 = tensor.reshape %0(%1) : (tensor<?x1x1xf32>, tensor<2xindex>) -> tensor<?x?xf32>
+  return %2 : tensor<?x?xf32>
+}
+
+// -----
+
 // CHECK-LABEL: test_tosa_use_def_chain
 func.func @test_tosa_use_def_chain(%arg0: tensor<1x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>, %arg2: tensor<16xf32>) -> tensor<?x16x16x16xf32> {
   // CHECK: [[CONV:%.+]] = tosa.conv2d %arg0, %arg1, %arg2
@@ -1298,3 +1309,4 @@ func.func @test_large_constant_permutation() {
   %72 = tosa.transpose %14, %cst_26 : (tensor<?x27xi64>, tensor<2xi32>) -> tensor<?x27xi64>
   return
 }
+



More information about the Mlir-commits mailing list