[Mlir-commits] [mlir] [mlir][python] fix linalg.pack/unpack (PR #127729)

Maksim Levental llvmlistbot at llvm.org
Wed Feb 19 07:58:13 PST 2025


https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/127729

>From d18ca566050d5816950dce5fae6fdb4b97fb9208 Mon Sep 17 00:00:00 2001
From: Maksim Levental <maksim.levental at gmail.com>
Date: Tue, 18 Feb 2025 19:29:36 -0500
Subject: [PATCH] [mlir][python] fix linalg.pack

---
 mlir/python/mlir/dialects/LinalgOps.td       |  1 +
 mlir/python/mlir/dialects/linalg/__init__.py | 39 +++++++++++++++++++-
 mlir/test/python/dialects/linalg/ops.py      | 31 ++++++++++++++++
 3 files changed, 70 insertions(+), 1 deletion(-)

diff --git a/mlir/python/mlir/dialects/LinalgOps.td b/mlir/python/mlir/dialects/LinalgOps.td
index b7658c85a9c44..89fb3f219e858 100644
--- a/mlir/python/mlir/dialects/LinalgOps.td
+++ b/mlir/python/mlir/dialects/LinalgOps.td
@@ -11,5 +11,6 @@
 
 include "mlir/Dialect/Linalg/IR/LinalgOps.td"
 include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.td"
+include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td"
 
 #endif
diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index c5fbb833ee399..b99344d34db89 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -58,7 +58,11 @@
 from .opdsl.ops.core_named_ops import *
 
 from ...ir import *
-from .._ods_common import get_op_result_or_value as _get_op_result_or_value
+from .._ods_common import (
+    get_op_result_or_value as _get_op_result_or_value,
+    get_op_result_or_op_results as _get_op_result_or_op_results,
+    _dispatch_mixed_values,
+)
 from ...extras.meta import region_op
 
 
@@ -202,3 +206,36 @@ def contract(
     return create_op(
         ContractOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
     )
+
+
+def pack(
+    source,
+    dest,
+    inner_dims_pos,
+    inner_tiles,
+    *,
+    padding_value=None,
+    outer_dims_perm=None,
+    loc=None,
+    ip=None,
+) -> ir.Value:
+    (
+        dynamic_inner_tiles,
+        # packed here means %1:2 packing (results packing)
+        _inner_tiles,
+        static_inner_tiles,
+    ) = _dispatch_mixed_values(inner_tiles)
+
+    return _get_op_result_or_op_results(
+        PackOp(
+            source=source,
+            dest=dest,
+            inner_dims_pos=inner_dims_pos,
+            inner_tiles=dynamic_inner_tiles,
+            static_inner_tiles=static_inner_tiles,
+            padding_value=padding_value,
+            outer_dims_perm=outer_dims_perm,
+            loc=loc,
+            ip=ip,
+        )
+    )
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index 307a88709ad52..d199558750e1e 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -566,3 +566,34 @@ def batch_matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
                 )
 
     print(module)
+
+
+# CHECK-LABEL: TEST: testPackOp
+ at run
+def testPackOp():
+    with Context(), Location.unknown():
+        module = Module.create()
+        f32 = F32Type.get()
+        with InsertionPoint(module.body):
+
+            @func.FuncOp.from_py_func(
+                RankedTensorType.get((129, 47, 16, 16), f32),
+                RankedTensorType.get((17, 2, 16, 16, 32, 8), f32),
+            )
+            def tensor_pack(src, dst):
+                return linalg.pack(
+                    src,
+                    dst,
+                    inner_dims_pos=[1, 0],
+                    inner_tiles=[32, 8],
+                    padding_value=arith.constant(f32, 0.0),
+                )
+
+        # CHECK-LABEL:   func.func @tensor_pack(
+        # CHECK-SAME:                           %[[VAL_0:.*]]: tensor<129x47x16x16xf32>,
+        # CHECK-SAME:                           %[[VAL_1:.*]]: tensor<17x2x16x16x32x8xf32>) -> tensor<17x2x16x16x32x8xf32> {
+        # CHECK:           %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f32
+        # CHECK:           %[[VAL_3:.*]] = linalg.pack %[[VAL_0]] padding_value(%[[VAL_2]] : f32) inner_dims_pos = [1, 0] inner_tiles = [32, 8] into %[[VAL_1]] : tensor<129x47x16x16xf32> -> tensor<17x2x16x16x32x8xf32>
+        # CHECK:           return %[[VAL_3]] : tensor<17x2x16x16x32x8xf32>
+        # CHECK:         }
+        print(module)



More information about the Mlir-commits mailing list