[Mlir-commits] [mlir] [mlir][linalg] Produce canonical linalg.generic for im2col (PR #134675)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jul 16 09:27:05 PDT 2025
https://github.com/fabrizio-indirli updated https://github.com/llvm/llvm-project/pull/134675
>From 63f7be0eba8f48374f8bbb01dba07a2a33026411 Mon Sep 17 00:00:00 2001
From: Fabrizio Indirli <Fabrizio.Indirli at arm.com>
Date: Mon, 7 Apr 2025 16:22:08 +0100
Subject: [PATCH] [mlir][linalg] Produce canonical linalg.generic for im2col
Before this patch, the Img2Col transform produced a non-canonical
linalg.generic whose input tensor was not reported in the inputs
of the operation: instead, it was accessed manually from inside the
op body, after an internal calculation of the access offsets.
This patch modifies the Im2Col rewrite to produce a canonical
linalg.generic whose input is correctly reported in its 'ins()',
whose access offsets are computed through an indexing map,
and whose body contains only a 'linalg.yield' op.
Signed-off-by: Fabrizio Indirli <Fabrizio.Indirli at arm.com>
Change-Id: I7bab2cb56d6d5d754de4c869dee1df2eae1b5029
---
.../Transforms/ConvertConv2DToImg2Col.cpp | 255 ++++++++++--------
.../Linalg/convert-conv2d-to-img2col.mlir | 126 ++++-----
2 files changed, 184 insertions(+), 197 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
index 81d44ba04fa1d..1e8a37d9c64d1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/AffineExpr.h"
@@ -51,28 +52,71 @@ static Value createMul(Location loc, Value x, Value y, Type accType,
return builder.create<arith::MulFOp>(loc, xConvert, yConvert);
}
-// Delinearizes the given composite `index` by the basis specified in `factors`.
-static SmallVector<Value> unrollIndex(OpBuilder &b, Location loc, Value index,
- ArrayRef<int64_t> factors) {
- assert(!factors.empty() && "empty factor list");
- SmallVector<Value> basis;
- for (int64_t f : factors)
- basis.push_back(b.create<arith::ConstantOp>(loc, b.getIndexAttr(f)));
- FailureOr<SmallVector<Value>> multiIndex =
- affine::delinearizeIndex(b, loc, index, basis);
- assert(!failed(multiIndex) && "Failed to linearize img2col index");
- return *multiIndex;
+// Generate the affine expression to compute the convolved index
+// for the input as `oIndex * stride + fIndex`,
+// where oIndex: output iterator; fIndex: filter iterator.
+static AffineExpr getConvolvedExpr(OpBuilder &b, int64_t stride,
+ bool useSymbols = true) {
+ AffineExpr oExpr, fExpr;
+ if (useSymbols)
+ bindSymbols(b.getContext(), oExpr, fExpr);
+ else
+ bindDims(b.getContext(), oExpr, fExpr);
+ return AffineExpr(stride * oExpr + fExpr);
}
-// Given indices corresponding to iterators in the output (oIndex) and filter
-// (fIndex) for a convolution, compute the convolved index for the
-// input as `oIndex * stride + fIndex`.
-static Value getConvolvedIndex(OpBuilder &b, Location loc, Value oIndex,
- Value fIndex, int64_t stride) {
- AffineExpr oExpr, fExpr;
- bindSymbols(b.getContext(), oExpr, fExpr);
- AffineMap convMap = AffineMap::get(0, 2, stride * oExpr + fExpr);
- return affine::makeComposedAffineApply(b, loc, convMap, {oIndex, fIndex});
+// Stores the affine expressions to map the iteration space of the im2col matrix
+// to the corresponding indices of the output and filter matrices
+struct Im2ColToOperandsExprs {
+ AffineExpr fhIndex;
+ AffineExpr fwIndex;
+ AffineExpr icIndex;
+ AffineExpr ohIndex;
+ AffineExpr owIndex;
+};
+
+// Stores the affine expressions to map the iteration space of the im2col matrix
+// to the input matrix indices
+struct Im2ColToInputDimsExprs {
+ AffineExpr bIndex;
+ AffineExpr hIndex;
+ AffineExpr wIndex;
+ AffineExpr cIndex;
+};
+
+/// Construct the affine expressions that map the indices of the im2col matrix
+/// to the corresponding input tensor indices for a 2D convolution with the the
+/// provided strides.
+///
+/// @param exprs Affine expressions for output and filter indices.
+/// @param strides [height, width] stride values for the convolution.
+/// @param rewriter Pattern rewriter.
+/// @return Affine expressions mapping im2col matrix indices to input
+/// offsets.
+static Im2ColToInputDimsExprs getIm2ColInputMap(Im2ColToOperandsExprs exprs,
+ ArrayRef<int64_t> strides,
+ RewriterBase &rewriter) {
+ // maps the iteration space of the im2col matrix to (output_y, filter_y)
+ auto hIndicesMap = AffineMap::inferFromExprList(
+ {ArrayRef{exprs.ohIndex, exprs.fhIndex}}, rewriter.getContext())[0];
+ // maps the iteration space of the im2col matrix to (output_x, filter_x)
+ auto wIndicesMap = AffineMap::inferFromExprList(
+ {ArrayRef{exprs.owIndex, exprs.fwIndex}}, rewriter.getContext())[0];
+ // Compute the input indexing map, to map the indices of the im2col matrix to
+ // the original input offsets. Each element of the im2col matrix corresponds
+ // to a pair of (out_element, filter_element). First, we build the expressions
+ // to compute the input (ix, iy) indices from [out_x/y, filter_x/y] pairs;
+ // then we compose them with the maps that map the im2col matrix elements to
+ // the (out_element, filter_element) pairs.
+ auto bIndexExpr = rewriter.getAffineDimExpr(0U);
+ auto hIndexExpr = getConvolvedExpr(rewriter, strides[0],
+ /*useSymbols*/ false);
+ hIndexExpr = hIndexExpr.compose(hIndicesMap);
+ auto wIndexExpr = getConvolvedExpr(rewriter, strides[1],
+ /*useSymbols*/ false);
+ wIndexExpr = wIndexExpr.compose(wIndicesMap);
+ auto cIndexExpr = exprs.icIndex;
+ return {bIndexExpr, hIndexExpr, wIndexExpr, cIndexExpr};
}
FailureOr<std::pair<Operation *, Operation *>>
@@ -136,44 +180,36 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) {
auto reduction = utils::IteratorType::reduction;
SmallVector<utils::IteratorType> img2colIterators(nloops, parallel);
+ // Given an index of the im2col matrix, retrieve the corresponding indices of
+ // the output and filter matrices
+ auto mIndicesExprs =
+ delinearize(rewriter.getAffineDimExpr(1U), ArrayRef<int64_t>{ow, 1});
+ auto kIndicesExprs = delinearize(rewriter.getAffineDimExpr(2U),
+ ArrayRef<int64_t>{fw * ic, ic, 1});
+ Im2ColToOperandsExprs i2cToOperExprs;
+ i2cToOperExprs.fhIndex = kIndicesExprs[0];
+ i2cToOperExprs.fwIndex = kIndicesExprs[1];
+ i2cToOperExprs.icIndex = kIndicesExprs[2];
+ i2cToOperExprs.ohIndex = mIndicesExprs[0];
+ i2cToOperExprs.owIndex = mIndicesExprs[1];
+
+ Im2ColToInputDimsExprs inExprs = getIm2ColInputMap(
+ i2cToOperExprs, llvm::to_vector(convOp.getStrides().getValues<int64_t>()),
+ rewriter);
+ auto inMap =
+ AffineMap::inferFromExprList({ArrayRef{inExprs.bIndex, inExprs.hIndex,
+ inExprs.wIndex, inExprs.cIndex}},
+ rewriter.getContext())[0];
+
SmallVector<AffineMap> img2colIndexingMaps = {
- AffineMap::getMultiDimIdentityMap(nloops, context)};
+ inMap, AffineMap::getMultiDimIdentityMap(nloops, context)};
auto img2ColTensor = rewriter.create<linalg::GenericOp>(
loc, colTensor.getType(),
- /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
+ /*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps,
img2colIterators,
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
- // Get the iterators named based on the matmul (batch, m, k).
- Value bIndex = nestedBuilder.create<linalg::IndexOp>(loc, 0);
- Value mIndex = nestedBuilder.create<linalg::IndexOp>(loc, 1);
- Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2);
-
- // Recover the original iteration indices from the problem/input sizes.
- SmallVector<Value> mIndices = unrollIndex(
- nestedBuilder, nestedLoc, mIndex, ArrayRef<int64_t>{oh, ow});
- auto ohIndex = mIndices[0];
- auto owIndex = mIndices[1];
-
- SmallVector<Value> kIndices = unrollIndex(
- nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{fh, fw, ic});
- auto fhIndex = kIndices[0];
- auto fwIndex = kIndices[1];
- auto icIndex = kIndices[2];
-
- // Extract the input element corresponding to the expanded indices.
- Value hIndex =
- getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
- convOp.getStrides().getValues<int64_t>()[0]);
- Value wIndex =
- getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
- convOp.getStrides().getValues<int64_t>()[1]);
-
- // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
- SmallVector<Value> extractionIndices{bIndex, hIndex, wIndex, icIndex};
- Value inputVal = nestedBuilder.create<tensor::ExtractOp>(
- loc, input, extractionIndices);
- nestedBuilder.create<linalg::YieldOp>(nestedLoc, inputVal);
+ nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
});
// Because the filter does not share the same batch dimension,
@@ -421,44 +457,36 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
auto reduction = utils::IteratorType::reduction;
SmallVector<utils::IteratorType, 3> img2colIterators(nloops, parallel);
- SmallVector<AffineMap, 4> img2colIndexingMaps = {
- AffineMap::getMultiDimIdentityMap(nloops, context)};
+ // Recover the original iteration indices from the problem/input sizes:
+ // given an index of the im2col matrix, retrieve the corresponding indices of
+ // the output and filter matrices
+ auto kIndicesExprs = delinearize(rewriter.getAffineDimExpr(1U),
+ ArrayRef<int64_t>{fh * fw, fw, 1});
+ auto mIndicesExprs =
+ delinearize(rewriter.getAffineDimExpr(2U), ArrayRef<int64_t>{ow, 1});
+ Im2ColToOperandsExprs i2cToOperExprs;
+ i2cToOperExprs.icIndex = kIndicesExprs[0];
+ i2cToOperExprs.fhIndex = kIndicesExprs[1];
+ i2cToOperExprs.fwIndex = kIndicesExprs[2];
+ i2cToOperExprs.ohIndex = mIndicesExprs[0];
+ i2cToOperExprs.owIndex = mIndicesExprs[1];
+ Im2ColToInputDimsExprs inExprs = getIm2ColInputMap(
+ i2cToOperExprs, llvm::to_vector(convOp.getStrides().getValues<int64_t>()),
+ rewriter);
+ auto inMap =
+ AffineMap::inferFromExprList({ArrayRef{inExprs.bIndex, inExprs.cIndex,
+ inExprs.hIndex, inExprs.wIndex}},
+ rewriter.getContext())[0];
+ // im2col[n, ic*fh*fw, oh*ow] = input[n, ic, sh*oh + fh, sw*ow + fw]
+ SmallVector<AffineMap> img2colIndexingMaps = {
+ inMap, AffineMap::getMultiDimIdentityMap(nloops, context)};
auto img2ColTensor = rewriter.create<linalg::GenericOp>(
loc, colTensor.getType(),
- /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
+ /*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps,
img2colIterators,
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
- // Get the iterators named based on the matmul (batch, m, k).
- Value bIndex = nestedBuilder.create<linalg::IndexOp>(loc, 0);
- Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 1);
- Value nIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2);
-
- // Recover the original iteration indices from the problem/input sizes.
- SmallVector<Value> kIndices = unrollIndex(
- nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{ic, fh, fw});
- auto icIndex = kIndices[0];
- auto fhIndex = kIndices[1];
- auto fwIndex = kIndices[2];
-
- SmallVector<Value> nIndices = unrollIndex(
- nestedBuilder, nestedLoc, nIndex, ArrayRef<int64_t>{oh, ow});
- auto ohIndex = nIndices[0];
- auto owIndex = nIndices[1];
-
- // Extract the input element corresponding to the expanded indices.
- Value hIndex =
- getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
- convOp.getStrides().getValues<int64_t>()[0]);
- Value wIndex =
- getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
- convOp.getStrides().getValues<int64_t>()[1]);
-
- // im2col[n, ic*fh*fw, oh*ow] = input[n, ic, sh*oh + fh, sw*ow + fw]
- SmallVector<Value> extractionIndices{bIndex, icIndex, hIndex, wIndex};
- Value inputVal = nestedBuilder.create<tensor::ExtractOp>(
- loc, input, extractionIndices);
- nestedBuilder.create<linalg::YieldOp>(nestedLoc, inputVal);
+ nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
});
// Because the filter does not share the same batch dimension,
@@ -545,6 +573,7 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>(
loc, reshapedOutputType, output, outputReassocIndices);
+ // Shape of the Toeplitz matrix produced by Im2col.
SmallVector<int64_t> colTensorShape = {n, oh * ow, fh * fw * ic};
Value colTensor = rewriter.create<tensor::EmptyOp>(
loc, colTensorShape, inputType.getElementType());
@@ -556,44 +585,36 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
auto reduction = utils::IteratorType::reduction;
SmallVector<utils::IteratorType> img2colIterators(nloops, parallel);
+ // Given an index of the im2col matrix, retrieve the corresponding indices of
+ // the output and filter matrices
+ auto mIndicesExprs =
+ delinearize(rewriter.getAffineDimExpr(1U), ArrayRef<int64_t>{ow, 1});
+ auto kIndicesExprs = delinearize(rewriter.getAffineDimExpr(2U),
+ ArrayRef<int64_t>{fw * ic, ic, 1});
+ Im2ColToOperandsExprs i2cToOperExprs;
+ i2cToOperExprs.fhIndex = kIndicesExprs[0];
+ i2cToOperExprs.fwIndex = kIndicesExprs[1];
+ i2cToOperExprs.icIndex = kIndicesExprs[2];
+ i2cToOperExprs.ohIndex = mIndicesExprs[0];
+ i2cToOperExprs.owIndex = mIndicesExprs[1];
+
+ // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
+ Im2ColToInputDimsExprs inExprs = getIm2ColInputMap(
+ i2cToOperExprs, llvm::to_vector(convOp.getStrides().getValues<int64_t>()),
+ rewriter);
+ auto inMap =
+ AffineMap::inferFromExprList({ArrayRef{inExprs.bIndex, inExprs.hIndex,
+ inExprs.wIndex, inExprs.cIndex}},
+ rewriter.getContext())[0];
SmallVector<AffineMap> img2colIndexingMaps = {
- AffineMap::getMultiDimIdentityMap(nloops, context)};
+ inMap, AffineMap::getMultiDimIdentityMap(nloops, context)};
auto img2ColTensor = rewriter.create<linalg::GenericOp>(
loc, colTensor.getType(),
- /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
+ /*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps,
img2colIterators,
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
- // Get the iterators named based on the matmul (batch, m, k).
- Value bIndex = nestedBuilder.create<linalg::IndexOp>(loc, 0);
- Value mIndex = nestedBuilder.create<linalg::IndexOp>(loc, 1);
- Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2);
-
- // Recover the original iteration indices from the problem/input sizes.
- SmallVector<Value> mIndices = unrollIndex(
- nestedBuilder, nestedLoc, mIndex, ArrayRef<int64_t>{oh, ow});
- auto ohIndex = mIndices[0];
- auto owIndex = mIndices[1];
-
- SmallVector<Value> kIndices = unrollIndex(
- nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{fh, fw, ic});
- auto fhIndex = kIndices[0];
- auto fwIndex = kIndices[1];
- auto icIndex = kIndices[2];
-
- // Extract the input element corresponding to the expanded indices.
- Value hIndex =
- getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
- convOp.getStrides().getValues<int64_t>()[0]);
- Value wIndex =
- getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
- convOp.getStrides().getValues<int64_t>()[1]);
-
- // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
- SmallVector<Value> extractionIndices{bIndex, hIndex, wIndex, icIndex};
- Value inputVal = nestedBuilder.create<tensor::ExtractOp>(
- loc, input, extractionIndices);
- nestedBuilder.create<linalg::YieldOp>(nestedLoc, inputVal);
+ nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
});
// Because we didn't transpose the filters we don't actually have a batched
diff --git a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
index c17f20b2d03ab..8627fcd2576b9 100644
--- a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
+++ b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
@@ -34,40 +34,35 @@ module attributes {transform.with_named_sequence} {
// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
// CHECK: ^bb0(%[[OUT_DATA:.+]]: f32)
-// Collapsed indices.
-// CHECK: %[[BINDEX:.+]] = linalg.index 0 : index
-// CHECK: %[[MINDEX:.+]] = linalg.index 1 : index
-// CHECK: %[[KINDEX:.+]] = linalg.index 2 : index
-
-// Compute input channel/convolved indices.
-// CHECK: %[[ICINDEX:.+]] = affine.apply affine_map<()[s0] -> (s0 mod 4)>()[%[[KINDEX]]]
-// CHECK: %[[CONVH:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 floordiv 14 + s1 floordiv 12)>()[%[[MINDEX]], %[[KINDEX]]]
-// CHECK: %[[CONVW:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 mod 14 + (s1 mod 12) floordiv 4)>()[%[[MINDEX]], %[[KINDEX]]]
-
-// Extract from the input tensor.
-// CHECK: %[[EXTRACTED_INPUT:.+]] = tensor.extract
-// CHECK-SAME: %{{.+}}{{\[}}%[[BINDEX]], %[[CONVH]], %[[CONVW]], %[[ICINDEX]]] : tensor<1x16x16x4xf32>
-// CHECK: linalg.yield %[[EXTRACTED_INPUT]] : f32
-
// CHECK: IR printer: transformed
// CHECK: tensor.expand_shape %{{[^ ]*}} {{\[}}[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32>
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// Im2col maps
+// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1 floordiv 14 + d2 floordiv 12, d1 mod 14 + (d2 mod 12) floordiv 4, d2 mod 4)>
+// CHECK-DAG: #[[MAPI2C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// Matmul maps
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
-// CHECK: @conv_16433136
-// CHECK-SAME: %[[INPUT:.+]]: tensor<1x16x16x4xf32>
-// CHECK-SAME: %[[FILTER:.+]]: tensor<3x3x4x16xf32>
-// CHECK-SAME: %[[OUTPUT:.+]]: tensor<1x14x14x16xf32>
+
+// CHECK: @conv_16433136
+// CHECK-SAME: %[[INPUT:.+]]: tensor<1x16x16x4xf32>
+// CHECK-SAME: %[[FILTER:.+]]: tensor<3x3x4x16xf32>
+// CHECK-SAME: %[[OUTPUT:.+]]: tensor<1x14x14x16xf32>
// CHECK-DAG: %[[COLLAPSED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0, 1, 2], [3]] : tensor<3x3x4x16xf32> into tensor<36x16xf32>
// CHECK-DAG: %[[COLLAPSED_OUT:.+]] = tensor.collapse_shape %[[OUTPUT]] {{\[}}[0], [1, 2], [3]] : tensor<1x14x14x16xf32> into tensor<1x196x16xf32>
-// CHECK: %[[INIT_COL_TENSOR:.+]] = tensor.empty() : tensor<1x196x36xf32>
-// CHECK: %[[COL_TENSOR:.+]] = linalg.generic
-// CHECK-SAME: #[[MAP0]]
-// CHECK: ^bb0(%[[OUT_DATA:.+]]: f32)
-// CHECK: linalg.yield %{{.+}} : f32
-// CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic
+// CHECK: %[[INIT_COL_TENSOR:.+]] = tensor.empty() : tensor<1x196x36xf32>
+
+// CHECK: %[[COL_TENSOR:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAPI2C]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
+// CHECK-SAME: ins(%[[INPUT]] : tensor<1x16x16x4xf32>)
+// CHECK-SAME: outs(%[[INIT_COL_TENSOR]] : tensor<1x196x36xf32>)
+// CHECK: ^bb0(%[[IN:.+]]: f32, %out: f32):
+// CHECK: linalg.yield %[[IN]] : f32
+// CHECK: } -> tensor<1x196x36xf32>
+
+// CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic
// CHECK-SAME: #[[MAP1]]
// CHECK-SAME: #[[MAP2]]
// CHECK-SAME: #[[MAP3]]
@@ -180,7 +175,10 @@ module attributes {transform.with_named_sequence} {
// -----
-// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// Im2col maps
+// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1 floordiv 14 + d2 floordiv 12, d1 mod 14 + (d2 mod 12) floordiv 4, d2 mod 4)>
+// CHECK-DAG: #[[MAPI2C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
// CHECK-DAG: #[[LHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
// CHECK-DAG: #[[RHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
// CHECK-DAG: #[[RESMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
@@ -191,9 +189,13 @@ module attributes {transform.with_named_sequence} {
// CHECK-DAG: %[[CS_RESULT:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0], [1, 2], [3]] : tensor<8x14x14x16xf32> into tensor<8x196x16xf32>
// CHECK: %[[IT:.+]] = tensor.empty() : tensor<8x196x36xf32>
// CHECK: %[[IMG2COL:.+]] = linalg.generic
-// CHECK-SAME: indexing_maps = [#[[MAP]]]
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAPI2C]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
+// CHECK-SAME: ins(%[[INPUT]] : tensor<8x16x16x4xf32>)
// CHECK-SAME: outs(%[[IT]] : tensor<8x196x36xf32>)
+// CHECK: ^bb0(%[[IN:.+]]: f32, %out: f32):
+// CHECK: linalg.yield %[[IN]] : f32
+// CHECK: } -> tensor<8x196x36xf32>
// CHECK: %[[MATMUL:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[LHSMAP]], #[[RHSMAP]], #[[RESMAP]]],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
@@ -224,13 +226,9 @@ module attributes {transform.with_named_sequence} {
// -----
-// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-
// Im2col maps
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 floordiv 9)>
-// CHECK-DAG: #[[MAP7:.+]] = affine_map<()[s0, s1] -> (s0 floordiv 14 + (s1 mod 9) floordiv 3)>
-// CHECK-DAG: #[[MAP8:.+]] = affine_map<()[s0, s1] -> (s0 + s1 - (s0 floordiv 14) * 14 - (s1 floordiv 3) * 3)>
-
+// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1 floordiv 9, d2 floordiv 14 + (d1 mod 9) floordiv 3, d2 mod 14 + d1 mod 3)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-DAG: #[[LHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
// CHECK-DAG: #[[RHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
@@ -242,32 +240,12 @@ module attributes {transform.with_named_sequence} {
// CHECK-DAG: %[[CS_RESULT:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0], [1], [2, 3]] : tensor<8x16x14x14xf32> into tensor<8x16x196xf32>
// CHECK: %[[IT:.+]] = tensor.empty() : tensor<8x36x196xf32>
// CHECK: %[[IMG2COL:.+]] = linalg.generic
-// CHECK-SAME: indexing_maps = [#[[MAP]]]
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP1]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
+// CHECK-SAME: ins(%[[INPUT]] : tensor<8x4x16x16xf32>)
// CHECK-SAME: outs(%[[IT]] : tensor<8x36x196xf32>)
-// Collapsed indices.
-// CHECK: %[[BINDEX:.+]] = linalg.index 0 : index
-// CHECK: %[[KINDEX:.+]] = linalg.index 1 : index
-// CHECK: %[[NINDEX:.+]] = linalg.index 2 : index
-
-// Compute input channel/convolved indices.
-// CHECK: %[[ICINDEX:.+]] = affine.apply #[[MAP1]]()[%[[KINDEX]]]
-// CHECK: %[[CONVH:.+]] = affine.apply #[[MAP7]]()[%[[NINDEX]], %[[KINDEX]]]
-// CHECK: %[[CONVW:.+]] = affine.apply #[[MAP8]]()[%[[NINDEX]], %[[KINDEX]]]
-
-// Extract from the input tensor.
-// CHECK: %[[EXTRACTED_INPUT:.+]] = tensor.extract
-// CHECK-SAME: %[[INPUT]]{{\[}}%[[BINDEX]], %[[ICINDEX]], %[[CONVH]], %[[CONVW]]] : tensor<8x4x16x16xf32>
-// CHECK: linalg.yield %[[EXTRACTED_INPUT]] : f32
-// CHECK: %[[MATMUL:.+]] = linalg.generic
-// CHECK-SAME: indexing_maps = [#[[LHSMAP]], #[[RHSMAP]], #[[RESMAP]]],
-// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
-// CHECK-SAME: ins(%[[CS_FILTER]], %[[IMG2COL]] : tensor<16x36xf32>, tensor<8x36x196xf32>)
-// CHECK-SAME: outs(%[[CS_RESULT]] : tensor<8x16x196xf32>)
-// CHECK: ^bb0(%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32):
-// CHECK: %[[MUL:.+]] = arith.mulf %[[ARG0]], %[[ARG1]] : f32
-// CHECK: %[[ADD:.+]] = arith.addf %[[MUL]], %[[ARG2]] : f32
-// CHECK: linalg.yield %[[ADD]] : f32
+// CHECK: ^bb0(%[[IN:.+]]: f32, %out: f32):
+// CHECK: linalg.yield %[[IN]] : f32
// CHECK: } -> tensor<8x16x196xf32>
// CHECK: %[[CS_FINAL:.+]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0], [1], [2, 3]] output_shape [8, 16, 14, 14] : tensor<8x16x196xf32> into tensor<8x16x14x14xf32>
// CHECK: return %[[CS_FINAL]]
@@ -291,31 +269,19 @@ module attributes {transform.with_named_sequence} {
// CHECK: IR printer: tensor_producer
// CHECK-NEXT: %[[COL_TENSOR:.+]] = linalg.generic
+// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1 floordiv 14 + d2 floordiv 12, d1 mod 14 + (d2 mod 12) floordiv 4, d2 mod 4)>
// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
-// CHECK: ^bb0(%[[OUT_DATA:.+]]: f32)
-
-// Collapsed indices.
-// CHECK: %[[BINDEX:.+]] = linalg.index 0 : index
-// CHECK: %[[MINDEX:.+]] = linalg.index 1 : index
-// CHECK: %[[KINDEX:.+]] = linalg.index 2 : index
-
-// Compute input channel/convolved indices.
-// CHECK: %[[ICINDEX:.+]] = affine.apply affine_map<()[s0] -> (s0 mod 4)>()[%[[KINDEX]]]
-// CHECK: %[[CONVH:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 floordiv 14 + s1 floordiv 12)>()[%[[MINDEX]], %[[KINDEX]]]
-// CHECK: %[[CONVW:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 mod 14 + (s1 mod 12) floordiv 4)>()[%[[MINDEX]], %[[KINDEX]]]
-
-// Extract from the input tensor.
-// CHECK: %[[EXTRACTED_INPUT:.+]] = tensor.extract
-// CHECK-SAME: %{{.+}}{{\[}}%[[BINDEX]], %[[CONVH]], %[[CONVW]], %[[ICINDEX]]] : tensor<1x16x16x4xf32>
-// CHECK: linalg.yield %[[EXTRACTED_INPUT]] : f32
+// CHECK: ^bb0(%[[IN_DATA:.+]]: f32, %[[OUT_DATA:.+]]: f32)
+// CHECK: linalg.yield %[[IN_DATA]] : f32
// CHECK: IR printer: transformed
// CHECK: tensor.expand_shape %{{[^ ]*}} {{\[}}[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32>
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
-// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
-// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1 floordiv 14 + d2 floordiv 12, d1 mod 14 + (d2 mod 12) floordiv 4, d2 mod 4)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
+// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
// CHECK: @conv_2d_nhwc_fhwc
// CHECK-SAME: %[[INPUT:.+]]: tensor<1x16x16x4xf32>
// CHECK-SAME: %[[FILTER:.+]]: tensor<16x3x3x4xf32>
@@ -324,13 +290,13 @@ module attributes {transform.with_named_sequence} {
// CHECK-DAG: %[[COLLAPSED_OUT:.+]] = tensor.collapse_shape %[[OUTPUT]] {{\[}}[0], [1, 2], [3]] : tensor<1x14x14x16xf32> into tensor<1x196x16xf32>
// CHECK: %[[INIT_COL_TENSOR:.+]] = tensor.empty() : tensor<1x196x36xf32>
// CHECK: %[[COL_TENSOR:.+]] = linalg.generic
-// CHECK-SAME: #[[MAP0]]
+// CHECK-SAME: [#[[MAP0]], #[[MAP1]]]
// CHECK: ^bb0(%[[OUT_DATA:.+]]: f32)
// CHECK: linalg.yield %{{.+}} : f32
// CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic
-// CHECK-SAME: #[[MAP1]]
// CHECK-SAME: #[[MAP2]]
// CHECK-SAME: #[[MAP3]]
+// CHECK-SAME: #[[MAP4]]
// CHECK-SAME: ins(%[[COL_TENSOR]], %[[COLLAPSED_FILTER]] : tensor<1x196x36xf32>, tensor<16x36xf32>)
// CHECK-SAME: outs(%[[COLLAPSED_OUT]] : tensor<1x196x16xf32>)
// CHECK: ^bb0(%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32)
More information about the Mlir-commits
mailing list