[Mlir-commits] [mlir] TosaToLinalgNamed: add option to prefer HWCF kernel layout for Conv2D ops. (PR #70482)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Oct 27 11:56:50 PDT 2023
https://github.com/bjacob updated https://github.com/llvm/llvm-project/pull/70482
>From bc3a83fbc15eae0b80314cf5c63abf62e69bda2c Mon Sep 17 00:00:00 2001
From: Benoit Jacob <jacob.benoit.1 at gmail.com>
Date: Fri, 27 Oct 2023 11:26:54 -0400
Subject: [PATCH 1/2] HWCF
---
mlir/include/mlir/Conversion/Passes.td | 6 +++
.../Conversion/TosaToLinalg/TosaToLinalg.h | 11 ++++-
.../TosaToLinalg/TosaToLinalgNamed.cpp | 44 ++++++++++++++++++-
.../TosaToLinalg/TosaToLinalgNamedPass.cpp | 14 ++++--
.../TosaToLinalg/TosaToLinalgPass.cpp | 6 ++-
.../TosaToLinalg/tosa-to-linalg-named.mlir | 7 +++
6 files changed, 80 insertions(+), 8 deletions(-)
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index f05e5a8ae667dab..336f0d3af951b9a 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1126,6 +1126,12 @@ def TosaToLinalgNamed
Linalg named operations.
}];
+ let options = [
+ Option<"preferConv2DKernelLayoutHWCF", "prefer-conv2d-kernel-layout-hwcf",
+ "bool", /*default=*/"false",
+ "Prefer generating linalg.conv_2d_nhwc_hwcf over linalg.conv_2d_nhwc_fhwc">
+ ];
+
let constructor = "tosa::createTosaToLinalgNamed()";
}
diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
index b4c4eb8651a6f00..7497a716e048d95 100644
--- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
+++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
@@ -26,7 +26,8 @@ namespace mlir {
namespace tosa {
std::unique_ptr<Pass> createTosaToLinalg();
-std::unique_ptr<Pass> createTosaToLinalgNamed();
+std::unique_ptr<Pass> createTosaToLinalgNamed(
+ const TosaToLinalgNamedOptions &options = TosaToLinalgNamedOptions());
/// Populates passes to convert from TOSA to Linalg on buffers. At the end of
/// the pass, the function will only contain linalg ops or standard ops if the
@@ -34,6 +35,8 @@ std::unique_ptr<Pass> createTosaToLinalgNamed();
/// benchmarking performance improvements from the canonicalizations.
void addTosaToLinalgPasses(
OpPassManager &pm, const TosaToLinalgOptions &options,
+ const TosaToLinalgNamedOptions &tosaToLinalgNamedOptions =
+ TosaToLinalgNamedOptions(),
// Note: Default to 'none' level unless otherwise specified.
tosa::TosaValidationOptions const &validationOptions = {
tosa::TosaProfileEnum::Undefined, false, tosa::TosaLevelEnum::None});
@@ -45,8 +48,12 @@ void registerTosaToLinalgPipelines();
/// Populates conversion passes from TOSA dialect to Linalg dialect.
void populateTosaToLinalgConversionPatterns(RewritePatternSet *patterns);
+enum class Conv2DKernelLayout { FHWC, HWCF };
+
/// Populates conversion passes from TOSA dialect to Linalg named operations.
-void populateTosaToLinalgNamedConversionPatterns(RewritePatternSet *patterns);
+void populateTosaToLinalgNamedConversionPatterns(
+ RewritePatternSet *patterns,
+ Conv2DKernelLayout conv2DKernelLayout = Conv2DKernelLayout::FHWC);
} // namespace tosa
} // namespace mlir
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index ee8f52deadbd152..ae0b58acfd295b2 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -26,6 +26,7 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include <numeric>
+#include <type_traits>
using namespace mlir;
using namespace mlir::tosa;
@@ -248,6 +249,35 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
pad.resize(pad.size() + 2, 0);
input = applyPad(loc, input, pad, zeroAttr, rewriter);
+ if (4 == inputTy.getRank()) {
+ // For 2D convolutions, we need to check if the target convolution op
+ // wants a HWCF kernel layout.
+ bool wantHwcf =
+ isQuantized ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
+ : std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
+ if (wantHwcf) {
+ // Transpose the kernel to match dimension ordering of the linalg
+ // convolution operation.
+ // TODO(suderman): See if this can be efficiently folded - check whether
+ // the input is used anywhere else, if not fold the constant.
+ SmallVector<int64_t> weightPerm;
+ for (int i = 1; i < resultTy.getRank(); i++)
+ weightPerm.push_back(i);
+ weightPerm.push_back(0);
+
+ SmallVector<int64_t> newWeightShape;
+ for (auto dim : weightPerm)
+ newWeightShape.push_back(weightShape[dim]);
+ auto weightPermAttr = rewriter.getI64TensorAttr(weightPerm);
+ Value weightPermValue =
+ rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
+ Type newWeightTy =
+ RankedTensorType::get(newWeightShape, weightTy.getElementType());
+ weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
+ weightPermValue);
+ }
+ }
+
// For Conv3D transpose the kernel to match dimension ordering of the linalg
// convolution operation. Conv2D has a 1-1 mapping in linalg so better to
// map directly and then transpose later if desired.
@@ -977,10 +1007,20 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
} // namespace
void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
- RewritePatternSet *patterns) {
+ RewritePatternSet *patterns, Conv2DKernelLayout conv2DKernelLayout) {
+ if (conv2DKernelLayout == Conv2DKernelLayout::FHWC) {
+ patterns->add<ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcFhwcOp,
+ linalg::Conv2DNhwcFhwcQOp>>(
+ patterns->getContext());
+ } else if (conv2DKernelLayout == Conv2DKernelLayout::HWCF) {
+ patterns->add<ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcHwcfOp,
+ linalg::Conv2DNhwcHwcfQOp>>(
+ patterns->getContext());
+ } else {
+ assert(false);
+ }
patterns->add<
// clang-format off
- ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcFhwcOp, linalg::Conv2DNhwcFhwcQOp>,
ConvConverter<tosa::Conv3DOp, linalg::Conv3DNdhwcDhwcfOp, linalg::Conv3DNdhwcDhwcfQOp>,
DepthwiseConvConverter,
MatMulConverter,
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
index 4c941a109ed845e..e330c9cff141e40 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
@@ -37,6 +37,9 @@ namespace {
struct TosaToLinalgNamed
: public impl::TosaToLinalgNamedBase<TosaToLinalgNamed> {
public:
+ TosaToLinalgNamed(const TosaToLinalgNamedOptions &options)
+ : impl::TosaToLinalgNamedBase<TosaToLinalgNamed>(options) {}
+
void getDependentDialects(DialectRegistry ®istry) const override {
registry
.insert<arith::ArithDialect, linalg::LinalgDialect, math::MathDialect,
@@ -61,13 +64,18 @@ struct TosaToLinalgNamed
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
FunctionOpInterface func = getOperation();
- mlir::tosa::populateTosaToLinalgNamedConversionPatterns(&patterns);
+ tosa::Conv2DKernelLayout conv2DKernelLayout =
+ preferConv2DKernelLayoutHWCF ? tosa::Conv2DKernelLayout::HWCF
+ : tosa::Conv2DKernelLayout::FHWC;
+ tosa::populateTosaToLinalgNamedConversionPatterns(&patterns,
+ conv2DKernelLayout);
if (failed(applyFullConversion(func, target, std::move(patterns))))
signalPassFailure();
}
};
} // namespace
-std::unique_ptr<Pass> mlir::tosa::createTosaToLinalgNamed() {
- return std::make_unique<TosaToLinalgNamed>();
+std::unique_ptr<Pass>
+mlir::tosa::createTosaToLinalgNamed(const TosaToLinalgNamedOptions &options) {
+ return std::make_unique<TosaToLinalgNamed>(options);
}
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
index a486e28c50c7129..687477810030d4c 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
@@ -76,6 +76,7 @@ std::unique_ptr<Pass> mlir::tosa::createTosaToLinalg() {
void mlir::tosa::addTosaToLinalgPasses(
OpPassManager &pm, const TosaToLinalgOptions &options,
+ const TosaToLinalgNamedOptions &tosaToLinalgNamedOptions,
tosa::TosaValidationOptions const &validationOptions) {
// Optional decompositions are designed to benefit linalg.
if (!options.disableTosaDecompositions)
@@ -84,7 +85,8 @@ void mlir::tosa::addTosaToLinalgPasses(
pm.addNestedPass<func::FuncOp>(tosa::createTosaInferShapesPass());
pm.addNestedPass<func::FuncOp>(tosa::createTosaMakeBroadcastablePass());
- pm.addNestedPass<func::FuncOp>(tosa::createTosaToLinalgNamed());
+ pm.addNestedPass<func::FuncOp>(
+ tosa::createTosaToLinalgNamed(tosaToLinalgNamedOptions));
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
// TODO: Remove pass that operates on const tensor and enable optionality
pm.addNestedPass<func::FuncOp>(tosa::createTosaLayerwiseConstantFoldPass(
@@ -106,7 +108,9 @@ void mlir::tosa::registerTosaToLinalgPipelines() {
"named operations.",
[](OpPassManager &pm) {
TosaToLinalgOptions tosaToLinalgOptions;
+ TosaToLinalgNamedOptions tosaToLinalgNamedOptions;
tosa::addTosaToLinalgPasses(pm, tosaToLinalgOptions,
+ tosaToLinalgNamedOptions,
/* validationOptions = */
{tosa::TosaProfileEnum::BaseInference,
/* StrictOperationSpecAlignment = */ true,
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index b601bfb28a4f280..1cf7c8dee606899 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -1,4 +1,5 @@
// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named))" %s -verify-diagnostics -o -| FileCheck %s
+// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named{prefer-conv2d-kernel-layout-hwcf=true}))" %s -verify-diagnostics -o -| FileCheck --check-prefix="HWCF" %s
// CHECK-LABEL: @matmul
func.func @matmul(%arg0: tensor<1x5x3xf32>, %arg1: tensor<1x3x6xf32>) -> (tensor<1x5x6xf32>) {
@@ -363,11 +364,14 @@ func.func @avg_pool_dyn(%arg0: tensor<?x6x34x62xf32>) -> (tensor<?x5x33x62xf32>)
// CHECK-LABEL: @conv2d_i8
func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi8>, %bias: tensor<28xi8>) -> () {
+ // HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi64>
+ // HWCF: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[TRANSPOSE_DIMS]] : (tensor<28x1x1x27xi8>, tensor<4xi64>) -> tensor<1x1x27x28xi8>
// CHECK: %[[M_IN:.+]] = tensor.empty()
// CHECK: %[[CST:.+]] = arith.constant 0
// CHECK: %[[FILL:.+]] = linalg.fill
// CHECK: %[[B_IN:.+]] = tensor.empty()
// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, %c0_i32_0, %c0_i32_1 : tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, i32, i32) outs(%[[FILL]] : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
+ // HWCF: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[TRANSPOSE]], %c0_i32_0, %c0_i32_1 : tensor<1x49x42x27xi8>, tensor<1x1x27x28xi8>, i32, i32) outs(%{{[a-zA-Z0-9_]*}} : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
// CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xi8>, tensor<1x45x40x28xi32>) outs(%[[B_IN]] : tensor<1x45x40x28xi32>)
// CHECK: arith.extsi
// CHECK: arith.addi
@@ -383,11 +387,14 @@ func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi
// CHECK-LABEL: @conv2d_f32
func.func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<28xf32>) -> () {
+ // HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi64>
+ // HWCF: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[TRANSPOSE_DIMS]] : (tensor<28x3x3x27xf32>, tensor<4xi64>) -> tensor<3x3x27x28xf32>
// CHECK: %[[M_IN:.+]] = tensor.empty()
// CHECK: %[[CST:.+]] = arith.constant 0
// CHECK: %[[FILL:.+]] = linalg.fill
// CHECK: %[[B_IN:.+]] = tensor.empty()
// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>) outs(%[[FILL]] : tensor<1x45x40x28xf32>)
+ // HWCF: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[TRANSPOSE]] : tensor<1x49x42x27xf32>, tensor<3x3x27x28xf32>) outs(%{{[a-zA-Z0-9_]*}} : tensor<1x45x40x28xf32>
// CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<1x45x40x28xf32>) outs(%[[B_IN]] : tensor<1x45x40x28xf32>)
// CHECK: arith.addf
// CHECK: linalg.yield
>From f51c61e03dacb756682c9b49260e850ea9692642 Mon Sep 17 00:00:00 2001
From: Benoit Jacob <jacob.benoit.1 at gmail.com>
Date: Fri, 27 Oct 2023 14:56:37 -0400
Subject: [PATCH 2/2] review comment
---
.../mlir/Conversion/TosaToLinalg/TosaToLinalg.h | 5 +----
.../Conversion/TosaToLinalg/TosaToLinalgNamed.cpp | 12 +++++-------
.../TosaToLinalg/TosaToLinalgNamedPass.cpp | 8 +++-----
3 files changed, 9 insertions(+), 16 deletions(-)
diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
index 7497a716e048d95..5fd77c8a0211a6d 100644
--- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
+++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
@@ -48,12 +48,9 @@ void registerTosaToLinalgPipelines();
/// Populates conversion passes from TOSA dialect to Linalg dialect.
void populateTosaToLinalgConversionPatterns(RewritePatternSet *patterns);
-enum class Conv2DKernelLayout { FHWC, HWCF };
-
/// Populates conversion passes from TOSA dialect to Linalg named operations.
void populateTosaToLinalgNamedConversionPatterns(
- RewritePatternSet *patterns,
- Conv2DKernelLayout conv2DKernelLayout = Conv2DKernelLayout::FHWC);
+ RewritePatternSet *patterns, const TosaToLinalgNamedOptions &options);
} // namespace tosa
} // namespace mlir
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index ae0b58acfd295b2..99a65f63038a43f 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -1007,17 +1007,15 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
} // namespace
void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
- RewritePatternSet *patterns, Conv2DKernelLayout conv2DKernelLayout) {
- if (conv2DKernelLayout == Conv2DKernelLayout::FHWC) {
- patterns->add<ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcFhwcOp,
- linalg::Conv2DNhwcFhwcQOp>>(
- patterns->getContext());
- } else if (conv2DKernelLayout == Conv2DKernelLayout::HWCF) {
+ RewritePatternSet *patterns, const TosaToLinalgNamedOptions &options) {
+ if (options.preferConv2DKernelLayoutHWCF) {
patterns->add<ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcHwcfOp,
linalg::Conv2DNhwcHwcfQOp>>(
patterns->getContext());
} else {
- assert(false);
+ patterns->add<ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcFhwcOp,
+ linalg::Conv2DNhwcFhwcQOp>>(
+ patterns->getContext());
}
patterns->add<
// clang-format off
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
index e330c9cff141e40..5312dc164c26c5e 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
@@ -64,11 +64,9 @@ struct TosaToLinalgNamed
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
FunctionOpInterface func = getOperation();
- tosa::Conv2DKernelLayout conv2DKernelLayout =
- preferConv2DKernelLayoutHWCF ? tosa::Conv2DKernelLayout::HWCF
- : tosa::Conv2DKernelLayout::FHWC;
- tosa::populateTosaToLinalgNamedConversionPatterns(&patterns,
- conv2DKernelLayout);
+ TosaToLinalgNamedOptions options;
+ options.preferConv2DKernelLayoutHWCF = preferConv2DKernelLayoutHWCF;
+ tosa::populateTosaToLinalgNamedConversionPatterns(&patterns, options);
if (failed(applyFullConversion(func, target, std::move(patterns))))
signalPassFailure();
}
More information about the Mlir-commits
mailing list