[Mlir-commits] [mlir] [mlir][TosaToLinalg] Fix TosaToLinalg to restrict `tosa.cast` types to integer or float (PR #128859)
Longsheng Mou
llvmlistbot at llvm.org
Wed Feb 26 03:15:19 PST 2025
https://github.com/CoTinker updated https://github.com/llvm/llvm-project/pull/128859
>From a0887df346456ed63880f6617e5fdba49731b7fe Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Wed, 26 Feb 2025 18:57:54 +0800
Subject: [PATCH] [mlir][TosaToLinalg] Fix TosaToLinalg to restrict `tosa.cast`
types to integer or float
This PR fixes a bug where `TosaToLinalg` incorrectly allows `tosa.cast` to
accept types other than integer or float.
---
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp | 5 +++++
.../Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir | 8 ++++++++
2 files changed, 13 insertions(+)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 607667fcc6945..e5994cdc777b2 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -524,6 +524,11 @@ static Value createLinalgBodyCalculationForElementwiseOp(
if (isa<tosa::CastOp>(op)) {
Type srcTy = elementTy;
Type dstTy = resultTypes.front();
+ if (!srcTy.isIntOrFloat() || !dstTy.isIntOrFloat()) {
+ (void)rewriter.notifyMatchFailure(op,"unsupported type");
+ return nullptr;
+ }
+
bool bitExtend =
srcTy.getIntOrFloatBitWidth() < dstTy.getIntOrFloatBitWidth();
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
index 460e207d62de6..5db3f56cf459e 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
@@ -54,3 +54,11 @@ func.func @test_add_2d_different_ranks(%arg0: tensor<3x4xf32>, %arg1: tensor<2x3
%0 = "tosa.add"(%arg0, %arg1) : (tensor<3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
return %0 : tensor<2x3x4xf32>
}
+
+// -----
+
+func.func @cast_unsupported_type(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3x!quant.uniform<i16:f32, 0.078431375324726104:128>> {
+ // expected-error at +1 {{failed to legalize operation 'tosa.cast'}}
+ %0 = tosa.cast %arg0 : (tensor<13x21x3xi32>) -> tensor<13x21x3x!quant.uniform<i16:f32, 0.078431375324726104:128>>
+ return %0 : tensor<13x21x3x!quant.uniform<i16:f32, 0.078431375324726104:128>>
+}
More information about the Mlir-commits
mailing list