[Mlir-commits] [mlir] 6d2fd3d - [mlir][linalg] Replace monomorphic contration ops with polymorphic variants.

Stella Laurenzo llvmlistbot at llvm.org
Mon Mar 1 21:22:35 PST 2021


Author: Stella Laurenzo
Date: 2021-03-01T21:19:53-08:00
New Revision: 6d2fd3d9cdd6ed24784ec47741e7e70c236a140e

URL: https://github.com/llvm/llvm-project/commit/6d2fd3d9cdd6ed24784ec47741e7e70c236a140e
DIFF: https://github.com/llvm/llvm-project/commit/6d2fd3d9cdd6ed24784ec47741e7e70c236a140e.diff

LOG: [mlir][linalg] Replace monomorphic contration ops with polymorphic variants.

* Moves `batch_matmul`, `matmul`, `matvec`, `vectmat`, `dot` to the new mechanism.
* This is not just an NFC change, in addition to using a new code generation mechanism, it also activates symbolic casting, allowing mixed precision operands and results.
* These definitions were generated from DSL by the tool: https://github.com/stellaraccident/mlir-linalgpy/blob/main/mlir_linalg/oplib/core.py (will be upstreamed in a subsequent set of changes).

Reviewed By: nicolasvasilache, ThomasRaoux

Differential Revision: https://reviews.llvm.org/D97719

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
    mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
    mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 93bc5760ed0c..5752af9bea9a 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -1,12 +1,12 @@
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
-  name: polymorphic_matmul
-  cpp_op_name: PolymorphicMatmulOp
+  name: matmul
+  cpp_op_name: MatmulOp
   doc: |-
-    Type polymorphic matrix multiplication.
+    Performs a matrix multiplacation of two 2D inputs.
 
-    This op is presently here to test a new path for generation and will replace
-    the existing 'matmul' op when ready. Do not use.
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
   implements:
   - LinalgContractionOpInterface
 structured_op: !LinalgStructuredOpConfig
@@ -60,4 +60,249 @@ structured_op: !LinalgStructuredOpConfig
                 operands:
                 - !ScalarExpression
                   scalar_arg: B
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: batch_matmul
+  cpp_op_name: BatchMatmulOp
+  doc: |-
+    Performs a batched matrix multiplacation of two 3D inputs.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+  implements:
+  - LinalgContractionOpInterface
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !<LinalgTensorDef>
+    name: A
+    usage: input
+    shape: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)>
+    element_type_var: T1
+  - !<LinalgTensorDef>
+    name: B
+    usage: input
+    shape: affine_map<()[s0, s1, s2, s3] -> (s0, s3, s2)>
+    element_type_var: T2
+  - !<LinalgTensorDef>
+    name: C
+    usage: output
+    shape: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
+    element_type_var: U
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)>
+    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d2)>
+    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d2)>
+  iterator_types:
+  - parallel
+  - parallel
+  - parallel
+  - reduction
+  assignments:
+  - !ScalarAssign
+    arg: C
+    value: !ScalarExpression
+      scalar_apply:
+        fn_name: add
+        operands:
+        - !ScalarExpression
+          scalar_arg: C
+        - !ScalarExpression
+          scalar_apply:
+            fn_name: mul
+            operands:
+            - !ScalarExpression
+              symbolic_cast:
+                type_var: U
+                operands:
+                - !ScalarExpression
+                  scalar_arg: A
+            - !ScalarExpression
+              symbolic_cast:
+                type_var: U
+                operands:
+                - !ScalarExpression
+                  scalar_arg: B
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: matvec
+  cpp_op_name: MatvecOp
+  doc: |-
+    Performs a matrix-vector multiplication.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+  implements:
+  - LinalgContractionOpInterface
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !<LinalgTensorDef>
+    name: A
+    usage: input
+    shape: affine_map<()[s0, s1] -> (s0, s1)>
+    element_type_var: T1
+  - !<LinalgTensorDef>
+    name: y
+    usage: input
+    shape: affine_map<()[s0, s1] -> (s1)>
+    element_type_var: T2
+  - !<LinalgTensorDef>
+    name: x
+    usage: output
+    shape: affine_map<()[s0, s1] -> (s0)>
+    element_type_var: U
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<(d0, d1)[s0, s1] -> (d0, d1)>
+    - affine_map<(d0, d1)[s0, s1] -> (d1)>
+    - affine_map<(d0, d1)[s0, s1] -> (d0)>
+  iterator_types:
+  - parallel
+  - reduction
+  assignments:
+  - !ScalarAssign
+    arg: x
+    value: !ScalarExpression
+      scalar_apply:
+        fn_name: add
+        operands:
+        - !ScalarExpression
+          scalar_arg: x
+        - !ScalarExpression
+          scalar_apply:
+            fn_name: mul
+            operands:
+            - !ScalarExpression
+              symbolic_cast:
+                type_var: U
+                operands:
+                - !ScalarExpression
+                  scalar_arg: A
+            - !ScalarExpression
+              symbolic_cast:
+                type_var: U
+                operands:
+                - !ScalarExpression
+                  scalar_arg: y
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: vecmat
+  cpp_op_name: VecmatOp
+  doc: |-
+    Performs a vector-matrix multiplacation.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+  implements:
+  - LinalgContractionOpInterface
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !<LinalgTensorDef>
+    name: y
+    usage: input
+    shape: affine_map<()[s0, s1] -> (s1)>
+    element_type_var: T1
+  - !<LinalgTensorDef>
+    name: A
+    usage: input
+    shape: affine_map<()[s0, s1] -> (s1, s0)>
+    element_type_var: T2
+  - !<LinalgTensorDef>
+    name: x
+    usage: output
+    shape: affine_map<()[s0, s1] -> (s0)>
+    element_type_var: U
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<(d0, d1)[s0, s1] -> (d1)>
+    - affine_map<(d0, d1)[s0, s1] -> (d1, d0)>
+    - affine_map<(d0, d1)[s0, s1] -> (d0)>
+  iterator_types:
+  - parallel
+  - reduction
+  assignments:
+  - !ScalarAssign
+    arg: x
+    value: !ScalarExpression
+      scalar_apply:
+        fn_name: add
+        operands:
+        - !ScalarExpression
+          scalar_arg: x
+        - !ScalarExpression
+          scalar_apply:
+            fn_name: mul
+            operands:
+            - !ScalarExpression
+              symbolic_cast:
+                type_var: U
+                operands:
+                - !ScalarExpression
+                  scalar_arg: y
+            - !ScalarExpression
+              symbolic_cast:
+                type_var: U
+                operands:
+                - !ScalarExpression
+                  scalar_arg: A
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: dot
+  cpp_op_name: DotOp
+  doc: |-
+    Performs a dot product of two vectors to a scalar result.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+  implements:
+  - LinalgContractionOpInterface
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !<LinalgTensorDef>
+    name: A
+    usage: input
+    shape: affine_map<()[s0] -> (s0)>
+    element_type_var: T1
+  - !<LinalgTensorDef>
+    name: B
+    usage: input
+    shape: affine_map<()[s0] -> (s0)>
+    element_type_var: T2
+  - !<LinalgTensorDef>
+    name: C
+    usage: output
+    shape: affine_map<()[s0] -> ()>
+    element_type_var: U
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<(d0)[s0] -> (d0)>
+    - affine_map<(d0)[s0] -> (d0)>
+    - affine_map<(d0)[s0] -> ()>
+  iterator_types:
+  - reduction
+  assignments:
+  - !ScalarAssign
+    arg: C
+    value: !ScalarExpression
+      scalar_apply:
+        fn_name: add
+        operands:
+        - !ScalarExpression
+          scalar_arg: C
+        - !ScalarExpression
+          scalar_apply:
+            fn_name: mul
+            operands:
+            - !ScalarExpression
+              symbolic_cast:
+                type_var: U
+                operands:
+                - !ScalarExpression
+                  scalar_arg: A
+            - !ScalarExpression
+              symbolic_cast:
+                type_var: U
+                operands:
+                - !ScalarExpression
+                  scalar_arg: B
 

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
index 338cc6eaa4d6..37b972b73cf5 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
@@ -1,9 +1,3 @@
-ods_def<MatmulOp>
-implements_interface<LinalgContractionOpInterface> :
-def matmul(A: f32(M, K), B: f32(K, N)) -> (C: f32(M, N)) {
-  C(m, n) = std_addf<k>(C(m, n), std_mulf(A(m, k), B(k, n)));
-}
-
 ods_def<MatmulColumnMajorOp>
 implements_interface<LinalgContractionOpInterface> :
 def matmul_column_major(A: f32(K, M), B: f32(N, K)) -> (C: f32(N, M)) {
@@ -30,12 +24,6 @@ def matmul_i32_i32_i32(A: i32(M, K), B: i32(K, N)) -> (C: i32(M, N)) {
   C(m, n) = std_addi<k>(C(m, n), std_muli(A(m, k), B(k, n)));
 }
 
-ods_def<MatvecOp>
-implements_interface<LinalgContractionOpInterface> :
-def matvec(A: f32(M, N), y: f32(N)) -> (x: f32(M)) {
-  x(m) = std_addf<n>(x(m), std_mulf(A(m, n), y(n)));
-}
-
 ods_def<MatvecI8I8I32Op>
 implements_interface<LinalgContractionOpInterface> :
 def matvec_i8_i8_i32(A: i8(M, N), y: i8(N)) -> (x: i32(M)) {
@@ -54,12 +42,6 @@ def matvec_i32_i32_i32(A: i32(M, N), y: i32(N)) -> (x: i32(M)) {
   x(m) = std_addi<n>(x(m), std_muli(A(m, n), y(n)));
 }
 
-ods_def<VecmatOp>
-implements_interface<LinalgContractionOpInterface> :
-def vecmat(y: f32(M), A: f32(M, N)) -> (x: f32(N)) {
-  x(n) = std_addf<m>(x(n), std_mulf(y(m), A(m, n)));
-}
-
 ods_def<VecmatI8I8I32Op>
 implements_interface<LinalgContractionOpInterface> :
 def vecmat_i8_i8_i32(y: i8(M), A: i8(M, N)) -> (x: i32(N)) {
@@ -78,12 +60,6 @@ def vecmat_i32_i32_i32(y: i32(M), A: i32(M, N)) -> (x: i32(N)) {
   x(n) = std_addi<m>(x(n), std_muli(y(m), A(m, n)));
 }
 
-ods_def<DotOp>
-implements_interface<LinalgContractionOpInterface> :
-def dot(A: f32(M), B: f32(M)) -> (C: f32()) {
-  C() = std_addf<m>(C(), std_mulf(A(m), B(m)));
-}
-
 ods_def<DotI8I8I32Op>
 implements_interface<LinalgContractionOpInterface> :
 def dot_i8_i8_i32(A: i8(M), B: i8(M)) -> (C: i32()) {
@@ -102,12 +78,6 @@ def dot_i32_i32_i32(A: i32(M), B: i32(M)) -> (C: i32()) {
   C() = std_addi<m>(C(), std_muli(A(m), B(m)));
 }
 
-ods_def<BatchMatmulOp>
-implements_interface<LinalgContractionOpInterface> :
-def batch_matmul(A: f32(Batch, M, K), B: f32(Batch, K, N)) -> (C: f32(Batch, M, N)) {
-  C(b, m, n) = std_addf<k>(C(b, m, n), std_mulf(A(b, m, k), B(b, k, n)));
-}
-
 ods_def<BatchMatmulI8I8I32Op>
 implements_interface<LinalgContractionOpInterface> :
 def batch_matmul_i8_i8_i32(A: i8(Batch, M, K), B: i8(Batch, K, N)) -> (C: i32(Batch, M, N)) {

diff  --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
index fc1183ec0d85..251dfe609606 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
@@ -1,7 +1,7 @@
 // RUN: mlir-opt %s -split-input-file -linalg-generalize-named-ops | FileCheck %s
 
 func @generalize_matmul_tensor_f32(%A : tensor<16x8xf32>, %B: tensor<8x32xf32>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
-  %0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xf32>, tensor<8x32xf32>)
+  %0 = linalg.matmul ins(%A, %B: tensor<16x8xf32>, tensor<8x32xf32>)
                           outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
   return %0: tensor<16x32xf32>
 }
@@ -16,7 +16,7 @@ func @generalize_matmul_tensor_f32(%A : tensor<16x8xf32>, %B: tensor<8x32xf32>,
 // -----
 
 func @generalize_matmul_tensor_i32(%A : tensor<16x8xi32>, %B: tensor<8x32xi32>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> {
-  %0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xi32>, tensor<8x32xi32>)
+  %0 = linalg.matmul ins(%A, %B: tensor<16x8xi32>, tensor<8x32xi32>)
                           outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32>
   return %0: tensor<16x32xi32>
 }
@@ -31,7 +31,7 @@ func @generalize_matmul_tensor_i32(%A : tensor<16x8xi32>, %B: tensor<8x32xi32>,
 // -----
 // Verifies floating point to integer cast.
 func @generalize_matmul_tensor_f32_f32_i16(%A : tensor<16x8xf32>, %B: tensor<8x32xf32>, %C: tensor<16x32xi16>) -> tensor<16x32xi16> {
-  %0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xf32>, tensor<8x32xf32>)
+  %0 = linalg.matmul ins(%A, %B: tensor<16x8xf32>, tensor<8x32xf32>)
                           outs(%C: tensor<16x32xi16>) -> tensor<16x32xi16>
   return %0: tensor<16x32xi16>
 }
@@ -48,7 +48,7 @@ func @generalize_matmul_tensor_f32_f32_i16(%A : tensor<16x8xf32>, %B: tensor<8x3
 // -----
 // Verifies sign extension cast.
 func @generalize_matmul_tensor_i8_i8_i32(%A : tensor<16x8xi8>, %B: tensor<8x32xi8>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> {
-  %0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xi8>, tensor<8x32xi8>)
+  %0 = linalg.matmul ins(%A, %B: tensor<16x8xi8>, tensor<8x32xi8>)
                           outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32>
   return %0: tensor<16x32xi32>
 }
@@ -65,7 +65,7 @@ func @generalize_matmul_tensor_i8_i8_i32(%A : tensor<16x8xi8>, %B: tensor<8x32xi
 // -----
 // Verifies that 
diff erent argument types is legal.
 func @generalize_matmul_tensor_i8_i16_i32(%A : tensor<16x8xi8>, %B: tensor<8x32xi16>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> {
-  %0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xi8>, tensor<8x32xi16>)
+  %0 = linalg.matmul ins(%A, %B: tensor<16x8xi8>, tensor<8x32xi16>)
                           outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32>
   return %0: tensor<16x32xi32>
 }
@@ -82,7 +82,7 @@ func @generalize_matmul_tensor_i8_i16_i32(%A : tensor<16x8xi8>, %B: tensor<8x32x
 // -----
 // Somewhat non-sensical but checks integer truncation cast.
 func @generalize_matmul_tensor_i32_i32_i16(%A : tensor<16x8xi32>, %B: tensor<8x32xi32>, %C: tensor<16x32xi16>) -> tensor<16x32xi16> {
-  %0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xi32>, tensor<8x32xi32>)
+  %0 = linalg.matmul ins(%A, %B: tensor<16x8xi32>, tensor<8x32xi32>)
                           outs(%C: tensor<16x32xi16>) -> tensor<16x32xi16>
   return %0: tensor<16x32xi16>
 }
@@ -99,7 +99,7 @@ func @generalize_matmul_tensor_i32_i32_i16(%A : tensor<16x8xi32>, %B: tensor<8x3
 // -----
 // Verifies integer to floating point cast.
 func @generalize_matmul_tensor_i8_i8_f32(%A : tensor<16x8xi8>, %B: tensor<8x32xi8>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
-  %0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xi8>, tensor<8x32xi8>)
+  %0 = linalg.matmul ins(%A, %B: tensor<16x8xi8>, tensor<8x32xi8>)
                           outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
   return %0: tensor<16x32xf32>
 }
@@ -116,7 +116,7 @@ func @generalize_matmul_tensor_i8_i8_f32(%A : tensor<16x8xi8>, %B: tensor<8x32xi
 // -----
 // Verifies floating point extension cast.
 func @generalize_matmul_tensor_f16_f16_f32(%A : tensor<16x8xf16>, %B: tensor<8x32xf16>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
-  %0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xf16>, tensor<8x32xf16>)
+  %0 = linalg.matmul ins(%A, %B: tensor<16x8xf16>, tensor<8x32xf16>)
                           outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
   return %0: tensor<16x32xf32>
 }
@@ -133,7 +133,7 @@ func @generalize_matmul_tensor_f16_f16_f32(%A : tensor<16x8xf16>, %B: tensor<8x3
 // -----
 // Verifies floating point truncation.
 func @generalize_matmul_tensor_f64_f64_f32(%A : tensor<16x8xf64>, %B: tensor<8x32xf64>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
-  %0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xf64>, tensor<8x32xf64>)
+  %0 = linalg.matmul ins(%A, %B: tensor<16x8xf64>, tensor<8x32xf64>)
                           outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
   return %0: tensor<16x32xf32>
 }


        


More information about the Mlir-commits mailing list