[Mlir-commits] [mlir] [mlir][tosa] Cleanups for post-merge review comments in tosa-infer-shapes (PR #87660)

Spenser Bauman llvmlistbot at llvm.org
Tue Apr 9 08:35:14 PDT 2024


https://github.com/sabauma updated https://github.com/llvm/llvm-project/pull/87660

>From 32a051d685a02c2843c3fab2a001c91768d5bd85 Mon Sep 17 00:00:00 2001
From: Spenser Bauman <sabauma at fastmail>
Date: Wed, 3 Apr 2024 14:35:46 -0400
Subject: [PATCH 1/2] Address review feedback from jpiennar

This change addresses some of the additional review feedback on
https://github.com/llvm/llvm-project/pull/87234.

A summary of the changes:

1. Cleaned up the language to use 'roll back' rather than revert to
   reduce the chance of confusion. Improved some function names as well.
2. Eliminated string comparisons on dialect names.
3. Prevented the introduction of redundant tensor.cast operations for the
   same value.
---
 .../Tosa/Transforms/TosaInferShapes.cpp       | 44 +++++++++----------
 1 file changed, 22 insertions(+), 22 deletions(-)

diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
index 8614559e2a6f13..d01891a04d2aac 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
 #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
@@ -39,13 +40,9 @@ namespace {
 // type-inference related interface.
 // When a non-replaceable use is encountered, the value is wrapped in a
 // cast back to the original type after inference.
-bool isReplaceableUser(Operation *user) {
-  // Handle unregistered dialects.
-  if (!user->getDialect())
-    return false;
-
-  return user->getDialect()->getNamespace() ==
-             TosaDialect::getDialectNamespace() ||
+bool canBeRefined(Operation *user) {
+  Dialect *tosaDialect = user->getContext()->getLoadedDialect<TosaDialect>();
+  return user->getDialect() == tosaDialect ||
          isa<InferTypeOpInterface, InferShapedTypeOpInterface>(user);
 }
 
@@ -53,16 +50,16 @@ bool isReplaceableUser(Operation *user) {
 // updated. For the tosa.while_loop operation, types are speculatively updated
 // within the body region to determine the output type of the while_loop. This
 // process is performed until a fixed point is reached, then the types are
-// reverted.
+// rolled back.
 //
-// This class encapsulates the state information needed to perform the reversion
+// This class encapsulates the state information needed to perform the roll back
 // process or to commit to the final changes.
 class TypeModificationState {
 public:
   TypeModificationState() = default;
 
   ~TypeModificationState() {
-    // Ensure the recorded modifications are either committed or reverted.
+    // Ensure the recorded modifications are either committed or rolled back.
     assert(oldTypes.empty() && "unhandled type modifications");
   }
 
@@ -74,10 +71,9 @@ class TypeModificationState {
     }
   }
 
-  // Revert changes made to the types in the IR by setting all the affected
+  // Roll back changes made to the types in the IR by setting all the affected
   // values to their old types.
-  void revert() {
-    // Otherwise revert the changes.
+  void rollBack() {
     for (auto [value, type] : oldTypes)
       value.setType(type);
 
@@ -91,15 +87,18 @@ class TypeModificationState {
     // For each use whose type changed, cast the value with the new type back to
     // the old type.
     for (auto [value, oldType] : oldTypes) {
+      tensor::CastOp castedValue;
       for (auto &use : value.getUses()) {
-        if (isReplaceableUser(use.getOwner()))
+        if (canBeRefined(use.getOwner()))
           continue;
 
-        OpBuilder builder(value.getContext());
-        builder.setInsertionPoint(use.getOwner());
+        // Cache the cast to avoid generating duplicates
+        if (!castedValue) {
+          ImplicitLocOpBuilder builder{value.getLoc(), use.getOwner()};
+          castedValue = builder.create<tensor::CastOp>(oldType, value);
+        }
 
-        Location loc = value.getLoc();
-        use.set(builder.create<tensor::CastOp>(loc, oldType, value));
+        use.set(castedValue);
       }
     }
 
@@ -211,8 +210,8 @@ void propagateShapesToTosaWhile(Operation &op, TypeModificationState &state) {
       argTypes[i] = newType;
     }
 
-    // Revert all changes made during the speculative part of the algorithm.
-    localState.revert();
+    // Roll back all changes made during the speculative part of the algorithm.
+    localState.rollBack();
   }
 
   // We now set the block arguments according to the most recent shape
@@ -228,10 +227,11 @@ void propagateShapesToTosaWhile(Operation &op, TypeModificationState &state) {
 }
 
 void propagateShapesInRegion(Region &region, TypeModificationState &state) {
+  Dialect *tosaDialect = region.getContext()->getLoadedDialect<TosaDialect>();
+
   for (auto &block : region) {
     for (Operation &op : block) {
-      if (!op.getDialect() ||
-          op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
+      if (op.getDialect() != tosaDialect)
         continue;
 
       propagateShapesToTosaIf(op, state);

>From 4319af9ed7bfd44b3673f9f9f66ac84721966717 Mon Sep 17 00:00:00 2001
From: Spenser Bauman <sbauman at mathworks.com>
Date: Tue, 9 Apr 2024 11:35:02 -0400
Subject: [PATCH 2/2] Fix dialect check to avoid the cost of dialect lookup

---
 mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
index d01891a04d2aac..44fd988d1d00fb 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
@@ -41,8 +41,10 @@ namespace {
 // When a non-replaceable use is encountered, the value is wrapped in a
 // cast back to the original type after inference.
 bool canBeRefined(Operation *user) {
-  Dialect *tosaDialect = user->getContext()->getLoadedDialect<TosaDialect>();
-  return user->getDialect() == tosaDialect ||
+  //Dialect *tosaDialect = user->getContext()->getLoadedDialect<TosaDialect>();
+  if (!user->getDialect())
+    return false;
+  return user->getDialect()->getTypeID() == TypeID::get<TosaDialect>() ||
          isa<InferTypeOpInterface, InferShapedTypeOpInterface>(user);
 }
 



More information about the Mlir-commits mailing list