[Mlir-commits] [mlir] [mlir][linalg][conv] Flatten the channel dimension when vectorizing (PR #71918)

Nicolas Vasilache llvmlistbot at llvm.org
Mon Nov 13 05:35:55 PST 2023


================
@@ -2936,6 +2938,176 @@ struct Conv1DGenerator
         .getOperation();
   }
 
+  /// Generate a vector implementation for ("flatten channel dim"):
+  /// ```
+  ///   Op def: (     n,     w,     c,    kw)
+  ///    Iters: ({Par(), Par(), Par(), Red()})
+  ///   Layout: {{n, 1 * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
+  /// ```
+  /// c of the input/output is collapsed with w. kw is always unrolled and
+  /// broadcast to match w.
+  ///
+  /// TODO: Add support for non-unit stride/dilation
+  /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
+  /// > 1.
+  FailureOr<Operation *> depthwiseConvFlatten() {
+    if (!valid)
+      return rewriter.notifyMatchFailure(op, "unvectorizable depthwise conv");
+
+    int64_t nSize, iSize, wSize, cSize, kwSize;
+    // kernel{kw, c}
+    bindShapeDims(rhsShapedType, kwSize, cSize);
+    // out{n, w, c}
+    bindShapeDims(resShapedType, nSize, wSize);
+    // in{n, w, c}
+    bindShapeDims(lhsShapedType, nSize, iSize);
+
+    vector::TransferWriteOp write;
+    Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+
+    if (strideW == 1)
+      return rewriter.notifyMatchFailure(
+          op, "Non-unit strides are not supported yet");
+    if (dilationW == 1)
+      return rewriter.notifyMatchFailure(
+          op, "Non-unit dilations are not supported yet");
+
+    Type lhsEltType = lhsShapedType.getElementType();
----------------
nicolasvasilache wrote:

The code below (minus the reassociations) is very similar to the existing.
There should be a way to refactor this better.

https://github.com/llvm/llvm-project/pull/71918


More information about the Mlir-commits mailing list