[Mlir-commits] [mlir] a72616d - [mlir][python] fix linalg.pack/unpack (#127729)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Feb 20 08:02:40 PST 2025
Author: Maksim Levental
Date: 2025-02-20T11:02:36-05:00
New Revision: a72616de18c0814ad37b5748d6bdc60b825dd889
URL: https://github.com/llvm/llvm-project/commit/a72616de18c0814ad37b5748d6bdc60b825dd889
DIFF: https://github.com/llvm/llvm-project/commit/a72616de18c0814ad37b5748d6bdc60b825dd889.diff
LOG: [mlir][python] fix linalg.pack/unpack (#127729)
This PR https://github.com/llvm/llvm-project/pull/123902 broke python
bindings for `tensor.pack`/`unpack`. This PR fixes that. It also
1. adds convenience wrappers for pack/unpack
2. cleans up matmul-like ops in the linalg bindings
3. fixes linalg docs missing pack/unpack
Added:
Modified:
mlir/docs/Dialects/Linalg/_index.md
mlir/python/mlir/dialects/LinalgOps.td
mlir/python/mlir/dialects/linalg/__init__.py
mlir/test/python/dialects/linalg/ops.py
Removed:
################################################################################
diff --git a/mlir/docs/Dialects/Linalg/_index.md b/mlir/docs/Dialects/Linalg/_index.md
index 976f0fd3c7e91..b519e4159f186 100644
--- a/mlir/docs/Dialects/Linalg/_index.md
+++ b/mlir/docs/Dialects/Linalg/_index.md
@@ -695,3 +695,4 @@ the same IR.
## Operations
[include "Dialects/LinalgOps.md"]
+[include "Dialects/LinalgRelayoutOps.td"]
diff --git a/mlir/python/mlir/dialects/LinalgOps.td b/mlir/python/mlir/dialects/LinalgOps.td
index b7658c85a9c44..89fb3f219e858 100644
--- a/mlir/python/mlir/dialects/LinalgOps.td
+++ b/mlir/python/mlir/dialects/LinalgOps.td
@@ -11,5 +11,6 @@
include "mlir/Dialect/Linalg/IR/LinalgOps.td"
include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.td"
+include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td"
#endif
diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index c5fbb833ee399..63586a5bb8bbb 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -58,7 +58,11 @@
from .opdsl.ops.core_named_ops import *
from ...ir import *
-from .._ods_common import get_op_result_or_value as _get_op_result_or_value
+from .._ods_common import (
+ get_op_result_or_value as _get_op_result_or_value,
+ get_op_result_or_op_results as _get_op_result_or_op_results,
+ _dispatch_mixed_values,
+)
from ...extras.meta import region_op
@@ -149,7 +153,7 @@ def __init__(
generic = region_op(GenericOp_, terminator=YieldOp)
-def create_op(
+def _create_matmul_like_op(
op_type,
*ins: Union[Operation, OpView, Value],
outs: Sequence[Union[Operation, OpView, Value]],
@@ -179,7 +183,11 @@ def matmul(
indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
cast: Optional[Union[TypeFn, Attribute]] = None,
):
- return create_op(MatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast)
+ return _get_op_result_or_op_results(
+ _create_matmul_like_op(
+ MatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
+ )
+ )
def batch_matmul(
@@ -188,8 +196,10 @@ def batch_matmul(
indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
cast: Optional[Union[TypeFn, Attribute]] = None,
):
- return create_op(
- BatchMatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
+ return _get_op_result_or_op_results(
+ _create_matmul_like_op(
+ BatchMatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
+ )
)
@@ -199,6 +209,72 @@ def contract(
indexing_maps: Sequence[AffineMapAttr],
cast: Optional[Union[TypeFn, Attribute]] = None,
):
- return create_op(
- ContractOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
+ return _get_op_result_or_op_results(
+ _create_matmul_like_op(
+ ContractOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
+ )
+ )
+
+
+def pack(
+ source,
+ dest,
+ inner_dims_pos,
+ inner_tiles,
+ *,
+ padding_value=None,
+ outer_dims_perm=None,
+ loc=None,
+ ip=None,
+) -> ir.Value:
+ (
+ dynamic_inner_tiles,
+ # packed here means %1:2 packing (results packing)
+ _inner_tiles,
+ static_inner_tiles,
+ ) = _dispatch_mixed_values(inner_tiles)
+
+ return _get_op_result_or_op_results(
+ PackOp(
+ source=source,
+ dest=dest,
+ inner_dims_pos=inner_dims_pos,
+ inner_tiles=dynamic_inner_tiles,
+ static_inner_tiles=static_inner_tiles,
+ padding_value=padding_value,
+ outer_dims_perm=outer_dims_perm,
+ loc=loc,
+ ip=ip,
+ )
+ )
+
+
+def unpack(
+ source,
+ dest,
+ inner_dims_pos,
+ inner_tiles,
+ *,
+ outer_dims_perm=None,
+ loc=None,
+ ip=None,
+) -> ir.Value:
+ (
+ dynamic_inner_tiles,
+ # packed here means %1:2 packing (results packing)
+ _inner_tiles,
+ static_inner_tiles,
+ ) = _dispatch_mixed_values(inner_tiles)
+
+ return _get_op_result_or_op_results(
+ UnPackOp(
+ source=source,
+ dest=dest,
+ inner_dims_pos=inner_dims_pos,
+ inner_tiles=dynamic_inner_tiles,
+ static_inner_tiles=static_inner_tiles,
+ outer_dims_perm=outer_dims_perm,
+ loc=loc,
+ ip=ip,
+ )
)
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index 307a88709ad52..e32a911b24b11 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -566,3 +566,43 @@ def batch_matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
)
print(module)
+
+
+# CHECK-LABEL: TEST: testPackUnPackOp
+ at run
+def testPackUnPackOp():
+ with Context(), Location.unknown():
+ module = Module.create()
+ f32 = F32Type.get()
+ with InsertionPoint(module.body):
+
+ @func.FuncOp.from_py_func(
+ RankedTensorType.get((128, 128), f32),
+ RankedTensorType.get((16, 16, 8, 8), f32),
+ )
+ def tensor_pack(src, dst):
+ packed = linalg.pack(
+ src,
+ dst,
+ inner_dims_pos=[1, 0],
+ inner_tiles=[8, 8],
+ padding_value=arith.constant(f32, 0.0),
+ )
+
+ unpacked = linalg.unpack(
+ packed,
+ src,
+ inner_dims_pos=[0, 1],
+ inner_tiles=[8, 8],
+ )
+
+ return unpacked
+
+ # CHECK-LABEL: func.func @tensor_pack(
+ # CHECK-SAME: %[[VAL_0:.*]]: tensor<128x128xf32>, %[[VAL_1:.*]]: tensor<16x16x8x8xf32>) -> tensor<128x128xf32> {
+ # CHECK: %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f32
+ # CHECK: %[[VAL_3:.*]] = linalg.pack %[[VAL_0]] padding_value(%[[VAL_2]] : f32) inner_dims_pos = [1, 0] inner_tiles = [8, 8] into %[[VAL_1]] : tensor<128x128xf32> -> tensor<16x16x8x8xf32>
+ # CHECK: %[[VAL_4:.*]] = linalg.unpack %[[VAL_3]] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %[[VAL_0]] : tensor<16x16x8x8xf32> -> tensor<128x128xf32>
+ # CHECK: return %[[VAL_4]] : tensor<128x128xf32>
+ # CHECK: }
+ print(module)
More information about the Mlir-commits
mailing list