[Mlir-commits] [mlir] [MLIR][Linalg] Introduce Python API for linalg.batch_matmul Ops. (PR #127614)

Md Asghar Ahmad Shahid llvmlistbot at llvm.org
Wed Feb 19 04:55:18 PST 2025


https://github.com/shahidact updated https://github.com/llvm/llvm-project/pull/127614

>From 243a16b478d537920042e764312f8cdc6c512ff2 Mon Sep 17 00:00:00 2001
From: mshahid <md.asghar.ahmad.shahid at intel.com>
Date: Tue, 18 Feb 2025 01:39:45 -0800
Subject: [PATCH 1/4] [MLIR][Linalg] Introduce Python API for
 linalg.batch_matmul Ops.

As linalg.batch_matmul has moved into tablegen from OpDSL, its derived
python wrapper no longer exist.This patch adds the required python
wrapper.

Also refactors the BatchmatmulOp printer to make it consistent with its
parser.
---
 .../Dialect/Linalg/IR/LinalgStructuredOps.td  |  5 +-
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 10 +-
 mlir/python/mlir/dialects/linalg/__init__.py  | 20 ++++
 mlir/test/Dialect/Linalg/named-ops.mlir       | 10 +-
 mlir/test/python/dialects/linalg/ops.py       | 99 +++++++++++++++++++
 5 files changed, 133 insertions(+), 11 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 6a439bfb09078..7ce6c4a353afe 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -858,7 +858,10 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
     let arguments = (ins
       Variadic<AnyType>:$inputs,
       Variadic<AnyShaped>:$outputs,
-      DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps
+      DefaultValuedOptionalAttr<
+       AffineMapArrayAttr,
+       "BatchMatmulOp::getDefaultIndexingMaps($_builder.getContext())"
+      >:$indexing_maps
     );
     let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
     let regions = (region AnyRegion:$region);
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index b756a67f3ba7a..b488e748df7ba 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -4004,11 +4004,6 @@ ParseResult BatchMatmulOp::parse(OpAsmParser &parser, OperationState &result) {
 }
 
 void BatchMatmulOp::print(OpAsmPrinter &p) {
-  SmallVector<StringRef, 3> elidedAttrs = {
-      "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
-  ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
-                           elidedAttrs);
-
   SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
       BatchMatmulOp::getDefaultIndexingMaps(getContext()),
       [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
@@ -4018,6 +4013,11 @@ void BatchMatmulOp::print(OpAsmPrinter &p) {
                           [&](Attribute attr) { p.printAttribute(attr); });
     p << "]";
   }
+
+  SmallVector<StringRef, 3> elidedAttrs = {
+      "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
+  ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
+                           elidedAttrs);
 }
 
 /// Verify the user defined indexing maps.
diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index 5cda4769d593f..e4890dd97e935 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -193,3 +193,23 @@ def contract(
     )
     fill_builtin_region(op.operation)
     return op
+
+def batch_matmul(
+    *ins: Union[Operation, OpView, Value],
+    outs: Sequence[Union[Operation, OpView, Value]],
+    indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
+):
+    ins = [_get_op_result_or_value(input) for input in ins]
+    if len(outs) > 1:
+        raise ValueError(f"{outs=} must have length 1.")
+    init = _get_op_result_or_value(outs[0])
+    result_types = [init.type] if isinstance(init.type, RankedTensorType) else []
+
+    op = BatchMatmulOp(
+        result_tensors=result_types,
+        inputs=ins,
+        outputs=[init],
+        indexing_maps=indexing_maps,
+    )
+    fill_builtin_region(op.operation)
+    return op
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 8474eeac0db5b..1bd9c8825b05e 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1497,7 +1497,7 @@ func.func @matmul_transpose_b(%arg0: memref<3x5xf32>, %arg1: memref<7x5xf32>, %a
 // CHECK-SAME:                                    %[[VAL_0:.*]]: memref<5xf32>,
 // CHECK-SAME:                                    %[[VAL_1:.*]]: memref<2x5x7xf32>,
 // CHECK-SAME:                                    %[[VAL_2:.*]]: memref<2x3x7xf32>) {
-// CHECK:           linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK:           linalg.batch_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>)
 // CHECK:           return
 // CHECK:         }
 func.func @batch_matmul_bcast_k_to_fill_missing_dims_A(%arg0: memref<5xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<2x3x7xf32>) {
@@ -1520,7 +1520,7 @@ func.func @batch_matmul_bcast_k_to_fill_missing_dims_A(%arg0: memref<5xf32>, %ar
 // CHECK-SAME:                                              %[[VAL_0:.*]]: memref<3x5xf32>,
 // CHECK-SAME:                                              %[[VAL_1:.*]]: memref<2x5x7xf32>,
 // CHECK-SAME:                                              %[[VAL_2:.*]]: memref<2x3x7xf32>) {
-// CHECK:           linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK:           linalg.batch_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>)
 // CHECK:           return
 // CHECK:         }
 func.func @batch_matmul_bcast_batch_dim_A(%arg0: memref<3x5xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<2x3x7xf32>) {
@@ -1543,7 +1543,7 @@ func.func @batch_matmul_bcast_batch_dim_A(%arg0: memref<3x5xf32>, %arg1: memref<
 // CHECK-SAME:                                                    %[[VAL_0:.*]]: memref<2x3x5xf32>,
 // CHECK-SAME:                                                    %[[VAL_1:.*]]: memref<5xf32>,
 // CHECK-SAME:                                                    %[[VAL_2:.*]]: memref<2x3x7xf32>) {
-// CHECK:           linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<2x3x5xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK:           linalg.batch_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<2x3x5xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>)
 // CHECK:           return
 // CHECK:         }
 func.func @batch_matmul_bcast_batch_and_n_dim_B(%arg0: memref<2x3x5xf32>, %arg1: memref<5xf32>, %arg2: memref<2x3x7xf32>) {
@@ -1566,7 +1566,7 @@ func.func @batch_matmul_bcast_batch_and_n_dim_B(%arg0: memref<2x3x5xf32>, %arg1:
 // CHECK-SAME:                                              %[[VAL_0:.*]]: memref<2x3x5xf32>,
 // CHECK-SAME:                                              %[[VAL_1:.*]]: memref<5x7xf32>,
 // CHECK-SAME:                                              %[[VAL_2:.*]]: memref<2x3x7xf32>) {
-// CHECK:           linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<2x3x5xf32>, memref<5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK:           linalg.batch_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<2x3x5xf32>, memref<5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>)
 // CHECK:           return
 // CHECK:         }
 
@@ -1622,7 +1622,7 @@ func.func @batch_matmul_explicit_transpose_B(%arg0: memref<2x3x5xf32>, %arg1: me
 // CHECK-SAME:                                                %[[VAL_0:.*]]: memref<3x5xf32>,
 // CHECK-SAME:                                                %[[VAL_1:.*]]: memref<2x7x5xf32>,
 // CHECK-SAME:                                                %[[VAL_2:.*]]: memref<2x3x7xf32>) {
-// CHECK:           linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<2x7x5xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK:           linalg.batch_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<2x7x5xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>)
 // CHECK:           return
 // CHECK:         }
 func.func @batch_matmul_bcast_A_transpose_B(%arg0: memref<3x5xf32>, %arg1: memref<2x7x5xf32>, %arg2: memref<2x3x7xf32>) {
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index 94f8ea4faf4a8..914fa4b7af261 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -466,3 +466,102 @@ def matmul_as_contract_op(
                 )
 
         print(module)
+
+# CHECK-LABEL: TEST: testBatchMatmulOp
+ at run
+def testBatchMatmulOp():
+    with Context(), Location.unknown():
+        module = Module.create()
+        f32 = F32Type.get()
+        with InsertionPoint(module.body):
+            a_shape = (2, 4, 8)
+            b_shape = (2, 8, 12)
+            b_transposed_shape = (2, 12, 8)
+            c_shape = (2, 4, 12)
+
+            dimBatch = ir.AffineDimExpr.get(0)
+            dimM = ir.AffineDimExpr.get(1)
+            dimN = ir.AffineDimExpr.get(2)
+            dimK = ir.AffineDimExpr.get(3)
+
+            # CHECK: #[[$A_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+            # CHECK: #[[$BTrans_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+            # CHECK: #[[$C_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
+            a_map = ir.AffineMap.get(4, 0, [dimBatch, dimM, dimK])
+            b_transposed_map = ir.AffineMap.get(4, 0, [dimBatch, dimN, dimK])
+            c_map = ir.AffineMap.get(4, 0, [dimBatch, dimM, dimN])
+
+            # CHECK: func.func @batch_matmul_op(
+            @func.FuncOp.from_py_func(
+                # CHECK-SAME:                         %[[A:.*]]: tensor<2x4x8xf32>,
+                RankedTensorType.get(a_shape, f32),
+                # CHECK-SAME:                         %[[Amem:.*]]: memref<2x4x8xf32>,
+                MemRefType.get(a_shape, f32),
+                # CHECK-SAME:                         %[[B:.*]]: tensor<2x8x12xf32>,
+                RankedTensorType.get(b_shape, f32),
+                # CHECK-SAME:                         %[[Bmem:.*]]: memref<2x8x12xf32>,
+                MemRefType.get(b_shape, f32),
+                # CHECK-SAME:                         %[[BTrans:.*]]: tensor<2x12x8xf32>,
+                RankedTensorType.get(b_transposed_shape, f32),
+                # CHECK-SAME:                         %[[BTransmem:.*]]: memref<2x12x8xf32>,
+                MemRefType.get(b_transposed_shape, f32),
+                # CHECK-SAME:                         %[[C:.*]]: tensor<2x4x12xf32>,
+                RankedTensorType.get(c_shape, f32),
+                # CHECK-SAME:                         %[[Cmem:.*]]: memref<2x4x12xf32>)
+                MemRefType.get(c_shape, f32),
+            )
+            def batch_matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
+                # CHECK: linalg.batch_matmul ins(%[[A]], %[[B]] : tensor<2x4x8xf32>, tensor<2x8x12xf32>) outs(%[[C]] : tensor<2x4x12xf32>)
+                res = linalg.BatchMatmulOp(
+                    result_tensors=(C.type,),
+                    inputs=(A, B),
+                    outputs=(C,),
+                )
+                linalg.fill_builtin_region(res.operation)
+                # CHECK: linalg.batch_matmul ins(%[[A]], %[[B]] : tensor<2x4x8xf32>, tensor<2x8x12xf32>) outs(%[[C]] : tensor<2x4x12xf32>)
+                res = linalg.batch_matmul(A, B, outs=(C,))
+
+                # CHECK: linalg.batch_matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<2x4x8xf32>, tensor<2x12x8xf32>) outs(%[[C]] : tensor<2x4x12xf32>)
+                res = linalg.BatchMatmulOp(
+                    result_tensors=(C.type,),
+                    inputs=(A, Btransposed),
+                    outputs=(C,),
+                    indexing_maps=[a_map, b_transposed_map, c_map],
+                )
+                linalg.fill_builtin_region(res.operation)
+                # CHECK: linalg.batch_matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<2x4x8xf32>, tensor<2x12x8xf32>) outs(%[[C]] : tensor<2x4x12xf32>)
+                res = linalg.batch_matmul(
+                    A,
+                    Btransposed,
+                    outs=(C,),
+                    indexing_maps=[a_map, b_transposed_map, c_map],
+                )
+
+                # CHECK: linalg.batch_matmul ins(%[[Amem]], %[[Bmem]] : memref<2x4x8xf32>, memref<2x8x12xf32>) outs(%[[Cmem]] : memref<2x4x12xf32>)
+                res = linalg.BatchMatmulOp(
+                    result_tensors=[],
+                    inputs=(Amem, Bmem),
+                    outputs=(Cmem,),
+                )
+                linalg.fill_builtin_region(res.operation)
+                # CHECK: linalg.batch_matmul ins(%[[Amem]], %[[Bmem]] : memref<2x4x8xf32>, memref<2x8x12xf32>) outs(%[[Cmem]] : memref<2x4x12xf32>)
+                linalg.batch_matmul(Amem, Bmem, outs=(Cmem,))
+
+                # CHECK: linalg.batch_matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<2x4x8xf32>, memref<2x12x8xf32>) outs(%[[Cmem]] : memref<2x4x12xf32>)
+                res = linalg.BatchMatmulOp(
+                    result_tensors=[],
+                    inputs=(Amem, Btransposedmem),
+                    outputs=(Cmem,),
+                    indexing_maps=[a_map, b_transposed_map, c_map],
+                )
+                linalg.fill_builtin_region(res.operation)
+                # CHECK: linalg.batch_matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<2x4x8xf32>, memref<2x12x8xf32>) outs(%[[Cmem]] : memref<2x4x12xf32>)
+                linalg.batch_matmul(
+                    Amem,
+                    Btransposedmem,
+                    outs=(Cmem,),
+                    indexing_maps=[a_map, b_transposed_map, c_map],
+                )
+
+    print(module)

>From 3be92347e51f8dea036824774451d8329592e067 Mon Sep 17 00:00:00 2001
From: mshahid <md.asghar.ahmad.shahid at intel.com>
Date: Tue, 18 Feb 2025 03:17:12 -0800
Subject: [PATCH 2/4] Fix python formatting

---
 mlir/python/mlir/dialects/linalg/__init__.py | 1 +
 mlir/test/python/dialects/linalg/ops.py      | 1 +
 2 files changed, 2 insertions(+)

diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index e4890dd97e935..30d86ce085e44 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -194,6 +194,7 @@ def contract(
     fill_builtin_region(op.operation)
     return op
 
+
 def batch_matmul(
     *ins: Union[Operation, OpView, Value],
     outs: Sequence[Union[Operation, OpView, Value]],
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index 914fa4b7af261..307a88709ad52 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -467,6 +467,7 @@ def matmul_as_contract_op(
 
         print(module)
 
+
 # CHECK-LABEL: TEST: testBatchMatmulOp
 @run
 def testBatchMatmulOp():

>From 8f63eaa2c60be1e7a2add95bc554c4c109c97126 Mon Sep 17 00:00:00 2001
From: mshahid <md.asghar.ahmad.shahid at intel.com>
Date: Tue, 18 Feb 2025 23:41:11 -0800
Subject: [PATCH 3/4] -Refactored the linalg op python wrapper creation with a
 helper function.

---
 .../Dialect/Linalg/IR/LinalgStructuredOps.td  | 10 ++--
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 15 ++++--
 mlir/python/mlir/dialects/linalg/__init__.py  | 50 +++++++------------
 3 files changed, 36 insertions(+), 39 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 7ce6c4a353afe..27e12445f2d31 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -859,9 +859,10 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
       Variadic<AnyType>:$inputs,
       Variadic<AnyShaped>:$outputs,
       DefaultValuedOptionalAttr<
-       AffineMapArrayAttr,
-       "BatchMatmulOp::getDefaultIndexingMaps($_builder.getContext())"
-      >:$indexing_maps
+        AffineMapArrayAttr,
+        "BatchMatmulOp::getDefaultIndexingMaps($_builder.getContext())"
+      >:$indexing_maps,
+      DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
     );
     let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
     let regions = (region AnyRegion:$region);
@@ -887,9 +888,10 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
       }]>,
       OpBuilder<
       (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
-            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+            "Attribute":$cast, CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
       [{
         $_state.addOperands(operands);
+        $_state.addAttribute("cast", cast);
         $_state.addAttributes(attributes);
         $_state.addTypes(resultTensorTypes);
         (void)$_state.addRegion(),
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index b488e748df7ba..42ea0e1197ef1 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -3951,11 +3951,18 @@ void BatchMatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
   RegionBuilderHelper helper(b, block);
   SmallVector<Value> yields;
 
+  TypeFn castVal = TypeFn::cast_signed;
+  auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
+    return attr.getName() == "cast";
+  });
+  if (castIter != attrs.end()) {
+    if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
+      castVal = attr.getValue();
+  }
+
   auto toType = block.getArgument(2).getType();
-  Value castValA =
-      helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(0));
-  Value castValB =
-      helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(1));
+  Value castValA = helper.buildTypeFn(castVal, toType, block.getArgument(0));
+  Value castValB = helper.buildTypeFn(castVal, toType, block.getArgument(1));
   Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
   Value addVal =
       helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);
diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index 30d86ce085e44..c5fbb833ee399 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -149,7 +149,8 @@ def __init__(
 generic = region_op(GenericOp_, terminator=YieldOp)
 
 
-def matmul(
+def create_op(
+    op_type,
     *ins: Union[Operation, OpView, Value],
     outs: Sequence[Union[Operation, OpView, Value]],
     indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
@@ -161,7 +162,7 @@ def matmul(
     init = _get_op_result_or_value(outs[0])
     result_types = [init.type] if isinstance(init.type, RankedTensorType) else []
 
-    op = MatmulOp(
+    op = op_type(
         result_tensors=result_types,
         inputs=ins,
         outputs=[init],
@@ -172,45 +173,32 @@ def matmul(
     return op
 
 
-def contract(
+def matmul(
     *ins: Union[Operation, OpView, Value],
     outs: Sequence[Union[Operation, OpView, Value]],
-    indexing_maps: Sequence[AffineMapAttr],
+    indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
     cast: Optional[Union[TypeFn, Attribute]] = None,
 ):
-    ins = [_get_op_result_or_value(input) for input in ins]
-    if len(outs) > 1:
-        raise ValueError(f"{outs=} must have length 1.")
-    init = _get_op_result_or_value(outs[0])
-    result_types = [init.type] if isinstance(init.type, RankedTensorType) else []
-
-    op = ContractOp(
-        result_tensors=result_types,
-        inputs=ins,
-        outputs=[init],
-        indexing_maps=indexing_maps,
-        cast=cast,
-    )
-    fill_builtin_region(op.operation)
-    return op
+    return create_op(MatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast)
 
 
 def batch_matmul(
     *ins: Union[Operation, OpView, Value],
     outs: Sequence[Union[Operation, OpView, Value]],
     indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
+    cast: Optional[Union[TypeFn, Attribute]] = None,
 ):
-    ins = [_get_op_result_or_value(input) for input in ins]
-    if len(outs) > 1:
-        raise ValueError(f"{outs=} must have length 1.")
-    init = _get_op_result_or_value(outs[0])
-    result_types = [init.type] if isinstance(init.type, RankedTensorType) else []
+    return create_op(
+        BatchMatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
+    )
 
-    op = BatchMatmulOp(
-        result_tensors=result_types,
-        inputs=ins,
-        outputs=[init],
-        indexing_maps=indexing_maps,
+
+def contract(
+    *ins: Union[Operation, OpView, Value],
+    outs: Sequence[Union[Operation, OpView, Value]],
+    indexing_maps: Sequence[AffineMapAttr],
+    cast: Optional[Union[TypeFn, Attribute]] = None,
+):
+    return create_op(
+        ContractOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
     )
-    fill_builtin_region(op.operation)
-    return op

>From 34b7ccdc087b6bf5422c112589079dbaa5a8c757 Mon Sep 17 00:00:00 2001
From: mshahid <md.asghar.ahmad.shahid at intel.com>
Date: Wed, 19 Feb 2025 04:54:34 -0800
Subject: [PATCH 4/4] -Fix indentation.

---
 mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 27e12445f2d31..a5725d6f1507e 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -888,7 +888,7 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
       }]>,
       OpBuilder<
       (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
-            "Attribute":$cast, CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+          "Attribute":$cast, CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
       [{
         $_state.addOperands(operands);
         $_state.addAttribute("cast", cast);



More information about the Mlir-commits mailing list