[Mlir-commits] [mlir] c828030 - [mlir][linalg] Refactor convolution to img2col conversion to use gather semantics
Quinn Dawkins
llvmlistbot at llvm.org
Thu Mar 23 16:40:43 PDT 2023
Author: Quinn Dawkins
Date: 2023-03-23T19:38:53-04:00
New Revision: c82803097f6a89edc49577e5bb4f7309e053efcc
URL: https://github.com/llvm/llvm-project/commit/c82803097f6a89edc49577e5bb4f7309e053efcc
DIFF: https://github.com/llvm/llvm-project/commit/c82803097f6a89edc49577e5bb4f7309e053efcc.diff
LOG: [mlir][linalg] Refactor convolution to img2col conversion to use gather semantics
Following up on the comments in https://reviews.llvm.org/D144108 this
patch refactors the im2col conversion patterns for `linalg.conv_2d_nhwc_hwcf`
and `linalg.conv_2d_nchw_fchw` convolutions to use gather semantics for the im2col
packing `linalg.generic`.
Follow up work can include a similar pattern for depthwise convolutions
and a generalization of the patterns here to work with any `LinalgOp` as
well.
Differential Revision: https://reviews.llvm.org/D144678
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
index 14bff411ef8c1..58a23e2be54d1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
@@ -41,6 +41,49 @@ static Value createMul(Location loc, Value x, Value y, OpBuilder &builder) {
return builder.create<arith::MulFOp>(loc, x, y);
}
+// Unrolls the given composite `index` into a set of subindices with maximum
+// iteration ranges specified by `factors` according to the following
+// assumptions:
+// 1. The iteration range for `index` is [0, f1 * f2 * ... * fn] i.e. the
+// product of the given list of factors
+// 2. The iterators corresponding to the entries in `factors` are ordered from
+// slowest to fastest varying
+// Each subindex is then computed as:
+// subindex[i] = floor( (index % (fi * ... * fn)) / (fi-1 * ... * fn) )
+static SmallVector<Value, 3> unrollIndex(OpBuilder &b, Location loc,
+ Value index,
+ ArrayRef<int64_t> factors) {
+ assert(factors.size() >= 1 && "empty factor list");
+ SmallVector<Value, 3> indices(factors.size());
+ int64_t runningProd = 1;
+ for (int i = factors.size() - 1, end = 0; i >= end; i--) {
+ Value unrolledIndex = index;
+ if (i > 0) {
+ Value modBase = b.create<arith::ConstantOp>(
+ loc, b.getIndexAttr(runningProd * factors[i]));
+ unrolledIndex = b.create<arith::RemUIOp>(loc, unrolledIndex, modBase);
+ }
+ if (runningProd > 1) {
+ Value divDenom =
+ b.create<arith::ConstantOp>(loc, b.getIndexAttr(runningProd));
+ unrolledIndex = b.create<arith::DivUIOp>(loc, unrolledIndex, divDenom);
+ }
+ runningProd *= factors[i];
+ indices[i] = unrolledIndex;
+ }
+ return indices;
+}
+
+// Given indices corresponding to iterators in the output (oIndex) and filter
+// (fIndex) for a convolution, compute the convolved index for the
+// input as `oIndex * stride + fIndex`.
+static Value getConvolvedIndex(OpBuilder &b, Location loc, Value oIndex,
+ Value fIndex, int64_t stride) {
+ Value strideVal = b.create<arith::ConstantOp>(loc, b.getIndexAttr(stride));
+ Value convIndex = b.create<arith::MulIOp>(loc, oIndex, strideVal);
+ return b.create<arith::AddIOp>(loc, convIndex, fIndex);
+}
+
FailureOr<std::pair<Operation *, Operation *>>
rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) {
auto inputType = convOp.getInputs()[0].getType().cast<ShapedType>();
@@ -68,32 +111,34 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) {
ArrayRef<int64_t> filterShape = filterType.getShape();
ArrayRef<int64_t> outputShape = outputType.getShape();
- int n = outputShape[0];
- int oh = outputShape[1];
- int ow = outputShape[2];
- int oc = outputShape[3];
- int fh = filterShape[0];
- int fw = filterShape[1];
- int ic = filterShape[2];
+ int64_t n = outputShape[0];
+ int64_t oh = outputShape[1];
+ int64_t ow = outputShape[2];
+ int64_t oc = outputShape[3];
+ int64_t fh = filterShape[0];
+ int64_t fw = filterShape[1];
+ int64_t ic = filterShape[2];
Location loc = convOp.getLoc();
- SmallVector<int64_t> colTensorShape = {n, oh, ow, fh, fw, ic};
+ // Reshape output and filter to the LHS and result of a (B)MNK matmul.
+ SmallVector<ReassociationIndices> filterReassocIndices = {{0, 1, 2}, {3}};
+ auto reshapedFilterType =
+ RankedTensorType::get({fh * fw * ic, oc}, inputType.getElementType());
+ Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
+ loc, reshapedFilterType, filter, filterReassocIndices);
+
+ SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1, 2}, {3}};
+ RankedTensorType reshapedOutputType =
+ RankedTensorType::get({n, oh * ow, oc}, outputType.getElementType());
+ Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>(
+ loc, reshapedOutputType, output, outputReassocIndices);
+ SmallVector<int64_t> colTensorShape = {n, oh * ow, fh * fw * ic};
Value colTensor = rewriter.create<tensor::EmptyOp>(
loc, colTensorShape, inputType.getElementType());
- AffineExpr nDim, ohDim, owDim, khDim, kwDim, icDim;
- bindDims(context, nDim, ohDim, owDim, khDim, kwDim, icDim);
-
- AffineExpr shSym = rewriter.getAffineConstantExpr(
- convOp.getStrides().getValues<int64_t>()[0]);
- AffineExpr swSym = rewriter.getAffineConstantExpr(
- convOp.getStrides().getValues<int64_t>()[1]);
-
- SmallVector<AffineExpr> inputExprs = {nDim, ohDim * shSym + khDim,
- owDim * swSym + kwDim, icDim};
-
+ // Convert the input to a (BMK) column tensor.
auto nloops = colTensorShape.size();
auto parallel = utils::IteratorType::parallel;
@@ -101,85 +146,68 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) {
SmallVector<utils::IteratorType> img2colIterators(nloops, parallel);
SmallVector<AffineMap> img2colIndexingMaps = {
- AffineMap::get(nloops, 0, inputExprs, context),
AffineMap::getMultiDimIdentityMap(nloops, context)};
auto img2ColTensor = rewriter.create<linalg::GenericOp>(
loc, colTensor.getType(),
- /*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps,
+ /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
img2colIterators,
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
- nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
+ // Get the iterators named based on the matmul (batch, m, k).
+ Value bIndex = nestedBuilder.create<linalg::IndexOp>(loc, 0);
+ Value mIndex = nestedBuilder.create<linalg::IndexOp>(loc, 1);
+ Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2);
+
+ // Recover the original iteration indices from the problem/input sizes.
+ SmallVector<Value, 3> mIndices = unrollIndex(
+ nestedBuilder, nestedLoc, mIndex, ArrayRef<int64_t>{oh, ow});
+ auto ohIndex = mIndices[0];
+ auto owIndex = mIndices[1];
+
+ SmallVector<Value, 3> kIndices = unrollIndex(
+ nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{fh, fw, ic});
+ auto fhIndex = kIndices[0];
+ auto fwIndex = kIndices[1];
+ auto icIndex = kIndices[2];
+
+ // Extract the input element corresponding to the expanded indices.
+ Value hIndex =
+ getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
+ convOp.getStrides().getValues<int64_t>()[0]);
+ Value wIndex =
+ getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
+ convOp.getStrides().getValues<int64_t>()[1]);
+
+ // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
+ SmallVector<Value> extractionIndices{bIndex, hIndex, wIndex, icIndex};
+ Value inputVal = nestedBuilder.create<tensor::ExtractOp>(
+ loc, input, extractionIndices);
+ nestedBuilder.create<linalg::YieldOp>(nestedLoc, inputVal);
});
- SmallVector<ReassociationIndices> img2ColTensorReassocIndices;
- SmallVector<ReassociationIndices> outputReassocIndices;
- RankedTensorType reshapedImg2ColTensorType, reshapedOutputType;
- if (n == 1) {
- img2ColTensorReassocIndices = {{0, 1, 2}, {3, 4, 5}};
- outputReassocIndices = {{0, 1, 2}, {3}};
-
- reshapedImg2ColTensorType = RankedTensorType::get(
- {oh * ow, fh * fw * ic}, inputType.getElementType());
- reshapedOutputType =
- RankedTensorType::get({oh * ow, oc}, outputType.getElementType());
- } else {
- img2ColTensorReassocIndices = {{0}, {1, 2}, {3, 4, 5}};
- outputReassocIndices = {{0}, {1, 2}, {3}};
-
- reshapedImg2ColTensorType = RankedTensorType::get(
- {n, oh * ow, fh * fw * ic}, inputType.getElementType());
- reshapedOutputType =
- RankedTensorType::get({n, oh * ow, oc}, outputType.getElementType());
- }
-
- SmallVector<ReassociationIndices> filterReassocIndices = {{0, 1, 2}, {3}};
- auto reshapedFilterType =
- RankedTensorType::get({fh * fw * ic, oc}, inputType.getElementType());
-
- Value reshapedImg2ColTensor = rewriter.create<tensor::CollapseShapeOp>(
- loc, reshapedImg2ColTensorType, img2ColTensor.getResult(0),
- img2ColTensorReassocIndices);
-
- Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
- loc, reshapedFilterType, filter, filterReassocIndices);
-
- Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>(
- loc, reshapedOutputType, output, outputReassocIndices);
-
- Value result;
- if (n == 1) {
- auto matmulOp = rewriter.create<linalg::MatmulOp>(
- loc, reshapedOutputType,
- ArrayRef<Value>{reshapedImg2ColTensor, reshapedFilter},
- ArrayRef<Value>{reshapedOutput});
- result = matmulOp.getResults().front();
- } else {
- // For cases where batch is not 1, we need to keep the batch dimension
- // separate. Because the filter does not share the same batch dimension,
- // the batch dimension is only used in indexing the input and output. Thus
- // we cannot use existing linalg named ops like linalg.batch_matmul.
- // i.e. (B x) M x K * K x N = (B x) M x N
- AffineExpr bDim, mDim, nDim, kDim;
- bindDims(context, bDim, mDim, nDim, kDim);
- auto lhsMap = AffineMap::get(4, 0, {bDim, mDim, kDim}, context);
- auto rhsMap = AffineMap::get(4, 0, {kDim, nDim}, context);
- auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
- SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
- parallel, reduction};
-
- auto genericOp = rewriter.create<linalg::GenericOp>(
- loc, reshapedOutputType,
- /*inputs=*/ValueRange{reshapedImg2ColTensor, reshapedFilter},
- /*outputs=*/ValueRange{reshapedOutput},
- ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
- [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
- Value mul = createMul(loc, args[0], args[1], nestedBuilder);
- Value add = createAdd(loc, mul, args[2], nestedBuilder);
- nestedBuilder.create<linalg::YieldOp>(nestedLoc, add);
- });
- result = genericOp.getResults().front();
- }
+ // Because the filter does not share the same batch dimension,
+ // the batch dimension is only used in indexing the input and output. Thus
+ // we cannot use existing linalg named ops like linalg.batch_matmul.
+ // i.e. (B x) M x K * K x N = (B x) M x N
+ AffineExpr bDim, mDim, nDim, kDim;
+ bindDims(context, bDim, mDim, nDim, kDim);
+ auto lhsMap = AffineMap::get(4, 0, {bDim, mDim, kDim}, context);
+ auto rhsMap = AffineMap::get(4, 0, {kDim, nDim}, context);
+ auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
+ SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
+ parallel, reduction};
+
+ auto genericOp = rewriter.create<linalg::GenericOp>(
+ loc, reshapedOutputType,
+ /*inputs=*/ValueRange{img2ColTensor.getResult(0), reshapedFilter},
+ /*outputs=*/ValueRange{reshapedOutput},
+ ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
+ [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+ Value mul = createMul(loc, args[0], args[1], nestedBuilder);
+ Value add = createAdd(loc, mul, args[2], nestedBuilder);
+ nestedBuilder.create<linalg::YieldOp>(nestedLoc, add);
+ });
+ Value result = genericOp.getResults().front();
auto reshapedResult = rewriter.create<tensor::ExpandShapeOp>(
loc, outputType, result, outputReassocIndices);
@@ -367,33 +395,33 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
auto filterShape = filterType.getShape();
auto outputShape = outputType.getShape();
- int n = outputShape[0];
- int oc = outputShape[1];
- int oh = outputShape[2];
- int ow = outputShape[3];
- int ic = filterShape[1];
- int fh = filterShape[2];
- int fw = filterShape[3];
+ int64_t n = outputShape[0];
+ int64_t oc = outputShape[1];
+ int64_t oh = outputShape[2];
+ int64_t ow = outputShape[3];
+ int64_t ic = filterShape[1];
+ int64_t fh = filterShape[2];
+ int64_t fw = filterShape[3];
auto loc = convOp.getLoc();
-
- SmallVector<int64_t, 4> colTensorShape = {n, ic, fh, fw, oh, ow};
-
- Value colTensor = rewriter.create<tensor::EmptyOp>(
- loc, colTensorShape, inputType.getElementType());
-
MLIRContext *context = rewriter.getContext();
- AffineExpr nDim, icDim, khDim, kwDim, ohDim, owDim;
- bindDims(context, nDim, icDim, khDim, kwDim, ohDim, owDim);
+ SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
+ auto reshapedFilterType =
+ RankedTensorType::get({oc, ic * fh * fw}, inputType.getElementType());
+ Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
+ loc, reshapedFilterType, filter, filterReassocIndices);
- auto shSym = rewriter.getAffineConstantExpr(
- convOp.getStrides().getValues<int64_t>()[0]);
- auto swSym = rewriter.getAffineConstantExpr(
- convOp.getStrides().getValues<int64_t>()[1]);
+ SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1}, {2, 3}};
+ auto reshapedOutputType =
+ RankedTensorType::get({n, oc, oh * ow}, outputType.getElementType());
+ Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>(
+ loc, reshapedOutputType, output, outputReassocIndices);
- SmallVector<AffineExpr, 4> inputExprs = {nDim, icDim, ohDim * shSym + khDim,
- owDim * swSym + kwDim};
+ // Convert the input to a (BKN) tensor.
+ SmallVector<int64_t, 4> colTensorShape = {n, ic * fh * fw, oh * ow};
+ Value colTensor = rewriter.create<tensor::EmptyOp>(
+ loc, colTensorShape, inputType.getElementType());
auto nloops = colTensorShape.size();
@@ -402,83 +430,67 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
SmallVector<utils::IteratorType, 3> img2colIterators(nloops, parallel);
SmallVector<AffineMap, 4> img2colIndexingMaps = {
- AffineMap::get(nloops, 0, inputExprs, context),
AffineMap::getMultiDimIdentityMap(nloops, context)};
auto img2ColTensor = rewriter.create<linalg::GenericOp>(
loc, colTensor.getType(),
- /*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps,
+ /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
img2colIterators,
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
- nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
+ // Get the iterators named based on the matmul (batch, m, k).
+ Value bIndex = nestedBuilder.create<linalg::IndexOp>(loc, 0);
+ Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 1);
+ Value nIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2);
+
+ // Recover the original iteration indices from the problem/input sizes.
+ SmallVector<Value, 3> kIndices = unrollIndex(
+ nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{ic, fh, fw});
+ auto icIndex = kIndices[0];
+ auto fhIndex = kIndices[1];
+ auto fwIndex = kIndices[2];
+
+ SmallVector<Value, 3> nIndices = unrollIndex(
+ nestedBuilder, nestedLoc, nIndex, ArrayRef<int64_t>{oh, ow});
+ auto ohIndex = nIndices[0];
+ auto owIndex = nIndices[1];
+
+ // Extract the input element corresponding to the expanded indices.
+ Value hIndex =
+ getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
+ convOp.getStrides().getValues<int64_t>()[0]);
+ Value wIndex =
+ getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
+ convOp.getStrides().getValues<int64_t>()[1]);
+
+ // im2col[n, ic*fh*fw, oh*ow] = input[n, ic, sh*oh + fh, sw*ow + fw]
+ SmallVector<Value> extractionIndices{bIndex, icIndex, hIndex, wIndex};
+ Value inputVal = nestedBuilder.create<tensor::ExtractOp>(
+ loc, input, extractionIndices);
+ nestedBuilder.create<linalg::YieldOp>(nestedLoc, inputVal);
});
- SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
- auto reshapedFilterType =
- RankedTensorType::get({oc, fh * fw * ic}, inputType.getElementType());
- Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
- loc, reshapedFilterType, filter, filterReassocIndices);
-
- SmallVector<ReassociationIndices> img2ColTensorReassocIndices;
- SmallVector<ReassociationIndices> outputReassocIndices;
- RankedTensorType reshapedImg2ColTensorType, reshapedOutputType;
- if (n == 1) {
- img2ColTensorReassocIndices = {{0, 1, 2, 3}, {4, 5}};
- outputReassocIndices = {{0, 1}, {2, 3}};
-
- reshapedImg2ColTensorType = RankedTensorType::get(
- {fh * fw * ic, oh * ow}, inputType.getElementType());
- reshapedOutputType =
- RankedTensorType::get({oc, oh * ow}, outputType.getElementType());
- } else {
- img2ColTensorReassocIndices = {{0}, {1, 2, 3}, {4, 5}};
- outputReassocIndices = {{0}, {1}, {2, 3}};
-
- reshapedImg2ColTensorType = RankedTensorType::get(
- {n, fh * fw * ic, oh * ow}, inputType.getElementType());
- reshapedOutputType =
- RankedTensorType::get({n, oc, oh * ow}, outputType.getElementType());
- }
-
- Value reshapedImg2ColTensor = rewriter.create<tensor::CollapseShapeOp>(
- loc, reshapedImg2ColTensorType, img2ColTensor.getResult(0),
- img2ColTensorReassocIndices);
-
- Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>(
- loc, reshapedOutputType, output, outputReassocIndices);
-
- Value result;
- if (n == 1) {
- auto matmulOp = rewriter.create<linalg::MatmulOp>(
- loc, reshapedOutputType,
- ArrayRef<Value>{reshapedFilter, reshapedImg2ColTensor},
- ArrayRef<Value>{reshapedOutput});
- result = matmulOp.getResults().front();
- } else {
- // For cases where batch is not 1, we need to keep the batch dimension
- // separate. Because the filter does not share the same batch dimension,
- // the batch dimension is only used in indexing the input and output. Thus
- // we cannot use existing linalg named ops like linalg.batch_matmul.
- // i.e. M x K * (B x) K x N = (B x) M x N
- AffineExpr bDim, mDim, nDim, kDim;
- bindDims(context, bDim, mDim, nDim, kDim);
- auto lhsMap = AffineMap::get(4, 0, {mDim, kDim}, context);
- auto rhsMap = AffineMap::get(4, 0, {bDim, kDim, nDim}, context);
- auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
- SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
- parallel, reduction};
- auto genericOp = rewriter.create<linalg::GenericOp>(
- loc, reshapedOutputType,
- /*inputs=*/ValueRange{reshapedFilter, reshapedImg2ColTensor},
- /*outputs=*/ValueRange{reshapedOutput},
- ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
- [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
- Value mul = createMul(loc, args[0], args[1], nestedBuilder);
- Value add = createAdd(loc, mul, args[2], nestedBuilder);
- nestedBuilder.create<linalg::YieldOp>(nestedLoc, add);
- });
- result = genericOp.getResults().front();
- }
+ // Because the filter does not share the same batch dimension,
+ // the batch dimension is only used in indexing the input and output. Thus
+ // we cannot use existing linalg named ops like linalg.batch_matmul.
+ // i.e. M x K * (B x) K x N = (B x) M x N
+ AffineExpr bDim, mDim, nDim, kDim;
+ bindDims(context, bDim, mDim, nDim, kDim);
+ auto lhsMap = AffineMap::get(4, 0, {mDim, kDim}, context);
+ auto rhsMap = AffineMap::get(4, 0, {bDim, kDim, nDim}, context);
+ auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
+ SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
+ parallel, reduction};
+ auto genericOp = rewriter.create<linalg::GenericOp>(
+ loc, reshapedOutputType,
+ /*inputs=*/ValueRange{reshapedFilter, img2ColTensor.getResult(0)},
+ /*outputs=*/ValueRange{reshapedOutput},
+ ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
+ [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+ Value mul = createMul(loc, args[0], args[1], nestedBuilder);
+ Value add = createAdd(loc, mul, args[2], nestedBuilder);
+ nestedBuilder.create<linalg::YieldOp>(nestedLoc, add);
+ });
+ Value result = genericOp.getResults().front();
auto reshapedResult = rewriter.create<tensor::ExpandShapeOp>(
loc, outputType, result, outputReassocIndices);
diff --git a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
index e33e51ddababb..ffcba1086f3f6 100644
--- a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
+++ b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
@@ -29,36 +29,71 @@ transform.sequence failures(propagate) {
// CHECK: IR printer: tensor_producer
// CHECK-NEXT: %[[COL_TENSOR:.+]] = linalg.generic
-// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d3, d2 + d4, d5)>,
-// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>]
-// CHECK: ^bb0(%[[IN_DATA:.+]]: f32, %[[OUT_DATA:.+]]: f32)
-// CHECK: linalg.yield %[[IN_DATA]] : f32
+// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
+// CHECK: ^bb0(%[[OUT_DATA:.+]]: f32)
+
+// Collapsed indices.
+// CHECK: %[[BINDEX:.+]] = linalg.index 0 : index
+// CHECK: %[[MINDEX:.+]] = linalg.index 1 : index
+// CHECK: %[[KINDEX:.+]] = linalg.index 2 : index
+
+// Unrolled output shape indices.
+// CHECK: %[[C14:.+]] = arith.constant 14 : index
+// CHECK: %[[OWINDEX:.+]] = arith.remui %[[MINDEX]], %[[C14]] : index
+// CHECK: %[[C14_1:.+]] = arith.constant 14 : index
+// CHECK: %[[OHINDEX:.+]] = arith.divui %[[MINDEX]], %[[C14_1]] : index
+
+// Unrolled filter shape indices.
+// CHECK: %[[C4:.+]] = arith.constant 4 : index
+// CHECK: %[[ICINDEX:.+]] = arith.remui %[[KINDEX]], %[[C4]] : index
+// CHECK: %[[C12:.+]] = arith.constant 12 : index
+// CHECK: %[[FWREM:.+]] = arith.remui %[[KINDEX]], %[[C12]] : index
+// CHECK: %[[C4_2:.+]] = arith.constant 4 : index
+// CHECK: %[[FWINDEX:.+]] = arith.divui %[[FWREM]], %[[C4_2]] : index
+// CHECK: %[[C12_3:.+]] = arith.constant 12 : index
+// CHECK: %[[FHINDEX:.+]] = arith.divui %[[KINDEX]], %[[C12_3]] : index
+
+// Compute input indices.
+// CHECK: %[[SH:.+]] = arith.constant 1 : index
+// CHECK: %[[STRIDEDOH:.+]] = arith.muli %[[OHINDEX]], %[[SH]] : index
+// CHECK: %[[CONVH:.+]] = arith.addi %[[STRIDEDOH]], %[[FHINDEX]] : index
+// CHECK: %[[SW:.+]] = arith.constant 1 : index
+// CHECK: %[[STRIDEDOW:.+]] = arith.muli %[[OWINDEX]], %[[SW]] : index
+// CHECK: %[[CONVW:.+]] = arith.addi %[[STRIDEDOW]], %[[FWINDEX]] : index
+// CHECK: %[[EXTRACTED_INPUT:.+]] = tensor.extract
+// CHECK-SAME: %{{.+}}{{\[}}%[[BINDEX]], %[[CONVH]], %[[CONVW]], %[[ICINDEX]]] : tensor<1x16x16x4xf32>
+// CHECK: linalg.yield %[[EXTRACTED_INPUT]] : f32
// CHECK: IR printer: transformed
-// CHECK: tensor.expand_shape %{{[^ ]*}} {{\[}}[0, 1, 2], [3]] : tensor<196x16xf32> into tensor<1x14x14x16xf32>
+// CHECK: tensor.expand_shape %{{[^ ]*}} {{\[}}[0], [1, 2], [3]] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32>
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d3, d2 + d4, d5)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
// CHECK: @conv_16433136
-// CHECK: %[[INPUT:.+]]: tensor<1x16x16x4xf32>
-// CHECK: %[[FILTER:.+]]: tensor<3x3x4x16xf32>
-// CHECK: %[[OUTPUT:.+]]: tensor<1x14x14x16xf32>
-// CHECK: %[[INIT_COL_TENSOR:.+]] = tensor.empty() : tensor<1x14x14x3x3x4xf32>
+// CHECK-SAME: %[[INPUT:.+]]: tensor<1x16x16x4xf32>
+// CHECK-SAME: %[[FILTER:.+]]: tensor<3x3x4x16xf32>
+// CHECK-SAME: %[[OUTPUT:.+]]: tensor<1x14x14x16xf32>
+// CHECK-DAG: %[[COLLAPSED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0, 1, 2], [3]] : tensor<3x3x4x16xf32> into tensor<36x16xf32>
+// CHECK-DAG: %[[COLLAPSED_OUT:.+]] = tensor.collapse_shape %[[OUTPUT]] {{\[}}[0], [1, 2], [3]] : tensor<1x14x14x16xf32> into tensor<1x196x16xf32>
+// CHECK: %[[INIT_COL_TENSOR:.+]] = tensor.empty() : tensor<1x196x36xf32>
// CHECK: %[[COL_TENSOR:.+]] = linalg.generic
// CHECK-SAME: #[[MAP0]]
+// CHECK: ^bb0(%[[OUT_DATA:.+]]: f32)
+// CHECK: linalg.yield %{{.+}} : f32
+// CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic
// CHECK-SAME: #[[MAP1]]
-// CHECK: ^bb0(%[[IN_DATA:.+]]: f32, %[[OUT_DATA:.+]]: f32)
-// CHECK: linalg.yield %[[IN_DATA]] : f32
-// CHECK-DAG: %[[RESHAPED_INIT_COL_TENSOR:.+]] = tensor.collapse_shape %[[COL_TENSOR]]
-// CHECK-SAME: [0, 1, 2], [3, 4, 5]
-// CHECK-SAME: tensor<1x14x14x3x3x4xf32> into tensor<196x36xf32>
-// CHECK-DAG: %[[RESHAPED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]]
-// CHECK-SAME: [0, 1, 2], [3]
-// CHECK-SAME: tensor<3x3x4x16xf32> into tensor<36x16xf32>
-// CHECK-DAG: %[[RESHAPED_OUTPUT:.+]] = tensor.collapse_shape %[[OUTPUT]]
-// CHECK-SAME: [0, 1, 2], [3]
-// CHECK: %[[MATMUL_RESULT:.+]] = linalg.matmul ins(%[[RESHAPED_INIT_COL_TENSOR]], %[[RESHAPED_FILTER]] : tensor<196x36xf32>, tensor<36x16xf32>) outs(%[[RESHAPED_OUTPUT]] : tensor<196x16xf32>)
-// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0, 1, 2], [3]] : tensor<196x16xf32> into tensor<1x14x14x16xf32>
+// CHECK-SAME: #[[MAP2]]
+// CHECK-SAME: #[[MAP3]]
+// CHECK-SAME: ins(%[[COL_TENSOR]], %[[COLLAPSED_FILTER]] : tensor<1x196x36xf32>, tensor<36x16xf32>)
+// CHECK-SAME: outs(%[[COLLAPSED_OUT]] : tensor<1x196x16xf32>)
+// CHECK: ^bb0(%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32)
+// CHECK: %[[MUL:.+]] = arith.mulf %[[ARG0]], %[[ARG1]] : f32
+// CHECK: %[[ADD:.+]] = arith.addf %[[MUL]], %[[ARG2]] : f32
+// CHECK: linalg.yield %[[ADD]] : f32
+// CHECK: } -> tensor<1x196x16xf32>
+// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32>
// CHECK: return %[[RESULT]]
func.func @conv_16433136(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
@@ -156,27 +191,24 @@ transform.sequence failures(propagate) {
// -----
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d3, d2 + d4, d5)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
+// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-DAG: #[[LHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
// CHECK-DAG: #[[RHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
// CHECK-DAG: #[[RESMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
// CHECK: func.func @batch_nhwc_conv
// CHECK-SAME: (%[[INPUT:.+]]: tensor<8x16x16x4xf32>, %[[FILTER:.+]]: tensor<3x3x4x16xf32>, %[[INIT:.+]]: tensor<8x14x14x16xf32>)
-// CHECK: %[[IT:.+]] = tensor.empty() : tensor<8x14x14x3x3x4xf32>
+// CHECK-DAG: %[[CS_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0, 1, 2], [3]] : tensor<3x3x4x16xf32> into tensor<36x16xf32>
+// CHECK-DAG: %[[CS_RESULT:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0], [1, 2], [3]] : tensor<8x14x14x16xf32> into tensor<8x196x16xf32>
+// CHECK: %[[IT:.+]] = tensor.empty() : tensor<8x196x36xf32>
// CHECK: %[[IMG2COL:.+]] = linalg.generic
-// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
-// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]
-// CHECK-SAME: ins(%[[INPUT]] : tensor<8x16x16x4xf32>)
-// CHECK-SAME: outs(%[[IT]] : tensor<8x14x14x3x3x4xf32>)
-// CHECK: %[[CS_INPUT:.+]] = tensor.collapse_shape %[[IMG2COL]] {{\[}}[0], [1, 2], [3, 4, 5]] : tensor<8x14x14x3x3x4xf32> into tensor<8x196x36xf32>
-// CHECK: %[[CS_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0, 1, 2], [3]] : tensor<3x3x4x16xf32> into tensor<36x16xf32>
-// CHECK: %[[CS_RESULT:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0], [1, 2], [3]] : tensor<8x14x14x16xf32> into tensor<8x196x16xf32>
+// CHECK-SAME: indexing_maps = [#[[MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
+// CHECK-SAME: outs(%[[IT]] : tensor<8x196x36xf32>)
// CHECK: %[[MATMUL:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[LHSMAP]], #[[RHSMAP]], #[[RESMAP]]],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
-// CHECK-SAME: ins(%[[CS_INPUT]], %[[CS_FILTER]] : tensor<8x196x36xf32>, tensor<36x16xf32>)
+// CHECK-SAME: ins(%[[IMG2COL]], %[[CS_FILTER]] : tensor<8x196x36xf32>, tensor<36x16xf32>)
// CHECK-SAME: outs(%[[CS_RESULT]] : tensor<8x196x16xf32>)
// CHECK: ^bb0(%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32):
// CHECK: %[[MUL:.+]] = arith.mulf %[[ARG0]], %[[ARG1]] : f32
@@ -201,27 +233,55 @@ transform.sequence failures(propagate) {
// -----
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4 + d2, d5 + d3)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
+// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-DAG: #[[LHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
// CHECK-DAG: #[[RHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
// CHECK-DAG: #[[RESMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
// CHECK: func.func @batch_nchw_conv
// CHECK-SAME: (%[[INPUT:.+]]: tensor<8x4x16x16xf32>, %[[FILTER:.+]]: tensor<16x4x3x3xf32>, %[[INIT:.+]]: tensor<8x16x14x14xf32>)
-// CHECK: %[[IT:.+]] = tensor.empty() : tensor<8x4x3x3x14x14xf32>
+// CHECK-DAG: %[[CS_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0], [1, 2, 3]] : tensor<16x4x3x3xf32> into tensor<16x36xf32>
+// CHECK-DAG: %[[CS_RESULT:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0], [1], [2, 3]] : tensor<8x16x14x14xf32> into tensor<8x16x196xf32>
+// CHECK: %[[IT:.+]] = tensor.empty() : tensor<8x36x196xf32>
// CHECK: %[[IMG2COL:.+]] = linalg.generic
-// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
-// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]
-// CHECK-SAME: ins(%[[INPUT]] : tensor<8x4x16x16xf32>)
-// CHECK-SAME: outs(%[[IT]] : tensor<8x4x3x3x14x14xf32>)
-// CHECK: %[[CS_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0], [1, 2, 3]] : tensor<16x4x3x3xf32> into tensor<16x36xf32>
-// CHECK: %[[CS_INPUT:.+]] = tensor.collapse_shape %[[IMG2COL]] {{\[}}[0], [1, 2, 3], [4, 5]] : tensor<8x4x3x3x14x14xf32> into tensor<8x36x196xf32>
-// CHECK: %[[CS_RESULT:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0], [1], [2, 3]] : tensor<8x16x14x14xf32> into tensor<8x16x196xf32>
+// CHECK-SAME: indexing_maps = [#[[MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
+// CHECK-SAME: outs(%[[IT]] : tensor<8x36x196xf32>)
+// Collapsed indices.
+// CHECK: %[[BINDEX:.+]] = linalg.index 0 : index
+// CHECK: %[[KINDEX:.+]] = linalg.index 1 : index
+// CHECK: %[[NINDEX:.+]] = linalg.index 2 : index
+
+// Unrolled filter shape indices.
+// CHECK: %[[C3:.+]] = arith.constant 3 : index
+// CHECK: %[[FWINDEX:.+]] = arith.remui %[[KINDEX]], %[[C3]] : index
+// CHECK: %[[C9:.+]] = arith.constant 9 : index
+// CHECK: %[[FHREM:.+]] = arith.remui %[[KINDEX]], %[[C9]] : index
+// CHECK: %[[C3_1:.+]] = arith.constant 3 : index
+// CHECK: %[[FHINDEX:.+]] = arith.divui %[[FHREM]], %[[C3_1]] : index
+// CHECK: %[[C9_2:.+]] = arith.constant 9 : index
+// CHECK: %[[ICINDEX:.+]] = arith.divui %[[KINDEX]], %[[C9_2]] : index
+
+// Unrolled output shape indices.
+// CHECK: %[[C14:.+]] = arith.constant 14 : index
+// CHECK: %[[OWINDEX:.+]] = arith.remui %[[NINDEX]], %[[C14]] : index
+// CHECK: %[[C14_3:.+]] = arith.constant 14 : index
+// CHECK: %[[OHINDEX:.+]] = arith.divui %[[NINDEX]], %[[C14_3]] : index
+
+// Compute input indices.
+// CHECK: %[[SH:.+]] = arith.constant 1 : index
+// CHECK: %[[STRIDEDOH:.+]] = arith.muli %[[OHINDEX]], %[[SH]] : index
+// CHECK: %[[CONVH:.+]] = arith.addi %[[STRIDEDOH]], %[[FHINDEX]] : index
+// CHECK: %[[SW:.+]] = arith.constant 1 : index
+// CHECK: %[[STRIDEDOW:.+]] = arith.muli %[[OWINDEX]], %[[SW]] : index
+// CHECK: %[[CONVW:.+]] = arith.addi %[[STRIDEDOW]], %[[FWINDEX]] : index
+// CHECK: %[[EXTRACTED_INPUT:.+]] = tensor.extract
+// CHECK-SAME: %[[INPUT]]{{\[}}%[[BINDEX]], %[[ICINDEX]], %[[CONVH]], %[[CONVW]]] : tensor<8x4x16x16xf32>
+// CHECK: linalg.yield %[[EXTRACTED_INPUT]] : f32
// CHECK: %[[MATMUL:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[LHSMAP]], #[[RHSMAP]], #[[RESMAP]]],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
-// CHECK-SAME: ins(%[[CS_FILTER]], %[[CS_INPUT]] : tensor<16x36xf32>, tensor<8x36x196xf32>)
+// CHECK-SAME: ins(%[[CS_FILTER]], %[[IMG2COL]] : tensor<16x36xf32>, tensor<8x36x196xf32>)
// CHECK-SAME: outs(%[[CS_RESULT]] : tensor<8x16x196xf32>)
// CHECK: ^bb0(%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32):
// CHECK: %[[MUL:.+]] = arith.mulf %[[ARG0]], %[[ARG1]] : f32
More information about the Mlir-commits
mailing list