[Mlir-commits] [mlir] [mlir][TOSA] Fix shape inference when operand was inferred (PR #66906)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Sep 20 15:35:45 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
<details>
<summary>Changes</summary>
https://github.com/llvm/llvm-project/commit/057fc8e7d8a3593f98930b8b91f80b9dd9b5fd4a Introduces a bug in the `TosaInferShapesPass` when an operand type was already inferred. https://github.com/llvm/llvm-project/blob/f7bfa583b7a5ff0e9954d2810006b7a71123be88/mlir/include/mlir/Interfaces/InferTypeOpInterface.td#L248 interprets the `ValueShapeRange` as a normal `ValueRange` and looses the information of the inference.
This PR changes the logic of the shape inference a bit. Instead of saving the type information in a `DenseMap` and updating the types after the whole analysis for a region, it now updates the types directly in each iteration. That way the operands always have the inferred type.
---
Full diff: https://github.com/llvm/llvm-project/pull/66906.diff
3 Files Affected:
- (modified) mlir/include/mlir/Interfaces/InferTypeOpInterface.td (+1-1)
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp (+15-41)
- (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+13)
``````````diff
diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
index 54c1c13fd029dbc..c5eeeaf58a7b4f8 100644
--- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
+++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
@@ -223,7 +223,7 @@ def InferTypeOpAdaptorWithIsCompatible : InferTypeOpAdaptorBase<
>;
// Convenient trait to define a wrapper to inferReturnTypeComponents that passes
-// in the Op Adaptor directly
+// in the Op Adaptor directly. Only uses the current types of the operands.
class InferShapedTypeOpAdaptorBase<list<string> overridenMethods = []> : TraitList<
[
// Op implements infer type op interface.
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
index 9c49cd55788571b..3cc16a91edce747 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
@@ -41,8 +41,7 @@ namespace {
void propagateShapesInRegion(Region ®ion);
-void propagateShapesToTosaIf(
- Operation &op, DenseMap<Value, ShapedTypeComponents> &shapesStorage) {
+void propagateShapesToTosaIf(Operation &op) {
IfOp ifOp = dyn_cast<IfOp>(op);
if (!ifOp)
return;
@@ -53,12 +52,12 @@ void propagateShapesToTosaIf(
return;
for (unsigned int i = 1, s = op.getNumOperands(); i < s; i++) {
- auto inferredTy = shapesStorage[op.getOperand(i)];
+ auto inferredTy = cast<ShapedType>(op.getOperand(i).getType());
auto blockArg = frontBlock.getArgument(i - 1);
auto oldType = cast<ShapedType>(blockArg.getType());
if (inferredTy.hasRank()) {
- Type newType = oldType.clone(inferredTy.getDims());
+ Type newType = oldType.clone(inferredTy.getShape());
blockArg.setType(newType);
}
}
@@ -79,8 +78,7 @@ void propagateShapesToTosaIf(
}
}
-void propagateShapesToTosaWhile(
- Operation &op, DenseMap<Value, ShapedTypeComponents> &shapesStorage) {
+void propagateShapesToTosaWhile(Operation &op) {
WhileOp whileOp = dyn_cast<WhileOp>(op);
if (!whileOp)
return;
@@ -91,9 +89,8 @@ void propagateShapesToTosaWhile(
llvm::SmallVector<Type> argTypes;
for (auto operand : op.getOperands()) {
auto operandTy = cast<ShapedType>(operand.getType());
- auto shapedTypeComponent = shapesStorage[operand];
- if (shapedTypeComponent.hasRank()) {
- auto newTy = operandTy.clone(shapedTypeComponent.getDims());
+ if (operandTy.hasRank()) {
+ auto newTy = operandTy.clone(operandTy.getShape());
argTypes.push_back(newTy);
} else {
argTypes.push_back(operand.getType());
@@ -187,21 +184,6 @@ void propagateShapesToTosaWhile(
}
void propagateShapesInRegion(Region ®ion) {
- DenseMap<Value, ShapedTypeComponents> shapesStorage;
- auto setShapes = [&](Value val, Type t) {
- if (auto st = dyn_cast<ShapedType>(t))
- shapesStorage[val] = st;
- else
- shapesStorage[val] = t;
- };
- auto operandShape = [&](Value val) -> ShapeAdaptor {
- // Query the WIP mapping rather than the type if set.
- auto it = shapesStorage.find(val);
- if (it == shapesStorage.end())
- return nullptr;
- return it->second;
- };
-
// Check whether this use case is replaceable. We define an op as
// being replaceable if it is used by a ReturnOp, a TosaOp, or an op with a
// type-inference related interface.
@@ -217,8 +199,8 @@ void propagateShapesInRegion(Region ®ion) {
if (op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
continue;
- propagateShapesToTosaIf(op, shapesStorage);
- propagateShapesToTosaWhile(op, shapesStorage);
+ propagateShapesToTosaIf(op);
+ propagateShapesToTosaWhile(op);
InferShapedTypeOpInterface shapeInterface =
dyn_cast<InferShapedTypeOpInterface>(op);
@@ -227,12 +209,11 @@ void propagateShapesInRegion(Region ®ion) {
SmallVector<ShapedTypeComponents> returnedShapes;
- ValueShapeRange range(op.getOperands(), operandShape);
if (shapeInterface
- .inferReturnTypeComponents(op.getContext(), op.getLoc(), range,
- op.getDiscardableAttrDictionary(),
- op.getPropertiesStorage(),
- op.getRegions(), returnedShapes)
+ .inferReturnTypeComponents(
+ op.getContext(), op.getLoc(), op.getOperands(),
+ op.getDiscardableAttrDictionary(), op.getPropertiesStorage(),
+ op.getRegions(), returnedShapes)
.succeeded()) {
for (auto it : llvm::zip(op.getResults(), returnedShapes)) {
Value result = std::get<0>(it);
@@ -262,20 +243,13 @@ void propagateShapesInRegion(Region ®ion) {
ValueKnowledge::join(currentKnowledge, inferredKnowledge);
if (!newKnowledge)
continue;
- setShapes(result, newKnowledge.getType());
+
+ // Set new type
+ result.setType(newKnowledge.getType());
}
}
}
}
-
- // Actually update types with updated shape knowledge.
- for (auto it : shapesStorage) {
- auto result = it.second;
- if (result.hasRank()) {
- Type t = cast<ShapedType>(it.first.getType()).clone(result.getDims());
- it.first.setType(t);
- }
- }
}
/// Pass that performs shape propagation across TOSA operations. This includes
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index cb96b2a8a0d193b..d468ba582483cbe 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -1259,3 +1259,16 @@ func.func @test_non_tosa_consumer_extract(%arg0: tensor<4x4xf32>, %arg1: index)
%1 = tensor.extract %0[%arg1, %arg1] : tensor<?x?xf32>
return %1 : f32
}
+
+// -----
+
+// CHECK-LABEL: test_tosa_use_def_chain
+func.func @test_tosa_use_def_chain(%arg0: tensor<1x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>, %arg2: tensor<16xf32>) -> tensor<?x16x16x16xf32> {
+ // CHECK: [[CONV:%.+]] = tosa.conv2d %arg0, %arg1, %arg2
+ // CHECK: (tensor<1x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
+ %0 = tosa.conv2d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1>, pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>} : (tensor<1x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<?x32x32x16xf32>
+ // CHECK: tosa.max_pool2d [[CONV]]
+ // CHECK: (tensor<1x32x32x16xf32>) -> tensor<1x16x16x16xf32>
+ %1 = tosa.max_pool2d %0 {kernel = array<i64: 2, 2>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 2, 2>} : (tensor<?x32x32x16xf32>) -> tensor<?x16x16x16xf32>
+ return %1 : tensor<?x16x16x16xf32>
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/66906
More information about the Mlir-commits
mailing list