[Mlir-commits] [mlir] 1ac2d19 - [mlir][linalg] Add canonicalizers for depthwise conv
Rob Suderman
llvmlistbot at llvm.org
Wed Sep 15 14:13:58 PDT 2021
Author: Rob Suderman
Date: 2021-09-15T14:09:15-07:00
New Revision: 1ac2d195ecb5d4c549c11b9c1df00179f5fea7ed
URL: https://github.com/llvm/llvm-project/commit/1ac2d195ecb5d4c549c11b9c1df00179f5fea7ed
DIFF: https://github.com/llvm/llvm-project/commit/1ac2d195ecb5d4c549c11b9c1df00179f5fea7ed.diff
LOG: [mlir][linalg] Add canonicalizers for depthwise conv
There are two main versions of depthwise conv depending whether the multiplier
is 1 or not. In cases where m == 1 we should use the version without the
multiplier channel as it can perform greater optimization.
Add lowering for the quantized/float versions to have a multiplier of one.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D108959
Added:
Modified:
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index a0e02400f7ce..b3eeaabc780e 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -3045,6 +3045,119 @@ struct FoldTensorCastOp : public OpInterfaceRewritePattern<LinalgOp> {
return success();
}
};
+
+static llvm::SmallVector<int64_t> getIndicesVector(int start, int end) {
+ return llvm::to_vector<2>(llvm::seq<int64_t>(start, end));
+}
+
+LogicalResult matchAndReplaceDepthwiseConv(Operation *operation, Value input,
+ Value kernel, Value iZp, Value kZp,
+ Value init, Attribute stride,
+ Attribute dilation,
+ PatternRewriter &rewriter) {
+ Location loc = operation->getLoc();
+ auto linalgOp = dyn_cast<LinalgOp>(operation);
+ // Exit out on the memref version of this operation.
+ if (!linalgOp || !linalgOp.hasTensorSemantics())
+ return failure();
+
+ auto result = operation->getResult(0);
+
+ auto kernelTy = kernel.getType().dyn_cast<RankedTensorType>();
+ auto initTy = init.getType().dyn_cast<RankedTensorType>();
+ auto resultTy = result.getType().template dyn_cast<RankedTensorType>();
+ if (!kernelTy || !initTy || !resultTy)
+ return failure();
+
+ if (kernelTy.getDimSize(3) != 1)
+ return failure();
+
+ // Collapse kernel dims.
+ SmallVector<ReassociationIndices, 4> collapsedKernelDims = {
+ getIndicesVector(0, 1), getIndicesVector(1, 2), getIndicesVector(2, 4)};
+ auto newKernelTy = RankedTensorType::get(
+ {kernelTy.getDimSize(0), kernelTy.getDimSize(1), kernelTy.getDimSize(2)},
+ kernelTy.getElementType());
+ auto collapsedKernel = rewriter.create<linalg::TensorCollapseShapeOp>(
+ loc, newKernelTy, kernel, collapsedKernelDims);
+
+ // Collapse init dims.
+ SmallVector<ReassociationIndices, 4> collapsedInitDims = {
+ getIndicesVector(0, 1), getIndicesVector(1, 2), getIndicesVector(2, 3),
+ getIndicesVector(3, 5)};
+ auto newInitTy =
+ RankedTensorType::get({initTy.getDimSize(0), initTy.getDimSize(1),
+ initTy.getDimSize(2), initTy.getDimSize(3)},
+ initTy.getElementType());
+ auto collapsedInit = rewriter.create<linalg::TensorCollapseShapeOp>(
+ loc, newInitTy, init, collapsedInitDims);
+
+ Value newConv;
+ if (isa<DepthwiseConv2DNhwcOp>(operation)) {
+ newConv = rewriter
+ .create<DepthwiseConv2DNhwOp>(
+ loc, newInitTy, ValueRange{input, collapsedKernel},
+ ValueRange{collapsedInit}, stride, dilation)
+ .getResult(0);
+ } else if (isa<DepthwiseConv2DNhwcQOp>(operation)) {
+ newConv =
+ rewriter
+ .create<DepthwiseConv2DNhwQOp>(
+ loc, newInitTy, ValueRange{input, collapsedKernel, iZp, kZp},
+ ValueRange{collapsedInit}, stride, dilation)
+ .getResult(0);
+ }
+
+ if (!newConv)
+ return failure();
+
+ // Expand dimensions back out to
+ rewriter.replaceOpWithNewOp<linalg::TensorExpandShapeOp>(
+ operation, resultTy, newConv, collapsedInitDims);
+ return success();
+}
+
+struct SimplifyDepthwiseConvOp
+ : public OpRewritePattern<DepthwiseConv2DNhwcOp> {
+ using OpRewritePattern<DepthwiseConv2DNhwcOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(DepthwiseConv2DNhwcOp op,
+ PatternRewriter &rewriter) const override {
+ Operation *operation = op.getOperation();
+ Value input = op.getInputOperand(0)->get();
+ Value kernel = op.getInputOperand(1)->get();
+ Value init = op.getOutputOperand(0)->get();
+
+ auto stride = op.strides();
+ auto dilation = op.dilations();
+
+ return matchAndReplaceDepthwiseConv(operation, input, kernel, nullptr,
+ nullptr, init, stride, dilation,
+ rewriter);
+ }
+};
+
+struct SimplifyDepthwiseConvQOp
+ : public OpRewritePattern<DepthwiseConv2DNhwcQOp> {
+ using OpRewritePattern<DepthwiseConv2DNhwcQOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(DepthwiseConv2DNhwcQOp op,
+ PatternRewriter &rewriter) const override {
+ Operation *operation = op.getOperation();
+ Value input = op.getInputOperand(0)->get();
+ Value kernel = op.getInputOperand(1)->get();
+ Value iZp = op.getInputOperand(2)->get();
+ Value kZp = op.getInputOperand(3)->get();
+ Value init = op.getOutputOperand(0)->get();
+
+ auto stride = op.strides();
+ auto dilation = op.dilations();
+
+ return matchAndReplaceDepthwiseConv(operation, input, kernel, iZp, kZp,
+ init, stride, dilation, rewriter);
+ }
+};
+
} // namespace
#define LINALGOP_FOLDERS(XXX) \
@@ -3070,5 +3183,6 @@ LINALGOP_FOLDERS(GenericOp)
void LinalgDialect::getCanonicalizationPatterns(
RewritePatternSet &results) const {
- results.add<EraseDeadLinalgOp, FoldTensorCastOp>(getContext());
+ results.add<EraseDeadLinalgOp, FoldTensorCastOp, SimplifyDepthwiseConvOp,
+ SimplifyDepthwiseConvQOp>(getContext());
}
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index db915f10e7dd..3d434c2d6ebc 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1004,3 +1004,27 @@ func @dim_of_tiled_loop_result_no_canonicalize(%arg0: tensor<?x?xf32>, %arg1: te
return %r2 : index
}
+// -----
+
+// CHECK-LABEL: @depthwise_conv
+func @depthwise_conv(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?x?x?x1xf32>, %arg2: tensor<?x?x?x?x1xf32>) -> tensor<?x?x?x?x1xf32> {
+ // CHECK-DAG: %[[KERNEL:.+]] = linalg.tensor_collapse_shape %arg1 {{\[\[}}0], [1], [2, 3]]
+ // CHECK-DAG: %[[INIT:.+]] = linalg.tensor_collapse_shape %arg2 {{\[\[}}0], [1], [2], [3, 4]]
+ // CHECK-DAG: %[[CONV:.+]] = linalg.depthwise_conv2D_nhw {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %[[KERNEL]] : tensor<?x?x?x?xf32>, tensor<?x?x?xf32>) outs(%[[INIT]] : tensor<?x?x?x?xf32>)
+ // CHECK: %[[OUT:.+]] = linalg.tensor_expand_shape %[[CONV]] {{\[\[}}0], [1], [2], [3, 4]]
+ %0 = linalg.depthwise_conv2D_nhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<?x?x?x?xf32>, tensor<?x?x?x1xf32>) outs(%arg2 : tensor<?x?x?x?x1xf32>) -> tensor<?x?x?x?x1xf32>
+ return %0 : tensor<?x?x?x?x1xf32>
+}
+
+
+// -----
+
+// CHECK-LABEL: @depthwise_conv_q
+func @depthwise_conv_q(%arg0: tensor<?x?x?x?xi8>, %arg1: tensor<?x?x?x1xi8>, %arg2: tensor<?x?x?x?x1xi32>, %arg3 : i32, %arg4 : i32) -> tensor<?x?x?x?x1xi32> {
+ // CHECK-DAG: %[[KERNEL:.+]] = linalg.tensor_collapse_shape %arg1 {{\[\[}}0], [1], [2, 3]]
+ // CHECK-DAG: %[[INIT:.+]] = linalg.tensor_collapse_shape %arg2 {{\[\[}}0], [1], [2], [3, 4]]
+ // CHECK-DAG: %[[CONV:.+]] = linalg.depthwise_conv2D_nhw_q {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %[[KERNEL]], %arg3, %arg4 : tensor<?x?x?x?xi8>, tensor<?x?x?xi8>, i32, i32) outs(%[[INIT]] : tensor<?x?x?x?xi32>)
+ // CHECK: %[[OUT:.+]] = linalg.tensor_expand_shape %[[CONV]] {{\[\[}}0], [1], [2], [3, 4]]
+ %0 = linalg.depthwise_conv2D_nhwc_q {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1, %arg3, %arg4 : tensor<?x?x?x?xi8>, tensor<?x?x?x1xi8>, i32, i32) outs(%arg2 : tensor<?x?x?x?x1xi32>) -> tensor<?x?x?x?x1xi32>
+ return %0 : tensor<?x?x?x?x1xi32>
+}
More information about the Mlir-commits
mailing list