[Mlir-commits] [mlir] 4142932 - [mlir][Linalg] Move named op conversions out of canonicalizations.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Dec 20 10:19:21 PST 2021


Author: MaheshRavishankar
Date: 2021-12-20T10:19:05-08:00
New Revision: 4142932a834f0dca9e9ae0c3754f097ffa3fc1ef

URL: https://github.com/llvm/llvm-project/commit/4142932a834f0dca9e9ae0c3754f097ffa3fc1ef
DIFF: https://github.com/llvm/llvm-project/commit/4142932a834f0dca9e9ae0c3754f097ffa3fc1ef.diff

LOG: [mlir][Linalg] Move named op conversions out of canonicalizations.

These conversions are better suited to be applied at whole tensor
level. Applying these as canonicalizations end up triggering such
canonicalizations at all levels of the stack which might be
undesirable. For example some of the resulting code patterns wont
bufferize in-place and need additional stack buffers. Best is to be
more deliberate in when these canonicalizations apply.

Differential Revision: https://reviews.llvm.org/D115912

Added: 
    mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp
    mlir/test/Dialect/Linalg/namedop_conversion.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/Passes.h
    mlir/include/mlir/Dialect/Linalg/Passes.td
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
    mlir/lib/Dialect/Linalg/Transforms/PassDetail.h
    mlir/test/Dialect/Linalg/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h
index f21252800af89..8ebaaa8f8e4dd 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.h
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.h
@@ -26,6 +26,8 @@ std::unique_ptr<Pass> createLinalgFoldUnitExtentDimsPass();
 std::unique_ptr<Pass> createLinalgElementwiseOpFusionPass();
 std::unique_ptr<Pass> createFoldReshapeOpsByLinearizationPass();
 
+std::unique_ptr<Pass> createLinalgNamedOpConversionPass();
+
 std::unique_ptr<OperationPass<FuncOp>> createLinalgTilingPass(
     ArrayRef<int64_t> tileSizes = {},
     linalg::LinalgTilingLoopType loopType = linalg::LinalgTilingLoopType::Loops,

diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 504bc562148f6..5bcc8cc6e33f1 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -100,6 +100,12 @@ def LinalgFoldReshapeOpsByLinearization :
   let dependentDialects = ["AffineDialect", "memref::MemRefDialect"];
 }
 
+def LinalgNamedOpConversion: Pass<"linalg-named-op-conversion"> {
+  let summary = "Convert from one named linalg op to another.";
+  let constructor = "mlir::createLinalgNamedOpConversionPass()";
+  let dependentDialects = ["linalg::LinalgDialect", "tensor::TensorDialect"];
+}
+
 def LinalgLowerTiledLoopsToSCF
     : FunctionPass<"convert-linalg-tiled-loops-to-scf"> {
   let summary = "Lower linalg tiled loops to SCF loops and parallel loops";

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index c14259f7babad..34eef99dc729e 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -86,6 +86,10 @@ void populateFoldReshapeOpsByLinearizationPatterns(RewritePatternSet &patterns);
 void populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
     RewritePatternSet &patterns);
 
+/// Patterns to convert from one named op to another. These can be seen as
+/// canonicalizations of named ops into another named op.
+void populateLinalgNamedOpConversionPatterns(RewritePatternSet &patterns);
+
 /// Populates the given list with patterns to bufferize linalg ops.
 void populateLinalgBufferizePatterns(
     bufferization::BufferizeTypeConverter &converter,

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 6e8b08bcbb8e1..26a0c9277b327 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2665,118 +2665,6 @@ struct FoldTensorCastOp : public OpInterfaceRewritePattern<LinalgOp> {
   }
 };
 
-static llvm::SmallVector<int64_t> getIndicesVector(int start, int end) {
-  return llvm::to_vector<2>(llvm::seq<int64_t>(start, end));
-}
-
-LogicalResult matchAndReplaceDepthwiseConv(Operation *operation, Value input,
-                                           Value kernel, Value iZp, Value kZp,
-                                           Value init, Attribute stride,
-                                           Attribute dilation,
-                                           PatternRewriter &rewriter) {
-  Location loc = operation->getLoc();
-  auto linalgOp = dyn_cast<LinalgOp>(operation);
-  // Exit out on the memref version of this operation.
-  if (!linalgOp || !linalgOp.hasTensorSemantics())
-    return failure();
-
-  auto result = operation->getResult(0);
-
-  auto kernelTy = kernel.getType().dyn_cast<RankedTensorType>();
-  auto initTy = init.getType().dyn_cast<RankedTensorType>();
-  auto resultTy = result.getType().template dyn_cast<RankedTensorType>();
-  if (!kernelTy || !initTy || !resultTy)
-    return failure();
-
-  if (kernelTy.getDimSize(3) != 1)
-    return failure();
-
-  // Collapse kernel dims.
-  SmallVector<ReassociationIndices, 4> collapsedKernelDims = {
-      getIndicesVector(0, 1), getIndicesVector(1, 2), getIndicesVector(2, 4)};
-  auto newKernelTy = RankedTensorType::get(
-      {kernelTy.getDimSize(0), kernelTy.getDimSize(1), kernelTy.getDimSize(2)},
-      kernelTy.getElementType());
-  auto collapsedKernel = rewriter.create<tensor::CollapseShapeOp>(
-      loc, newKernelTy, kernel, collapsedKernelDims);
-
-  // Collapse init dims.
-  SmallVector<ReassociationIndices, 4> collapsedInitDims = {
-      getIndicesVector(0, 1), getIndicesVector(1, 2), getIndicesVector(2, 3),
-      getIndicesVector(3, 5)};
-  auto newInitTy =
-      RankedTensorType::get({initTy.getDimSize(0), initTy.getDimSize(1),
-                             initTy.getDimSize(2), initTy.getDimSize(3)},
-                            initTy.getElementType());
-  auto collapsedInit = rewriter.create<tensor::CollapseShapeOp>(
-      loc, newInitTy, init, collapsedInitDims);
-
-  Value newConv;
-  if (isa<DepthwiseConv2DNhwcHwcmOp>(operation)) {
-    newConv = rewriter
-                  .create<DepthwiseConv2DNhwcHwcOp>(
-                      loc, newInitTy, ValueRange{input, collapsedKernel},
-                      ValueRange{collapsedInit}, stride, dilation)
-                  .getResult(0);
-  } else if (isa<DepthwiseConv2DNhwcHwcmQOp>(operation)) {
-    newConv =
-        rewriter
-            .create<DepthwiseConv2DNhwcHwcQOp>(
-                loc, newInitTy, ValueRange{input, collapsedKernel, iZp, kZp},
-                ValueRange{collapsedInit}, stride, dilation)
-            .getResult(0);
-  }
-
-  if (!newConv)
-    return failure();
-
-  // Expand dimensions back out to
-  rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
-      operation, resultTy, newConv, collapsedInitDims);
-  return success();
-}
-
-struct SimplifyDepthwiseConvOp
-    : public OpRewritePattern<DepthwiseConv2DNhwcHwcmOp> {
-  using OpRewritePattern<DepthwiseConv2DNhwcHwcmOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcmOp op,
-                                PatternRewriter &rewriter) const override {
-    Operation *operation = op.getOperation();
-    Value input = op.getInputOperand(0)->get();
-    Value kernel = op.getInputOperand(1)->get();
-    Value init = op.getOutputOperand(0)->get();
-
-    auto stride = op.strides();
-    auto dilation = op.dilations();
-
-    return matchAndReplaceDepthwiseConv(operation, input, kernel, nullptr,
-                                        nullptr, init, stride, dilation,
-                                        rewriter);
-  }
-};
-
-struct SimplifyDepthwiseConvQOp
-    : public OpRewritePattern<DepthwiseConv2DNhwcHwcmQOp> {
-  using OpRewritePattern<DepthwiseConv2DNhwcHwcmQOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcmQOp op,
-                                PatternRewriter &rewriter) const override {
-    Operation *operation = op.getOperation();
-    Value input = op.getInputOperand(0)->get();
-    Value kernel = op.getInputOperand(1)->get();
-    Value iZp = op.getInputOperand(2)->get();
-    Value kZp = op.getInputOperand(3)->get();
-    Value init = op.getOutputOperand(0)->get();
-
-    auto stride = op.strides();
-    auto dilation = op.dilations();
-
-    return matchAndReplaceDepthwiseConv(operation, input, kernel, iZp, kZp,
-                                        init, stride, dilation, rewriter);
-  }
-};
-
 } // namespace
 
 #define LINALGOP_FOLDERS(XXX)                                                  \
@@ -2798,8 +2686,7 @@ LINALGOP_FOLDERS(GenericOp)
 
 void LinalgDialect::getCanonicalizationPatterns(
     RewritePatternSet &results) const {
-  results.add<EraseDeadLinalgOp, FoldTensorCastOp, SimplifyDepthwiseConvOp,
-              SimplifyDepthwiseConvQOp>(getContext());
+  results.add<EraseDeadLinalgOp, FoldTensorCastOp>(getContext());
 }
 
 Operation *LinalgDialect::materializeConstant(OpBuilder &builder,

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 5edede26bb720..5df61c73fcc6f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -16,6 +16,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   Interchange.cpp
   Loops.cpp
   LinalgStrategyPasses.cpp
+  NamedOpConversions.cpp
   Promotion.cpp
   Tiling.cpp
   Transforms.cpp

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp b/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp
new file mode 100644
index 0000000000000..bb38607d769ac
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp
@@ -0,0 +1,160 @@
+//===- NamedOpConversions.cpp - Implements conversions between named ops --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements conversions between named ops that can be seens as
+// canonicalizations of named ops.
+//
+//===----------------------------------------------------------------------===//
+#include "PassDetail.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include "llvm/ADT/SmallVector.h"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+static llvm::SmallVector<int64_t> getIndicesVector(int start, int end) {
+  return llvm::to_vector<2>(llvm::seq<int64_t>(start, end));
+}
+
+static LogicalResult
+matchAndReplaceDepthwiseConv(Operation *operation, Value input, Value kernel,
+                             Value iZp, Value kZp, Value init, Attribute stride,
+                             Attribute dilation, PatternRewriter &rewriter) {
+  Location loc = operation->getLoc();
+  auto linalgOp = dyn_cast<LinalgOp>(operation);
+  // Exit out on the memref version of this operation.
+  if (!linalgOp || !linalgOp.hasTensorSemantics())
+    return failure();
+
+  auto result = operation->getResult(0);
+
+  auto kernelTy = kernel.getType().dyn_cast<RankedTensorType>();
+  auto initTy = init.getType().dyn_cast<RankedTensorType>();
+  auto resultTy = result.getType().template dyn_cast<RankedTensorType>();
+  if (!kernelTy || !initTy || !resultTy)
+    return failure();
+
+  if (kernelTy.getDimSize(3) != 1)
+    return failure();
+
+  // Collapse kernel dims.
+  SmallVector<ReassociationIndices, 4> collapsedKernelDims = {
+      getIndicesVector(0, 1), getIndicesVector(1, 2), getIndicesVector(2, 4)};
+  auto newKernelTy = RankedTensorType::get(
+      {kernelTy.getDimSize(0), kernelTy.getDimSize(1), kernelTy.getDimSize(2)},
+      kernelTy.getElementType());
+  auto collapsedKernel = rewriter.create<tensor::CollapseShapeOp>(
+      loc, newKernelTy, kernel, collapsedKernelDims);
+
+  // Collapse init dims.
+  SmallVector<ReassociationIndices, 4> collapsedInitDims = {
+      getIndicesVector(0, 1), getIndicesVector(1, 2), getIndicesVector(2, 3),
+      getIndicesVector(3, 5)};
+  auto newInitTy =
+      RankedTensorType::get({initTy.getDimSize(0), initTy.getDimSize(1),
+                             initTy.getDimSize(2), initTy.getDimSize(3)},
+                            initTy.getElementType());
+  auto collapsedInit = rewriter.create<tensor::CollapseShapeOp>(
+      loc, newInitTy, init, collapsedInitDims);
+
+  Value newConv;
+  if (isa<DepthwiseConv2DNhwcHwcmOp>(operation)) {
+    newConv = rewriter
+                  .create<DepthwiseConv2DNhwcHwcOp>(
+                      loc, newInitTy, ValueRange{input, collapsedKernel},
+                      ValueRange{collapsedInit}, stride, dilation)
+                  .getResult(0);
+  } else if (isa<DepthwiseConv2DNhwcHwcmQOp>(operation)) {
+    newConv =
+        rewriter
+            .create<DepthwiseConv2DNhwcHwcQOp>(
+                loc, newInitTy, ValueRange{input, collapsedKernel, iZp, kZp},
+                ValueRange{collapsedInit}, stride, dilation)
+            .getResult(0);
+  }
+
+  if (!newConv)
+    return failure();
+
+  // Expand dimensions back out to
+  rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
+      operation, resultTy, newConv, collapsedInitDims);
+  return success();
+}
+
+namespace {
+struct SimplifyDepthwiseConvOp
+    : public OpRewritePattern<DepthwiseConv2DNhwcHwcmOp> {
+  using OpRewritePattern<DepthwiseConv2DNhwcHwcmOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcmOp op,
+                                PatternRewriter &rewriter) const override {
+    Operation *operation = op.getOperation();
+    Value input = op.getInputOperand(0)->get();
+    Value kernel = op.getInputOperand(1)->get();
+    Value init = op.getOutputOperand(0)->get();
+
+    auto stride = op.strides();
+    auto dilation = op.dilations();
+
+    return matchAndReplaceDepthwiseConv(operation, input, kernel, nullptr,
+                                        nullptr, init, stride, dilation,
+                                        rewriter);
+  }
+};
+
+struct SimplifyDepthwiseConvQOp
+    : public OpRewritePattern<DepthwiseConv2DNhwcHwcmQOp> {
+  using OpRewritePattern<DepthwiseConv2DNhwcHwcmQOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcmQOp op,
+                                PatternRewriter &rewriter) const override {
+    Operation *operation = op.getOperation();
+    Value input = op.getInputOperand(0)->get();
+    Value kernel = op.getInputOperand(1)->get();
+    Value iZp = op.getInputOperand(2)->get();
+    Value kZp = op.getInputOperand(3)->get();
+    Value init = op.getOutputOperand(0)->get();
+
+    auto stride = op.strides();
+    auto dilation = op.dilations();
+
+    return matchAndReplaceDepthwiseConv(operation, input, kernel, iZp, kZp,
+                                        init, stride, dilation, rewriter);
+  }
+};
+
+struct LinalgNamedOpConversionPass
+    : public LinalgNamedOpConversionBase<LinalgNamedOpConversionPass> {
+  LinalgNamedOpConversionPass() = default;
+  LinalgNamedOpConversionPass(const LinalgNamedOpConversionPass &) {}
+
+  void runOnOperation() override {
+    Operation *op = getOperation();
+    RewritePatternSet patterns(op->getContext());
+    populateLinalgNamedOpConversionPatterns(patterns);
+    if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
+      return signalPassFailure();
+  }
+};
+} // namespace
+
+void mlir::linalg::populateLinalgNamedOpConversionPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<SimplifyDepthwiseConvOp, SimplifyDepthwiseConvQOp>(
+      patterns.getContext());
+}
+
+std::unique_ptr<Pass> mlir::createLinalgNamedOpConversionPass() {
+  return std::make_unique<LinalgNamedOpConversionPass>();
+}

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/PassDetail.h b/mlir/lib/Dialect/Linalg/Transforms/PassDetail.h
index b499dbbf0322e..78cb590f0697e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PassDetail.h
+++ b/mlir/lib/Dialect/Linalg/Transforms/PassDetail.h
@@ -38,6 +38,10 @@ namespace memref {
 class MemRefDialect;
 } // namespace memref
 
+namespace tensor {
+class TensorDialect;
+} // namespace tensor
+
 namespace vector {
 class VectorDialect;
 } // namespace vector

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 5465fee05f983..a6913d6f06e2a 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -758,28 +758,3 @@ func @dim_of_tiled_loop_result_no_canonicalize(%arg0: tensor<?x?xf32>, %arg1: te
   %r2 = tensor.dim %r, %c0 : tensor<?x?xf32>
   return %r2 : index
 }
-
-// -----
-
-// CHECK-LABEL: @depthwise_conv
-func @depthwise_conv(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?x?x?x1xf32>, %arg2: tensor<?x?x?x?x1xf32>) -> tensor<?x?x?x?x1xf32> {
-  // CHECK-DAG: %[[KERNEL:.+]] = tensor.collapse_shape %arg1 {{\[\[}}0], [1], [2, 3]]
-  // CHECK-DAG: %[[INIT:.+]] = tensor.collapse_shape %arg2 {{\[\[}}0], [1], [2], [3, 4]]
-  // CHECK-DAG: %[[CONV:.+]] = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %[[KERNEL]] : tensor<?x?x?x?xf32>, tensor<?x?x?xf32>) outs(%[[INIT]] : tensor<?x?x?x?xf32>)
-  // CHECK: %[[OUT:.+]] = tensor.expand_shape %[[CONV]] {{\[\[}}0], [1], [2], [3, 4]]
-  %0 = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<?x?x?x?xf32>, tensor<?x?x?x1xf32>) outs(%arg2 : tensor<?x?x?x?x1xf32>) -> tensor<?x?x?x?x1xf32>
-  return %0 : tensor<?x?x?x?x1xf32>
-}
-
-
-// -----
-
-// CHECK-LABEL: @depthwise_conv_q
-func @depthwise_conv_q(%arg0: tensor<?x?x?x?xi8>, %arg1: tensor<?x?x?x1xi8>, %arg2: tensor<?x?x?x?x1xi32>, %arg3 : i32, %arg4 : i32) -> tensor<?x?x?x?x1xi32> {
-  // CHECK-DAG: %[[KERNEL:.+]] = tensor.collapse_shape %arg1 {{\[\[}}0], [1], [2, 3]]
-  // CHECK-DAG: %[[INIT:.+]] = tensor.collapse_shape %arg2 {{\[\[}}0], [1], [2], [3, 4]]
-  // CHECK-DAG: %[[CONV:.+]] = linalg.depthwise_conv_2d_nhwc_hwc_q {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %[[KERNEL]], %arg3, %arg4 : tensor<?x?x?x?xi8>, tensor<?x?x?xi8>, i32, i32) outs(%[[INIT]] : tensor<?x?x?x?xi32>)
-  // CHECK: %[[OUT:.+]] = tensor.expand_shape %[[CONV]] {{\[\[}}0], [1], [2], [3, 4]]
-  %0 = linalg.depthwise_conv_2d_nhwc_hwcm_q {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1, %arg3, %arg4 : tensor<?x?x?x?xi8>, tensor<?x?x?x1xi8>, i32, i32) outs(%arg2 : tensor<?x?x?x?x1xi32>) -> tensor<?x?x?x?x1xi32>
-  return %0 : tensor<?x?x?x?x1xi32>
-}

diff  --git a/mlir/test/Dialect/Linalg/namedop_conversion.mlir b/mlir/test/Dialect/Linalg/namedop_conversion.mlir
new file mode 100644
index 0000000000000..5f33f650930e2
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/namedop_conversion.mlir
@@ -0,0 +1,24 @@
+// RUN: mlir-opt %s -linalg-named-op-conversion -split-input-file | FileCheck %s
+
+// CHECK-LABEL: @depthwise_conv
+func @depthwise_conv(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?x?x?x1xf32>, %arg2: tensor<?x?x?x?x1xf32>) -> tensor<?x?x?x?x1xf32> {
+  // CHECK-DAG: %[[KERNEL:.+]] = tensor.collapse_shape %arg1 {{\[\[}}0], [1], [2, 3]]
+  // CHECK-DAG: %[[INIT:.+]] = tensor.collapse_shape %arg2 {{\[\[}}0], [1], [2], [3, 4]]
+  // CHECK-DAG: %[[CONV:.+]] = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %[[KERNEL]] : tensor<?x?x?x?xf32>, tensor<?x?x?xf32>) outs(%[[INIT]] : tensor<?x?x?x?xf32>)
+  // CHECK: %[[OUT:.+]] = tensor.expand_shape %[[CONV]] {{\[\[}}0], [1], [2], [3, 4]]
+  %0 = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<?x?x?x?xf32>, tensor<?x?x?x1xf32>) outs(%arg2 : tensor<?x?x?x?x1xf32>) -> tensor<?x?x?x?x1xf32>
+  return %0 : tensor<?x?x?x?x1xf32>
+}
+
+
+// -----
+
+// CHECK-LABEL: @depthwise_conv_q
+func @depthwise_conv_q(%arg0: tensor<?x?x?x?xi8>, %arg1: tensor<?x?x?x1xi8>, %arg2: tensor<?x?x?x?x1xi32>, %arg3 : i32, %arg4 : i32) -> tensor<?x?x?x?x1xi32> {
+  // CHECK-DAG: %[[KERNEL:.+]] = tensor.collapse_shape %arg1 {{\[\[}}0], [1], [2, 3]]
+  // CHECK-DAG: %[[INIT:.+]] = tensor.collapse_shape %arg2 {{\[\[}}0], [1], [2], [3, 4]]
+  // CHECK-DAG: %[[CONV:.+]] = linalg.depthwise_conv_2d_nhwc_hwc_q {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %[[KERNEL]], %arg3, %arg4 : tensor<?x?x?x?xi8>, tensor<?x?x?xi8>, i32, i32) outs(%[[INIT]] : tensor<?x?x?x?xi32>)
+  // CHECK: %[[OUT:.+]] = tensor.expand_shape %[[CONV]] {{\[\[}}0], [1], [2], [3, 4]]
+  %0 = linalg.depthwise_conv_2d_nhwc_hwcm_q {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1, %arg3, %arg4 : tensor<?x?x?x?xi8>, tensor<?x?x?x1xi8>, i32, i32) outs(%arg2 : tensor<?x?x?x?x1xi32>) -> tensor<?x?x?x?x1xi32>
+  return %0 : tensor<?x?x?x?x1xi32>
+}


        


More information about the Mlir-commits mailing list