[Mlir-commits] [mlir] [mlir][linalg] Produce canonical linalg.generic for im2col (PR #134675)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Apr 7 08:56:13 PDT 2025


https://github.com/fabrizio-indirli created https://github.com/llvm/llvm-project/pull/134675

Before this patch, the Img2Col transform produced a non-canonical linalg.generic whose input tensor was not reported in the inputs of the operation: instead, it was accessed manually from inside the op body, after an internal calculation of the access offsets. This patch modifies the Im2Col rewrite to produce a canonical linalg.generic whose input is correctly reported in its 'ins()', whose access offsets are computed through an indexing map, and whose body contains only a 'linalg.yield' op.

>From 51960f1766d430e564446e1b35413cd4f3e7232a Mon Sep 17 00:00:00 2001
From: Fabrizio Indirli <Fabrizio.Indirli at arm.com>
Date: Mon, 7 Apr 2025 16:22:08 +0100
Subject: [PATCH] [mlir][linalg] Produce canonical linalg.generic for im2col

Before this patch, the Img2Col transform produced a non-canonical
linalg.generic whose input tensor was not reported in the inputs
of the operation: instead, it was accessed manually from inside the
op body, after an internal calculation of the access offsets.
This patch modifies the Im2Col rewrite to produce a canonical
linalg.generic whose input is correctly reported in its 'ins()',
whose access offsets are computed through an indexing map,
and whose body contains only a 'linalg.yield' op.

Signed-off-by: Fabrizio Indirli <Fabrizio.Indirli at arm.com>
---
 .../Transforms/ConvertConv2DToImg2Col.cpp     | 76 +++++++++++--------
 .../Linalg/convert-conv2d-to-img2col.mlir     | 32 +++-----
 2 files changed, 53 insertions(+), 55 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
index 81d44ba04fa1d..4999e8bc4ecae 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/IR/AffineExpr.h"
@@ -64,6 +65,19 @@ static SmallVector<Value> unrollIndex(OpBuilder &b, Location loc, Value index,
   return *multiIndex;
 }
 
+// Generate the affine expression to compute the convolved index
+// for the input as `oIndex * stride + fIndex`,
+// where oIndex: output iterator; fIndex: filter iterator.
+static AffineExpr getConvolvedExpr(OpBuilder &b, int64_t stride,
+                                   bool useSymbols = true) {
+  AffineExpr oExpr, fExpr;
+  if (useSymbols)
+    bindSymbols(b.getContext(), oExpr, fExpr);
+  else
+    bindDims(b.getContext(), oExpr, fExpr);
+  return AffineExpr(stride * oExpr + fExpr);
+}
+
 // Given indices corresponding to iterators in the output (oIndex) and filter
 // (fIndex) for a convolution, compute the convolved index for the
 // input as `oIndex * stride + fIndex`.
@@ -71,7 +85,7 @@ static Value getConvolvedIndex(OpBuilder &b, Location loc, Value oIndex,
                                Value fIndex, int64_t stride) {
   AffineExpr oExpr, fExpr;
   bindSymbols(b.getContext(), oExpr, fExpr);
-  AffineMap convMap = AffineMap::get(0, 2, stride * oExpr + fExpr);
+  AffineMap convMap = AffineMap::get(0, 2, getConvolvedExpr(b, stride));
   return affine::makeComposedAffineApply(b, loc, convMap, {oIndex, fIndex});
 }
 
@@ -556,44 +570,40 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
   auto reduction = utils::IteratorType::reduction;
   SmallVector<utils::IteratorType> img2colIterators(nloops, parallel);
 
+  // Recover the original iteration indices from the problem/input sizes.
+  auto mIndicesExprs =
+      delinearize(rewriter.getAffineDimExpr(1U), ArrayRef<int64_t>{ow, 1});
+  auto kIndicesExprs = delinearize(rewriter.getAffineDimExpr(2U),
+                                   ArrayRef<int64_t>{fw * ic, ic, 1});
+  auto hIndicesMap = AffineMap::inferFromExprList(
+      {ArrayRef{mIndicesExprs[0], kIndicesExprs[0]}}, rewriter.getContext())[0];
+  auto wIndicesMap = AffineMap::inferFromExprList(
+      {ArrayRef{mIndicesExprs[1], kIndicesExprs[1]}}, rewriter.getContext())[0];
+  // Compute the input indexing map, to map the output indices to the input
+  // offsets
+  auto bIndexExpr = rewriter.getAffineDimExpr(0U);
+  auto hIndexExpr =
+      getConvolvedExpr(rewriter, convOp.getStrides().getValues<int64_t>()[0],
+                       /*useSymbols*/ false)
+          .compose(hIndicesMap);
+  auto wIndexExpr =
+      getConvolvedExpr(rewriter, convOp.getStrides().getValues<int64_t>()[1],
+                       /*useSymbols*/ false)
+          .compose(wIndicesMap);
+  auto cIndexExpr = kIndicesExprs[2];
+  auto inMap = AffineMap::inferFromExprList(
+      {ArrayRef{bIndexExpr, hIndexExpr, wIndexExpr, cIndexExpr}},
+      rewriter.getContext())[0];
+
   SmallVector<AffineMap> img2colIndexingMaps = {
-      AffineMap::getMultiDimIdentityMap(nloops, context)};
+      inMap, AffineMap::getMultiDimIdentityMap(nloops, context)};
 
   auto img2ColTensor = rewriter.create<linalg::GenericOp>(
       loc, colTensor.getType(),
-      /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
+      /*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps,
       img2colIterators,
       [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
-        // Get the iterators named based on the matmul (batch, m, k).
-        Value bIndex = nestedBuilder.create<linalg::IndexOp>(loc, 0);
-        Value mIndex = nestedBuilder.create<linalg::IndexOp>(loc, 1);
-        Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2);
-
-        // Recover the original iteration indices from the problem/input sizes.
-        SmallVector<Value> mIndices = unrollIndex(
-            nestedBuilder, nestedLoc, mIndex, ArrayRef<int64_t>{oh, ow});
-        auto ohIndex = mIndices[0];
-        auto owIndex = mIndices[1];
-
-        SmallVector<Value> kIndices = unrollIndex(
-            nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{fh, fw, ic});
-        auto fhIndex = kIndices[0];
-        auto fwIndex = kIndices[1];
-        auto icIndex = kIndices[2];
-
-        // Extract the input element corresponding to the expanded indices.
-        Value hIndex =
-            getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
-                              convOp.getStrides().getValues<int64_t>()[0]);
-        Value wIndex =
-            getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
-                              convOp.getStrides().getValues<int64_t>()[1]);
-
-        // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
-        SmallVector<Value> extractionIndices{bIndex, hIndex, wIndex, icIndex};
-        Value inputVal = nestedBuilder.create<tensor::ExtractOp>(
-            loc, input, extractionIndices);
-        nestedBuilder.create<linalg::YieldOp>(nestedLoc, inputVal);
+        nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
       });
 
   // Because we didn't transpose the filters we don't actually have a batched
diff --git a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
index c17f20b2d03ab..80ff27da430bf 100644
--- a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
+++ b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
@@ -291,31 +291,19 @@ module attributes {transform.with_named_sequence} {
 
 // CHECK: IR printer: tensor_producer
 // CHECK-NEXT: %[[COL_TENSOR:.+]] = linalg.generic
+// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1 floordiv 14 + d2 floordiv 12, d1 mod 14 + (d2 mod 12) floordiv 4, d2 mod 4)>
 // CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
-// CHECK: ^bb0(%[[OUT_DATA:.+]]: f32)
-
-// Collapsed indices.
-// CHECK: %[[BINDEX:.+]] = linalg.index 0 : index
-// CHECK: %[[MINDEX:.+]] = linalg.index 1 : index
-// CHECK: %[[KINDEX:.+]] = linalg.index 2 : index
-
-// Compute input channel/convolved indices.
-// CHECK: %[[ICINDEX:.+]] = affine.apply affine_map<()[s0] -> (s0 mod 4)>()[%[[KINDEX]]]
-// CHECK: %[[CONVH:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 floordiv 14 + s1 floordiv 12)>()[%[[MINDEX]], %[[KINDEX]]]
-// CHECK: %[[CONVW:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 mod 14 + (s1 mod 12) floordiv 4)>()[%[[MINDEX]], %[[KINDEX]]]
-
-// Extract from the input tensor.
-// CHECK: %[[EXTRACTED_INPUT:.+]] = tensor.extract
-// CHECK-SAME: %{{.+}}{{\[}}%[[BINDEX]], %[[CONVH]], %[[CONVW]], %[[ICINDEX]]] : tensor<1x16x16x4xf32>
-// CHECK: linalg.yield %[[EXTRACTED_INPUT]] : f32
+//     CHECK: ^bb0(%[[IN_DATA:.+]]: f32, %[[OUT_DATA:.+]]: f32)
+//     CHECK: linalg.yield %[[IN_DATA]] : f32
 
 // CHECK: IR printer: transformed
 // CHECK: tensor.expand_shape %{{[^ ]*}} {{\[}}[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32>
 
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
-// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
-// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1 floordiv 14 + d2 floordiv 12, d1 mod 14 + (d2 mod 12) floordiv 4, d2 mod 4)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
+// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
 //      CHECK: @conv_2d_nhwc_fhwc
 //      CHECK-SAME: %[[INPUT:.+]]: tensor<1x16x16x4xf32>
 //      CHECK-SAME: %[[FILTER:.+]]: tensor<16x3x3x4xf32>
@@ -324,13 +312,13 @@ module attributes {transform.with_named_sequence} {
 //  CHECK-DAG: %[[COLLAPSED_OUT:.+]] = tensor.collapse_shape %[[OUTPUT]] {{\[}}[0], [1, 2], [3]] : tensor<1x14x14x16xf32> into tensor<1x196x16xf32>
 //      CHECK: %[[INIT_COL_TENSOR:.+]] = tensor.empty() : tensor<1x196x36xf32>
 //      CHECK: %[[COL_TENSOR:.+]] = linalg.generic
-//           CHECK-SAME: #[[MAP0]]
+//           CHECK-SAME: [#[[MAP0]], #[[MAP1]]]
 //                CHECK: ^bb0(%[[OUT_DATA:.+]]: f32)
 //                CHECK: linalg.yield %{{.+}} : f32
 //      CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic
-//           CHECK-SAME: #[[MAP1]]
 //           CHECK-SAME: #[[MAP2]]
 //           CHECK-SAME: #[[MAP3]]
+//           CHECK-SAME: #[[MAP4]]
 //           CHECK-SAME: ins(%[[COL_TENSOR]], %[[COLLAPSED_FILTER]] : tensor<1x196x36xf32>, tensor<16x36xf32>)
 //           CHECK-SAME: outs(%[[COLLAPSED_OUT]] : tensor<1x196x16xf32>)
 //                CHECK: ^bb0(%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32)



More information about the Mlir-commits mailing list