[Mlir-commits] [mlir] [mlir][linalg] Produce canonical linalg.generic for im2col (PR #134675)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jul 16 09:24:16 PDT 2025


https://github.com/fabrizio-indirli updated https://github.com/llvm/llvm-project/pull/134675

>From 001d3e7e73bffafa3a92da73fed6a8644c0d5b22 Mon Sep 17 00:00:00 2001
From: Fabrizio Indirli <Fabrizio.Indirli at arm.com>
Date: Mon, 7 Apr 2025 16:22:08 +0100
Subject: [PATCH] [mlir][linalg] Produce canonical linalg.generic for im2col

Before this patch, the Img2Col transform produced a non-canonical
linalg.generic whose input tensor was not reported in the inputs
of the operation: instead, it was accessed manually from inside the
op body, after an internal calculation of the access offsets.
This patch modifies the Im2Col rewrite to produce a canonical
linalg.generic whose input is correctly reported in its 'ins()',
whose access offsets are computed through an indexing map,
and whose body contains only a 'linalg.yield' op.

Signed-off-by: Fabrizio Indirli <Fabrizio.Indirli at arm.com>
Change-Id: I7bab2cb56d6d5d754de4c869dee1df2eae1b5029
---
 .../Transforms/ConvertConv2DToImg2Col.cpp     | 241 +++++++++++-------
 .../Linalg/convert-conv2d-to-img2col.mlir     | 126 ++++-----
 2 files changed, 189 insertions(+), 178 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
index 81d44ba04fa1d..6b34b498f409a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/IR/AffineExpr.h"
@@ -64,6 +65,19 @@ static SmallVector<Value> unrollIndex(OpBuilder &b, Location loc, Value index,
   return *multiIndex;
 }
 
+// Generate the affine expression to compute the convolved index
+// for the input as `oIndex * stride + fIndex`,
+// where oIndex: output iterator; fIndex: filter iterator.
+static AffineExpr getConvolvedExpr(OpBuilder &b, int64_t stride,
+                                   bool useSymbols = true) {
+  AffineExpr oExpr, fExpr;
+  if (useSymbols)
+    bindSymbols(b.getContext(), oExpr, fExpr);
+  else
+    bindDims(b.getContext(), oExpr, fExpr);
+  return AffineExpr(stride * oExpr + fExpr);
+}
+
 // Given indices corresponding to iterators in the output (oIndex) and filter
 // (fIndex) for a convolution, compute the convolved index for the
 // input as `oIndex * stride + fIndex`.
@@ -71,10 +85,64 @@ static Value getConvolvedIndex(OpBuilder &b, Location loc, Value oIndex,
                                Value fIndex, int64_t stride) {
   AffineExpr oExpr, fExpr;
   bindSymbols(b.getContext(), oExpr, fExpr);
-  AffineMap convMap = AffineMap::get(0, 2, stride * oExpr + fExpr);
+  AffineMap convMap = AffineMap::get(0, 2, getConvolvedExpr(b, stride));
   return affine::makeComposedAffineApply(b, loc, convMap, {oIndex, fIndex});
 }
 
+// Stores the affine expressions to map the iteration space of the im2col matrix
+// to the corresponding indices of the output and filter matrices
+struct Im2ColToOperandsExprs {
+  AffineExpr fhIndex;
+  AffineExpr fwIndex;
+  AffineExpr icIndex;
+  AffineExpr ohIndex;
+  AffineExpr owIndex;
+};
+
+// Stores the affine expressions to map the iteration space of the im2col matrix
+// to the input matrix indices
+struct Im2ColToInputDimsExprs {
+  AffineExpr bIndex;
+  AffineExpr hIndex;
+  AffineExpr wIndex;
+  AffineExpr cIndex;
+};
+
+/// Construct the affine expressions that map the indices of the im2col matrix
+/// to the corresponding input tensor indices for a 2D convolution with the the
+/// provided strides.
+///
+/// @param exprs      Affine expressions for output and filter indices.
+/// @param strides    [height, width] stride values for the convolution.
+/// @param rewriter   Pattern rewriter.
+/// @return           Affine expressions mapping im2col matrix indices to input
+/// offsets.
+static Im2ColToInputDimsExprs getIm2ColInputMap(Im2ColToOperandsExprs exprs,
+                                                ArrayRef<int64_t> strides,
+                                                RewriterBase &rewriter) {
+  // maps the iteration space of the im2col matrix to (output_y, filter_y)
+  auto hIndicesMap = AffineMap::inferFromExprList(
+      {ArrayRef{exprs.ohIndex, exprs.fhIndex}}, rewriter.getContext())[0];
+  // maps the iteration space of the im2col matrix to (output_x, filter_x)
+  auto wIndicesMap = AffineMap::inferFromExprList(
+      {ArrayRef{exprs.owIndex, exprs.fwIndex}}, rewriter.getContext())[0];
+  // Compute the input indexing map, to map the indices of the im2col matrix to
+  // the original input offsets. Each element of the im2col matrix corresponds
+  // to a pair of (out_element, filter_element). First, we build the expressions
+  // to compute the input (ix, iy) indices from [out_x/y, filter_x/y] pairs;
+  // then we compose them with the maps that map the im2col matrix elements to
+  // the (out_element, filter_element) pairs.
+  auto bIndexExpr = rewriter.getAffineDimExpr(0U);
+  auto hIndexExpr = getConvolvedExpr(rewriter, strides[0],
+                                     /*useSymbols*/ false);
+  hIndexExpr = hIndexExpr.compose(hIndicesMap);
+  auto wIndexExpr = getConvolvedExpr(rewriter, strides[1],
+                                     /*useSymbols*/ false);
+  wIndexExpr = wIndexExpr.compose(wIndicesMap);
+  auto cIndexExpr = exprs.icIndex;
+  return {bIndexExpr, hIndexExpr, wIndexExpr, cIndexExpr};
+}
+
 FailureOr<std::pair<Operation *, Operation *>>
 rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) {
   auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
@@ -136,44 +204,36 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) {
   auto reduction = utils::IteratorType::reduction;
   SmallVector<utils::IteratorType> img2colIterators(nloops, parallel);
 
+  // Given an index of the im2col matrix, retrieve the corresponding indices of
+  // the output and filter matrices
+  auto mIndicesExprs =
+      delinearize(rewriter.getAffineDimExpr(1U), ArrayRef<int64_t>{ow, 1});
+  auto kIndicesExprs = delinearize(rewriter.getAffineDimExpr(2U),
+                                   ArrayRef<int64_t>{fw * ic, ic, 1});
+  Im2ColToOperandsExprs i2cToOperExprs;
+  i2cToOperExprs.fhIndex = kIndicesExprs[0];
+  i2cToOperExprs.fwIndex = kIndicesExprs[1];
+  i2cToOperExprs.icIndex = kIndicesExprs[2];
+  i2cToOperExprs.ohIndex = mIndicesExprs[0];
+  i2cToOperExprs.owIndex = mIndicesExprs[1];
+
+  Im2ColToInputDimsExprs inExprs = getIm2ColInputMap(
+      i2cToOperExprs, llvm::to_vector(convOp.getStrides().getValues<int64_t>()),
+      rewriter);
+  auto inMap =
+      AffineMap::inferFromExprList({ArrayRef{inExprs.bIndex, inExprs.hIndex,
+                                             inExprs.wIndex, inExprs.cIndex}},
+                                   rewriter.getContext())[0];
+
   SmallVector<AffineMap> img2colIndexingMaps = {
-      AffineMap::getMultiDimIdentityMap(nloops, context)};
+      inMap, AffineMap::getMultiDimIdentityMap(nloops, context)};
 
   auto img2ColTensor = rewriter.create<linalg::GenericOp>(
       loc, colTensor.getType(),
-      /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
+      /*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps,
       img2colIterators,
       [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
-        // Get the iterators named based on the matmul (batch, m, k).
-        Value bIndex = nestedBuilder.create<linalg::IndexOp>(loc, 0);
-        Value mIndex = nestedBuilder.create<linalg::IndexOp>(loc, 1);
-        Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2);
-
-        // Recover the original iteration indices from the problem/input sizes.
-        SmallVector<Value> mIndices = unrollIndex(
-            nestedBuilder, nestedLoc, mIndex, ArrayRef<int64_t>{oh, ow});
-        auto ohIndex = mIndices[0];
-        auto owIndex = mIndices[1];
-
-        SmallVector<Value> kIndices = unrollIndex(
-            nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{fh, fw, ic});
-        auto fhIndex = kIndices[0];
-        auto fwIndex = kIndices[1];
-        auto icIndex = kIndices[2];
-
-        // Extract the input element corresponding to the expanded indices.
-        Value hIndex =
-            getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
-                              convOp.getStrides().getValues<int64_t>()[0]);
-        Value wIndex =
-            getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
-                              convOp.getStrides().getValues<int64_t>()[1]);
-
-        // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
-        SmallVector<Value> extractionIndices{bIndex, hIndex, wIndex, icIndex};
-        Value inputVal = nestedBuilder.create<tensor::ExtractOp>(
-            loc, input, extractionIndices);
-        nestedBuilder.create<linalg::YieldOp>(nestedLoc, inputVal);
+        nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
       });
 
   // Because the filter does not share the same batch dimension,
@@ -421,44 +481,36 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
   auto reduction = utils::IteratorType::reduction;
   SmallVector<utils::IteratorType, 3> img2colIterators(nloops, parallel);
 
-  SmallVector<AffineMap, 4> img2colIndexingMaps = {
-      AffineMap::getMultiDimIdentityMap(nloops, context)};
+  // Recover the original iteration indices from the problem/input sizes:
+  // given an index of the im2col matrix, retrieve the corresponding indices of
+  // the output and filter matrices
+  auto kIndicesExprs = delinearize(rewriter.getAffineDimExpr(1U),
+                                   ArrayRef<int64_t>{fh * fw, fw, 1});
+  auto mIndicesExprs =
+      delinearize(rewriter.getAffineDimExpr(2U), ArrayRef<int64_t>{ow, 1});
+  Im2ColToOperandsExprs i2cToOperExprs;
+  i2cToOperExprs.icIndex = kIndicesExprs[0];
+  i2cToOperExprs.fhIndex = kIndicesExprs[1];
+  i2cToOperExprs.fwIndex = kIndicesExprs[2];
+  i2cToOperExprs.ohIndex = mIndicesExprs[0];
+  i2cToOperExprs.owIndex = mIndicesExprs[1];
+  Im2ColToInputDimsExprs inExprs = getIm2ColInputMap(
+      i2cToOperExprs, llvm::to_vector(convOp.getStrides().getValues<int64_t>()),
+      rewriter);
+  auto inMap =
+      AffineMap::inferFromExprList({ArrayRef{inExprs.bIndex, inExprs.cIndex,
+                                             inExprs.hIndex, inExprs.wIndex}},
+                                   rewriter.getContext())[0];
+  // im2col[n, ic*fh*fw, oh*ow] = input[n, ic, sh*oh + fh, sw*ow + fw]
+  SmallVector<AffineMap> img2colIndexingMaps = {
+      inMap, AffineMap::getMultiDimIdentityMap(nloops, context)};
 
   auto img2ColTensor = rewriter.create<linalg::GenericOp>(
       loc, colTensor.getType(),
-      /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
+      /*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps,
       img2colIterators,
       [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
-        // Get the iterators named based on the matmul (batch, m, k).
-        Value bIndex = nestedBuilder.create<linalg::IndexOp>(loc, 0);
-        Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 1);
-        Value nIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2);
-
-        // Recover the original iteration indices from the problem/input sizes.
-        SmallVector<Value> kIndices = unrollIndex(
-            nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{ic, fh, fw});
-        auto icIndex = kIndices[0];
-        auto fhIndex = kIndices[1];
-        auto fwIndex = kIndices[2];
-
-        SmallVector<Value> nIndices = unrollIndex(
-            nestedBuilder, nestedLoc, nIndex, ArrayRef<int64_t>{oh, ow});
-        auto ohIndex = nIndices[0];
-        auto owIndex = nIndices[1];
-
-        // Extract the input element corresponding to the expanded indices.
-        Value hIndex =
-            getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
-                              convOp.getStrides().getValues<int64_t>()[0]);
-        Value wIndex =
-            getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
-                              convOp.getStrides().getValues<int64_t>()[1]);
-
-        // im2col[n, ic*fh*fw, oh*ow] = input[n, ic, sh*oh + fh, sw*ow + fw]
-        SmallVector<Value> extractionIndices{bIndex, icIndex, hIndex, wIndex};
-        Value inputVal = nestedBuilder.create<tensor::ExtractOp>(
-            loc, input, extractionIndices);
-        nestedBuilder.create<linalg::YieldOp>(nestedLoc, inputVal);
+        nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
       });
 
   // Because the filter does not share the same batch dimension,
@@ -545,6 +597,7 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
   Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>(
       loc, reshapedOutputType, output, outputReassocIndices);
 
+  // Shape of the Toeplitz matrix produced by Im2col.
   SmallVector<int64_t> colTensorShape = {n, oh * ow, fh * fw * ic};
   Value colTensor = rewriter.create<tensor::EmptyOp>(
       loc, colTensorShape, inputType.getElementType());
@@ -556,44 +609,36 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
   auto reduction = utils::IteratorType::reduction;
   SmallVector<utils::IteratorType> img2colIterators(nloops, parallel);
 
+  // Given an index of the im2col matrix, retrieve the corresponding indices of
+  // the output and filter matrices
+  auto mIndicesExprs =
+      delinearize(rewriter.getAffineDimExpr(1U), ArrayRef<int64_t>{ow, 1});
+  auto kIndicesExprs = delinearize(rewriter.getAffineDimExpr(2U),
+                                   ArrayRef<int64_t>{fw * ic, ic, 1});
+  Im2ColToOperandsExprs i2cToOperExprs;
+  i2cToOperExprs.fhIndex = kIndicesExprs[0];
+  i2cToOperExprs.fwIndex = kIndicesExprs[1];
+  i2cToOperExprs.icIndex = kIndicesExprs[2];
+  i2cToOperExprs.ohIndex = mIndicesExprs[0];
+  i2cToOperExprs.owIndex = mIndicesExprs[1];
+
+  // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
+  Im2ColToInputDimsExprs inExprs = getIm2ColInputMap(
+      i2cToOperExprs, llvm::to_vector(convOp.getStrides().getValues<int64_t>()),
+      rewriter);
+  auto inMap =
+      AffineMap::inferFromExprList({ArrayRef{inExprs.bIndex, inExprs.hIndex,
+                                             inExprs.wIndex, inExprs.cIndex}},
+                                   rewriter.getContext())[0];
   SmallVector<AffineMap> img2colIndexingMaps = {
-      AffineMap::getMultiDimIdentityMap(nloops, context)};
+      inMap, AffineMap::getMultiDimIdentityMap(nloops, context)};
 
   auto img2ColTensor = rewriter.create<linalg::GenericOp>(
       loc, colTensor.getType(),
-      /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
+      /*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps,
       img2colIterators,
       [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
-        // Get the iterators named based on the matmul (batch, m, k).
-        Value bIndex = nestedBuilder.create<linalg::IndexOp>(loc, 0);
-        Value mIndex = nestedBuilder.create<linalg::IndexOp>(loc, 1);
-        Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2);
-
-        // Recover the original iteration indices from the problem/input sizes.
-        SmallVector<Value> mIndices = unrollIndex(
-            nestedBuilder, nestedLoc, mIndex, ArrayRef<int64_t>{oh, ow});
-        auto ohIndex = mIndices[0];
-        auto owIndex = mIndices[1];
-
-        SmallVector<Value> kIndices = unrollIndex(
-            nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{fh, fw, ic});
-        auto fhIndex = kIndices[0];
-        auto fwIndex = kIndices[1];
-        auto icIndex = kIndices[2];
-
-        // Extract the input element corresponding to the expanded indices.
-        Value hIndex =
-            getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
-                              convOp.getStrides().getValues<int64_t>()[0]);
-        Value wIndex =
-            getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
-                              convOp.getStrides().getValues<int64_t>()[1]);
-
-        // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
-        SmallVector<Value> extractionIndices{bIndex, hIndex, wIndex, icIndex};
-        Value inputVal = nestedBuilder.create<tensor::ExtractOp>(
-            loc, input, extractionIndices);
-        nestedBuilder.create<linalg::YieldOp>(nestedLoc, inputVal);
+        nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
       });
 
   // Because we didn't transpose the filters we don't actually have a batched
diff --git a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
index c17f20b2d03ab..8627fcd2576b9 100644
--- a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
+++ b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
@@ -34,40 +34,35 @@ module attributes {transform.with_named_sequence} {
 // CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
 // CHECK: ^bb0(%[[OUT_DATA:.+]]: f32)
 
-// Collapsed indices.
-// CHECK: %[[BINDEX:.+]] = linalg.index 0 : index
-// CHECK: %[[MINDEX:.+]] = linalg.index 1 : index
-// CHECK: %[[KINDEX:.+]] = linalg.index 2 : index
-
-// Compute input channel/convolved indices.
-// CHECK: %[[ICINDEX:.+]] = affine.apply affine_map<()[s0] -> (s0 mod 4)>()[%[[KINDEX]]]
-// CHECK: %[[CONVH:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 floordiv 14 + s1 floordiv 12)>()[%[[MINDEX]], %[[KINDEX]]]
-// CHECK: %[[CONVW:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 mod 14 + (s1 mod 12) floordiv 4)>()[%[[MINDEX]], %[[KINDEX]]]
-
-// Extract from the input tensor.
-// CHECK: %[[EXTRACTED_INPUT:.+]] = tensor.extract
-// CHECK-SAME: %{{.+}}{{\[}}%[[BINDEX]], %[[CONVH]], %[[CONVW]], %[[ICINDEX]]] : tensor<1x16x16x4xf32>
-// CHECK: linalg.yield %[[EXTRACTED_INPUT]] : f32
-
 // CHECK: IR printer: transformed
 // CHECK: tensor.expand_shape %{{[^ ]*}} {{\[}}[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32>
 
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// Im2col maps
+// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1 floordiv 14 + d2 floordiv 12, d1 mod 14 + (d2 mod 12) floordiv 4, d2 mod 4)>
+// CHECK-DAG: #[[MAPI2C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// Matmul maps
 // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
 // CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
 // CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
-//      CHECK: @conv_16433136
-//      CHECK-SAME: %[[INPUT:.+]]: tensor<1x16x16x4xf32>
-//      CHECK-SAME: %[[FILTER:.+]]: tensor<3x3x4x16xf32>
-//      CHECK-SAME: %[[OUTPUT:.+]]: tensor<1x14x14x16xf32>
+
+//  CHECK: @conv_16433136
+//  CHECK-SAME: %[[INPUT:.+]]: tensor<1x16x16x4xf32>
+//  CHECK-SAME: %[[FILTER:.+]]: tensor<3x3x4x16xf32>
+//  CHECK-SAME: %[[OUTPUT:.+]]: tensor<1x14x14x16xf32>
 //  CHECK-DAG: %[[COLLAPSED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0, 1, 2], [3]] : tensor<3x3x4x16xf32> into tensor<36x16xf32>
 //  CHECK-DAG: %[[COLLAPSED_OUT:.+]] = tensor.collapse_shape %[[OUTPUT]] {{\[}}[0], [1, 2], [3]] : tensor<1x14x14x16xf32> into tensor<1x196x16xf32>
-//      CHECK: %[[INIT_COL_TENSOR:.+]] = tensor.empty() : tensor<1x196x36xf32>
-//      CHECK: %[[COL_TENSOR:.+]] = linalg.generic
-//           CHECK-SAME: #[[MAP0]]
-//                CHECK: ^bb0(%[[OUT_DATA:.+]]: f32)
-//                CHECK: linalg.yield %{{.+}} : f32
-//      CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic
+//  CHECK: %[[INIT_COL_TENSOR:.+]] = tensor.empty() : tensor<1x196x36xf32>
+
+//  CHECK:   %[[COL_TENSOR:.+]] = linalg.generic
+//  CHECK-SAME:      indexing_maps = [#[[MAP]], #[[MAPI2C]]]
+//  CHECK-SAME:      iterator_types = ["parallel", "parallel", "parallel"]
+//  CHECK-SAME:   ins(%[[INPUT]] : tensor<1x16x16x4xf32>)
+//  CHECK-SAME:   outs(%[[INIT_COL_TENSOR]] : tensor<1x196x36xf32>)
+//  CHECK:         ^bb0(%[[IN:.+]]: f32, %out: f32):
+//  CHECK:          linalg.yield %[[IN]] : f32
+//  CHECK:   } -> tensor<1x196x36xf32>
+
+//  CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic
 //           CHECK-SAME: #[[MAP1]]
 //           CHECK-SAME: #[[MAP2]]
 //           CHECK-SAME: #[[MAP3]]
@@ -180,7 +175,10 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-//  CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+//  Im2col maps
+//  CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1 floordiv 14 + d2 floordiv 12, d1 mod 14 + (d2 mod 12) floordiv 4, d2 mod 4)>
+//  CHECK-DAG: #[[MAPI2C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
 //  CHECK-DAG: #[[LHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
 //  CHECK-DAG: #[[RHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
 //  CHECK-DAG: #[[RESMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
@@ -191,9 +189,13 @@ module attributes {transform.with_named_sequence} {
 //  CHECK-DAG:   %[[CS_RESULT:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0], [1, 2], [3]] : tensor<8x14x14x16xf32> into tensor<8x196x16xf32>
 //      CHECK:   %[[IT:.+]] = tensor.empty() : tensor<8x196x36xf32>
 //      CHECK:   %[[IMG2COL:.+]] = linalg.generic
-// CHECK-SAME:      indexing_maps = [#[[MAP]]]
+// CHECK-SAME:      indexing_maps = [#[[MAP]], #[[MAPI2C]]]
 // CHECK-SAME:      iterator_types = ["parallel", "parallel", "parallel"]
+// CHECK-SAME:   ins(%[[INPUT]] : tensor<8x16x16x4xf32>)
 // CHECK-SAME:   outs(%[[IT]] : tensor<8x196x36xf32>)
+// CHECK:         ^bb0(%[[IN:.+]]: f32, %out: f32):
+//      CHECK:     linalg.yield %[[IN]] : f32
+//      CHECK:   } -> tensor<8x196x36xf32>
 //      CHECK:   %[[MATMUL:.+]] = linalg.generic
 // CHECK-SAME:      indexing_maps = [#[[LHSMAP]], #[[RHSMAP]], #[[RESMAP]]],
 // CHECK-SAME:      iterator_types = ["parallel", "parallel", "parallel", "reduction"]
@@ -224,13 +226,9 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-//  CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-
 //  Im2col maps
-//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 floordiv 9)>
-//  CHECK-DAG: #[[MAP7:.+]] = affine_map<()[s0, s1] -> (s0 floordiv 14 + (s1 mod 9) floordiv 3)>
-//  CHECK-DAG: #[[MAP8:.+]] = affine_map<()[s0, s1] -> (s0 + s1 - (s0 floordiv 14) * 14 - (s1 floordiv 3) * 3)>
-
+//  CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1 floordiv 9, d2 floordiv 14 + (d1 mod 9) floordiv 3, d2 mod 14 + d1 mod 3)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 
 //  CHECK-DAG: #[[LHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
 //  CHECK-DAG: #[[RHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
@@ -242,32 +240,12 @@ module attributes {transform.with_named_sequence} {
 //  CHECK-DAG:   %[[CS_RESULT:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0], [1], [2, 3]] : tensor<8x16x14x14xf32> into tensor<8x16x196xf32>
 //      CHECK:   %[[IT:.+]] = tensor.empty() : tensor<8x36x196xf32>
 //      CHECK:   %[[IMG2COL:.+]] = linalg.generic
-// CHECK-SAME:      indexing_maps = [#[[MAP]]]
+// CHECK-SAME:      indexing_maps = [#[[MAP]], #[[MAP1]]]
 // CHECK-SAME:      iterator_types = ["parallel", "parallel", "parallel"]
+// CHECK-SAME:   ins(%[[INPUT]] : tensor<8x4x16x16xf32>)
 // CHECK-SAME:   outs(%[[IT]] : tensor<8x36x196xf32>)
-//      Collapsed indices.
-//      CHECK:       %[[BINDEX:.+]] = linalg.index 0 : index
-//      CHECK:       %[[KINDEX:.+]] = linalg.index 1 : index
-//      CHECK:       %[[NINDEX:.+]] = linalg.index 2 : index
-
-//      Compute input channel/convolved indices.
-//      CHECK:       %[[ICINDEX:.+]] = affine.apply #[[MAP1]]()[%[[KINDEX]]]
-//      CHECK:       %[[CONVH:.+]] = affine.apply #[[MAP7]]()[%[[NINDEX]], %[[KINDEX]]]
-//      CHECK:       %[[CONVW:.+]] = affine.apply #[[MAP8]]()[%[[NINDEX]], %[[KINDEX]]]
-
-//      Extract from the input tensor.
-//      CHECK:       %[[EXTRACTED_INPUT:.+]] = tensor.extract
-//      CHECK-SAME:  %[[INPUT]]{{\[}}%[[BINDEX]], %[[ICINDEX]], %[[CONVH]], %[[CONVW]]] : tensor<8x4x16x16xf32>
-//      CHECK: linalg.yield %[[EXTRACTED_INPUT]] : f32
-//      CHECK:   %[[MATMUL:.+]] = linalg.generic
-// CHECK-SAME:      indexing_maps = [#[[LHSMAP]], #[[RHSMAP]], #[[RESMAP]]],
-// CHECK-SAME:      iterator_types = ["parallel", "parallel", "parallel", "reduction"]
-// CHECK-SAME:   ins(%[[CS_FILTER]], %[[IMG2COL]] : tensor<16x36xf32>, tensor<8x36x196xf32>)
-// CHECK-SAME:   outs(%[[CS_RESULT]] : tensor<8x16x196xf32>)
-//      CHECK:   ^bb0(%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32):
-//      CHECK:     %[[MUL:.+]] = arith.mulf %[[ARG0]], %[[ARG1]] : f32
-//      CHECK:     %[[ADD:.+]] = arith.addf %[[MUL]], %[[ARG2]] : f32
-//      CHECK:     linalg.yield %[[ADD]] : f32
+// CHECK:         ^bb0(%[[IN:.+]]: f32, %out: f32):
+//      CHECK:     linalg.yield %[[IN]] : f32
 //      CHECK:   } -> tensor<8x16x196xf32>
 //      CHECK:   %[[CS_FINAL:.+]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0], [1], [2, 3]] output_shape [8, 16, 14, 14] : tensor<8x16x196xf32> into tensor<8x16x14x14xf32>
 //      CHECK:   return %[[CS_FINAL]]
@@ -291,31 +269,19 @@ module attributes {transform.with_named_sequence} {
 
 // CHECK: IR printer: tensor_producer
 // CHECK-NEXT: %[[COL_TENSOR:.+]] = linalg.generic
+// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1 floordiv 14 + d2 floordiv 12, d1 mod 14 + (d2 mod 12) floordiv 4, d2 mod 4)>
 // CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
-// CHECK: ^bb0(%[[OUT_DATA:.+]]: f32)
-
-// Collapsed indices.
-// CHECK: %[[BINDEX:.+]] = linalg.index 0 : index
-// CHECK: %[[MINDEX:.+]] = linalg.index 1 : index
-// CHECK: %[[KINDEX:.+]] = linalg.index 2 : index
-
-// Compute input channel/convolved indices.
-// CHECK: %[[ICINDEX:.+]] = affine.apply affine_map<()[s0] -> (s0 mod 4)>()[%[[KINDEX]]]
-// CHECK: %[[CONVH:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 floordiv 14 + s1 floordiv 12)>()[%[[MINDEX]], %[[KINDEX]]]
-// CHECK: %[[CONVW:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 mod 14 + (s1 mod 12) floordiv 4)>()[%[[MINDEX]], %[[KINDEX]]]
-
-// Extract from the input tensor.
-// CHECK: %[[EXTRACTED_INPUT:.+]] = tensor.extract
-// CHECK-SAME: %{{.+}}{{\[}}%[[BINDEX]], %[[CONVH]], %[[CONVW]], %[[ICINDEX]]] : tensor<1x16x16x4xf32>
-// CHECK: linalg.yield %[[EXTRACTED_INPUT]] : f32
+//     CHECK: ^bb0(%[[IN_DATA:.+]]: f32, %[[OUT_DATA:.+]]: f32)
+//     CHECK: linalg.yield %[[IN_DATA]] : f32
 
 // CHECK: IR printer: transformed
 // CHECK: tensor.expand_shape %{{[^ ]*}} {{\[}}[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32>
 
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
-// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
-// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1 floordiv 14 + d2 floordiv 12, d1 mod 14 + (d2 mod 12) floordiv 4, d2 mod 4)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
+// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
 //      CHECK: @conv_2d_nhwc_fhwc
 //      CHECK-SAME: %[[INPUT:.+]]: tensor<1x16x16x4xf32>
 //      CHECK-SAME: %[[FILTER:.+]]: tensor<16x3x3x4xf32>
@@ -324,13 +290,13 @@ module attributes {transform.with_named_sequence} {
 //  CHECK-DAG: %[[COLLAPSED_OUT:.+]] = tensor.collapse_shape %[[OUTPUT]] {{\[}}[0], [1, 2], [3]] : tensor<1x14x14x16xf32> into tensor<1x196x16xf32>
 //      CHECK: %[[INIT_COL_TENSOR:.+]] = tensor.empty() : tensor<1x196x36xf32>
 //      CHECK: %[[COL_TENSOR:.+]] = linalg.generic
-//           CHECK-SAME: #[[MAP0]]
+//           CHECK-SAME: [#[[MAP0]], #[[MAP1]]]
 //                CHECK: ^bb0(%[[OUT_DATA:.+]]: f32)
 //                CHECK: linalg.yield %{{.+}} : f32
 //      CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic
-//           CHECK-SAME: #[[MAP1]]
 //           CHECK-SAME: #[[MAP2]]
 //           CHECK-SAME: #[[MAP3]]
+//           CHECK-SAME: #[[MAP4]]
 //           CHECK-SAME: ins(%[[COL_TENSOR]], %[[COLLAPSED_FILTER]] : tensor<1x196x36xf32>, tensor<16x36xf32>)
 //           CHECK-SAME: outs(%[[COLLAPSED_OUT]] : tensor<1x196x16xf32>)
 //                CHECK: ^bb0(%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32)



More information about the Mlir-commits mailing list