[Mlir-commits] [mlir] 06c440f - [mlir][tosa] Canonicalize tosa.transpose to tosa.reshape
Rob Suderman
llvmlistbot at llvm.org
Tue Jan 3 11:20:52 PST 2023
Author: Rob Suderman
Date: 2023-01-03T11:19:55-08:00
New Revision: 06c440f2dac2598308d777e704cc4866471af561
URL: https://github.com/llvm/llvm-project/commit/06c440f2dac2598308d777e704cc4866471af561
DIFF: https://github.com/llvm/llvm-project/commit/06c440f2dac2598308d777e704cc4866471af561.diff
LOG: [mlir][tosa] Canonicalize tosa.transpose to tosa.reshape
Added tosa.transpose canonicalization for case where a tosa.transpose is
equivalent to a tosa.reshape. This occurs when the permutation does not
permutate non-unary dimensions.
Reviewed By: NatashaKnk
Differential Revision: https://reviews.llvm.org/D140356
Added:
Modified:
mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
mlir/test/Dialect/Tosa/canonicalize.mlir
mlir/test/Dialect/Tosa/constant-op-fold.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 215a4cc1df50b..1f91d43ba2e9d 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -88,16 +88,22 @@ struct ReshapeConstOptimization : public OpRewritePattern<tosa::ReshapeOp> {
LogicalResult matchAndRewrite(tosa::ReshapeOp op,
PatternRewriter &rewriter) const override {
Value input = op.getInput1();
+ ShapedType inputTy = input.getType().cast<ShapedType>();
+ ShapedType resultTy = op.getType().cast<ShapedType>();
ArrayAttr newShape = op.getNewShape();
+ if (inputTy.getElementType() != resultTy.getElementType())
+ return rewriter.notifyMatchFailure(op, "element type does not match.");
+
// Check if input is constant
DenseElementsAttr inputAttr;
if (!matchPattern(input, m_Constant(&inputAttr)))
- return failure();
+ return rewriter.notifyMatchFailure(op, "Non-constant input.");
// Check if has >1 consumer and is not splat
if (!input.hasOneUse() && !inputAttr.isSplat())
- return failure();
+ return rewriter.notifyMatchFailure(op,
+ "Used more than once or not-splat");
// Grab the new shape
SmallVector<int64_t> newShapeValues = llvm::to_vector<6>(
@@ -132,7 +138,7 @@ LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
return success();
}
-struct NoOpOptimization : public OpRewritePattern<tosa::TransposeOp> {
+struct TransposeNoOp : public OpRewritePattern<tosa::TransposeOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(tosa::TransposeOp op,
@@ -159,9 +165,60 @@ struct NoOpOptimization : public OpRewritePattern<tosa::TransposeOp> {
}
};
+// Determines the case when tosa.transpose is a tosa.reshape operation.
+struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tosa::TransposeOp op,
+ PatternRewriter &rewriter) const override {
+ DenseIntElementsAttr permAttr;
+ if (!matchPattern(op.getPerms(), m_Constant(&permAttr)))
+ return rewriter.notifyMatchFailure(op, "Non-constant permutation");
+
+ auto input = op.getInput1();
+ auto inputTy = input.getType().cast<ShapedType>();
+ if (!inputTy.hasRank())
+ return rewriter.notifyMatchFailure(op, "Unranked input.");
+
+ int64_t numDynDims = 0;
+ for (int i = 0; i < inputTy.getRank(); ++i)
+ if (inputTy.isDynamicDim(i))
+ numDynDims++;
+
+ if (numDynDims > 1)
+ return rewriter.notifyMatchFailure(op, "Has more than one dynamic dim.");
+
+ SmallVector<int64_t> permValues = llvm::to_vector<6>(
+ llvm::map_range(permAttr.getValues<APInt>(),
+ [](const APInt &val) { return val.getSExtValue(); }));
+
+ SmallVector<int64_t> nonZeroPerms;
+ nonZeroPerms.reserve(permValues.size());
+ for (auto idx : permValues) {
+ auto sz = inputTy.getDimSize(idx);
+ if (sz != 1)
+ nonZeroPerms.push_back(idx);
+ }
+
+ for (int i = 1, s = nonZeroPerms.size(); i < s; ++i)
+ if (nonZeroPerms[i - 1] > nonZeroPerms[i])
+ return rewriter.notifyMatchFailure(op,
+ "Transpose changes memeory layout.");
+
+ SmallVector<int64_t> newShape;
+ newShape.reserve(inputTy.getRank());
+ for (int i = 0, s = inputTy.getRank(); i < s; ++i)
+ newShape.push_back(inputTy.getDimSize(permValues[i]));
+
+ rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
+ op, op.getType(), op.getInput1(), rewriter.getI64ArrayAttr(newShape));
+ return success();
+ }
+};
+
void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<NoOpOptimization>(context);
+ results.add<TransposeNoOp, TransposeIsReshape>(context);
}
struct AddZeroOptimization : public OpRewritePattern<tosa::AddOp> {
@@ -958,6 +1015,11 @@ OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
if (!operands[1])
return {};
+ auto inputTy = getInput1().getType().cast<ShapedType>();
+ auto resultTy = getType().cast<ShapedType>();
+ if (inputTy.getElementType() != resultTy.getElementType())
+ return {};
+
// Transposing splat values just means reshaping.
if (auto input = operands[0].dyn_cast_or_null<DenseElementsAttr>()) {
if (input.isSplat())
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index ca25100bb5fcd..7eea2323c02bb 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -400,6 +400,14 @@ func.func @transpose_no_op(%arg0: tensor<3x4x5x6xf32>) -> tensor<3x4x5x6xf32> {
return %1 : tensor<3x4x5x6xf32>
}
+// CHECK-LABEL: @transpose_is_reshape
+func.func @transpose_is_reshape(%arg0: tensor<1x4x5x1xf32>) -> tensor<1x4x1x5xf32> {
+ // CHECK: "tosa.reshape"(%arg0) {new_shape = [1, 4, 1, 5]} : (tensor<1x4x5x1xf32>) -> tensor<1x4x1x5xf32>
+ %perms = "tosa.const"() {value = dense<[3, 1, 0, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %0 = "tosa.transpose"(%arg0, %perms) : (tensor<1x4x5x1xf32>, tensor<4xi32>) -> tensor<1x4x1x5xf32>
+ return %0 : tensor<1x4x1x5xf32>
+}
+
// CHECK-LABEL: @single_bit_reshape
// https://github.com/llvm/llvm-project/issues/55440
func.func @single_bit_reshape() -> tensor<1xi1> {
diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
index 08115787db58a..1ca93fe07cc77 100644
--- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir
+++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
@@ -90,12 +90,12 @@ func.func @transpose_nofold_multi_users() -> (tensor<3x2xf32>, tensor<2x3xf32>)
}
// CHECK-LABEL: @transpose_nofold_quantized_types
-func.func @transpose_nofold_quantized_types() -> tensor<1x1x16x1x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>> {
+func.func @transpose_nofold_quantized_types() -> tensor<1x1x2x2x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01}>> {
%perms = "tosa.const"() {value = dense<[1, 2, 3, 0]> : tensor<4xi32>} : () -> tensor<4xi32>
- %input = "tosa.const"() {value = dense<[[[[-127, 127, 127, -127, -127, -127, -127, -127, -127, 127, 127, 127, 127, 127, -127, 127]]]]> : tensor<1x1x1x16xi8>} : () -> tensor<1x1x1x16xi8>
+ %input = "tosa.const"() {value = dense<-127> : tensor<2x1x1x2xi8>} : () -> tensor<2x1x1x2xi8>
// CHECK: tosa.transpose
- %0 = "tosa.transpose"(%input, %perms) : (tensor<1x1x1x16xi8>, tensor<4xi32>) -> tensor<1x1x16x1x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>>
- return %0: tensor<1x1x16x1x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>>
+ %0 = "tosa.transpose"(%input, %perms) : (tensor<2x1x1x2xi8>, tensor<4xi32>) -> tensor<1x1x2x2x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01}>>
+ return %0: tensor<1x1x2x2x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01}>>
}
// -----
More information about the Mlir-commits
mailing list