[Mlir-commits] [mlir] d0bf55c - [mlir]][linalg] Add named op for matmul_transpose_b

Thomas Raoux llvmlistbot at llvm.org
Wed Jan 11 11:50:16 PST 2023


Author: Thomas Raoux
Date: 2023-01-11T19:49:48Z
New Revision: d0bf55cfdbe1ac388496924a6b96a281a2b25c21

URL: https://github.com/llvm/llvm-project/commit/d0bf55cfdbe1ac388496924a6b96a281a2b25c21
DIFF: https://github.com/llvm/llvm-project/commit/d0bf55cfdbe1ac388496924a6b96a281a2b25c21.diff

LOG: [mlir]][linalg] Add named op for matmul_transpose_b

matmul where the RHS operand is transposed allows better memory access
patterns on several architectures including common GPUs. Having a named
op for it allows to handle this kind of matmul in a more explicit way.

Differential Revision: https://reviews.llvm.org/D141430

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
    mlir/test/Dialect/Linalg/named-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 249f60ba0297..cbe40fcec5e0 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -479,6 +479,78 @@ structured_op: !LinalgStructuredOpConfig
                 - !ScalarExpression
                   scalar_arg: rhs
 --- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: matmul_transpose_b
+  cpp_class_name: MatmulTransposeBOp
+  doc: |-
+    Performs a matrix multiplication of two 2D inputs with rhs operand transposed.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+  implements:
+  - LinalgContractionOpInterface
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: A
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
+  - !LinalgOperandDefConfig
+    name: B
+    kind: input_tensor
+    type_var: T2
+    shape_map: affine_map<()[s0, s1, s2] -> (s2, s1)>
+  - !LinalgOperandDefConfig
+    name: C
+    kind: output_tensor
+    type_var: U
+    shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
+  - !LinalgOperandDefConfig
+    name: cast
+    kind: type_fn_attr
+    default_fn: cast_signed
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
+    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d1, d2)>
+    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
+  iterator_types:
+  - parallel
+  - parallel
+  - reduction
+  assignments:
+  - !ScalarAssign
+    arg: C
+    value: !ScalarExpression
+      scalar_fn:
+        kind: binary
+        fn_name: add
+        operands:
+        - !ScalarExpression
+          scalar_arg: C
+        - !ScalarExpression
+          scalar_fn:
+            kind: binary
+            fn_name: mul
+            operands:
+            - !ScalarExpression
+              scalar_fn:
+                kind: type
+                attr_name: cast
+                type_var: U
+                operands:
+                - !ScalarExpression
+                  scalar_arg: A
+            - !ScalarExpression
+              scalar_fn:
+                kind: type
+                attr_name: cast
+                type_var: U
+                operands:
+                - !ScalarExpression
+                  scalar_arg: B
+--- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: batch_matmul
   cpp_class_name: BatchMatmulOp
@@ -653,6 +725,76 @@ structured_op: !LinalgStructuredOpConfig
                     - !ScalarExpression
                       scalar_arg: BZp
 --- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: batch_matmul_transpose_b
+  cpp_class_name: BatchMatmulTransposeBOp
+  doc: |-
+    Performs a batched matrix multiplication of two 3D inputs where rhs operand has its non-batch
+    dimensions transposed.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+  implements:
+  - LinalgContractionOpInterface
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: A
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
+  - !LinalgOperandDefConfig
+    name: B
+    kind: input_tensor
+    type_var: T2
+    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s3, s2)>
+  - !LinalgOperandDefConfig
+    name: C
+    kind: output_tensor
+    type_var: U
+    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)>
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)>
+    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d2, d3)>
+    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d2)>
+  iterator_types:
+  - parallel
+  - parallel
+  - parallel
+  - reduction
+  assignments:
+  - !ScalarAssign
+    arg: C
+    value: !ScalarExpression
+      scalar_fn:
+        kind: binary
+        fn_name: add
+        operands:
+        - !ScalarExpression
+          scalar_arg: C
+        - !ScalarExpression
+          scalar_fn:
+            kind: binary
+            fn_name: mul
+            operands:
+            - !ScalarExpression
+              scalar_fn:
+                kind: type
+                fn_name: cast_signed
+                type_var: U
+                operands:
+                - !ScalarExpression
+                  scalar_arg: A
+            - !ScalarExpression
+              scalar_fn:
+                kind: type
+                fn_name: cast_signed
+                type_var: U
+                operands:
+                - !ScalarExpression
+                  scalar_arg: B
+--- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: batch_reduce_matmul
   cpp_class_name: BatchReduceMatmulOp

diff  --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 381fdc39354a..3a17350f25ac 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1070,3 +1070,25 @@ func.func @batch_reduce_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32
   linalg.batch_reduce_matmul ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2: memref<?x?xf32>)
   return
 }
+
+// -----
+
+// CHECK-LABEL: func @matmul_transpose_b
+//       CHECK:   linalg.matmul_transpose_b
+//  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<7x5xf32>)
+//  CHECK-SAME:     outs(%{{.+}} : memref<3x7xf32>)
+func.func @matmul_transpose_b(%arg0: memref<3x5xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
+  linalg.matmul_transpose_b ins(%arg0, %arg1 : memref<3x5xf32>, memref<7x5xf32>) outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @batchmatmul_transpose_b
+//       CHECK:   linalg.batch_matmul_transpose_b
+//  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<2x3x5xf32>, memref<2x7x5xf32>)
+//  CHECK-SAME:     outs(%{{.+}} : memref<2x3x7xf32>)
+func.func @batchmatmul_transpose_b(%arg0: memref<2x3x5xf32>, %arg1: memref<2x7x5xf32>, %arg2: memref<2x3x7xf32>) {
+  linalg.batch_matmul_transpose_b ins(%arg0, %arg1 : memref<2x3x5xf32>, memref<2x7x5xf32>) outs(%arg2: memref<2x3x7xf32>)
+  return
+}


        


More information about the Mlir-commits mailing list