[Mlir-commits] [mlir] 93ffe17 - [mlir][tosa] Only match rfft2d of floats in linalg conversion (#93432)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jun 14 16:10:03 PDT 2024


Author: Kavan Bickerstaff
Date: 2024-06-14T16:09:59-07:00
New Revision: 93ffe1792fd9a985b96fee1105b399b5196a15bc

URL: https://github.com/llvm/llvm-project/commit/93ffe1792fd9a985b96fee1105b399b5196a15bc
DIFF: https://github.com/llvm/llvm-project/commit/93ffe1792fd9a985b96fee1105b399b5196a15bc.diff

LOG: [mlir][tosa] Only match rfft2d of floats in linalg conversion (#93432)

This prevents an assertion being triggered by the cast to FloatType.

Fixes https://github.com/llvm/llvm-project/issues/92064

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index e6ba6e6bc602d..8ad8e41414656 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -2324,7 +2324,10 @@ struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
     auto loc = rfft2d.getLoc();
     auto input = rfft2d.getInput();
     auto elementType =
-        cast<FloatType>(cast<ShapedType>(input.getType()).getElementType());
+        dyn_cast<FloatType>(cast<ShapedType>(input.getType()).getElementType());
+    if (!elementType)
+      return rewriter.notifyMatchFailure(rfft2d,
+                                         "only supports float element types");
 
     // Compute the output type and set of dynamic sizes
     llvm::SmallVector<Value> dynamicSizes;

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
index ad65410e635e9..b78577275a52a 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
@@ -27,3 +27,12 @@ func.func @unranked_add(%arg0 : tensor<10x10xf32> , %arg1 : tensor<10x10xf32>, %
   %2 = tosa.reshape %0 {new_shape = array<i64: 10, 10>} : (tensor<*xf32>) -> tensor<10x10xf32>
   return %2 : tensor<10x10xf32>
 }
+
+// -----
+
+// CHECK-LABEL: @rfft2d_with_non_float_type
+func.func @rfft2d_with_non_float_type(%arg0 : tensor<1x1x1xi32>) -> (tensor<1x1x1xi32>, tensor<1x1x1xi32>) {
+  // expected-error at +1 {{failed to legalize operation 'tosa.rfft2d'}}
+  %real, %imag = tosa.rfft2d %arg0 : (tensor<1x1x1xi32>) -> (tensor<1x1x1xi32>, tensor<1x1x1xi32>)
+  return %real, %imag : tensor<1x1x1xi32>, tensor<1x1x1xi32>
+}


        


More information about the Mlir-commits mailing list