[Mlir-commits] [mlir] f381768 - [MLIR][Linalg] introduce batch-reduce GEMM
Lorenzo Chelini
llvmlistbot at llvm.org
Mon Sep 19 03:12:03 PDT 2022
Author: lorenzo chelini
Date: 2022-09-19T12:11:54+02:00
New Revision: f381768a8da6bd6bde8bdff34f080bf12bf20064
URL: https://github.com/llvm/llvm-project/commit/f381768a8da6bd6bde8bdff34f080bf12bf20064
DIFF: https://github.com/llvm/llvm-project/commit/f381768a8da6bd6bde8bdff34f080bf12bf20064.diff
LOG: [MLIR][Linalg] introduce batch-reduce GEMM
The batch-reduce GEMM kernel essentially multiplies a sequence of input tensor
blocks (which form a batch) and the partial multiplication results are reduced
into a single output tensor block.
See: https://ieeexplore.ieee.org/document/9139809 for more details.
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D134163
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
mlir/test/Dialect/Linalg/generalize-named-ops.mlir
mlir/test/Dialect/Linalg/named-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index cd943e2183252..d86b82fcbe45a 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -548,6 +548,76 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: B
--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+ name: batch_reduce_matmul
+ cpp_class_name: BatchReduceMatmulOp
+ doc: |-
+ Performs a batch-reduce matrix multiplication of two 3D inputs.
+ The partial multiplication results are reduced into a 2D output.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ implements:
+ - LinalgContractionOpInterface
+structured_op: !LinalgStructuredOpConfig
+ args:
+ - !LinalgOperandDefConfig
+ name: A
+ kind: input_tensor
+ type_var: T1
+ shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)>
+ - !LinalgOperandDefConfig
+ name: B
+ kind: input_tensor
+ type_var: T2
+ shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s3, s2)>
+ - !LinalgOperandDefConfig
+ name: C
+ kind: output_tensor
+ type_var: U
+ shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
+ indexing_maps: !LinalgIndexingMapsConfig
+ static_indexing_maps:
+ - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)>
+ - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d2)>
+ - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d1, d2)>
+ iterator_types:
+ - reduction
+ - parallel
+ - parallel
+ - reduction
+ assignments:
+ - !ScalarAssign
+ arg: C
+ value: !ScalarExpression
+ scalar_fn:
+ kind: binary
+ fn_name: add
+ operands:
+ - !ScalarExpression
+ scalar_arg: C
+ - !ScalarExpression
+ scalar_fn:
+ kind: binary
+ fn_name: mul
+ operands:
+ - !ScalarExpression
+ scalar_fn:
+ kind: type
+ fn_name: cast_signed
+ type_var: U
+ operands:
+ - !ScalarExpression
+ scalar_arg: A
+ - !ScalarExpression
+ scalar_fn:
+ kind: type
+ fn_name: cast_signed
+ type_var: U
+ operands:
+ - !ScalarExpression
+ scalar_arg: B
+--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: quantized_batch_matmul
cpp_class_name: QuantizedBatchMatmulOp
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 983842cde1325..b9b292d847cd9 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -150,6 +150,20 @@ def quantized_batch_matmul(A=TensorDef(T1, Batch, S.M, S.K),
TypeFn.cast_signed(U, AZp)) * (TypeFn.cast_signed(
U, B[D.b, D.k, D.n]) - TypeFn.cast_signed(U, BZp))
+ at linalg_structured_op
+def batch_reduce_matmul(A=TensorDef(T1, Batch, S.M, S.K),
+ B=TensorDef(T2, Batch, S.K, S.N),
+ C=TensorDef(U, S.M, S.N, output=True)):
+ """Performs a batch-reduce matrix multiplication of two 3D inputs.
+ The partial multiplication results are reduced into a 2D output.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ """
+ domain(D.b, D.m, D.n, D.k)
+ implements(ContractionOpInterface)
+ C[D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k] * TypeFn.cast_signed(
+ U, B[D.b, D.k, D.n])
@linalg_structured_op
def matvec(A=TensorDef(T1, S.M, S.N),
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index 7fdabbae1c159..e43b13a5e5955 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -248,3 +248,27 @@ func.func @generalize_batch_matm_vec(%lhs : memref<?x?x?xi8>, %rhs: memref<?x?xi
// CHECK: %[[MUL:.+]] = arith.mulf %[[BBARG0_F32]], %[[BBARG1_F32]]
// CHECK: %[[ADD:.+]] = arith.addf %[[BBARG2]], %[[MUL]]
// CHECK: linalg.yield %[[ADD]] : f32
+
+// -----
+
+func.func @batch_reduce_gemm(%lhs: memref<7x8x9xf32>, %rhs: memref<7x9x8xf32>, %out: memref<8x8xf32>) {
+ linalg.batch_reduce_matmul ins(%lhs, %rhs: memref<7x8x9xf32>, memref<7x9x8xf32>)
+ outs(%out: memref<8x8xf32>)
+ return
+}
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+// CHECK: @batch_reduce_gemm
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
+// CHECK-SAME: iterator_types = ["reduction", "parallel", "parallel", "reduction"]}
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<7x8x9xf32>, memref<7x9x8xf32>)
+// CHECK-SAME: outs(%{{.+}} : memref<8x8xf32>
+// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32)
+// CHECK: %[[MUL:.+]] = arith.mulf %[[BBARG0]], %[[BBARG1]] : f32
+// CHECK: %[[ADD:.+]] = arith.addf %[[BBARG2]], %[[MUL]] : f32
+// CHECK: linalg.yield %[[ADD]] : f32
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index f4126b4cf9ea9..5a7b7ffbb59fd 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -794,3 +794,23 @@ func.func @conv_interface_wrong_num_operands(
}) {dilations = dense<1> : tensor<2xi64>, linalg.memoized_indexing_maps = [#map0, #map1, #map2], operand_segment_sizes = array<i32: 2, 1>, strides = dense<1> : tensor<2xi64>} : (tensor<?x?x?x?xf32>, tensor<?x?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
return %0 : tensor<?x?x?x?xf32>
}
+
+// -----
+
+func.func @batch_reduce_matmul(%arg0: tensor<8x128x256xf32>, %arg1: tensor<8x256x512xf32>, %arg2: tensor<128x512xf32>) -> tensor<128x512xf32> {
+ // CHECK: %{{.+}} = linalg.batch_reduce_matmul
+ // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<8x128x256xf32>, tensor<8x256x512xf32>)
+ // CHECK-SAME: outs(%{{.+}} : tensor<128x512xf32>) -> tensor<128x512xf32>
+ %0 = linalg.batch_reduce_matmul ins(%arg0, %arg1 : tensor<8x128x256xf32>, tensor<8x256x512xf32>) outs(%arg2: tensor<128x512xf32>) -> tensor<128x512xf32>
+ return %0: tensor<128x512xf32>
+}
+
+// -----
+
+func.func @batch_reduce_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+ // CHECK: linalg.batch_reduce_matmul
+ // CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<?x?x?xf32>, memref<?x?x?xf32>)
+ // CHECK-SAME: outs(%{{.+}} : memref<?x?xf32>)
+ linalg.batch_reduce_matmul ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2: memref<?x?xf32>)
+ return
+}
More information about the Mlir-commits
mailing list