[Mlir-commits] [mlir] [MLIR][Linalg] Expose linalg.matmul and linalg.contract via Python API (PR #126377)
Rolf Morel
llvmlistbot at llvm.org
Sat Feb 8 08:50:25 PST 2025
https://github.com/rolfmorel created https://github.com/llvm/llvm-project/pull/126377
Now that linalg.matmul is in tablegen, "hand write" the Python wrapper that OpDSL used to derive. Similarly, add a Python wrapper for the new linalg.contract op.
Required following misc. fixes:
1) make linalg.matmul consistent in whether indexing_maps occurs before
or after operands, i.e. per the tests case it comes _before_.
TODO: fix linalg.batch_matmul as well
2) tablegen for linalg.contract did not state it accepted an optional
cast attr.
>From e3f2a7c23d2d10809e91ba419ab1092dda6264eb Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Sat, 8 Feb 2025 08:41:22 -0800
Subject: [PATCH] [MLIR][Linalg] Expose linalg.matmul and linalg.contract via
Python API
Now that linalg.matmul is in tablegen, "hand write" the Python wrapper
that OpDSL used to derive. Similarly, add a Python wrapper for the new
linalg.contract op.
Required following misc. fixes:
1) make linalg.matmul consistent in whether indexing_maps occurs before
or after operands, i.e. per the tests case it comes _before_.
TODO: fix linalg.batch_matmul as well
2) tablegen for linalg.contract did not state it accepted an optional
cast attr.
---
.../Dialect/Linalg/IR/LinalgStructuredOps.td | 3 +-
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 10 +-
mlir/python/mlir/dialects/linalg/__init__.py | 48 +++++
mlir/test/python/dialects/linalg/ops.py | 186 ++++++++++++++++++
4 files changed, 241 insertions(+), 6 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 110ed7d2fc00e2a..6146ff09482fbad 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -752,7 +752,8 @@ def ContractOp : LinalgStructuredBase_Op<"contract", [
let arguments = (ins
Variadic<AnyType>:$inputs,
Variadic<AnyShaped>:$outputs,
- AffineMapArrayAttr:$indexing_maps
+ AffineMapArrayAttr:$indexing_maps,
+ DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
);
let results = (outs Variadic<AnyShaped>:$result_tensors);
// NB: The only reason this op has a region - and it get populated at op build
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index b50931f15826ce2..d40cec02df6338d 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -3666,11 +3666,6 @@ ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
}
void MatmulOp::print(OpAsmPrinter &p) {
- SmallVector<StringRef, 3> elidedAttrs = {
- "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
- printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
- elidedAttrs);
-
SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
MatmulOp::getDefaultIndexingMaps(getContext()),
[](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
@@ -3680,6 +3675,11 @@ void MatmulOp::print(OpAsmPrinter &p) {
[&](Attribute attr) { p.printAttribute(attr); });
p << "]";
}
+
+ SmallVector<StringRef, 3> elidedAttrs = {
+ "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
+ printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
+ elidedAttrs);
}
/// Verify the user defined indexing maps.
diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index 742262a9c496952..07e7c0ccd70025d 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -147,3 +147,51 @@ def __init__(
generic = region_op(GenericOp_, terminator=YieldOp)
+
+
+def matmul(
+ inputs: Sequence[Union[Operation, OpView, Value]],
+ *,
+ outs: Sequence[Union[Operation, OpView, Value]],
+ indexing_maps: Sequence[AffineMapAttr],
+ cast: Optional[Union[TypeFn, Attribute]]=None
+):
+ inputs = [_get_op_result_or_value(input) for input in inputs]
+ if len(outs) > 1:
+ raise ValueError(f"{outs=} must have length 1.")
+ init = _get_op_result_or_value(outs[0])
+ result_types = [init.type] if isinstance(init.type, RankedTensorType) else []
+
+ op = MatmulOp(
+ result_tensors=result_types,
+ inputs=inputs,
+ outputs=[init],
+ indexing_maps=indexing_maps,
+ cast=cast
+ )
+ fill_builtin_region(op.operation)
+ return op
+
+
+def contract(
+ inputs: Sequence[Union[Operation, OpView, Value]],
+ *,
+ outs: Sequence[Union[Operation, OpView, Value]],
+ indexing_maps: Sequence[AffineMapAttr],
+ cast: Optional[Union[TypeFn, Attribute]]=None
+):
+ inputs = [_get_op_result_or_value(input) for input in inputs]
+ if len(outs) > 1:
+ raise ValueError(f"{outs=} must have length 1.")
+ init = _get_op_result_or_value(outs[0])
+ result_types = [init.type] if isinstance(init.type, RankedTensorType) else []
+
+ op = ContractOp(
+ result_tensors=result_types,
+ inputs=inputs,
+ outputs=[init],
+ indexing_maps=indexing_maps,
+ cast=cast
+ )
+ fill_builtin_region(op.operation)
+ return op
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index ac7186c24bed84e..6baea4f917c128c 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -256,3 +256,189 @@ def f(a, b):
module.operation.verify()
print(module)
+
+
+# CHECK-LABEL: TEST: testMatmulOp
+ at run
+def testMatmulOp():
+ with Context(), Location.unknown():
+ module = Module.create()
+ f32 = F32Type.get()
+ with InsertionPoint(module.body):
+ a_shape = (4, 8)
+ b_shape = (8, 12)
+ b_transposed_shape = (12, 8)
+ c_shape = (4, 12)
+
+ dimM = ir.AffineDimExpr.get(0)
+ dimN = ir.AffineDimExpr.get(1)
+ dimK = ir.AffineDimExpr.get(2)
+
+ # CHECK: #[[$A_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+ # CHECK: #[[$BTrans_MAP:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+ # CHECK: #[[$C_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+ a_map = ir.AffineMap.get(3, 0, [dimM, dimK])
+ b_map = ir.AffineMap.get(3, 0, [dimK, dimN])
+ c_map = ir.AffineMap.get(3, 0, [dimM, dimN])
+ b_transposed_map = ir.AffineMap.get(3, 0, [dimN, dimK])
+
+ # CHECK: func.func @matmul_op(
+ @func.FuncOp.from_py_func(
+ # CHECK-SAME: %[[A:.*]]: tensor<4x8xf32>,
+ RankedTensorType.get(a_shape, f32),
+ # CHECK-SAME: %[[Amem:.*]]: memref<4x8xf32>,
+ MemRefType.get(a_shape, f32),
+ # CHECK-SAME: %[[B:.*]]: tensor<8x12xf32>,
+ RankedTensorType.get(b_shape, f32),
+ # CHECK-SAME: %[[Bmem:.*]]: memref<8x12xf32>,
+ MemRefType.get(b_shape, f32),
+ # CHECK-SAME: %[[BTrans:.*]]: tensor<12x8xf32>,
+ RankedTensorType.get(b_transposed_shape, f32),
+ # CHECK-SAME: %[[BTransmem:.*]]: memref<12x8xf32>,
+ MemRefType.get(b_transposed_shape, f32),
+ # CHECK-SAME: %[[C:.*]]: tensor<4x12xf32>,
+ RankedTensorType.get(c_shape, f32),
+ # CHECK-SAME: %[[Cmem:.*]]: memref<4x12xf32>)
+ MemRefType.get(c_shape, f32),
+ )
+ def matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
+ # CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
+ op4 = linalg.MatmulOp(
+ result_tensors=(C.type,),
+ inputs=(A, B),
+ outputs=(C,),
+ indexing_maps=[a_map, b_map, c_map]
+ )
+ linalg.fill_builtin_region(op4.operation)
+ # CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
+ op5 = linalg.matmul((A, B), outs=(C,), indexing_maps=[a_map, b_map, c_map])
+
+ # CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
+ op4 = linalg.MatmulOp(
+ result_tensors=(C.type,),
+ inputs=(A, Btransposed),
+ outputs=(C,),
+ indexing_maps=[a_map, b_transposed_map, c_map]
+ )
+ linalg.fill_builtin_region(op4.operation)
+ # CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
+ op5 = linalg.matmul((A, Btransposed), outs=(C,), indexing_maps=[a_map, b_transposed_map, c_map])
+
+ # And now with memrefs...
+
+ # CHECK: linalg.matmul ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+ op4 = linalg.MatmulOp(
+ result_tensors=[],
+ inputs=(Amem, Bmem),
+ outputs=(Cmem,),
+ indexing_maps=[a_map, b_map, c_map]
+ )
+ linalg.fill_builtin_region(op4.operation)
+ # CHECK: linalg.matmul ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+ linalg.matmul((Amem, Bmem), outs=(Cmem,), indexing_maps=[a_map, b_map, c_map])
+
+ # CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+ op4 = linalg.MatmulOp(
+ result_tensors=[],
+ inputs=(Amem, Btransposedmem),
+ outputs=(Cmem,),
+ indexing_maps=[a_map, b_transposed_map, c_map]
+ )
+ linalg.fill_builtin_region(op4.operation)
+ # CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+ linalg.matmul((Amem, Btransposedmem), outs=(Cmem,), indexing_maps=[a_map, b_transposed_map, c_map])
+
+ print(module)
+
+
+# CHECK-LABEL: TEST: testContractOp
+ at run
+def testContractOp():
+ with Context(), Location.unknown():
+ module = Module.create()
+ f32 = F32Type.get()
+ with InsertionPoint(module.body):
+ a_shape = (4, 8)
+ b_shape = (8, 12)
+ b_transposed_shape = (12, 8)
+ c_shape = (4, 12)
+
+ dimM = ir.AffineDimExpr.get(0)
+ dimN = ir.AffineDimExpr.get(1)
+ dimK = ir.AffineDimExpr.get(2)
+
+ # CHECK: #[[$A_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+ # CHECK: #[[$B_MAP:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+ # CHECK: #[[$C_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+ # CHECK: #[[$BTrans_MAP:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+ a_map = ir.AffineMap.get(3, 0, [dimM, dimK])
+ b_map = ir.AffineMap.get(3, 0, [dimK, dimN])
+ c_map = ir.AffineMap.get(3, 0, [dimM, dimN])
+ b_transposed_map = ir.AffineMap.get(3, 0, [dimN, dimK])
+
+ # CHECK: func.func @matmul_as_contract_op(
+ @func.FuncOp.from_py_func(
+ # CHECK-SAME: %[[A:.*]]: tensor<4x8xf32>,
+ RankedTensorType.get(a_shape, f32),
+ # CHECK-SAME: %[[Amem:.*]]: memref<4x8xf32>,
+ MemRefType.get(a_shape, f32),
+ # CHECK-SAME: %[[B:.*]]: tensor<8x12xf32>,
+ RankedTensorType.get(b_shape, f32),
+ # CHECK-SAME: %[[Bmem:.*]]: memref<8x12xf32>,
+ MemRefType.get(b_shape, f32),
+ # CHECK-SAME: %[[BTrans:.*]]: tensor<12x8xf32>,
+ RankedTensorType.get(b_transposed_shape, f32),
+ # CHECK-SAME: %[[BTransmem:.*]]: memref<12x8xf32>,
+ MemRefType.get(b_transposed_shape, f32),
+ # CHECK-SAME: %[[C:.*]]: tensor<4x12xf32>,
+ RankedTensorType.get(c_shape, f32),
+ # CHECK-SAME: %[[Cmem:.*]]: memref<4x12xf32>)
+ MemRefType.get(c_shape, f32),
+ )
+ def matmul_as_contract_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
+ # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$B_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
+ op4 = linalg.ContractOp(
+ result_tensors=(C.type,),
+ inputs=(A, B),
+ outputs=(C,),
+ indexing_maps=[a_map, b_map, c_map]
+ )
+ linalg.fill_builtin_region(op4.operation)
+ # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$B_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
+ op5 = linalg.contract((A, B), outs=(C,), indexing_maps=[a_map, b_map, c_map])
+
+ # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
+ op4 = linalg.ContractOp(
+ result_tensors=(C.type,),
+ inputs=(A, Btransposed),
+ outputs=(C,),
+ indexing_maps=[a_map, b_transposed_map, c_map]
+ )
+ linalg.fill_builtin_region(op4.operation)
+ # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
+ op5 = linalg.contract((A, Btransposed), outs=(C,), indexing_maps=[a_map, b_transposed_map, c_map])
+ # And now with memrefs...
+
+ # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$B_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+ op4 = linalg.ContractOp(
+ result_tensors=[],
+ inputs=(Amem, Bmem),
+ outputs=(Cmem,),
+ indexing_maps=[a_map, b_map, c_map]
+ )
+ linalg.fill_builtin_region(op4.operation)
+ # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$B_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+ linalg.contract((Amem, Bmem), outs=(Cmem,), indexing_maps=[a_map, b_map, c_map])
+
+ # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+ op4 = linalg.ContractOp(
+ result_tensors=[],
+ inputs=(Amem, Btransposedmem),
+ outputs=(Cmem,),
+ indexing_maps=[a_map, b_transposed_map, c_map]
+ )
+ linalg.fill_builtin_region(op4.operation)
+ # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+ linalg.contract((Amem, Btransposedmem), outs=(Cmem,), indexing_maps=[a_map, b_transposed_map, c_map])
+
+ print(module)
More information about the Mlir-commits
mailing list