[Mlir-commits] [mlir] 1be48fd - [mlir][TosaToLinalg] Fix TosaToLinalg to restrict `tosa.cast` types to integer or float (#128859)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Feb 26 10:53:15 PST 2025


Author: Longsheng Mou
Date: 2025-02-26T10:53:12-08:00
New Revision: 1be48fdf8bb25f82889aa75ca130e7aaf86295fe

URL: https://github.com/llvm/llvm-project/commit/1be48fdf8bb25f82889aa75ca130e7aaf86295fe
DIFF: https://github.com/llvm/llvm-project/commit/1be48fdf8bb25f82889aa75ca130e7aaf86295fe.diff

LOG: [mlir][TosaToLinalg] Fix TosaToLinalg to restrict `tosa.cast` types to integer or float (#128859)

This PR fixes a bug where `TosaToLinalg` incorrectly allows `tosa.cast`
to accept types other than integer or float.
Fixes #116342.

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 607667fcc6945..dfab37497d5c2 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -524,6 +524,11 @@ static Value createLinalgBodyCalculationForElementwiseOp(
   if (isa<tosa::CastOp>(op)) {
     Type srcTy = elementTy;
     Type dstTy = resultTypes.front();
+    if (!srcTy.isIntOrFloat() || !dstTy.isIntOrFloat()) {
+      (void)rewriter.notifyMatchFailure(op, "unsupported type");
+      return nullptr;
+    }
+
     bool bitExtend =
         srcTy.getIntOrFloatBitWidth() < dstTy.getIntOrFloatBitWidth();
 

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
index 460e207d62de6..5db3f56cf459e 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
@@ -54,3 +54,11 @@ func.func @test_add_2d_
diff erent_ranks(%arg0: tensor<3x4xf32>, %arg1: tensor<2x3
   %0 = "tosa.add"(%arg0, %arg1) : (tensor<3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
   return %0 : tensor<2x3x4xf32>
 }
+
+// -----
+
+func.func @cast_unsupported_type(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3x!quant.uniform<i16:f32, 0.078431375324726104:128>> {
+  // expected-error at +1 {{failed to legalize operation 'tosa.cast'}}
+  %0 = tosa.cast %arg0 : (tensor<13x21x3xi32>) -> tensor<13x21x3x!quant.uniform<i16:f32, 0.078431375324726104:128>>
+  return %0 : tensor<13x21x3x!quant.uniform<i16:f32, 0.078431375324726104:128>>
+}


        


More information about the Mlir-commits mailing list