[Mlir-commits] [mlir] [MLIR][Linalg] Introduce Python API for linalg.batch_matmul Ops. (PR #127614)

Md Asghar Ahmad Shahid llvmlistbot at llvm.org
Wed Feb 19 02:04:33 PST 2025


================
@@ -193,3 +193,23 @@ def contract(
     )
     fill_builtin_region(op.operation)
     return op
+
+def batch_matmul(
+    *ins: Union[Operation, OpView, Value],
+    outs: Sequence[Union[Operation, OpView, Value]],
+    indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
+):
+    ins = [_get_op_result_or_value(input) for input in ins]
+    if len(outs) > 1:
+        raise ValueError(f"{outs=} must have length 1.")
+    init = _get_op_result_or_value(outs[0])
+    result_types = [init.type] if isinstance(init.type, RankedTensorType) else []
+
+    op = BatchMatmulOp(
+        result_tensors=result_types,
+        inputs=ins,
+        outputs=[init],
+        indexing_maps=indexing_maps,
+    )
+    fill_builtin_region(op.operation)
----------------
shahidact wrote:

Done.

https://github.com/llvm/llvm-project/pull/127614


More information about the Mlir-commits mailing list