[Mlir-commits] [mlir] 0516f49 - Add linalg.mmt4d named op
Ahmed Taei
llvmlistbot at llvm.org
Thu Jul 1 12:41:16 PDT 2021
Author: Ahmed Taei
Date: 2021-07-01T12:41:08-07:00
New Revision: 0516f49c081590305a9db972ebc7fceb942b8ce3
URL: https://github.com/llvm/llvm-project/commit/0516f49c081590305a9db972ebc7fceb942b8ce3
DIFF: https://github.com/llvm/llvm-project/commit/0516f49c081590305a9db972ebc7fceb942b8ce3.diff
LOG: Add linalg.mmt4d named op
This op performs matrix-matrix-transpose multiplication of 4-d inputs as the following:
```
C[m1, n1, m0, n0] = sum_{k1, k0}(A[m1, k1, m0, k0] * B[n1, k1, n0, k0])
```
Reviewed By: Benoit
Differential Revision: https://reviews.llvm.org/D105244
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 8781e16bba34..a8baf23bbfaa 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -62,6 +62,79 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: B
--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+ name: mmt4d
+ cpp_class_name: Mmt4DOp
+ doc: |-
+ Performs a matrix-matrix-transpose multiplication of two 4D inputs.
+
+ Differences from linalg.matmul:
+ * The right hand side is transposed, whence the 't' in 'mmt'.
+ * The input and output tensors have a 4D shape instead of a 2D shape. They
+ are interpreted as 2D matrices with one level of 2D tile subdivision,
+ whence the 2+2=4 dimensions. The inner tile dimensions are identified with
+ '0' suffixes below, for instance the LHS matrix shape (M, K, M0, K0) reads
+ as: MxK tiles, each of shape M0xK0.
+ implements:
+ - LinalgContractionOpInterface
+structured_op: !LinalgStructuredOpConfig
+ args:
+ - !LinalgOperandDefConfig
+ name: lhs
+ usage: InputOperand
+ type_var: LhsType
+ shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s1, s2, s3)>
+ - !LinalgOperandDefConfig
+ name: rhs
+ usage: InputOperand
+ type_var: RhsType
+ shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s4, s1, s5, s3)>
+ - !LinalgOperandDefConfig
+ name: accum
+ usage: OutputOperand
+ type_var: AccumType
+ shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s4, s2, s5)>
+ indexing_maps: !LinalgIndexingMapsConfig
+ static_indexing_maps:
+ - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5] -> (d0, d4, d1,
+ d5)>
+ - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5] -> (d2, d4, d3,
+ d5)>
+ - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5] -> (d0, d2, d1,
+ d3)>
+ iterator_types:
+ - parallel
+ - parallel
+ - parallel
+ - parallel
+ - reduction
+ - reduction
+ assignments:
+ - !ScalarAssign
+ arg: accum
+ value: !ScalarExpression
+ scalar_apply:
+ fn_name: add
+ operands:
+ - !ScalarExpression
+ scalar_arg: accum
+ - !ScalarExpression
+ scalar_apply:
+ fn_name: mul
+ operands:
+ - !ScalarExpression
+ symbolic_cast:
+ type_var: AccumType
+ operands:
+ - !ScalarExpression
+ scalar_arg: lhs
+ - !ScalarExpression
+ symbolic_cast:
+ type_var: AccumType
+ operands:
+ - !ScalarExpression
+ scalar_arg: rhs
+--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: batch_matmul
cpp_class_name: BatchMatmulOp
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 561cd2e7d08d..095d94956f5b 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -21,6 +21,26 @@ def matmul(
C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
+ at linalg_structured_op
+def mmt4d(lhs=TensorDef(TV.LhsType, S.M, S.K, S.M0, S.K0),
+ rhs=TensorDef(TV.RhsType, S.N, S.K, S.N0, S.K0),
+ accum=TensorDef(TV.AccumType, S.M, S.N, S.M0, S.N0,
+ output=True)):
+ """Performs a matrix-matrix-transpose multiplication of two 4D inputs.
+
+ Differences from linalg.matmul:
+ * The right hand side is transposed, whence the 't' in 'mmt'.
+ * The input and output tensors have a 4D shape instead of a 2D shape. They
+ are interpreted as 2D matrices with one level of 2D tile subdivision,
+ whence the 2+2=4 dimensions. The inner tile dimensions are identified with
+ '0' suffixes below, for instance the LHS matrix shape (M, K, M0, K0) reads
+ as: MxK tiles, each of shape M0xK0.
+ """
+ domain(D.m, D.m0, D.n, D.n0, D.k, D.k0)
+ implements(ContractionOpInterface)
+ accum[D.m, D.n, D.m0, D.n0] += cast(TV.AccumType, lhs[D.m, D.k, D.m0, D.k0]) * cast(TV.AccumType, rhs[D.n, D.k, D.n0, D.k0])
+
+
@linalg_structured_op
def batch_matmul(
A=TensorDef(T1, Batch, S.M, S.K),
More information about the Mlir-commits
mailing list