[Mlir-commits] [mlir] 9739ef6 - [MLIR][Tosa] Turn reshape(const()) from canonicalization into fold; fix dynamic shape case
Matthias Gehre
llvmlistbot at llvm.org
Thu Jul 6 11:07:30 PDT 2023
Author: Liam Fitzpatrick
Date: 2023-07-06T20:07:24+02:00
New Revision: 9739ef67a8f79ed88d4e28710c6d5c67d5566425
URL: https://github.com/llvm/llvm-project/commit/9739ef67a8f79ed88d4e28710c6d5c67d5566425
DIFF: https://github.com/llvm/llvm-project/commit/9739ef67a8f79ed88d4e28710c6d5c67d5566425.diff
LOG: [MLIR][Tosa] Turn reshape(const()) from canonicalization into fold; fix dynamic shape case
1) Turns the canonicalization into a fold, so it can cleanup IR within other passes.
2) When the output of the reshape is a dynamic shaped tensor,
we cannot apply the fold to the constant, because constants are required to have
static shape.
Differential Revision: https://reviews.llvm.org/D154615
Added:
Modified:
mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
mlir/test/Dialect/Tosa/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 5af10ec80a5c90..8cefa64bc4c1fe 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -82,40 +82,9 @@ struct ReshapeReshapeOptimization : public OpRewritePattern<tosa::ReshapeOp> {
}
};
-struct ReshapeConstOptimization : public OpRewritePattern<tosa::ReshapeOp> {
- using OpRewritePattern<tosa::ReshapeOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(tosa::ReshapeOp op,
- PatternRewriter &rewriter) const override {
- Value input = op.getInput1();
- ShapedType inputTy = llvm::cast<ShapedType>(input.getType());
- ShapedType resultTy = llvm::cast<ShapedType>(op.getType());
-
- if (inputTy.getElementType() != resultTy.getElementType())
- return rewriter.notifyMatchFailure(op, "element type does not match.");
-
- // Check if input is constant
- DenseElementsAttr inputAttr;
- if (!matchPattern(input, m_Constant(&inputAttr)))
- return rewriter.notifyMatchFailure(op, "Non-constant input.");
-
- // Check if has >1 consumer and is not splat
- if (!input.hasOneUse() && !inputAttr.isSplat())
- return rewriter.notifyMatchFailure(op,
- "Used more than once or not-splat");
-
- // Build new const op with correct output shape
- DenseElementsAttr outputAttr = inputAttr.reshape(
- llvm::cast<ShapedType>(inputAttr.getType()).clone(op.getNewShape()));
- rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, resultTy, outputAttr);
- return success();
- }
-};
-
void ReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ReshapeReshapeOptimization>(context);
- results.add<ReshapeConstOptimization>(context);
}
LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
@@ -851,12 +820,25 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
if (inputTy == outputTy)
return getInput1();
+ // Constants must have static shape.
+ if (!outputTy.hasStaticShape())
+ return {};
+
auto operand = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
- if (operand && outputTy.hasStaticShape() && operand.isSplat()) {
+ if (!operand)
+ return {};
+
+ // Okay to duplicate splat constants.
+ if (operand.isSplat()) {
return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
}
- return {};
+ // Don't duplicate other constants.
+ if (!getInput1().hasOneUse())
+ return {};
+
+ return operand.reshape(
+ llvm::cast<ShapedType>(operand.getType()).clone(getNewShape()));
}
OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 3f6d0e0b014112..1f9ce1f4c7b929 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -375,16 +375,24 @@ func.func @reshape_canonicalize_double(%arg0: tensor<?x10xf32>) -> tensor<?x5xf3
}
// CHECK-LABEL: @reshape_canonicalize_const
-func.func @reshape_canonicalize_const() -> tensor<1x10xi32> {
- // CHECK: %[[VAR0:.+]] = "tosa.const"() <{value = dense<0> : tensor<1x10xi32>}
+func.func @reshape_canonicalize_const() -> tensor<1x5xi32> {
+ // CHECK: %[[VAR0:.+]] = "tosa.const"() <{value = dense<{{\[\[}}0, 1, 2, 3, 4]]> : tensor<1x5xi32>}
// CHECK: return %[[VAR0]]
- %0 = "tosa.const"() {value = dense<0> : tensor<10xi32>} : () -> tensor<10xi32>
- %1 = "tosa.reshape"(%0) {new_shape = array<i64: 1, 10>} : (tensor<10xi32>) -> tensor<1x10xi32>
- return %1 : tensor<1x10xi32>
+ %0 = "tosa.const"() {value = dense<[0, 1, 2, 3, 4]> : tensor<5xi32>} : () -> tensor<5xi32>
+ %1 = "tosa.reshape"(%0) {new_shape = array<i64: 1, 5>} : (tensor<5xi32>) -> tensor<1x5xi32>
+ return %1 : tensor<1x5xi32>
+}
+
+// CHECK-LABEL: @reshape_canonicalize_const_dynamic
+func.func @reshape_canonicalize_const_dynamic() -> tensor<1x?xi32> {
+ // CHECK: tosa.reshape
+ %0 = "tosa.const"() {value = dense<[0, 1, 2, 3, 4]> : tensor<5xi32>} : () -> tensor<5xi32>
+ %1 = "tosa.reshape"(%0) {new_shape = array<i64: 1, 5>} : (tensor<5xi32>) -> tensor<1x?xi32>
+ return %1 : tensor<1x?xi32>
}
-// CHECK-LABEL: @reshape_canonicalize_const_spat
-func.func @reshape_canonicalize_const_spat() -> (tensor<10xi32>, tensor<1x10xi32>) {
+// CHECK-LABEL: @reshape_canonicalize_const_splat
+func.func @reshape_canonicalize_const_splat() -> (tensor<10xi32>, tensor<1x10xi32>) {
// CHECK-DAG: %[[VAR0:.+]] = "tosa.const"() <{value = dense<0> : tensor<10xi32>}
// CHECK-DAG: %[[VAR1:.+]] = "tosa.const"() <{value = dense<0> : tensor<1x10xi32>}
// CHECK: return %[[VAR0]], %[[VAR1]]
More information about the Mlir-commits
mailing list