[Mlir-commits] [mlir] 3bcaf2e - [mlir][tosa] Moves constant folding operations out of the Canonicalizer

Robert Suderman llvmlistbot at llvm.org
Mon Jun 6 15:42:57 PDT 2022


Author: Georgios Pinitas
Date: 2022-06-06T22:10:22Z
New Revision: 3bcaf2eb9337f1832fa45a095aa2e8862dcb84cd

URL: https://github.com/llvm/llvm-project/commit/3bcaf2eb9337f1832fa45a095aa2e8862dcb84cd
DIFF: https://github.com/llvm/llvm-project/commit/3bcaf2eb9337f1832fa45a095aa2e8862dcb84cd.diff

LOG: [mlir][tosa] Moves constant folding operations out of the Canonicalizer

Transpose operations on constant data were getting folded during the
canonicalization process. This has compile time cost proportional to
the constant size. Moving this to a separate pass to enable optionality
and flexibility of how such scenarios can be handled.

Reviewed By: rsuderman, jpienaar, stellaraccident

Differential Revision: https://reviews.llvm.org/D124685

Added: 
    mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp
    mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp
    mlir/test/Dialect/Tosa/constant-op-fold.mlir

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
    mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
    mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
    mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
    mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
    mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
    mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
    mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
    mlir/lib/Dialect/Tosa/Transforms/TosaOptionalDecompositions.cpp
    mlir/test/Dialect/Tosa/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
index a5b42990625af..711c101b15a52 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
@@ -34,6 +34,17 @@ namespace tosa {
 } // namespace tosa
 } // namespace mlir
 
+//===----------------------------------------------------------------------===//
+// Utility Functions
+//===----------------------------------------------------------------------===//
+namespace mlir {
+namespace tosa {
+/// Appends the canonicalization patterns for all the TOSA ops to the `patterns`
+void populateTosaOpsCanonicalizationPatterns(MLIRContext *ctx,
+                                             RewritePatternSet &patterns);
+} // namespace tosa
+} // namespace mlir
+
 #define GET_OP_CLASSES
 #include "mlir/Dialect/Tosa/IR/TosaOps.h.inc"
 

diff  --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
index 1bdfc2f43bf3b..9ffccfc948824 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
@@ -26,7 +26,10 @@ void populateTosaDecomposeTransposeConv(MLIRContext *ctx,
                                         RewritePatternSet &patterns);
 void populateTosaDecomposeDepthwise(MLIRContext *ctx,
                                     RewritePatternSet &patterns);
+void populateTosaFoldConstantTransposePatterns(MLIRContext *ctx,
+                                               RewritePatternSet &patterns);
 
+std::unique_ptr<Pass> createTosaLayerwiseConstantFoldPass();
 std::unique_ptr<Pass> createTosaInferShapesPass();
 std::unique_ptr<Pass> createTosaMakeBroadcastablePass();
 std::unique_ptr<Pass> createTosaTestQuantUtilAPIPass();

diff  --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index c3180ec14a325..46bd7a4780e00 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -15,6 +15,15 @@
 
 include "mlir/Pass/PassBase.td"
 
+def TosaLayerwiseConstantFoldPass : Pass<"tosa-layerwise-constant-fold", "func::FuncOp"> {
+  let summary = "Fold layerwise operations on constant tensors";
+  let description = [{
+    Pass that enables folding of full-layer operations on constant tensors.
+  }];
+
+  let constructor = "createTosaLayerwiseConstantFoldPass()";
+}
+
 def TosaInferShapes : Pass<"tosa-infer-shapes", "func::FuncOp"> {
   let summary = "Propagate shapes across TOSA operations";
   let description = [{

diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
index a8c610c05a7bc..18f7efe36f503 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
@@ -76,6 +76,8 @@ void mlir::tosa::addTosaToLinalgPasses(OpPassManager &pm,
   pm.addNestedPass<func::FuncOp>(tosa::createTosaMakeBroadcastablePass());
   pm.addNestedPass<func::FuncOp>(tosa::createTosaToLinalgNamed());
   pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
+  // TODO: Remove pass that operates on const tensor and enable optionality
+  pm.addNestedPass<func::FuncOp>(tosa::createTosaLayerwiseConstantFoldPass());
   pm.addNestedPass<func::FuncOp>(tosa::createTosaMakeBroadcastablePass());
   pm.addNestedPass<func::FuncOp>(tosa::createTosaToLinalg());
 }

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 1cf2a8808a07f..4de0c0f1a9ed8 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -94,6 +94,20 @@ Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value,
 // Operator Canonicalizers.
 //===----------------------------------------------------------------------===//
 
+template <typename... Args>
+void addOpsCanonicalizations(MLIRContext *ctx, RewritePatternSet &patterns) {
+  (void)std::initializer_list<int>{
+      0, (Args::getCanonicalizationPatterns(patterns, ctx), 0)...};
+}
+
+void mlir::tosa::populateTosaOpsCanonicalizationPatterns(
+    MLIRContext *ctx, RewritePatternSet &patterns) {
+  addOpsCanonicalizations<
+#define GET_OP_LIST
+#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
+      >(ctx, patterns);
+}
+
 struct ConcatOptimization : public OpRewritePattern<tosa::ConcatOp> {
   using OpRewritePattern<tosa::ConcatOp>::OpRewritePattern;
 
@@ -189,70 +203,6 @@ LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
   return success();
 }
 
-struct ConstantTransposeOptimization
-    : public OpRewritePattern<tosa::TransposeOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(tosa::TransposeOp op,
-                                PatternRewriter &rewriter) const override {
-    auto outputType = op.getType().cast<ShapedType>();
-    ArrayRef<int64_t> outputShape = outputType.getShape();
-    // TOSA supports quantized types.
-    if (!outputType.getElementType().isIntOrIndexOrFloat())
-      return failure();
-
-    DenseElementsAttr inputValues;
-    if (!matchPattern(op.input1(), m_Constant(&inputValues)))
-      return failure();
-    // Make sure the input is a constant that has a single user.
-    if (!llvm::hasSingleElement(op.input1().getDefiningOp()->getUsers()))
-      return failure();
-
-    DenseIntElementsAttr permAttr;
-    if (!matchPattern(op.perms(), m_Constant(&permAttr)))
-      return failure();
-    auto permValues = llvm::to_vector<6>(llvm::map_range(
-        // TOSA allows both 32- and 64-bit integer tensors here.
-        permAttr.getValues<APInt>(),
-        [](const APInt &val) { return val.getZExtValue(); }));
-
-    auto inputType = op.input1().getType().cast<ShapedType>();
-    ArrayRef<int64_t> inputShape = inputType.getShape();
-    int64_t numElements = inputType.getNumElements();
-
-    SmallVector<Attribute, 4> outputValues;
-    outputValues.resize(numElements);
-
-    // Transpose the input constant. Because we don't know its rank in advance,
-    // we need to loop over the range [0, element count) and delinearize the
-    // index.
-    auto attrValues = inputValues.getValues<Attribute>();
-    for (int srcLinearIndex = 0; srcLinearIndex < numElements;
-         ++srcLinearIndex) {
-      SmallVector<uint64_t, 6> srcIndices(inputType.getRank(), 0);
-      int totalCount = srcLinearIndex;
-      for (int dim = inputType.getRank() - 1; dim >= 0; --dim) {
-        srcIndices[dim] = totalCount % inputShape[dim];
-        totalCount /= inputShape[dim];
-      }
-
-      SmallVector<uint64_t, 6> dstIndices(outputType.getRank(), 0);
-      for (int dim = outputType.getRank() - 1; dim >= 0; --dim)
-        dstIndices[dim] = srcIndices[permValues[dim]];
-
-      uint64_t dstLinearIndex = dstIndices.front();
-      for (int dim = 1; dim < outputType.getRank(); ++dim)
-        dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim];
-
-      outputValues[dstLinearIndex] = attrValues[srcIndices];
-    }
-
-    rewriter.replaceOpWithNewOp<tosa::ConstOp>(
-        op, outputType, DenseElementsAttr::get(outputType, outputValues));
-    return success();
-  }
-};
-
 struct NoOpOptimization : public OpRewritePattern<tosa::TransposeOp> {
   using OpRewritePattern::OpRewritePattern;
 
@@ -282,7 +232,6 @@ struct NoOpOptimization : public OpRewritePattern<tosa::TransposeOp> {
 
 void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                               MLIRContext *context) {
-  results.add<ConstantTransposeOptimization>(context);
   results.add<NoOpOptimization>(context);
 }
 

diff  --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
index e98d3dfe26a70..79979eee9077d 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
@@ -2,7 +2,9 @@ add_mlir_dialect_library(MLIRTosaTransforms
   TosaDecomposeTransposeConv.cpp
   TosaDecomposeConv2D.cpp
   TosaDecomposeDepthwise.cpp
+  TosaFoldConstantTranspose.cpp
   TosaInferShapes.cpp
+  TosaLayerwiseConstantFoldPass.cpp
   TosaMakeBroadcastable.cpp
   TosaOptionalDecompositions.cpp
 

diff  --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
index ac8583f7c03e2..ef94e55c855d3 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
@@ -1,4 +1,4 @@
-//===- TosaDecomposeConv2D.cpp ------------------------------------------===//
+//===- TosaDecomposeConv2D.cpp --------------------------------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.

diff  --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
index 2ce9f24e6d9c9..b4bac42029e49 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
@@ -1,5 +1,4 @@
-//===- TosaDecomposeDepthwise.cpp
-//------------------------------------------===//
+//===- TosaDecomposeDepthwise.cpp -----------------------------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.

diff  --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
index d6ffa463f31bd..fa6ec91bb2416 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
@@ -1,5 +1,4 @@
-//===- TosaDecomposeTransposeConv.cpp
-//------------------------------------------===//
+//===- TosaDecomposeTransposeConv.cpp -------------------------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.

diff  --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp
new file mode 100644
index 0000000000000..5f14cf68321af
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp
@@ -0,0 +1,91 @@
+//===- TosaFoldConstantTranspose.cpp --------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Fold TOSA Transpose operation on constant data
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+using namespace mlir::tosa;
+
+namespace {
+
+struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::TransposeOp op,
+                                PatternRewriter &rewriter) const override {
+    auto outputType = op.getType().cast<ShapedType>();
+    // TOSA supports quantized types.
+    if (!outputType.getElementType().isIntOrIndexOrFloat())
+      return failure();
+
+    DenseElementsAttr inputValues;
+    if (!matchPattern(op.input1(), m_Constant(&inputValues)))
+      return failure();
+    // Make sure the input is a constant that has a single user.
+    if (!llvm::hasSingleElement(op.input1().getDefiningOp()->getUsers()))
+      return failure();
+
+    DenseIntElementsAttr permAttr;
+    if (!matchPattern(op.perms(), m_Constant(&permAttr)))
+      return failure();
+    auto permValues = llvm::to_vector<6>(llvm::map_range(
+        // TOSA allows both 32- and 64-bit integer tensors here.
+        permAttr.getValues<APInt>(),
+        [](const APInt &val) { return val.getZExtValue(); }));
+
+    auto inputType = op.input1().getType().cast<ShapedType>();
+    ArrayRef<int64_t> inputShape = inputType.getShape();
+    int64_t numElements = inputType.getNumElements();
+
+    SmallVector<Attribute, 4> outputValues;
+    outputValues.resize(numElements);
+
+    // Transpose the input constant. Because we don't know its rank in advance,
+    // we need to loop over the range [0, element count) and delinearize the
+    // index.
+    auto attrValues = inputValues.getValues<Attribute>();
+    ArrayRef<int64_t> outputShape = outputType.getShape();
+    for (int srcLinearIndex = 0; srcLinearIndex < numElements;
+         ++srcLinearIndex) {
+      SmallVector<uint64_t, 6> srcIndices(inputType.getRank(), 0);
+      int totalCount = srcLinearIndex;
+      for (int dim = inputType.getRank() - 1; dim >= 0; --dim) {
+        srcIndices[dim] = totalCount % inputShape[dim];
+        totalCount /= inputShape[dim];
+      }
+
+      SmallVector<uint64_t, 6> dstIndices(outputType.getRank(), 0);
+      for (int dim = outputType.getRank() - 1; dim >= 0; --dim)
+        dstIndices[dim] = srcIndices[permValues[dim]];
+
+      uint64_t dstLinearIndex = dstIndices.front();
+      for (int dim = 1; dim < outputType.getRank(); ++dim)
+        dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim];
+
+      outputValues[dstLinearIndex] = attrValues[srcIndices];
+    }
+
+    rewriter.replaceOpWithNewOp<tosa::ConstOp>(
+        op, outputType, DenseElementsAttr::get(outputType, outputValues));
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::tosa::populateTosaFoldConstantTransposePatterns(
+    MLIRContext *ctx, RewritePatternSet &patterns) {
+  patterns.add<TosaFoldConstantTranspose>(ctx);
+}

diff  --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
index fc55e44a7d373..e75399b7bdd24 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
@@ -1,4 +1,4 @@
-//===- TosaInferShapes.cpp ------------------------------------------===//
+//===- TosaInferShapes.cpp ------------------------------------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.

diff  --git a/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp
new file mode 100644
index 0000000000000..7cf7ff14eb9ac
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp
@@ -0,0 +1,43 @@
+//===- TosaLayerwiseConstantFoldPass.cpp ----------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements constant folding transformations on TOSA operations
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::tosa;
+
+namespace {
+
+struct TosaLayerwiseConstantFoldPass
+    : public TosaLayerwiseConstantFoldPassBase<TosaLayerwiseConstantFoldPass> {
+  void runOnOperation() override {
+    auto *ctx = &getContext();
+    RewritePatternSet patterns(ctx);
+    auto func = getOperation();
+
+    mlir::tosa::populateTosaFoldConstantTransposePatterns(ctx, patterns);
+    mlir::tosa::populateTosaOpsCanonicalizationPatterns(ctx, patterns);
+
+    if (applyPatternsAndFoldGreedily(func, std::move(patterns)).failed())
+      signalPassFailure();
+  }
+};
+
+} // namespace
+
+std::unique_ptr<Pass> mlir::tosa::createTosaLayerwiseConstantFoldPass() {
+  return std::make_unique<TosaLayerwiseConstantFoldPass>();
+}

diff  --git a/mlir/lib/Dialect/Tosa/Transforms/TosaOptionalDecompositions.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaOptionalDecompositions.cpp
index 0bf9eed621107..78b8cb3084afd 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaOptionalDecompositions.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaOptionalDecompositions.cpp
@@ -1,5 +1,4 @@
-//===- TosaOptionalDecompositions.cpp
-//------------------------------------------===//
+//===- TosaOptionalDecompositions.cpp -------------------------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.

diff  --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 934ca583f330d..62f1adb1e77ac 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -391,104 +391,6 @@ func.func @tile_nofold(%arg0: tensor<3x4xf32>) -> tensor<3x8xf32> {
   return %0 : tensor<3x8xf32>
 }
 
-// CHECK-LABEL: @transpose_fold
-func.func @transpose_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> {
-  // CHECK: return %arg0
-  %0 = arith.constant dense<[0, 1]> : tensor<2xi32>
-  %1 = "tosa.transpose"(%arg0, %0) { perms = [1, 0] }: (tensor<3x4xf32>, tensor<2xi32>) -> tensor<3x4xf32>
-  return %1 : tensor<3x4xf32>
-}
-
-// CHECK-LABEL: @transpose_nofold
-func.func @transpose_nofold(%arg0: tensor<3x3xf32>) -> tensor<3x3xf32> {
-  // CHECK: "tosa.transpose"
-  %0 = arith.constant dense<[1, 0]> : tensor<2xi32>
-  %1 = "tosa.transpose"(%arg0, %0) { perms = [1, 0] }: (tensor<3x3xf32>, tensor<2xi32>) -> tensor<3x3xf32>
-  return %1 : tensor<3x3xf32>
-}
-
-// CHECK-LABEL: @transpose_nofold_shape
-func.func @transpose_nofold_shape(%arg0: tensor<3x4xf32>) -> tensor<?x?xf32> {
-  // CHECK: "tosa.transpose"
-  %0 = arith.constant dense<[1, 0]> : tensor<2xi32>
-  %1 = "tosa.transpose"(%arg0, %0) { perms = [1, 0] }: (tensor<3x4xf32>, tensor<2xi32>) -> tensor<?x?xf32>
-  return %1 : tensor<?x?xf32>
-}
-
-// CHECK-LABEL: @transpose_fold_splat
-func.func @transpose_fold_splat() -> tensor<3x2xf32> {
-  %input = "tosa.const"() {value = dense<4.0> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
-  %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
-  //               CHECK: %[[CST:.+]] = "tosa.const"()
-  // CHECK-SAME{LITERAL}: value = dense<4.000000e+00> : tensor<3x2xf32>
-  %1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
-  // CHECK: return %[[CST]]
-  return %1 : tensor<3x2xf32>
-}
-
-// CHECK-LABEL: @transpose_fold_2d_float
-func.func @transpose_fold_2d_float() -> tensor<3x2xf32> {
-  %input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
-  %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
-  //               CHECK: %[[CST:.+]] = "tosa.const"()
-  // CHECK-SAME{LITERAL}: value = dense<[[0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf32>
-  %1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
-  // CHECK: return %[[CST]]
-  return %1 : tensor<3x2xf32>
-}
-
-// CHECK-LABEL: @transpose_fold_4d_int
-func.func @transpose_fold_4d_int() -> tensor<3x1x4x2xi32> {
-  %input = "tosa.const"() {value = dense<[[
-    [[ 0,  1,  2,  3], [ 4,  5,  6,  7], [ 8,  9, 10, 11]],
-    [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]
-  ]]> : tensor<1x2x3x4xi32>} : () -> tensor<1x2x3x4xi32>
-  %perms = "tosa.const"() {value = dense<[2, 0, 3, 1]> : tensor<4xi64>} : () -> tensor<4xi64>
-  //               CHECK: %[[CST:.+]] = "tosa.const"()
-  // CHECK-SAME{LITERAL}: value = dense<[
-  // CHECK-SAME{LITERAL}:   [[[0, 12], [1, 13], [2, 14], [3, 15]]],
-  // CHECK-SAME{LITERAL}:   [[[4, 16], [5, 17], [6, 18], [7, 19]]],
-  // CHECK-SAME{LITERAL}:   [[[8, 20], [9, 21], [10, 22], [11, 23]]]
-  // CHECK-SAME{LITERAL}: ]>
-  %1 = "tosa.transpose"(%input, %perms) : (tensor<1x2x3x4xi32>, tensor<4xi64>) -> tensor<3x1x4x2xi32>
-  // CHECK: return %[[CST]]
-  return %1 : tensor<3x1x4x2xi32>
-}
-
-// CHECK-LABEL: @transpose_nofold_non_cst_input
-func.func @transpose_nofold_non_cst_input(%input: tensor<2x3xf32>) -> tensor<3x2xf32> {
-  %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
-  // CHECK: tosa.transpose
-  %1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
-  return %1 : tensor<3x2xf32>
-}
-
-// CHECK-LABEL: @transpose_nofold_non_cst_perms
-func.func @transpose_nofold_non_cst_perms(%perms: tensor<2xi32>) -> tensor<3x2xf32> {
-  %input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
-  // CHECK: tosa.transpose
-  %1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
-  return %1 : tensor<3x2xf32>
-}
-
-// CHECK-LABEL: @transpose_nofold_multi_users
-func.func @transpose_nofold_multi_users() -> (tensor<3x2xf32>, tensor<2x3xf32>) {
-  %input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
-  %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
-  // CHECK: tosa.transpose
-  %1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
-  return %1, %input : tensor<3x2xf32>, tensor<2x3xf32>
-}
-
-// CHECK-LABEL: @transpose_nofold_quantized_types
-func.func @transpose_nofold_quantized_types() -> tensor<1x1x16x1x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>> {
-  %perms = "tosa.const"() {value = dense<[1, 2, 3, 0]> : tensor<4xi32>} : () -> tensor<4xi32>
-  %input = "tosa.const"() {value = dense<[[[[-127, 127, 127, -127, -127, -127, -127, -127, -127, 127, 127, 127, 127, 127, -127, 127]]]]> : tensor<1x1x1x16xi8>} : () -> tensor<1x1x1x16xi8>
-  // CHECK: tosa.transpose
-  %0 = "tosa.transpose"(%input, %perms) : (tensor<1x1x1x16xi8>, tensor<4xi32>) -> tensor<1x1x16x1x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>>
-  return %0: tensor<1x1x16x1x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>>
-}
-
 // CHECK-LABEL: @transpose_no_op
 func.func @transpose_no_op(%arg0: tensor<3x4x5x6xf32>) -> tensor<3x4x5x6xf32> {
   // CHECK: return %arg0

diff  --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
new file mode 100644
index 0000000000000..09f8245e771a7
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
@@ -0,0 +1,99 @@
+// RUN: mlir-opt --split-input-file --tosa-layerwise-constant-fold %s | FileCheck %s
+
+// CHECK-LABEL: @transpose_fold
+func.func @transpose_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> {
+  // CHECK: return %arg0
+  %0 = arith.constant dense<[0, 1]> : tensor<2xi32>
+  %1 = "tosa.transpose"(%arg0, %0) { perms = [1, 0] }: (tensor<3x4xf32>, tensor<2xi32>) -> tensor<3x4xf32>
+  return %1 : tensor<3x4xf32>
+}
+
+// CHECK-LABEL: @transpose_nofold
+func.func @transpose_nofold(%arg0: tensor<3x3xf32>) -> tensor<3x3xf32> {
+  // CHECK: "tosa.transpose"
+  %0 = arith.constant dense<[1, 0]> : tensor<2xi32>
+  %1 = "tosa.transpose"(%arg0, %0) { perms = [1, 0] }: (tensor<3x3xf32>, tensor<2xi32>) -> tensor<3x3xf32>
+  return %1 : tensor<3x3xf32>
+}
+
+// CHECK-LABEL: @transpose_nofold_shape
+func.func @transpose_nofold_shape(%arg0: tensor<3x4xf32>) -> tensor<?x?xf32> {
+  // CHECK: "tosa.transpose"
+  %0 = arith.constant dense<[1, 0]> : tensor<2xi32>
+  %1 = "tosa.transpose"(%arg0, %0) { perms = [1, 0] }: (tensor<3x4xf32>, tensor<2xi32>) -> tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: @transpose_fold_splat
+func.func @transpose_fold_splat() -> tensor<3x2xf32> {
+  %input = "tosa.const"() {value = dense<4.0> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
+  %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
+  //               CHECK: %[[CST:.+]] = "tosa.const"()
+  // CHECK-SAME{LITERAL}: value = dense<4.000000e+00> : tensor<3x2xf32>
+  %1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
+  // CHECK: return %[[CST]]
+  return %1 : tensor<3x2xf32>
+}
+
+// CHECK-LABEL: @transpose_fold_2d_float
+func.func @transpose_fold_2d_float() -> tensor<3x2xf32> {
+  %input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
+  %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
+  //               CHECK: %[[CST:.+]] = "tosa.const"()
+  // CHECK-SAME{LITERAL}: value = dense<[[0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf32>
+  %1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
+  // CHECK: return %[[CST]]
+  return %1 : tensor<3x2xf32>
+}
+
+// CHECK-LABEL: @transpose_fold_4d_int
+func.func @transpose_fold_4d_int() -> tensor<3x1x4x2xi32> {
+  %input = "tosa.const"() {value = dense<[[
+    [[ 0,  1,  2,  3], [ 4,  5,  6,  7], [ 8,  9, 10, 11]],
+    [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]
+  ]]> : tensor<1x2x3x4xi32>} : () -> tensor<1x2x3x4xi32>
+  %perms = "tosa.const"() {value = dense<[2, 0, 3, 1]> : tensor<4xi64>} : () -> tensor<4xi64>
+  //               CHECK: %[[CST:.+]] = "tosa.const"()
+  // CHECK-SAME{LITERAL}: value = dense<[
+  // CHECK-SAME{LITERAL}:   [[[0, 12], [1, 13], [2, 14], [3, 15]]],
+  // CHECK-SAME{LITERAL}:   [[[4, 16], [5, 17], [6, 18], [7, 19]]],
+  // CHECK-SAME{LITERAL}:   [[[8, 20], [9, 21], [10, 22], [11, 23]]]
+  // CHECK-SAME{LITERAL}: ]>
+  %1 = "tosa.transpose"(%input, %perms) : (tensor<1x2x3x4xi32>, tensor<4xi64>) -> tensor<3x1x4x2xi32>
+  // CHECK: return %[[CST]]
+  return %1 : tensor<3x1x4x2xi32>
+}
+
+// CHECK-LABEL: @transpose_nofold_non_cst_input
+func.func @transpose_nofold_non_cst_input(%input: tensor<2x3xf32>) -> tensor<3x2xf32> {
+  %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
+  // CHECK: tosa.transpose
+  %1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
+  return %1 : tensor<3x2xf32>
+}
+
+// CHECK-LABEL: @transpose_nofold_non_cst_perms
+func.func @transpose_nofold_non_cst_perms(%perms: tensor<2xi32>) -> tensor<3x2xf32> {
+  %input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
+  // CHECK: tosa.transpose
+  %1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
+  return %1 : tensor<3x2xf32>
+}
+
+// CHECK-LABEL: @transpose_nofold_multi_users
+func.func @transpose_nofold_multi_users() -> (tensor<3x2xf32>, tensor<2x3xf32>) {
+  %input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
+  %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
+  // CHECK: tosa.transpose
+  %1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
+  return %1, %input : tensor<3x2xf32>, tensor<2x3xf32>
+}
+
+// CHECK-LABEL: @transpose_nofold_quantized_types
+func.func @transpose_nofold_quantized_types() -> tensor<1x1x16x1x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>> {
+  %perms = "tosa.const"() {value = dense<[1, 2, 3, 0]> : tensor<4xi32>} : () -> tensor<4xi32>
+  %input = "tosa.const"() {value = dense<[[[[-127, 127, 127, -127, -127, -127, -127, -127, -127, 127, 127, 127, 127, 127, -127, 127]]]]> : tensor<1x1x1x16xi8>} : () -> tensor<1x1x1x16xi8>
+  // CHECK: tosa.transpose
+  %0 = "tosa.transpose"(%input, %perms) : (tensor<1x1x1x16xi8>, tensor<4xi32>) -> tensor<1x1x16x1x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>>
+  return %0: tensor<1x1x16x1x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>>
+}


        


More information about the Mlir-commits mailing list