[Mlir-commits] [mlir] [mlir][linalg] Add NHWC + FHWC Img2Col (PR #68708)
Quinn Dawkins
llvmlistbot at llvm.org
Wed Oct 11 10:56:04 PDT 2023
================
@@ -494,6 +494,140 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
reshapedResult.getOperation());
}
+FailureOr<std::pair<Operation *, Operation *>>
+rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
+ auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
+ auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
+ auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
+
+ if (!filterType.hasStaticShape())
+ return rewriter.notifyMatchFailure(
+ convOp, "expected a static shape for the filter");
+
+ if (!inputType.hasStaticShape())
+ return rewriter.notifyMatchFailure(convOp,
+ "expected a static shape for the input");
+
+ // TODO: Support dilation.
+ if (!hasAllOneValues(convOp.getDilations()))
+ return rewriter.notifyMatchFailure(convOp,
+ "expected all ones for dilations");
+
+ MLIRContext *context = rewriter.getContext();
+ Value input = convOp.getInputs()[0];
+ Value filter = convOp.getInputs()[1];
+ Value output = convOp.getOutputs()[0];
+
+ ArrayRef<int64_t> filterShape = filterType.getShape();
+ ArrayRef<int64_t> outputShape = outputType.getShape();
+
+ int64_t n = outputShape[0];
+ int64_t oh = outputShape[1];
+ int64_t ow = outputShape[2];
+ int64_t oc = outputShape[3];
+ int64_t fh = filterShape[1];
+ int64_t fw = filterShape[2];
+ int64_t ic = filterShape[3];
+
+ Location loc = convOp.getLoc();
+
+ // Reshape output and filter to the LHS and result of a (B)MNK matmul.
+ SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
+ auto reshapedFilterType =
+ RankedTensorType::get({oc, fh * fw * ic}, filterType.getElementType());
+ Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
+ loc, reshapedFilterType, filter, filterReassocIndices);
+
+ SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1, 2}, {3}};
+ RankedTensorType reshapedOutputType =
+ RankedTensorType::get({n, oh * ow, oc}, outputType.getElementType());
+ Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>(
+ loc, reshapedOutputType, output, outputReassocIndices);
+
+ SmallVector<int64_t> colTensorShape = {n, oh * ow, fh * fw * ic};
+ Value colTensor = rewriter.create<tensor::EmptyOp>(
+ loc, colTensorShape, inputType.getElementType());
+
+ // Convert the input to a (BMK) column tensor.
+ auto nloops = colTensorShape.size();
+
+ auto parallel = utils::IteratorType::parallel;
+ auto reduction = utils::IteratorType::reduction;
+ SmallVector<utils::IteratorType> img2colIterators(nloops, parallel);
+
+ SmallVector<AffineMap> img2colIndexingMaps = {
+ AffineMap::getMultiDimIdentityMap(nloops, context)};
+
+ auto img2ColTensor = rewriter.create<linalg::GenericOp>(
+ loc, colTensor.getType(),
+ /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
+ img2colIterators,
+ [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+ // Get the iterators named based on the matmul (batch, m, k).
+ Value bIndex = nestedBuilder.create<linalg::IndexOp>(loc, 0);
+ Value mIndex = nestedBuilder.create<linalg::IndexOp>(loc, 1);
+ Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2);
+
+ // Recover the original iteration indices from the problem/input sizes.
+ SmallVector<Value> mIndices = unrollIndex(
+ nestedBuilder, nestedLoc, mIndex, ArrayRef<int64_t>{oh, ow});
+ auto ohIndex = mIndices[0];
+ auto owIndex = mIndices[1];
+
+ SmallVector<Value> kIndices = unrollIndex(
+ nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{fh, fw, ic});
+ auto fhIndex = kIndices[0];
+ auto fwIndex = kIndices[1];
+ auto icIndex = kIndices[2];
+
+ // Extract the input element corresponding to the expanded indices.
+ Value hIndex =
+ getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
+ convOp.getStrides().getValues<int64_t>()[0]);
+ Value wIndex =
+ getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
+ convOp.getStrides().getValues<int64_t>()[1]);
+
+ // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
+ SmallVector<Value> extractionIndices{bIndex, hIndex, wIndex, icIndex};
+ Value inputVal = nestedBuilder.create<tensor::ExtractOp>(
+ loc, input, extractionIndices);
+ nestedBuilder.create<linalg::YieldOp>(nestedLoc, inputVal);
+ });
+
+ // Because we didn't transpose the filters we don't actually have a batched
+ // matrix multiply. Instead, we have an operation consisting of "rowise" dot
----------------
qedawkins wrote:
nit: row-wise
https://github.com/llvm/llvm-project/pull/68708
More information about the Mlir-commits
mailing list