[Mlir-commits] [mlir] c828030 - [mlir][linalg] Refactor convolution to img2col conversion to use gather semantics

Quinn Dawkins llvmlistbot at llvm.org
Thu Mar 23 16:40:43 PDT 2023


Author: Quinn Dawkins
Date: 2023-03-23T19:38:53-04:00
New Revision: c82803097f6a89edc49577e5bb4f7309e053efcc

URL: https://github.com/llvm/llvm-project/commit/c82803097f6a89edc49577e5bb4f7309e053efcc
DIFF: https://github.com/llvm/llvm-project/commit/c82803097f6a89edc49577e5bb4f7309e053efcc.diff

LOG: [mlir][linalg] Refactor convolution to img2col conversion to use gather semantics

Following up on the comments in https://reviews.llvm.org/D144108 this
patch refactors the im2col conversion patterns for `linalg.conv_2d_nhwc_hwcf`
and `linalg.conv_2d_nchw_fchw` convolutions to use gather semantics for the im2col
packing `linalg.generic`.

Follow up work can include a similar pattern for depthwise convolutions
and a generalization of the patterns here to work with any `LinalgOp` as
well.

Differential Revision: https://reviews.llvm.org/D144678

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
    mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
index 14bff411ef8c1..58a23e2be54d1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
@@ -41,6 +41,49 @@ static Value createMul(Location loc, Value x, Value y, OpBuilder &builder) {
   return builder.create<arith::MulFOp>(loc, x, y);
 }
 
+// Unrolls the given composite `index` into a set of subindices with maximum
+// iteration ranges specified by `factors` according to the following
+// assumptions:
+//   1. The iteration range for `index` is [0, f1 * f2 * ... * fn] i.e. the
+//   product of the given list of factors
+//   2. The iterators corresponding to the entries in `factors` are ordered from
+//   slowest to fastest varying
+// Each subindex is then computed as:
+//    subindex[i] = floor( (index % (fi * ... * fn)) / (fi-1 * ... * fn) )
+static SmallVector<Value, 3> unrollIndex(OpBuilder &b, Location loc,
+                                         Value index,
+                                         ArrayRef<int64_t> factors) {
+  assert(factors.size() >= 1 && "empty factor list");
+  SmallVector<Value, 3> indices(factors.size());
+  int64_t runningProd = 1;
+  for (int i = factors.size() - 1, end = 0; i >= end; i--) {
+    Value unrolledIndex = index;
+    if (i > 0) {
+      Value modBase = b.create<arith::ConstantOp>(
+          loc, b.getIndexAttr(runningProd * factors[i]));
+      unrolledIndex = b.create<arith::RemUIOp>(loc, unrolledIndex, modBase);
+    }
+    if (runningProd > 1) {
+      Value divDenom =
+          b.create<arith::ConstantOp>(loc, b.getIndexAttr(runningProd));
+      unrolledIndex = b.create<arith::DivUIOp>(loc, unrolledIndex, divDenom);
+    }
+    runningProd *= factors[i];
+    indices[i] = unrolledIndex;
+  }
+  return indices;
+}
+
+// Given indices corresponding to iterators in the output (oIndex) and filter
+// (fIndex) for a convolution, compute the convolved index for the
+// input as `oIndex * stride + fIndex`.
+static Value getConvolvedIndex(OpBuilder &b, Location loc, Value oIndex,
+                               Value fIndex, int64_t stride) {
+  Value strideVal = b.create<arith::ConstantOp>(loc, b.getIndexAttr(stride));
+  Value convIndex = b.create<arith::MulIOp>(loc, oIndex, strideVal);
+  return b.create<arith::AddIOp>(loc, convIndex, fIndex);
+}
+
 FailureOr<std::pair<Operation *, Operation *>>
 rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) {
   auto inputType = convOp.getInputs()[0].getType().cast<ShapedType>();
@@ -68,32 +111,34 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) {
   ArrayRef<int64_t> filterShape = filterType.getShape();
   ArrayRef<int64_t> outputShape = outputType.getShape();
 
-  int n = outputShape[0];
-  int oh = outputShape[1];
-  int ow = outputShape[2];
-  int oc = outputShape[3];
-  int fh = filterShape[0];
-  int fw = filterShape[1];
-  int ic = filterShape[2];
+  int64_t n = outputShape[0];
+  int64_t oh = outputShape[1];
+  int64_t ow = outputShape[2];
+  int64_t oc = outputShape[3];
+  int64_t fh = filterShape[0];
+  int64_t fw = filterShape[1];
+  int64_t ic = filterShape[2];
 
   Location loc = convOp.getLoc();
 
-  SmallVector<int64_t> colTensorShape = {n, oh, ow, fh, fw, ic};
+  // Reshape output and filter to the LHS and result of a (B)MNK matmul.
+  SmallVector<ReassociationIndices> filterReassocIndices = {{0, 1, 2}, {3}};
+  auto reshapedFilterType =
+      RankedTensorType::get({fh * fw * ic, oc}, inputType.getElementType());
+  Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
+      loc, reshapedFilterType, filter, filterReassocIndices);
+
+  SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1, 2}, {3}};
+  RankedTensorType reshapedOutputType =
+      RankedTensorType::get({n, oh * ow, oc}, outputType.getElementType());
+  Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>(
+      loc, reshapedOutputType, output, outputReassocIndices);
 
+  SmallVector<int64_t> colTensorShape = {n, oh * ow, fh * fw * ic};
   Value colTensor = rewriter.create<tensor::EmptyOp>(
       loc, colTensorShape, inputType.getElementType());
 
-  AffineExpr nDim, ohDim, owDim, khDim, kwDim, icDim;
-  bindDims(context, nDim, ohDim, owDim, khDim, kwDim, icDim);
-
-  AffineExpr shSym = rewriter.getAffineConstantExpr(
-      convOp.getStrides().getValues<int64_t>()[0]);
-  AffineExpr swSym = rewriter.getAffineConstantExpr(
-      convOp.getStrides().getValues<int64_t>()[1]);
-
-  SmallVector<AffineExpr> inputExprs = {nDim, ohDim * shSym + khDim,
-                                        owDim * swSym + kwDim, icDim};
-
+  // Convert the input to a (BMK) column tensor.
   auto nloops = colTensorShape.size();
 
   auto parallel = utils::IteratorType::parallel;
@@ -101,85 +146,68 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) {
   SmallVector<utils::IteratorType> img2colIterators(nloops, parallel);
 
   SmallVector<AffineMap> img2colIndexingMaps = {
-      AffineMap::get(nloops, 0, inputExprs, context),
       AffineMap::getMultiDimIdentityMap(nloops, context)};
 
   auto img2ColTensor = rewriter.create<linalg::GenericOp>(
       loc, colTensor.getType(),
-      /*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps,
+      /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
       img2colIterators,
       [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
-        nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
+        // Get the iterators named based on the matmul (batch, m, k).
+        Value bIndex = nestedBuilder.create<linalg::IndexOp>(loc, 0);
+        Value mIndex = nestedBuilder.create<linalg::IndexOp>(loc, 1);
+        Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2);
+
+        // Recover the original iteration indices from the problem/input sizes.
+        SmallVector<Value, 3> mIndices = unrollIndex(
+            nestedBuilder, nestedLoc, mIndex, ArrayRef<int64_t>{oh, ow});
+        auto ohIndex = mIndices[0];
+        auto owIndex = mIndices[1];
+
+        SmallVector<Value, 3> kIndices = unrollIndex(
+            nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{fh, fw, ic});
+        auto fhIndex = kIndices[0];
+        auto fwIndex = kIndices[1];
+        auto icIndex = kIndices[2];
+
+        // Extract the input element corresponding to the expanded indices.
+        Value hIndex =
+            getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
+                              convOp.getStrides().getValues<int64_t>()[0]);
+        Value wIndex =
+            getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
+                              convOp.getStrides().getValues<int64_t>()[1]);
+
+        // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
+        SmallVector<Value> extractionIndices{bIndex, hIndex, wIndex, icIndex};
+        Value inputVal = nestedBuilder.create<tensor::ExtractOp>(
+            loc, input, extractionIndices);
+        nestedBuilder.create<linalg::YieldOp>(nestedLoc, inputVal);
       });
 
-  SmallVector<ReassociationIndices> img2ColTensorReassocIndices;
-  SmallVector<ReassociationIndices> outputReassocIndices;
-  RankedTensorType reshapedImg2ColTensorType, reshapedOutputType;
-  if (n == 1) {
-    img2ColTensorReassocIndices = {{0, 1, 2}, {3, 4, 5}};
-    outputReassocIndices = {{0, 1, 2}, {3}};
-
-    reshapedImg2ColTensorType = RankedTensorType::get(
-        {oh * ow, fh * fw * ic}, inputType.getElementType());
-    reshapedOutputType =
-        RankedTensorType::get({oh * ow, oc}, outputType.getElementType());
-  } else {
-    img2ColTensorReassocIndices = {{0}, {1, 2}, {3, 4, 5}};
-    outputReassocIndices = {{0}, {1, 2}, {3}};
-
-    reshapedImg2ColTensorType = RankedTensorType::get(
-        {n, oh * ow, fh * fw * ic}, inputType.getElementType());
-    reshapedOutputType =
-        RankedTensorType::get({n, oh * ow, oc}, outputType.getElementType());
-  }
-
-  SmallVector<ReassociationIndices> filterReassocIndices = {{0, 1, 2}, {3}};
-  auto reshapedFilterType =
-      RankedTensorType::get({fh * fw * ic, oc}, inputType.getElementType());
-
-  Value reshapedImg2ColTensor = rewriter.create<tensor::CollapseShapeOp>(
-      loc, reshapedImg2ColTensorType, img2ColTensor.getResult(0),
-      img2ColTensorReassocIndices);
-
-  Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
-      loc, reshapedFilterType, filter, filterReassocIndices);
-
-  Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>(
-      loc, reshapedOutputType, output, outputReassocIndices);
-
-  Value result;
-  if (n == 1) {
-    auto matmulOp = rewriter.create<linalg::MatmulOp>(
-        loc, reshapedOutputType,
-        ArrayRef<Value>{reshapedImg2ColTensor, reshapedFilter},
-        ArrayRef<Value>{reshapedOutput});
-    result = matmulOp.getResults().front();
-  } else {
-    // For cases where batch is not 1, we need to keep the batch dimension
-    // separate. Because the filter does not share the same batch dimension,
-    // the batch dimension is only used in indexing the input and output. Thus
-    // we cannot use existing linalg named ops like linalg.batch_matmul.
-    // i.e. (B x) M x K * K x N = (B x) M x N
-    AffineExpr bDim, mDim, nDim, kDim;
-    bindDims(context, bDim, mDim, nDim, kDim);
-    auto lhsMap = AffineMap::get(4, 0, {bDim, mDim, kDim}, context);
-    auto rhsMap = AffineMap::get(4, 0, {kDim, nDim}, context);
-    auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
-    SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
-                                                         parallel, reduction};
-
-    auto genericOp = rewriter.create<linalg::GenericOp>(
-        loc, reshapedOutputType,
-        /*inputs=*/ValueRange{reshapedImg2ColTensor, reshapedFilter},
-        /*outputs=*/ValueRange{reshapedOutput},
-        ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
-        [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
-          Value mul = createMul(loc, args[0], args[1], nestedBuilder);
-          Value add = createAdd(loc, mul, args[2], nestedBuilder);
-          nestedBuilder.create<linalg::YieldOp>(nestedLoc, add);
-        });
-    result = genericOp.getResults().front();
-  }
+  // Because the filter does not share the same batch dimension,
+  // the batch dimension is only used in indexing the input and output. Thus
+  // we cannot use existing linalg named ops like linalg.batch_matmul.
+  // i.e. (B x) M x K * K x N = (B x) M x N
+  AffineExpr bDim, mDim, nDim, kDim;
+  bindDims(context, bDim, mDim, nDim, kDim);
+  auto lhsMap = AffineMap::get(4, 0, {bDim, mDim, kDim}, context);
+  auto rhsMap = AffineMap::get(4, 0, {kDim, nDim}, context);
+  auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
+  SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
+                                                       parallel, reduction};
+
+  auto genericOp = rewriter.create<linalg::GenericOp>(
+      loc, reshapedOutputType,
+      /*inputs=*/ValueRange{img2ColTensor.getResult(0), reshapedFilter},
+      /*outputs=*/ValueRange{reshapedOutput},
+      ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
+      [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+        Value mul = createMul(loc, args[0], args[1], nestedBuilder);
+        Value add = createAdd(loc, mul, args[2], nestedBuilder);
+        nestedBuilder.create<linalg::YieldOp>(nestedLoc, add);
+      });
+  Value result = genericOp.getResults().front();
 
   auto reshapedResult = rewriter.create<tensor::ExpandShapeOp>(
       loc, outputType, result, outputReassocIndices);
@@ -367,33 +395,33 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
   auto filterShape = filterType.getShape();
   auto outputShape = outputType.getShape();
 
-  int n = outputShape[0];
-  int oc = outputShape[1];
-  int oh = outputShape[2];
-  int ow = outputShape[3];
-  int ic = filterShape[1];
-  int fh = filterShape[2];
-  int fw = filterShape[3];
+  int64_t n = outputShape[0];
+  int64_t oc = outputShape[1];
+  int64_t oh = outputShape[2];
+  int64_t ow = outputShape[3];
+  int64_t ic = filterShape[1];
+  int64_t fh = filterShape[2];
+  int64_t fw = filterShape[3];
 
   auto loc = convOp.getLoc();
-
-  SmallVector<int64_t, 4> colTensorShape = {n, ic, fh, fw, oh, ow};
-
-  Value colTensor = rewriter.create<tensor::EmptyOp>(
-      loc, colTensorShape, inputType.getElementType());
-
   MLIRContext *context = rewriter.getContext();
 
-  AffineExpr nDim, icDim, khDim, kwDim, ohDim, owDim;
-  bindDims(context, nDim, icDim, khDim, kwDim, ohDim, owDim);
+  SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
+  auto reshapedFilterType =
+      RankedTensorType::get({oc, ic * fh * fw}, inputType.getElementType());
+  Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
+      loc, reshapedFilterType, filter, filterReassocIndices);
 
-  auto shSym = rewriter.getAffineConstantExpr(
-      convOp.getStrides().getValues<int64_t>()[0]);
-  auto swSym = rewriter.getAffineConstantExpr(
-      convOp.getStrides().getValues<int64_t>()[1]);
+  SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1}, {2, 3}};
+  auto reshapedOutputType =
+      RankedTensorType::get({n, oc, oh * ow}, outputType.getElementType());
+  Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>(
+      loc, reshapedOutputType, output, outputReassocIndices);
 
-  SmallVector<AffineExpr, 4> inputExprs = {nDim, icDim, ohDim * shSym + khDim,
-                                           owDim * swSym + kwDim};
+  // Convert the input to a (BKN) tensor.
+  SmallVector<int64_t, 4> colTensorShape = {n, ic * fh * fw, oh * ow};
+  Value colTensor = rewriter.create<tensor::EmptyOp>(
+      loc, colTensorShape, inputType.getElementType());
 
   auto nloops = colTensorShape.size();
 
@@ -402,83 +430,67 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
   SmallVector<utils::IteratorType, 3> img2colIterators(nloops, parallel);
 
   SmallVector<AffineMap, 4> img2colIndexingMaps = {
-      AffineMap::get(nloops, 0, inputExprs, context),
       AffineMap::getMultiDimIdentityMap(nloops, context)};
 
   auto img2ColTensor = rewriter.create<linalg::GenericOp>(
       loc, colTensor.getType(),
-      /*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps,
+      /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
       img2colIterators,
       [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
-        nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
+        // Get the iterators named based on the matmul (batch, m, k).
+        Value bIndex = nestedBuilder.create<linalg::IndexOp>(loc, 0);
+        Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 1);
+        Value nIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2);
+
+        // Recover the original iteration indices from the problem/input sizes.
+        SmallVector<Value, 3> kIndices = unrollIndex(
+            nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{ic, fh, fw});
+        auto icIndex = kIndices[0];
+        auto fhIndex = kIndices[1];
+        auto fwIndex = kIndices[2];
+
+        SmallVector<Value, 3> nIndices = unrollIndex(
+            nestedBuilder, nestedLoc, nIndex, ArrayRef<int64_t>{oh, ow});
+        auto ohIndex = nIndices[0];
+        auto owIndex = nIndices[1];
+
+        // Extract the input element corresponding to the expanded indices.
+        Value hIndex =
+            getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
+                              convOp.getStrides().getValues<int64_t>()[0]);
+        Value wIndex =
+            getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
+                              convOp.getStrides().getValues<int64_t>()[1]);
+
+        // im2col[n, ic*fh*fw, oh*ow] = input[n, ic, sh*oh + fh, sw*ow + fw]
+        SmallVector<Value> extractionIndices{bIndex, icIndex, hIndex, wIndex};
+        Value inputVal = nestedBuilder.create<tensor::ExtractOp>(
+            loc, input, extractionIndices);
+        nestedBuilder.create<linalg::YieldOp>(nestedLoc, inputVal);
       });
 
-  SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
-  auto reshapedFilterType =
-      RankedTensorType::get({oc, fh * fw * ic}, inputType.getElementType());
-  Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
-      loc, reshapedFilterType, filter, filterReassocIndices);
-
-  SmallVector<ReassociationIndices> img2ColTensorReassocIndices;
-  SmallVector<ReassociationIndices> outputReassocIndices;
-  RankedTensorType reshapedImg2ColTensorType, reshapedOutputType;
-  if (n == 1) {
-    img2ColTensorReassocIndices = {{0, 1, 2, 3}, {4, 5}};
-    outputReassocIndices = {{0, 1}, {2, 3}};
-
-    reshapedImg2ColTensorType = RankedTensorType::get(
-        {fh * fw * ic, oh * ow}, inputType.getElementType());
-    reshapedOutputType =
-        RankedTensorType::get({oc, oh * ow}, outputType.getElementType());
-  } else {
-    img2ColTensorReassocIndices = {{0}, {1, 2, 3}, {4, 5}};
-    outputReassocIndices = {{0}, {1}, {2, 3}};
-
-    reshapedImg2ColTensorType = RankedTensorType::get(
-        {n, fh * fw * ic, oh * ow}, inputType.getElementType());
-    reshapedOutputType =
-        RankedTensorType::get({n, oc, oh * ow}, outputType.getElementType());
-  }
-
-  Value reshapedImg2ColTensor = rewriter.create<tensor::CollapseShapeOp>(
-      loc, reshapedImg2ColTensorType, img2ColTensor.getResult(0),
-      img2ColTensorReassocIndices);
-
-  Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>(
-      loc, reshapedOutputType, output, outputReassocIndices);
-
-  Value result;
-  if (n == 1) {
-    auto matmulOp = rewriter.create<linalg::MatmulOp>(
-        loc, reshapedOutputType,
-        ArrayRef<Value>{reshapedFilter, reshapedImg2ColTensor},
-        ArrayRef<Value>{reshapedOutput});
-    result = matmulOp.getResults().front();
-  } else {
-    // For cases where batch is not 1, we need to keep the batch dimension
-    // separate. Because the filter does not share the same batch dimension,
-    // the batch dimension is only used in indexing the input and output. Thus
-    // we cannot use existing linalg named ops like linalg.batch_matmul.
-    // i.e. M x K * (B x) K x N = (B x) M x N
-    AffineExpr bDim, mDim, nDim, kDim;
-    bindDims(context, bDim, mDim, nDim, kDim);
-    auto lhsMap = AffineMap::get(4, 0, {mDim, kDim}, context);
-    auto rhsMap = AffineMap::get(4, 0, {bDim, kDim, nDim}, context);
-    auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
-    SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
-                                                         parallel, reduction};
-    auto genericOp = rewriter.create<linalg::GenericOp>(
-        loc, reshapedOutputType,
-        /*inputs=*/ValueRange{reshapedFilter, reshapedImg2ColTensor},
-        /*outputs=*/ValueRange{reshapedOutput},
-        ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
-        [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
-          Value mul = createMul(loc, args[0], args[1], nestedBuilder);
-          Value add = createAdd(loc, mul, args[2], nestedBuilder);
-          nestedBuilder.create<linalg::YieldOp>(nestedLoc, add);
-        });
-    result = genericOp.getResults().front();
-  }
+  // Because the filter does not share the same batch dimension,
+  // the batch dimension is only used in indexing the input and output. Thus
+  // we cannot use existing linalg named ops like linalg.batch_matmul.
+  // i.e. M x K * (B x) K x N = (B x) M x N
+  AffineExpr bDim, mDim, nDim, kDim;
+  bindDims(context, bDim, mDim, nDim, kDim);
+  auto lhsMap = AffineMap::get(4, 0, {mDim, kDim}, context);
+  auto rhsMap = AffineMap::get(4, 0, {bDim, kDim, nDim}, context);
+  auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
+  SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
+                                                       parallel, reduction};
+  auto genericOp = rewriter.create<linalg::GenericOp>(
+      loc, reshapedOutputType,
+      /*inputs=*/ValueRange{reshapedFilter, img2ColTensor.getResult(0)},
+      /*outputs=*/ValueRange{reshapedOutput},
+      ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
+      [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+        Value mul = createMul(loc, args[0], args[1], nestedBuilder);
+        Value add = createAdd(loc, mul, args[2], nestedBuilder);
+        nestedBuilder.create<linalg::YieldOp>(nestedLoc, add);
+      });
+  Value result = genericOp.getResults().front();
 
   auto reshapedResult = rewriter.create<tensor::ExpandShapeOp>(
       loc, outputType, result, outputReassocIndices);

diff  --git a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
index e33e51ddababb..ffcba1086f3f6 100644
--- a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
+++ b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
@@ -29,36 +29,71 @@ transform.sequence failures(propagate) {
 
 // CHECK: IR printer: tensor_producer
 // CHECK-NEXT: %[[COL_TENSOR:.+]] = linalg.generic
-// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d3, d2 + d4, d5)>,
-// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>]
-// CHECK: ^bb0(%[[IN_DATA:.+]]: f32, %[[OUT_DATA:.+]]: f32)
-// CHECK: linalg.yield %[[IN_DATA]] : f32
+// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
+// CHECK: ^bb0(%[[OUT_DATA:.+]]: f32)
+
+// Collapsed indices.
+// CHECK: %[[BINDEX:.+]] = linalg.index 0 : index
+// CHECK: %[[MINDEX:.+]] = linalg.index 1 : index
+// CHECK: %[[KINDEX:.+]] = linalg.index 2 : index
+
+// Unrolled output shape indices.
+// CHECK: %[[C14:.+]] = arith.constant 14 : index
+// CHECK: %[[OWINDEX:.+]] = arith.remui %[[MINDEX]], %[[C14]] : index
+// CHECK: %[[C14_1:.+]] = arith.constant 14 : index
+// CHECK: %[[OHINDEX:.+]] = arith.divui %[[MINDEX]], %[[C14_1]] : index
+
+// Unrolled filter shape indices.
+// CHECK: %[[C4:.+]] = arith.constant 4 : index
+// CHECK: %[[ICINDEX:.+]] = arith.remui %[[KINDEX]], %[[C4]] : index
+// CHECK: %[[C12:.+]] = arith.constant 12 : index
+// CHECK: %[[FWREM:.+]] = arith.remui %[[KINDEX]], %[[C12]] : index
+// CHECK: %[[C4_2:.+]] = arith.constant 4 : index
+// CHECK: %[[FWINDEX:.+]] = arith.divui %[[FWREM]], %[[C4_2]] : index
+// CHECK: %[[C12_3:.+]] = arith.constant 12 : index
+// CHECK: %[[FHINDEX:.+]] = arith.divui %[[KINDEX]], %[[C12_3]] : index
+
+// Compute input indices.
+// CHECK: %[[SH:.+]] = arith.constant 1 : index
+// CHECK: %[[STRIDEDOH:.+]] = arith.muli %[[OHINDEX]], %[[SH]] : index
+// CHECK: %[[CONVH:.+]] = arith.addi %[[STRIDEDOH]], %[[FHINDEX]] : index
+// CHECK: %[[SW:.+]] = arith.constant 1 : index
+// CHECK: %[[STRIDEDOW:.+]] = arith.muli %[[OWINDEX]], %[[SW]] : index
+// CHECK: %[[CONVW:.+]] = arith.addi %[[STRIDEDOW]], %[[FWINDEX]] : index
+// CHECK: %[[EXTRACTED_INPUT:.+]] = tensor.extract
+// CHECK-SAME: %{{.+}}{{\[}}%[[BINDEX]], %[[CONVH]], %[[CONVW]], %[[ICINDEX]]] : tensor<1x16x16x4xf32>
+// CHECK: linalg.yield %[[EXTRACTED_INPUT]] : f32
 
 // CHECK: IR printer: transformed
-// CHECK: tensor.expand_shape %{{[^ ]*}} {{\[}}[0, 1, 2], [3]] : tensor<196x16xf32> into tensor<1x14x14x16xf32>
+// CHECK: tensor.expand_shape %{{[^ ]*}} {{\[}}[0], [1, 2], [3]] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32>
 
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d3, d2 + d4, d5)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
 //      CHECK: @conv_16433136
-//      CHECK: %[[INPUT:.+]]: tensor<1x16x16x4xf32>
-//      CHECK: %[[FILTER:.+]]: tensor<3x3x4x16xf32>
-//      CHECK: %[[OUTPUT:.+]]: tensor<1x14x14x16xf32>
-//      CHECK: %[[INIT_COL_TENSOR:.+]] = tensor.empty() : tensor<1x14x14x3x3x4xf32>
+//      CHECK-SAME: %[[INPUT:.+]]: tensor<1x16x16x4xf32>
+//      CHECK-SAME: %[[FILTER:.+]]: tensor<3x3x4x16xf32>
+//      CHECK-SAME: %[[OUTPUT:.+]]: tensor<1x14x14x16xf32>
+//  CHECK-DAG: %[[COLLAPSED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0, 1, 2], [3]] : tensor<3x3x4x16xf32> into tensor<36x16xf32>
+//  CHECK-DAG: %[[COLLAPSED_OUT:.+]] = tensor.collapse_shape %[[OUTPUT]] {{\[}}[0], [1, 2], [3]] : tensor<1x14x14x16xf32> into tensor<1x196x16xf32>
+//      CHECK: %[[INIT_COL_TENSOR:.+]] = tensor.empty() : tensor<1x196x36xf32>
 //      CHECK: %[[COL_TENSOR:.+]] = linalg.generic
 //           CHECK-SAME: #[[MAP0]]
+//                CHECK: ^bb0(%[[OUT_DATA:.+]]: f32)
+//                CHECK: linalg.yield %{{.+}} : f32
+//      CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic
 //           CHECK-SAME: #[[MAP1]]
-//                CHECK: ^bb0(%[[IN_DATA:.+]]: f32, %[[OUT_DATA:.+]]: f32)
-//                CHECK: linalg.yield %[[IN_DATA]] : f32
-//      CHECK-DAG: %[[RESHAPED_INIT_COL_TENSOR:.+]] = tensor.collapse_shape %[[COL_TENSOR]]
-//           CHECK-SAME: [0, 1, 2], [3, 4, 5]
-//           CHECK-SAME: tensor<1x14x14x3x3x4xf32> into tensor<196x36xf32>
-//      CHECK-DAG: %[[RESHAPED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]]
-//           CHECK-SAME: [0, 1, 2], [3]
-//           CHECK-SAME: tensor<3x3x4x16xf32> into tensor<36x16xf32>
-//      CHECK-DAG: %[[RESHAPED_OUTPUT:.+]] = tensor.collapse_shape %[[OUTPUT]]
-//           CHECK-SAME: [0, 1, 2], [3]
-//      CHECK: %[[MATMUL_RESULT:.+]] = linalg.matmul ins(%[[RESHAPED_INIT_COL_TENSOR]], %[[RESHAPED_FILTER]] : tensor<196x36xf32>, tensor<36x16xf32>) outs(%[[RESHAPED_OUTPUT]] : tensor<196x16xf32>)
-//      CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0, 1, 2], [3]] : tensor<196x16xf32> into tensor<1x14x14x16xf32>
+//           CHECK-SAME: #[[MAP2]]
+//           CHECK-SAME: #[[MAP3]]
+//           CHECK-SAME: ins(%[[COL_TENSOR]], %[[COLLAPSED_FILTER]] : tensor<1x196x36xf32>, tensor<36x16xf32>)
+//           CHECK-SAME: outs(%[[COLLAPSED_OUT]] : tensor<1x196x16xf32>)
+//                CHECK: ^bb0(%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32)
+//                CHECK:     %[[MUL:.+]] = arith.mulf %[[ARG0]], %[[ARG1]] : f32
+//                CHECK:     %[[ADD:.+]] = arith.addf %[[MUL]], %[[ARG2]] : f32
+//                CHECK:     linalg.yield %[[ADD]] : f32
+//                CHECK: } -> tensor<1x196x16xf32>
+//      CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32>
 //      CHECK: return %[[RESULT]]
 
 func.func @conv_16433136(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
@@ -156,27 +191,24 @@ transform.sequence failures(propagate) {
 
 // -----
 
-//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d3, d2 + d4, d5)>
-//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
+//  CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 //  CHECK-DAG: #[[LHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
 //  CHECK-DAG: #[[RHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
 //  CHECK-DAG: #[[RESMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
 
 //      CHECK: func.func @batch_nhwc_conv
 // CHECK-SAME: (%[[INPUT:.+]]: tensor<8x16x16x4xf32>, %[[FILTER:.+]]: tensor<3x3x4x16xf32>, %[[INIT:.+]]: tensor<8x14x14x16xf32>)
-//      CHECK:   %[[IT:.+]] = tensor.empty() : tensor<8x14x14x3x3x4xf32>
+//  CHECK-DAG:   %[[CS_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0, 1, 2], [3]] : tensor<3x3x4x16xf32> into tensor<36x16xf32>
+//  CHECK-DAG:   %[[CS_RESULT:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0], [1, 2], [3]] : tensor<8x14x14x16xf32> into tensor<8x196x16xf32>
+//      CHECK:   %[[IT:.+]] = tensor.empty() : tensor<8x196x36xf32>
 //      CHECK:   %[[IMG2COL:.+]] = linalg.generic
-// CHECK-SAME:      indexing_maps = [#[[MAP0]], #[[MAP1]]]
-// CHECK-SAME:      iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]
-// CHECK-SAME:   ins(%[[INPUT]] : tensor<8x16x16x4xf32>)
-// CHECK-SAME:   outs(%[[IT]] : tensor<8x14x14x3x3x4xf32>)
-//      CHECK:   %[[CS_INPUT:.+]] = tensor.collapse_shape %[[IMG2COL]] {{\[}}[0], [1, 2], [3, 4, 5]] : tensor<8x14x14x3x3x4xf32> into tensor<8x196x36xf32>
-//      CHECK:   %[[CS_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0, 1, 2], [3]] : tensor<3x3x4x16xf32> into tensor<36x16xf32>
-//      CHECK:   %[[CS_RESULT:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0], [1, 2], [3]] : tensor<8x14x14x16xf32> into tensor<8x196x16xf32>
+// CHECK-SAME:      indexing_maps = [#[[MAP]]]
+// CHECK-SAME:      iterator_types = ["parallel", "parallel", "parallel"]
+// CHECK-SAME:   outs(%[[IT]] : tensor<8x196x36xf32>)
 //      CHECK:   %[[MATMUL:.+]] = linalg.generic
 // CHECK-SAME:      indexing_maps = [#[[LHSMAP]], #[[RHSMAP]], #[[RESMAP]]],
 // CHECK-SAME:      iterator_types = ["parallel", "parallel", "parallel", "reduction"]
-// CHECK-SAME:   ins(%[[CS_INPUT]], %[[CS_FILTER]] : tensor<8x196x36xf32>, tensor<36x16xf32>)
+// CHECK-SAME:   ins(%[[IMG2COL]], %[[CS_FILTER]] : tensor<8x196x36xf32>, tensor<36x16xf32>)
 // CHECK-SAME:   outs(%[[CS_RESULT]] : tensor<8x196x16xf32>)
 //      CHECK:   ^bb0(%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32):
 //      CHECK:     %[[MUL:.+]] = arith.mulf %[[ARG0]], %[[ARG1]] : f32
@@ -201,27 +233,55 @@ transform.sequence failures(propagate) {
 
 // -----
 
-//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4 + d2, d5 + d3)>
-//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
+//  CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 //  CHECK-DAG: #[[LHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
 //  CHECK-DAG: #[[RHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
 //  CHECK-DAG: #[[RESMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
 
 //      CHECK: func.func @batch_nchw_conv
 // CHECK-SAME: (%[[INPUT:.+]]: tensor<8x4x16x16xf32>, %[[FILTER:.+]]: tensor<16x4x3x3xf32>, %[[INIT:.+]]: tensor<8x16x14x14xf32>)
-//      CHECK:   %[[IT:.+]] = tensor.empty() : tensor<8x4x3x3x14x14xf32>
+//  CHECK-DAG:   %[[CS_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0], [1, 2, 3]] : tensor<16x4x3x3xf32> into tensor<16x36xf32>
+//  CHECK-DAG:   %[[CS_RESULT:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0], [1], [2, 3]] : tensor<8x16x14x14xf32> into tensor<8x16x196xf32>
+//      CHECK:   %[[IT:.+]] = tensor.empty() : tensor<8x36x196xf32>
 //      CHECK:   %[[IMG2COL:.+]] = linalg.generic
-// CHECK-SAME:      indexing_maps = [#[[MAP0]], #[[MAP1]]]
-// CHECK-SAME:      iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]
-// CHECK-SAME:   ins(%[[INPUT]] : tensor<8x4x16x16xf32>)
-// CHECK-SAME:   outs(%[[IT]] : tensor<8x4x3x3x14x14xf32>)
-//      CHECK:   %[[CS_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0], [1, 2, 3]] : tensor<16x4x3x3xf32> into tensor<16x36xf32>
-//      CHECK:   %[[CS_INPUT:.+]] = tensor.collapse_shape %[[IMG2COL]] {{\[}}[0], [1, 2, 3], [4, 5]] : tensor<8x4x3x3x14x14xf32> into tensor<8x36x196xf32>
-//      CHECK:   %[[CS_RESULT:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0], [1], [2, 3]] : tensor<8x16x14x14xf32> into tensor<8x16x196xf32>
+// CHECK-SAME:      indexing_maps = [#[[MAP]]]
+// CHECK-SAME:      iterator_types = ["parallel", "parallel", "parallel"]
+// CHECK-SAME:   outs(%[[IT]] : tensor<8x36x196xf32>)
+//      Collapsed indices.
+//      CHECK:       %[[BINDEX:.+]] = linalg.index 0 : index
+//      CHECK:       %[[KINDEX:.+]] = linalg.index 1 : index
+//      CHECK:       %[[NINDEX:.+]] = linalg.index 2 : index
+
+//      Unrolled filter shape indices.
+//      CHECK:       %[[C3:.+]] = arith.constant 3 : index
+//      CHECK:       %[[FWINDEX:.+]] = arith.remui %[[KINDEX]], %[[C3]] : index
+//      CHECK:       %[[C9:.+]] = arith.constant 9 : index
+//      CHECK:       %[[FHREM:.+]] = arith.remui %[[KINDEX]], %[[C9]] : index
+//      CHECK:       %[[C3_1:.+]] = arith.constant 3 : index
+//      CHECK:       %[[FHINDEX:.+]] = arith.divui %[[FHREM]], %[[C3_1]] : index
+//      CHECK:       %[[C9_2:.+]] = arith.constant 9 : index
+//      CHECK:       %[[ICINDEX:.+]] = arith.divui %[[KINDEX]], %[[C9_2]] : index
+
+//      Unrolled output shape indices.
+//      CHECK:       %[[C14:.+]] = arith.constant 14 : index
+//      CHECK:       %[[OWINDEX:.+]] = arith.remui %[[NINDEX]], %[[C14]] : index
+//      CHECK:       %[[C14_3:.+]] = arith.constant 14 : index
+//      CHECK:       %[[OHINDEX:.+]] = arith.divui %[[NINDEX]], %[[C14_3]] : index
+
+//      Compute input indices.
+//      CHECK:       %[[SH:.+]] = arith.constant 1 : index
+//      CHECK:       %[[STRIDEDOH:.+]] = arith.muli %[[OHINDEX]], %[[SH]] : index
+//      CHECK:       %[[CONVH:.+]] = arith.addi %[[STRIDEDOH]], %[[FHINDEX]] : index
+//      CHECK:       %[[SW:.+]] = arith.constant 1 : index
+//      CHECK:       %[[STRIDEDOW:.+]] = arith.muli %[[OWINDEX]], %[[SW]] : index
+//      CHECK:       %[[CONVW:.+]] = arith.addi %[[STRIDEDOW]], %[[FWINDEX]] : index
+//      CHECK:       %[[EXTRACTED_INPUT:.+]] = tensor.extract
+//      CHECK-SAME:  %[[INPUT]]{{\[}}%[[BINDEX]], %[[ICINDEX]], %[[CONVH]], %[[CONVW]]] : tensor<8x4x16x16xf32>
+//      CHECK: linalg.yield %[[EXTRACTED_INPUT]] : f32
 //      CHECK:   %[[MATMUL:.+]] = linalg.generic
 // CHECK-SAME:      indexing_maps = [#[[LHSMAP]], #[[RHSMAP]], #[[RESMAP]]],
 // CHECK-SAME:      iterator_types = ["parallel", "parallel", "parallel", "reduction"]
-// CHECK-SAME:   ins(%[[CS_FILTER]], %[[CS_INPUT]] : tensor<16x36xf32>, tensor<8x36x196xf32>)
+// CHECK-SAME:   ins(%[[CS_FILTER]], %[[IMG2COL]] : tensor<16x36xf32>, tensor<8x36x196xf32>)
 // CHECK-SAME:   outs(%[[CS_RESULT]] : tensor<8x16x196xf32>)
 //      CHECK:   ^bb0(%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32):
 //      CHECK:     %[[MUL:.+]] = arith.mulf %[[ARG0]], %[[ARG1]] : f32


        


More information about the Mlir-commits mailing list