[Mlir-commits] [mlir] cb1911a - [mlir][linalg] Add named op for matmul_transpose_a

Nicolas Vasilache llvmlistbot at llvm.org
Mon Jun 26 13:09:04 PDT 2023


Author: Nicolas Vasilache
Date: 2023-06-26T20:09:00Z
New Revision: cb1911a2db14b02a3f7855723d9138d26c41e66f

URL: https://github.com/llvm/llvm-project/commit/cb1911a2db14b02a3f7855723d9138d26c41e66f
DIFF: https://github.com/llvm/llvm-project/commit/cb1911a2db14b02a3f7855723d9138d26c41e66f.diff

LOG: [mlir][linalg] Add named op for matmul_transpose_a

matmul with transposed LHS operand allows better memory access
patterns on several architectures including common GPUs. Having a named
op for it allows to handle this kind of matmul in a more explicit way.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
    mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
    mlir/test/Dialect/Linalg/named-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 09516b95cc4bb..a89a112574861 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -400,6 +400,79 @@ structured_op: !LinalgStructuredOpConfig
                     - !ScalarExpression
                       scalar_arg: BZp
 --- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: matmul_transpose_a
+  cpp_class_name: MatmulTransposeAOp
+  doc: |-
+    Performs a matrix multiplication of two 2D inputs with lhs operand
+    transposed.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+  implements:
+  - LinalgContractionOpInterface
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: A
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
+  - !LinalgOperandDefConfig
+    name: B
+    kind: input_tensor
+    type_var: T2
+    shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
+  - !LinalgOperandDefConfig
+    name: C
+    kind: output_tensor
+    type_var: U
+    shape_map: affine_map<()[s0, s1, s2] -> (s2, s1)>
+  - !LinalgOperandDefConfig
+    name: cast
+    kind: type_fn_attr
+    default_fn: cast_signed
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d0)>
+    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d1)>
+    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
+  iterator_types:
+  - parallel
+  - parallel
+  - reduction
+  assignments:
+  - !ScalarAssign
+    arg: C
+    value: !ScalarExpression
+      scalar_fn:
+        kind: binary
+        fn_name: add
+        operands:
+        - !ScalarExpression
+          scalar_arg: C
+        - !ScalarExpression
+          scalar_fn:
+            kind: binary
+            fn_name: mul
+            operands:
+            - !ScalarExpression
+              scalar_fn:
+                kind: type
+                attr_name: cast
+                type_var: U
+                operands:
+                - !ScalarExpression
+                  scalar_arg: A
+            - !ScalarExpression
+              scalar_fn:
+                kind: type
+                attr_name: cast
+                type_var: U
+                operands:
+                - !ScalarExpression
+                  scalar_arg: B
+--- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: matmul_transpose_b
   cpp_class_name: MatmulTransposeBOp
@@ -621,6 +694,76 @@ structured_op: !LinalgStructuredOpConfig
                 - !ScalarExpression
                   scalar_arg: B
 --- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: batch_matmul_transpose_a
+  cpp_class_name: BatchMatmulTransposeAOp
+  doc: |-
+    Performs a batched matrix multiplication of two 3D inputs where lhs operand
+    has its non-batch dimensions transposed.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+  implements:
+  - LinalgContractionOpInterface
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: A
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
+  - !LinalgOperandDefConfig
+    name: B
+    kind: input_tensor
+    type_var: T2
+    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)>
+  - !LinalgOperandDefConfig
+    name: C
+    kind: output_tensor
+    type_var: U
+    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s2, s3)>
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d1)>
+    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d2)>
+    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d2)>
+  iterator_types:
+  - parallel
+  - parallel
+  - parallel
+  - reduction
+  assignments:
+  - !ScalarAssign
+    arg: C
+    value: !ScalarExpression
+      scalar_fn:
+        kind: binary
+        fn_name: add
+        operands:
+        - !ScalarExpression
+          scalar_arg: C
+        - !ScalarExpression
+          scalar_fn:
+            kind: binary
+            fn_name: mul
+            operands:
+            - !ScalarExpression
+              scalar_fn:
+                kind: type
+                fn_name: cast_signed
+                type_var: U
+                operands:
+                - !ScalarExpression
+                  scalar_arg: A
+            - !ScalarExpression
+              scalar_fn:
+                kind: type
+                fn_name: cast_signed
+                type_var: U
+                operands:
+                - !ScalarExpression
+                  scalar_arg: B
+--- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: batch_matmul_transpose_b
   cpp_class_name: BatchMatmulTransposeBOp

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 263d2109b4a9f..4c3e8fb25700b 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -107,6 +107,22 @@ def quantized_matmul(
     )
 
 
+ at linalg_structured_op
+def matmul_transpose_a(A=TensorDef(T1, S.K, S.N),
+                       B=TensorDef(T2, S.K, S.M),
+                       C=TensorDef(U, S.M, S.N, output=True),
+                       cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
+  """Performs a matrix multiplication of two 2D inputs with lhs operand
+  transposed.
+
+  Numeric casting is performed on the operands to the inner multiply, promoting
+  them to the same data type as the accumulator/output.
+  """
+  domain(D.m, D.n, D.k)
+  implements(ContractionOpInterface)
+  C[D.m, D.n] += cast(U, A[D.k, D.m]) * cast(U, B[D.k, D.n])
+
+
 @linalg_structured_op
 def matmul_transpose_b(A=TensorDef(T1, S.M, S.K),
                        B=TensorDef(T2, S.N, S.K),
@@ -164,6 +180,22 @@ def batch_matmul(
     )
 
 
+ at linalg_structured_op
+def batch_matmul_transpose_a(A=TensorDef(T1, Batch, S.K, S.M),
+                             B=TensorDef(T2, Batch, S.K, S.N),
+                             C=TensorDef(U, Batch, S.M, S.N, output=True)):
+  """Performs a batched matrix multiplication of two 3D inputs where lhs operand
+  has its non-batch dimensions transposed.
+
+  Numeric casting is performed on the operands to the inner multiply, promoting
+  them to the same data type as the accumulator/output.
+  """
+  domain(D.b, D.m, D.n, D.k)
+  implements(ContractionOpInterface)
+  C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.k, D.m]) \
+                    * TypeFn.cast_signed(U, B[D.b, D.k, D.n])
+
+
 @linalg_structured_op
 def batch_matmul_transpose_b(A=TensorDef(T1, Batch, S.M, S.K),
                              B=TensorDef(T2, Batch, S.N, S.K),

diff  --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index e89fb8d682e22..9396e2197e3d6 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1143,6 +1143,17 @@ func.func @batch_reduce_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32
 
 // -----
 
+// CHECK-LABEL: func @matmul_transpose_a
+//       CHECK:   linalg.matmul_transpose_a
+//  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<5x3xf32>, memref<5x7xf32>)
+//  CHECK-SAME:     outs(%{{.+}} : memref<3x7xf32>)
+func.func @matmul_transpose_a(%arg0: memref<5x3xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+  linalg.matmul_transpose_a ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// -----
+
 // CHECK-LABEL: func @matmul_transpose_b
 //       CHECK:   linalg.matmul_transpose_b
 //  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<7x5xf32>)
@@ -1154,6 +1165,17 @@ func.func @matmul_transpose_b(%arg0: memref<3x5xf32>, %arg1: memref<7x5xf32>, %a
 
 // -----
 
+// CHECK-LABEL: func @batchmatmul_transpose_a
+//       CHECK:   linalg.batch_matmul_transpose_a
+//  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<2x5x3xf32>, memref<2x5x7xf32>)
+//  CHECK-SAME:     outs(%{{.+}} : memref<2x3x7xf32>)
+func.func @batchmatmul_transpose_a(%arg0: memref<2x5x3xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<2x3x7xf32>) {
+  linalg.batch_matmul_transpose_a ins(%arg0, %arg1 : memref<2x5x3xf32>, memref<2x5x7xf32>) outs(%arg2: memref<2x3x7xf32>)
+  return
+}
+
+// -----
+
 // CHECK-LABEL: func @batchmatmul_transpose_b
 //       CHECK:   linalg.batch_matmul_transpose_b
 //  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<2x3x5xf32>, memref<2x7x5xf32>)


        


More information about the Mlir-commits mailing list