[Mlir-commits] [mlir] [mlir][tosa] Cleanups for post-merge review comments in tosa-infer-shapes (PR #87660)
Spenser Bauman
llvmlistbot at llvm.org
Tue Apr 9 08:36:36 PDT 2024
https://github.com/sabauma updated https://github.com/llvm/llvm-project/pull/87660
>From 32a051d685a02c2843c3fab2a001c91768d5bd85 Mon Sep 17 00:00:00 2001
From: Spenser Bauman <sabauma at fastmail>
Date: Wed, 3 Apr 2024 14:35:46 -0400
Subject: [PATCH 1/2] Address review feedback from jpiennar
This change addresses some of the additional review feedback on
https://github.com/llvm/llvm-project/pull/87234.
A summary of the changes:
1. Cleaned up the language to use 'roll back' rather than revert to
reduce the chance of confusion. Improved some function names as well.
2. Eliminated string comparisons on dialect names.
3. Prevented the introduction of redundant tensor.cast operations for the
same value.
---
.../Tosa/Transforms/TosaInferShapes.cpp | 44 +++++++++----------
1 file changed, 22 insertions(+), 22 deletions(-)
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
index 8614559e2a6f13..d01891a04d2aac 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
@@ -18,6 +18,7 @@
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -39,13 +40,9 @@ namespace {
// type-inference related interface.
// When a non-replaceable use is encountered, the value is wrapped in a
// cast back to the original type after inference.
-bool isReplaceableUser(Operation *user) {
- // Handle unregistered dialects.
- if (!user->getDialect())
- return false;
-
- return user->getDialect()->getNamespace() ==
- TosaDialect::getDialectNamespace() ||
+bool canBeRefined(Operation *user) {
+ Dialect *tosaDialect = user->getContext()->getLoadedDialect<TosaDialect>();
+ return user->getDialect() == tosaDialect ||
isa<InferTypeOpInterface, InferShapedTypeOpInterface>(user);
}
@@ -53,16 +50,16 @@ bool isReplaceableUser(Operation *user) {
// updated. For the tosa.while_loop operation, types are speculatively updated
// within the body region to determine the output type of the while_loop. This
// process is performed until a fixed point is reached, then the types are
-// reverted.
+// rolled back.
//
-// This class encapsulates the state information needed to perform the reversion
+// This class encapsulates the state information needed to perform the roll back
// process or to commit to the final changes.
class TypeModificationState {
public:
TypeModificationState() = default;
~TypeModificationState() {
- // Ensure the recorded modifications are either committed or reverted.
+ // Ensure the recorded modifications are either committed or rolled back.
assert(oldTypes.empty() && "unhandled type modifications");
}
@@ -74,10 +71,9 @@ class TypeModificationState {
}
}
- // Revert changes made to the types in the IR by setting all the affected
+ // Roll back changes made to the types in the IR by setting all the affected
// values to their old types.
- void revert() {
- // Otherwise revert the changes.
+ void rollBack() {
for (auto [value, type] : oldTypes)
value.setType(type);
@@ -91,15 +87,18 @@ class TypeModificationState {
// For each use whose type changed, cast the value with the new type back to
// the old type.
for (auto [value, oldType] : oldTypes) {
+ tensor::CastOp castedValue;
for (auto &use : value.getUses()) {
- if (isReplaceableUser(use.getOwner()))
+ if (canBeRefined(use.getOwner()))
continue;
- OpBuilder builder(value.getContext());
- builder.setInsertionPoint(use.getOwner());
+ // Cache the cast to avoid generating duplicates
+ if (!castedValue) {
+ ImplicitLocOpBuilder builder{value.getLoc(), use.getOwner()};
+ castedValue = builder.create<tensor::CastOp>(oldType, value);
+ }
- Location loc = value.getLoc();
- use.set(builder.create<tensor::CastOp>(loc, oldType, value));
+ use.set(castedValue);
}
}
@@ -211,8 +210,8 @@ void propagateShapesToTosaWhile(Operation &op, TypeModificationState &state) {
argTypes[i] = newType;
}
- // Revert all changes made during the speculative part of the algorithm.
- localState.revert();
+ // Roll back all changes made during the speculative part of the algorithm.
+ localState.rollBack();
}
// We now set the block arguments according to the most recent shape
@@ -228,10 +227,11 @@ void propagateShapesToTosaWhile(Operation &op, TypeModificationState &state) {
}
void propagateShapesInRegion(Region ®ion, TypeModificationState &state) {
+ Dialect *tosaDialect = region.getContext()->getLoadedDialect<TosaDialect>();
+
for (auto &block : region) {
for (Operation &op : block) {
- if (!op.getDialect() ||
- op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
+ if (op.getDialect() != tosaDialect)
continue;
propagateShapesToTosaIf(op, state);
>From 8ae0ffbf49e93f5e9d315eae2919260c2f62cf47 Mon Sep 17 00:00:00 2001
From: Spenser Bauman <sbauman at mathworks.com>
Date: Tue, 9 Apr 2024 11:35:02 -0400
Subject: [PATCH 2/2] Fix dialect check to avoid the cost of dialect lookup
---
mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
index d01891a04d2aac..b1d5720541846f 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
@@ -41,8 +41,9 @@ namespace {
// When a non-replaceable use is encountered, the value is wrapped in a
// cast back to the original type after inference.
bool canBeRefined(Operation *user) {
- Dialect *tosaDialect = user->getContext()->getLoadedDialect<TosaDialect>();
- return user->getDialect() == tosaDialect ||
+ if (!user->getDialect())
+ return false;
+ return user->getDialect()->getTypeID() == TypeID::get<TosaDialect>() ||
isa<InferTypeOpInterface, InferShapedTypeOpInterface>(user);
}
More information about the Mlir-commits
mailing list