[Mlir-commits] [mlir] [MLIR][Linalg] Expose linalg.matmul and linalg.contract via Python API (PR #126377)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Feb 8 08:53:57 PST 2025


github-actions[bot] wrote:

<!--LLVM CODE FORMAT COMMENT: {darker}-->


:warning: Python code formatter, darker found issues in your code. :warning:

<details>
<summary>
You can test this locally with the following command:
</summary>

``````````bash
darker --check --diff -r 564b9b7f4db05b5ce3558041b164f21dfe051a91...e3f2a7c23d2d10809e91ba419ab1092dda6264eb mlir/python/mlir/dialects/linalg/__init__.py mlir/test/python/dialects/linalg/ops.py
``````````

</details>

<details>
<summary>
View the diff from darker here.
</summary>

``````````diff
--- python/mlir/dialects/linalg/__init__.py	2025-02-08 16:44:27.000000 +0000
+++ python/mlir/dialects/linalg/__init__.py	2025-02-08 16:53:27.252780 +0000
@@ -152,11 +152,11 @@
 def matmul(
     inputs: Sequence[Union[Operation, OpView, Value]],
     *,
     outs: Sequence[Union[Operation, OpView, Value]],
     indexing_maps: Sequence[AffineMapAttr],
-    cast: Optional[Union[TypeFn, Attribute]]=None
+    cast: Optional[Union[TypeFn, Attribute]] = None,
 ):
     inputs = [_get_op_result_or_value(input) for input in inputs]
     if len(outs) > 1:
         raise ValueError(f"{outs=} must have length 1.")
     init = _get_op_result_or_value(outs[0])
@@ -165,22 +165,22 @@
     op = MatmulOp(
         result_tensors=result_types,
         inputs=inputs,
         outputs=[init],
         indexing_maps=indexing_maps,
-        cast=cast
+        cast=cast,
     )
     fill_builtin_region(op.operation)
     return op
 
 
 def contract(
     inputs: Sequence[Union[Operation, OpView, Value]],
     *,
     outs: Sequence[Union[Operation, OpView, Value]],
     indexing_maps: Sequence[AffineMapAttr],
-    cast: Optional[Union[TypeFn, Attribute]]=None
+    cast: Optional[Union[TypeFn, Attribute]] = None,
 ):
     inputs = [_get_op_result_or_value(input) for input in inputs]
     if len(outs) > 1:
         raise ValueError(f"{outs=} must have length 1.")
     init = _get_op_result_or_value(outs[0])
@@ -189,9 +189,9 @@
     op = ContractOp(
         result_tensors=result_types,
         inputs=inputs,
         outputs=[init],
         indexing_maps=indexing_maps,
-        cast=cast
+        cast=cast,
     )
     fill_builtin_region(op.operation)
     return op
--- test/python/dialects/linalg/ops.py	2025-02-08 16:44:27.000000 +0000
+++ test/python/dialects/linalg/ops.py	2025-02-08 16:53:27.402488 +0000
@@ -305,50 +305,62 @@
                 # CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
                 op4 = linalg.MatmulOp(
                     result_tensors=(C.type,),
                     inputs=(A, B),
                     outputs=(C,),
-                    indexing_maps=[a_map, b_map, c_map]
+                    indexing_maps=[a_map, b_map, c_map],
                 )
                 linalg.fill_builtin_region(op4.operation)
                 # CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
-                op5 = linalg.matmul((A, B), outs=(C,), indexing_maps=[a_map, b_map, c_map])
+                op5 = linalg.matmul(
+                    (A, B), outs=(C,), indexing_maps=[a_map, b_map, c_map]
+                )
 
                 # CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
                 op4 = linalg.MatmulOp(
                     result_tensors=(C.type,),
                     inputs=(A, Btransposed),
                     outputs=(C,),
-                    indexing_maps=[a_map, b_transposed_map, c_map]
+                    indexing_maps=[a_map, b_transposed_map, c_map],
                 )
                 linalg.fill_builtin_region(op4.operation)
                 # CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
-                op5 = linalg.matmul((A, Btransposed), outs=(C,), indexing_maps=[a_map, b_transposed_map, c_map])
+                op5 = linalg.matmul(
+                    (A, Btransposed),
+                    outs=(C,),
+                    indexing_maps=[a_map, b_transposed_map, c_map],
+                )
 
                 # And now with memrefs...
 
                 # CHECK: linalg.matmul ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
                 op4 = linalg.MatmulOp(
                     result_tensors=[],
                     inputs=(Amem, Bmem),
                     outputs=(Cmem,),
-                    indexing_maps=[a_map, b_map, c_map]
+                    indexing_maps=[a_map, b_map, c_map],
                 )
                 linalg.fill_builtin_region(op4.operation)
                 # CHECK: linalg.matmul ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
-                linalg.matmul((Amem, Bmem), outs=(Cmem,), indexing_maps=[a_map, b_map, c_map])
+                linalg.matmul(
+                    (Amem, Bmem), outs=(Cmem,), indexing_maps=[a_map, b_map, c_map]
+                )
 
                 # CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
                 op4 = linalg.MatmulOp(
                     result_tensors=[],
                     inputs=(Amem, Btransposedmem),
                     outputs=(Cmem,),
-                    indexing_maps=[a_map, b_transposed_map, c_map]
+                    indexing_maps=[a_map, b_transposed_map, c_map],
                 )
                 linalg.fill_builtin_region(op4.operation)
                 # CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
-                linalg.matmul((Amem, Btransposedmem), outs=(Cmem,), indexing_maps=[a_map, b_transposed_map, c_map])
+                linalg.matmul(
+                    (Amem, Btransposedmem),
+                    outs=(Cmem,),
+                    indexing_maps=[a_map, b_transposed_map, c_map],
+                )
 
         print(module)
 
 
 # CHECK-LABEL: TEST: testContractOp
@@ -393,52 +405,66 @@
                 # CHECK-SAME:                         %[[C:.*]]: tensor<4x12xf32>,
                 RankedTensorType.get(c_shape, f32),
                 # CHECK-SAME:                         %[[Cmem:.*]]: memref<4x12xf32>)
                 MemRefType.get(c_shape, f32),
             )
-            def matmul_as_contract_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
+            def matmul_as_contract_op(
+                A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem
+            ):
                 # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$B_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
                 op4 = linalg.ContractOp(
                     result_tensors=(C.type,),
                     inputs=(A, B),
                     outputs=(C,),
-                    indexing_maps=[a_map, b_map, c_map]
+                    indexing_maps=[a_map, b_map, c_map],
                 )
                 linalg.fill_builtin_region(op4.operation)
                 # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$B_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
-                op5 = linalg.contract((A, B), outs=(C,), indexing_maps=[a_map, b_map, c_map])
+                op5 = linalg.contract(
+                    (A, B), outs=(C,), indexing_maps=[a_map, b_map, c_map]
+                )
 
                 # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
                 op4 = linalg.ContractOp(
                     result_tensors=(C.type,),
                     inputs=(A, Btransposed),
                     outputs=(C,),
-                    indexing_maps=[a_map, b_transposed_map, c_map]
+                    indexing_maps=[a_map, b_transposed_map, c_map],
                 )
                 linalg.fill_builtin_region(op4.operation)
                 # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
-                op5 = linalg.contract((A, Btransposed), outs=(C,), indexing_maps=[a_map, b_transposed_map, c_map])
+                op5 = linalg.contract(
+                    (A, Btransposed),
+                    outs=(C,),
+                    indexing_maps=[a_map, b_transposed_map, c_map],
+                )
                 # And now with memrefs...
 
                 # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$B_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
                 op4 = linalg.ContractOp(
                     result_tensors=[],
                     inputs=(Amem, Bmem),
                     outputs=(Cmem,),
-                    indexing_maps=[a_map, b_map, c_map]
+                    indexing_maps=[a_map, b_map, c_map],
                 )
                 linalg.fill_builtin_region(op4.operation)
                 # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$B_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
-                linalg.contract((Amem, Bmem), outs=(Cmem,), indexing_maps=[a_map, b_map, c_map])
+                linalg.contract(
+                    (Amem, Bmem), outs=(Cmem,), indexing_maps=[a_map, b_map, c_map]
+                )
 
                 # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
                 op4 = linalg.ContractOp(
                     result_tensors=[],
                     inputs=(Amem, Btransposedmem),
                     outputs=(Cmem,),
-                    indexing_maps=[a_map, b_transposed_map, c_map]
+                    indexing_maps=[a_map, b_transposed_map, c_map],
                 )
                 linalg.fill_builtin_region(op4.operation)
                 # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
-                linalg.contract((Amem, Btransposedmem), outs=(Cmem,), indexing_maps=[a_map, b_transposed_map, c_map])
+                linalg.contract(
+                    (Amem, Btransposedmem),
+                    outs=(Cmem,),
+                    indexing_maps=[a_map, b_transposed_map, c_map],
+                )
 
         print(module)

``````````

</details>


https://github.com/llvm/llvm-project/pull/126377


More information about the Mlir-commits mailing list