[Mlir-commits] [mlir] [mlir][linalg] Add NHWC + FHWC Img2Col (PR #68708)

Quinn Dawkins llvmlistbot at llvm.org
Wed Oct 11 10:56:04 PDT 2023


================
@@ -494,6 +494,140 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
                         reshapedResult.getOperation());
 }
 
+FailureOr<std::pair<Operation *, Operation *>>
+rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
+  auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
+  auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
+  auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
+
+  if (!filterType.hasStaticShape())
+    return rewriter.notifyMatchFailure(
+        convOp, "expected a static shape for the filter");
+
+  if (!inputType.hasStaticShape())
+    return rewriter.notifyMatchFailure(convOp,
+                                       "expected a static shape for the input");
+
+  // TODO: Support dilation.
+  if (!hasAllOneValues(convOp.getDilations()))
+    return rewriter.notifyMatchFailure(convOp,
+                                       "expected all ones for dilations");
+
+  MLIRContext *context = rewriter.getContext();
+  Value input = convOp.getInputs()[0];
+  Value filter = convOp.getInputs()[1];
+  Value output = convOp.getOutputs()[0];
+
+  ArrayRef<int64_t> filterShape = filterType.getShape();
+  ArrayRef<int64_t> outputShape = outputType.getShape();
+
+  int64_t n = outputShape[0];
+  int64_t oh = outputShape[1];
+  int64_t ow = outputShape[2];
+  int64_t oc = outputShape[3];
+  int64_t fh = filterShape[1];
+  int64_t fw = filterShape[2];
+  int64_t ic = filterShape[3];
+
+  Location loc = convOp.getLoc();
+
+  // Reshape output and filter to the LHS and result of a (B)MNK matmul.
----------------
qedawkins wrote:

nit: matmul_transpose_b (or another way to indicate it isn't a normal matmul).

https://github.com/llvm/llvm-project/pull/68708


More information about the Mlir-commits mailing list