[Mlir-commits] [mlir] 19109a2 - [mlir][tosa] Fix crash when inferring shape of tosa.equal

Jacques Pienaar llvmlistbot at llvm.org
Tue May 2 08:18:48 PDT 2023


Author: Spenser Bauman
Date: 2023-05-02T08:18:35-07:00
New Revision: 19109a274e67f8f6a30a0da22b2135daa327bb79

URL: https://github.com/llvm/llvm-project/commit/19109a274e67f8f6a30a0da22b2135daa327bb79
DIFF: https://github.com/llvm/llvm-project/commit/19109a274e67f8f6a30a0da22b2135daa327bb79.diff

LOG: [mlir][tosa] Fix crash when inferring shape of tosa.equal

The tosa-infer-shapes pass crashes when trying to infer the output shape
of tosa.equal when the input shape is unranked.
This is due to tosa-infer-shapes requiring at least information on the
base type of the resulting operation from inferReturnTypeComponents.
This change enhances EqualOp::inferReturnTypeComponents to always supply
the inferred elementType.

Reviewed By: eric-k256

Differential Revision: https://reviews.llvm.org/D149582

Added: 
    

Modified: 
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
    mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 903200f5f5cd5..aec98e256bf94 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -491,14 +491,15 @@ LogicalResult tosa::EqualOp::inferReturnTypeComponents(
     ValueShapeRange operands, DictionaryAttr attributes,
     OpaqueProperties properties, RegionRange regions,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+  auto elementType = IntegerType::get(context, /*width=*/1);
+
   llvm::SmallVector<int64_t> outShape;
   if (resolveBroadcastShape(operands, outShape).failed()) {
-    inferredReturnShapes.push_back(ShapedTypeComponents());
+    inferredReturnShapes.push_back(ShapedTypeComponents(elementType));
     return success();
   }
 
-  inferredReturnShapes.push_back(
-      ShapedTypeComponents(outShape, IntegerType::get(context, /*width=*/1)));
+  inferredReturnShapes.push_back(ShapedTypeComponents(outShape, elementType));
   return success();
 }
 

diff  --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 505350786d08d..268eae90d5cd9 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -1224,3 +1224,13 @@ func.func @test_dynamic_batch_fft2d(%arg0: tensor<?x4x8xf32>, %arg1: tensor<?x4x
   %output_real, %output_imag = "tosa.fft2d"(%arg0, %arg1) {inverse = false} : (tensor<?x4x8xf32>, tensor<?x4x8xf32>) -> (tensor<?x4x8xf32>, tensor<?x4x8xf32>)
   return %output_real, %output_imag : tensor<?x4x8xf32>, tensor<?x4x8xf32>
 }
+
+// -----
+
+// CHECK-LABEL: @test_unranked_equal
+func.func @test_unranked_equal(%arg0 : tensor<*xf32>, %arg1 : tensor<f32>) -> () {
+  // CHECK: "tosa.equal"(%arg0, %arg1) : (tensor<*xf32>, tensor<f32>) -> tensor<*xi1>
+  %0 = "tosa.equal"(%arg0, %arg1) : (tensor<*xf32>, tensor<f32>) -> tensor<*xi1>
+
+  return
+}


        


More information about the Mlir-commits mailing list