[Mlir-commits] [mlir] 0e083ce - [mlir][tosa] Update tosa.matmul lowering to linalg.batch_matmul
Rob Suderman
llvmlistbot at llvm.org
Wed Jun 9 11:12:26 PDT 2021
Author: Rob Suderman
Date: 2021-06-09T11:05:36-07:00
New Revision: 0e083cef7003ba822be9b5dccbb01f9bfbb9dd34
URL: https://github.com/llvm/llvm-project/commit/0e083cef7003ba822be9b5dccbb01f9bfbb9dd34
DIFF: https://github.com/llvm/llvm-project/commit/0e083cef7003ba822be9b5dccbb01f9bfbb9dd34.diff
LOG: [mlir][tosa] Update tosa.matmul lowering to linalg.batch_matmul
tosa.matmul is a batched matmul, update the lowering for linalg
with the tests.
Reviewed By: sjarus
Differential Revision: https://reviews.llvm.org/D103937
Added:
Modified:
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 89a13750f99b0..19e35928cb06c 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1019,7 +1019,7 @@ class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
loc, outputTy.getShape(), outputTy.getElementType());
Value zeroTensor =
rewriter.create<linalg::FillOp>(loc, initTensor, zero).getResult(0);
- rewriter.replaceOpWithNewOp<linalg::MatmulOp>(
+ rewriter.replaceOpWithNewOp<linalg::BatchMatmulOp>(
op, TypeRange{op.getType()}, ValueRange{adaptor.a(), adaptor.b()},
ValueRange{zeroTensor});
return success();
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 15d0bf5129dc4..3aec680ec9d0e 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -844,13 +844,13 @@ func @tile(%arg0 : tensor<2x3xi8>) -> () {
// CHECK-LABEL: @matmul
-func @matmul(%arg0: tensor<5x3xf32>, %arg1: tensor<3x6xf32>, %arg2: tensor<6xf32>) -> (tensor<5x6xf32>) {
+func @matmul(%arg0: tensor<1x5x3xf32>, %arg1: tensor<1x3x6xf32>, %arg2: tensor<1x6xf32>) -> (tensor<1x5x6xf32>) {
// CHECK: [[C0:%.+]] = constant 0
- // CHECK: [[INIT:%.+]] = linalg.init_tensor [5, 6]
- // CHECK: [[FILLED:%.+]] = linalg.fill([[INIT]], [[C0]]) : tensor<5x6xf32>, f32 -> tensor<5x6xf32>
- // CHECK: linalg.matmul ins(%arg0, %arg1 : tensor<5x3xf32>, tensor<3x6xf32>) outs([[FILLED]] : tensor<5x6xf32>) -> tensor<5x6xf32>
- %0 = "tosa.matmul"(%arg0, %arg1) : (tensor<5x3xf32>, tensor<3x6xf32>) -> (tensor<5x6xf32>)
- return %0 : tensor<5x6xf32>
+ // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 5, 6]
+ // CHECK: [[FILLED:%.+]] = linalg.fill([[INIT]], [[C0]]) : tensor<1x5x6xf32>, f32 -> tensor<1x5x6xf32>
+ // CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x5x3xf32>, tensor<1x3x6xf32>) outs([[FILLED]] : tensor<1x5x6xf32>) -> tensor<1x5x6xf32>
+ %0 = "tosa.matmul"(%arg0, %arg1) : (tensor<1x5x3xf32>, tensor<1x3x6xf32>) -> (tensor<1x5x6xf32>)
+ return %0 : tensor<1x5x6xf32>
}
// -----
More information about the Mlir-commits
mailing list