[Mlir-commits] [mlir] [MLIR][Linalg] Expose linalg.matmul and linalg.contract via Python API (PR #126377)
Rolf Morel
llvmlistbot at llvm.org
Sun Feb 9 10:58:17 PST 2025
https://github.com/rolfmorel updated https://github.com/llvm/llvm-project/pull/126377
>From e3f2a7c23d2d10809e91ba419ab1092dda6264eb Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Sat, 8 Feb 2025 08:41:22 -0800
Subject: [PATCH 1/8] [MLIR][Linalg] Expose linalg.matmul and linalg.contract
via Python API
Now that linalg.matmul is in tablegen, "hand write" the Python wrapper
that OpDSL used to derive. Similarly, add a Python wrapper for the new
linalg.contract op.
Required following misc. fixes:
1) make linalg.matmul consistent in whether indexing_maps occurs before
or after operands, i.e. per the tests case it comes _before_.
TODO: fix linalg.batch_matmul as well
2) tablegen for linalg.contract did not state it accepted an optional
cast attr.
---
.../Dialect/Linalg/IR/LinalgStructuredOps.td | 3 +-
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 10 +-
mlir/python/mlir/dialects/linalg/__init__.py | 48 +++++
mlir/test/python/dialects/linalg/ops.py | 186 ++++++++++++++++++
4 files changed, 241 insertions(+), 6 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 110ed7d2fc00e2a..6146ff09482fbad 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -752,7 +752,8 @@ def ContractOp : LinalgStructuredBase_Op<"contract", [
let arguments = (ins
Variadic<AnyType>:$inputs,
Variadic<AnyShaped>:$outputs,
- AffineMapArrayAttr:$indexing_maps
+ AffineMapArrayAttr:$indexing_maps,
+ DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
);
let results = (outs Variadic<AnyShaped>:$result_tensors);
// NB: The only reason this op has a region - and it get populated at op build
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index b50931f15826ce2..d40cec02df6338d 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -3666,11 +3666,6 @@ ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
}
void MatmulOp::print(OpAsmPrinter &p) {
- SmallVector<StringRef, 3> elidedAttrs = {
- "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
- printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
- elidedAttrs);
-
SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
MatmulOp::getDefaultIndexingMaps(getContext()),
[](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
@@ -3680,6 +3675,11 @@ void MatmulOp::print(OpAsmPrinter &p) {
[&](Attribute attr) { p.printAttribute(attr); });
p << "]";
}
+
+ SmallVector<StringRef, 3> elidedAttrs = {
+ "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
+ printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
+ elidedAttrs);
}
/// Verify the user defined indexing maps.
diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index 742262a9c496952..07e7c0ccd70025d 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -147,3 +147,51 @@ def __init__(
generic = region_op(GenericOp_, terminator=YieldOp)
+
+
+def matmul(
+ inputs: Sequence[Union[Operation, OpView, Value]],
+ *,
+ outs: Sequence[Union[Operation, OpView, Value]],
+ indexing_maps: Sequence[AffineMapAttr],
+ cast: Optional[Union[TypeFn, Attribute]]=None
+):
+ inputs = [_get_op_result_or_value(input) for input in inputs]
+ if len(outs) > 1:
+ raise ValueError(f"{outs=} must have length 1.")
+ init = _get_op_result_or_value(outs[0])
+ result_types = [init.type] if isinstance(init.type, RankedTensorType) else []
+
+ op = MatmulOp(
+ result_tensors=result_types,
+ inputs=inputs,
+ outputs=[init],
+ indexing_maps=indexing_maps,
+ cast=cast
+ )
+ fill_builtin_region(op.operation)
+ return op
+
+
+def contract(
+ inputs: Sequence[Union[Operation, OpView, Value]],
+ *,
+ outs: Sequence[Union[Operation, OpView, Value]],
+ indexing_maps: Sequence[AffineMapAttr],
+ cast: Optional[Union[TypeFn, Attribute]]=None
+):
+ inputs = [_get_op_result_or_value(input) for input in inputs]
+ if len(outs) > 1:
+ raise ValueError(f"{outs=} must have length 1.")
+ init = _get_op_result_or_value(outs[0])
+ result_types = [init.type] if isinstance(init.type, RankedTensorType) else []
+
+ op = ContractOp(
+ result_tensors=result_types,
+ inputs=inputs,
+ outputs=[init],
+ indexing_maps=indexing_maps,
+ cast=cast
+ )
+ fill_builtin_region(op.operation)
+ return op
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index ac7186c24bed84e..6baea4f917c128c 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -256,3 +256,189 @@ def f(a, b):
module.operation.verify()
print(module)
+
+
+# CHECK-LABEL: TEST: testMatmulOp
+ at run
+def testMatmulOp():
+ with Context(), Location.unknown():
+ module = Module.create()
+ f32 = F32Type.get()
+ with InsertionPoint(module.body):
+ a_shape = (4, 8)
+ b_shape = (8, 12)
+ b_transposed_shape = (12, 8)
+ c_shape = (4, 12)
+
+ dimM = ir.AffineDimExpr.get(0)
+ dimN = ir.AffineDimExpr.get(1)
+ dimK = ir.AffineDimExpr.get(2)
+
+ # CHECK: #[[$A_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+ # CHECK: #[[$BTrans_MAP:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+ # CHECK: #[[$C_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+ a_map = ir.AffineMap.get(3, 0, [dimM, dimK])
+ b_map = ir.AffineMap.get(3, 0, [dimK, dimN])
+ c_map = ir.AffineMap.get(3, 0, [dimM, dimN])
+ b_transposed_map = ir.AffineMap.get(3, 0, [dimN, dimK])
+
+ # CHECK: func.func @matmul_op(
+ @func.FuncOp.from_py_func(
+ # CHECK-SAME: %[[A:.*]]: tensor<4x8xf32>,
+ RankedTensorType.get(a_shape, f32),
+ # CHECK-SAME: %[[Amem:.*]]: memref<4x8xf32>,
+ MemRefType.get(a_shape, f32),
+ # CHECK-SAME: %[[B:.*]]: tensor<8x12xf32>,
+ RankedTensorType.get(b_shape, f32),
+ # CHECK-SAME: %[[Bmem:.*]]: memref<8x12xf32>,
+ MemRefType.get(b_shape, f32),
+ # CHECK-SAME: %[[BTrans:.*]]: tensor<12x8xf32>,
+ RankedTensorType.get(b_transposed_shape, f32),
+ # CHECK-SAME: %[[BTransmem:.*]]: memref<12x8xf32>,
+ MemRefType.get(b_transposed_shape, f32),
+ # CHECK-SAME: %[[C:.*]]: tensor<4x12xf32>,
+ RankedTensorType.get(c_shape, f32),
+ # CHECK-SAME: %[[Cmem:.*]]: memref<4x12xf32>)
+ MemRefType.get(c_shape, f32),
+ )
+ def matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
+ # CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
+ op4 = linalg.MatmulOp(
+ result_tensors=(C.type,),
+ inputs=(A, B),
+ outputs=(C,),
+ indexing_maps=[a_map, b_map, c_map]
+ )
+ linalg.fill_builtin_region(op4.operation)
+ # CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
+ op5 = linalg.matmul((A, B), outs=(C,), indexing_maps=[a_map, b_map, c_map])
+
+ # CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
+ op4 = linalg.MatmulOp(
+ result_tensors=(C.type,),
+ inputs=(A, Btransposed),
+ outputs=(C,),
+ indexing_maps=[a_map, b_transposed_map, c_map]
+ )
+ linalg.fill_builtin_region(op4.operation)
+ # CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
+ op5 = linalg.matmul((A, Btransposed), outs=(C,), indexing_maps=[a_map, b_transposed_map, c_map])
+
+ # And now with memrefs...
+
+ # CHECK: linalg.matmul ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+ op4 = linalg.MatmulOp(
+ result_tensors=[],
+ inputs=(Amem, Bmem),
+ outputs=(Cmem,),
+ indexing_maps=[a_map, b_map, c_map]
+ )
+ linalg.fill_builtin_region(op4.operation)
+ # CHECK: linalg.matmul ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+ linalg.matmul((Amem, Bmem), outs=(Cmem,), indexing_maps=[a_map, b_map, c_map])
+
+ # CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+ op4 = linalg.MatmulOp(
+ result_tensors=[],
+ inputs=(Amem, Btransposedmem),
+ outputs=(Cmem,),
+ indexing_maps=[a_map, b_transposed_map, c_map]
+ )
+ linalg.fill_builtin_region(op4.operation)
+ # CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+ linalg.matmul((Amem, Btransposedmem), outs=(Cmem,), indexing_maps=[a_map, b_transposed_map, c_map])
+
+ print(module)
+
+
+# CHECK-LABEL: TEST: testContractOp
+ at run
+def testContractOp():
+ with Context(), Location.unknown():
+ module = Module.create()
+ f32 = F32Type.get()
+ with InsertionPoint(module.body):
+ a_shape = (4, 8)
+ b_shape = (8, 12)
+ b_transposed_shape = (12, 8)
+ c_shape = (4, 12)
+
+ dimM = ir.AffineDimExpr.get(0)
+ dimN = ir.AffineDimExpr.get(1)
+ dimK = ir.AffineDimExpr.get(2)
+
+ # CHECK: #[[$A_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+ # CHECK: #[[$B_MAP:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+ # CHECK: #[[$C_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+ # CHECK: #[[$BTrans_MAP:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+ a_map = ir.AffineMap.get(3, 0, [dimM, dimK])
+ b_map = ir.AffineMap.get(3, 0, [dimK, dimN])
+ c_map = ir.AffineMap.get(3, 0, [dimM, dimN])
+ b_transposed_map = ir.AffineMap.get(3, 0, [dimN, dimK])
+
+ # CHECK: func.func @matmul_as_contract_op(
+ @func.FuncOp.from_py_func(
+ # CHECK-SAME: %[[A:.*]]: tensor<4x8xf32>,
+ RankedTensorType.get(a_shape, f32),
+ # CHECK-SAME: %[[Amem:.*]]: memref<4x8xf32>,
+ MemRefType.get(a_shape, f32),
+ # CHECK-SAME: %[[B:.*]]: tensor<8x12xf32>,
+ RankedTensorType.get(b_shape, f32),
+ # CHECK-SAME: %[[Bmem:.*]]: memref<8x12xf32>,
+ MemRefType.get(b_shape, f32),
+ # CHECK-SAME: %[[BTrans:.*]]: tensor<12x8xf32>,
+ RankedTensorType.get(b_transposed_shape, f32),
+ # CHECK-SAME: %[[BTransmem:.*]]: memref<12x8xf32>,
+ MemRefType.get(b_transposed_shape, f32),
+ # CHECK-SAME: %[[C:.*]]: tensor<4x12xf32>,
+ RankedTensorType.get(c_shape, f32),
+ # CHECK-SAME: %[[Cmem:.*]]: memref<4x12xf32>)
+ MemRefType.get(c_shape, f32),
+ )
+ def matmul_as_contract_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
+ # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$B_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
+ op4 = linalg.ContractOp(
+ result_tensors=(C.type,),
+ inputs=(A, B),
+ outputs=(C,),
+ indexing_maps=[a_map, b_map, c_map]
+ )
+ linalg.fill_builtin_region(op4.operation)
+ # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$B_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
+ op5 = linalg.contract((A, B), outs=(C,), indexing_maps=[a_map, b_map, c_map])
+
+ # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
+ op4 = linalg.ContractOp(
+ result_tensors=(C.type,),
+ inputs=(A, Btransposed),
+ outputs=(C,),
+ indexing_maps=[a_map, b_transposed_map, c_map]
+ )
+ linalg.fill_builtin_region(op4.operation)
+ # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
+ op5 = linalg.contract((A, Btransposed), outs=(C,), indexing_maps=[a_map, b_transposed_map, c_map])
+ # And now with memrefs...
+
+ # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$B_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+ op4 = linalg.ContractOp(
+ result_tensors=[],
+ inputs=(Amem, Bmem),
+ outputs=(Cmem,),
+ indexing_maps=[a_map, b_map, c_map]
+ )
+ linalg.fill_builtin_region(op4.operation)
+ # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$B_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+ linalg.contract((Amem, Bmem), outs=(Cmem,), indexing_maps=[a_map, b_map, c_map])
+
+ # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+ op4 = linalg.ContractOp(
+ result_tensors=[],
+ inputs=(Amem, Btransposedmem),
+ outputs=(Cmem,),
+ indexing_maps=[a_map, b_transposed_map, c_map]
+ )
+ linalg.fill_builtin_region(op4.operation)
+ # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+ linalg.contract((Amem, Btransposedmem), outs=(Cmem,), indexing_maps=[a_map, b_transposed_map, c_map])
+
+ print(module)
>From 57b2509c212f99e9838e0afb4790fc551ff337dc Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Sat, 8 Feb 2025 09:00:51 -0800
Subject: [PATCH 2/8] Fix linalg.matmul tests which encoded different
placements of indexing_maps
---
mlir/test/Dialect/Linalg/named-ops.mlir | 16 ++++++++++------
1 file changed, 10 insertions(+), 6 deletions(-)
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index ed8683522c74a1b..68ea97be911a66a 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1269,7 +1269,7 @@ func.func @matmul_transpose_b_explicit(%arg0: memref<3x5xf32>, %arg1: memref<7x5
// CHECK-SAME: %[[VAL_0:.*]]: memref<3x5xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: memref<7x5xf32>,
// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
-// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK: linalg.matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
// CHECK: return
// CHECK: }
@@ -1294,7 +1294,7 @@ func.func @matmul_transpose_a_b_explicit(%arg0: memref<5x3xf32>, %arg1: memref<7
// CHECK-SAME: %[[VAL_0:.*]]: memref<5x3xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: memref<7x5xf32>,
// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
-// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5x3xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK: linalg.matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<5x3xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
// CHECK: return
// CHECK: }
@@ -1315,6 +1315,7 @@ func.func @matmul_bcast_a(%arg0: memref<5xf32>, %arg1: memref<5x7xf32>, %arg2: m
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-LABEL: func @matmul_bcast_a
// CHECK: linalg.matmul
+// CHECK-SAME: indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<5xf32>, memref<5x7xf32>)
// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
@@ -1335,6 +1336,7 @@ func.func @matmul_bcast_a_dim1(%arg0: memref<5xf32>, %arg1: memref<5x7xf32>, %ar
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-LABEL: func @matmul_bcast_a_dim1
// CHECK: linalg.matmul
+// CHECK-SAME: indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<5xf32>, memref<5x7xf32>)
// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
@@ -1355,6 +1357,7 @@ func.func @matmul_bcast_b(%arg0: memref<3x5xf32>, %arg1: memref<5xf32>, %arg2: m
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-LABEL: func @matmul_bcast_b
// CHECK: linalg.matmul
+// CHECK-SAME: indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<5xf32>)
// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
@@ -1376,7 +1379,7 @@ func.func @matmul_bcast_a_b(%arg0: memref<5xf32>, %arg1: memref<5xf32>, %arg2: m
// CHECK-LABEL: func.func @matmul_bcast_a_b(
// CHECK-SAME: %[[VAL_0:.*]]: memref<5xf32>, %[[VAL_1:.*]]: memref<5xf32>,
// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
-// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_1]]]
+// CHECK: linalg.matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_1]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
// CHECK: return
// CHECK: }
@@ -1397,6 +1400,7 @@ func.func @matmul_bcast_b_dim1(%arg0: memref<3x5xf32>, %arg1: memref<5xf32>, %ar
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-LABEL: func @matmul_bcast_b_dim1
// CHECK: linalg.matmul
+// CHECK-SAME: indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<5xf32>)
// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
@@ -1420,7 +1424,7 @@ func.func @dynamic_matmul_bcast_a(%arg0: memref<?xf32>, %arg1: memref<?x?xf32>,
// CHECK-SAME: %[[VAL_0:.*]]: memref<?xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: memref<?x?xf32>,
// CHECK-SAME: %[[VAL_2:.*]]: memref<?x?xf32>) {
-// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<?xf32>, memref<?x?xf32>) outs(%[[VAL_2]] : memref<?x?xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK: linalg.matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<?xf32>, memref<?x?xf32>) outs(%[[VAL_2]] : memref<?x?xf32>)
// CHECK: return
// CHECK: }
@@ -1444,7 +1448,7 @@ func.func @matmul_bcast_a_transpose_b(%arg0: memref<5xf32>, %arg1: memref<7x5xf3
// CHECK-SAME: %[[VAL_0:.*]]: memref<5xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: memref<7x5xf32>,
// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
-// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK: linalg.matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
// CHECK: return
// CHECK: }
@@ -1468,7 +1472,7 @@ func.func @matmul_bcast_b_transpose_a(%arg0: memref<5x3xf32>, %arg1: memref<5xf3
// CHECK-SAME: %[[VAL_0:.*]]: memref<5x3xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: memref<5xf32>,
// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
-// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5x3xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK: linalg.matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<5x3xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
// CHECK: return
// CHECK: }
>From 18ff3a681e9b13d845adf3eb9bcc16177e6fb2e9 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Sat, 8 Feb 2025 09:08:39 -0800
Subject: [PATCH 3/8] Python formatting
---
mlir/python/mlir/dialects/linalg/__init__.py | 8 +--
mlir/test/python/dialects/linalg/ops.py | 76 +++++++++++++-------
2 files changed, 55 insertions(+), 29 deletions(-)
diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index 07e7c0ccd70025d..b56252f9a617ceb 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -154,7 +154,7 @@ def matmul(
*,
outs: Sequence[Union[Operation, OpView, Value]],
indexing_maps: Sequence[AffineMapAttr],
- cast: Optional[Union[TypeFn, Attribute]]=None
+ cast: Optional[Union[TypeFn, Attribute]] = None,
):
inputs = [_get_op_result_or_value(input) for input in inputs]
if len(outs) > 1:
@@ -167,7 +167,7 @@ def matmul(
inputs=inputs,
outputs=[init],
indexing_maps=indexing_maps,
- cast=cast
+ cast=cast,
)
fill_builtin_region(op.operation)
return op
@@ -178,7 +178,7 @@ def contract(
*,
outs: Sequence[Union[Operation, OpView, Value]],
indexing_maps: Sequence[AffineMapAttr],
- cast: Optional[Union[TypeFn, Attribute]]=None
+ cast: Optional[Union[TypeFn, Attribute]] = None,
):
inputs = [_get_op_result_or_value(input) for input in inputs]
if len(outs) > 1:
@@ -191,7 +191,7 @@ def contract(
inputs=inputs,
outputs=[init],
indexing_maps=indexing_maps,
- cast=cast
+ cast=cast,
)
fill_builtin_region(op.operation)
return op
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index 6baea4f917c128c..b487d72cbf9f358 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -303,50 +303,62 @@ def testMatmulOp():
)
def matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
# CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
- op4 = linalg.MatmulOp(
+ res = linalg.MatmulOp(
result_tensors=(C.type,),
inputs=(A, B),
outputs=(C,),
- indexing_maps=[a_map, b_map, c_map]
+ indexing_maps=[a_map, b_map, c_map],
)
- linalg.fill_builtin_region(op4.operation)
+ linalg.fill_builtin_region(res.operation)
# CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
- op5 = linalg.matmul((A, B), outs=(C,), indexing_maps=[a_map, b_map, c_map])
+ res = linalg.matmul(
+ (A, B), outs=(C,), indexing_maps=[a_map, b_map, c_map]
+ )
# CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
- op4 = linalg.MatmulOp(
+ res = linalg.MatmulOp(
result_tensors=(C.type,),
inputs=(A, Btransposed),
outputs=(C,),
- indexing_maps=[a_map, b_transposed_map, c_map]
+ indexing_maps=[a_map, b_transposed_map, c_map],
)
- linalg.fill_builtin_region(op4.operation)
+ linalg.fill_builtin_region(res.operation)
# CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
- op5 = linalg.matmul((A, Btransposed), outs=(C,), indexing_maps=[a_map, b_transposed_map, c_map])
+ res = linalg.matmul(
+ (A, Btransposed),
+ outs=(C,),
+ indexing_maps=[a_map, b_transposed_map, c_map],
+ )
# And now with memrefs...
# CHECK: linalg.matmul ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
- op4 = linalg.MatmulOp(
+ res = linalg.MatmulOp(
result_tensors=[],
inputs=(Amem, Bmem),
outputs=(Cmem,),
- indexing_maps=[a_map, b_map, c_map]
+ indexing_maps=[a_map, b_map, c_map],
)
- linalg.fill_builtin_region(op4.operation)
+ linalg.fill_builtin_region(res.operation)
# CHECK: linalg.matmul ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
- linalg.matmul((Amem, Bmem), outs=(Cmem,), indexing_maps=[a_map, b_map, c_map])
+ linalg.matmul(
+ (Amem, Bmem), outs=(Cmem,), indexing_maps=[a_map, b_map, c_map]
+ )
# CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
- op4 = linalg.MatmulOp(
+ res = linalg.MatmulOp(
result_tensors=[],
inputs=(Amem, Btransposedmem),
outputs=(Cmem,),
- indexing_maps=[a_map, b_transposed_map, c_map]
+ indexing_maps=[a_map, b_transposed_map, c_map],
)
- linalg.fill_builtin_region(op4.operation)
+ linalg.fill_builtin_region(res.operation)
# CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
- linalg.matmul((Amem, Btransposedmem), outs=(Cmem,), indexing_maps=[a_map, b_transposed_map, c_map])
+ linalg.matmul(
+ (Amem, Btransposedmem),
+ outs=(Cmem,),
+ indexing_maps=[a_map, b_transposed_map, c_map],
+ )
print(module)
@@ -395,28 +407,36 @@ def testContractOp():
# CHECK-SAME: %[[Cmem:.*]]: memref<4x12xf32>)
MemRefType.get(c_shape, f32),
)
- def matmul_as_contract_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
+ def matmul_as_contract_op(
+ A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem
+ ):
# CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$B_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
op4 = linalg.ContractOp(
result_tensors=(C.type,),
inputs=(A, B),
outputs=(C,),
- indexing_maps=[a_map, b_map, c_map]
+ indexing_maps=[a_map, b_map, c_map],
)
linalg.fill_builtin_region(op4.operation)
# CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$B_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
- op5 = linalg.contract((A, B), outs=(C,), indexing_maps=[a_map, b_map, c_map])
+ op5 = linalg.contract(
+ (A, B), outs=(C,), indexing_maps=[a_map, b_map, c_map]
+ )
# CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
op4 = linalg.ContractOp(
result_tensors=(C.type,),
inputs=(A, Btransposed),
outputs=(C,),
- indexing_maps=[a_map, b_transposed_map, c_map]
+ indexing_maps=[a_map, b_transposed_map, c_map],
)
linalg.fill_builtin_region(op4.operation)
# CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
- op5 = linalg.contract((A, Btransposed), outs=(C,), indexing_maps=[a_map, b_transposed_map, c_map])
+ op5 = linalg.contract(
+ (A, Btransposed),
+ outs=(C,),
+ indexing_maps=[a_map, b_transposed_map, c_map],
+ )
# And now with memrefs...
# CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$B_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
@@ -424,21 +444,27 @@ def matmul_as_contract_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem
result_tensors=[],
inputs=(Amem, Bmem),
outputs=(Cmem,),
- indexing_maps=[a_map, b_map, c_map]
+ indexing_maps=[a_map, b_map, c_map],
)
linalg.fill_builtin_region(op4.operation)
# CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$B_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
- linalg.contract((Amem, Bmem), outs=(Cmem,), indexing_maps=[a_map, b_map, c_map])
+ linalg.contract(
+ (Amem, Bmem), outs=(Cmem,), indexing_maps=[a_map, b_map, c_map]
+ )
# CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
op4 = linalg.ContractOp(
result_tensors=[],
inputs=(Amem, Btransposedmem),
outputs=(Cmem,),
- indexing_maps=[a_map, b_transposed_map, c_map]
+ indexing_maps=[a_map, b_transposed_map, c_map],
)
linalg.fill_builtin_region(op4.operation)
# CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
- linalg.contract((Amem, Btransposedmem), outs=(Cmem,), indexing_maps=[a_map, b_transposed_map, c_map])
+ linalg.contract(
+ (Amem, Btransposedmem),
+ outs=(Cmem,),
+ indexing_maps=[a_map, b_transposed_map, c_map],
+ )
print(module)
>From 2e6bf665bbeca637a5e8d21689c6b128fc23275d Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Sat, 8 Feb 2025 13:46:41 -0800
Subject: [PATCH 4/8] Make indexing_maps optional on matmul, as it should be
---
mlir/python/mlir/dialects/linalg/__init__.py | 2 +-
mlir/test/python/dialects/linalg/ops.py | 8 ++------
2 files changed, 3 insertions(+), 7 deletions(-)
diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index b56252f9a617ceb..590229818c7440a 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -153,7 +153,7 @@ def matmul(
inputs: Sequence[Union[Operation, OpView, Value]],
*,
outs: Sequence[Union[Operation, OpView, Value]],
- indexing_maps: Sequence[AffineMapAttr],
+ indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
cast: Optional[Union[TypeFn, Attribute]] = None,
):
inputs = [_get_op_result_or_value(input) for input in inputs]
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index b487d72cbf9f358..36012303a52a231 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -307,13 +307,10 @@ def matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
result_tensors=(C.type,),
inputs=(A, B),
outputs=(C,),
- indexing_maps=[a_map, b_map, c_map],
)
linalg.fill_builtin_region(res.operation)
# CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
- res = linalg.matmul(
- (A, B), outs=(C,), indexing_maps=[a_map, b_map, c_map]
- )
+ res = linalg.matmul((A, B), outs=(C,))
# CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
res = linalg.MatmulOp(
@@ -337,12 +334,11 @@ def matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
result_tensors=[],
inputs=(Amem, Bmem),
outputs=(Cmem,),
- indexing_maps=[a_map, b_map, c_map],
)
linalg.fill_builtin_region(res.operation)
# CHECK: linalg.matmul ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
linalg.matmul(
- (Amem, Bmem), outs=(Cmem,), indexing_maps=[a_map, b_map, c_map]
+ (Amem, Bmem), outs=(Cmem,)
)
# CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
>From 85426f6019a0faf7b592b91b23133b4298493c24 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Sat, 8 Feb 2025 15:24:21 -0800
Subject: [PATCH 5/8] Fix matmul's indexing_maps issue
---
.../Dialect/Linalg/IR/LinalgStructuredOps.td | 18 +++++++++++++++++-
mlir/python/mlir/dialects/linalg/__init__.py | 5 +++++
mlir/test/python/dialects/linalg/ops.py | 4 +---
3 files changed, 23 insertions(+), 4 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 6146ff09482fbad..dc1c93355e2582b 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -555,6 +555,22 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
// Op definition for MatmulOp
//===----------------------------------------------------------------------===//
+
+// DONOTMERGE(rolfmorel): explain why the below is necessary
+def DefaultValuedMatmulIndexingMapsAttr :
+ Attr<AffineMapArrayAttr.predicate, AffineMapArrayAttr.summary> {
+ let storageType = AffineMapArrayAttr.storageType;
+ let returnType = AffineMapArrayAttr.returnType;
+ let convertFromStorage = AffineMapArrayAttr.convertFromStorage;
+ let constBuilderCall = "$_builder.getAffineMapArrayAttr($0.empty() ? MatmulOp::getDefaultIndexingMaps($_builder.getContext()) : $0)";
+ let defaultValue = "SmallVector<AffineMap>()";
+ let valueType = AffineMapArrayAttr.valueType;
+ let isOptional = 1;
+
+ let baseAttr = AffineMapArrayAttr;
+}
+
+
def MatmulOp : LinalgStructuredBase_Op<"matmul", [
AttrSizedOperandSegments,
LinalgContractionOpInterface]> {
@@ -606,7 +622,7 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
let arguments = (ins
Variadic<AnyType>:$inputs,
Variadic<AnyShaped>:$outputs,
- DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps,
+ DefaultValuedMatmulIndexingMapsAttr:$indexing_maps, // DONOTMERGE(rolfmorel): explain why this is necessary
DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
);
let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index 590229818c7440a..ab47946e7eab772 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -149,6 +149,11 @@ def __init__(
generic = region_op(GenericOp_, terminator=YieldOp)
+ at register_attribute_builder("DefaultValuedMatmulIndexingMapsAttr")
+def _DefaultValuedMatmulIndexingMapsAttr(x, context):
+ return ArrayAttr.get([AffineMapAttr.get(v) for v in x])
+
+
def matmul(
inputs: Sequence[Union[Operation, OpView, Value]],
*,
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index 36012303a52a231..c332f7001fcb237 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -337,9 +337,7 @@ def matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
)
linalg.fill_builtin_region(res.operation)
# CHECK: linalg.matmul ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
- linalg.matmul(
- (Amem, Bmem), outs=(Cmem,)
- )
+ linalg.matmul((Amem, Bmem), outs=(Cmem,))
# CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
res = linalg.MatmulOp(
>From e98838da2635fea28043b3a274e5a69dc4c570c4 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Sun, 9 Feb 2025 10:37:03 -0800
Subject: [PATCH 6/8] Get rid of "whole new type" by introducing "context
dependent" default valued attr class
Should probably be moved to CommonAttrConstraints.td
---
.../Dialect/Linalg/IR/LinalgStructuredOps.td | 25 +++++++++++--------
mlir/python/mlir/dialects/linalg/__init__.py | 5 ----
2 files changed, 15 insertions(+), 15 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index dc1c93355e2582b..f8cf63e6e8a3f0a 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -557,17 +557,19 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
// DONOTMERGE(rolfmorel): explain why the below is necessary
-def DefaultValuedMatmulIndexingMapsAttr :
- Attr<AffineMapArrayAttr.predicate, AffineMapArrayAttr.summary> {
- let storageType = AffineMapArrayAttr.storageType;
- let returnType = AffineMapArrayAttr.returnType;
- let convertFromStorage = AffineMapArrayAttr.convertFromStorage;
- let constBuilderCall = "$_builder.getAffineMapArrayAttr($0.empty() ? MatmulOp::getDefaultIndexingMaps($_builder.getContext()) : $0)";
- let defaultValue = "SmallVector<AffineMap>()";
- let valueType = AffineMapArrayAttr.valueType;
+class DefaultValuedContextDependentAttr<Attr attr,
+ string builderCall,
+ string default> :
+ Attr<attr.predicate, attr.summary> {
+ let storageType = attr.storageType;
+ let returnType = attr.returnType;
+ let convertFromStorage = attr.convertFromStorage;
+ let constBuilderCall = builderCall; // DONOTMERGE(rolfmorel): explain why this needs to be a parameter
+ let defaultValue = default;
+ let valueType = attr.valueType;
let isOptional = 1;
- let baseAttr = AffineMapArrayAttr;
+ let baseAttr = attr;
}
@@ -622,7 +624,10 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
let arguments = (ins
Variadic<AnyType>:$inputs,
Variadic<AnyShaped>:$outputs,
- DefaultValuedMatmulIndexingMapsAttr:$indexing_maps, // DONOTMERGE(rolfmorel): explain why this is necessary
+ DefaultValuedContextDependentAttr<AffineMapArrayAttr,
+ builderCall = [{ $_builder.getAffineMapArrayAttr(
+ $0.empty() ? MatmulOp::getDefaultIndexingMaps($_builder.getContext()) : $0
+ )}], default = "SmallVector<AffineMap>()">:$indexing_maps,
DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
);
let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index ab47946e7eab772..590229818c7440a 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -149,11 +149,6 @@ def __init__(
generic = region_op(GenericOp_, terminator=YieldOp)
- at register_attribute_builder("DefaultValuedMatmulIndexingMapsAttr")
-def _DefaultValuedMatmulIndexingMapsAttr(x, context):
- return ArrayAttr.get([AffineMapAttr.get(v) for v in x])
-
-
def matmul(
inputs: Sequence[Union[Operation, OpView, Value]],
*,
>From 931b47369710207c6e29c66124bcbff0318feb81 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Sun, 9 Feb 2025 10:52:18 -0800
Subject: [PATCH 7/8] Move builderCall parameter from custom attr class to
DefaultValued(Optional)Attr
---
.../include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td | 4 ++--
mlir/include/mlir/IR/CommonAttrConstraints.td | 8 ++++----
2 files changed, 6 insertions(+), 6 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index f8cf63e6e8a3f0a..6a7e7876d938107 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -624,10 +624,10 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
let arguments = (ins
Variadic<AnyType>:$inputs,
Variadic<AnyShaped>:$outputs,
- DefaultValuedContextDependentAttr<AffineMapArrayAttr,
+ DefaultValuedOptionalAttr<AffineMapArrayAttr, "SmallVector<AffineMap>()",
builderCall = [{ $_builder.getAffineMapArrayAttr(
$0.empty() ? MatmulOp::getDefaultIndexingMaps($_builder.getContext()) : $0
- )}], default = "SmallVector<AffineMap>()">:$indexing_maps,
+ )}]>:$indexing_maps,
DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
);
let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
diff --git a/mlir/include/mlir/IR/CommonAttrConstraints.td b/mlir/include/mlir/IR/CommonAttrConstraints.td
index 599f5ecba5803b0..f677009334fa21d 100644
--- a/mlir/include/mlir/IR/CommonAttrConstraints.td
+++ b/mlir/include/mlir/IR/CommonAttrConstraints.td
@@ -90,7 +90,7 @@ class DialectAttr<Dialect d, Pred condition, string summary = ""> :
// Attribute modifier definition
// Decorates an attribute to have an (unvalidated) default value if not present.
-class DefaultValuedAttr<Attr attr, string val> :
+class DefaultValuedAttr<Attr attr, string val, string builderCall = ""> :
Attr<attr.predicate, attr.summary> {
// Construct this attribute with the input attribute and change only
// the default value.
@@ -98,7 +98,7 @@ class DefaultValuedAttr<Attr attr, string val> :
let storageType = attr.storageType;
let returnType = attr.returnType;
let convertFromStorage = attr.convertFromStorage;
- let constBuilderCall = attr.constBuilderCall;
+ let constBuilderCall = !if(!eq(builderCall, ""), attr.constBuilderCall, builderCall);
let defaultValue = val;
let valueType = attr.valueType;
@@ -107,7 +107,7 @@ class DefaultValuedAttr<Attr attr, string val> :
// Decorates an optional attribute to have an (unvalidated) default value
// return by ODS generated accessors if not present.
-class DefaultValuedOptionalAttr<Attr attr, string val> :
+class DefaultValuedOptionalAttr<Attr attr, string val, string builderCall = ""> :
Attr<attr.predicate, attr.summary> {
// Construct this attribute with the input attribute and change only
// the default value.
@@ -115,7 +115,7 @@ class DefaultValuedOptionalAttr<Attr attr, string val> :
let storageType = attr.storageType;
let returnType = attr.returnType;
let convertFromStorage = attr.convertFromStorage;
- let constBuilderCall = attr.constBuilderCall;
+ let constBuilderCall = !if(!eq(builderCall, ""), attr.constBuilderCall, builderCall);
let defaultValue = val;
let valueType = attr.valueType;
let isOptional = 1;
>From 538566a7b5f746bf7e1d6bdd6706a10b6dda62c7 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Sun, 9 Feb 2025 10:57:08 -0800
Subject: [PATCH 8/8] Remove junk
---
.../Dialect/Linalg/IR/LinalgStructuredOps.td | 18 ------------------
1 file changed, 18 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 6a7e7876d938107..f60b369a1200815 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -555,24 +555,6 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
// Op definition for MatmulOp
//===----------------------------------------------------------------------===//
-
-// DONOTMERGE(rolfmorel): explain why the below is necessary
-class DefaultValuedContextDependentAttr<Attr attr,
- string builderCall,
- string default> :
- Attr<attr.predicate, attr.summary> {
- let storageType = attr.storageType;
- let returnType = attr.returnType;
- let convertFromStorage = attr.convertFromStorage;
- let constBuilderCall = builderCall; // DONOTMERGE(rolfmorel): explain why this needs to be a parameter
- let defaultValue = default;
- let valueType = attr.valueType;
- let isOptional = 1;
-
- let baseAttr = attr;
-}
-
-
def MatmulOp : LinalgStructuredBase_Op<"matmul", [
AttrSizedOperandSegments,
LinalgContractionOpInterface]> {
More information about the Mlir-commits
mailing list