[Mlir-commits] [mlir] 5d51e00 - [mlir][linalg] Propagate filter tensor encoding in im2col (#160099)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Sep 26 02:33:16 PDT 2025


Author: fabrizio-indirli
Date: 2025-09-26T10:33:13+01:00
New Revision: 5d51e006caa6a2ca6c9944200ca69bcb1c38576a

URL: https://github.com/llvm/llvm-project/commit/5d51e006caa6a2ca6c9944200ca69bcb1c38576a
DIFF: https://github.com/llvm/llvm-project/commit/5d51e006caa6a2ca6c9944200ca69bcb1c38576a.diff

LOG: [mlir][linalg] Propagate filter tensor encoding in im2col (#160099)

In the im2col decomposition, propagate the filter tensor encoding (if
specified) through the tensor.collapse_shape op, so that it can be used
by the consuming linalg.generic matmul op.

Signed-off-by: Fabrizio Indirli <Fabrizio.Indirli at arm.com>

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
    mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 41670249936e6..7266687584b38 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1858,6 +1858,7 @@ void populateDecomposePadPatterns(RewritePatternSet &patterns);
 
 /// Populates patterns to transform linalg.conv_2d_xxx operations into
 /// linalg.generic (for img2col packing) and linalg.matmul.
+/// Note: currently limited to Tensor semantics only.
 /// \see rewriteInIm2Col for more details.
 void populateConvertConv2DToImg2ColPatterns(RewritePatternSet &patterns);
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
index 108abe800b13e..ebc4dcf6bbcb5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
@@ -20,6 +20,7 @@
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include <cassert>
 #include <utility>
 
 namespace mlir {
@@ -124,6 +125,10 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) {
   auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
   auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
 
+  if (!convOp.hasPureTensorSemantics())
+    return rewriter.notifyMatchFailure(
+        convOp, "expected op to have pure tensor semantics");
+
   if (!filterType.hasStaticShape())
     return rewriter.notifyMatchFailure(
         convOp, "expected a static shape for the filter");
@@ -155,10 +160,15 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) {
 
   Location loc = convOp.getLoc();
 
+  assert(isa<RankedTensorType>(filterType) &&
+         "expected filter type to be a ranked tensor");
+  auto tensorFilterType = cast<RankedTensorType>(filterType);
+
   // Reshape output and filter to the LHS and result of a (B)MNK matmul.
   SmallVector<ReassociationIndices> filterReassocIndices = {{0, 1, 2}, {3}};
   auto reshapedFilterType =
-      RankedTensorType::get({fh * fw * ic, oc}, filterType.getElementType());
+      RankedTensorType::get({fh * fw * ic, oc}, filterType.getElementType(),
+                            tensorFilterType.getEncoding());
   Value reshapedFilter = tensor::CollapseShapeOp::create(
       rewriter, loc, reshapedFilterType, filter, filterReassocIndices);
 
@@ -253,6 +263,10 @@ rewriteInIm2Col(RewriterBase &rewriter,
   auto filterType = cast<RankedTensorType>(convOp.getInputs()[1].getType());
   auto outputType = cast<RankedTensorType>(convOp.getOutputs()[0].getType());
 
+  if (!convOp.hasPureTensorSemantics())
+    return rewriter.notifyMatchFailure(
+        convOp, "expected op to have pure tensor semantics");
+
   if (!filterType.hasStaticShape())
     return rewriter.notifyMatchFailure(
         convOp, "expected a static shape for the filter");
@@ -404,6 +418,10 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
   auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
   auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
 
+  if (!convOp.hasPureTensorSemantics())
+    return rewriter.notifyMatchFailure(
+        convOp, "expected op to have pure tensor semantics");
+
   if (!filterType.hasStaticShape())
     return rewriter.notifyMatchFailure(
         convOp, "expected a static shape for the filter");
@@ -435,9 +453,14 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
   auto loc = convOp.getLoc();
   MLIRContext *context = rewriter.getContext();
 
+  assert(isa<RankedTensorType>(filterType) &&
+         "expected filter type to be a ranked tensor");
+  auto tensorFilterType = cast<RankedTensorType>(filterType);
+
   SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
   auto reshapedFilterType =
-      RankedTensorType::get({oc, ic * fh * fw}, inputType.getElementType());
+      RankedTensorType::get({oc, ic * fh * fw}, inputType.getElementType(),
+                            tensorFilterType.getEncoding());
   Value reshapedFilter = tensor::CollapseShapeOp::create(
       rewriter, loc, reshapedFilterType, filter, filterReassocIndices);
 
@@ -529,6 +552,10 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
   auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
   auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
 
+  if (!convOp.hasPureTensorSemantics())
+    return rewriter.notifyMatchFailure(
+        convOp, "expected op to have pure tensor semantics");
+
   if (!filterType.hasStaticShape())
     return rewriter.notifyMatchFailure(
         convOp, "expected a static shape for the filter");
@@ -560,11 +587,16 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
 
   Location loc = convOp.getLoc();
 
+  assert(isa<RankedTensorType>(filterType) &&
+         "expected filter type to be a ranked tensor");
+  auto tensorFilterType = cast<RankedTensorType>(filterType);
+
   // Reshape output and filter to the LHS and result of a "row-wise" matrix
   // multiplication.
   SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
   auto reshapedFilterType =
-      RankedTensorType::get({oc, fh * fw * ic}, filterType.getElementType());
+      RankedTensorType::get({oc, fh * fw * ic}, filterType.getElementType(),
+                            tensorFilterType.getEncoding());
   Value reshapedFilter = tensor::CollapseShapeOp::create(
       rewriter, loc, reshapedFilterType, filter, filterReassocIndices);
 

diff  --git a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
index 8627fcd2576b9..152a392afe247 100644
--- a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
+++ b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
@@ -26,6 +26,26 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// Memref semantics is not supported.
+// Check that we emit an error.
+func.func @negative_conv_memref(%arg0: memref<1x16x16x4xf32>, %arg1: memref<16x3x3x4xf32>, %arg2: memref<1x14x14x16xf32>) {
+    // expected-note at below {{when applied to this op}}
+    linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : memref<2xi64>, strides = dense<1> : memref<2xi64> }
+       ins(%arg0, %arg1: memref<1x16x16x4xf32>, memref<16x3x3x4xf32>) outs(%arg2: memref<1x14x14x16xf32>)
+    return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // expected-error at below {{failed to apply}}
+    %img2col_tensor_producer, %transformed = transform.structured.convert_conv2d_to_img2col %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
 // Check that we get the proper handles for the img2col tensor producer
 // and the final instruction.
 
@@ -267,6 +287,31 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// Check that the encoding on the filter (weights) tensor is propagated when applying the transform. 
+
+// CHECK: func.func @batch_nchw_conv_with_filter_encoding(%[[INPUT:.+]]: tensor<8x4x16x16xf32>, %[[FILTER:.*]]: tensor<16x4x3x3xf32, 42 : i32>, %[[OUTPUT:.*]]: tensor<8x16x14x14xf32>)
+//  CHECK-DAG: %[[COLLAPSED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]]
+  // CHECK-SAME{LITERAL}: [[0], [1, 2, 3]] : tensor<16x4x3x3xf32, 42 : i32> into tensor<16x36xf32, 42 : i32>
+//  CHECK: %[[COL_TENSOR:.+]] = linalg.generic {{.*}} ins(%[[INPUT]] : tensor<8x4x16x16xf32>)
+//  CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic {{.*}} ins(%[[COLLAPSED_FILTER]], %[[COL_TENSOR]] : tensor<16x36xf32, 42 : i32>, tensor<8x36x196xf32>)
+func.func @batch_nchw_conv_with_filter_encoding(%arg0: tensor<8x4x16x16xf32>, %arg1: tensor<16x4x3x3xf32, 42 : i32>, %arg2: tensor<8x16x14x14xf32>) -> tensor<8x16x14x14xf32> {
+    %0 = linalg.conv_2d_nchw_fchw
+      {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+       ins(%arg0, %arg1: tensor<8x4x16x16xf32>, tensor<16x4x3x3xf32, 42 : i32>)
+      outs(%arg2: tensor<8x16x14x14xf32>) -> tensor<8x16x14x14xf32>
+    return %0 : tensor<8x16x14x14xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.conv_2d_nchw_fchw"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1:2 = transform.structured.convert_conv2d_to_img2col %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
 // CHECK: IR printer: tensor_producer
 // CHECK-NEXT: %[[COL_TENSOR:.+]] = linalg.generic
 // CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1 floordiv 14 + d2 floordiv 12, d1 mod 14 + (d2 mod 12) floordiv 4, d2 mod 4)>
@@ -290,7 +335,7 @@ module attributes {transform.with_named_sequence} {
 //  CHECK-DAG: %[[COLLAPSED_OUT:.+]] = tensor.collapse_shape %[[OUTPUT]] {{\[}}[0], [1, 2], [3]] : tensor<1x14x14x16xf32> into tensor<1x196x16xf32>
 //      CHECK: %[[INIT_COL_TENSOR:.+]] = tensor.empty() : tensor<1x196x36xf32>
 //      CHECK: %[[COL_TENSOR:.+]] = linalg.generic
-//           CHECK-SAME: [#[[MAP0]], #[[MAP1]]]
+//           CHECK-SAME: [#[[MAP0]], #[[MAP1]]], {{.*}} ins(%[[INPUT]] : tensor<1x16x16x4xf32>) outs(%[[INIT_COL_TENSOR]] : tensor<1x196x36xf32>)
 //                CHECK: ^bb0(%[[OUT_DATA:.+]]: f32)
 //                CHECK: linalg.yield %{{.+}} : f32
 //      CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic
@@ -327,6 +372,31 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// Check that the encoding on the filter (weights) tensor is propagated when applying the transform. 
+
+// CHECK: func.func @conv_2d_nhwc_fhwc_with_filter_encoding(%[[INPUT:.+]]: tensor<1x16x16x4xf32>, %[[FILTER:.*]]: tensor<16x3x3x4xf32, 42 : i32>, %[[OUTPUT:.*]]: tensor<1x14x14x16xf32>)
+//  CHECK-DAG: %[[COLLAPSED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]]
+  // CHECK-SAME{LITERAL}: [[0], [1, 2, 3]] : tensor<16x3x3x4xf32, 42 : i32> into tensor<16x36xf32, 42 : i32>
+//  CHECK: %[[COL_TENSOR:.+]] = linalg.generic {{.*}} ins(%[[INPUT]] : tensor<1x16x16x4xf32>)
+//  CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic {{.*}} ins(%[[COL_TENSOR]], %[[COLLAPSED_FILTER]] : tensor<1x196x36xf32>, tensor<16x36xf32, 42 : i32>)
+func.func @conv_2d_nhwc_fhwc_with_filter_encoding(%input: tensor<1x16x16x4xf32>, %filter: tensor<16x3x3x4xf32, 42 : i32>, %out: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
+    %0 = linalg.conv_2d_nhwc_fhwc
+      { dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+      ins(%input, %filter: tensor<1x16x16x4xf32>, tensor<16x3x3x4xf32, 42 : i32>)
+      outs(%out: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
+    return %0 : tensor<1x14x14x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1:2 = transform.structured.convert_conv2d_to_img2col %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
 // Check for signed extend when the input type is smaller than the accumulator type.
 
 // CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>


        


More information about the Mlir-commits mailing list