[Mlir-commits] [mlir] [MLIR][Linalg] Expose linalg.matmul and linalg.contract via Python API (PR #126377)

Rolf Morel llvmlistbot at llvm.org
Mon Feb 10 03:15:02 PST 2025


https://github.com/rolfmorel updated https://github.com/llvm/llvm-project/pull/126377

>From e3f2a7c23d2d10809e91ba419ab1092dda6264eb Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Sat, 8 Feb 2025 08:41:22 -0800
Subject: [PATCH 01/11] [MLIR][Linalg] Expose linalg.matmul and linalg.contract
 via Python API

Now that linalg.matmul is in tablegen, "hand write" the Python wrapper
that OpDSL used to derive. Similarly, add a Python wrapper for the new
linalg.contract op.

Required following misc. fixes:
1) make linalg.matmul consistent in whether indexing_maps occurs before
   or after operands, i.e. per the tests case it comes _before_.
   TODO: fix linalg.batch_matmul as well
2) tablegen for linalg.contract did not state it accepted an optional
   cast attr.
---
 .../Dialect/Linalg/IR/LinalgStructuredOps.td  |   3 +-
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      |  10 +-
 mlir/python/mlir/dialects/linalg/__init__.py  |  48 +++++
 mlir/test/python/dialects/linalg/ops.py       | 186 ++++++++++++++++++
 4 files changed, 241 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 110ed7d2fc00e2a..6146ff09482fbad 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -752,7 +752,8 @@ def ContractOp : LinalgStructuredBase_Op<"contract", [
   let arguments = (ins
     Variadic<AnyType>:$inputs,
     Variadic<AnyShaped>:$outputs,
-    AffineMapArrayAttr:$indexing_maps
+    AffineMapArrayAttr:$indexing_maps,
+    DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
   );
   let results = (outs Variadic<AnyShaped>:$result_tensors);
   // NB: The only reason this op has a region - and it get populated at op build
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index b50931f15826ce2..d40cec02df6338d 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -3666,11 +3666,6 @@ ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
 }
 
 void MatmulOp::print(OpAsmPrinter &p) {
-  SmallVector<StringRef, 3> elidedAttrs = {
-      "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
-  printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
-                         elidedAttrs);
-
   SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
       MatmulOp::getDefaultIndexingMaps(getContext()),
       [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
@@ -3680,6 +3675,11 @@ void MatmulOp::print(OpAsmPrinter &p) {
                           [&](Attribute attr) { p.printAttribute(attr); });
     p << "]";
   }
+
+  SmallVector<StringRef, 3> elidedAttrs = {
+      "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
+  printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
+                         elidedAttrs);
 }
 
 /// Verify the user defined indexing maps.
diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index 742262a9c496952..07e7c0ccd70025d 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -147,3 +147,51 @@ def __init__(
 
 
 generic = region_op(GenericOp_, terminator=YieldOp)
+
+
+def matmul(
+    inputs: Sequence[Union[Operation, OpView, Value]],
+    *,
+    outs: Sequence[Union[Operation, OpView, Value]],
+    indexing_maps: Sequence[AffineMapAttr],
+    cast: Optional[Union[TypeFn, Attribute]]=None
+):
+    inputs = [_get_op_result_or_value(input) for input in inputs]
+    if len(outs) > 1:
+        raise ValueError(f"{outs=} must have length 1.")
+    init = _get_op_result_or_value(outs[0])
+    result_types = [init.type] if isinstance(init.type, RankedTensorType) else []
+
+    op = MatmulOp(
+        result_tensors=result_types,
+        inputs=inputs,
+        outputs=[init],
+        indexing_maps=indexing_maps,
+        cast=cast
+    )
+    fill_builtin_region(op.operation)
+    return op
+
+
+def contract(
+    inputs: Sequence[Union[Operation, OpView, Value]],
+    *,
+    outs: Sequence[Union[Operation, OpView, Value]],
+    indexing_maps: Sequence[AffineMapAttr],
+    cast: Optional[Union[TypeFn, Attribute]]=None
+):
+    inputs = [_get_op_result_or_value(input) for input in inputs]
+    if len(outs) > 1:
+        raise ValueError(f"{outs=} must have length 1.")
+    init = _get_op_result_or_value(outs[0])
+    result_types = [init.type] if isinstance(init.type, RankedTensorType) else []
+
+    op = ContractOp(
+        result_tensors=result_types,
+        inputs=inputs,
+        outputs=[init],
+        indexing_maps=indexing_maps,
+        cast=cast
+    )
+    fill_builtin_region(op.operation)
+    return op
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index ac7186c24bed84e..6baea4f917c128c 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -256,3 +256,189 @@ def f(a, b):
 
     module.operation.verify()
     print(module)
+
+
+# CHECK-LABEL: TEST: testMatmulOp
+ at run
+def testMatmulOp():
+    with Context(), Location.unknown():
+        module = Module.create()
+        f32 = F32Type.get()
+        with InsertionPoint(module.body):
+            a_shape = (4, 8)
+            b_shape = (8, 12)
+            b_transposed_shape = (12, 8)
+            c_shape = (4, 12)
+
+            dimM = ir.AffineDimExpr.get(0)
+            dimN = ir.AffineDimExpr.get(1)
+            dimK = ir.AffineDimExpr.get(2)
+
+            # CHECK: #[[$A_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+            # CHECK: #[[$BTrans_MAP:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+            # CHECK: #[[$C_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+            a_map = ir.AffineMap.get(3, 0, [dimM, dimK])
+            b_map = ir.AffineMap.get(3, 0, [dimK, dimN])
+            c_map = ir.AffineMap.get(3, 0, [dimM, dimN])
+            b_transposed_map = ir.AffineMap.get(3, 0, [dimN, dimK])
+
+            # CHECK: func.func @matmul_op(
+            @func.FuncOp.from_py_func(
+                # CHECK-SAME:                         %[[A:.*]]: tensor<4x8xf32>,
+                RankedTensorType.get(a_shape, f32),
+                # CHECK-SAME:                         %[[Amem:.*]]: memref<4x8xf32>,
+                MemRefType.get(a_shape, f32),
+                # CHECK-SAME:                         %[[B:.*]]: tensor<8x12xf32>,
+                RankedTensorType.get(b_shape, f32),
+                # CHECK-SAME:                         %[[Bmem:.*]]: memref<8x12xf32>,
+                MemRefType.get(b_shape, f32),
+                # CHECK-SAME:                         %[[BTrans:.*]]: tensor<12x8xf32>,
+                RankedTensorType.get(b_transposed_shape, f32),
+                # CHECK-SAME:                         %[[BTransmem:.*]]: memref<12x8xf32>,
+                MemRefType.get(b_transposed_shape, f32),
+                # CHECK-SAME:                         %[[C:.*]]: tensor<4x12xf32>,
+                RankedTensorType.get(c_shape, f32),
+                # CHECK-SAME:                         %[[Cmem:.*]]: memref<4x12xf32>)
+                MemRefType.get(c_shape, f32),
+            )
+            def matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
+                # CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
+                op4 = linalg.MatmulOp(
+                    result_tensors=(C.type,),
+                    inputs=(A, B),
+                    outputs=(C,),
+                    indexing_maps=[a_map, b_map, c_map]
+                )
+                linalg.fill_builtin_region(op4.operation)
+                # CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
+                op5 = linalg.matmul((A, B), outs=(C,), indexing_maps=[a_map, b_map, c_map])
+
+                # CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
+                op4 = linalg.MatmulOp(
+                    result_tensors=(C.type,),
+                    inputs=(A, Btransposed),
+                    outputs=(C,),
+                    indexing_maps=[a_map, b_transposed_map, c_map]
+                )
+                linalg.fill_builtin_region(op4.operation)
+                # CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
+                op5 = linalg.matmul((A, Btransposed), outs=(C,), indexing_maps=[a_map, b_transposed_map, c_map])
+
+                # And now with memrefs...
+
+                # CHECK: linalg.matmul ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+                op4 = linalg.MatmulOp(
+                    result_tensors=[],
+                    inputs=(Amem, Bmem),
+                    outputs=(Cmem,),
+                    indexing_maps=[a_map, b_map, c_map]
+                )
+                linalg.fill_builtin_region(op4.operation)
+                # CHECK: linalg.matmul ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+                linalg.matmul((Amem, Bmem), outs=(Cmem,), indexing_maps=[a_map, b_map, c_map])
+
+                # CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+                op4 = linalg.MatmulOp(
+                    result_tensors=[],
+                    inputs=(Amem, Btransposedmem),
+                    outputs=(Cmem,),
+                    indexing_maps=[a_map, b_transposed_map, c_map]
+                )
+                linalg.fill_builtin_region(op4.operation)
+                # CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+                linalg.matmul((Amem, Btransposedmem), outs=(Cmem,), indexing_maps=[a_map, b_transposed_map, c_map])
+
+        print(module)
+
+
+# CHECK-LABEL: TEST: testContractOp
+ at run
+def testContractOp():
+    with Context(), Location.unknown():
+        module = Module.create()
+        f32 = F32Type.get()
+        with InsertionPoint(module.body):
+            a_shape = (4, 8)
+            b_shape = (8, 12)
+            b_transposed_shape = (12, 8)
+            c_shape = (4, 12)
+
+            dimM = ir.AffineDimExpr.get(0)
+            dimN = ir.AffineDimExpr.get(1)
+            dimK = ir.AffineDimExpr.get(2)
+
+            # CHECK: #[[$A_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+            # CHECK: #[[$B_MAP:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+            # CHECK: #[[$C_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+            # CHECK: #[[$BTrans_MAP:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+            a_map = ir.AffineMap.get(3, 0, [dimM, dimK])
+            b_map = ir.AffineMap.get(3, 0, [dimK, dimN])
+            c_map = ir.AffineMap.get(3, 0, [dimM, dimN])
+            b_transposed_map = ir.AffineMap.get(3, 0, [dimN, dimK])
+
+            # CHECK: func.func @matmul_as_contract_op(
+            @func.FuncOp.from_py_func(
+                # CHECK-SAME:                         %[[A:.*]]: tensor<4x8xf32>,
+                RankedTensorType.get(a_shape, f32),
+                # CHECK-SAME:                         %[[Amem:.*]]: memref<4x8xf32>,
+                MemRefType.get(a_shape, f32),
+                # CHECK-SAME:                         %[[B:.*]]: tensor<8x12xf32>,
+                RankedTensorType.get(b_shape, f32),
+                # CHECK-SAME:                         %[[Bmem:.*]]: memref<8x12xf32>,
+                MemRefType.get(b_shape, f32),
+                # CHECK-SAME:                         %[[BTrans:.*]]: tensor<12x8xf32>,
+                RankedTensorType.get(b_transposed_shape, f32),
+                # CHECK-SAME:                         %[[BTransmem:.*]]: memref<12x8xf32>,
+                MemRefType.get(b_transposed_shape, f32),
+                # CHECK-SAME:                         %[[C:.*]]: tensor<4x12xf32>,
+                RankedTensorType.get(c_shape, f32),
+                # CHECK-SAME:                         %[[Cmem:.*]]: memref<4x12xf32>)
+                MemRefType.get(c_shape, f32),
+            )
+            def matmul_as_contract_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
+                # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$B_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
+                op4 = linalg.ContractOp(
+                    result_tensors=(C.type,),
+                    inputs=(A, B),
+                    outputs=(C,),
+                    indexing_maps=[a_map, b_map, c_map]
+                )
+                linalg.fill_builtin_region(op4.operation)
+                # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$B_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
+                op5 = linalg.contract((A, B), outs=(C,), indexing_maps=[a_map, b_map, c_map])
+
+                # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
+                op4 = linalg.ContractOp(
+                    result_tensors=(C.type,),
+                    inputs=(A, Btransposed),
+                    outputs=(C,),
+                    indexing_maps=[a_map, b_transposed_map, c_map]
+                )
+                linalg.fill_builtin_region(op4.operation)
+                # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
+                op5 = linalg.contract((A, Btransposed), outs=(C,), indexing_maps=[a_map, b_transposed_map, c_map])
+                # And now with memrefs...
+
+                # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$B_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+                op4 = linalg.ContractOp(
+                    result_tensors=[],
+                    inputs=(Amem, Bmem),
+                    outputs=(Cmem,),
+                    indexing_maps=[a_map, b_map, c_map]
+                )
+                linalg.fill_builtin_region(op4.operation)
+                # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$B_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+                linalg.contract((Amem, Bmem), outs=(Cmem,), indexing_maps=[a_map, b_map, c_map])
+
+                # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+                op4 = linalg.ContractOp(
+                    result_tensors=[],
+                    inputs=(Amem, Btransposedmem),
+                    outputs=(Cmem,),
+                    indexing_maps=[a_map, b_transposed_map, c_map]
+                )
+                linalg.fill_builtin_region(op4.operation)
+                # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+                linalg.contract((Amem, Btransposedmem), outs=(Cmem,), indexing_maps=[a_map, b_transposed_map, c_map])
+
+        print(module)

>From 57b2509c212f99e9838e0afb4790fc551ff337dc Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Sat, 8 Feb 2025 09:00:51 -0800
Subject: [PATCH 02/11] Fix linalg.matmul tests which encoded different
 placements of indexing_maps

---
 mlir/test/Dialect/Linalg/named-ops.mlir | 16 ++++++++++------
 1 file changed, 10 insertions(+), 6 deletions(-)

diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index ed8683522c74a1b..68ea97be911a66a 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1269,7 +1269,7 @@ func.func @matmul_transpose_b_explicit(%arg0: memref<3x5xf32>, %arg1: memref<7x5
 // CHECK-SAME:                                           %[[VAL_0:.*]]: memref<3x5xf32>,
 // CHECK-SAME:                                           %[[VAL_1:.*]]: memref<7x5xf32>,
 // CHECK-SAME:                                           %[[VAL_2:.*]]: memref<3x7xf32>) {
-// CHECK:           linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK:           linalg.matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
 // CHECK:           return
 // CHECK:         }
 
@@ -1294,7 +1294,7 @@ func.func @matmul_transpose_a_b_explicit(%arg0: memref<5x3xf32>, %arg1: memref<7
 // CHECK-SAME:                                             %[[VAL_0:.*]]: memref<5x3xf32>,
 // CHECK-SAME:                                             %[[VAL_1:.*]]: memref<7x5xf32>,
 // CHECK-SAME:                                             %[[VAL_2:.*]]: memref<3x7xf32>) {
-// CHECK:           linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5x3xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK:           linalg.matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<5x3xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
 // CHECK:           return
 // CHECK:         }
 
@@ -1315,6 +1315,7 @@ func.func @matmul_bcast_a(%arg0: memref<5xf32>, %arg1: memref<5x7xf32>, %arg2: m
 // CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
 // CHECK-LABEL: func @matmul_bcast_a
 //       CHECK:   linalg.matmul
+//  CHECK-SAME:     indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
 //  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<5xf32>, memref<5x7xf32>)
 //  CHECK-SAME:     outs(%{{.+}} : memref<3x7xf32>)
 
@@ -1335,6 +1336,7 @@ func.func @matmul_bcast_a_dim1(%arg0: memref<5xf32>, %arg1: memref<5x7xf32>, %ar
 // CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
 // CHECK-LABEL: func @matmul_bcast_a_dim1
 //       CHECK:   linalg.matmul
+//  CHECK-SAME:     indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
 //  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<5xf32>, memref<5x7xf32>)
 //  CHECK-SAME:     outs(%{{.+}} : memref<3x7xf32>)
 
@@ -1355,6 +1357,7 @@ func.func @matmul_bcast_b(%arg0: memref<3x5xf32>, %arg1: memref<5xf32>, %arg2: m
 // CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
 // CHECK-LABEL: func @matmul_bcast_b
 //       CHECK:   linalg.matmul
+//  CHECK-SAME:     indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
 //  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<5xf32>)
 //  CHECK-SAME:     outs(%{{.+}} : memref<3x7xf32>)
 
@@ -1376,7 +1379,7 @@ func.func @matmul_bcast_a_b(%arg0: memref<5xf32>, %arg1: memref<5xf32>, %arg2: m
 // CHECK-LABEL:   func.func @matmul_bcast_a_b(
 // CHECK-SAME:                                %[[VAL_0:.*]]: memref<5xf32>, %[[VAL_1:.*]]: memref<5xf32>,
 // CHECK-SAME:                                %[[VAL_2:.*]]: memref<3x7xf32>) {
-// CHECK:           linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_1]]]
+// CHECK:           linalg.matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_1]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
 // CHECK:           return
 // CHECK:         }
 
@@ -1397,6 +1400,7 @@ func.func @matmul_bcast_b_dim1(%arg0: memref<3x5xf32>, %arg1: memref<5xf32>, %ar
 // CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
 // CHECK-LABEL: func @matmul_bcast_b_dim1
 //       CHECK:   linalg.matmul
+//  CHECK-SAME:     indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
 //  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<5xf32>)
 //  CHECK-SAME:     outs(%{{.+}} : memref<3x7xf32>)
 
@@ -1420,7 +1424,7 @@ func.func @dynamic_matmul_bcast_a(%arg0: memref<?xf32>, %arg1: memref<?x?xf32>,
 // CHECK-SAME:                                      %[[VAL_0:.*]]: memref<?xf32>,
 // CHECK-SAME:                                      %[[VAL_1:.*]]: memref<?x?xf32>,
 // CHECK-SAME:                                      %[[VAL_2:.*]]: memref<?x?xf32>) {
-// CHECK:           linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<?xf32>, memref<?x?xf32>) outs(%[[VAL_2]] : memref<?x?xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK:           linalg.matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<?xf32>, memref<?x?xf32>) outs(%[[VAL_2]] : memref<?x?xf32>)
 // CHECK:           return
 // CHECK:         }
 
@@ -1444,7 +1448,7 @@ func.func @matmul_bcast_a_transpose_b(%arg0: memref<5xf32>, %arg1: memref<7x5xf3
 // CHECK-SAME:                                  %[[VAL_0:.*]]: memref<5xf32>,
 // CHECK-SAME:                                  %[[VAL_1:.*]]: memref<7x5xf32>,
 // CHECK-SAME:                                  %[[VAL_2:.*]]: memref<3x7xf32>) {
-// CHECK:           linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK:           linalg.matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
 // CHECK:           return
 // CHECK:         }
 
@@ -1468,7 +1472,7 @@ func.func @matmul_bcast_b_transpose_a(%arg0: memref<5x3xf32>, %arg1: memref<5xf3
 // CHECK-SAME:                                          %[[VAL_0:.*]]: memref<5x3xf32>,
 // CHECK-SAME:                                          %[[VAL_1:.*]]: memref<5xf32>,
 // CHECK-SAME:                                          %[[VAL_2:.*]]: memref<3x7xf32>) {
-// CHECK:           linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5x3xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK:           linalg.matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<5x3xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
 // CHECK:           return
 // CHECK:         }
 

>From 18ff3a681e9b13d845adf3eb9bcc16177e6fb2e9 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Sat, 8 Feb 2025 09:08:39 -0800
Subject: [PATCH 03/11] Python formatting

---
 mlir/python/mlir/dialects/linalg/__init__.py |  8 +--
 mlir/test/python/dialects/linalg/ops.py      | 76 +++++++++++++-------
 2 files changed, 55 insertions(+), 29 deletions(-)

diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index 07e7c0ccd70025d..b56252f9a617ceb 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -154,7 +154,7 @@ def matmul(
     *,
     outs: Sequence[Union[Operation, OpView, Value]],
     indexing_maps: Sequence[AffineMapAttr],
-    cast: Optional[Union[TypeFn, Attribute]]=None
+    cast: Optional[Union[TypeFn, Attribute]] = None,
 ):
     inputs = [_get_op_result_or_value(input) for input in inputs]
     if len(outs) > 1:
@@ -167,7 +167,7 @@ def matmul(
         inputs=inputs,
         outputs=[init],
         indexing_maps=indexing_maps,
-        cast=cast
+        cast=cast,
     )
     fill_builtin_region(op.operation)
     return op
@@ -178,7 +178,7 @@ def contract(
     *,
     outs: Sequence[Union[Operation, OpView, Value]],
     indexing_maps: Sequence[AffineMapAttr],
-    cast: Optional[Union[TypeFn, Attribute]]=None
+    cast: Optional[Union[TypeFn, Attribute]] = None,
 ):
     inputs = [_get_op_result_or_value(input) for input in inputs]
     if len(outs) > 1:
@@ -191,7 +191,7 @@ def contract(
         inputs=inputs,
         outputs=[init],
         indexing_maps=indexing_maps,
-        cast=cast
+        cast=cast,
     )
     fill_builtin_region(op.operation)
     return op
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index 6baea4f917c128c..b487d72cbf9f358 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -303,50 +303,62 @@ def testMatmulOp():
             )
             def matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
                 # CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
-                op4 = linalg.MatmulOp(
+                res = linalg.MatmulOp(
                     result_tensors=(C.type,),
                     inputs=(A, B),
                     outputs=(C,),
-                    indexing_maps=[a_map, b_map, c_map]
+                    indexing_maps=[a_map, b_map, c_map],
                 )
-                linalg.fill_builtin_region(op4.operation)
+                linalg.fill_builtin_region(res.operation)
                 # CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
-                op5 = linalg.matmul((A, B), outs=(C,), indexing_maps=[a_map, b_map, c_map])
+                res = linalg.matmul(
+                    (A, B), outs=(C,), indexing_maps=[a_map, b_map, c_map]
+                )
 
                 # CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
-                op4 = linalg.MatmulOp(
+                res = linalg.MatmulOp(
                     result_tensors=(C.type,),
                     inputs=(A, Btransposed),
                     outputs=(C,),
-                    indexing_maps=[a_map, b_transposed_map, c_map]
+                    indexing_maps=[a_map, b_transposed_map, c_map],
                 )
-                linalg.fill_builtin_region(op4.operation)
+                linalg.fill_builtin_region(res.operation)
                 # CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
-                op5 = linalg.matmul((A, Btransposed), outs=(C,), indexing_maps=[a_map, b_transposed_map, c_map])
+                res = linalg.matmul(
+                    (A, Btransposed),
+                    outs=(C,),
+                    indexing_maps=[a_map, b_transposed_map, c_map],
+                )
 
                 # And now with memrefs...
 
                 # CHECK: linalg.matmul ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
-                op4 = linalg.MatmulOp(
+                res = linalg.MatmulOp(
                     result_tensors=[],
                     inputs=(Amem, Bmem),
                     outputs=(Cmem,),
-                    indexing_maps=[a_map, b_map, c_map]
+                    indexing_maps=[a_map, b_map, c_map],
                 )
-                linalg.fill_builtin_region(op4.operation)
+                linalg.fill_builtin_region(res.operation)
                 # CHECK: linalg.matmul ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
-                linalg.matmul((Amem, Bmem), outs=(Cmem,), indexing_maps=[a_map, b_map, c_map])
+                linalg.matmul(
+                    (Amem, Bmem), outs=(Cmem,), indexing_maps=[a_map, b_map, c_map]
+                )
 
                 # CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
-                op4 = linalg.MatmulOp(
+                res = linalg.MatmulOp(
                     result_tensors=[],
                     inputs=(Amem, Btransposedmem),
                     outputs=(Cmem,),
-                    indexing_maps=[a_map, b_transposed_map, c_map]
+                    indexing_maps=[a_map, b_transposed_map, c_map],
                 )
-                linalg.fill_builtin_region(op4.operation)
+                linalg.fill_builtin_region(res.operation)
                 # CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
-                linalg.matmul((Amem, Btransposedmem), outs=(Cmem,), indexing_maps=[a_map, b_transposed_map, c_map])
+                linalg.matmul(
+                    (Amem, Btransposedmem),
+                    outs=(Cmem,),
+                    indexing_maps=[a_map, b_transposed_map, c_map],
+                )
 
         print(module)
 
@@ -395,28 +407,36 @@ def testContractOp():
                 # CHECK-SAME:                         %[[Cmem:.*]]: memref<4x12xf32>)
                 MemRefType.get(c_shape, f32),
             )
-            def matmul_as_contract_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
+            def matmul_as_contract_op(
+                A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem
+            ):
                 # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$B_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
                 op4 = linalg.ContractOp(
                     result_tensors=(C.type,),
                     inputs=(A, B),
                     outputs=(C,),
-                    indexing_maps=[a_map, b_map, c_map]
+                    indexing_maps=[a_map, b_map, c_map],
                 )
                 linalg.fill_builtin_region(op4.operation)
                 # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$B_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
-                op5 = linalg.contract((A, B), outs=(C,), indexing_maps=[a_map, b_map, c_map])
+                op5 = linalg.contract(
+                    (A, B), outs=(C,), indexing_maps=[a_map, b_map, c_map]
+                )
 
                 # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
                 op4 = linalg.ContractOp(
                     result_tensors=(C.type,),
                     inputs=(A, Btransposed),
                     outputs=(C,),
-                    indexing_maps=[a_map, b_transposed_map, c_map]
+                    indexing_maps=[a_map, b_transposed_map, c_map],
                 )
                 linalg.fill_builtin_region(op4.operation)
                 # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
-                op5 = linalg.contract((A, Btransposed), outs=(C,), indexing_maps=[a_map, b_transposed_map, c_map])
+                op5 = linalg.contract(
+                    (A, Btransposed),
+                    outs=(C,),
+                    indexing_maps=[a_map, b_transposed_map, c_map],
+                )
                 # And now with memrefs...
 
                 # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$B_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
@@ -424,21 +444,27 @@ def matmul_as_contract_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem
                     result_tensors=[],
                     inputs=(Amem, Bmem),
                     outputs=(Cmem,),
-                    indexing_maps=[a_map, b_map, c_map]
+                    indexing_maps=[a_map, b_map, c_map],
                 )
                 linalg.fill_builtin_region(op4.operation)
                 # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$B_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
-                linalg.contract((Amem, Bmem), outs=(Cmem,), indexing_maps=[a_map, b_map, c_map])
+                linalg.contract(
+                    (Amem, Bmem), outs=(Cmem,), indexing_maps=[a_map, b_map, c_map]
+                )
 
                 # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
                 op4 = linalg.ContractOp(
                     result_tensors=[],
                     inputs=(Amem, Btransposedmem),
                     outputs=(Cmem,),
-                    indexing_maps=[a_map, b_transposed_map, c_map]
+                    indexing_maps=[a_map, b_transposed_map, c_map],
                 )
                 linalg.fill_builtin_region(op4.operation)
                 # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
-                linalg.contract((Amem, Btransposedmem), outs=(Cmem,), indexing_maps=[a_map, b_transposed_map, c_map])
+                linalg.contract(
+                    (Amem, Btransposedmem),
+                    outs=(Cmem,),
+                    indexing_maps=[a_map, b_transposed_map, c_map],
+                )
 
         print(module)

>From 2e6bf665bbeca637a5e8d21689c6b128fc23275d Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Sat, 8 Feb 2025 13:46:41 -0800
Subject: [PATCH 04/11] Make indexing_maps optional on matmul, as it should be

---
 mlir/python/mlir/dialects/linalg/__init__.py | 2 +-
 mlir/test/python/dialects/linalg/ops.py      | 8 ++------
 2 files changed, 3 insertions(+), 7 deletions(-)

diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index b56252f9a617ceb..590229818c7440a 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -153,7 +153,7 @@ def matmul(
     inputs: Sequence[Union[Operation, OpView, Value]],
     *,
     outs: Sequence[Union[Operation, OpView, Value]],
-    indexing_maps: Sequence[AffineMapAttr],
+    indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
     cast: Optional[Union[TypeFn, Attribute]] = None,
 ):
     inputs = [_get_op_result_or_value(input) for input in inputs]
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index b487d72cbf9f358..36012303a52a231 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -307,13 +307,10 @@ def matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
                     result_tensors=(C.type,),
                     inputs=(A, B),
                     outputs=(C,),
-                    indexing_maps=[a_map, b_map, c_map],
                 )
                 linalg.fill_builtin_region(res.operation)
                 # CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
-                res = linalg.matmul(
-                    (A, B), outs=(C,), indexing_maps=[a_map, b_map, c_map]
-                )
+                res = linalg.matmul((A, B), outs=(C,))
 
                 # CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
                 res = linalg.MatmulOp(
@@ -337,12 +334,11 @@ def matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
                     result_tensors=[],
                     inputs=(Amem, Bmem),
                     outputs=(Cmem,),
-                    indexing_maps=[a_map, b_map, c_map],
                 )
                 linalg.fill_builtin_region(res.operation)
                 # CHECK: linalg.matmul ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
                 linalg.matmul(
-                    (Amem, Bmem), outs=(Cmem,), indexing_maps=[a_map, b_map, c_map]
+                    (Amem, Bmem), outs=(Cmem,)
                 )
 
                 # CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)

>From 85426f6019a0faf7b592b91b23133b4298493c24 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Sat, 8 Feb 2025 15:24:21 -0800
Subject: [PATCH 05/11] Fix matmul's indexing_maps issue

---
 .../Dialect/Linalg/IR/LinalgStructuredOps.td   | 18 +++++++++++++++++-
 mlir/python/mlir/dialects/linalg/__init__.py   |  5 +++++
 mlir/test/python/dialects/linalg/ops.py        |  4 +---
 3 files changed, 23 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 6146ff09482fbad..dc1c93355e2582b 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -555,6 +555,22 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
 // Op definition for MatmulOp
 //===----------------------------------------------------------------------===//
 
+
+// DONOTMERGE(rolfmorel): explain why the below is necessary
+def DefaultValuedMatmulIndexingMapsAttr :
+    Attr<AffineMapArrayAttr.predicate, AffineMapArrayAttr.summary> {
+  let storageType = AffineMapArrayAttr.storageType;
+  let returnType = AffineMapArrayAttr.returnType;
+  let convertFromStorage = AffineMapArrayAttr.convertFromStorage;
+  let constBuilderCall = "$_builder.getAffineMapArrayAttr($0.empty() ? MatmulOp::getDefaultIndexingMaps($_builder.getContext()) : $0)";
+  let defaultValue = "SmallVector<AffineMap>()";
+  let valueType = AffineMapArrayAttr.valueType;
+  let isOptional = 1;
+
+  let baseAttr = AffineMapArrayAttr;
+}
+
+
 def MatmulOp : LinalgStructuredBase_Op<"matmul", [
                AttrSizedOperandSegments,
                LinalgContractionOpInterface]> {
@@ -606,7 +622,7 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
     let arguments = (ins
       Variadic<AnyType>:$inputs,
       Variadic<AnyShaped>:$outputs,
-      DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps,
+      DefaultValuedMatmulIndexingMapsAttr:$indexing_maps, // DONOTMERGE(rolfmorel): explain why this is necessary
       DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
     );
     let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index 590229818c7440a..ab47946e7eab772 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -149,6 +149,11 @@ def __init__(
 generic = region_op(GenericOp_, terminator=YieldOp)
 
 
+ at register_attribute_builder("DefaultValuedMatmulIndexingMapsAttr")
+def _DefaultValuedMatmulIndexingMapsAttr(x, context):
+    return ArrayAttr.get([AffineMapAttr.get(v) for v in x])
+
+
 def matmul(
     inputs: Sequence[Union[Operation, OpView, Value]],
     *,
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index 36012303a52a231..c332f7001fcb237 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -337,9 +337,7 @@ def matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
                 )
                 linalg.fill_builtin_region(res.operation)
                 # CHECK: linalg.matmul ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
-                linalg.matmul(
-                    (Amem, Bmem), outs=(Cmem,)
-                )
+                linalg.matmul((Amem, Bmem), outs=(Cmem,))
 
                 # CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
                 res = linalg.MatmulOp(

>From e98838da2635fea28043b3a274e5a69dc4c570c4 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Sun, 9 Feb 2025 10:37:03 -0800
Subject: [PATCH 06/11] Get rid of "whole new type" by introducing "context
 dependent" default valued attr class

Should probably be moved to CommonAttrConstraints.td
---
 .../Dialect/Linalg/IR/LinalgStructuredOps.td  | 25 +++++++++++--------
 mlir/python/mlir/dialects/linalg/__init__.py  |  5 ----
 2 files changed, 15 insertions(+), 15 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index dc1c93355e2582b..f8cf63e6e8a3f0a 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -557,17 +557,19 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
 
 
 // DONOTMERGE(rolfmorel): explain why the below is necessary
-def DefaultValuedMatmulIndexingMapsAttr :
-    Attr<AffineMapArrayAttr.predicate, AffineMapArrayAttr.summary> {
-  let storageType = AffineMapArrayAttr.storageType;
-  let returnType = AffineMapArrayAttr.returnType;
-  let convertFromStorage = AffineMapArrayAttr.convertFromStorage;
-  let constBuilderCall = "$_builder.getAffineMapArrayAttr($0.empty() ? MatmulOp::getDefaultIndexingMaps($_builder.getContext()) : $0)";
-  let defaultValue = "SmallVector<AffineMap>()";
-  let valueType = AffineMapArrayAttr.valueType;
+class DefaultValuedContextDependentAttr<Attr attr,
+                                        string builderCall,
+                                        string default> :
+    Attr<attr.predicate, attr.summary> {
+  let storageType = attr.storageType;
+  let returnType = attr.returnType;
+  let convertFromStorage = attr.convertFromStorage;
+  let constBuilderCall = builderCall;  // DONOTMERGE(rolfmorel): explain why this needs to be a parameter
+  let defaultValue = default;
+  let valueType = attr.valueType;
   let isOptional = 1;
 
-  let baseAttr = AffineMapArrayAttr;
+  let baseAttr = attr;
 }
 
 
@@ -622,7 +624,10 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
     let arguments = (ins
       Variadic<AnyType>:$inputs,
       Variadic<AnyShaped>:$outputs,
-      DefaultValuedMatmulIndexingMapsAttr:$indexing_maps, // DONOTMERGE(rolfmorel): explain why this is necessary
+      DefaultValuedContextDependentAttr<AffineMapArrayAttr,
+          builderCall = [{ $_builder.getAffineMapArrayAttr(
+            $0.empty() ? MatmulOp::getDefaultIndexingMaps($_builder.getContext()) : $0
+          )}], default = "SmallVector<AffineMap>()">:$indexing_maps,
       DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
     );
     let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index ab47946e7eab772..590229818c7440a 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -149,11 +149,6 @@ def __init__(
 generic = region_op(GenericOp_, terminator=YieldOp)
 
 
- at register_attribute_builder("DefaultValuedMatmulIndexingMapsAttr")
-def _DefaultValuedMatmulIndexingMapsAttr(x, context):
-    return ArrayAttr.get([AffineMapAttr.get(v) for v in x])
-
-
 def matmul(
     inputs: Sequence[Union[Operation, OpView, Value]],
     *,

>From 931b47369710207c6e29c66124bcbff0318feb81 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Sun, 9 Feb 2025 10:52:18 -0800
Subject: [PATCH 07/11] Move builderCall parameter from custom attr class to
 DefaultValued(Optional)Attr

---
 .../include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td | 4 ++--
 mlir/include/mlir/IR/CommonAttrConstraints.td             | 8 ++++----
 2 files changed, 6 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index f8cf63e6e8a3f0a..6a7e7876d938107 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -624,10 +624,10 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
     let arguments = (ins
       Variadic<AnyType>:$inputs,
       Variadic<AnyShaped>:$outputs,
-      DefaultValuedContextDependentAttr<AffineMapArrayAttr,
+      DefaultValuedOptionalAttr<AffineMapArrayAttr, "SmallVector<AffineMap>()",
           builderCall = [{ $_builder.getAffineMapArrayAttr(
             $0.empty() ? MatmulOp::getDefaultIndexingMaps($_builder.getContext()) : $0
-          )}], default = "SmallVector<AffineMap>()">:$indexing_maps,
+          )}]>:$indexing_maps,
       DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
     );
     let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
diff --git a/mlir/include/mlir/IR/CommonAttrConstraints.td b/mlir/include/mlir/IR/CommonAttrConstraints.td
index 599f5ecba5803b0..f677009334fa21d 100644
--- a/mlir/include/mlir/IR/CommonAttrConstraints.td
+++ b/mlir/include/mlir/IR/CommonAttrConstraints.td
@@ -90,7 +90,7 @@ class DialectAttr<Dialect d, Pred condition, string summary = ""> :
 // Attribute modifier definition
 
 // Decorates an attribute to have an (unvalidated) default value if not present.
-class DefaultValuedAttr<Attr attr, string val> :
+class DefaultValuedAttr<Attr attr, string val, string builderCall = ""> :
     Attr<attr.predicate, attr.summary> {
   // Construct this attribute with the input attribute and change only
   // the default value.
@@ -98,7 +98,7 @@ class DefaultValuedAttr<Attr attr, string val> :
   let storageType = attr.storageType;
   let returnType = attr.returnType;
   let convertFromStorage = attr.convertFromStorage;
-  let constBuilderCall = attr.constBuilderCall;
+  let constBuilderCall = !if(!eq(builderCall, ""), attr.constBuilderCall, builderCall);
   let defaultValue = val;
   let valueType = attr.valueType;
 
@@ -107,7 +107,7 @@ class DefaultValuedAttr<Attr attr, string val> :
 
 // Decorates an optional attribute to have an (unvalidated) default value
 // return by ODS generated accessors if not present.
-class DefaultValuedOptionalAttr<Attr attr, string val> :
+class DefaultValuedOptionalAttr<Attr attr, string val, string builderCall = ""> :
     Attr<attr.predicate, attr.summary> {
   // Construct this attribute with the input attribute and change only
   // the default value.
@@ -115,7 +115,7 @@ class DefaultValuedOptionalAttr<Attr attr, string val> :
   let storageType = attr.storageType;
   let returnType = attr.returnType;
   let convertFromStorage = attr.convertFromStorage;
-  let constBuilderCall = attr.constBuilderCall;
+  let constBuilderCall = !if(!eq(builderCall, ""), attr.constBuilderCall, builderCall);
   let defaultValue = val;
   let valueType = attr.valueType;
   let isOptional = 1;

>From 538566a7b5f746bf7e1d6bdd6706a10b6dda62c7 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Sun, 9 Feb 2025 10:57:08 -0800
Subject: [PATCH 08/11] Remove junk

---
 .../Dialect/Linalg/IR/LinalgStructuredOps.td   | 18 ------------------
 1 file changed, 18 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 6a7e7876d938107..f60b369a1200815 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -555,24 +555,6 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
 // Op definition for MatmulOp
 //===----------------------------------------------------------------------===//
 
-
-// DONOTMERGE(rolfmorel): explain why the below is necessary
-class DefaultValuedContextDependentAttr<Attr attr,
-                                        string builderCall,
-                                        string default> :
-    Attr<attr.predicate, attr.summary> {
-  let storageType = attr.storageType;
-  let returnType = attr.returnType;
-  let convertFromStorage = attr.convertFromStorage;
-  let constBuilderCall = builderCall;  // DONOTMERGE(rolfmorel): explain why this needs to be a parameter
-  let defaultValue = default;
-  let valueType = attr.valueType;
-  let isOptional = 1;
-
-  let baseAttr = attr;
-}
-
-
 def MatmulOp : LinalgStructuredBase_Op<"matmul", [
                AttrSizedOperandSegments,
                LinalgContractionOpInterface]> {

>From c81c97f559efcbe1e5da69d22647da722004e595 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Sun, 9 Feb 2025 12:14:35 -0800
Subject: [PATCH 09/11] Allow access to `$_builder` in Attr's `defaultValue`

Reverts changes to DefaultValued(Optional)Attr.
---
 .../Dialect/Linalg/IR/LinalgStructuredOps.td  |  8 ++---
 mlir/include/mlir/IR/CommonAttrConstraints.td | 11 ++++---
 mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp   | 31 +++++++++++++------
 mlir/tools/mlir-tblgen/OpFormatGen.cpp        | 22 ++++++++-----
 mlir/tools/mlir-tblgen/RewriterGen.cpp        |  2 +-
 5 files changed, 47 insertions(+), 27 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index f60b369a1200815..29cb8035b583b54 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -606,10 +606,10 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
     let arguments = (ins
       Variadic<AnyType>:$inputs,
       Variadic<AnyShaped>:$outputs,
-      DefaultValuedOptionalAttr<AffineMapArrayAttr, "SmallVector<AffineMap>()",
-          builderCall = [{ $_builder.getAffineMapArrayAttr(
-            $0.empty() ? MatmulOp::getDefaultIndexingMaps($_builder.getContext()) : $0
-          )}]>:$indexing_maps,
+      DefaultValuedOptionalAttr<
+        AffineMapArrayAttr,
+        "MatmulOp::getDefaultIndexingMaps($_builder.getContext())"
+      >:$indexing_maps,
       DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
     );
     let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
diff --git a/mlir/include/mlir/IR/CommonAttrConstraints.td b/mlir/include/mlir/IR/CommonAttrConstraints.td
index f677009334fa21d..2beb1e8110afe8a 100644
--- a/mlir/include/mlir/IR/CommonAttrConstraints.td
+++ b/mlir/include/mlir/IR/CommonAttrConstraints.td
@@ -50,6 +50,9 @@ class Attr<Pred condition, string summary = ""> :
 
   // Default value for attribute.
   // Requires a constBuilderCall defined.
+  //
+  // Format: `$_builder` will be expanded to the relevant builder, e.g. to allow
+  //         access to the current context.
   string defaultValue = ?;
 
   // The value type of this attribute. This corresponds to the mlir::Type that
@@ -90,7 +93,7 @@ class DialectAttr<Dialect d, Pred condition, string summary = ""> :
 // Attribute modifier definition
 
 // Decorates an attribute to have an (unvalidated) default value if not present.
-class DefaultValuedAttr<Attr attr, string val, string builderCall = ""> :
+class DefaultValuedAttr<Attr attr, string val> :
     Attr<attr.predicate, attr.summary> {
   // Construct this attribute with the input attribute and change only
   // the default value.
@@ -98,7 +101,7 @@ class DefaultValuedAttr<Attr attr, string val, string builderCall = ""> :
   let storageType = attr.storageType;
   let returnType = attr.returnType;
   let convertFromStorage = attr.convertFromStorage;
-  let constBuilderCall = !if(!eq(builderCall, ""), attr.constBuilderCall, builderCall);
+  let constBuilderCall = attr.constBuilderCall;
   let defaultValue = val;
   let valueType = attr.valueType;
 
@@ -107,7 +110,7 @@ class DefaultValuedAttr<Attr attr, string val, string builderCall = ""> :
 
 // Decorates an optional attribute to have an (unvalidated) default value
 // return by ODS generated accessors if not present.
-class DefaultValuedOptionalAttr<Attr attr, string val, string builderCall = ""> :
+class DefaultValuedOptionalAttr<Attr attr, string val> :
     Attr<attr.predicate, attr.summary> {
   // Construct this attribute with the input attribute and change only
   // the default value.
@@ -115,7 +118,7 @@ class DefaultValuedOptionalAttr<Attr attr, string val, string builderCall = "">
   let storageType = attr.storageType;
   let returnType = attr.returnType;
   let convertFromStorage = attr.convertFromStorage;
-  let constBuilderCall = !if(!eq(builderCall, ""), attr.constBuilderCall, builderCall);
+  let constBuilderCall = attr.constBuilderCall;
   let defaultValue = val;
   let valueType = attr.valueType;
   let isOptional = 1;
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index a970cbc5cacebe3..629e863dac5e3af 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -1334,8 +1334,9 @@ static void emitAttrGetterWithReturnType(FmtContext &fctx,
       PrintFatalError("DefaultValuedAttr of type " + attr.getAttrDefName() +
                       " must have a constBuilder");
     }
-    std::string defaultValue = std::string(
-        tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
+    std::string defaultValue =
+        std::string(tgfmt(attr.getConstBuilderTemplate(), &fctx,
+                          tgfmt(attr.getDefaultValue(), &fctx)));
     body << "    if (!attr)\n      return "
          << tgfmt(attr.getConvertFromStorageCall(),
                   &fctx.withSelf(defaultValue))
@@ -1467,6 +1468,7 @@ void OpEmitter::genPropertiesSupport() {
         os << "   if (!attr) attr = dict.get(\"result_segment_sizes\");";
       }
 
+      fctx.withBuilder(odsBuilder);
       setPropMethod << "{\n"
                     << formatv(propFromAttrFmt,
                                tgfmt(prop.getConvertFromAttributeCall(),
@@ -1479,7 +1481,7 @@ void OpEmitter::genPropertiesSupport() {
                                  prop.getStorageTypeValueOverride());
       } else if (prop.hasDefaultValue()) {
         setPropMethod << formatv(attrGetDefaultFmt, name,
-                                 prop.getDefaultValue());
+                                 tgfmt(prop.getDefaultValue(), &fctx));
       } else {
         setPropMethod << formatv(attrGetNoDefaultFmt, name);
       }
@@ -2919,6 +2921,9 @@ getBuilderSignature(const Builder &builder) {
   arguments.emplace_back("::mlir::OpBuilder &", odsBuilder);
   arguments.emplace_back("::mlir::OperationState &", builderOpState);
 
+  FmtContext fctx;
+  fctx.withBuilder(odsBuilder);
+
   for (unsigned i = 0, e = params.size(); i < e; ++i) {
     // If no name is provided, generate one.
     std::optional<StringRef> paramName = params[i].getName();
@@ -2931,7 +2936,7 @@ getBuilderSignature(const Builder &builder) {
       defaultValue = *defaultParamValue;
 
     arguments.emplace_back(params[i].getCppType(), std::move(name),
-                           defaultValue);
+                           tgfmt(defaultValue, &fctx));
   }
 
   return arguments;
@@ -3189,6 +3194,9 @@ void OpEmitter::buildParamList(SmallVectorImpl<MethodParameter> &paramList,
     }
   }
 
+  FmtContext fctx;
+  fctx.withBuilder(odsBuilder);
+
   for (int i = 0, e = op.getNumArgs(), numOperands = 0; i < e; ++i) {
     Argument arg = op.getArg(i);
     if (const auto *operand =
@@ -3210,7 +3218,7 @@ void OpEmitter::buildParamList(SmallVectorImpl<MethodParameter> &paramList,
       StringRef type = prop.getInterfaceType();
       std::string defaultValue;
       if (prop.hasDefaultValue() && i >= defaultValuedAttrLikeStartIndex) {
-        defaultValue = prop.getDefaultValue();
+        defaultValue = tgfmt(prop.getDefaultValue(), &fctx);
       }
       bool isOptional = prop.hasDefaultValue();
       paramList.emplace_back(type, propArg->name, StringRef(defaultValue),
@@ -3242,7 +3250,7 @@ void OpEmitter::buildParamList(SmallVectorImpl<MethodParameter> &paramList,
     if (i >= defaultValuedAttrStartIndex) {
       if (attrParamKind == AttrParamKind::UnwrappedValue &&
           canUseUnwrappedRawValue(attr))
-        defaultValue += attr.getDefaultValue();
+        defaultValue += tgfmt(attr.getDefaultValue(), &fctx);
       else
         defaultValue += "nullptr";
     }
@@ -4172,6 +4180,9 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
       staticVerifierEmitter(staticVerifierEmitter),
       emitHelper(op, /*emitForOp=*/false) {
 
+  FmtContext fctx;
+  fctx.withBuilder(odsBuilder);
+
   genericAdaptorBase.declare<VisibilityDeclaration>(Visibility::Public);
   bool useProperties = emitHelper.hasProperties();
   if (useProperties) {
@@ -4212,7 +4223,7 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
         if (prop.hasStorageTypeValueOverride())
           os << " = " << prop.getStorageTypeValueOverride();
         else if (prop.hasDefaultValue())
-          os << " = " << prop.getDefaultValue();
+          os << " = " << tgfmt(prop.getDefaultValue(), &fctx);
         comparatorOs << "        rhs." << name << " == this->" << name
                      << " &&\n";
         // Emit accessors using the interface type.
@@ -4454,7 +4465,6 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
   if (auto *m = genericAdaptor.addMethod("RangeT", "getOperands"))
     m->body() << "  return odsOperands;";
 
-  FmtContext fctx;
   fctx.withBuilder("::mlir::Builder(odsAttrs.getContext())");
 
   // Generate named accessor with Attribute return type.
@@ -4481,8 +4491,9 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
       // Use the default value if attribute is not set.
       // TODO: this is inefficient, we are recreating the attribute for every
       // call. This should be set instead.
-      std::string defaultValue = std::string(
-          tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
+      std::string defaultValue =
+          std::string(tgfmt(attr.getConstBuilderTemplate(), &fctx,
+                            tgfmt(attr.getDefaultValue(), &fctx)));
       body << "if (!attr)\n  attr = " << defaultValue << ";\n";
     }
     body << "return attr;\n";
diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index f03a3bfd398ed68..fe724e86d670785 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -1999,7 +1999,7 @@ static void genNonDefaultValueCheck(MethodBody &body, const Operator &op,
     fctx.withBuilder("::mlir::OpBuilder((*this)->getContext())");
     body << getter << "Attr() != "
          << tgfmt(attr.getConstBuilderTemplate(), &fctx,
-                  attr.getDefaultValue());
+                  tgfmt(attr.getDefaultValue(), &fctx));
   }
   if (optionalAndDefault)
     body << ")";
@@ -2007,8 +2007,10 @@ static void genNonDefaultValueCheck(MethodBody &body, const Operator &op,
 
 static void genNonDefaultValueCheck(MethodBody &body, const Operator &op,
                                     PropertyVariable &propElement) {
-  body << op.getGetterName(propElement.getVar()->name)
-       << "() != " << propElement.getVar()->prop.getDefaultValue();
+  FmtContext fctx;
+  fctx.withBuilder("::mlir::OpBuilder((*this)->getContext())");
+  body << op.getGetterName(propElement.getVar()->name) << "() != "
+       << tgfmt(propElement.getVar()->prop.getDefaultValue(), &fctx);
 }
 
 /// Elide the variadic segment size attributes if necessary.
@@ -2045,8 +2047,9 @@ static void genPropDictPrinter(OperationFormat &fmt, Operator &op,
       const StringRef &name = namedAttr.name;
       FmtContext fctx;
       fctx.withBuilder("odsBuilder");
-      std::string defaultValue = std::string(
-          tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
+      std::string defaultValue =
+          std::string(tgfmt(attr.getConstBuilderTemplate(), &fctx,
+                            tgfmt(attr.getDefaultValue(), &fctx)));
       body << "  {\n";
       body << "     ::mlir::Builder odsBuilder(getContext());\n";
       body << "     ::mlir::Attribute attr = " << op.getGetterName(name)
@@ -2059,8 +2062,10 @@ static void genPropDictPrinter(OperationFormat &fmt, Operator &op,
   // Similarly, elide default-valued properties.
   for (const NamedProperty &prop : op.getProperties()) {
     if (prop.prop.hasDefaultValue()) {
+      FmtContext fctx;
+      fctx.withBuilder("odsBuilder");
       body << "  if (" << op.getGetterName(prop.name)
-           << "() == " << prop.prop.getDefaultValue() << ") {";
+           << "() == " << tgfmt(prop.prop.getDefaultValue(), &fctx) << ") {";
       body << "    elidedProps.push_back(\"" << prop.name << "\");\n";
       body << "  }\n";
     }
@@ -2094,8 +2099,9 @@ static void genAttrDictPrinter(OperationFormat &fmt, Operator &op,
       const StringRef &name = namedAttr.name;
       FmtContext fctx;
       fctx.withBuilder("odsBuilder");
-      std::string defaultValue = std::string(
-          tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
+      std::string defaultValue =
+          std::string(tgfmt(attr.getConstBuilderTemplate(), &fctx,
+                            tgfmt(attr.getDefaultValue(), &fctx)));
       body << "  {\n";
       body << "     ::mlir::Builder odsBuilder(getContext());\n";
       body << "     ::mlir::Attribute attr = " << op.getGetterName(name)
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index a041c4d3277798d..f6eb5bdfe568e00 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -879,7 +879,7 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
   if (attr.hasDefaultValue()) {
     os << "if (!tblgen_attr) tblgen_attr = "
        << std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx,
-                            attr.getDefaultValue()))
+                            tgfmt(attr.getDefaultValue(), &fmtCtx)))
        << ";\n";
   } else if (attr.isOptional()) {
     // For a missing attribute that is optional according to definition, we

>From f83f7aea76847217f25007778f59410d0568d318 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Mon, 10 Feb 2025 03:06:00 -0800
Subject: [PATCH 10/11] Switch argument format to that of OpDSL-derived linalg
 ops

---
 mlir/python/mlir/dialects/linalg/__init__.py | 14 ++++++--------
 mlir/test/python/dialects/linalg/ops.py      | 16 ++++++++--------
 2 files changed, 14 insertions(+), 16 deletions(-)

diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index 590229818c7440a..5cda4769d593f35 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -150,13 +150,12 @@ def __init__(
 
 
 def matmul(
-    inputs: Sequence[Union[Operation, OpView, Value]],
-    *,
+    *ins: Union[Operation, OpView, Value],
     outs: Sequence[Union[Operation, OpView, Value]],
     indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
     cast: Optional[Union[TypeFn, Attribute]] = None,
 ):
-    inputs = [_get_op_result_or_value(input) for input in inputs]
+    ins = [_get_op_result_or_value(input) for input in ins]
     if len(outs) > 1:
         raise ValueError(f"{outs=} must have length 1.")
     init = _get_op_result_or_value(outs[0])
@@ -164,7 +163,7 @@ def matmul(
 
     op = MatmulOp(
         result_tensors=result_types,
-        inputs=inputs,
+        inputs=ins,
         outputs=[init],
         indexing_maps=indexing_maps,
         cast=cast,
@@ -174,13 +173,12 @@ def matmul(
 
 
 def contract(
-    inputs: Sequence[Union[Operation, OpView, Value]],
-    *,
+    *ins: Union[Operation, OpView, Value],
     outs: Sequence[Union[Operation, OpView, Value]],
     indexing_maps: Sequence[AffineMapAttr],
     cast: Optional[Union[TypeFn, Attribute]] = None,
 ):
-    inputs = [_get_op_result_or_value(input) for input in inputs]
+    ins = [_get_op_result_or_value(input) for input in ins]
     if len(outs) > 1:
         raise ValueError(f"{outs=} must have length 1.")
     init = _get_op_result_or_value(outs[0])
@@ -188,7 +186,7 @@ def contract(
 
     op = ContractOp(
         result_tensors=result_types,
-        inputs=inputs,
+        inputs=ins,
         outputs=[init],
         indexing_maps=indexing_maps,
         cast=cast,
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index c332f7001fcb237..02ed61ab4a8ead8 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -310,7 +310,7 @@ def matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
                 )
                 linalg.fill_builtin_region(res.operation)
                 # CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
-                res = linalg.matmul((A, B), outs=(C,))
+                res = linalg.matmul(A, B, outs=(C,))
 
                 # CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
                 res = linalg.MatmulOp(
@@ -322,7 +322,7 @@ def matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
                 linalg.fill_builtin_region(res.operation)
                 # CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
                 res = linalg.matmul(
-                    (A, Btransposed),
+                    A, Btransposed,
                     outs=(C,),
                     indexing_maps=[a_map, b_transposed_map, c_map],
                 )
@@ -337,7 +337,7 @@ def matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
                 )
                 linalg.fill_builtin_region(res.operation)
                 # CHECK: linalg.matmul ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
-                linalg.matmul((Amem, Bmem), outs=(Cmem,))
+                linalg.matmul(Amem, Bmem, outs=(Cmem,))
 
                 # CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
                 res = linalg.MatmulOp(
@@ -349,7 +349,7 @@ def matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
                 linalg.fill_builtin_region(res.operation)
                 # CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
                 linalg.matmul(
-                    (Amem, Btransposedmem),
+                    Amem, Btransposedmem,
                     outs=(Cmem,),
                     indexing_maps=[a_map, b_transposed_map, c_map],
                 )
@@ -414,7 +414,7 @@ def matmul_as_contract_op(
                 linalg.fill_builtin_region(op4.operation)
                 # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$B_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
                 op5 = linalg.contract(
-                    (A, B), outs=(C,), indexing_maps=[a_map, b_map, c_map]
+                    A, B, outs=(C,), indexing_maps=[a_map, b_map, c_map]
                 )
 
                 # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
@@ -427,7 +427,7 @@ def matmul_as_contract_op(
                 linalg.fill_builtin_region(op4.operation)
                 # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
                 op5 = linalg.contract(
-                    (A, Btransposed),
+                    A, Btransposed,
                     outs=(C,),
                     indexing_maps=[a_map, b_transposed_map, c_map],
                 )
@@ -443,7 +443,7 @@ def matmul_as_contract_op(
                 linalg.fill_builtin_region(op4.operation)
                 # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$B_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
                 linalg.contract(
-                    (Amem, Bmem), outs=(Cmem,), indexing_maps=[a_map, b_map, c_map]
+                    Amem, Bmem, outs=(Cmem,), indexing_maps=[a_map, b_map, c_map]
                 )
 
                 # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
@@ -456,7 +456,7 @@ def matmul_as_contract_op(
                 linalg.fill_builtin_region(op4.operation)
                 # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
                 linalg.contract(
-                    (Amem, Btransposedmem),
+                    Amem, Btransposedmem,
                     outs=(Cmem,),
                     indexing_maps=[a_map, b_transposed_map, c_map],
                 )

>From 7a5234f4fe89d79124fc4adee322cab312e31920 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Mon, 10 Feb 2025 03:14:23 -0800
Subject: [PATCH 11/11] Py formatting fix

---
 mlir/test/python/dialects/linalg/ops.py | 12 ++++++++----
 1 file changed, 8 insertions(+), 4 deletions(-)

diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index 02ed61ab4a8ead8..94f8ea4faf4a806 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -322,7 +322,8 @@ def matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
                 linalg.fill_builtin_region(res.operation)
                 # CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
                 res = linalg.matmul(
-                    A, Btransposed,
+                    A,
+                    Btransposed,
                     outs=(C,),
                     indexing_maps=[a_map, b_transposed_map, c_map],
                 )
@@ -349,7 +350,8 @@ def matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
                 linalg.fill_builtin_region(res.operation)
                 # CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
                 linalg.matmul(
-                    Amem, Btransposedmem,
+                    Amem,
+                    Btransposedmem,
                     outs=(Cmem,),
                     indexing_maps=[a_map, b_transposed_map, c_map],
                 )
@@ -427,7 +429,8 @@ def matmul_as_contract_op(
                 linalg.fill_builtin_region(op4.operation)
                 # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
                 op5 = linalg.contract(
-                    A, Btransposed,
+                    A,
+                    Btransposed,
                     outs=(C,),
                     indexing_maps=[a_map, b_transposed_map, c_map],
                 )
@@ -456,7 +459,8 @@ def matmul_as_contract_op(
                 linalg.fill_builtin_region(op4.operation)
                 # CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
                 linalg.contract(
-                    Amem, Btransposedmem,
+                    Amem,
+                    Btransposedmem,
                     outs=(Cmem,),
                     indexing_maps=[a_map, b_transposed_map, c_map],
                 )



More information about the Mlir-commits mailing list