[Mlir-commits] [mlir] [mlir][linalg] Propagate filter tensor encoding in im2col (PR #160099)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Sep 26 01:51:14 PDT 2025
https://github.com/fabrizio-indirli updated https://github.com/llvm/llvm-project/pull/160099
>From d73b6c1b39c464427425b56bbdd055c975f4d993 Mon Sep 17 00:00:00 2001
From: Fabrizio Indirli <Fabrizio.Indirli at arm.com>
Date: Mon, 22 Sep 2025 14:05:17 +0100
Subject: [PATCH] [mlir][linalg] Propagate filter tensor encoding in im2col
In the im2col decomposition, propagate the filter tensor encoding
(if specified) through the tensor.collapse_shape op, so that it
can be used by the consuming linalg.generic matmul op.
Signed-off-by: Fabrizio Indirli <Fabrizio.Indirli at arm.com>
Change-Id: I275d6fad0257d9813b9821341a6160144ae983e7
---
.../Dialect/Linalg/Transforms/Transforms.h | 1 +
.../Transforms/ConvertConv2DToImg2Col.cpp | 38 +++++++++-
.../Linalg/convert-conv2d-to-img2col.mlir | 72 ++++++++++++++++++-
3 files changed, 107 insertions(+), 4 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 64d3a2448b409..e7ba9ab723604 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1858,6 +1858,7 @@ void populateDecomposePadPatterns(RewritePatternSet &patterns);
/// Populates patterns to transform linalg.conv_2d_xxx operations into
/// linalg.generic (for img2col packing) and linalg.matmul.
+/// Note: currently limited to Tensor semantics only.
/// \see rewriteInIm2Col for more details.
void populateConvertConv2DToImg2ColPatterns(RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
index 108abe800b13e..ebc4dcf6bbcb5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
@@ -20,6 +20,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
+#include <cassert>
#include <utility>
namespace mlir {
@@ -124,6 +125,10 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) {
auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
+ if (!convOp.hasPureTensorSemantics())
+ return rewriter.notifyMatchFailure(
+ convOp, "expected op to have pure tensor semantics");
+
if (!filterType.hasStaticShape())
return rewriter.notifyMatchFailure(
convOp, "expected a static shape for the filter");
@@ -155,10 +160,15 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) {
Location loc = convOp.getLoc();
+ assert(isa<RankedTensorType>(filterType) &&
+ "expected filter type to be a ranked tensor");
+ auto tensorFilterType = cast<RankedTensorType>(filterType);
+
// Reshape output and filter to the LHS and result of a (B)MNK matmul.
SmallVector<ReassociationIndices> filterReassocIndices = {{0, 1, 2}, {3}};
auto reshapedFilterType =
- RankedTensorType::get({fh * fw * ic, oc}, filterType.getElementType());
+ RankedTensorType::get({fh * fw * ic, oc}, filterType.getElementType(),
+ tensorFilterType.getEncoding());
Value reshapedFilter = tensor::CollapseShapeOp::create(
rewriter, loc, reshapedFilterType, filter, filterReassocIndices);
@@ -253,6 +263,10 @@ rewriteInIm2Col(RewriterBase &rewriter,
auto filterType = cast<RankedTensorType>(convOp.getInputs()[1].getType());
auto outputType = cast<RankedTensorType>(convOp.getOutputs()[0].getType());
+ if (!convOp.hasPureTensorSemantics())
+ return rewriter.notifyMatchFailure(
+ convOp, "expected op to have pure tensor semantics");
+
if (!filterType.hasStaticShape())
return rewriter.notifyMatchFailure(
convOp, "expected a static shape for the filter");
@@ -404,6 +418,10 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
+ if (!convOp.hasPureTensorSemantics())
+ return rewriter.notifyMatchFailure(
+ convOp, "expected op to have pure tensor semantics");
+
if (!filterType.hasStaticShape())
return rewriter.notifyMatchFailure(
convOp, "expected a static shape for the filter");
@@ -435,9 +453,14 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
auto loc = convOp.getLoc();
MLIRContext *context = rewriter.getContext();
+ assert(isa<RankedTensorType>(filterType) &&
+ "expected filter type to be a ranked tensor");
+ auto tensorFilterType = cast<RankedTensorType>(filterType);
+
SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
auto reshapedFilterType =
- RankedTensorType::get({oc, ic * fh * fw}, inputType.getElementType());
+ RankedTensorType::get({oc, ic * fh * fw}, inputType.getElementType(),
+ tensorFilterType.getEncoding());
Value reshapedFilter = tensor::CollapseShapeOp::create(
rewriter, loc, reshapedFilterType, filter, filterReassocIndices);
@@ -529,6 +552,10 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
+ if (!convOp.hasPureTensorSemantics())
+ return rewriter.notifyMatchFailure(
+ convOp, "expected op to have pure tensor semantics");
+
if (!filterType.hasStaticShape())
return rewriter.notifyMatchFailure(
convOp, "expected a static shape for the filter");
@@ -560,11 +587,16 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
Location loc = convOp.getLoc();
+ assert(isa<RankedTensorType>(filterType) &&
+ "expected filter type to be a ranked tensor");
+ auto tensorFilterType = cast<RankedTensorType>(filterType);
+
// Reshape output and filter to the LHS and result of a "row-wise" matrix
// multiplication.
SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
auto reshapedFilterType =
- RankedTensorType::get({oc, fh * fw * ic}, filterType.getElementType());
+ RankedTensorType::get({oc, fh * fw * ic}, filterType.getElementType(),
+ tensorFilterType.getEncoding());
Value reshapedFilter = tensor::CollapseShapeOp::create(
rewriter, loc, reshapedFilterType, filter, filterReassocIndices);
diff --git a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
index 8627fcd2576b9..152a392afe247 100644
--- a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
+++ b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
@@ -26,6 +26,26 @@ module attributes {transform.with_named_sequence} {
// -----
+// Memref semantics is not supported.
+// Check that we emit an error.
+func.func @negative_conv_memref(%arg0: memref<1x16x16x4xf32>, %arg1: memref<16x3x3x4xf32>, %arg2: memref<1x14x14x16xf32>) {
+ // expected-note at below {{when applied to this op}}
+ linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : memref<2xi64>, strides = dense<1> : memref<2xi64> }
+ ins(%arg0, %arg1: memref<1x16x16x4xf32>, memref<16x3x3x4xf32>) outs(%arg2: memref<1x14x14x16xf32>)
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // expected-error at below {{failed to apply}}
+ %img2col_tensor_producer, %transformed = transform.structured.convert_conv2d_to_img2col %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// -----
+
// Check that we get the proper handles for the img2col tensor producer
// and the final instruction.
@@ -267,6 +287,31 @@ module attributes {transform.with_named_sequence} {
// -----
+// Check that the encoding on the filter (weights) tensor is propagated when applying the transform.
+
+// CHECK: func.func @batch_nchw_conv_with_filter_encoding(%[[INPUT:.+]]: tensor<8x4x16x16xf32>, %[[FILTER:.*]]: tensor<16x4x3x3xf32, 42 : i32>, %[[OUTPUT:.*]]: tensor<8x16x14x14xf32>)
+// CHECK-DAG: %[[COLLAPSED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]]
+ // CHECK-SAME{LITERAL}: [[0], [1, 2, 3]] : tensor<16x4x3x3xf32, 42 : i32> into tensor<16x36xf32, 42 : i32>
+// CHECK: %[[COL_TENSOR:.+]] = linalg.generic {{.*}} ins(%[[INPUT]] : tensor<8x4x16x16xf32>)
+// CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic {{.*}} ins(%[[COLLAPSED_FILTER]], %[[COL_TENSOR]] : tensor<16x36xf32, 42 : i32>, tensor<8x36x196xf32>)
+func.func @batch_nchw_conv_with_filter_encoding(%arg0: tensor<8x4x16x16xf32>, %arg1: tensor<16x4x3x3xf32, 42 : i32>, %arg2: tensor<8x16x14x14xf32>) -> tensor<8x16x14x14xf32> {
+ %0 = linalg.conv_2d_nchw_fchw
+ {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+ ins(%arg0, %arg1: tensor<8x4x16x16xf32>, tensor<16x4x3x3xf32, 42 : i32>)
+ outs(%arg2: tensor<8x16x14x14xf32>) -> tensor<8x16x14x14xf32>
+ return %0 : tensor<8x16x14x14xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.conv_2d_nchw_fchw"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1:2 = transform.structured.convert_conv2d_to_img2col %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// -----
+
// CHECK: IR printer: tensor_producer
// CHECK-NEXT: %[[COL_TENSOR:.+]] = linalg.generic
// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1 floordiv 14 + d2 floordiv 12, d1 mod 14 + (d2 mod 12) floordiv 4, d2 mod 4)>
@@ -290,7 +335,7 @@ module attributes {transform.with_named_sequence} {
// CHECK-DAG: %[[COLLAPSED_OUT:.+]] = tensor.collapse_shape %[[OUTPUT]] {{\[}}[0], [1, 2], [3]] : tensor<1x14x14x16xf32> into tensor<1x196x16xf32>
// CHECK: %[[INIT_COL_TENSOR:.+]] = tensor.empty() : tensor<1x196x36xf32>
// CHECK: %[[COL_TENSOR:.+]] = linalg.generic
-// CHECK-SAME: [#[[MAP0]], #[[MAP1]]]
+// CHECK-SAME: [#[[MAP0]], #[[MAP1]]], {{.*}} ins(%[[INPUT]] : tensor<1x16x16x4xf32>) outs(%[[INIT_COL_TENSOR]] : tensor<1x196x36xf32>)
// CHECK: ^bb0(%[[OUT_DATA:.+]]: f32)
// CHECK: linalg.yield %{{.+}} : f32
// CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic
@@ -327,6 +372,31 @@ module attributes {transform.with_named_sequence} {
// -----
+// Check that the encoding on the filter (weights) tensor is propagated when applying the transform.
+
+// CHECK: func.func @conv_2d_nhwc_fhwc_with_filter_encoding(%[[INPUT:.+]]: tensor<1x16x16x4xf32>, %[[FILTER:.*]]: tensor<16x3x3x4xf32, 42 : i32>, %[[OUTPUT:.*]]: tensor<1x14x14x16xf32>)
+// CHECK-DAG: %[[COLLAPSED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]]
+ // CHECK-SAME{LITERAL}: [[0], [1, 2, 3]] : tensor<16x3x3x4xf32, 42 : i32> into tensor<16x36xf32, 42 : i32>
+// CHECK: %[[COL_TENSOR:.+]] = linalg.generic {{.*}} ins(%[[INPUT]] : tensor<1x16x16x4xf32>)
+// CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic {{.*}} ins(%[[COL_TENSOR]], %[[COLLAPSED_FILTER]] : tensor<1x196x36xf32>, tensor<16x36xf32, 42 : i32>)
+func.func @conv_2d_nhwc_fhwc_with_filter_encoding(%input: tensor<1x16x16x4xf32>, %filter: tensor<16x3x3x4xf32, 42 : i32>, %out: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
+ %0 = linalg.conv_2d_nhwc_fhwc
+ { dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+ ins(%input, %filter: tensor<1x16x16x4xf32>, tensor<16x3x3x4xf32, 42 : i32>)
+ outs(%out: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
+ return %0 : tensor<1x14x14x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1:2 = transform.structured.convert_conv2d_to_img2col %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// -----
+
// Check for signed extend when the input type is smaller than the accumulator type.
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
More information about the Mlir-commits
mailing list