[Mlir-commits] [mlir] [mlir][Linalg] Preserve encodings in static shape inference. (PR #132311)

Han-Chung Wang llvmlistbot at llvm.org
Fri Mar 21 10:26:14 PDT 2025


https://github.com/hanhanW updated https://github.com/llvm/llvm-project/pull/132311

>From 24d6062913128a85c388273629a64b1b0326049d Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Thu, 20 Mar 2025 16:54:56 -0700
Subject: [PATCH 1/2] [mlir][Linalg] Preserve encodings in static shape
 inference.

Previously, the encodings are unconditionally dropped during the shape
inference. The revision adds the support for preserving the encodings in
the linalg ops.

Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp   |  3 ++-
 mlir/test/Dialect/Linalg/canonicalize.mlir | 23 ++++++++++++++++++++++
 2 files changed, 25 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 07b19e5cb1a89..275c107cd70f8 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2539,7 +2539,8 @@ static void createNewOperandWithStaticSizes(
     newShape.push_back(affineExprToSize[dimExpr]);
     newOperandNeeded = true;
   }
-  resultType = RankedTensorType::get(newShape, sourceType.getElementType());
+  resultType = RankedTensorType::get(newShape, sourceType.getElementType(),
+                                     sourceType.getEncoding());
   if (newOperandNeeded) {
     changeNeeded = true;
     // Get the new operand value given its size and element type by
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index db4f6181f517c..103ec55dfa441 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -649,6 +649,29 @@ func.func @cast_dest(%arg0: tensor<?x?x?xf32>, %arg1: tensor<1x?x?xf32>, %arg2:
 
 // -----
 
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
+// CHECK-DAG:   #[[$SPARSE:.+]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
+// CHECK-LABEL: func @static_shape_inference_with_encoding(
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]
+func.func @static_shape_inference_with_encoding(%arg0: tensor<?x?xf32, #sparse>, %arg1: tensor<?x?xf32>) -> tensor<3x4xf32> {
+  %0 = tensor.empty() : tensor<3x4xf32>
+  %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32, #sparse>, tensor<?x?xf32>) outs(%0 : tensor<3x4xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %2 = arith.addf %in, %in_0 : f32
+    linalg.yield %2 : f32
+  } -> tensor<3x4xf32>
+  return %1 : tensor<3x4xf32>
+    //  CHECK:      %[[CAST_ARG0:.*]] = tensor.cast %[[ARG0]] : tensor<?x?xf32, #[[$SPARSE]]> to tensor<3x4xf32, #[[$SPARSE]]>
+    //  CHECK-NEXT: %[[CAST_ARG1:.*]] = tensor.cast %[[ARG1]] : tensor<?x?xf32> to tensor<3x4xf32>
+    //  CHECK-NEXT: %[[GENERIC_OP:.*]] = linalg.generic
+    //  CHECK-SAME: ins(%[[CAST_ARG0]], %[[CAST_ARG1]] : tensor<3x4xf32, #[[$SPARSE]]>, tensor<3x4xf32>)
+    //  CHECK-SAME: outs({{.*}} : tensor<3x4xf32>)
+}
+
+// -----
+
 //       CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 + 1)>
 // CHECK-LABEL: func @insert_pad_into_fill
 //  CHECK-SAME: (%[[INPUT:.+]]: tensor<?x?x?xf32>, %[[LOW0:.+]]: index, %[[LOW1:.+]]: index, %{{.+}}: index, %{{.+}}: index)

>From 9acf0324865b15e122f47ad9076efd37adc59f4d Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Fri, 21 Mar 2025 10:25:51 -0700
Subject: [PATCH 2/2] format generic op

Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
 mlir/test/Dialect/Linalg/canonicalize.mlir | 6 +++++-
 1 file changed, 5 insertions(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 103ec55dfa441..f99491c25d832 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -657,7 +657,11 @@ func.func @cast_dest(%arg0: tensor<?x?x?xf32>, %arg1: tensor<1x?x?xf32>, %arg2:
 // CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]
 func.func @static_shape_inference_with_encoding(%arg0: tensor<?x?xf32, #sparse>, %arg1: tensor<?x?xf32>) -> tensor<3x4xf32> {
   %0 = tensor.empty() : tensor<3x4xf32>
-  %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32, #sparse>, tensor<?x?xf32>) outs(%0 : tensor<3x4xf32>) {
+  %1 = linalg.generic {
+    indexing_maps = [#map, #map, #map],
+    iterator_types = ["parallel", "parallel"]
+  } ins(%arg0, %arg1 : tensor<?x?xf32, #sparse>, tensor<?x?xf32>)
+    outs(%0 : tensor<3x4xf32>) {
   ^bb0(%in: f32, %in_0: f32, %out: f32):
     %2 = arith.addf %in, %in_0 : f32
     linalg.yield %2 : f32



More information about the Mlir-commits mailing list