[Mlir-commits] [mlir] 2eb7bbb - [mlir][tosa] Use 3D tensors in tosa.matmul

Rob Suderman llvmlistbot at llvm.org
Wed Jun 30 12:28:22 PDT 2021


Author: Suraj Sudhir
Date: 2021-06-30T12:22:52-07:00
New Revision: 2eb7bbbe65b6374e6137772f1c2c46e6daa5c33d

URL: https://github.com/llvm/llvm-project/commit/2eb7bbbe65b6374e6137772f1c2c46e6daa5c33d
DIFF: https://github.com/llvm/llvm-project/commit/2eb7bbbe65b6374e6137772f1c2c46e6daa5c33d.diff

LOG: [mlir][tosa] Use 3D tensors in tosa.matmul

Signed-off-by: Suraj Sudhir <suraj.sudhir at arm.com>

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D105213

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
    mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
    mlir/test/Dialect/Tosa/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 639934b1acb48..3a1f9d26be118 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -208,13 +208,13 @@ def Tosa_MatMulOp : Tosa_Op<"matmul", [NoSideEffect]> {
   }];
 
   let arguments = (ins
-    Tosa_Tensor2Dto3D:$a,
-    Tosa_Tensor2Dto3D:$b,
+    Tosa_Tensor3D:$a,
+    Tosa_Tensor3D:$b,
     OptionalAttr<Tosa_MatMulOpQuantizationAttr>:$quantization_info
   );
 
   let results = (outs
-    Tosa_Tensor2Dto3D:$c
+    Tosa_Tensor3D:$c
   );
 
   let builders = [Tosa_MatMulOpQuantInfoBuilder];

diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 5969d98408a9e..08324a15a07b5 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -124,8 +124,6 @@ def Tosa_Tensor5D : TensorRankOf<[Tosa_AnyNumber], [5]>;
 def Tosa_Tensor1Dto4D : TensorRankOf<[Tosa_AnyNumber], [1,2,3,4]>;
 def Tosa_Tensor1Dto6D : TensorRankOf<[Tosa_AnyNumber], [1,2,3,4,5,6]>;
 
-def Tosa_Tensor2Dto3D : TensorRankOf<[Tosa_AnyNumber], [2,3]>;
-
 def Tosa_TensorUpto4D : TensorRankOf<[Tosa_AnyNumber], [0,1,2,3,4]>;
 
 def Tosa_Int32TensorUpto4D : TensorRankOf<[Tosa_Int32], [0,1,2,3,4]>;

diff  --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 6ef301081bb56..ec169d0e16ebf 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -39,9 +39,9 @@ func @test_fully_connected(%arg0: tensor<14x19xf32>, %arg1: tensor<19x28xf32>, %
 
 // -----
 // CHECK-LABEL: test_matmul
-func @test_matmul(%arg0: tensor<14x19xf32>, %arg1: tensor<19x28xf32>) -> tensor<14x28xf32> {
-  %0 = "tosa.matmul"(%arg0, %arg1) : (tensor<14x19xf32>, tensor<19x28xf32>) -> tensor<14x28xf32>
-  return %0 : tensor<14x28xf32>
+func @test_matmul(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> {
+  %0 = "tosa.matmul"(%arg0, %arg1) : (tensor<1x14x19xf32>, tensor<1x19x28xf32>) -> tensor<1x14x28xf32>
+  return %0 : tensor<1x14x28xf32>
 }
 
 // -----


        


More information about the Mlir-commits mailing list