[PATCH] D78327: [mlir][Linalg] Create a named batchmatmul op and pipe it through.

Sean Silva via Phabricator via llvm-commits llvm-commits at lists.llvm.org
Fri Apr 17 14:05:04 PDT 2020


silvas added inline comments.


================
Comment at: mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc:2
+ods_def<BatchMatmulOp>:
+def batch_matmul(A: f32(Batch, M, K), B: f32(K, N)) -> (C: f32(Batch, M, N)) {
+  C(b, m, n) = std_addf<k>(std_mulf(A(b, m, k), B(k, n)));
----------------
mravishankar wrote:
> What is the reference for this specification ? ONNX/TF both seem to have a batch dimension for B as well. Without that this is effectively broadcasting B
This isn't enough to legalize e.g. tf.BatchMatMul or torch.matmul, which allow leading batch dimensions on both sides.

https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-mat-mul
https://pytorch.org/docs/stable/torch.html#torch.matmul

In IREE we have a batch matmul op that handles batch on both sides:
https://github.com/google/iree/blob/f80f39c7e96c2af15741e9c774eb8b54bf38df28/iree/compiler/Dialect/VMLA/IR/VMLAOps.td#L323

I expect that in a typical lowering flow, we will legalize tf.BatchMatMul or torch.matmul by reshaping all the batch dimensions into a single dimension on both sides (possibly a dummy "1" dimension in case of no batch on one side). Then we can expand this op into generic form and fuse/cleanup those reshapes which will eliminate batch dimensions on either side.

I don't see a situation where we would create this op. 

My intuition is that batch matmul with a batch dimension only on one side is not that interesting, because fundamentally it is the same as a regular matmul, because you just fold the batch dimension into the free dimension of the respective operand (e.g. in the case you have here, you can just reshape the two dimensions Batch,M in the LHS into a single dimension of extent Batch*M). Batch matmul is only interesting from a lowering perspective when you have a batch dimension on both sides, which introduces a distinct data-reuse behavior as compared to a normal matmul.

So in terms of defining a set of "primitives" or lowering to library calls (e.g. https://devblogs.nvidia.com/cublas-strided-batched-matrix-multiply/), having a batch on both sides seems to be the only relevant case. So I would recommend defining this as:
```
def batch_matmul(A: f32(Batch, M, K), B: f32(Batch, K, N)) -> (C: f32(Batch, M, N)) {
```


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D78327/new/

https://reviews.llvm.org/D78327





More information about the llvm-commits mailing list