[Mlir-commits] [mlir] e513f2c - [mlir][tosa] Cleanups for post-merge review comments in tosa-infer-shapes (#87660)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri May 10 10:56:10 PDT 2024


Author: Spenser Bauman
Date: 2024-05-10T13:56:06-04:00
New Revision: e513f2c69b13322d0289cbb74c91a84996382baa

URL: https://github.com/llvm/llvm-project/commit/e513f2c69b13322d0289cbb74c91a84996382baa
DIFF: https://github.com/llvm/llvm-project/commit/e513f2c69b13322d0289cbb74c91a84996382baa.diff

LOG: [mlir][tosa] Cleanups for post-merge review comments in tosa-infer-shapes (#87660)

This change addresses some of the additional review feedback on
https://github.com/llvm/llvm-project/pull/87234.

A summary of the changes:

1. Cleaned up the language to use 'roll back' rather than revert to
reduce the chance of confusion. Improved some function names as well.
2. Eliminated string comparisons on dialect names.
3. Prevented the introduction of redundant tensor.cast operations for
the same value.

---------

Co-authored-by: Spenser Bauman <sabauma at fastmail>

Added: 
    

Modified: 
    mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
index 8614559e2a6f1..b1d5720541846 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
 #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
@@ -39,13 +40,10 @@ namespace {
 // type-inference related interface.
 // When a non-replaceable use is encountered, the value is wrapped in a
 // cast back to the original type after inference.
-bool isReplaceableUser(Operation *user) {
-  // Handle unregistered dialects.
+bool canBeRefined(Operation *user) {
   if (!user->getDialect())
     return false;
-
-  return user->getDialect()->getNamespace() ==
-             TosaDialect::getDialectNamespace() ||
+  return user->getDialect()->getTypeID() == TypeID::get<TosaDialect>() ||
          isa<InferTypeOpInterface, InferShapedTypeOpInterface>(user);
 }
 
@@ -53,16 +51,16 @@ bool isReplaceableUser(Operation *user) {
 // updated. For the tosa.while_loop operation, types are speculatively updated
 // within the body region to determine the output type of the while_loop. This
 // process is performed until a fixed point is reached, then the types are
-// reverted.
+// rolled back.
 //
-// This class encapsulates the state information needed to perform the reversion
+// This class encapsulates the state information needed to perform the roll back
 // process or to commit to the final changes.
 class TypeModificationState {
 public:
   TypeModificationState() = default;
 
   ~TypeModificationState() {
-    // Ensure the recorded modifications are either committed or reverted.
+    // Ensure the recorded modifications are either committed or rolled back.
     assert(oldTypes.empty() && "unhandled type modifications");
   }
 
@@ -74,10 +72,9 @@ class TypeModificationState {
     }
   }
 
-  // Revert changes made to the types in the IR by setting all the affected
+  // Roll back changes made to the types in the IR by setting all the affected
   // values to their old types.
-  void revert() {
-    // Otherwise revert the changes.
+  void rollBack() {
     for (auto [value, type] : oldTypes)
       value.setType(type);
 
@@ -91,15 +88,18 @@ class TypeModificationState {
     // For each use whose type changed, cast the value with the new type back to
     // the old type.
     for (auto [value, oldType] : oldTypes) {
+      tensor::CastOp castedValue;
       for (auto &use : value.getUses()) {
-        if (isReplaceableUser(use.getOwner()))
+        if (canBeRefined(use.getOwner()))
           continue;
 
-        OpBuilder builder(value.getContext());
-        builder.setInsertionPoint(use.getOwner());
+        // Cache the cast to avoid generating duplicates
+        if (!castedValue) {
+          ImplicitLocOpBuilder builder{value.getLoc(), use.getOwner()};
+          castedValue = builder.create<tensor::CastOp>(oldType, value);
+        }
 
-        Location loc = value.getLoc();
-        use.set(builder.create<tensor::CastOp>(loc, oldType, value));
+        use.set(castedValue);
       }
     }
 
@@ -211,8 +211,8 @@ void propagateShapesToTosaWhile(Operation &op, TypeModificationState &state) {
       argTypes[i] = newType;
     }
 
-    // Revert all changes made during the speculative part of the algorithm.
-    localState.revert();
+    // Roll back all changes made during the speculative part of the algorithm.
+    localState.rollBack();
   }
 
   // We now set the block arguments according to the most recent shape
@@ -228,10 +228,11 @@ void propagateShapesToTosaWhile(Operation &op, TypeModificationState &state) {
 }
 
 void propagateShapesInRegion(Region &region, TypeModificationState &state) {
+  Dialect *tosaDialect = region.getContext()->getLoadedDialect<TosaDialect>();
+
   for (auto &block : region) {
     for (Operation &op : block) {
-      if (!op.getDialect() ||
-          op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
+      if (op.getDialect() != tosaDialect)
         continue;
 
       propagateShapesToTosaIf(op, state);


        


More information about the Mlir-commits mailing list