[Mlir-commits] [mlir] [mlir][linalg] Add NHWC + FHWC Img2Col (PR #68708)

Quinn Dawkins llvmlistbot at llvm.org
Wed Oct 11 10:56:04 PDT 2023


================
@@ -494,6 +494,140 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
                         reshapedResult.getOperation());
 }
 
+FailureOr<std::pair<Operation *, Operation *>>
+rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
+  auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
+  auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
+  auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
+
+  if (!filterType.hasStaticShape())
+    return rewriter.notifyMatchFailure(
+        convOp, "expected a static shape for the filter");
+
+  if (!inputType.hasStaticShape())
+    return rewriter.notifyMatchFailure(convOp,
+                                       "expected a static shape for the input");
+
+  // TODO: Support dilation.
+  if (!hasAllOneValues(convOp.getDilations()))
+    return rewriter.notifyMatchFailure(convOp,
+                                       "expected all ones for dilations");
+
+  MLIRContext *context = rewriter.getContext();
+  Value input = convOp.getInputs()[0];
+  Value filter = convOp.getInputs()[1];
+  Value output = convOp.getOutputs()[0];
+
+  ArrayRef<int64_t> filterShape = filterType.getShape();
+  ArrayRef<int64_t> outputShape = outputType.getShape();
+
+  int64_t n = outputShape[0];
+  int64_t oh = outputShape[1];
+  int64_t ow = outputShape[2];
+  int64_t oc = outputShape[3];
+  int64_t fh = filterShape[1];
+  int64_t fw = filterShape[2];
+  int64_t ic = filterShape[3];
+
+  Location loc = convOp.getLoc();
+
+  // Reshape output and filter to the LHS and result of a (B)MNK matmul.
+  SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
+  auto reshapedFilterType =
+      RankedTensorType::get({oc, fh * fw * ic}, filterType.getElementType());
+  Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
+      loc, reshapedFilterType, filter, filterReassocIndices);
+
+  SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1, 2}, {3}};
+  RankedTensorType reshapedOutputType =
+      RankedTensorType::get({n, oh * ow, oc}, outputType.getElementType());
+  Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>(
+      loc, reshapedOutputType, output, outputReassocIndices);
+
+  SmallVector<int64_t> colTensorShape = {n, oh * ow, fh * fw * ic};
+  Value colTensor = rewriter.create<tensor::EmptyOp>(
+      loc, colTensorShape, inputType.getElementType());
+
+  // Convert the input to a (BMK) column tensor.
+  auto nloops = colTensorShape.size();
+
+  auto parallel = utils::IteratorType::parallel;
+  auto reduction = utils::IteratorType::reduction;
+  SmallVector<utils::IteratorType> img2colIterators(nloops, parallel);
+
+  SmallVector<AffineMap> img2colIndexingMaps = {
+      AffineMap::getMultiDimIdentityMap(nloops, context)};
+
+  auto img2ColTensor = rewriter.create<linalg::GenericOp>(
+      loc, colTensor.getType(),
+      /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
+      img2colIterators,
+      [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+        // Get the iterators named based on the matmul (batch, m, k).
+        Value bIndex = nestedBuilder.create<linalg::IndexOp>(loc, 0);
+        Value mIndex = nestedBuilder.create<linalg::IndexOp>(loc, 1);
+        Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2);
+
+        // Recover the original iteration indices from the problem/input sizes.
+        SmallVector<Value> mIndices = unrollIndex(
+            nestedBuilder, nestedLoc, mIndex, ArrayRef<int64_t>{oh, ow});
+        auto ohIndex = mIndices[0];
+        auto owIndex = mIndices[1];
+
+        SmallVector<Value> kIndices = unrollIndex(
+            nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{fh, fw, ic});
+        auto fhIndex = kIndices[0];
+        auto fwIndex = kIndices[1];
+        auto icIndex = kIndices[2];
+
+        // Extract the input element corresponding to the expanded indices.
+        Value hIndex =
+            getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
+                              convOp.getStrides().getValues<int64_t>()[0]);
+        Value wIndex =
+            getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
+                              convOp.getStrides().getValues<int64_t>()[1]);
+
+        // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
+        SmallVector<Value> extractionIndices{bIndex, hIndex, wIndex, icIndex};
+        Value inputVal = nestedBuilder.create<tensor::ExtractOp>(
+            loc, input, extractionIndices);
+        nestedBuilder.create<linalg::YieldOp>(nestedLoc, inputVal);
+      });
+
+  // Because we didn't transpose the filters we don't actually have a batched
+  // matrix multiply. Instead, we have an operation consisting of "rowise" dot
----------------
qedawkins wrote:

nit: row-wise

https://github.com/llvm/llvm-project/pull/68708


More information about the Mlir-commits mailing list