[Mlir-commits] [mlir] 2c4f569 - Add linalg.batch_matvec named op

Ahmed Taei llvmlistbot at llvm.org
Wed Jun 30 11:37:27 PDT 2021


Author: Ahmed Taei
Date: 2021-06-30T11:37:21-07:00
New Revision: 2c4f5690ab5e435691aafe554725dbbd521b3754

URL: https://github.com/llvm/llvm-project/commit/2c4f5690ab5e435691aafe554725dbbd521b3754
DIFF: https://github.com/llvm/llvm-project/commit/2c4f5690ab5e435691aafe554725dbbd521b3754.diff

LOG: Add linalg.batch_matvec named op

    Similarly to batch_mat vec outer most dim is a batching dim
    and this op does |b| matrix-vector-products :
    C[b, i] = sum_k(A[b, i, k] * B[b, k])

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D104739

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
    mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
    mlir/test/Dialect/Linalg/generalize-named-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index e536b44fe6fb2..8781e16bba34e 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -247,6 +247,68 @@ structured_op: !LinalgStructuredOpConfig
                 - !ScalarExpression
                   scalar_arg: A
 --- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: batch_matvec
+  cpp_class_name: BatchMatvecOp
+  doc: |-
+    Performs a batched matrix-vector multiplication.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+  implements:
+  - LinalgContractionOpInterface
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: A
+    usage: InputOperand
+    type_var: T1
+    shape_map: affine_map<()[s0, s1, s2] -> (s0, s1, s2)>
+  - !LinalgOperandDefConfig
+    name: B
+    usage: InputOperand
+    type_var: T2
+    shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
+  - !LinalgOperandDefConfig
+    name: C
+    usage: OutputOperand
+    type_var: U
+    shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1, d2)>
+    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
+    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
+  iterator_types:
+  - parallel
+  - parallel
+  - reduction
+  assignments:
+  - !ScalarAssign
+    arg: C
+    value: !ScalarExpression
+      scalar_apply:
+        fn_name: add
+        operands:
+        - !ScalarExpression
+          scalar_arg: C
+        - !ScalarExpression
+          scalar_apply:
+            fn_name: mul
+            operands:
+            - !ScalarExpression
+              symbolic_cast:
+                type_var: U
+                operands:
+                - !ScalarExpression
+                  scalar_arg: A
+            - !ScalarExpression
+              symbolic_cast:
+                type_var: U
+                operands:
+                - !ScalarExpression
+                  scalar_arg: B
+--- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: dot
   cpp_class_name: DotOp

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 5867109279aa4..561cd2e7d08db 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -66,6 +66,21 @@ def vecmat(
   x[D.n] += cast(U, y[D.m]) * cast(U, A[D.m, D.n])
 
 
+ at linalg_structured_op
+def batch_matvec(
+    A=TensorDef(T1, Batch, S.M, S.K),
+    B=TensorDef(T2, Batch, S.K),
+    C=TensorDef(U,  Batch, S.M, output=True)):
+  """Performs a batched matrix-vector multiplication.
+
+  Numeric casting is performed on the operands to the inner multiply, promoting
+  them to the same data type as the accumulator/output.
+  """
+  domain(D.b, D.m, D.k)
+  implements(ContractionOpInterface)
+  C[D.b, D.m] += cast(U, A[D.b, D.m, D.k]) * cast(U, B[D.b, D.k])
+
+
 @linalg_structured_op
 def dot(
     A=TensorDef(T1, S.M), B=TensorDef(T2, S.M), C=TensorDef(U, output=True)):

diff  --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index 412309a0f7434..405c7b156da6b 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -490,3 +490,28 @@ func @generalize_fill(%output: memref<?x?xf32>, %value : f32) {
 
 // CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32)
 // CHECK-NEXT:      linalg.yield %[[BBARG0]] : f32
+
+// -----
+
+func @generalize_batch_matm_vec(%lhs : memref<?x?x?xi8>, %rhs: memref<?x?xi8>,  %out: memref<?x?xf32>) {
+  linalg.batch_matvec ins(%lhs, %rhs: memref<?x?x?xi8>, memref<?x?xi8>)
+                     outs(%out: memref<?x?xf32>)
+  return
+}
+// CHECK: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK: @generalize_batch_matm_vec
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]}
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<?x?x?xi8>, memref<?x?xi8>)
+// CHECK-SAME: outs(%{{.+}} : memref<?x?xf32>)
+// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8, %[[BBARG2:.+]]: f32)
+// CHECK:            %[[BBARG0_F32:.+]] = sitofp %[[BBARG0]] : i8 to f32
+// CHECK:            %[[BBARG1_F32:.+]] = sitofp %[[BBARG1]] : i8 to f32
+// CHECK:            %[[MUL:.+]] = mulf %[[BBARG0_F32]], %[[BBARG1_F32]]
+// CHECK:            %[[ADD:.+]] = addf %[[BBARG2]], %[[MUL]]
+// CHECK:            linalg.yield %[[ADD]] : f32


        


More information about the Mlir-commits mailing list