[Mlir-commits] [mlir] [mlir][python] fix linalg.pack/unpack (PR #127729)

Maksim Levental llvmlistbot at llvm.org
Wed Feb 19 19:08:59 PST 2025


https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/127729

>From d18ca566050d5816950dce5fae6fdb4b97fb9208 Mon Sep 17 00:00:00 2001
From: Maksim Levental <maksim.levental at gmail.com>
Date: Tue, 18 Feb 2025 19:29:36 -0500
Subject: [PATCH 1/2] [mlir][python] fix linalg.pack

---
 mlir/python/mlir/dialects/LinalgOps.td       |  1 +
 mlir/python/mlir/dialects/linalg/__init__.py | 39 +++++++++++++++++++-
 mlir/test/python/dialects/linalg/ops.py      | 31 ++++++++++++++++
 3 files changed, 70 insertions(+), 1 deletion(-)

diff --git a/mlir/python/mlir/dialects/LinalgOps.td b/mlir/python/mlir/dialects/LinalgOps.td
index b7658c85a9c44..89fb3f219e858 100644
--- a/mlir/python/mlir/dialects/LinalgOps.td
+++ b/mlir/python/mlir/dialects/LinalgOps.td
@@ -11,5 +11,6 @@
 
 include "mlir/Dialect/Linalg/IR/LinalgOps.td"
 include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.td"
+include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td"
 
 #endif
diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index c5fbb833ee399..b99344d34db89 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -58,7 +58,11 @@
 from .opdsl.ops.core_named_ops import *
 
 from ...ir import *
-from .._ods_common import get_op_result_or_value as _get_op_result_or_value
+from .._ods_common import (
+    get_op_result_or_value as _get_op_result_or_value,
+    get_op_result_or_op_results as _get_op_result_or_op_results,
+    _dispatch_mixed_values,
+)
 from ...extras.meta import region_op
 
 
@@ -202,3 +206,36 @@ def contract(
     return create_op(
         ContractOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
     )
+
+
+def pack(
+    source,
+    dest,
+    inner_dims_pos,
+    inner_tiles,
+    *,
+    padding_value=None,
+    outer_dims_perm=None,
+    loc=None,
+    ip=None,
+) -> ir.Value:
+    (
+        dynamic_inner_tiles,
+        # packed here means %1:2 packing (results packing)
+        _inner_tiles,
+        static_inner_tiles,
+    ) = _dispatch_mixed_values(inner_tiles)
+
+    return _get_op_result_or_op_results(
+        PackOp(
+            source=source,
+            dest=dest,
+            inner_dims_pos=inner_dims_pos,
+            inner_tiles=dynamic_inner_tiles,
+            static_inner_tiles=static_inner_tiles,
+            padding_value=padding_value,
+            outer_dims_perm=outer_dims_perm,
+            loc=loc,
+            ip=ip,
+        )
+    )
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index 307a88709ad52..d199558750e1e 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -566,3 +566,34 @@ def batch_matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
                 )
 
     print(module)
+
+
+# CHECK-LABEL: TEST: testPackOp
+ at run
+def testPackOp():
+    with Context(), Location.unknown():
+        module = Module.create()
+        f32 = F32Type.get()
+        with InsertionPoint(module.body):
+
+            @func.FuncOp.from_py_func(
+                RankedTensorType.get((129, 47, 16, 16), f32),
+                RankedTensorType.get((17, 2, 16, 16, 32, 8), f32),
+            )
+            def tensor_pack(src, dst):
+                return linalg.pack(
+                    src,
+                    dst,
+                    inner_dims_pos=[1, 0],
+                    inner_tiles=[32, 8],
+                    padding_value=arith.constant(f32, 0.0),
+                )
+
+        # CHECK-LABEL:   func.func @tensor_pack(
+        # CHECK-SAME:                           %[[VAL_0:.*]]: tensor<129x47x16x16xf32>,
+        # CHECK-SAME:                           %[[VAL_1:.*]]: tensor<17x2x16x16x32x8xf32>) -> tensor<17x2x16x16x32x8xf32> {
+        # CHECK:           %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f32
+        # CHECK:           %[[VAL_3:.*]] = linalg.pack %[[VAL_0]] padding_value(%[[VAL_2]] : f32) inner_dims_pos = [1, 0] inner_tiles = [32, 8] into %[[VAL_1]] : tensor<129x47x16x16xf32> -> tensor<17x2x16x16x32x8xf32>
+        # CHECK:           return %[[VAL_3]] : tensor<17x2x16x16x32x8xf32>
+        # CHECK:         }
+        print(module)

>From 6e772de864d6ad7e07cc5412c1e1d86c7a19b521 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Wed, 19 Feb 2025 21:35:38 -0500
Subject: [PATCH 2/2] fixup

---
 mlir/docs/Dialects/Linalg/_index.md          |  1 +
 mlir/python/mlir/dialects/linalg/__init__.py | 51 +++++++++++++++++---
 mlir/test/python/dialects/linalg/ops.py      | 29 +++++++----
 3 files changed, 65 insertions(+), 16 deletions(-)

diff --git a/mlir/docs/Dialects/Linalg/_index.md b/mlir/docs/Dialects/Linalg/_index.md
index 976f0fd3c7e91..b519e4159f186 100644
--- a/mlir/docs/Dialects/Linalg/_index.md
+++ b/mlir/docs/Dialects/Linalg/_index.md
@@ -695,3 +695,4 @@ the same IR.
 ## Operations
 
 [include "Dialects/LinalgOps.md"]
+[include "Dialects/LinalgRelayoutOps.td"]
diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index b99344d34db89..63586a5bb8bbb 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -153,7 +153,7 @@ def __init__(
 generic = region_op(GenericOp_, terminator=YieldOp)
 
 
-def create_op(
+def _create_matmul_like_op(
     op_type,
     *ins: Union[Operation, OpView, Value],
     outs: Sequence[Union[Operation, OpView, Value]],
@@ -183,7 +183,11 @@ def matmul(
     indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
     cast: Optional[Union[TypeFn, Attribute]] = None,
 ):
-    return create_op(MatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast)
+    return _get_op_result_or_op_results(
+        _create_matmul_like_op(
+            MatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
+        )
+    )
 
 
 def batch_matmul(
@@ -192,8 +196,10 @@ def batch_matmul(
     indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
     cast: Optional[Union[TypeFn, Attribute]] = None,
 ):
-    return create_op(
-        BatchMatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
+    return _get_op_result_or_op_results(
+        _create_matmul_like_op(
+            BatchMatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
+        )
     )
 
 
@@ -203,8 +209,10 @@ def contract(
     indexing_maps: Sequence[AffineMapAttr],
     cast: Optional[Union[TypeFn, Attribute]] = None,
 ):
-    return create_op(
-        ContractOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
+    return _get_op_result_or_op_results(
+        _create_matmul_like_op(
+            ContractOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
+        )
     )
 
 
@@ -239,3 +247,34 @@ def pack(
             ip=ip,
         )
     )
+
+
+def unpack(
+    source,
+    dest,
+    inner_dims_pos,
+    inner_tiles,
+    *,
+    outer_dims_perm=None,
+    loc=None,
+    ip=None,
+) -> ir.Value:
+    (
+        dynamic_inner_tiles,
+        # packed here means %1:2 packing (results packing)
+        _inner_tiles,
+        static_inner_tiles,
+    ) = _dispatch_mixed_values(inner_tiles)
+
+    return _get_op_result_or_op_results(
+        UnPackOp(
+            source=source,
+            dest=dest,
+            inner_dims_pos=inner_dims_pos,
+            inner_tiles=dynamic_inner_tiles,
+            static_inner_tiles=static_inner_tiles,
+            outer_dims_perm=outer_dims_perm,
+            loc=loc,
+            ip=ip,
+        )
+    )
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index d199558750e1e..e32a911b24b11 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -568,32 +568,41 @@ def batch_matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
     print(module)
 
 
-# CHECK-LABEL: TEST: testPackOp
+# CHECK-LABEL: TEST: testPackUnPackOp
 @run
-def testPackOp():
+def testPackUnPackOp():
     with Context(), Location.unknown():
         module = Module.create()
         f32 = F32Type.get()
         with InsertionPoint(module.body):
 
             @func.FuncOp.from_py_func(
-                RankedTensorType.get((129, 47, 16, 16), f32),
-                RankedTensorType.get((17, 2, 16, 16, 32, 8), f32),
+                RankedTensorType.get((128, 128), f32),
+                RankedTensorType.get((16, 16, 8, 8), f32),
             )
             def tensor_pack(src, dst):
-                return linalg.pack(
+                packed = linalg.pack(
                     src,
                     dst,
                     inner_dims_pos=[1, 0],
-                    inner_tiles=[32, 8],
+                    inner_tiles=[8, 8],
                     padding_value=arith.constant(f32, 0.0),
                 )
 
+                unpacked = linalg.unpack(
+                    packed,
+                    src,
+                    inner_dims_pos=[0, 1],
+                    inner_tiles=[8, 8],
+                )
+
+                return unpacked
+
         # CHECK-LABEL:   func.func @tensor_pack(
-        # CHECK-SAME:                           %[[VAL_0:.*]]: tensor<129x47x16x16xf32>,
-        # CHECK-SAME:                           %[[VAL_1:.*]]: tensor<17x2x16x16x32x8xf32>) -> tensor<17x2x16x16x32x8xf32> {
+        # CHECK-SAME:      %[[VAL_0:.*]]: tensor<128x128xf32>, %[[VAL_1:.*]]: tensor<16x16x8x8xf32>) -> tensor<128x128xf32> {
         # CHECK:           %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f32
-        # CHECK:           %[[VAL_3:.*]] = linalg.pack %[[VAL_0]] padding_value(%[[VAL_2]] : f32) inner_dims_pos = [1, 0] inner_tiles = [32, 8] into %[[VAL_1]] : tensor<129x47x16x16xf32> -> tensor<17x2x16x16x32x8xf32>
-        # CHECK:           return %[[VAL_3]] : tensor<17x2x16x16x32x8xf32>
+        # CHECK:           %[[VAL_3:.*]] = linalg.pack %[[VAL_0]] padding_value(%[[VAL_2]] : f32) inner_dims_pos = [1, 0] inner_tiles = [8, 8] into %[[VAL_1]] : tensor<128x128xf32> -> tensor<16x16x8x8xf32>
+        # CHECK:           %[[VAL_4:.*]] = linalg.unpack %[[VAL_3]] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %[[VAL_0]] : tensor<16x16x8x8xf32> -> tensor<128x128xf32>
+        # CHECK:           return %[[VAL_4]] : tensor<128x128xf32>
         # CHECK:         }
         print(module)



More information about the Mlir-commits mailing list