[Mlir-commits] [mlir] 3718082 - [MLIR][Linalg] introduce batch-reduce GEMM

Lorenzo Chelini llvmlistbot at llvm.org
Mon Sep 19 03:51:19 PDT 2022


Author: Lorenzo Chelini
Date: 2022-09-19T12:50:27+02:00
New Revision: 3718082e2b11cd1c66317bc067aac7cb094226cc

URL: https://github.com/llvm/llvm-project/commit/3718082e2b11cd1c66317bc067aac7cb094226cc
DIFF: https://github.com/llvm/llvm-project/commit/3718082e2b11cd1c66317bc067aac7cb094226cc.diff

LOG: [MLIR][Linalg] introduce batch-reduce GEMM

The batch-reduce GEMM kernel essentially multiplies a sequence of input tensor
blocks (which form a batch) and the partial multiplication results are reduced
into a single output tensor block.

See: https://ieeexplore.ieee.org/document/9139809 for more details.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D134163

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
    mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
    mlir/test/Dialect/Linalg/generalize-named-ops.mlir
    mlir/test/Dialect/Linalg/named-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index cd943e2183252..57c6100c3ae39 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -653,6 +653,76 @@ structured_op: !LinalgStructuredOpConfig
                     - !ScalarExpression
                       scalar_arg: BZp
 --- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: batch_reduce_matmul
+  cpp_class_name: BatchReduceMatmulOp
+  doc: |-
+    Performs a batch-reduce matrix multiplication of two 3D inputs.
+    The partial multiplication results are reduced into a 2D output.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output."
+  implements:
+  - LinalgContractionOpInterface
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: A
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
+  - !LinalgOperandDefConfig
+    name: B
+    kind: input_tensor
+    type_var: T2
+    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s2, s3)>
+  - !LinalgOperandDefConfig
+    name: C
+    kind: output_tensor
+    type_var: U
+    shape_map: affine_map<()[s0, s1, s2, s3] -> (s1, s3)>
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)>
+    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d2)>
+    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d1, d2)>
+  iterator_types:
+  - reduction
+  - parallel
+  - parallel
+  - reduction
+  assignments:
+  - !ScalarAssign
+    arg: C
+    value: !ScalarExpression
+      scalar_fn:
+        kind: binary
+        fn_name: add
+        operands:
+        - !ScalarExpression
+          scalar_arg: C
+        - !ScalarExpression
+          scalar_fn:
+            kind: type
+            fn_name: cast_signed
+            type_var: U
+            operands:
+            - !ScalarExpression
+              scalar_fn:
+                kind: binary
+                fn_name: mul
+                operands:
+                - !ScalarExpression
+                  scalar_arg: A
+                - !ScalarExpression
+                  scalar_fn:
+                    kind: type
+                    fn_name: cast_signed
+                    type_var: U
+                    operands:
+                    - !ScalarExpression
+                      scalar_arg: B
+--- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: matvec
   cpp_class_name: MatvecOp

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 983842cde1325..1aa112dcf9186 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -150,6 +150,20 @@ def quantized_batch_matmul(A=TensorDef(T1, Batch, S.M, S.K),
                        TypeFn.cast_signed(U, AZp)) * (TypeFn.cast_signed(
                            U, B[D.b, D.k, D.n]) - TypeFn.cast_signed(U, BZp))
 
+ at linalg_structured_op
+def batch_reduce_matmul(A=TensorDef(T1, Batch, S.M, S.K),
+                        B=TensorDef(T2, Batch, S.K, S.N),
+                        C=TensorDef(U, S.M, S.N, output=True)):
+  """Performs a batch-reduce matrix multiplication of two 3D inputs.
+  The partial multiplication results are reduced into a 2D output.
+
+  Numeric casting is performed on the operands to the inner multiply, promoting
+  them to the same data type as the accumulator/output.
+  """
+  domain(D.b, D.m, D.n, D.k)
+  implements(ContractionOpInterface)
+  C[D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k] * TypeFn.cast_signed(
+    U, B[D.b, D.k, D.n]))
 
 @linalg_structured_op
 def matvec(A=TensorDef(T1, S.M, S.N),

diff  --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index 7fdabbae1c159..e43b13a5e5955 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -248,3 +248,27 @@ func.func @generalize_batch_matm_vec(%lhs : memref<?x?x?xi8>, %rhs: memref<?x?xi
 // CHECK:            %[[MUL:.+]] = arith.mulf %[[BBARG0_F32]], %[[BBARG1_F32]]
 // CHECK:            %[[ADD:.+]] = arith.addf %[[BBARG2]], %[[MUL]]
 // CHECK:            linalg.yield %[[ADD]] : f32
+
+// -----
+
+func.func @batch_reduce_gemm(%lhs: memref<7x8x9xf32>, %rhs: memref<7x9x8xf32>, %out: memref<8x8xf32>) {
+  linalg.batch_reduce_matmul ins(%lhs, %rhs: memref<7x8x9xf32>, memref<7x9x8xf32>)  
+                             outs(%out: memref<8x8xf32>)
+  return
+}
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+// CHECK: @batch_reduce_gemm
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
+// CHECK-SAME: iterator_types = ["reduction", "parallel", "parallel", "reduction"]}
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<7x8x9xf32>, memref<7x9x8xf32>)
+// CHECK-SAME: outs(%{{.+}} : memref<8x8xf32>
+// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32)
+// CHECK:         %[[MUL:.+]] = arith.mulf %[[BBARG0]], %[[BBARG1]] : f32
+// CHECK:         %[[ADD:.+]] = arith.addf %[[BBARG2]], %[[MUL]] : f32
+// CHECK:         linalg.yield %[[ADD]] : f32

diff  --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index f4126b4cf9ea9..5a7b7ffbb59fd 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -794,3 +794,23 @@ func.func @conv_interface_wrong_num_operands(
     }) {dilations = dense<1> : tensor<2xi64>, linalg.memoized_indexing_maps = [#map0, #map1, #map2], operand_segment_sizes = array<i32: 2, 1>, strides = dense<1> : tensor<2xi64>} : (tensor<?x?x?x?xf32>, tensor<?x?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
   return %0 : tensor<?x?x?x?xf32>
 }
+
+// -----
+
+func.func @batch_reduce_matmul(%arg0: tensor<8x128x256xf32>, %arg1: tensor<8x256x512xf32>, %arg2: tensor<128x512xf32>) -> tensor<128x512xf32> {
+  // CHECK: %{{.+}} = linalg.batch_reduce_matmul
+  // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<8x128x256xf32>, tensor<8x256x512xf32>)
+  // CHECK-SAME: outs(%{{.+}} : tensor<128x512xf32>) -> tensor<128x512xf32>
+  %0 = linalg.batch_reduce_matmul ins(%arg0, %arg1 : tensor<8x128x256xf32>, tensor<8x256x512xf32>) outs(%arg2: tensor<128x512xf32>) -> tensor<128x512xf32>
+  return %0: tensor<128x512xf32>
+}
+
+// -----
+
+func.func @batch_reduce_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+  // CHECK: linalg.batch_reduce_matmul
+  // CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<?x?x?xf32>, memref<?x?x?xf32>)
+  // CHECK-SAME: outs(%{{.+}} : memref<?x?xf32>)
+  linalg.batch_reduce_matmul ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2: memref<?x?xf32>)
+  return
+}


        


More information about the Mlir-commits mailing list