[Mlir-commits] [mlir] 0e083ce - [mlir][tosa] Update tosa.matmul lowering to linalg.batch_matmul

Rob Suderman llvmlistbot at llvm.org
Wed Jun 9 11:12:26 PDT 2021


Author: Rob Suderman
Date: 2021-06-09T11:05:36-07:00
New Revision: 0e083cef7003ba822be9b5dccbb01f9bfbb9dd34

URL: https://github.com/llvm/llvm-project/commit/0e083cef7003ba822be9b5dccbb01f9bfbb9dd34
DIFF: https://github.com/llvm/llvm-project/commit/0e083cef7003ba822be9b5dccbb01f9bfbb9dd34.diff

LOG: [mlir][tosa] Update tosa.matmul lowering to linalg.batch_matmul

tosa.matmul is a batched matmul, update the lowering for linalg
with the tests.

Reviewed By: sjarus

Differential Revision: https://reviews.llvm.org/D103937

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 89a13750f99b0..19e35928cb06c 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1019,7 +1019,7 @@ class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
         loc, outputTy.getShape(), outputTy.getElementType());
     Value zeroTensor =
         rewriter.create<linalg::FillOp>(loc, initTensor, zero).getResult(0);
-    rewriter.replaceOpWithNewOp<linalg::MatmulOp>(
+    rewriter.replaceOpWithNewOp<linalg::BatchMatmulOp>(
         op, TypeRange{op.getType()}, ValueRange{adaptor.a(), adaptor.b()},
         ValueRange{zeroTensor});
     return success();

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 15d0bf5129dc4..3aec680ec9d0e 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -844,13 +844,13 @@ func @tile(%arg0 : tensor<2x3xi8>) -> () {
 
 
 // CHECK-LABEL: @matmul
-func @matmul(%arg0: tensor<5x3xf32>, %arg1: tensor<3x6xf32>, %arg2: tensor<6xf32>) -> (tensor<5x6xf32>) {
+func @matmul(%arg0: tensor<1x5x3xf32>, %arg1: tensor<1x3x6xf32>, %arg2: tensor<1x6xf32>) -> (tensor<1x5x6xf32>) {
   // CHECK: [[C0:%.+]] = constant 0
-  // CHECK: [[INIT:%.+]] = linalg.init_tensor [5, 6]
-  // CHECK: [[FILLED:%.+]] = linalg.fill([[INIT]], [[C0]]) : tensor<5x6xf32>, f32 -> tensor<5x6xf32>
-  // CHECK: linalg.matmul ins(%arg0, %arg1 : tensor<5x3xf32>, tensor<3x6xf32>) outs([[FILLED]] : tensor<5x6xf32>) -> tensor<5x6xf32>
-  %0 = "tosa.matmul"(%arg0, %arg1) : (tensor<5x3xf32>, tensor<3x6xf32>)  -> (tensor<5x6xf32>)
-  return %0 : tensor<5x6xf32>
+  // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 5, 6]
+  // CHECK: [[FILLED:%.+]] = linalg.fill([[INIT]], [[C0]]) : tensor<1x5x6xf32>, f32 -> tensor<1x5x6xf32>
+  // CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x5x3xf32>, tensor<1x3x6xf32>) outs([[FILLED]] : tensor<1x5x6xf32>) -> tensor<1x5x6xf32>
+  %0 = "tosa.matmul"(%arg0, %arg1) : (tensor<1x5x3xf32>, tensor<1x3x6xf32>)  -> (tensor<1x5x6xf32>)
+  return %0 : tensor<1x5x6xf32>
 }
 
 // -----


        


More information about the Mlir-commits mailing list