[Mlir-commits] [mlir] f796bc6 - [MLIR][Linalg] Expose linalg.matmul and linalg.contract via Python API (#126377)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Feb 10 04:05:17 PST 2025
Author: Rolf Morel
Date: 2025-02-10T12:05:13Z
New Revision: f796bc622a7725708b8ffbe0c7a684a8557e77a3
URL: https://github.com/llvm/llvm-project/commit/f796bc622a7725708b8ffbe0c7a684a8557e77a3
DIFF: https://github.com/llvm/llvm-project/commit/f796bc622a7725708b8ffbe0c7a684a8557e77a3.diff
LOG: [MLIR][Linalg] Expose linalg.matmul and linalg.contract via Python API (#126377)
Now that linalg.matmul is in tablegen, "hand write" the Python wrapper
that OpDSL used to derive. Similarly, add a Python wrapper for the new
linalg.contract op.
Required following misc. fixes:
1) make linalg.matmul's parsing and printing consistent w.r.t. whether
indexing_maps occurs before or after operands, i.e. per the tests cases
it comes _before_.
2) tablegen for linalg.contract did not state it accepted an optional
cast attr.
3) In ODS's C++-generating code, expand partial support for `$_builder`
access in `Attr::defaultValue` to full support. This enables access to
the current `MlirContext` when constructing the default value (as is
required when the default value consists of affine maps).
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
mlir/include/mlir/IR/CommonAttrConstraints.td
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/python/mlir/dialects/linalg/__init__.py
mlir/test/Dialect/Linalg/named-ops.mlir
mlir/test/python/dialects/linalg/ops.py
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
mlir/tools/mlir-tblgen/OpFormatGen.cpp
mlir/tools/mlir-tblgen/RewriterGen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 110ed7d2fc00e2a..29cb8035b583b54 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -606,7 +606,10 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
let arguments = (ins
Variadic<AnyType>:$inputs,
Variadic<AnyShaped>:$outputs,
- DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps,
+ DefaultValuedOptionalAttr<
+ AffineMapArrayAttr,
+ "MatmulOp::getDefaultIndexingMaps($_builder.getContext())"
+ >:$indexing_maps,
DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
);
let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
@@ -752,7 +755,8 @@ def ContractOp : LinalgStructuredBase_Op<"contract", [
let arguments = (ins
Variadic<AnyType>:$inputs,
Variadic<AnyShaped>:$outputs,
- AffineMapArrayAttr:$indexing_maps
+ AffineMapArrayAttr:$indexing_maps,
+ DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
);
let results = (outs Variadic<AnyShaped>:$result_tensors);
// NB: The only reason this op has a region - and it get populated at op build
diff --git a/mlir/include/mlir/IR/CommonAttrConstraints.td b/mlir/include/mlir/IR/CommonAttrConstraints.td
index 599f5ecba5803b0..2beb1e8110afe8a 100644
--- a/mlir/include/mlir/IR/CommonAttrConstraints.td
+++ b/mlir/include/mlir/IR/CommonAttrConstraints.td
@@ -50,6 +50,9 @@ class Attr<Pred condition, string summary = ""> :
// Default value for attribute.
// Requires a constBuilderCall defined.
+ //
+ // Format: `$_builder` will be expanded to the relevant builder, e.g. to allow
+ // access to the current context.
string defaultValue = ?;
// The value type of this attribute. This corresponds to the mlir::Type that
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index b50931f15826ce2..d40cec02df6338d 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -3666,11 +3666,6 @@ ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
}
void MatmulOp::print(OpAsmPrinter &p) {
- SmallVector<StringRef, 3> elidedAttrs = {
- "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
- printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
- elidedAttrs);
-
SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
MatmulOp::getDefaultIndexingMaps(getContext()),
[](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
@@ -3680,6 +3675,11 @@ void MatmulOp::print(OpAsmPrinter &p) {
[&](Attribute attr) { p.printAttribute(attr); });
p << "]";
}
+
+ SmallVector<StringRef, 3> elidedAttrs = {
+ "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
+ printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
+ elidedAttrs);
}
/// Verify the user defined indexing maps.
diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index 742262a9c496952..5cda4769d593f35 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -147,3 +147,49 @@ def __init__(
generic = region_op(GenericOp_, terminator=YieldOp)
+
+
+def matmul(
+ *ins: Union[Operation, OpView, Value],
+ outs: Sequence[Union[Operation, OpView, Value]],
+ indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
+ cast: Optional[Union[TypeFn, Attribute]] = None,
+):
+ ins = [_get_op_result_or_value(input) for input in ins]
+ if len(outs) > 1:
+ raise ValueError(f"{outs=} must have length 1.")
+ init = _get_op_result_or_value(outs[0])
+ result_types = [init.type] if isinstance(init.type, RankedTensorType) else []
+
+ op = MatmulOp(
+ result_tensors=result_types,
+ inputs=ins,
+ outputs=[init],
+ indexing_maps=indexing_maps,
+ cast=cast,
+ )
+ fill_builtin_region(op.operation)
+ return op
+
+
+def contract(
+ *ins: Union[Operation, OpView, Value],
+ outs: Sequence[Union[Operation, OpView, Value]],
+ indexing_maps: Sequence[AffineMapAttr],
+ cast: Optional[Union[TypeFn, Attribute]] = None,
+):
+ ins = [_get_op_result_or_value(input) for input in ins]
+ if len(outs) > 1:
+ raise ValueError(f"{outs=} must have length 1.")
+ init = _get_op_result_or_value(outs[0])
+ result_types = [init.type] if isinstance(init.type, RankedTensorType) else []
+
+ op = ContractOp(
+ result_tensors=result_types,
+ inputs=ins,
+ outputs=[init],
+ indexing_maps=indexing_maps,
+ cast=cast,
+ )
+ fill_builtin_region(op.operation)
+ return op
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index ed8683522c74a1b..68ea97be911a66a 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1269,7 +1269,7 @@ func.func @matmul_transpose_b_explicit(%arg0: memref<3x5xf32>, %arg1: memref<7x5
// CHECK-SAME: %[[VAL_0:.*]]: memref<3x5xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: memref<7x5xf32>,
// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
-// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK: linalg.matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
// CHECK: return
// CHECK: }
@@ -1294,7 +1294,7 @@ func.func @matmul_transpose_a_b_explicit(%arg0: memref<5x3xf32>, %arg1: memref<7
// CHECK-SAME: %[[VAL_0:.*]]: memref<5x3xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: memref<7x5xf32>,
// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
-// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5x3xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK: linalg.matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<5x3xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
// CHECK: return
// CHECK: }
@@ -1315,6 +1315,7 @@ func.func @matmul_bcast_a(%arg0: memref<5xf32>, %arg1: memref<5x7xf32>, %arg2: m
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-LABEL: func @matmul_bcast_a
// CHECK: linalg.matmul
+// CHECK-SAME: indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<5xf32>, memref<5x7xf32>)
// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
@@ -1335,6 +1336,7 @@ func.func @matmul_bcast_a_dim1(%arg0: memref<5xf32>, %arg1: memref<5x7xf32>, %ar
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-LABEL: func @matmul_bcast_a_dim1
// CHECK: linalg.matmul
+// CHECK-SAME: indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<5xf32>, memref<5x7xf32>)
// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
@@ -1355,6 +1357,7 @@ func.func @matmul_bcast_b(%arg0: memref<3x5xf32>, %arg1: memref<5xf32>, %arg2: m
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-LABEL: func @matmul_bcast_b
// CHECK: linalg.matmul
+// CHECK-SAME: indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<5xf32>)
// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
@@ -1376,7 +1379,7 @@ func.func @matmul_bcast_a_b(%arg0: memref<5xf32>, %arg1: memref<5xf32>, %arg2: m
// CHECK-LABEL: func.func @matmul_bcast_a_b(
// CHECK-SAME: %[[VAL_0:.*]]: memref<5xf32>, %[[VAL_1:.*]]: memref<5xf32>,
// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
-// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_1]]]
+// CHECK: linalg.matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_1]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
// CHECK: return
// CHECK: }
@@ -1397,6 +1400,7 @@ func.func @matmul_bcast_b_dim1(%arg0: memref<3x5xf32>, %arg1: memref<5xf32>, %ar
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-LABEL: func @matmul_bcast_b_dim1
// CHECK: linalg.matmul
+// CHECK-SAME: indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<5xf32>)
// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
@@ -1420,7 +1424,7 @@ func.func @dynamic_matmul_bcast_a(%arg0: memref<?xf32>, %arg1: memref<?x?xf32>,
// CHECK-SAME: %[[VAL_0:.*]]: memref<?xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: memref<?x?xf32>,
// CHECK-SAME: %[[VAL_2:.*]]: memref<?x?xf32>) {
-// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<?xf32>, memref<?x?xf32>) outs(%[[VAL_2]] : memref<?x?xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK: linalg.matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<?xf32>, memref<?x?xf32>) outs(%[[VAL_2]] : memref<?x?xf32>)
// CHECK: return
// CHECK: }
@@ -1444,7 +1448,7 @@ func.func @matmul_bcast_a_transpose_b(%arg0: memref<5xf32>, %arg1: memref<7x5xf3
// CHECK-SAME: %[[VAL_0:.*]]: memref<5xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: memref<7x5xf32>,
// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
-// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK: linalg.matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
// CHECK: return
// CHECK: }
@@ -1468,7 +1472,7 @@ func.func @matmul_bcast_b_transpose_a(%arg0: memref<5x3xf32>, %arg1: memref<5xf3
// CHECK-SAME: %[[VAL_0:.*]]: memref<5x3xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: memref<5xf32>,
// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
-// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5x3xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK: linalg.matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<5x3xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
// CHECK: return
// CHECK: }
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index ac7186c24bed84e..94f8ea4faf4a806 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -256,3 +256,213 @@ def f(a, b):
module.operation.verify()
print(module)
+
+
+# CHECK-LABEL: TEST: testMatmulOp
+ at run
+def testMatmulOp():
+ with Context(), Location.unknown():
+ module = Module.create()
+ f32 = F32Type.get()
+ with InsertionPoint(module.body):
+ a_shape = (4, 8)
+ b_shape = (8, 12)
+ b_transposed_shape = (12, 8)
+ c_shape = (4, 12)
+
+ dimM = ir.AffineDimExpr.get(0)
+ dimN = ir.AffineDimExpr.get(1)
+ dimK = ir.AffineDimExpr.get(2)
+
+ # CHECK: #[[$A_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+ # CHECK: #[[$BTrans_MAP:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+ # CHECK: #[[$C_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+ a_map = ir.AffineMap.get(3, 0, [dimM, dimK])
+ b_map = ir.AffineMap.get(3, 0, [dimK, dimN])
+ c_map = ir.AffineMap.get(3, 0, [dimM, dimN])
+ b_transposed_map = ir.AffineMap.get(3, 0, [dimN, dimK])
+
+ # CHECK: func.func @matmul_op(
+ @func.FuncOp.from_py_func(
+ # CHECK-SAME: %[[A:.*]]: tensor<4x8xf32>,
+ RankedTensorType.get(a_shape, f32),
+ # CHECK-SAME: %[[Amem:.*]]: memref<4x8xf32>,
+ MemRefType.get(a_shape, f32),
+ # CHECK-SAME: %[[B:.*]]: tensor<8x12xf32>,
+ RankedTensorType.get(b_shape, f32),
+ # CHECK-SAME: %[[Bmem:.*]]: memref<8x12xf32>,
+ MemRefType.get(b_shape, f32),
+ # CHECK-SAME: %[[BTrans:.*]]: tensor<12x8xf32>,
+ RankedTensorType.get(b_transposed_shape, f32),
+ # CHECK-SAME: %[[BTransmem:.*]]: memref<12x8xf32>,
+ MemRefType.get(b_transposed_shape, f32),
+ # CHECK-SAME: %[[C:.*]]: tensor<4x12xf32>,
+ RankedTensorType.get(c_shape, f32),
+ # CHECK-SAME: %[[Cmem:.*]]: memref<4x12xf32>)
+ MemRefType.get(c_shape, f32),
+ )
+ def matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
+ # CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
+ res = linalg.MatmulOp(
+ result_tensors=(C.type,),
+ inputs=(A, B),
+ outputs=(C,),
+ )
+ linalg.fill_builtin_region(res.operation)
+ # CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
+ res = linalg.matmul(A, B, outs=(C,))
+
+ # CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
+ res = linalg.MatmulOp(
+ result_tensors=(C.type,),
+ inputs=(A, Btransposed),
+ outputs=(C,),
+ indexing_maps=[a_map, b_transposed_map, c_map],
+ )
+ linalg.fill_builtin_region(res.operation)
+ # CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
+ res = linalg.matmul(
+ A,
+ Btransposed,
+ outs=(C,),
+ indexing_maps=[a_map, b_transposed_map, c_map],
+ )
+
+ # And now with memrefs...
+
+ # CHECK: linalg.matmul ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+ res = linalg.MatmulOp(
+ result_tensors=[],
+ inputs=(Amem, Bmem),
+ outputs=(Cmem,),
+ )
+ linalg.fill_builtin_region(res.operation)
+ # CHECK: linalg.matmul ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+ linalg.matmul(Amem, Bmem, outs=(Cmem,))
+
+ # CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+ res = linalg.MatmulOp(
+ result_tensors=[],
+ inputs=(Amem, Btransposedmem),
+ outputs=(Cmem,),
+ indexing_maps=[a_map, b_transposed_map, c_map],
+ )
+ linalg.fill_builtin_region(res.operation)
+ # CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+ linalg.matmul(
+ Amem,
+ Btransposedmem,
+ outs=(Cmem,),
+ indexing_maps=[a_map, b_transposed_map, c_map],
+ )
+
+ print(module)
+
+
+# CHECK-LABEL: TEST: testContractOp
+ at run
+def testContractOp():
+ with Context(), Location.unknown():
+ module = Module.create()
+ f32 = F32Type.get()
+ with InsertionPoint(module.body):
+ a_shape = (4, 8)
+ b_shape = (8, 12)
+ b_transposed_shape = (12, 8)
+ c_shape = (4, 12)
+
+ dimM = ir.AffineDimExpr.get(0)
+ dimN = ir.AffineDimExpr.get(1)
+ dimK = ir.AffineDimExpr.get(2)
+
+ # CHECK: #[[$A_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+ # CHECK: #[[$B_MAP:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+ # CHECK: #[[$C_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+ # CHECK: #[[$BTrans_MAP:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+ a_map = ir.AffineMap.get(3, 0, [dimM, dimK])
+ b_map = ir.AffineMap.get(3, 0, [dimK, dimN])
+ c_map = ir.AffineMap.get(3, 0, [dimM, dimN])
+ b_transposed_map = ir.AffineMap.get(3, 0, [dimN, dimK])
+
+ # CHECK: func.func @matmul_as_contract_op(
+ @func.FuncOp.from_py_func(
+ # CHECK-SAME: %[[A:.*]]: tensor<4x8xf32>,
+ RankedTensorType.get(a_shape, f32),
+ # CHECK-SAME: %[[Amem:.*]]: memref<4x8xf32>,
+ MemRefType.get(a_shape, f32),
+ # CHECK-SAME: %[[B:.*]]: tensor<8x12xf32>,
+ RankedTensorType.get(b_shape, f32),
+ # CHECK-SAME: %[[Bmem:.*]]: memref<8x12xf32>,
+ MemRefType.get(b_shape, f32),
+ # CHECK-SAME: %[[BTrans:.*]]: tensor<12x8xf32>,
+ RankedTensorType.get(b_transposed_shape, f32),
+ # CHECK-SAME: %[[BTransmem:.*]]: memref<12x8xf32>,
+ MemRefType.get(b_transposed_shape, f32),
+ # CHECK-SAME: %[[C:.*]]: tensor<4x12xf32>,
+ RankedTensorType.get(c_shape, f32),
+ # CHECK-SAME: %[[Cmem:.*]]: memref<4x12xf32>)
+ MemRefType.get(c_shape, f32),
+ )
+ def matmul_as_contract_op(
+ A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem
+ ):
+ # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$B_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
+ op4 = linalg.ContractOp(
+ result_tensors=(C.type,),
+ inputs=(A, B),
+ outputs=(C,),
+ indexing_maps=[a_map, b_map, c_map],
+ )
+ linalg.fill_builtin_region(op4.operation)
+ # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$B_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
+ op5 = linalg.contract(
+ A, B, outs=(C,), indexing_maps=[a_map, b_map, c_map]
+ )
+
+ # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
+ op4 = linalg.ContractOp(
+ result_tensors=(C.type,),
+ inputs=(A, Btransposed),
+ outputs=(C,),
+ indexing_maps=[a_map, b_transposed_map, c_map],
+ )
+ linalg.fill_builtin_region(op4.operation)
+ # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
+ op5 = linalg.contract(
+ A,
+ Btransposed,
+ outs=(C,),
+ indexing_maps=[a_map, b_transposed_map, c_map],
+ )
+ # And now with memrefs...
+
+ # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$B_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+ op4 = linalg.ContractOp(
+ result_tensors=[],
+ inputs=(Amem, Bmem),
+ outputs=(Cmem,),
+ indexing_maps=[a_map, b_map, c_map],
+ )
+ linalg.fill_builtin_region(op4.operation)
+ # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$B_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+ linalg.contract(
+ Amem, Bmem, outs=(Cmem,), indexing_maps=[a_map, b_map, c_map]
+ )
+
+ # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+ op4 = linalg.ContractOp(
+ result_tensors=[],
+ inputs=(Amem, Btransposedmem),
+ outputs=(Cmem,),
+ indexing_maps=[a_map, b_transposed_map, c_map],
+ )
+ linalg.fill_builtin_region(op4.operation)
+ # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+ linalg.contract(
+ Amem,
+ Btransposedmem,
+ outs=(Cmem,),
+ indexing_maps=[a_map, b_transposed_map, c_map],
+ )
+
+ print(module)
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index a970cbc5cacebe3..629e863dac5e3af 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -1334,8 +1334,9 @@ static void emitAttrGetterWithReturnType(FmtContext &fctx,
PrintFatalError("DefaultValuedAttr of type " + attr.getAttrDefName() +
" must have a constBuilder");
}
- std::string defaultValue = std::string(
- tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
+ std::string defaultValue =
+ std::string(tgfmt(attr.getConstBuilderTemplate(), &fctx,
+ tgfmt(attr.getDefaultValue(), &fctx)));
body << " if (!attr)\n return "
<< tgfmt(attr.getConvertFromStorageCall(),
&fctx.withSelf(defaultValue))
@@ -1467,6 +1468,7 @@ void OpEmitter::genPropertiesSupport() {
os << " if (!attr) attr = dict.get(\"result_segment_sizes\");";
}
+ fctx.withBuilder(odsBuilder);
setPropMethod << "{\n"
<< formatv(propFromAttrFmt,
tgfmt(prop.getConvertFromAttributeCall(),
@@ -1479,7 +1481,7 @@ void OpEmitter::genPropertiesSupport() {
prop.getStorageTypeValueOverride());
} else if (prop.hasDefaultValue()) {
setPropMethod << formatv(attrGetDefaultFmt, name,
- prop.getDefaultValue());
+ tgfmt(prop.getDefaultValue(), &fctx));
} else {
setPropMethod << formatv(attrGetNoDefaultFmt, name);
}
@@ -2919,6 +2921,9 @@ getBuilderSignature(const Builder &builder) {
arguments.emplace_back("::mlir::OpBuilder &", odsBuilder);
arguments.emplace_back("::mlir::OperationState &", builderOpState);
+ FmtContext fctx;
+ fctx.withBuilder(odsBuilder);
+
for (unsigned i = 0, e = params.size(); i < e; ++i) {
// If no name is provided, generate one.
std::optional<StringRef> paramName = params[i].getName();
@@ -2931,7 +2936,7 @@ getBuilderSignature(const Builder &builder) {
defaultValue = *defaultParamValue;
arguments.emplace_back(params[i].getCppType(), std::move(name),
- defaultValue);
+ tgfmt(defaultValue, &fctx));
}
return arguments;
@@ -3189,6 +3194,9 @@ void OpEmitter::buildParamList(SmallVectorImpl<MethodParameter> ¶mList,
}
}
+ FmtContext fctx;
+ fctx.withBuilder(odsBuilder);
+
for (int i = 0, e = op.getNumArgs(), numOperands = 0; i < e; ++i) {
Argument arg = op.getArg(i);
if (const auto *operand =
@@ -3210,7 +3218,7 @@ void OpEmitter::buildParamList(SmallVectorImpl<MethodParameter> ¶mList,
StringRef type = prop.getInterfaceType();
std::string defaultValue;
if (prop.hasDefaultValue() && i >= defaultValuedAttrLikeStartIndex) {
- defaultValue = prop.getDefaultValue();
+ defaultValue = tgfmt(prop.getDefaultValue(), &fctx);
}
bool isOptional = prop.hasDefaultValue();
paramList.emplace_back(type, propArg->name, StringRef(defaultValue),
@@ -3242,7 +3250,7 @@ void OpEmitter::buildParamList(SmallVectorImpl<MethodParameter> ¶mList,
if (i >= defaultValuedAttrStartIndex) {
if (attrParamKind == AttrParamKind::UnwrappedValue &&
canUseUnwrappedRawValue(attr))
- defaultValue += attr.getDefaultValue();
+ defaultValue += tgfmt(attr.getDefaultValue(), &fctx);
else
defaultValue += "nullptr";
}
@@ -4172,6 +4180,9 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
staticVerifierEmitter(staticVerifierEmitter),
emitHelper(op, /*emitForOp=*/false) {
+ FmtContext fctx;
+ fctx.withBuilder(odsBuilder);
+
genericAdaptorBase.declare<VisibilityDeclaration>(Visibility::Public);
bool useProperties = emitHelper.hasProperties();
if (useProperties) {
@@ -4212,7 +4223,7 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
if (prop.hasStorageTypeValueOverride())
os << " = " << prop.getStorageTypeValueOverride();
else if (prop.hasDefaultValue())
- os << " = " << prop.getDefaultValue();
+ os << " = " << tgfmt(prop.getDefaultValue(), &fctx);
comparatorOs << " rhs." << name << " == this->" << name
<< " &&\n";
// Emit accessors using the interface type.
@@ -4454,7 +4465,6 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
if (auto *m = genericAdaptor.addMethod("RangeT", "getOperands"))
m->body() << " return odsOperands;";
- FmtContext fctx;
fctx.withBuilder("::mlir::Builder(odsAttrs.getContext())");
// Generate named accessor with Attribute return type.
@@ -4481,8 +4491,9 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
// Use the default value if attribute is not set.
// TODO: this is inefficient, we are recreating the attribute for every
// call. This should be set instead.
- std::string defaultValue = std::string(
- tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
+ std::string defaultValue =
+ std::string(tgfmt(attr.getConstBuilderTemplate(), &fctx,
+ tgfmt(attr.getDefaultValue(), &fctx)));
body << "if (!attr)\n attr = " << defaultValue << ";\n";
}
body << "return attr;\n";
diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index f03a3bfd398ed68..fe724e86d670785 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -1999,7 +1999,7 @@ static void genNonDefaultValueCheck(MethodBody &body, const Operator &op,
fctx.withBuilder("::mlir::OpBuilder((*this)->getContext())");
body << getter << "Attr() != "
<< tgfmt(attr.getConstBuilderTemplate(), &fctx,
- attr.getDefaultValue());
+ tgfmt(attr.getDefaultValue(), &fctx));
}
if (optionalAndDefault)
body << ")";
@@ -2007,8 +2007,10 @@ static void genNonDefaultValueCheck(MethodBody &body, const Operator &op,
static void genNonDefaultValueCheck(MethodBody &body, const Operator &op,
PropertyVariable &propElement) {
- body << op.getGetterName(propElement.getVar()->name)
- << "() != " << propElement.getVar()->prop.getDefaultValue();
+ FmtContext fctx;
+ fctx.withBuilder("::mlir::OpBuilder((*this)->getContext())");
+ body << op.getGetterName(propElement.getVar()->name) << "() != "
+ << tgfmt(propElement.getVar()->prop.getDefaultValue(), &fctx);
}
/// Elide the variadic segment size attributes if necessary.
@@ -2045,8 +2047,9 @@ static void genPropDictPrinter(OperationFormat &fmt, Operator &op,
const StringRef &name = namedAttr.name;
FmtContext fctx;
fctx.withBuilder("odsBuilder");
- std::string defaultValue = std::string(
- tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
+ std::string defaultValue =
+ std::string(tgfmt(attr.getConstBuilderTemplate(), &fctx,
+ tgfmt(attr.getDefaultValue(), &fctx)));
body << " {\n";
body << " ::mlir::Builder odsBuilder(getContext());\n";
body << " ::mlir::Attribute attr = " << op.getGetterName(name)
@@ -2059,8 +2062,10 @@ static void genPropDictPrinter(OperationFormat &fmt, Operator &op,
// Similarly, elide default-valued properties.
for (const NamedProperty &prop : op.getProperties()) {
if (prop.prop.hasDefaultValue()) {
+ FmtContext fctx;
+ fctx.withBuilder("odsBuilder");
body << " if (" << op.getGetterName(prop.name)
- << "() == " << prop.prop.getDefaultValue() << ") {";
+ << "() == " << tgfmt(prop.prop.getDefaultValue(), &fctx) << ") {";
body << " elidedProps.push_back(\"" << prop.name << "\");\n";
body << " }\n";
}
@@ -2094,8 +2099,9 @@ static void genAttrDictPrinter(OperationFormat &fmt, Operator &op,
const StringRef &name = namedAttr.name;
FmtContext fctx;
fctx.withBuilder("odsBuilder");
- std::string defaultValue = std::string(
- tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
+ std::string defaultValue =
+ std::string(tgfmt(attr.getConstBuilderTemplate(), &fctx,
+ tgfmt(attr.getDefaultValue(), &fctx)));
body << " {\n";
body << " ::mlir::Builder odsBuilder(getContext());\n";
body << " ::mlir::Attribute attr = " << op.getGetterName(name)
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index a041c4d3277798d..f6eb5bdfe568e00 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -879,7 +879,7 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
if (attr.hasDefaultValue()) {
os << "if (!tblgen_attr) tblgen_attr = "
<< std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx,
- attr.getDefaultValue()))
+ tgfmt(attr.getDefaultValue(), &fmtCtx)))
<< ";\n";
} else if (attr.isOptional()) {
// For a missing attribute that is optional according to definition, we
More information about the Mlir-commits
mailing list