[Mlir-commits] [mlir] [mlir][linalg] Add NHWC + FHWC Img2Col (PR #68708)
Quinn Dawkins
llvmlistbot at llvm.org
Wed Oct 11 10:56:04 PDT 2023
================
@@ -494,6 +494,140 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
reshapedResult.getOperation());
}
+FailureOr<std::pair<Operation *, Operation *>>
+rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
+ auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
+ auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
+ auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
+
+ if (!filterType.hasStaticShape())
+ return rewriter.notifyMatchFailure(
+ convOp, "expected a static shape for the filter");
+
+ if (!inputType.hasStaticShape())
+ return rewriter.notifyMatchFailure(convOp,
+ "expected a static shape for the input");
+
+ // TODO: Support dilation.
+ if (!hasAllOneValues(convOp.getDilations()))
+ return rewriter.notifyMatchFailure(convOp,
+ "expected all ones for dilations");
+
+ MLIRContext *context = rewriter.getContext();
+ Value input = convOp.getInputs()[0];
+ Value filter = convOp.getInputs()[1];
+ Value output = convOp.getOutputs()[0];
+
+ ArrayRef<int64_t> filterShape = filterType.getShape();
+ ArrayRef<int64_t> outputShape = outputType.getShape();
+
+ int64_t n = outputShape[0];
+ int64_t oh = outputShape[1];
+ int64_t ow = outputShape[2];
+ int64_t oc = outputShape[3];
+ int64_t fh = filterShape[1];
+ int64_t fw = filterShape[2];
+ int64_t ic = filterShape[3];
+
+ Location loc = convOp.getLoc();
+
+ // Reshape output and filter to the LHS and result of a (B)MNK matmul.
----------------
qedawkins wrote:
nit: matmul_transpose_b (or another way to indicate it isn't a normal matmul).
https://github.com/llvm/llvm-project/pull/68708
More information about the Mlir-commits
mailing list