[Mlir-commits] [mlir] 2c4f569 - Add linalg.batch_matvec named op
Ahmed Taei
llvmlistbot at llvm.org
Wed Jun 30 11:37:27 PDT 2021
Author: Ahmed Taei
Date: 2021-06-30T11:37:21-07:00
New Revision: 2c4f5690ab5e435691aafe554725dbbd521b3754
URL: https://github.com/llvm/llvm-project/commit/2c4f5690ab5e435691aafe554725dbbd521b3754
DIFF: https://github.com/llvm/llvm-project/commit/2c4f5690ab5e435691aafe554725dbbd521b3754.diff
LOG: Add linalg.batch_matvec named op
Similarly to batch_mat vec outer most dim is a batching dim
and this op does |b| matrix-vector-products :
C[b, i] = sum_k(A[b, i, k] * B[b, k])
Reviewed By: rsuderman
Differential Revision: https://reviews.llvm.org/D104739
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
mlir/test/Dialect/Linalg/generalize-named-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index e536b44fe6fb2..8781e16bba34e 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -247,6 +247,68 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: A
--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+ name: batch_matvec
+ cpp_class_name: BatchMatvecOp
+ doc: |-
+ Performs a batched matrix-vector multiplication.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ implements:
+ - LinalgContractionOpInterface
+structured_op: !LinalgStructuredOpConfig
+ args:
+ - !LinalgOperandDefConfig
+ name: A
+ usage: InputOperand
+ type_var: T1
+ shape_map: affine_map<()[s0, s1, s2] -> (s0, s1, s2)>
+ - !LinalgOperandDefConfig
+ name: B
+ usage: InputOperand
+ type_var: T2
+ shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
+ - !LinalgOperandDefConfig
+ name: C
+ usage: OutputOperand
+ type_var: U
+ shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
+ indexing_maps: !LinalgIndexingMapsConfig
+ static_indexing_maps:
+ - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1, d2)>
+ - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
+ - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
+ iterator_types:
+ - parallel
+ - parallel
+ - reduction
+ assignments:
+ - !ScalarAssign
+ arg: C
+ value: !ScalarExpression
+ scalar_apply:
+ fn_name: add
+ operands:
+ - !ScalarExpression
+ scalar_arg: C
+ - !ScalarExpression
+ scalar_apply:
+ fn_name: mul
+ operands:
+ - !ScalarExpression
+ symbolic_cast:
+ type_var: U
+ operands:
+ - !ScalarExpression
+ scalar_arg: A
+ - !ScalarExpression
+ symbolic_cast:
+ type_var: U
+ operands:
+ - !ScalarExpression
+ scalar_arg: B
+--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: dot
cpp_class_name: DotOp
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 5867109279aa4..561cd2e7d08db 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -66,6 +66,21 @@ def vecmat(
x[D.n] += cast(U, y[D.m]) * cast(U, A[D.m, D.n])
+ at linalg_structured_op
+def batch_matvec(
+ A=TensorDef(T1, Batch, S.M, S.K),
+ B=TensorDef(T2, Batch, S.K),
+ C=TensorDef(U, Batch, S.M, output=True)):
+ """Performs a batched matrix-vector multiplication.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ """
+ domain(D.b, D.m, D.k)
+ implements(ContractionOpInterface)
+ C[D.b, D.m] += cast(U, A[D.b, D.m, D.k]) * cast(U, B[D.b, D.k])
+
+
@linalg_structured_op
def dot(
A=TensorDef(T1, S.M), B=TensorDef(T2, S.M), C=TensorDef(U, output=True)):
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index 412309a0f7434..405c7b156da6b 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -490,3 +490,28 @@ func @generalize_fill(%output: memref<?x?xf32>, %value : f32) {
// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32)
// CHECK-NEXT: linalg.yield %[[BBARG0]] : f32
+
+// -----
+
+func @generalize_batch_matm_vec(%lhs : memref<?x?x?xi8>, %rhs: memref<?x?xi8>, %out: memref<?x?xf32>) {
+ linalg.batch_matvec ins(%lhs, %rhs: memref<?x?x?xi8>, memref<?x?xi8>)
+ outs(%out: memref<?x?xf32>)
+ return
+}
+// CHECK: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK: @generalize_batch_matm_vec
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]}
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<?x?x?xi8>, memref<?x?xi8>)
+// CHECK-SAME: outs(%{{.+}} : memref<?x?xf32>)
+// CHECK: ^{{.+}}(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8, %[[BBARG2:.+]]: f32)
+// CHECK: %[[BBARG0_F32:.+]] = sitofp %[[BBARG0]] : i8 to f32
+// CHECK: %[[BBARG1_F32:.+]] = sitofp %[[BBARG1]] : i8 to f32
+// CHECK: %[[MUL:.+]] = mulf %[[BBARG0_F32]], %[[BBARG1_F32]]
+// CHECK: %[[ADD:.+]] = addf %[[BBARG2]], %[[MUL]]
+// CHECK: linalg.yield %[[ADD]] : f32
More information about the Mlir-commits
mailing list