[Mlir-commits] [mlir] 4ada6c2 - [mlir][tosa] Adds a canonicalization to the transpose op if the perms are a no op

Rob Suderman llvmlistbot at llvm.org
Mon Oct 18 16:33:48 PDT 2021


Author: not-jenni
Date: 2021-10-18T16:30:53-07:00
New Revision: 4ada6c2aafffd90c87900cab0adbb4d43c874b9b

URL: https://github.com/llvm/llvm-project/commit/4ada6c2aafffd90c87900cab0adbb4d43c874b9b
DIFF: https://github.com/llvm/llvm-project/commit/4ada6c2aafffd90c87900cab0adbb4d43c874b9b.diff

LOG: [mlir][tosa] Adds a canonicalization to the transpose op if the perms are a no op

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D112037

Added: 
    

Modified: 
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
    mlir/test/Dialect/Tosa/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 9c8d4ac54a848..2ad14f5bf1809 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -222,9 +222,37 @@ struct ConstantTransposeOptimization
   }
 };
 
+struct NoOpOptimization : public OpRewritePattern<tosa::TransposeOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::TransposeOp op,
+                                PatternRewriter &rewriter) const override {
+    auto perm = op.perms();
+
+    DenseIntElementsAttr permAttr;
+    if (!matchPattern(perm, m_Constant(&permAttr))) {
+      return failure();
+    }
+
+    SmallVector<int64_t> permValues = llvm::to_vector<6>(
+        llvm::map_range(permAttr.getValues<APInt>(),
+                        [](const APInt &val) { return val.getSExtValue(); }));
+
+    for (int i = 0, s = permValues.size(); i < s; i++) {
+      if (i != permValues[i]) {
+        return failure();
+      }
+    }
+
+    rewriter.replaceOp(op, op.input1());
+    return success();
+  }
+};
+
 void TransposeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                               MLIRContext *context) {
   results.insert<ConstantTransposeOptimization>(context);
+  results.insert<NoOpOptimization>(context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 5fe5bd44e1495..983ce58c0c652 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -233,7 +233,7 @@ func @transpose_nofold(%arg0: tensor<3x3xf32>) -> tensor<3x3xf32> {
 // CHECK-LABEL: @transpose_nofold_shape
 func @transpose_nofold_shape(%arg0: tensor<3x4xf32>) -> tensor<?x?xf32> {
   // CHECK: "tosa.transpose"
-  %0 = arith.constant dense<[0, 1]> : tensor<2xi32>
+  %0 = arith.constant dense<[1, 0]> : tensor<2xi32>
   %1 = "tosa.transpose"(%arg0, %0) { perms = [1, 0] }: (tensor<3x4xf32>, tensor<2xi32>) -> tensor<?x?xf32>
   return %1 : tensor<?x?xf32>
 }
@@ -325,3 +325,14 @@ func @transpose_nofold_quantized_types() -> tensor<1x1x16x1x!quant.uniform<i8<-1
   %0 = "tosa.transpose"(%input, %perms) : (tensor<1x1x1x16xi8>, tensor<4xi32>) -> tensor<1x1x16x1x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>>
   return %0: tensor<1x1x16x1x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>>
 }
+
+// -----
+
+// CHECK-LABEL: @transpose_no_op
+func @transpose_no_op(%arg0: tensor<3x4x5x6xf32>) -> tensor<3x4x5x6xf32> {
+  // CHECK: return %arg0
+  // CHECK-NOT: tosa.transpose
+  %perms = "tosa.const"() {value = dense<[0, 1, 2, 3]> : tensor<4xi32>} : () -> tensor<4xi32>
+  %1 = "tosa.transpose"(%arg0, %perms) : (tensor<3x4x5x6xf32>, tensor<4xi32>) -> tensor<3x4x5x6xf32>
+  return %1 : tensor<3x4x5x6xf32>
+}


        


More information about the Mlir-commits mailing list