[Mlir-commits] [mlir] [mlir][linalg] Fix inferConvolutionDimsImpl (depthwise convs) (PR #90057)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Apr 25 07:04:51 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: Andrzej WarzyƄski (banach-space)

<details>
<summary>Changes</summary>

ATM, `inferConvolutionDimsImpl` will "remove" "unconvolved" dims from the
calculation of the channel dims. However, that's incorrect for depthwise
convolutions for which the channel dimension falls into that group (i.e.
"unconvolved" dims).


---
Full diff: https://github.com/llvm/llvm-project/pull/90057.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp (-1) 
- (modified) mlir/test/Dialect/Linalg/match-ops-interpreter.mlir (+22) 


``````````diff
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 3627ff6617eda3..3b92da5ceccd39 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -556,7 +556,6 @@ inferConvolutionDimsImpl(LinalgOp linalgOp,
   // filterDims & outputDims - unConvolvedDims are the output channel iterators.
   llvm::SmallDenseSet<int64_t> oc = filterDims;
   llvm::set_intersect(oc, outputDims);
-  llvm::set_subtract(oc, inputExprWalker.unConvolvedDims);
 
   // filterDims & outputDims & unConvolvedDims are the depth iterators.
   llvm::SmallDenseSet<int64_t> depth = filterDims;
diff --git a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
index 24c7bdd9e1050e..c637e1df7efd3e 100644
--- a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
+++ b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
@@ -1062,6 +1062,28 @@ module attributes { transform.target_tag = "start_here" } {
     return %result : tensor<10x18x15xf64>
   }
 
+  func.func @convolution_depthwise(%input: tensor<1x10x196x48xf32>, %filter: tensor<1x4x48xf32>) -> tensor<1x10x191x48xf32> {
+    %cst = arith.constant 0.0 : f32 
+    %empty = tensor.empty() : tensor<1x10x191x48xf32>
+    %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<1x10x191x48xf32>) -> tensor<1x10x191x48xf32>
+    // expected-remark @below {{convolution}}
+    // expected-remark @below {{batch dims 0}}
+    // expected-remark @below {{output image dims 1 : i64, 2 : i64}}
+    // expected-remark @below {{output channel dims 3}}
+    // expected-remark @below {{filter loop dims 4 : i64, 5 : i64}}
+    // expected-remark @below {{input channel dims}}
+    // expected-remark @below {{depth dims 3}}
+    // expected-remark @below {{strides 1 : i64, 1 : i64}}
+    // expected-remark @below {{dilations 1 : i64, 1 : i64}}
+    %result = linalg.depthwise_conv_2d_nhwc_hwc {
+      dilations = dense<1> : tensor<2xi64>,
+      strides = dense<1> : tensor<2xi64>}
+      ins(%input, %filter : tensor<1x10x196x48xf32>, tensor<1x4x48xf32>)
+      outs(%fill : tensor<1x10x191x48xf32>) -> tensor<1x10x191x48xf32>
+
+    return %result : tensor<1x10x191x48xf32>
+  }
+
   func.func @convolution_multi_channel(%input: tensor<2x34x68x16xf32>, %filter: tensor<8x2x3x5x16x16xf32>) -> tensor<8x32x32x16xf32> {
     %cst = arith.constant 0.0 : f32
     %empty = tensor.empty() : tensor<8x32x32x16xf32>

``````````

</details>


https://github.com/llvm/llvm-project/pull/90057


More information about the Mlir-commits mailing list