[Mlir-commits] [mlir] [mlir][python] fix linalg.pack/unpack (PR #127729)
Maksim Levental
llvmlistbot at llvm.org
Wed Feb 19 07:55:54 PST 2025
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/127729
>From 9623fa682d799755e4d6f2c577bda32c08a1034c Mon Sep 17 00:00:00 2001
From: Maksim Levental <maksim.levental at gmail.com>
Date: Tue, 18 Feb 2025 19:29:36 -0500
Subject: [PATCH] [mlir][python] fix linalg.pack
---
mlir/python/mlir/dialects/LinalgOps.td | 1 +
mlir/python/mlir/dialects/linalg/__init__.py | 36 +++++++++++++++++++-
mlir/test/python/dialects/linalg/ops.py | 31 +++++++++++++++++
3 files changed, 67 insertions(+), 1 deletion(-)
diff --git a/mlir/python/mlir/dialects/LinalgOps.td b/mlir/python/mlir/dialects/LinalgOps.td
index b7658c85a9c44..89fb3f219e858 100644
--- a/mlir/python/mlir/dialects/LinalgOps.td
+++ b/mlir/python/mlir/dialects/LinalgOps.td
@@ -11,5 +11,6 @@
include "mlir/Dialect/Linalg/IR/LinalgOps.td"
include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.td"
+include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td"
#endif
diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index c5fbb833ee399..38540e87c7d7b 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -58,7 +58,10 @@
from .opdsl.ops.core_named_ops import *
from ...ir import *
-from .._ods_common import get_op_result_or_value as _get_op_result_or_value
+from .._ods_common import (
+ get_op_result_or_value as _get_op_result_or_value,
+ _dispatch_mixed_values,
+)
from ...extras.meta import region_op
@@ -202,3 +205,34 @@ def contract(
return create_op(
ContractOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
)
+
+
+def pack(
+ source,
+ dest,
+ inner_dims_pos,
+ inner_tiles,
+ *,
+ padding_value=None,
+ outer_dims_perm=None,
+ loc=None,
+ ip=None,
+) -> ir.Value:
+ (
+ dynamic_inner_tiles,
+ # packed here means %1:2 packing (results packing)
+ _inner_tiles,
+ static_inner_tiles,
+ ) = _dispatch_mixed_values(inner_tiles)
+
+ return PackOp(
+ source=source,
+ dest=dest,
+ inner_dims_pos=inner_dims_pos,
+ inner_tiles=dynamic_inner_tiles,
+ static_inner_tiles=static_inner_tiles,
+ padding_value=padding_value,
+ outer_dims_perm=outer_dims_perm,
+ loc=loc,
+ ip=ip,
+ ).result
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index 307a88709ad52..d199558750e1e 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -566,3 +566,34 @@ def batch_matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
)
print(module)
+
+
+# CHECK-LABEL: TEST: testPackOp
+ at run
+def testPackOp():
+ with Context(), Location.unknown():
+ module = Module.create()
+ f32 = F32Type.get()
+ with InsertionPoint(module.body):
+
+ @func.FuncOp.from_py_func(
+ RankedTensorType.get((129, 47, 16, 16), f32),
+ RankedTensorType.get((17, 2, 16, 16, 32, 8), f32),
+ )
+ def tensor_pack(src, dst):
+ return linalg.pack(
+ src,
+ dst,
+ inner_dims_pos=[1, 0],
+ inner_tiles=[32, 8],
+ padding_value=arith.constant(f32, 0.0),
+ )
+
+ # CHECK-LABEL: func.func @tensor_pack(
+ # CHECK-SAME: %[[VAL_0:.*]]: tensor<129x47x16x16xf32>,
+ # CHECK-SAME: %[[VAL_1:.*]]: tensor<17x2x16x16x32x8xf32>) -> tensor<17x2x16x16x32x8xf32> {
+ # CHECK: %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f32
+ # CHECK: %[[VAL_3:.*]] = linalg.pack %[[VAL_0]] padding_value(%[[VAL_2]] : f32) inner_dims_pos = [1, 0] inner_tiles = [32, 8] into %[[VAL_1]] : tensor<129x47x16x16xf32> -> tensor<17x2x16x16x32x8xf32>
+ # CHECK: return %[[VAL_3]] : tensor<17x2x16x16x32x8xf32>
+ # CHECK: }
+ print(module)
More information about the Mlir-commits
mailing list