[Mlir-commits] [mlir] TosaToLinalgNamed: add option to prefer HWCF kernel layout for Conv2D ops. (PR #70482)
    llvmlistbot at llvm.org 
    llvmlistbot at llvm.org
       
    Fri Oct 27 10:09:10 PDT 2023
    
    
  
https://github.com/bjacob created https://github.com/llvm/llvm-project/pull/70482
Switching to FHWC happened in https://github.com/llvm/llvm-project/pull/68304 and is fine in itself but caused downstream performance regression https://github.com/openxla/iree/pull/15296#issuecomment-1782994965 , so this PR makes this optional.
>From bc3a83fbc15eae0b80314cf5c63abf62e69bda2c Mon Sep 17 00:00:00 2001
From: Benoit Jacob <jacob.benoit.1 at gmail.com>
Date: Fri, 27 Oct 2023 11:26:54 -0400
Subject: [PATCH] HWCF
---
 mlir/include/mlir/Conversion/Passes.td        |  6 +++
 .../Conversion/TosaToLinalg/TosaToLinalg.h    | 11 ++++-
 .../TosaToLinalg/TosaToLinalgNamed.cpp        | 44 ++++++++++++++++++-
 .../TosaToLinalg/TosaToLinalgNamedPass.cpp    | 14 ++++--
 .../TosaToLinalg/TosaToLinalgPass.cpp         |  6 ++-
 .../TosaToLinalg/tosa-to-linalg-named.mlir    |  7 +++
 6 files changed, 80 insertions(+), 8 deletions(-)
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index f05e5a8ae667dab..336f0d3af951b9a 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1126,6 +1126,12 @@ def TosaToLinalgNamed
     Linalg named operations.
   }];
 
+  let options = [
+      Option<"preferConv2DKernelLayoutHWCF", "prefer-conv2d-kernel-layout-hwcf",
+           "bool", /*default=*/"false",
+           "Prefer generating linalg.conv_2d_nhwc_hwcf over linalg.conv_2d_nhwc_fhwc">
+  ];
+
   let constructor = "tosa::createTosaToLinalgNamed()";
 }
 
diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
index b4c4eb8651a6f00..7497a716e048d95 100644
--- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
+++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
@@ -26,7 +26,8 @@ namespace mlir {
 namespace tosa {
 
 std::unique_ptr<Pass> createTosaToLinalg();
-std::unique_ptr<Pass> createTosaToLinalgNamed();
+std::unique_ptr<Pass> createTosaToLinalgNamed(
+    const TosaToLinalgNamedOptions &options = TosaToLinalgNamedOptions());
 
 /// Populates passes to convert from TOSA to Linalg on buffers. At the end of
 /// the pass, the function will only contain linalg ops or standard ops if the
@@ -34,6 +35,8 @@ std::unique_ptr<Pass> createTosaToLinalgNamed();
 /// benchmarking performance improvements from the canonicalizations.
 void addTosaToLinalgPasses(
     OpPassManager &pm, const TosaToLinalgOptions &options,
+    const TosaToLinalgNamedOptions &tosaToLinalgNamedOptions =
+        TosaToLinalgNamedOptions(),
     // Note: Default to 'none' level unless otherwise specified.
     tosa::TosaValidationOptions const &validationOptions = {
         tosa::TosaProfileEnum::Undefined, false, tosa::TosaLevelEnum::None});
@@ -45,8 +48,12 @@ void registerTosaToLinalgPipelines();
 /// Populates conversion passes from TOSA dialect to Linalg dialect.
 void populateTosaToLinalgConversionPatterns(RewritePatternSet *patterns);
 
+enum class Conv2DKernelLayout { FHWC, HWCF };
+
 /// Populates conversion passes from TOSA dialect to Linalg named operations.
-void populateTosaToLinalgNamedConversionPatterns(RewritePatternSet *patterns);
+void populateTosaToLinalgNamedConversionPatterns(
+    RewritePatternSet *patterns,
+    Conv2DKernelLayout conv2DKernelLayout = Conv2DKernelLayout::FHWC);
 
 } // namespace tosa
 } // namespace mlir
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index ee8f52deadbd152..ae0b58acfd295b2 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -26,6 +26,7 @@
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 #include <numeric>
+#include <type_traits>
 
 using namespace mlir;
 using namespace mlir::tosa;
@@ -248,6 +249,35 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
     pad.resize(pad.size() + 2, 0);
     input = applyPad(loc, input, pad, zeroAttr, rewriter);
 
+    if (4 == inputTy.getRank()) {
+      // For 2D convolutions, we need to check if the target convolution op
+      // wants a HWCF kernel layout.
+      bool wantHwcf =
+          isQuantized ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
+                      : std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
+      if (wantHwcf) {
+        // Transpose the kernel to match dimension ordering of the linalg
+        // convolution operation.
+        // TODO(suderman): See if this can be efficiently folded - check whether
+        // the input is used anywhere else, if not fold the constant.
+        SmallVector<int64_t> weightPerm;
+        for (int i = 1; i < resultTy.getRank(); i++)
+          weightPerm.push_back(i);
+        weightPerm.push_back(0);
+
+        SmallVector<int64_t> newWeightShape;
+        for (auto dim : weightPerm)
+          newWeightShape.push_back(weightShape[dim]);
+        auto weightPermAttr = rewriter.getI64TensorAttr(weightPerm);
+        Value weightPermValue =
+            rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
+        Type newWeightTy =
+            RankedTensorType::get(newWeightShape, weightTy.getElementType());
+        weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
+                                                    weightPermValue);
+      }
+    }
+
     // For Conv3D transpose the kernel to match dimension ordering of the linalg
     // convolution operation. Conv2D has a 1-1 mapping in linalg so better to
     // map directly and then transpose later if desired.
@@ -977,10 +1007,20 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
 } // namespace
 
 void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
-    RewritePatternSet *patterns) {
+    RewritePatternSet *patterns, Conv2DKernelLayout conv2DKernelLayout) {
+  if (conv2DKernelLayout == Conv2DKernelLayout::FHWC) {
+    patterns->add<ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcFhwcOp,
+                                linalg::Conv2DNhwcFhwcQOp>>(
+        patterns->getContext());
+  } else if (conv2DKernelLayout == Conv2DKernelLayout::HWCF) {
+    patterns->add<ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcHwcfOp,
+                                linalg::Conv2DNhwcHwcfQOp>>(
+        patterns->getContext());
+  } else {
+    assert(false);
+  }
   patterns->add<
       // clang-format off
-      ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcFhwcOp, linalg::Conv2DNhwcFhwcQOp>,
       ConvConverter<tosa::Conv3DOp, linalg::Conv3DNdhwcDhwcfOp, linalg::Conv3DNdhwcDhwcfQOp>,
       DepthwiseConvConverter,
       MatMulConverter,
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
index 4c941a109ed845e..e330c9cff141e40 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
@@ -37,6 +37,9 @@ namespace {
 struct TosaToLinalgNamed
     : public impl::TosaToLinalgNamedBase<TosaToLinalgNamed> {
 public:
+  TosaToLinalgNamed(const TosaToLinalgNamedOptions &options)
+      : impl::TosaToLinalgNamedBase<TosaToLinalgNamed>(options) {}
+
   void getDependentDialects(DialectRegistry ®istry) const override {
     registry
         .insert<arith::ArithDialect, linalg::LinalgDialect, math::MathDialect,
@@ -61,13 +64,18 @@ struct TosaToLinalgNamed
     target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
 
     FunctionOpInterface func = getOperation();
-    mlir::tosa::populateTosaToLinalgNamedConversionPatterns(&patterns);
+    tosa::Conv2DKernelLayout conv2DKernelLayout =
+        preferConv2DKernelLayoutHWCF ? tosa::Conv2DKernelLayout::HWCF
+                                     : tosa::Conv2DKernelLayout::FHWC;
+    tosa::populateTosaToLinalgNamedConversionPatterns(&patterns,
+                                                      conv2DKernelLayout);
     if (failed(applyFullConversion(func, target, std::move(patterns))))
       signalPassFailure();
   }
 };
 } // namespace
 
-std::unique_ptr<Pass> mlir::tosa::createTosaToLinalgNamed() {
-  return std::make_unique<TosaToLinalgNamed>();
+std::unique_ptr<Pass>
+mlir::tosa::createTosaToLinalgNamed(const TosaToLinalgNamedOptions &options) {
+  return std::make_unique<TosaToLinalgNamed>(options);
 }
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
index a486e28c50c7129..687477810030d4c 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
@@ -76,6 +76,7 @@ std::unique_ptr<Pass> mlir::tosa::createTosaToLinalg() {
 
 void mlir::tosa::addTosaToLinalgPasses(
     OpPassManager &pm, const TosaToLinalgOptions &options,
+    const TosaToLinalgNamedOptions &tosaToLinalgNamedOptions,
     tosa::TosaValidationOptions const &validationOptions) {
   // Optional decompositions are designed to benefit linalg.
   if (!options.disableTosaDecompositions)
@@ -84,7 +85,8 @@ void mlir::tosa::addTosaToLinalgPasses(
 
   pm.addNestedPass<func::FuncOp>(tosa::createTosaInferShapesPass());
   pm.addNestedPass<func::FuncOp>(tosa::createTosaMakeBroadcastablePass());
-  pm.addNestedPass<func::FuncOp>(tosa::createTosaToLinalgNamed());
+  pm.addNestedPass<func::FuncOp>(
+      tosa::createTosaToLinalgNamed(tosaToLinalgNamedOptions));
   pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
   // TODO: Remove pass that operates on const tensor and enable optionality
   pm.addNestedPass<func::FuncOp>(tosa::createTosaLayerwiseConstantFoldPass(
@@ -106,7 +108,9 @@ void mlir::tosa::registerTosaToLinalgPipelines() {
       "named operations.",
       [](OpPassManager &pm) {
         TosaToLinalgOptions tosaToLinalgOptions;
+        TosaToLinalgNamedOptions tosaToLinalgNamedOptions;
         tosa::addTosaToLinalgPasses(pm, tosaToLinalgOptions,
+                                    tosaToLinalgNamedOptions,
                                     /* validationOptions = */
                                     {tosa::TosaProfileEnum::BaseInference,
                                      /* StrictOperationSpecAlignment = */ true,
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index b601bfb28a4f280..1cf7c8dee606899 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -1,4 +1,5 @@
 // RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named))" %s -verify-diagnostics -o -| FileCheck %s
+// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named{prefer-conv2d-kernel-layout-hwcf=true}))" %s -verify-diagnostics -o -| FileCheck --check-prefix="HWCF" %s
 
 // CHECK-LABEL: @matmul
 func.func @matmul(%arg0: tensor<1x5x3xf32>, %arg1: tensor<1x3x6xf32>) -> (tensor<1x5x6xf32>) {
@@ -363,11 +364,14 @@ func.func @avg_pool_dyn(%arg0: tensor<?x6x34x62xf32>) -> (tensor<?x5x33x62xf32>)
 
 // CHECK-LABEL: @conv2d_i8
 func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi8>, %bias: tensor<28xi8>) -> () {
+  // HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi64>
+  // HWCF: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[TRANSPOSE_DIMS]] : (tensor<28x1x1x27xi8>, tensor<4xi64>) -> tensor<1x1x27x28xi8>
   // CHECK: %[[M_IN:.+]] = tensor.empty()
   // CHECK: %[[CST:.+]] = arith.constant 0
   // CHECK: %[[FILL:.+]] = linalg.fill
   // CHECK: %[[B_IN:.+]] = tensor.empty()
   // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, %c0_i32_0, %c0_i32_1 : tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, i32, i32) outs(%[[FILL]] : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
+  // HWCF: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[TRANSPOSE]], %c0_i32_0, %c0_i32_1 : tensor<1x49x42x27xi8>, tensor<1x1x27x28xi8>, i32, i32) outs(%{{[a-zA-Z0-9_]*}} : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
   // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xi8>, tensor<1x45x40x28xi32>) outs(%[[B_IN]] : tensor<1x45x40x28xi32>)
   // CHECK:   arith.extsi
   // CHECK:   arith.addi
@@ -383,11 +387,14 @@ func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi
 
 // CHECK-LABEL: @conv2d_f32
 func.func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<28xf32>) -> () {
+  // HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi64>
+  // HWCF: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[TRANSPOSE_DIMS]] : (tensor<28x3x3x27xf32>, tensor<4xi64>) -> tensor<3x3x27x28xf32>
   // CHECK: %[[M_IN:.+]] = tensor.empty()
   // CHECK: %[[CST:.+]] = arith.constant 0
   // CHECK: %[[FILL:.+]] = linalg.fill
   // CHECK: %[[B_IN:.+]] = tensor.empty()
   // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>) outs(%[[FILL]] : tensor<1x45x40x28xf32>)
+  // HWCF: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[TRANSPOSE]] : tensor<1x49x42x27xf32>, tensor<3x3x27x28xf32>) outs(%{{[a-zA-Z0-9_]*}} : tensor<1x45x40x28xf32>
   // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<1x45x40x28xf32>) outs(%[[B_IN]] : tensor<1x45x40x28xf32>)
   // CHECK:   arith.addf
   // CHECK:   linalg.yield
    
    
More information about the Mlir-commits
mailing list