[Mlir-commits] [mlir] [MLIR][Linalg] Introduce Python API for linalg.batch_matmul Ops. (PR #127614)
Md Asghar Ahmad Shahid
llvmlistbot at llvm.org
Wed Feb 19 02:04:33 PST 2025
================
@@ -193,3 +193,23 @@ def contract(
)
fill_builtin_region(op.operation)
return op
+
+def batch_matmul(
+ *ins: Union[Operation, OpView, Value],
+ outs: Sequence[Union[Operation, OpView, Value]],
+ indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
+):
+ ins = [_get_op_result_or_value(input) for input in ins]
+ if len(outs) > 1:
+ raise ValueError(f"{outs=} must have length 1.")
+ init = _get_op_result_or_value(outs[0])
+ result_types = [init.type] if isinstance(init.type, RankedTensorType) else []
+
+ op = BatchMatmulOp(
+ result_tensors=result_types,
+ inputs=ins,
+ outputs=[init],
+ indexing_maps=indexing_maps,
+ )
+ fill_builtin_region(op.operation)
----------------
shahidact wrote:
Done.
https://github.com/llvm/llvm-project/pull/127614
More information about the Mlir-commits
mailing list