[Mlir-commits] [mlir] a72616d - [mlir][python] fix linalg.pack/unpack (#127729)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Feb 20 08:02:40 PST 2025


Author: Maksim Levental
Date: 2025-02-20T11:02:36-05:00
New Revision: a72616de18c0814ad37b5748d6bdc60b825dd889

URL: https://github.com/llvm/llvm-project/commit/a72616de18c0814ad37b5748d6bdc60b825dd889
DIFF: https://github.com/llvm/llvm-project/commit/a72616de18c0814ad37b5748d6bdc60b825dd889.diff

LOG: [mlir][python] fix linalg.pack/unpack (#127729)

This PR https://github.com/llvm/llvm-project/pull/123902 broke python
bindings for `tensor.pack`/`unpack`. This PR fixes that. It also

1. adds convenience wrappers for pack/unpack
2. cleans up matmul-like ops in the linalg bindings
3. fixes linalg docs missing pack/unpack

Added: 
    

Modified: 
    mlir/docs/Dialects/Linalg/_index.md
    mlir/python/mlir/dialects/LinalgOps.td
    mlir/python/mlir/dialects/linalg/__init__.py
    mlir/test/python/dialects/linalg/ops.py

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Dialects/Linalg/_index.md b/mlir/docs/Dialects/Linalg/_index.md
index 976f0fd3c7e91..b519e4159f186 100644
--- a/mlir/docs/Dialects/Linalg/_index.md
+++ b/mlir/docs/Dialects/Linalg/_index.md
@@ -695,3 +695,4 @@ the same IR.
 ## Operations
 
 [include "Dialects/LinalgOps.md"]
+[include "Dialects/LinalgRelayoutOps.td"]

diff  --git a/mlir/python/mlir/dialects/LinalgOps.td b/mlir/python/mlir/dialects/LinalgOps.td
index b7658c85a9c44..89fb3f219e858 100644
--- a/mlir/python/mlir/dialects/LinalgOps.td
+++ b/mlir/python/mlir/dialects/LinalgOps.td
@@ -11,5 +11,6 @@
 
 include "mlir/Dialect/Linalg/IR/LinalgOps.td"
 include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.td"
+include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td"
 
 #endif

diff  --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index c5fbb833ee399..63586a5bb8bbb 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -58,7 +58,11 @@
 from .opdsl.ops.core_named_ops import *
 
 from ...ir import *
-from .._ods_common import get_op_result_or_value as _get_op_result_or_value
+from .._ods_common import (
+    get_op_result_or_value as _get_op_result_or_value,
+    get_op_result_or_op_results as _get_op_result_or_op_results,
+    _dispatch_mixed_values,
+)
 from ...extras.meta import region_op
 
 
@@ -149,7 +153,7 @@ def __init__(
 generic = region_op(GenericOp_, terminator=YieldOp)
 
 
-def create_op(
+def _create_matmul_like_op(
     op_type,
     *ins: Union[Operation, OpView, Value],
     outs: Sequence[Union[Operation, OpView, Value]],
@@ -179,7 +183,11 @@ def matmul(
     indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
     cast: Optional[Union[TypeFn, Attribute]] = None,
 ):
-    return create_op(MatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast)
+    return _get_op_result_or_op_results(
+        _create_matmul_like_op(
+            MatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
+        )
+    )
 
 
 def batch_matmul(
@@ -188,8 +196,10 @@ def batch_matmul(
     indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
     cast: Optional[Union[TypeFn, Attribute]] = None,
 ):
-    return create_op(
-        BatchMatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
+    return _get_op_result_or_op_results(
+        _create_matmul_like_op(
+            BatchMatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
+        )
     )
 
 
@@ -199,6 +209,72 @@ def contract(
     indexing_maps: Sequence[AffineMapAttr],
     cast: Optional[Union[TypeFn, Attribute]] = None,
 ):
-    return create_op(
-        ContractOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
+    return _get_op_result_or_op_results(
+        _create_matmul_like_op(
+            ContractOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
+        )
+    )
+
+
+def pack(
+    source,
+    dest,
+    inner_dims_pos,
+    inner_tiles,
+    *,
+    padding_value=None,
+    outer_dims_perm=None,
+    loc=None,
+    ip=None,
+) -> ir.Value:
+    (
+        dynamic_inner_tiles,
+        # packed here means %1:2 packing (results packing)
+        _inner_tiles,
+        static_inner_tiles,
+    ) = _dispatch_mixed_values(inner_tiles)
+
+    return _get_op_result_or_op_results(
+        PackOp(
+            source=source,
+            dest=dest,
+            inner_dims_pos=inner_dims_pos,
+            inner_tiles=dynamic_inner_tiles,
+            static_inner_tiles=static_inner_tiles,
+            padding_value=padding_value,
+            outer_dims_perm=outer_dims_perm,
+            loc=loc,
+            ip=ip,
+        )
+    )
+
+
+def unpack(
+    source,
+    dest,
+    inner_dims_pos,
+    inner_tiles,
+    *,
+    outer_dims_perm=None,
+    loc=None,
+    ip=None,
+) -> ir.Value:
+    (
+        dynamic_inner_tiles,
+        # packed here means %1:2 packing (results packing)
+        _inner_tiles,
+        static_inner_tiles,
+    ) = _dispatch_mixed_values(inner_tiles)
+
+    return _get_op_result_or_op_results(
+        UnPackOp(
+            source=source,
+            dest=dest,
+            inner_dims_pos=inner_dims_pos,
+            inner_tiles=dynamic_inner_tiles,
+            static_inner_tiles=static_inner_tiles,
+            outer_dims_perm=outer_dims_perm,
+            loc=loc,
+            ip=ip,
+        )
     )

diff  --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index 307a88709ad52..e32a911b24b11 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -566,3 +566,43 @@ def batch_matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
                 )
 
     print(module)
+
+
+# CHECK-LABEL: TEST: testPackUnPackOp
+ at run
+def testPackUnPackOp():
+    with Context(), Location.unknown():
+        module = Module.create()
+        f32 = F32Type.get()
+        with InsertionPoint(module.body):
+
+            @func.FuncOp.from_py_func(
+                RankedTensorType.get((128, 128), f32),
+                RankedTensorType.get((16, 16, 8, 8), f32),
+            )
+            def tensor_pack(src, dst):
+                packed = linalg.pack(
+                    src,
+                    dst,
+                    inner_dims_pos=[1, 0],
+                    inner_tiles=[8, 8],
+                    padding_value=arith.constant(f32, 0.0),
+                )
+
+                unpacked = linalg.unpack(
+                    packed,
+                    src,
+                    inner_dims_pos=[0, 1],
+                    inner_tiles=[8, 8],
+                )
+
+                return unpacked
+
+        # CHECK-LABEL:   func.func @tensor_pack(
+        # CHECK-SAME:      %[[VAL_0:.*]]: tensor<128x128xf32>, %[[VAL_1:.*]]: tensor<16x16x8x8xf32>) -> tensor<128x128xf32> {
+        # CHECK:           %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f32
+        # CHECK:           %[[VAL_3:.*]] = linalg.pack %[[VAL_0]] padding_value(%[[VAL_2]] : f32) inner_dims_pos = [1, 0] inner_tiles = [8, 8] into %[[VAL_1]] : tensor<128x128xf32> -> tensor<16x16x8x8xf32>
+        # CHECK:           %[[VAL_4:.*]] = linalg.unpack %[[VAL_3]] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %[[VAL_0]] : tensor<16x16x8x8xf32> -> tensor<128x128xf32>
+        # CHECK:           return %[[VAL_4]] : tensor<128x128xf32>
+        # CHECK:         }
+        print(module)


        


More information about the Mlir-commits mailing list