[Mlir-commits] [mlir] [mlir][linalg] Fix inferConvolutionDimsImpl (depthwise convs) (PR #90057)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Apr 25 07:04:51 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg
Author: Andrzej WarzyĆski (banach-space)
<details>
<summary>Changes</summary>
ATM, `inferConvolutionDimsImpl` will "remove" "unconvolved" dims from the
calculation of the channel dims. However, that's incorrect for depthwise
convolutions for which the channel dimension falls into that group (i.e.
"unconvolved" dims).
---
Full diff: https://github.com/llvm/llvm-project/pull/90057.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp (-1)
- (modified) mlir/test/Dialect/Linalg/match-ops-interpreter.mlir (+22)
``````````diff
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 3627ff6617eda3..3b92da5ceccd39 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -556,7 +556,6 @@ inferConvolutionDimsImpl(LinalgOp linalgOp,
// filterDims & outputDims - unConvolvedDims are the output channel iterators.
llvm::SmallDenseSet<int64_t> oc = filterDims;
llvm::set_intersect(oc, outputDims);
- llvm::set_subtract(oc, inputExprWalker.unConvolvedDims);
// filterDims & outputDims & unConvolvedDims are the depth iterators.
llvm::SmallDenseSet<int64_t> depth = filterDims;
diff --git a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
index 24c7bdd9e1050e..c637e1df7efd3e 100644
--- a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
+++ b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
@@ -1062,6 +1062,28 @@ module attributes { transform.target_tag = "start_here" } {
return %result : tensor<10x18x15xf64>
}
+ func.func @convolution_depthwise(%input: tensor<1x10x196x48xf32>, %filter: tensor<1x4x48xf32>) -> tensor<1x10x191x48xf32> {
+ %cst = arith.constant 0.0 : f32
+ %empty = tensor.empty() : tensor<1x10x191x48xf32>
+ %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<1x10x191x48xf32>) -> tensor<1x10x191x48xf32>
+ // expected-remark @below {{convolution}}
+ // expected-remark @below {{batch dims 0}}
+ // expected-remark @below {{output image dims 1 : i64, 2 : i64}}
+ // expected-remark @below {{output channel dims 3}}
+ // expected-remark @below {{filter loop dims 4 : i64, 5 : i64}}
+ // expected-remark @below {{input channel dims}}
+ // expected-remark @below {{depth dims 3}}
+ // expected-remark @below {{strides 1 : i64, 1 : i64}}
+ // expected-remark @below {{dilations 1 : i64, 1 : i64}}
+ %result = linalg.depthwise_conv_2d_nhwc_hwc {
+ dilations = dense<1> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>}
+ ins(%input, %filter : tensor<1x10x196x48xf32>, tensor<1x4x48xf32>)
+ outs(%fill : tensor<1x10x191x48xf32>) -> tensor<1x10x191x48xf32>
+
+ return %result : tensor<1x10x191x48xf32>
+ }
+
func.func @convolution_multi_channel(%input: tensor<2x34x68x16xf32>, %filter: tensor<8x2x3x5x16x16xf32>) -> tensor<8x32x32x16xf32> {
%cst = arith.constant 0.0 : f32
%empty = tensor.empty() : tensor<8x32x32x16xf32>
``````````
</details>
https://github.com/llvm/llvm-project/pull/90057
More information about the Mlir-commits
mailing list