[Mlir-commits] [mlir] 21bb638 - [MLIR][linalg] Make integer matmul ops cast before multiplying
Geoffrey Martin-Noble
llvmlistbot at llvm.org
Fri Feb 26 08:36:40 PST 2021
Author: Geoffrey Martin-Noble
Date: 2021-02-26T08:36:31-08:00
New Revision: 21bb63893e8557df7a7ab690ad98cb5979099186
URL: https://github.com/llvm/llvm-project/commit/21bb63893e8557df7a7ab690ad98cb5979099186
DIFF: https://github.com/llvm/llvm-project/commit/21bb63893e8557df7a7ab690ad98cb5979099186.diff
LOG: [MLIR][linalg] Make integer matmul ops cast before multiplying
Right now they multiply before casting which means they would frequently
overflow. There are various reasonable ways to do this, but until we
have robust op description infra, this is a simple and safe default. More
careful treatments are likely to be hardware specific, as well (e.g.
using an i8*i8->i16 mul instruction).
Reviewed By: nicolasvasilache, mravishankar
Differential Revision: https://reviews.llvm.org/D97505
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
mlir/test/Dialect/Linalg/vectorization.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
index 3ebbb48a92a0..338cc6eaa4d6 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
@@ -15,13 +15,13 @@ implements_interface<LinalgContractionOpInterface> :
def matmul_i8_i8_i32(A: i8(M, K), B: i8(K, N)) -> (C: i32(M, N)) {
// TODO: ideally something closer to
// C(m, n) += cast<i32>(A(m, k)) * cast<i32>(B(k, n))
- C(m, n) = std_addi<k>(C(m, n), std_sexti32(std_muli(A(m, k), B(k, n))));
+ C(m, n) = std_addi<k>(C(m, n), std_muli(std_sexti32(A(m, k)), std_sexti32(B(k, n))));
}
ods_def<MatmulI16I16I32Op>
implements_interface<LinalgContractionOpInterface> :
def matmul_i16_i16_i32(A: i16(M, K), B: i16(K, N)) -> (C: i32(M, N)) {
- C(m, n) = std_addi<k>(C(m, n), std_sexti32(std_muli(A(m, k), B(k, n))));
+ C(m, n) = std_addi<k>(C(m, n), std_muli(std_sexti32(A(m, k)), std_sexti32(B(k, n))));
}
ods_def<MatmulI32I32I32Op>
@@ -39,13 +39,13 @@ def matvec(A: f32(M, N), y: f32(N)) -> (x: f32(M)) {
ods_def<MatvecI8I8I32Op>
implements_interface<LinalgContractionOpInterface> :
def matvec_i8_i8_i32(A: i8(M, N), y: i8(N)) -> (x: i32(M)) {
- x(m) = std_addi<n>(x(m), std_sexti32(std_muli(A(m, n), y(n))));
+ x(m) = std_addi<n>(x(m), std_muli(std_sexti32(A(m, n)), std_sexti32(y(n))));
}
ods_def<MatvecI16I16I32Op>
implements_interface<LinalgContractionOpInterface> :
def matvec_i16_i16_i32(A: i16(M, N), y: i16(N)) -> (x: i32(M)) {
- x(m) = std_addi<n>(x(m), std_sexti32(std_muli(A(m, n), y(n))));
+ x(m) = std_addi<n>(x(m), std_muli(std_sexti32(A(m, n)), std_sexti32(y(n))));
}
ods_def<MatvecI32I32I32Op>
@@ -63,13 +63,13 @@ def vecmat(y: f32(M), A: f32(M, N)) -> (x: f32(N)) {
ods_def<VecmatI8I8I32Op>
implements_interface<LinalgContractionOpInterface> :
def vecmat_i8_i8_i32(y: i8(M), A: i8(M, N)) -> (x: i32(N)) {
- x(n) = std_addi<m>(x(n), std_sexti32(std_muli(y(m), A(m, n))));
+ x(n) = std_addi<m>(x(n), std_muli(std_sexti32(y(m)), std_sexti32(A(m, n))));
}
ods_def<VecmatI16I16I32Op>
implements_interface<LinalgContractionOpInterface> :
def vecmat_i16_i16_i32(y: i16(M), A: i16(M, N)) -> (x: i32(N)) {
- x(n) = std_addi<m>(x(n), std_sexti32(std_muli(y(m), A(m, n))));
+ x(n) = std_addi<m>(x(n), std_muli(std_sexti32(y(m)), std_sexti32(A(m, n))));
}
ods_def<VecmatI32I32I32Op>
@@ -87,13 +87,13 @@ def dot(A: f32(M), B: f32(M)) -> (C: f32()) {
ods_def<DotI8I8I32Op>
implements_interface<LinalgContractionOpInterface> :
def dot_i8_i8_i32(A: i8(M), B: i8(M)) -> (C: i32()) {
- C() = std_addi<m>(C(), std_sexti32(std_muli(A(m), B(m))));
+ C() = std_addi<m>(C(), std_muli(std_sexti32(A(m)), std_sexti32(B(m))));
}
ods_def<DotI16I16I32Op>
implements_interface<LinalgContractionOpInterface> :
def dot_i16_i16_i32(A: i16(M), B: i16(M)) -> (C: i32()) {
- C() = std_addi<m>(C(), std_sexti32(std_muli(A(m), B(m))));
+ C() = std_addi<m>(C(), std_muli(std_sexti32(A(m)), std_sexti32(B(m))));
}
ods_def<DotI32I32I32Op>
@@ -112,14 +112,14 @@ ods_def<BatchMatmulI8I8I32Op>
implements_interface<LinalgContractionOpInterface> :
def batch_matmul_i8_i8_i32(A: i8(Batch, M, K), B: i8(Batch, K, N)) -> (C: i32(Batch, M, N)) {
C(b, m, n) =
- std_addi<k>(C(b, m, n), std_sexti32(std_muli(A(b, m, k), B(b, k, n))));
+ std_addi<k>(C(b, m, n), std_muli(std_sexti32(A(b, m, k)), std_sexti32(B(b, k, n))));
}
ods_def<BatchMatmulI16I16I32Op>
implements_interface<LinalgContractionOpInterface> :
def batch_matmul_i16_i16_i32(A: i16(Batch, M, K), B: i16(Batch, K, N)) -> (C: i32(Batch, M, N)) {
C(b, m, n) =
- std_addi<k>(C(b, m, n), std_sexti32(std_muli(A(b, m, k), B(b, k, n))));
+ std_addi<k>(C(b, m, n), std_muli(std_sexti32(A(b, m, k)), std_sexti32(B(b, k, n))));
}
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 13d2e181e2ea..436d513e637a 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -373,17 +373,18 @@ func @matmul_tensors(
// CHECK-SAME: %[[ARG2:[a-z0-9]+]]: memref<4x12xi32>
func @matmul_i8_i8_i32(%a: memref<4x6xi8>, %b: memref<6x12xi8>, %c: memref<4x12xi32>) {
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
- // CHECK-DAG: %[[VEC_C0:.*]] = constant dense<0> : vector<4x12xi8>
+ // CHECK-DAG: %[[VEC_C0:.*]] = constant dense<0> : vector<4x12xi32>
// CHECK-DAG: %[[V0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x6xi8>, vector<4x6xi8>
// CHECK-DAG: %[[V1:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : memref<6x12xi8>, vector<6x12xi8>
// CHECK-DAG: %[[V2:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], {{.*}} : memref<4x12xi32>, vector<4x12xi32>
+ // CHECK-DAG: %[[V0_32:.*]] = sexti %[[V0]] : vector<4x6xi8> to vector<4x6xi32>
+ // CHECK-DAG: %[[V1_32:.*]] = sexti %[[V1]] : vector<6x12xi8> to vector<6x12xi32>
//
// linalg contraction lowers to %tmp = vector.contract %a, %b, %c0 followed by addf %c, %tmp.
// a later canonicalization fuses the add into vector.contract.
- // CHECK: %[[C:.*]] = vector.contract {{.*}} iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[V0]], %[[V1]], %[[VEC_C0]]
- // CHECK-SAME: vector<4x6xi8>, vector<6x12xi8> into vector<4x12xi8>
- // CHECK: %[[C32:.*]] = sexti %[[C]] : vector<4x12xi8> to vector<4x12xi32>
- // CHECK: %[[RES:.*]] = addi %[[V2]], %[[C32]] : vector<4x12xi32>
+ // CHECK: %[[C:.*]] = vector.contract {{.*}} iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[V0_32]], %[[V1_32]], %[[VEC_C0]]
+ // CHECK-SAME: vector<4x6xi32>, vector<6x12xi32> into vector<4x12xi32>
+ // CHECK: %[[RES:.*]] = addi %[[V2]], %[[C]] : vector<4x12xi32>
// CHECK: vector.transfer_write %[[RES]], %[[ARG2]][%[[C0]], %[[C0]]] {masked = [false, false]}
// CHECK-SAME: vector<4x12xi32>, memref<4x12xi32>
linalg.matmul_i8_i8_i32 ins(%a, %b : memref<4x6xi8>, memref<6x12xi8>)
More information about the Mlir-commits
mailing list