[Mlir-commits] [mlir] 8c8336f - Add missing `linalg.batch_vecmat` named op (#70218)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Oct 25 08:41:27 PDT 2023


Author: bjacob
Date: 2023-10-25T11:41:24-04:00
New Revision: 8c8336fcadac9ca24d50c635d25562b805e1ff1d

URL: https://github.com/llvm/llvm-project/commit/8c8336fcadac9ca24d50c635d25562b805e1ff1d
DIFF: https://github.com/llvm/llvm-project/commit/8c8336fcadac9ca24d50c635d25562b805e1ff1d.diff

LOG: Add missing `linalg.batch_vecmat` named op (#70218)

Linalg currently has these named ops:
* `matmul`
* `matvec`
* `vecmat`
* `batch_matmul`
* `batch_matvec`

But it does not have:
* `batch_vecmat`

This PRs adds that for consistency, and I have a short-term need for it
( https://github.com/openxla/iree/issues/15158 ), so not having this
would cause some contortion on my end.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
    mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
    mlir/test/Dialect/Linalg/generalize-named-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index cd64b813c11e532..12d520cd382413a 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -1796,6 +1796,74 @@ structured_op: !LinalgStructuredOpConfig
                 - !ScalarExpression
                   scalar_arg: B
 --- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: batch_vecmat
+  cpp_class_name: BatchVecmatOp
+  doc: |-
+    Performs a batched matrix-vector multiplication.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+  implements:
+  - LinalgContractionOpInterface
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: A
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
+  - !LinalgOperandDefConfig
+    name: B
+    kind: input_tensor
+    type_var: T2
+    shape_map: affine_map<()[s0, s1, s2] -> (s0, s1, s2)>
+  - !LinalgOperandDefConfig
+    name: C
+    kind: output_tensor
+    type_var: U
+    shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
+    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2, d1)>
+    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
+  iterator_types:
+  - parallel
+  - parallel
+  - reduction
+  assignments:
+  - !ScalarAssign
+    arg: C
+    value: !ScalarExpression
+      scalar_fn:
+        kind: binary
+        fn_name: add
+        operands:
+        - !ScalarExpression
+          scalar_arg: C
+        - !ScalarExpression
+          scalar_fn:
+            kind: binary
+            fn_name: mul
+            operands:
+            - !ScalarExpression
+              scalar_fn:
+                kind: type
+                fn_name: cast_signed
+                type_var: U
+                operands:
+                - !ScalarExpression
+                  scalar_arg: A
+            - !ScalarExpression
+              scalar_fn:
+                kind: type
+                fn_name: cast_signed
+                type_var: U
+                operands:
+                - !ScalarExpression
+                  scalar_arg: B
+--- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: dot
   cpp_class_name: DotOp

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 19734a80a107bfe..62b7da2ae2b5337 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -517,6 +517,24 @@ def batch_matvec(
     )
 
 
+ at linalg_structured_op
+def batch_vecmat(
+    A=TensorDef(T1, Batch, S.K),
+    B=TensorDef(T2, Batch, S.K, S.N),
+    C=TensorDef(U, Batch, S.N, output=True),
+):
+    """Performs a batched matrix-vector multiplication.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+    """
+    domain(D.b, D.n, D.k)
+    implements(ContractionOpInterface)
+    C[D.b, D.n] += TypeFn.cast_signed(U, A[D.b, D.k]) * TypeFn.cast_signed(
+        U, B[D.b, D.k, D.n]
+    )
+
+
 @linalg_structured_op
 def dot(A=TensorDef(T1, S.M), B=TensorDef(T2, S.M), C=TensorDef(U, output=True)):
     """Performs a dot product of two vectors to a scalar result.

diff  --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index 54cc0defc1f8cd8..2259d47eb2b2b0d 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -251,6 +251,31 @@ func.func @generalize_batch_matm_vec(%lhs : memref<?x?x?xi8>, %rhs: memref<?x?xi
 
 // -----
 
+func.func @generalize_batch_vecmat(%lhs : memref<?x?xi8>, %rhs: memref<?x?x?xi8>,  %out: memref<?x?xf32>) {
+  linalg.batch_vecmat ins(%lhs, %rhs: memref<?x?xi8>, memref<?x?x?xi8>)
+                     outs(%out: memref<?x?xf32>)
+  return
+}
+// CHECK: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
+// CHECK: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK: @generalize_batch_vecmat
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]}
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<?x?xi8>, memref<?x?x?xi8>)
+// CHECK-SAME: outs(%{{.+}} : memref<?x?xf32>)
+// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8, %[[BBARG2:.+]]: f32)
+// CHECK:            %[[BBARG0_F32:.+]] = arith.sitofp %[[BBARG0]] : i8 to f32
+// CHECK:            %[[BBARG1_F32:.+]] = arith.sitofp %[[BBARG1]] : i8 to f32
+// CHECK:            %[[MUL:.+]] = arith.mulf %[[BBARG0_F32]], %[[BBARG1_F32]]
+// CHECK:            %[[ADD:.+]] = arith.addf %[[BBARG2]], %[[MUL]]
+// CHECK:            linalg.yield %[[ADD]] : f32
+
+// -----
+
 func.func @batch_reduce_gemm(%lhs: memref<7x8x9xf32>, %rhs: memref<7x9x8xf32>, %out: memref<8x8xf32>) {
   linalg.batch_reduce_matmul ins(%lhs, %rhs: memref<7x8x9xf32>, memref<7x9x8xf32>)
                              outs(%out: memref<8x8xf32>)


        


More information about the Mlir-commits mailing list