[Mlir-commits] [mlir] cb1911a - [mlir][linalg] Add named op for matmul_transpose_a
Nicolas Vasilache
llvmlistbot at llvm.org
Mon Jun 26 13:09:04 PDT 2023
Author: Nicolas Vasilache
Date: 2023-06-26T20:09:00Z
New Revision: cb1911a2db14b02a3f7855723d9138d26c41e66f
URL: https://github.com/llvm/llvm-project/commit/cb1911a2db14b02a3f7855723d9138d26c41e66f
DIFF: https://github.com/llvm/llvm-project/commit/cb1911a2db14b02a3f7855723d9138d26c41e66f.diff
LOG: [mlir][linalg] Add named op for matmul_transpose_a
matmul with transposed LHS operand allows better memory access
patterns on several architectures including common GPUs. Having a named
op for it allows to handle this kind of matmul in a more explicit way.
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
mlir/test/Dialect/Linalg/named-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 09516b95cc4bb..a89a112574861 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -400,6 +400,79 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: BZp
--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+ name: matmul_transpose_a
+ cpp_class_name: MatmulTransposeAOp
+ doc: |-
+ Performs a matrix multiplication of two 2D inputs with lhs operand
+ transposed.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ implements:
+ - LinalgContractionOpInterface
+structured_op: !LinalgStructuredOpConfig
+ args:
+ - !LinalgOperandDefConfig
+ name: A
+ kind: input_tensor
+ type_var: T1
+ shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
+ - !LinalgOperandDefConfig
+ name: B
+ kind: input_tensor
+ type_var: T2
+ shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
+ - !LinalgOperandDefConfig
+ name: C
+ kind: output_tensor
+ type_var: U
+ shape_map: affine_map<()[s0, s1, s2] -> (s2, s1)>
+ - !LinalgOperandDefConfig
+ name: cast
+ kind: type_fn_attr
+ default_fn: cast_signed
+ indexing_maps: !LinalgIndexingMapsConfig
+ static_indexing_maps:
+ - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d0)>
+ - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d1)>
+ - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
+ iterator_types:
+ - parallel
+ - parallel
+ - reduction
+ assignments:
+ - !ScalarAssign
+ arg: C
+ value: !ScalarExpression
+ scalar_fn:
+ kind: binary
+ fn_name: add
+ operands:
+ - !ScalarExpression
+ scalar_arg: C
+ - !ScalarExpression
+ scalar_fn:
+ kind: binary
+ fn_name: mul
+ operands:
+ - !ScalarExpression
+ scalar_fn:
+ kind: type
+ attr_name: cast
+ type_var: U
+ operands:
+ - !ScalarExpression
+ scalar_arg: A
+ - !ScalarExpression
+ scalar_fn:
+ kind: type
+ attr_name: cast
+ type_var: U
+ operands:
+ - !ScalarExpression
+ scalar_arg: B
+--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: matmul_transpose_b
cpp_class_name: MatmulTransposeBOp
@@ -621,6 +694,76 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: B
--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+ name: batch_matmul_transpose_a
+ cpp_class_name: BatchMatmulTransposeAOp
+ doc: |-
+ Performs a batched matrix multiplication of two 3D inputs where lhs operand
+ has its non-batch dimensions transposed.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ implements:
+ - LinalgContractionOpInterface
+structured_op: !LinalgStructuredOpConfig
+ args:
+ - !LinalgOperandDefConfig
+ name: A
+ kind: input_tensor
+ type_var: T1
+ shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
+ - !LinalgOperandDefConfig
+ name: B
+ kind: input_tensor
+ type_var: T2
+ shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)>
+ - !LinalgOperandDefConfig
+ name: C
+ kind: output_tensor
+ type_var: U
+ shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s2, s3)>
+ indexing_maps: !LinalgIndexingMapsConfig
+ static_indexing_maps:
+ - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d1)>
+ - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d2)>
+ - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d2)>
+ iterator_types:
+ - parallel
+ - parallel
+ - parallel
+ - reduction
+ assignments:
+ - !ScalarAssign
+ arg: C
+ value: !ScalarExpression
+ scalar_fn:
+ kind: binary
+ fn_name: add
+ operands:
+ - !ScalarExpression
+ scalar_arg: C
+ - !ScalarExpression
+ scalar_fn:
+ kind: binary
+ fn_name: mul
+ operands:
+ - !ScalarExpression
+ scalar_fn:
+ kind: type
+ fn_name: cast_signed
+ type_var: U
+ operands:
+ - !ScalarExpression
+ scalar_arg: A
+ - !ScalarExpression
+ scalar_fn:
+ kind: type
+ fn_name: cast_signed
+ type_var: U
+ operands:
+ - !ScalarExpression
+ scalar_arg: B
+--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: batch_matmul_transpose_b
cpp_class_name: BatchMatmulTransposeBOp
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 263d2109b4a9f..4c3e8fb25700b 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -107,6 +107,22 @@ def quantized_matmul(
)
+ at linalg_structured_op
+def matmul_transpose_a(A=TensorDef(T1, S.K, S.N),
+ B=TensorDef(T2, S.K, S.M),
+ C=TensorDef(U, S.M, S.N, output=True),
+ cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
+ """Performs a matrix multiplication of two 2D inputs with lhs operand
+ transposed.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ """
+ domain(D.m, D.n, D.k)
+ implements(ContractionOpInterface)
+ C[D.m, D.n] += cast(U, A[D.k, D.m]) * cast(U, B[D.k, D.n])
+
+
@linalg_structured_op
def matmul_transpose_b(A=TensorDef(T1, S.M, S.K),
B=TensorDef(T2, S.N, S.K),
@@ -164,6 +180,22 @@ def batch_matmul(
)
+ at linalg_structured_op
+def batch_matmul_transpose_a(A=TensorDef(T1, Batch, S.K, S.M),
+ B=TensorDef(T2, Batch, S.K, S.N),
+ C=TensorDef(U, Batch, S.M, S.N, output=True)):
+ """Performs a batched matrix multiplication of two 3D inputs where lhs operand
+ has its non-batch dimensions transposed.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ """
+ domain(D.b, D.m, D.n, D.k)
+ implements(ContractionOpInterface)
+ C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.k, D.m]) \
+ * TypeFn.cast_signed(U, B[D.b, D.k, D.n])
+
+
@linalg_structured_op
def batch_matmul_transpose_b(A=TensorDef(T1, Batch, S.M, S.K),
B=TensorDef(T2, Batch, S.N, S.K),
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index e89fb8d682e22..9396e2197e3d6 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1143,6 +1143,17 @@ func.func @batch_reduce_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32
// -----
+// CHECK-LABEL: func @matmul_transpose_a
+// CHECK: linalg.matmul_transpose_a
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<5x3xf32>, memref<5x7xf32>)
+// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
+func.func @matmul_transpose_a(%arg0: memref<5x3xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+ linalg.matmul_transpose_a ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>)
+ return
+}
+
+// -----
+
// CHECK-LABEL: func @matmul_transpose_b
// CHECK: linalg.matmul_transpose_b
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<7x5xf32>)
@@ -1154,6 +1165,17 @@ func.func @matmul_transpose_b(%arg0: memref<3x5xf32>, %arg1: memref<7x5xf32>, %a
// -----
+// CHECK-LABEL: func @batchmatmul_transpose_a
+// CHECK: linalg.batch_matmul_transpose_a
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<2x5x3xf32>, memref<2x5x7xf32>)
+// CHECK-SAME: outs(%{{.+}} : memref<2x3x7xf32>)
+func.func @batchmatmul_transpose_a(%arg0: memref<2x5x3xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<2x3x7xf32>) {
+ linalg.batch_matmul_transpose_a ins(%arg0, %arg1 : memref<2x5x3xf32>, memref<2x5x7xf32>) outs(%arg2: memref<2x3x7xf32>)
+ return
+}
+
+// -----
+
// CHECK-LABEL: func @batchmatmul_transpose_b
// CHECK: linalg.batch_matmul_transpose_b
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<2x3x5xf32>, memref<2x7x5xf32>)
More information about the Mlir-commits
mailing list