[Mlir-commits] [mlir] Add missing `linalg.batch_vecmat` named op (PR #70218)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Oct 25 08:20:18 PDT 2023
https://github.com/bjacob updated https://github.com/llvm/llvm-project/pull/70218
>From 5476ede320266de91111833793ac4fcc934b45d1 Mon Sep 17 00:00:00 2001
From: Benoit Jacob <jacob.benoit.1 at gmail.com>
Date: Wed, 25 Oct 2023 11:00:32 -0400
Subject: [PATCH 1/2] batch_vecmat
---
.../Linalg/IR/LinalgNamedStructuredOps.yaml | 68 +++++++++++++++++++
.../linalg/opdsl/ops/core_named_ops.py | 17 +++++
.../Dialect/Linalg/generalize-named-ops.mlir | 25 +++++++
3 files changed, 110 insertions(+)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index cd64b813c11e532..12d520cd382413a 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -1796,6 +1796,74 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: B
--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+ name: batch_vecmat
+ cpp_class_name: BatchVecmatOp
+ doc: |-
+ Performs a batched matrix-vector multiplication.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ implements:
+ - LinalgContractionOpInterface
+structured_op: !LinalgStructuredOpConfig
+ args:
+ - !LinalgOperandDefConfig
+ name: A
+ kind: input_tensor
+ type_var: T1
+ shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
+ - !LinalgOperandDefConfig
+ name: B
+ kind: input_tensor
+ type_var: T2
+ shape_map: affine_map<()[s0, s1, s2] -> (s0, s1, s2)>
+ - !LinalgOperandDefConfig
+ name: C
+ kind: output_tensor
+ type_var: U
+ shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
+ indexing_maps: !LinalgIndexingMapsConfig
+ static_indexing_maps:
+ - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
+ - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2, d1)>
+ - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
+ iterator_types:
+ - parallel
+ - parallel
+ - reduction
+ assignments:
+ - !ScalarAssign
+ arg: C
+ value: !ScalarExpression
+ scalar_fn:
+ kind: binary
+ fn_name: add
+ operands:
+ - !ScalarExpression
+ scalar_arg: C
+ - !ScalarExpression
+ scalar_fn:
+ kind: binary
+ fn_name: mul
+ operands:
+ - !ScalarExpression
+ scalar_fn:
+ kind: type
+ fn_name: cast_signed
+ type_var: U
+ operands:
+ - !ScalarExpression
+ scalar_arg: A
+ - !ScalarExpression
+ scalar_fn:
+ kind: type
+ fn_name: cast_signed
+ type_var: U
+ operands:
+ - !ScalarExpression
+ scalar_arg: B
+--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: dot
cpp_class_name: DotOp
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 19734a80a107bfe..5144c42480cbc75 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -516,6 +516,23 @@ def batch_matvec(
U, B[D.b, D.k]
)
+ at linalg_structured_op
+def batch_vecmat(
+ A=TensorDef(T1, Batch, S.K),
+ B=TensorDef(T2, Batch, S.K, S.N),
+ C=TensorDef(U, Batch, S.N, output=True),
+):
+ """Performs a batched matrix-vector multiplication.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ """
+ domain(D.b, D.n, D.k)
+ implements(ContractionOpInterface)
+ C[D.b, D.n] += TypeFn.cast_signed(U, A[D.b, D.k]) * TypeFn.cast_signed(
+ U, B[D.b, D.k, D.n]
+ )
+
@linalg_structured_op
def dot(A=TensorDef(T1, S.M), B=TensorDef(T2, S.M), C=TensorDef(U, output=True)):
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index 54cc0defc1f8cd8..2259d47eb2b2b0d 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -251,6 +251,31 @@ func.func @generalize_batch_matm_vec(%lhs : memref<?x?x?xi8>, %rhs: memref<?x?xi
// -----
+func.func @generalize_batch_vecmat(%lhs : memref<?x?xi8>, %rhs: memref<?x?x?xi8>, %out: memref<?x?xf32>) {
+ linalg.batch_vecmat ins(%lhs, %rhs: memref<?x?xi8>, memref<?x?x?xi8>)
+ outs(%out: memref<?x?xf32>)
+ return
+}
+// CHECK: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
+// CHECK: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK: @generalize_batch_vecmat
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]}
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<?x?xi8>, memref<?x?x?xi8>)
+// CHECK-SAME: outs(%{{.+}} : memref<?x?xf32>)
+// CHECK: ^{{.+}}(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8, %[[BBARG2:.+]]: f32)
+// CHECK: %[[BBARG0_F32:.+]] = arith.sitofp %[[BBARG0]] : i8 to f32
+// CHECK: %[[BBARG1_F32:.+]] = arith.sitofp %[[BBARG1]] : i8 to f32
+// CHECK: %[[MUL:.+]] = arith.mulf %[[BBARG0_F32]], %[[BBARG1_F32]]
+// CHECK: %[[ADD:.+]] = arith.addf %[[BBARG2]], %[[MUL]]
+// CHECK: linalg.yield %[[ADD]] : f32
+
+// -----
+
func.func @batch_reduce_gemm(%lhs: memref<7x8x9xf32>, %rhs: memref<7x9x8xf32>, %out: memref<8x8xf32>) {
linalg.batch_reduce_matmul ins(%lhs, %rhs: memref<7x8x9xf32>, memref<7x9x8xf32>)
outs(%out: memref<8x8xf32>)
>From b802b831c1d528cd6d3c75ade7ec4b8d2fd4ce75 Mon Sep 17 00:00:00 2001
From: Benoit Jacob <jacob.benoit.1 at gmail.com>
Date: Wed, 25 Oct 2023 11:20:03 -0400
Subject: [PATCH 2/2] lint
---
mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 5144c42480cbc75..62b7da2ae2b5337 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -516,6 +516,7 @@ def batch_matvec(
U, B[D.b, D.k]
)
+
@linalg_structured_op
def batch_vecmat(
A=TensorDef(T1, Batch, S.K),
More information about the Mlir-commits
mailing list