[Mlir-commits] [mlir] [MLIR][Linalg] Expose linalg.matmul and linalg.contract via Python API (PR #126377)

Rolf Morel llvmlistbot at llvm.org
Sat Feb 8 08:50:25 PST 2025


https://github.com/rolfmorel created https://github.com/llvm/llvm-project/pull/126377

Now that linalg.matmul is in tablegen, "hand write" the Python wrapper that OpDSL used to derive. Similarly, add a Python wrapper for the new linalg.contract op.

Required following misc. fixes:
1) make linalg.matmul consistent in whether indexing_maps occurs before
   or after operands, i.e. per the tests case it comes _before_.
   TODO: fix linalg.batch_matmul as well
2) tablegen for linalg.contract did not state it accepted an optional
   cast attr.

>From e3f2a7c23d2d10809e91ba419ab1092dda6264eb Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Sat, 8 Feb 2025 08:41:22 -0800
Subject: [PATCH] [MLIR][Linalg] Expose linalg.matmul and linalg.contract via
 Python API

Now that linalg.matmul is in tablegen, "hand write" the Python wrapper
that OpDSL used to derive. Similarly, add a Python wrapper for the new
linalg.contract op.

Required following misc. fixes:
1) make linalg.matmul consistent in whether indexing_maps occurs before
   or after operands, i.e. per the tests case it comes _before_.
   TODO: fix linalg.batch_matmul as well
2) tablegen for linalg.contract did not state it accepted an optional
   cast attr.
---
 .../Dialect/Linalg/IR/LinalgStructuredOps.td  |   3 +-
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      |  10 +-
 mlir/python/mlir/dialects/linalg/__init__.py  |  48 +++++
 mlir/test/python/dialects/linalg/ops.py       | 186 ++++++++++++++++++
 4 files changed, 241 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 110ed7d2fc00e2a..6146ff09482fbad 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -752,7 +752,8 @@ def ContractOp : LinalgStructuredBase_Op<"contract", [
   let arguments = (ins
     Variadic<AnyType>:$inputs,
     Variadic<AnyShaped>:$outputs,
-    AffineMapArrayAttr:$indexing_maps
+    AffineMapArrayAttr:$indexing_maps,
+    DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
   );
   let results = (outs Variadic<AnyShaped>:$result_tensors);
   // NB: The only reason this op has a region - and it get populated at op build
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index b50931f15826ce2..d40cec02df6338d 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -3666,11 +3666,6 @@ ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
 }
 
 void MatmulOp::print(OpAsmPrinter &p) {
-  SmallVector<StringRef, 3> elidedAttrs = {
-      "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
-  printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
-                         elidedAttrs);
-
   SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
       MatmulOp::getDefaultIndexingMaps(getContext()),
       [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
@@ -3680,6 +3675,11 @@ void MatmulOp::print(OpAsmPrinter &p) {
                           [&](Attribute attr) { p.printAttribute(attr); });
     p << "]";
   }
+
+  SmallVector<StringRef, 3> elidedAttrs = {
+      "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
+  printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
+                         elidedAttrs);
 }
 
 /// Verify the user defined indexing maps.
diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index 742262a9c496952..07e7c0ccd70025d 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -147,3 +147,51 @@ def __init__(
 
 
 generic = region_op(GenericOp_, terminator=YieldOp)
+
+
+def matmul(
+    inputs: Sequence[Union[Operation, OpView, Value]],
+    *,
+    outs: Sequence[Union[Operation, OpView, Value]],
+    indexing_maps: Sequence[AffineMapAttr],
+    cast: Optional[Union[TypeFn, Attribute]]=None
+):
+    inputs = [_get_op_result_or_value(input) for input in inputs]
+    if len(outs) > 1:
+        raise ValueError(f"{outs=} must have length 1.")
+    init = _get_op_result_or_value(outs[0])
+    result_types = [init.type] if isinstance(init.type, RankedTensorType) else []
+
+    op = MatmulOp(
+        result_tensors=result_types,
+        inputs=inputs,
+        outputs=[init],
+        indexing_maps=indexing_maps,
+        cast=cast
+    )
+    fill_builtin_region(op.operation)
+    return op
+
+
+def contract(
+    inputs: Sequence[Union[Operation, OpView, Value]],
+    *,
+    outs: Sequence[Union[Operation, OpView, Value]],
+    indexing_maps: Sequence[AffineMapAttr],
+    cast: Optional[Union[TypeFn, Attribute]]=None
+):
+    inputs = [_get_op_result_or_value(input) for input in inputs]
+    if len(outs) > 1:
+        raise ValueError(f"{outs=} must have length 1.")
+    init = _get_op_result_or_value(outs[0])
+    result_types = [init.type] if isinstance(init.type, RankedTensorType) else []
+
+    op = ContractOp(
+        result_tensors=result_types,
+        inputs=inputs,
+        outputs=[init],
+        indexing_maps=indexing_maps,
+        cast=cast
+    )
+    fill_builtin_region(op.operation)
+    return op
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index ac7186c24bed84e..6baea4f917c128c 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -256,3 +256,189 @@ def f(a, b):
 
     module.operation.verify()
     print(module)
+
+
+# CHECK-LABEL: TEST: testMatmulOp
+ at run
+def testMatmulOp():
+    with Context(), Location.unknown():
+        module = Module.create()
+        f32 = F32Type.get()
+        with InsertionPoint(module.body):
+            a_shape = (4, 8)
+            b_shape = (8, 12)
+            b_transposed_shape = (12, 8)
+            c_shape = (4, 12)
+
+            dimM = ir.AffineDimExpr.get(0)
+            dimN = ir.AffineDimExpr.get(1)
+            dimK = ir.AffineDimExpr.get(2)
+
+            # CHECK: #[[$A_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+            # CHECK: #[[$BTrans_MAP:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+            # CHECK: #[[$C_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+            a_map = ir.AffineMap.get(3, 0, [dimM, dimK])
+            b_map = ir.AffineMap.get(3, 0, [dimK, dimN])
+            c_map = ir.AffineMap.get(3, 0, [dimM, dimN])
+            b_transposed_map = ir.AffineMap.get(3, 0, [dimN, dimK])
+
+            # CHECK: func.func @matmul_op(
+            @func.FuncOp.from_py_func(
+                # CHECK-SAME:                         %[[A:.*]]: tensor<4x8xf32>,
+                RankedTensorType.get(a_shape, f32),
+                # CHECK-SAME:                         %[[Amem:.*]]: memref<4x8xf32>,
+                MemRefType.get(a_shape, f32),
+                # CHECK-SAME:                         %[[B:.*]]: tensor<8x12xf32>,
+                RankedTensorType.get(b_shape, f32),
+                # CHECK-SAME:                         %[[Bmem:.*]]: memref<8x12xf32>,
+                MemRefType.get(b_shape, f32),
+                # CHECK-SAME:                         %[[BTrans:.*]]: tensor<12x8xf32>,
+                RankedTensorType.get(b_transposed_shape, f32),
+                # CHECK-SAME:                         %[[BTransmem:.*]]: memref<12x8xf32>,
+                MemRefType.get(b_transposed_shape, f32),
+                # CHECK-SAME:                         %[[C:.*]]: tensor<4x12xf32>,
+                RankedTensorType.get(c_shape, f32),
+                # CHECK-SAME:                         %[[Cmem:.*]]: memref<4x12xf32>)
+                MemRefType.get(c_shape, f32),
+            )
+            def matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
+                # CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
+                op4 = linalg.MatmulOp(
+                    result_tensors=(C.type,),
+                    inputs=(A, B),
+                    outputs=(C,),
+                    indexing_maps=[a_map, b_map, c_map]
+                )
+                linalg.fill_builtin_region(op4.operation)
+                # CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
+                op5 = linalg.matmul((A, B), outs=(C,), indexing_maps=[a_map, b_map, c_map])
+
+                # CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
+                op4 = linalg.MatmulOp(
+                    result_tensors=(C.type,),
+                    inputs=(A, Btransposed),
+                    outputs=(C,),
+                    indexing_maps=[a_map, b_transposed_map, c_map]
+                )
+                linalg.fill_builtin_region(op4.operation)
+                # CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
+                op5 = linalg.matmul((A, Btransposed), outs=(C,), indexing_maps=[a_map, b_transposed_map, c_map])
+
+                # And now with memrefs...
+
+                # CHECK: linalg.matmul ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+                op4 = linalg.MatmulOp(
+                    result_tensors=[],
+                    inputs=(Amem, Bmem),
+                    outputs=(Cmem,),
+                    indexing_maps=[a_map, b_map, c_map]
+                )
+                linalg.fill_builtin_region(op4.operation)
+                # CHECK: linalg.matmul ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+                linalg.matmul((Amem, Bmem), outs=(Cmem,), indexing_maps=[a_map, b_map, c_map])
+
+                # CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+                op4 = linalg.MatmulOp(
+                    result_tensors=[],
+                    inputs=(Amem, Btransposedmem),
+                    outputs=(Cmem,),
+                    indexing_maps=[a_map, b_transposed_map, c_map]
+                )
+                linalg.fill_builtin_region(op4.operation)
+                # CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+                linalg.matmul((Amem, Btransposedmem), outs=(Cmem,), indexing_maps=[a_map, b_transposed_map, c_map])
+
+        print(module)
+
+
+# CHECK-LABEL: TEST: testContractOp
+ at run
+def testContractOp():
+    with Context(), Location.unknown():
+        module = Module.create()
+        f32 = F32Type.get()
+        with InsertionPoint(module.body):
+            a_shape = (4, 8)
+            b_shape = (8, 12)
+            b_transposed_shape = (12, 8)
+            c_shape = (4, 12)
+
+            dimM = ir.AffineDimExpr.get(0)
+            dimN = ir.AffineDimExpr.get(1)
+            dimK = ir.AffineDimExpr.get(2)
+
+            # CHECK: #[[$A_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+            # CHECK: #[[$B_MAP:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+            # CHECK: #[[$C_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+            # CHECK: #[[$BTrans_MAP:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+            a_map = ir.AffineMap.get(3, 0, [dimM, dimK])
+            b_map = ir.AffineMap.get(3, 0, [dimK, dimN])
+            c_map = ir.AffineMap.get(3, 0, [dimM, dimN])
+            b_transposed_map = ir.AffineMap.get(3, 0, [dimN, dimK])
+
+            # CHECK: func.func @matmul_as_contract_op(
+            @func.FuncOp.from_py_func(
+                # CHECK-SAME:                         %[[A:.*]]: tensor<4x8xf32>,
+                RankedTensorType.get(a_shape, f32),
+                # CHECK-SAME:                         %[[Amem:.*]]: memref<4x8xf32>,
+                MemRefType.get(a_shape, f32),
+                # CHECK-SAME:                         %[[B:.*]]: tensor<8x12xf32>,
+                RankedTensorType.get(b_shape, f32),
+                # CHECK-SAME:                         %[[Bmem:.*]]: memref<8x12xf32>,
+                MemRefType.get(b_shape, f32),
+                # CHECK-SAME:                         %[[BTrans:.*]]: tensor<12x8xf32>,
+                RankedTensorType.get(b_transposed_shape, f32),
+                # CHECK-SAME:                         %[[BTransmem:.*]]: memref<12x8xf32>,
+                MemRefType.get(b_transposed_shape, f32),
+                # CHECK-SAME:                         %[[C:.*]]: tensor<4x12xf32>,
+                RankedTensorType.get(c_shape, f32),
+                # CHECK-SAME:                         %[[Cmem:.*]]: memref<4x12xf32>)
+                MemRefType.get(c_shape, f32),
+            )
+            def matmul_as_contract_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
+                # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$B_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
+                op4 = linalg.ContractOp(
+                    result_tensors=(C.type,),
+                    inputs=(A, B),
+                    outputs=(C,),
+                    indexing_maps=[a_map, b_map, c_map]
+                )
+                linalg.fill_builtin_region(op4.operation)
+                # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$B_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
+                op5 = linalg.contract((A, B), outs=(C,), indexing_maps=[a_map, b_map, c_map])
+
+                # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
+                op4 = linalg.ContractOp(
+                    result_tensors=(C.type,),
+                    inputs=(A, Btransposed),
+                    outputs=(C,),
+                    indexing_maps=[a_map, b_transposed_map, c_map]
+                )
+                linalg.fill_builtin_region(op4.operation)
+                # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
+                op5 = linalg.contract((A, Btransposed), outs=(C,), indexing_maps=[a_map, b_transposed_map, c_map])
+                # And now with memrefs...
+
+                # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$B_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+                op4 = linalg.ContractOp(
+                    result_tensors=[],
+                    inputs=(Amem, Bmem),
+                    outputs=(Cmem,),
+                    indexing_maps=[a_map, b_map, c_map]
+                )
+                linalg.fill_builtin_region(op4.operation)
+                # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$B_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+                linalg.contract((Amem, Bmem), outs=(Cmem,), indexing_maps=[a_map, b_map, c_map])
+
+                # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+                op4 = linalg.ContractOp(
+                    result_tensors=[],
+                    inputs=(Amem, Btransposedmem),
+                    outputs=(Cmem,),
+                    indexing_maps=[a_map, b_transposed_map, c_map]
+                )
+                linalg.fill_builtin_region(op4.operation)
+                # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+                linalg.contract((Amem, Btransposedmem), outs=(Cmem,), indexing_maps=[a_map, b_transposed_map, c_map])
+
+        print(module)



More information about the Mlir-commits mailing list