[Mlir-commits] [mlir] [mlir][linalg] Propagate filter tensor encoding in im2col (PR #160099)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Sep 22 06:21:40 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: None (fabrizio-indirli)

<details>
<summary>Changes</summary>

In the im2col decomposition, propagate the filter tensor encoding (if specified) through the tensor.collapse_shape op, so that it can be used by the consuming linalg.generic matmul op.

---
Full diff: https://github.com/llvm/llvm-project/pull/160099.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp (+21-3) 
- (modified) mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir (+29-1) 


``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
index 108abe800b13e..12e2b6f5c3f0e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
@@ -155,10 +155,16 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) {
 
   Location loc = convOp.getLoc();
 
+  if (!isa<RankedTensorType>(filterType))
+    return rewriter.notifyMatchFailure(
+        convOp, "expected filter type to be a ranked tensor");
+  auto tensorFilterType = cast<RankedTensorType>(filterType);
+
   // Reshape output and filter to the LHS and result of a (B)MNK matmul.
   SmallVector<ReassociationIndices> filterReassocIndices = {{0, 1, 2}, {3}};
   auto reshapedFilterType =
-      RankedTensorType::get({fh * fw * ic, oc}, filterType.getElementType());
+      RankedTensorType::get({fh * fw * ic, oc}, filterType.getElementType(),
+                            tensorFilterType.getEncoding());
   Value reshapedFilter = tensor::CollapseShapeOp::create(
       rewriter, loc, reshapedFilterType, filter, filterReassocIndices);
 
@@ -435,9 +441,15 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
   auto loc = convOp.getLoc();
   MLIRContext *context = rewriter.getContext();
 
+  if (!isa<RankedTensorType>(filterType))
+    return rewriter.notifyMatchFailure(
+        convOp, "expected filter type to be a ranked tensor");
+  auto tensorFilterType = cast<RankedTensorType>(filterType);
+
   SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
   auto reshapedFilterType =
-      RankedTensorType::get({oc, ic * fh * fw}, inputType.getElementType());
+      RankedTensorType::get({oc, ic * fh * fw}, inputType.getElementType(),
+                            tensorFilterType.getEncoding());
   Value reshapedFilter = tensor::CollapseShapeOp::create(
       rewriter, loc, reshapedFilterType, filter, filterReassocIndices);
 
@@ -560,11 +572,17 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
 
   Location loc = convOp.getLoc();
 
+  if (!isa<RankedTensorType>(filterType))
+    return rewriter.notifyMatchFailure(
+        convOp, "expected filter type to be a ranked tensor");
+  auto tensorFilterType = cast<RankedTensorType>(filterType);
+
   // Reshape output and filter to the LHS and result of a "row-wise" matrix
   // multiplication.
   SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
   auto reshapedFilterType =
-      RankedTensorType::get({oc, fh * fw * ic}, filterType.getElementType());
+      RankedTensorType::get({oc, fh * fw * ic}, filterType.getElementType(),
+                            tensorFilterType.getEncoding());
   Value reshapedFilter = tensor::CollapseShapeOp::create(
       rewriter, loc, reshapedFilterType, filter, filterReassocIndices);
 
diff --git a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
index 8627fcd2576b9..af911e3b3919a 100644
--- a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
+++ b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
@@ -290,7 +290,7 @@ module attributes {transform.with_named_sequence} {
 //  CHECK-DAG: %[[COLLAPSED_OUT:.+]] = tensor.collapse_shape %[[OUTPUT]] {{\[}}[0], [1, 2], [3]] : tensor<1x14x14x16xf32> into tensor<1x196x16xf32>
 //      CHECK: %[[INIT_COL_TENSOR:.+]] = tensor.empty() : tensor<1x196x36xf32>
 //      CHECK: %[[COL_TENSOR:.+]] = linalg.generic
-//           CHECK-SAME: [#[[MAP0]], #[[MAP1]]]
+//           CHECK-SAME: [#[[MAP0]], #[[MAP1]]], {{.*}} ins(%[[INPUT]] : tensor<1x16x16x4xf32>) outs(%[[INIT_COL_TENSOR]] : tensor<1x196x36xf32>)
 //                CHECK: ^bb0(%[[OUT_DATA:.+]]: f32)
 //                CHECK: linalg.yield %{{.+}} : f32
 //      CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic
@@ -327,6 +327,34 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// CHECK: func.func @conv2d_decompose_im2col_with_filter_encoding
+// CHECK-SAME: %[[INPUT:.+]]: tensor<1x16x16x4xf32>,
+// CHECK-SAME: %[[FILTER:.*]]: tensor<16x3x3x4xf32, 42 : i32>,
+// CHECK-SAME: %[[OUTPUT:.*]]: tensor<1x14x14x16xf32>
+//  CHECK-DAG: %[[COLLAPSED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]]
+  // CHECK-SAME{LITERAL}: [[0], [1, 2, 3]] : tensor<16x3x3x4xf32, 42 : i32> into tensor<16x36xf32, 42 : i32>
+//  CHECK-DAG: %[[COLLAPSED_OUT:.+]] = tensor.collapse_shape %[[OUTPUT]] {{\[}}[0], [1, 2], [3]] : tensor<1x14x14x16xf32> into tensor<1x196x16xf32>
+//  CHECK: %[[INIT_COL_TENSOR:.+]] = tensor.empty() : tensor<1x196x36xf32>
+//  CHECK: %[[COL_TENSOR:.+]] = linalg.generic {{.*}} ins(%[[INPUT]] : tensor<1x16x16x4xf32>) outs(%[[INIT_COL_TENSOR]] : tensor<1x196x36xf32>)
+//  CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic
+func.func @conv2d_decompose_im2col_with_filter_encoding(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<16x3x3x4xf32, 42 : i32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
+    %0 = linalg.conv_2d_nhwc_fhwc
+      { dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+      ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<16x3x3x4xf32, 42 : i32>)
+      outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
+    return %0 : tensor<1x14x14x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1:2 = transform.structured.convert_conv2d_to_img2col %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
 // Check for signed extend when the input type is smaller than the accumulator type.
 
 // CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>

``````````

</details>


https://github.com/llvm/llvm-project/pull/160099


More information about the Mlir-commits mailing list