[Mlir-commits] [mlir] [MLIR][Linalg] Expose linalg.matmul and linalg.contract via Python API (PR #126377)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Feb 8 08:53:57 PST 2025
github-actions[bot] wrote:
<!--LLVM CODE FORMAT COMMENT: {darker}-->
:warning: Python code formatter, darker found issues in your code. :warning:
<details>
<summary>
You can test this locally with the following command:
</summary>
``````````bash
darker --check --diff -r 564b9b7f4db05b5ce3558041b164f21dfe051a91...e3f2a7c23d2d10809e91ba419ab1092dda6264eb mlir/python/mlir/dialects/linalg/__init__.py mlir/test/python/dialects/linalg/ops.py
``````````
</details>
<details>
<summary>
View the diff from darker here.
</summary>
``````````diff
--- python/mlir/dialects/linalg/__init__.py 2025-02-08 16:44:27.000000 +0000
+++ python/mlir/dialects/linalg/__init__.py 2025-02-08 16:53:27.252780 +0000
@@ -152,11 +152,11 @@
def matmul(
inputs: Sequence[Union[Operation, OpView, Value]],
*,
outs: Sequence[Union[Operation, OpView, Value]],
indexing_maps: Sequence[AffineMapAttr],
- cast: Optional[Union[TypeFn, Attribute]]=None
+ cast: Optional[Union[TypeFn, Attribute]] = None,
):
inputs = [_get_op_result_or_value(input) for input in inputs]
if len(outs) > 1:
raise ValueError(f"{outs=} must have length 1.")
init = _get_op_result_or_value(outs[0])
@@ -165,22 +165,22 @@
op = MatmulOp(
result_tensors=result_types,
inputs=inputs,
outputs=[init],
indexing_maps=indexing_maps,
- cast=cast
+ cast=cast,
)
fill_builtin_region(op.operation)
return op
def contract(
inputs: Sequence[Union[Operation, OpView, Value]],
*,
outs: Sequence[Union[Operation, OpView, Value]],
indexing_maps: Sequence[AffineMapAttr],
- cast: Optional[Union[TypeFn, Attribute]]=None
+ cast: Optional[Union[TypeFn, Attribute]] = None,
):
inputs = [_get_op_result_or_value(input) for input in inputs]
if len(outs) > 1:
raise ValueError(f"{outs=} must have length 1.")
init = _get_op_result_or_value(outs[0])
@@ -189,9 +189,9 @@
op = ContractOp(
result_tensors=result_types,
inputs=inputs,
outputs=[init],
indexing_maps=indexing_maps,
- cast=cast
+ cast=cast,
)
fill_builtin_region(op.operation)
return op
--- test/python/dialects/linalg/ops.py 2025-02-08 16:44:27.000000 +0000
+++ test/python/dialects/linalg/ops.py 2025-02-08 16:53:27.402488 +0000
@@ -305,50 +305,62 @@
# CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
op4 = linalg.MatmulOp(
result_tensors=(C.type,),
inputs=(A, B),
outputs=(C,),
- indexing_maps=[a_map, b_map, c_map]
+ indexing_maps=[a_map, b_map, c_map],
)
linalg.fill_builtin_region(op4.operation)
# CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
- op5 = linalg.matmul((A, B), outs=(C,), indexing_maps=[a_map, b_map, c_map])
+ op5 = linalg.matmul(
+ (A, B), outs=(C,), indexing_maps=[a_map, b_map, c_map]
+ )
# CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
op4 = linalg.MatmulOp(
result_tensors=(C.type,),
inputs=(A, Btransposed),
outputs=(C,),
- indexing_maps=[a_map, b_transposed_map, c_map]
+ indexing_maps=[a_map, b_transposed_map, c_map],
)
linalg.fill_builtin_region(op4.operation)
# CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
- op5 = linalg.matmul((A, Btransposed), outs=(C,), indexing_maps=[a_map, b_transposed_map, c_map])
+ op5 = linalg.matmul(
+ (A, Btransposed),
+ outs=(C,),
+ indexing_maps=[a_map, b_transposed_map, c_map],
+ )
# And now with memrefs...
# CHECK: linalg.matmul ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
op4 = linalg.MatmulOp(
result_tensors=[],
inputs=(Amem, Bmem),
outputs=(Cmem,),
- indexing_maps=[a_map, b_map, c_map]
+ indexing_maps=[a_map, b_map, c_map],
)
linalg.fill_builtin_region(op4.operation)
# CHECK: linalg.matmul ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
- linalg.matmul((Amem, Bmem), outs=(Cmem,), indexing_maps=[a_map, b_map, c_map])
+ linalg.matmul(
+ (Amem, Bmem), outs=(Cmem,), indexing_maps=[a_map, b_map, c_map]
+ )
# CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
op4 = linalg.MatmulOp(
result_tensors=[],
inputs=(Amem, Btransposedmem),
outputs=(Cmem,),
- indexing_maps=[a_map, b_transposed_map, c_map]
+ indexing_maps=[a_map, b_transposed_map, c_map],
)
linalg.fill_builtin_region(op4.operation)
# CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
- linalg.matmul((Amem, Btransposedmem), outs=(Cmem,), indexing_maps=[a_map, b_transposed_map, c_map])
+ linalg.matmul(
+ (Amem, Btransposedmem),
+ outs=(Cmem,),
+ indexing_maps=[a_map, b_transposed_map, c_map],
+ )
print(module)
# CHECK-LABEL: TEST: testContractOp
@@ -393,52 +405,66 @@
# CHECK-SAME: %[[C:.*]]: tensor<4x12xf32>,
RankedTensorType.get(c_shape, f32),
# CHECK-SAME: %[[Cmem:.*]]: memref<4x12xf32>)
MemRefType.get(c_shape, f32),
)
- def matmul_as_contract_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
+ def matmul_as_contract_op(
+ A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem
+ ):
# CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$B_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
op4 = linalg.ContractOp(
result_tensors=(C.type,),
inputs=(A, B),
outputs=(C,),
- indexing_maps=[a_map, b_map, c_map]
+ indexing_maps=[a_map, b_map, c_map],
)
linalg.fill_builtin_region(op4.operation)
# CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$B_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
- op5 = linalg.contract((A, B), outs=(C,), indexing_maps=[a_map, b_map, c_map])
+ op5 = linalg.contract(
+ (A, B), outs=(C,), indexing_maps=[a_map, b_map, c_map]
+ )
# CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
op4 = linalg.ContractOp(
result_tensors=(C.type,),
inputs=(A, Btransposed),
outputs=(C,),
- indexing_maps=[a_map, b_transposed_map, c_map]
+ indexing_maps=[a_map, b_transposed_map, c_map],
)
linalg.fill_builtin_region(op4.operation)
# CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
- op5 = linalg.contract((A, Btransposed), outs=(C,), indexing_maps=[a_map, b_transposed_map, c_map])
+ op5 = linalg.contract(
+ (A, Btransposed),
+ outs=(C,),
+ indexing_maps=[a_map, b_transposed_map, c_map],
+ )
# And now with memrefs...
# CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$B_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
op4 = linalg.ContractOp(
result_tensors=[],
inputs=(Amem, Bmem),
outputs=(Cmem,),
- indexing_maps=[a_map, b_map, c_map]
+ indexing_maps=[a_map, b_map, c_map],
)
linalg.fill_builtin_region(op4.operation)
# CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$B_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
- linalg.contract((Amem, Bmem), outs=(Cmem,), indexing_maps=[a_map, b_map, c_map])
+ linalg.contract(
+ (Amem, Bmem), outs=(Cmem,), indexing_maps=[a_map, b_map, c_map]
+ )
# CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
op4 = linalg.ContractOp(
result_tensors=[],
inputs=(Amem, Btransposedmem),
outputs=(Cmem,),
- indexing_maps=[a_map, b_transposed_map, c_map]
+ indexing_maps=[a_map, b_transposed_map, c_map],
)
linalg.fill_builtin_region(op4.operation)
# CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
- linalg.contract((Amem, Btransposedmem), outs=(Cmem,), indexing_maps=[a_map, b_transposed_map, c_map])
+ linalg.contract(
+ (Amem, Btransposedmem),
+ outs=(Cmem,),
+ indexing_maps=[a_map, b_transposed_map, c_map],
+ )
print(module)
``````````
</details>
https://github.com/llvm/llvm-project/pull/126377
More information about the Mlir-commits
mailing list