[Mlir-commits] [mlir] [mlir][tosa] Fix lowering of tosa.matmul with dynamic outputs (PR #72724)

Spenser Bauman llvmlistbot at llvm.org
Fri Nov 17 16:22:41 PST 2023


https://github.com/sabauma created https://github.com/llvm/llvm-project/pull/72724

The existing lowering of tosa.matmul will construct illegal tensor.empty operations when the output type is more dynamic than the input types.

  %0 = tosa.matmul %arg0, %arg1 : (tensor<1x1x8xf32>, tensor<1x8x1xf32>) -> tensor<?x1x1xf32>

When constructing the tensor.empty operation, consult the output type rather than the input types to decide whether a dimension is dynamic.

>From bff2ec3b458a72ce21275e5ba7979454aa0d5bbf Mon Sep 17 00:00:00 2001
From: Spenser Bauman <sbauman at mathworks.com>
Date: Fri, 17 Nov 2023 19:10:06 -0500
Subject: [PATCH] [mlir][tosa] Fix lowering of tosa.matmul with dynamic outputs

The existing lowering of tosa.matmul will construct illegal tensor.empty
operations when the output type is more dynamic than the input types.

  %0 = tosa.matmul %arg0, %arg1 : (tensor<1x1x8xf32>, tensor<1x8x1xf32>) -> tensor<?x1x1xf32>

When constructing the tensor.empty operation, consult the output type
rather than the input types to decide whether a dimension is dynamic.
---
 .../Conversion/TosaToLinalg/TosaToLinalgNamed.cpp  |  9 +++------
 .../TosaToLinalg/tosa-to-linalg-named.mlir         | 14 ++++++++++++++
 2 files changed, 17 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 99a65f63038a43f..9e374be534985e5 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -540,21 +540,18 @@ class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
     auto outputTy = cast<ShapedType>(op.getType());
     auto outputElementTy = outputTy.getElementType();
 
-    auto firstOperandTy = cast<ShapedType>(op->getOperand(0).getType());
-    auto secondOperandTy = cast<ShapedType>(op->getOperand(1).getType());
-
     SmallVector<Value> dynDims;
     dynDims.resize(cast<ShapedType>(op->getResult(0).getType()).getRank());
 
-    if (!firstOperandTy.hasRank() || firstOperandTy.isDynamicDim(0)) {
+    if (!outputTy.hasRank() || outputTy.isDynamicDim(0)) {
       dynDims[0] = rewriter.create<tensor::DimOp>(loc, op->getOperand(0), 0);
     }
 
-    if (!firstOperandTy.hasRank() || firstOperandTy.isDynamicDim(1)) {
+    if (!outputTy.hasRank() || outputTy.isDynamicDim(1)) {
       dynDims[1] = rewriter.create<tensor::DimOp>(loc, op->getOperand(0), 1);
     }
 
-    if (!secondOperandTy.hasRank() || secondOperandTy.isDynamicDim(2)) {
+    if (!outputTy.hasRank() || outputTy.isDynamicDim(2)) {
       dynDims[2] = rewriter.create<tensor::DimOp>(loc, op->getOperand(1), 2);
     }
 
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 1cf7c8dee606899..4edc75331932803 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -68,6 +68,20 @@ func.func @matmul_dyn_independent_dim(%arg0: tensor<1x5x?xf32>, %arg1: tensor<1x
 
 // -----
 
+// CHECK-LABEL: @matmul_dyn_output
+func.func @matmul_dyn_output(%arg0: tensor<1x1x8xf32>, %arg1: tensor<1x8x1xf32>) -> tensor<?x1x1xf32> {
+  // CHECK: %[[C0:.+]] = arith.constant 0 : index
+  // CHECK: %[[DIM0:.+]] = tensor.dim %arg0, %[[C0]] : tensor<1x1x8xf32>
+  // CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+  // CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM0]]) : tensor<?x1x1xf32>
+  // CHECK: %[[FILLED:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[INIT]] : tensor<?x1x1xf32>) -> tensor<?x1x1xf32>
+  // CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x1x8xf32>, tensor<1x8x1xf32>) outs(%[[FILLED]] : tensor<?x1x1xf32>) -> tensor<?x1x1xf32>
+  %0 = tosa.matmul %arg0, %arg1 : (tensor<1x1x8xf32>, tensor<1x8x1xf32>) -> tensor<?x1x1xf32>
+  return %0 : tensor<?x1x1xf32>
+}
+
+// -----
+
 // CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
 // CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
 



More information about the Mlir-commits mailing list