[Mlir-commits] [mlir] [mlir][python] fix linalg.pack/unpack (PR #127729)
Maksim Levental
llvmlistbot at llvm.org
Wed Feb 19 19:08:59 PST 2025
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/127729
>From d18ca566050d5816950dce5fae6fdb4b97fb9208 Mon Sep 17 00:00:00 2001
From: Maksim Levental <maksim.levental at gmail.com>
Date: Tue, 18 Feb 2025 19:29:36 -0500
Subject: [PATCH 1/2] [mlir][python] fix linalg.pack
---
mlir/python/mlir/dialects/LinalgOps.td | 1 +
mlir/python/mlir/dialects/linalg/__init__.py | 39 +++++++++++++++++++-
mlir/test/python/dialects/linalg/ops.py | 31 ++++++++++++++++
3 files changed, 70 insertions(+), 1 deletion(-)
diff --git a/mlir/python/mlir/dialects/LinalgOps.td b/mlir/python/mlir/dialects/LinalgOps.td
index b7658c85a9c44..89fb3f219e858 100644
--- a/mlir/python/mlir/dialects/LinalgOps.td
+++ b/mlir/python/mlir/dialects/LinalgOps.td
@@ -11,5 +11,6 @@
include "mlir/Dialect/Linalg/IR/LinalgOps.td"
include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.td"
+include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td"
#endif
diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index c5fbb833ee399..b99344d34db89 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -58,7 +58,11 @@
from .opdsl.ops.core_named_ops import *
from ...ir import *
-from .._ods_common import get_op_result_or_value as _get_op_result_or_value
+from .._ods_common import (
+ get_op_result_or_value as _get_op_result_or_value,
+ get_op_result_or_op_results as _get_op_result_or_op_results,
+ _dispatch_mixed_values,
+)
from ...extras.meta import region_op
@@ -202,3 +206,36 @@ def contract(
return create_op(
ContractOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
)
+
+
+def pack(
+ source,
+ dest,
+ inner_dims_pos,
+ inner_tiles,
+ *,
+ padding_value=None,
+ outer_dims_perm=None,
+ loc=None,
+ ip=None,
+) -> ir.Value:
+ (
+ dynamic_inner_tiles,
+ # packed here means %1:2 packing (results packing)
+ _inner_tiles,
+ static_inner_tiles,
+ ) = _dispatch_mixed_values(inner_tiles)
+
+ return _get_op_result_or_op_results(
+ PackOp(
+ source=source,
+ dest=dest,
+ inner_dims_pos=inner_dims_pos,
+ inner_tiles=dynamic_inner_tiles,
+ static_inner_tiles=static_inner_tiles,
+ padding_value=padding_value,
+ outer_dims_perm=outer_dims_perm,
+ loc=loc,
+ ip=ip,
+ )
+ )
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index 307a88709ad52..d199558750e1e 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -566,3 +566,34 @@ def batch_matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
)
print(module)
+
+
+# CHECK-LABEL: TEST: testPackOp
+ at run
+def testPackOp():
+ with Context(), Location.unknown():
+ module = Module.create()
+ f32 = F32Type.get()
+ with InsertionPoint(module.body):
+
+ @func.FuncOp.from_py_func(
+ RankedTensorType.get((129, 47, 16, 16), f32),
+ RankedTensorType.get((17, 2, 16, 16, 32, 8), f32),
+ )
+ def tensor_pack(src, dst):
+ return linalg.pack(
+ src,
+ dst,
+ inner_dims_pos=[1, 0],
+ inner_tiles=[32, 8],
+ padding_value=arith.constant(f32, 0.0),
+ )
+
+ # CHECK-LABEL: func.func @tensor_pack(
+ # CHECK-SAME: %[[VAL_0:.*]]: tensor<129x47x16x16xf32>,
+ # CHECK-SAME: %[[VAL_1:.*]]: tensor<17x2x16x16x32x8xf32>) -> tensor<17x2x16x16x32x8xf32> {
+ # CHECK: %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f32
+ # CHECK: %[[VAL_3:.*]] = linalg.pack %[[VAL_0]] padding_value(%[[VAL_2]] : f32) inner_dims_pos = [1, 0] inner_tiles = [32, 8] into %[[VAL_1]] : tensor<129x47x16x16xf32> -> tensor<17x2x16x16x32x8xf32>
+ # CHECK: return %[[VAL_3]] : tensor<17x2x16x16x32x8xf32>
+ # CHECK: }
+ print(module)
>From 6e772de864d6ad7e07cc5412c1e1d86c7a19b521 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Wed, 19 Feb 2025 21:35:38 -0500
Subject: [PATCH 2/2] fixup
---
mlir/docs/Dialects/Linalg/_index.md | 1 +
mlir/python/mlir/dialects/linalg/__init__.py | 51 +++++++++++++++++---
mlir/test/python/dialects/linalg/ops.py | 29 +++++++----
3 files changed, 65 insertions(+), 16 deletions(-)
diff --git a/mlir/docs/Dialects/Linalg/_index.md b/mlir/docs/Dialects/Linalg/_index.md
index 976f0fd3c7e91..b519e4159f186 100644
--- a/mlir/docs/Dialects/Linalg/_index.md
+++ b/mlir/docs/Dialects/Linalg/_index.md
@@ -695,3 +695,4 @@ the same IR.
## Operations
[include "Dialects/LinalgOps.md"]
+[include "Dialects/LinalgRelayoutOps.td"]
diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index b99344d34db89..63586a5bb8bbb 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -153,7 +153,7 @@ def __init__(
generic = region_op(GenericOp_, terminator=YieldOp)
-def create_op(
+def _create_matmul_like_op(
op_type,
*ins: Union[Operation, OpView, Value],
outs: Sequence[Union[Operation, OpView, Value]],
@@ -183,7 +183,11 @@ def matmul(
indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
cast: Optional[Union[TypeFn, Attribute]] = None,
):
- return create_op(MatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast)
+ return _get_op_result_or_op_results(
+ _create_matmul_like_op(
+ MatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
+ )
+ )
def batch_matmul(
@@ -192,8 +196,10 @@ def batch_matmul(
indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
cast: Optional[Union[TypeFn, Attribute]] = None,
):
- return create_op(
- BatchMatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
+ return _get_op_result_or_op_results(
+ _create_matmul_like_op(
+ BatchMatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
+ )
)
@@ -203,8 +209,10 @@ def contract(
indexing_maps: Sequence[AffineMapAttr],
cast: Optional[Union[TypeFn, Attribute]] = None,
):
- return create_op(
- ContractOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
+ return _get_op_result_or_op_results(
+ _create_matmul_like_op(
+ ContractOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
+ )
)
@@ -239,3 +247,34 @@ def pack(
ip=ip,
)
)
+
+
+def unpack(
+ source,
+ dest,
+ inner_dims_pos,
+ inner_tiles,
+ *,
+ outer_dims_perm=None,
+ loc=None,
+ ip=None,
+) -> ir.Value:
+ (
+ dynamic_inner_tiles,
+ # packed here means %1:2 packing (results packing)
+ _inner_tiles,
+ static_inner_tiles,
+ ) = _dispatch_mixed_values(inner_tiles)
+
+ return _get_op_result_or_op_results(
+ UnPackOp(
+ source=source,
+ dest=dest,
+ inner_dims_pos=inner_dims_pos,
+ inner_tiles=dynamic_inner_tiles,
+ static_inner_tiles=static_inner_tiles,
+ outer_dims_perm=outer_dims_perm,
+ loc=loc,
+ ip=ip,
+ )
+ )
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index d199558750e1e..e32a911b24b11 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -568,32 +568,41 @@ def batch_matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
print(module)
-# CHECK-LABEL: TEST: testPackOp
+# CHECK-LABEL: TEST: testPackUnPackOp
@run
-def testPackOp():
+def testPackUnPackOp():
with Context(), Location.unknown():
module = Module.create()
f32 = F32Type.get()
with InsertionPoint(module.body):
@func.FuncOp.from_py_func(
- RankedTensorType.get((129, 47, 16, 16), f32),
- RankedTensorType.get((17, 2, 16, 16, 32, 8), f32),
+ RankedTensorType.get((128, 128), f32),
+ RankedTensorType.get((16, 16, 8, 8), f32),
)
def tensor_pack(src, dst):
- return linalg.pack(
+ packed = linalg.pack(
src,
dst,
inner_dims_pos=[1, 0],
- inner_tiles=[32, 8],
+ inner_tiles=[8, 8],
padding_value=arith.constant(f32, 0.0),
)
+ unpacked = linalg.unpack(
+ packed,
+ src,
+ inner_dims_pos=[0, 1],
+ inner_tiles=[8, 8],
+ )
+
+ return unpacked
+
# CHECK-LABEL: func.func @tensor_pack(
- # CHECK-SAME: %[[VAL_0:.*]]: tensor<129x47x16x16xf32>,
- # CHECK-SAME: %[[VAL_1:.*]]: tensor<17x2x16x16x32x8xf32>) -> tensor<17x2x16x16x32x8xf32> {
+ # CHECK-SAME: %[[VAL_0:.*]]: tensor<128x128xf32>, %[[VAL_1:.*]]: tensor<16x16x8x8xf32>) -> tensor<128x128xf32> {
# CHECK: %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f32
- # CHECK: %[[VAL_3:.*]] = linalg.pack %[[VAL_0]] padding_value(%[[VAL_2]] : f32) inner_dims_pos = [1, 0] inner_tiles = [32, 8] into %[[VAL_1]] : tensor<129x47x16x16xf32> -> tensor<17x2x16x16x32x8xf32>
- # CHECK: return %[[VAL_3]] : tensor<17x2x16x16x32x8xf32>
+ # CHECK: %[[VAL_3:.*]] = linalg.pack %[[VAL_0]] padding_value(%[[VAL_2]] : f32) inner_dims_pos = [1, 0] inner_tiles = [8, 8] into %[[VAL_1]] : tensor<128x128xf32> -> tensor<16x16x8x8xf32>
+ # CHECK: %[[VAL_4:.*]] = linalg.unpack %[[VAL_3]] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %[[VAL_0]] : tensor<16x16x8x8xf32> -> tensor<128x128xf32>
+ # CHECK: return %[[VAL_4]] : tensor<128x128xf32>
# CHECK: }
print(module)
More information about the Mlir-commits
mailing list