[Mlir-commits] [mlir] 6d2fd3d - [mlir][linalg] Replace monomorphic contration ops with polymorphic variants.
Stella Laurenzo
llvmlistbot at llvm.org
Mon Mar 1 21:22:35 PST 2021
Author: Stella Laurenzo
Date: 2021-03-01T21:19:53-08:00
New Revision: 6d2fd3d9cdd6ed24784ec47741e7e70c236a140e
URL: https://github.com/llvm/llvm-project/commit/6d2fd3d9cdd6ed24784ec47741e7e70c236a140e
DIFF: https://github.com/llvm/llvm-project/commit/6d2fd3d9cdd6ed24784ec47741e7e70c236a140e.diff
LOG: [mlir][linalg] Replace monomorphic contration ops with polymorphic variants.
* Moves `batch_matmul`, `matmul`, `matvec`, `vectmat`, `dot` to the new mechanism.
* This is not just an NFC change, in addition to using a new code generation mechanism, it also activates symbolic casting, allowing mixed precision operands and results.
* These definitions were generated from DSL by the tool: https://github.com/stellaraccident/mlir-linalgpy/blob/main/mlir_linalg/oplib/core.py (will be upstreamed in a subsequent set of changes).
Reviewed By: nicolasvasilache, ThomasRaoux
Differential Revision: https://reviews.llvm.org/D97719
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 93bc5760ed0c..5752af9bea9a 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -1,12 +1,12 @@
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
- name: polymorphic_matmul
- cpp_op_name: PolymorphicMatmulOp
+ name: matmul
+ cpp_op_name: MatmulOp
doc: |-
- Type polymorphic matrix multiplication.
+ Performs a matrix multiplacation of two 2D inputs.
- This op is presently here to test a new path for generation and will replace
- the existing 'matmul' op when ready. Do not use.
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
implements:
- LinalgContractionOpInterface
structured_op: !LinalgStructuredOpConfig
@@ -60,4 +60,249 @@ structured_op: !LinalgStructuredOpConfig
operands:
- !ScalarExpression
scalar_arg: B
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+ name: batch_matmul
+ cpp_op_name: BatchMatmulOp
+ doc: |-
+ Performs a batched matrix multiplacation of two 3D inputs.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ implements:
+ - LinalgContractionOpInterface
+structured_op: !LinalgStructuredOpConfig
+ args:
+ - !<LinalgTensorDef>
+ name: A
+ usage: input
+ shape: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)>
+ element_type_var: T1
+ - !<LinalgTensorDef>
+ name: B
+ usage: input
+ shape: affine_map<()[s0, s1, s2, s3] -> (s0, s3, s2)>
+ element_type_var: T2
+ - !<LinalgTensorDef>
+ name: C
+ usage: output
+ shape: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
+ element_type_var: U
+ indexing_maps: !LinalgIndexingMapsConfig
+ static_indexing_maps:
+ - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)>
+ - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d2)>
+ - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d2)>
+ iterator_types:
+ - parallel
+ - parallel
+ - parallel
+ - reduction
+ assignments:
+ - !ScalarAssign
+ arg: C
+ value: !ScalarExpression
+ scalar_apply:
+ fn_name: add
+ operands:
+ - !ScalarExpression
+ scalar_arg: C
+ - !ScalarExpression
+ scalar_apply:
+ fn_name: mul
+ operands:
+ - !ScalarExpression
+ symbolic_cast:
+ type_var: U
+ operands:
+ - !ScalarExpression
+ scalar_arg: A
+ - !ScalarExpression
+ symbolic_cast:
+ type_var: U
+ operands:
+ - !ScalarExpression
+ scalar_arg: B
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+ name: matvec
+ cpp_op_name: MatvecOp
+ doc: |-
+ Performs a matrix-vector multiplication.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ implements:
+ - LinalgContractionOpInterface
+structured_op: !LinalgStructuredOpConfig
+ args:
+ - !<LinalgTensorDef>
+ name: A
+ usage: input
+ shape: affine_map<()[s0, s1] -> (s0, s1)>
+ element_type_var: T1
+ - !<LinalgTensorDef>
+ name: y
+ usage: input
+ shape: affine_map<()[s0, s1] -> (s1)>
+ element_type_var: T2
+ - !<LinalgTensorDef>
+ name: x
+ usage: output
+ shape: affine_map<()[s0, s1] -> (s0)>
+ element_type_var: U
+ indexing_maps: !LinalgIndexingMapsConfig
+ static_indexing_maps:
+ - affine_map<(d0, d1)[s0, s1] -> (d0, d1)>
+ - affine_map<(d0, d1)[s0, s1] -> (d1)>
+ - affine_map<(d0, d1)[s0, s1] -> (d0)>
+ iterator_types:
+ - parallel
+ - reduction
+ assignments:
+ - !ScalarAssign
+ arg: x
+ value: !ScalarExpression
+ scalar_apply:
+ fn_name: add
+ operands:
+ - !ScalarExpression
+ scalar_arg: x
+ - !ScalarExpression
+ scalar_apply:
+ fn_name: mul
+ operands:
+ - !ScalarExpression
+ symbolic_cast:
+ type_var: U
+ operands:
+ - !ScalarExpression
+ scalar_arg: A
+ - !ScalarExpression
+ symbolic_cast:
+ type_var: U
+ operands:
+ - !ScalarExpression
+ scalar_arg: y
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+ name: vecmat
+ cpp_op_name: VecmatOp
+ doc: |-
+ Performs a vector-matrix multiplacation.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ implements:
+ - LinalgContractionOpInterface
+structured_op: !LinalgStructuredOpConfig
+ args:
+ - !<LinalgTensorDef>
+ name: y
+ usage: input
+ shape: affine_map<()[s0, s1] -> (s1)>
+ element_type_var: T1
+ - !<LinalgTensorDef>
+ name: A
+ usage: input
+ shape: affine_map<()[s0, s1] -> (s1, s0)>
+ element_type_var: T2
+ - !<LinalgTensorDef>
+ name: x
+ usage: output
+ shape: affine_map<()[s0, s1] -> (s0)>
+ element_type_var: U
+ indexing_maps: !LinalgIndexingMapsConfig
+ static_indexing_maps:
+ - affine_map<(d0, d1)[s0, s1] -> (d1)>
+ - affine_map<(d0, d1)[s0, s1] -> (d1, d0)>
+ - affine_map<(d0, d1)[s0, s1] -> (d0)>
+ iterator_types:
+ - parallel
+ - reduction
+ assignments:
+ - !ScalarAssign
+ arg: x
+ value: !ScalarExpression
+ scalar_apply:
+ fn_name: add
+ operands:
+ - !ScalarExpression
+ scalar_arg: x
+ - !ScalarExpression
+ scalar_apply:
+ fn_name: mul
+ operands:
+ - !ScalarExpression
+ symbolic_cast:
+ type_var: U
+ operands:
+ - !ScalarExpression
+ scalar_arg: y
+ - !ScalarExpression
+ symbolic_cast:
+ type_var: U
+ operands:
+ - !ScalarExpression
+ scalar_arg: A
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+ name: dot
+ cpp_op_name: DotOp
+ doc: |-
+ Performs a dot product of two vectors to a scalar result.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ implements:
+ - LinalgContractionOpInterface
+structured_op: !LinalgStructuredOpConfig
+ args:
+ - !<LinalgTensorDef>
+ name: A
+ usage: input
+ shape: affine_map<()[s0] -> (s0)>
+ element_type_var: T1
+ - !<LinalgTensorDef>
+ name: B
+ usage: input
+ shape: affine_map<()[s0] -> (s0)>
+ element_type_var: T2
+ - !<LinalgTensorDef>
+ name: C
+ usage: output
+ shape: affine_map<()[s0] -> ()>
+ element_type_var: U
+ indexing_maps: !LinalgIndexingMapsConfig
+ static_indexing_maps:
+ - affine_map<(d0)[s0] -> (d0)>
+ - affine_map<(d0)[s0] -> (d0)>
+ - affine_map<(d0)[s0] -> ()>
+ iterator_types:
+ - reduction
+ assignments:
+ - !ScalarAssign
+ arg: C
+ value: !ScalarExpression
+ scalar_apply:
+ fn_name: add
+ operands:
+ - !ScalarExpression
+ scalar_arg: C
+ - !ScalarExpression
+ scalar_apply:
+ fn_name: mul
+ operands:
+ - !ScalarExpression
+ symbolic_cast:
+ type_var: U
+ operands:
+ - !ScalarExpression
+ scalar_arg: A
+ - !ScalarExpression
+ symbolic_cast:
+ type_var: U
+ operands:
+ - !ScalarExpression
+ scalar_arg: B
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
index 338cc6eaa4d6..37b972b73cf5 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
@@ -1,9 +1,3 @@
-ods_def<MatmulOp>
-implements_interface<LinalgContractionOpInterface> :
-def matmul(A: f32(M, K), B: f32(K, N)) -> (C: f32(M, N)) {
- C(m, n) = std_addf<k>(C(m, n), std_mulf(A(m, k), B(k, n)));
-}
-
ods_def<MatmulColumnMajorOp>
implements_interface<LinalgContractionOpInterface> :
def matmul_column_major(A: f32(K, M), B: f32(N, K)) -> (C: f32(N, M)) {
@@ -30,12 +24,6 @@ def matmul_i32_i32_i32(A: i32(M, K), B: i32(K, N)) -> (C: i32(M, N)) {
C(m, n) = std_addi<k>(C(m, n), std_muli(A(m, k), B(k, n)));
}
-ods_def<MatvecOp>
-implements_interface<LinalgContractionOpInterface> :
-def matvec(A: f32(M, N), y: f32(N)) -> (x: f32(M)) {
- x(m) = std_addf<n>(x(m), std_mulf(A(m, n), y(n)));
-}
-
ods_def<MatvecI8I8I32Op>
implements_interface<LinalgContractionOpInterface> :
def matvec_i8_i8_i32(A: i8(M, N), y: i8(N)) -> (x: i32(M)) {
@@ -54,12 +42,6 @@ def matvec_i32_i32_i32(A: i32(M, N), y: i32(N)) -> (x: i32(M)) {
x(m) = std_addi<n>(x(m), std_muli(A(m, n), y(n)));
}
-ods_def<VecmatOp>
-implements_interface<LinalgContractionOpInterface> :
-def vecmat(y: f32(M), A: f32(M, N)) -> (x: f32(N)) {
- x(n) = std_addf<m>(x(n), std_mulf(y(m), A(m, n)));
-}
-
ods_def<VecmatI8I8I32Op>
implements_interface<LinalgContractionOpInterface> :
def vecmat_i8_i8_i32(y: i8(M), A: i8(M, N)) -> (x: i32(N)) {
@@ -78,12 +60,6 @@ def vecmat_i32_i32_i32(y: i32(M), A: i32(M, N)) -> (x: i32(N)) {
x(n) = std_addi<m>(x(n), std_muli(y(m), A(m, n)));
}
-ods_def<DotOp>
-implements_interface<LinalgContractionOpInterface> :
-def dot(A: f32(M), B: f32(M)) -> (C: f32()) {
- C() = std_addf<m>(C(), std_mulf(A(m), B(m)));
-}
-
ods_def<DotI8I8I32Op>
implements_interface<LinalgContractionOpInterface> :
def dot_i8_i8_i32(A: i8(M), B: i8(M)) -> (C: i32()) {
@@ -102,12 +78,6 @@ def dot_i32_i32_i32(A: i32(M), B: i32(M)) -> (C: i32()) {
C() = std_addi<m>(C(), std_muli(A(m), B(m)));
}
-ods_def<BatchMatmulOp>
-implements_interface<LinalgContractionOpInterface> :
-def batch_matmul(A: f32(Batch, M, K), B: f32(Batch, K, N)) -> (C: f32(Batch, M, N)) {
- C(b, m, n) = std_addf<k>(C(b, m, n), std_mulf(A(b, m, k), B(b, k, n)));
-}
-
ods_def<BatchMatmulI8I8I32Op>
implements_interface<LinalgContractionOpInterface> :
def batch_matmul_i8_i8_i32(A: i8(Batch, M, K), B: i8(Batch, K, N)) -> (C: i32(Batch, M, N)) {
diff --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
index fc1183ec0d85..251dfe609606 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
@@ -1,7 +1,7 @@
// RUN: mlir-opt %s -split-input-file -linalg-generalize-named-ops | FileCheck %s
func @generalize_matmul_tensor_f32(%A : tensor<16x8xf32>, %B: tensor<8x32xf32>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
- %0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xf32>, tensor<8x32xf32>)
+ %0 = linalg.matmul ins(%A, %B: tensor<16x8xf32>, tensor<8x32xf32>)
outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
return %0: tensor<16x32xf32>
}
@@ -16,7 +16,7 @@ func @generalize_matmul_tensor_f32(%A : tensor<16x8xf32>, %B: tensor<8x32xf32>,
// -----
func @generalize_matmul_tensor_i32(%A : tensor<16x8xi32>, %B: tensor<8x32xi32>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> {
- %0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xi32>, tensor<8x32xi32>)
+ %0 = linalg.matmul ins(%A, %B: tensor<16x8xi32>, tensor<8x32xi32>)
outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32>
return %0: tensor<16x32xi32>
}
@@ -31,7 +31,7 @@ func @generalize_matmul_tensor_i32(%A : tensor<16x8xi32>, %B: tensor<8x32xi32>,
// -----
// Verifies floating point to integer cast.
func @generalize_matmul_tensor_f32_f32_i16(%A : tensor<16x8xf32>, %B: tensor<8x32xf32>, %C: tensor<16x32xi16>) -> tensor<16x32xi16> {
- %0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xf32>, tensor<8x32xf32>)
+ %0 = linalg.matmul ins(%A, %B: tensor<16x8xf32>, tensor<8x32xf32>)
outs(%C: tensor<16x32xi16>) -> tensor<16x32xi16>
return %0: tensor<16x32xi16>
}
@@ -48,7 +48,7 @@ func @generalize_matmul_tensor_f32_f32_i16(%A : tensor<16x8xf32>, %B: tensor<8x3
// -----
// Verifies sign extension cast.
func @generalize_matmul_tensor_i8_i8_i32(%A : tensor<16x8xi8>, %B: tensor<8x32xi8>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> {
- %0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xi8>, tensor<8x32xi8>)
+ %0 = linalg.matmul ins(%A, %B: tensor<16x8xi8>, tensor<8x32xi8>)
outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32>
return %0: tensor<16x32xi32>
}
@@ -65,7 +65,7 @@ func @generalize_matmul_tensor_i8_i8_i32(%A : tensor<16x8xi8>, %B: tensor<8x32xi
// -----
// Verifies that
diff erent argument types is legal.
func @generalize_matmul_tensor_i8_i16_i32(%A : tensor<16x8xi8>, %B: tensor<8x32xi16>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> {
- %0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xi8>, tensor<8x32xi16>)
+ %0 = linalg.matmul ins(%A, %B: tensor<16x8xi8>, tensor<8x32xi16>)
outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32>
return %0: tensor<16x32xi32>
}
@@ -82,7 +82,7 @@ func @generalize_matmul_tensor_i8_i16_i32(%A : tensor<16x8xi8>, %B: tensor<8x32x
// -----
// Somewhat non-sensical but checks integer truncation cast.
func @generalize_matmul_tensor_i32_i32_i16(%A : tensor<16x8xi32>, %B: tensor<8x32xi32>, %C: tensor<16x32xi16>) -> tensor<16x32xi16> {
- %0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xi32>, tensor<8x32xi32>)
+ %0 = linalg.matmul ins(%A, %B: tensor<16x8xi32>, tensor<8x32xi32>)
outs(%C: tensor<16x32xi16>) -> tensor<16x32xi16>
return %0: tensor<16x32xi16>
}
@@ -99,7 +99,7 @@ func @generalize_matmul_tensor_i32_i32_i16(%A : tensor<16x8xi32>, %B: tensor<8x3
// -----
// Verifies integer to floating point cast.
func @generalize_matmul_tensor_i8_i8_f32(%A : tensor<16x8xi8>, %B: tensor<8x32xi8>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
- %0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xi8>, tensor<8x32xi8>)
+ %0 = linalg.matmul ins(%A, %B: tensor<16x8xi8>, tensor<8x32xi8>)
outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
return %0: tensor<16x32xf32>
}
@@ -116,7 +116,7 @@ func @generalize_matmul_tensor_i8_i8_f32(%A : tensor<16x8xi8>, %B: tensor<8x32xi
// -----
// Verifies floating point extension cast.
func @generalize_matmul_tensor_f16_f16_f32(%A : tensor<16x8xf16>, %B: tensor<8x32xf16>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
- %0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xf16>, tensor<8x32xf16>)
+ %0 = linalg.matmul ins(%A, %B: tensor<16x8xf16>, tensor<8x32xf16>)
outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
return %0: tensor<16x32xf32>
}
@@ -133,7 +133,7 @@ func @generalize_matmul_tensor_f16_f16_f32(%A : tensor<16x8xf16>, %B: tensor<8x3
// -----
// Verifies floating point truncation.
func @generalize_matmul_tensor_f64_f64_f32(%A : tensor<16x8xf64>, %B: tensor<8x32xf64>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
- %0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xf64>, tensor<8x32xf64>)
+ %0 = linalg.matmul ins(%A, %B: tensor<16x8xf64>, tensor<8x32xf64>)
outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
return %0: tensor<16x32xf32>
}
More information about the Mlir-commits
mailing list