[Mlir-commits] [mlir] d0bf55c - [mlir]][linalg] Add named op for matmul_transpose_b
Thomas Raoux
llvmlistbot at llvm.org
Wed Jan 11 11:50:16 PST 2023
Author: Thomas Raoux
Date: 2023-01-11T19:49:48Z
New Revision: d0bf55cfdbe1ac388496924a6b96a281a2b25c21
URL: https://github.com/llvm/llvm-project/commit/d0bf55cfdbe1ac388496924a6b96a281a2b25c21
DIFF: https://github.com/llvm/llvm-project/commit/d0bf55cfdbe1ac388496924a6b96a281a2b25c21.diff
LOG: [mlir]][linalg] Add named op for matmul_transpose_b
matmul where the RHS operand is transposed allows better memory access
patterns on several architectures including common GPUs. Having a named
op for it allows to handle this kind of matmul in a more explicit way.
Differential Revision: https://reviews.llvm.org/D141430
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
mlir/test/Dialect/Linalg/named-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 249f60ba0297..cbe40fcec5e0 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -479,6 +479,78 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: rhs
--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+ name: matmul_transpose_b
+ cpp_class_name: MatmulTransposeBOp
+ doc: |-
+ Performs a matrix multiplication of two 2D inputs with rhs operand transposed.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ implements:
+ - LinalgContractionOpInterface
+structured_op: !LinalgStructuredOpConfig
+ args:
+ - !LinalgOperandDefConfig
+ name: A
+ kind: input_tensor
+ type_var: T1
+ shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
+ - !LinalgOperandDefConfig
+ name: B
+ kind: input_tensor
+ type_var: T2
+ shape_map: affine_map<()[s0, s1, s2] -> (s2, s1)>
+ - !LinalgOperandDefConfig
+ name: C
+ kind: output_tensor
+ type_var: U
+ shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
+ - !LinalgOperandDefConfig
+ name: cast
+ kind: type_fn_attr
+ default_fn: cast_signed
+ indexing_maps: !LinalgIndexingMapsConfig
+ static_indexing_maps:
+ - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
+ - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d1, d2)>
+ - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
+ iterator_types:
+ - parallel
+ - parallel
+ - reduction
+ assignments:
+ - !ScalarAssign
+ arg: C
+ value: !ScalarExpression
+ scalar_fn:
+ kind: binary
+ fn_name: add
+ operands:
+ - !ScalarExpression
+ scalar_arg: C
+ - !ScalarExpression
+ scalar_fn:
+ kind: binary
+ fn_name: mul
+ operands:
+ - !ScalarExpression
+ scalar_fn:
+ kind: type
+ attr_name: cast
+ type_var: U
+ operands:
+ - !ScalarExpression
+ scalar_arg: A
+ - !ScalarExpression
+ scalar_fn:
+ kind: type
+ attr_name: cast
+ type_var: U
+ operands:
+ - !ScalarExpression
+ scalar_arg: B
+--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: batch_matmul
cpp_class_name: BatchMatmulOp
@@ -653,6 +725,76 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: BZp
--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+ name: batch_matmul_transpose_b
+ cpp_class_name: BatchMatmulTransposeBOp
+ doc: |-
+ Performs a batched matrix multiplication of two 3D inputs where rhs operand has its non-batch
+ dimensions transposed.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ implements:
+ - LinalgContractionOpInterface
+structured_op: !LinalgStructuredOpConfig
+ args:
+ - !LinalgOperandDefConfig
+ name: A
+ kind: input_tensor
+ type_var: T1
+ shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
+ - !LinalgOperandDefConfig
+ name: B
+ kind: input_tensor
+ type_var: T2
+ shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s3, s2)>
+ - !LinalgOperandDefConfig
+ name: C
+ kind: output_tensor
+ type_var: U
+ shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)>
+ indexing_maps: !LinalgIndexingMapsConfig
+ static_indexing_maps:
+ - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)>
+ - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d2, d3)>
+ - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d2)>
+ iterator_types:
+ - parallel
+ - parallel
+ - parallel
+ - reduction
+ assignments:
+ - !ScalarAssign
+ arg: C
+ value: !ScalarExpression
+ scalar_fn:
+ kind: binary
+ fn_name: add
+ operands:
+ - !ScalarExpression
+ scalar_arg: C
+ - !ScalarExpression
+ scalar_fn:
+ kind: binary
+ fn_name: mul
+ operands:
+ - !ScalarExpression
+ scalar_fn:
+ kind: type
+ fn_name: cast_signed
+ type_var: U
+ operands:
+ - !ScalarExpression
+ scalar_arg: A
+ - !ScalarExpression
+ scalar_fn:
+ kind: type
+ fn_name: cast_signed
+ type_var: U
+ operands:
+ - !ScalarExpression
+ scalar_arg: B
+--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: batch_reduce_matmul
cpp_class_name: BatchReduceMatmulOp
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 381fdc39354a..3a17350f25ac 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1070,3 +1070,25 @@ func.func @batch_reduce_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32
linalg.batch_reduce_matmul ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2: memref<?x?xf32>)
return
}
+
+// -----
+
+// CHECK-LABEL: func @matmul_transpose_b
+// CHECK: linalg.matmul_transpose_b
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<7x5xf32>)
+// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
+func.func @matmul_transpose_b(%arg0: memref<3x5xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
+ linalg.matmul_transpose_b ins(%arg0, %arg1 : memref<3x5xf32>, memref<7x5xf32>) outs(%arg2: memref<3x7xf32>)
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @batchmatmul_transpose_b
+// CHECK: linalg.batch_matmul_transpose_b
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<2x3x5xf32>, memref<2x7x5xf32>)
+// CHECK-SAME: outs(%{{.+}} : memref<2x3x7xf32>)
+func.func @batchmatmul_transpose_b(%arg0: memref<2x3x5xf32>, %arg1: memref<2x7x5xf32>, %arg2: memref<2x3x7xf32>) {
+ linalg.batch_matmul_transpose_b ins(%arg0, %arg1 : memref<2x3x5xf32>, memref<2x7x5xf32>) outs(%arg2: memref<2x3x7xf32>)
+ return
+}
More information about the Mlir-commits
mailing list