[Mlir-commits] [mlir] [mlir][tosa] Fix lowering of tosa.matmul with dynamic outputs (PR #72724)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Nov 17 16:23:08 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-linalg
Author: Spenser Bauman (sabauma)
<details>
<summary>Changes</summary>
The existing lowering of tosa.matmul will construct illegal tensor.empty operations when the output type is more dynamic than the input types.
%0 = tosa.matmul %arg0, %arg1 : (tensor<1x1x8xf32>, tensor<1x8x1xf32>) -> tensor<?x1x1xf32>
When constructing the tensor.empty operation, consult the output type rather than the input types to decide whether a dimension is dynamic.
---
Full diff: https://github.com/llvm/llvm-project/pull/72724.diff
2 Files Affected:
- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp (+3-6)
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir (+14)
``````````diff
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 99a65f63038a43f..9e374be534985e5 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -540,21 +540,18 @@ class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
auto outputTy = cast<ShapedType>(op.getType());
auto outputElementTy = outputTy.getElementType();
- auto firstOperandTy = cast<ShapedType>(op->getOperand(0).getType());
- auto secondOperandTy = cast<ShapedType>(op->getOperand(1).getType());
-
SmallVector<Value> dynDims;
dynDims.resize(cast<ShapedType>(op->getResult(0).getType()).getRank());
- if (!firstOperandTy.hasRank() || firstOperandTy.isDynamicDim(0)) {
+ if (!outputTy.hasRank() || outputTy.isDynamicDim(0)) {
dynDims[0] = rewriter.create<tensor::DimOp>(loc, op->getOperand(0), 0);
}
- if (!firstOperandTy.hasRank() || firstOperandTy.isDynamicDim(1)) {
+ if (!outputTy.hasRank() || outputTy.isDynamicDim(1)) {
dynDims[1] = rewriter.create<tensor::DimOp>(loc, op->getOperand(0), 1);
}
- if (!secondOperandTy.hasRank() || secondOperandTy.isDynamicDim(2)) {
+ if (!outputTy.hasRank() || outputTy.isDynamicDim(2)) {
dynDims[2] = rewriter.create<tensor::DimOp>(loc, op->getOperand(1), 2);
}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 1cf7c8dee606899..4edc75331932803 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -68,6 +68,20 @@ func.func @matmul_dyn_independent_dim(%arg0: tensor<1x5x?xf32>, %arg1: tensor<1x
// -----
+// CHECK-LABEL: @matmul_dyn_output
+func.func @matmul_dyn_output(%arg0: tensor<1x1x8xf32>, %arg1: tensor<1x8x1xf32>) -> tensor<?x1x1xf32> {
+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
+ // CHECK: %[[DIM0:.+]] = tensor.dim %arg0, %[[C0]] : tensor<1x1x8xf32>
+ // CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+ // CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM0]]) : tensor<?x1x1xf32>
+ // CHECK: %[[FILLED:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[INIT]] : tensor<?x1x1xf32>) -> tensor<?x1x1xf32>
+ // CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x1x8xf32>, tensor<1x8x1xf32>) outs(%[[FILLED]] : tensor<?x1x1xf32>) -> tensor<?x1x1xf32>
+ %0 = tosa.matmul %arg0, %arg1 : (tensor<1x1x8xf32>, tensor<1x8x1xf32>) -> tensor<?x1x1xf32>
+ return %0 : tensor<?x1x1xf32>
+}
+
+// -----
+
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
``````````
</details>
https://github.com/llvm/llvm-project/pull/72724
More information about the Mlir-commits
mailing list