[Mlir-commits] [mlir] 2ce6a42 - [MLIR] Add Linalg support for integer (generalized) matmuls

Geoffrey Martin-Noble llvmlistbot at llvm.org
Mon Feb 22 11:13:35 PST 2021


Author: Geoffrey Martin-Noble
Date: 2021-02-22T11:13:26-08:00
New Revision: 2ce6a42cc94dbe1b0456ebc34c0db238bf5530d6

URL: https://github.com/llvm/llvm-project/commit/2ce6a42cc94dbe1b0456ebc34c0db238bf5530d6
DIFF: https://github.com/llvm/llvm-project/commit/2ce6a42cc94dbe1b0456ebc34c0db238bf5530d6.diff

LOG: [MLIR] Add Linalg support for integer (generalized) matmuls

This patch adds Linalg named ops for various types of integer matmuls.
Due to limitations in the tc spec/linalg-ods-gen ops cannot be type
polymorphic, so this instead creates new ops (improvements to the
methods for defining Linalg named ops are underway with a prototype at
https://github.com/stellaraccident/mlir-linalgpy).

To avoid the necessity of directly referencing these many new ops, this
adds additional methods to ContractionOpInterface to allow classifying
types of operations based on their indexing maps.

Reviewed By: nicolasvasilache, mravishankar

Differential Revision: https://reviews.llvm.org/D97006

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
    mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 95656ebd9983..7b88730835b6 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -34,6 +34,41 @@ def LinalgContractionOpInterface : OpInterface<"ContractionOpInterface"> {
   }];
   let cppNamespace = "::mlir::linalg";
   let verify = [{ return detail::verifyContractionInterface($_op); }];
+  let methods = [
+    InterfaceMethod<
+    /*desc=*/[{
+      Returns whether the given op has indexing maps the correspond to a
+      row-major matmul operation.
+    }],
+    /*retTy=*/"bool",
+    /*methodName=*/"isRowMajorMatmul",
+    /*args=*/(ins),
+    /*methodBody=*/[{
+        return mlir::isRowMajorMatmul($_op.indexing_maps());
+    }]>,
+    InterfaceMethod<
+    /*desc=*/[{
+      Returns whether the given op has indexing maps the correspond to a
+      column-major matmul operation.
+    }],
+    /*retTy=*/"bool",
+    /*methodName=*/"isColumnMajorMatmul",
+    /*args=*/(ins),
+    /*methodBody=*/[{
+        return mlir::isColumnMajorMatmul($_op.indexing_maps());
+    }]>,
+    InterfaceMethod<
+    /*desc=*/[{
+      Returns whether the given op has indexing maps the correspond to a
+      row-major batch matmul operation.
+    }],
+    /*retTy=*/"bool",
+    /*methodName=*/"isRowMajorBatchMatmul",
+    /*args=*/(ins),
+    /*methodBody=*/[{
+        return mlir::isRowMajorBatchMatmul($_op.indexing_maps());
+    }]>,
+  ];
 }
 
 // The 'LinalgStructuredInterface' provides access to the 'LinalgOp' interface.

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
index 6692f7d5831e..cd72aced29af 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
@@ -18,30 +18,118 @@ def matmul_i8_i8_i32(A: i8(M, K), B: i8(K, N)) -> (C: i32(M, N)) {
   C(m, n) = std_addi<k>(std_sexti32(std_muli(A(m, k), B(k, n))));
 }
 
+ods_def<MatmulI16I16I32Op>
+implements_interface<LinalgContractionOpInterface> :
+def matmul_i16_i16_i32(A: i16(M, K), B: i16(K, N)) -> (C: i32(M, N)) {
+  C(m, n) = std_addi<k>(std_sexti32(std_muli(A(m, k), B(k, n))));
+}
+
+ods_def<MatmulI32I32I32Op>
+implements_interface<LinalgContractionOpInterface> :
+def matmul_i32_i32_i32(A: i32(M, K), B: i32(K, N)) -> (C: i32(M, N)) {
+  C(m, n) = std_addi<k>(std_muli(A(m, k), B(k, n)));
+}
+
 ods_def<MatvecOp>
 implements_interface<LinalgContractionOpInterface> :
 def matvec(A: f32(M, N), y: f32(N)) -> (x: f32(M)) {
   x(m) = std_addf<n>(std_mulf(A(m, n), y(n)));
 }
 
+ods_def<MatvecI8I8I32Op>
+implements_interface<LinalgContractionOpInterface> :
+def matvec_i8_i8_i32(A: i8(M, N), y: i8(N)) -> (x: i32(M)) {
+  x(m) = std_addi<n>(std_sexti32(std_muli(A(m, n), y(n))));
+}
+
+ods_def<MatvecI16I16I32Op>
+implements_interface<LinalgContractionOpInterface> :
+def matvec_i16_i16_i32(A: i16(M, N), y: i16(N)) -> (x: i32(M)) {
+  x(m) = std_addi<n>(std_sexti32(std_muli(A(m, n), y(n))));
+}
+
+ods_def<MatvecI32I32I32Op>
+implements_interface<LinalgContractionOpInterface> :
+def matvec_i32_i32_i32(A: i32(M, N), y: i32(N)) -> (x: i32(M)) {
+  x(m) = std_addi<n>(std_muli(A(m, n), y(n)));
+}
+
 ods_def<VecmatOp>
 implements_interface<LinalgContractionOpInterface> :
 def vecmat(y: f32(M), A: f32(M, N)) -> (x: f32(N)) {
   x(n) = std_addf<m>(std_mulf(y(m), A(m, n)));
 }
 
+ods_def<VecmatI8I8I32Op>
+implements_interface<LinalgContractionOpInterface> :
+def vecmat_i8_i8_i32(y: i8(M), A: i8(M, N)) -> (x: i32(N)) {
+  x(n) = std_addi<m>(std_sexti32(std_muli(y(m), A(m, n))));
+}
+
+ods_def<VecmatI16I16I32Op>
+implements_interface<LinalgContractionOpInterface> :
+def vecmat_i16_i16_i32(y: i16(M), A: i16(M, N)) -> (x: i32(N)) {
+  x(n) = std_addi<m>(std_sexti32(std_muli(y(m), A(m, n))));
+}
+
+
+ods_def<VecmatI32I32I32Op>
+implements_interface<LinalgContractionOpInterface> :
+def vecmat_i32_i32_i32(y: i32(M), A: i32(M, N)) -> (x: i32(N)) {
+  x(n) = std_addi<m>(std_muli(y(m), A(m, n)));
+}
+
 ods_def<DotOp>
 implements_interface<LinalgContractionOpInterface> :
 def dot(A: f32(M), B: f32(M)) -> (C: f32()) {
   C() = std_addf<m>(std_mulf(A(m), B(m)));
 }
 
+ods_def<DotI8I8I32Op>
+implements_interface<LinalgContractionOpInterface> :
+def dot_i8_i8_i32(A: i8(M), B: i8(M)) -> (C: i32()) {
+  C() = std_addi<m>(std_sexti32(std_muli(A(m), B(m))));
+}
+
+ods_def<DotI16I16I32Op>
+implements_interface<LinalgContractionOpInterface> :
+def dot_i16_i16_i32(A: i16(M), B: i16(M)) -> (C: i32()) {
+  C() = std_addi<m>(std_sexti32(std_muli(A(m), B(m))));
+}
+
+
+ods_def<DotI32I32I32Op>
+implements_interface<LinalgContractionOpInterface> :
+def dot_i32_i32_i32(A: i32(M), B: i32(M)) -> (C: i32()) {
+  C() = std_addi<m>(std_muli(A(m), B(m)));
+}
+
+
 ods_def<BatchMatmulOp>
 implements_interface<LinalgContractionOpInterface> :
 def batch_matmul(A: f32(Batch, M, K), B: f32(Batch, K, N)) -> (C: f32(Batch, M, N)) {
   C(b, m, n) = std_addf<k>(std_mulf(A(b, m, k), B(b, k, n)));
 }
 
+ods_def<BatchMatmulI8I8I32Op>
+implements_interface<LinalgContractionOpInterface> :
+def batch_matmul_i8_i8_i32(A: i8(Batch, M, K), B: i8(Batch, K, N)) -> (C: i32(Batch, M, N)) {
+  C(b, m, n) = std_addi<k>(std_sexti32(std_muli(A(b, m, k), B(b, k, n))));
+}
+
+ods_def<BatchMatmulI16I16I32Op>
+implements_interface<LinalgContractionOpInterface> :
+def batch_matmul_i16_i16_i32(A: i16(Batch, M, K), B: i16(Batch, K, N)) -> (C: i32(Batch, M, N)) {
+  C(b, m, n) = std_addi<k>(std_sexti32(std_muli(A(b, m, k), B(b, k, n))));
+}
+
+
+ods_def<BatchMatmulI32I32I32Op>
+implements_interface<LinalgContractionOpInterface> :
+def batch_matmul_i32_i32_i32(A: i32(Batch, M, K), B: i32(Batch, K, N)) -> (C: i32(Batch, M, N)) {
+  C(b, m, n) = std_addi<k>(std_muli(A(b, m, k), B(b, k, n)));
+}
+
 ods_def<ConvWOp>:
 def conv_1d(I: f32(W), K: f32(KW)) -> (O: f32(W)) {
   O(w) = std_addf<kw>(std_mulf(I(w + kw), K(kw)));


        


More information about the Mlir-commits mailing list