[PATCH] D78327: [mlir][Linalg] Create a named batchmatmul op and pipe it through.

Mahesh Ravishankar via Phabricator via llvm-commits llvm-commits at lists.llvm.org
Sat Apr 18 23:58:02 PDT 2020


mravishankar added inline comments.


================
Comment at: mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc:2
+ods_def<BatchMatmulOp>:
+def batch_matmul(A: f32(Batch, M, K), B: f32(K, N)) -> (C: f32(Batch, M, N)) {
+  C(b, m, n) = std_addf<k>(std_mulf(A(b, m, k), B(k, n)));
----------------
mravishankar wrote:
> nicolasvasilache wrote:
> > nicolasvasilache wrote:
> > > silvas wrote:
> > > > mravishankar wrote:
> > > > > What is the reference for this specification ? ONNX/TF both seem to have a batch dimension for B as well. Without that this is effectively broadcasting B
> > > > This isn't enough to legalize e.g. tf.BatchMatMul or torch.matmul, which allow leading batch dimensions on both sides.
> > > > 
> > > > https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-mat-mul
> > > > https://pytorch.org/docs/stable/torch.html#torch.matmul
> > > > 
> > > > In IREE we have a batch matmul op that handles batch on both sides:
> > > > https://github.com/google/iree/blob/f80f39c7e96c2af15741e9c774eb8b54bf38df28/iree/compiler/Dialect/VMLA/IR/VMLAOps.td#L323
> > > > 
> > > > I expect that in a typical lowering flow, we will legalize tf.BatchMatMul or torch.matmul by reshaping all the batch dimensions into a single dimension on both sides (possibly a dummy "1" dimension in case of no batch on one side). Then we can expand this op into generic form and fuse/cleanup those reshapes which will eliminate batch dimensions on either side.
> > > > 
> > > > I don't see a situation where we would create this op. 
> > > > 
> > > > My intuition is that batch matmul with a batch dimension only on one side is not that interesting, because fundamentally it is the same as a regular matmul, because you just fold the batch dimension into the free dimension of the respective operand (e.g. in the case you have here, you can just reshape the two dimensions Batch,M in the LHS into a single dimension of extent Batch*M). Batch matmul is only interesting from a lowering perspective when you have a batch dimension on both sides, which introduces a distinct data-reuse behavior as compared to a normal matmul.
> > > > 
> > > > So in terms of defining a set of "primitives" or lowering to library calls (e.g. https://devblogs.nvidia.com/cublas-strided-batched-matrix-multiply/), having a batch on both sides seems to be the only relevant case. So I would recommend defining this as:
> > > > ```
> > > > def batch_matmul(A: f32(Batch, M, K), B: f32(Batch, K, N)) -> (C: f32(Batch, M, N)) {
> > > > ```
> > > It's just something to get started, the semantics are variadic and will require extensions.
> > > 
> > > Once these are implemented it will be easy to update.
> > > If you have a strong preference for another op, let me know what you'd prefer (an op that also exercises reduction).
> > > It can't be dot/matvec/matmul for now because that's already taken and more work is needed to replace them.
> > I went with @silvas ' suggestion, we can iterate on the semantics later once we have variadic support.
> Thanks for the update. I am not sure I follow what the semantics is variadic implies, i.e. I dont see anything variadic about the op as defined here, but I might be misreading the terms.
> My concern was merely if the named ops are supposed to have implicit broadcast semantics (in thoery it can, but that seems to lead to complications when it comes to things like dynamic broadcasting, etc. based on discussion on discourse). As it was defined previously, I read B as having broadcast semantics. Anyway, its OK now so thanks for taking care of it.
Actually strike the last comment. The spec has nothing to do with broadcasting. But, the current spec is indeed more preferable.


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D78327/new/

https://reviews.llvm.org/D78327





More information about the llvm-commits mailing list