[Mlir-commits] [mlir] 19109a2 - [mlir][tosa] Fix crash when inferring shape of tosa.equal
Jacques Pienaar
llvmlistbot at llvm.org
Tue May 2 08:18:48 PDT 2023
Author: Spenser Bauman
Date: 2023-05-02T08:18:35-07:00
New Revision: 19109a274e67f8f6a30a0da22b2135daa327bb79
URL: https://github.com/llvm/llvm-project/commit/19109a274e67f8f6a30a0da22b2135daa327bb79
DIFF: https://github.com/llvm/llvm-project/commit/19109a274e67f8f6a30a0da22b2135daa327bb79.diff
LOG: [mlir][tosa] Fix crash when inferring shape of tosa.equal
The tosa-infer-shapes pass crashes when trying to infer the output shape
of tosa.equal when the input shape is unranked.
This is due to tosa-infer-shapes requiring at least information on the
base type of the resulting operation from inferReturnTypeComponents.
This change enhances EqualOp::inferReturnTypeComponents to always supply
the inferred elementType.
Reviewed By: eric-k256
Differential Revision: https://reviews.llvm.org/D149582
Added:
Modified:
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 903200f5f5cd5..aec98e256bf94 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -491,14 +491,15 @@ LogicalResult tosa::EqualOp::inferReturnTypeComponents(
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+ auto elementType = IntegerType::get(context, /*width=*/1);
+
llvm::SmallVector<int64_t> outShape;
if (resolveBroadcastShape(operands, outShape).failed()) {
- inferredReturnShapes.push_back(ShapedTypeComponents());
+ inferredReturnShapes.push_back(ShapedTypeComponents(elementType));
return success();
}
- inferredReturnShapes.push_back(
- ShapedTypeComponents(outShape, IntegerType::get(context, /*width=*/1)));
+ inferredReturnShapes.push_back(ShapedTypeComponents(outShape, elementType));
return success();
}
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 505350786d08d..268eae90d5cd9 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -1224,3 +1224,13 @@ func.func @test_dynamic_batch_fft2d(%arg0: tensor<?x4x8xf32>, %arg1: tensor<?x4x
%output_real, %output_imag = "tosa.fft2d"(%arg0, %arg1) {inverse = false} : (tensor<?x4x8xf32>, tensor<?x4x8xf32>) -> (tensor<?x4x8xf32>, tensor<?x4x8xf32>)
return %output_real, %output_imag : tensor<?x4x8xf32>, tensor<?x4x8xf32>
}
+
+// -----
+
+// CHECK-LABEL: @test_unranked_equal
+func.func @test_unranked_equal(%arg0 : tensor<*xf32>, %arg1 : tensor<f32>) -> () {
+ // CHECK: "tosa.equal"(%arg0, %arg1) : (tensor<*xf32>, tensor<f32>) -> tensor<*xi1>
+ %0 = "tosa.equal"(%arg0, %arg1) : (tensor<*xf32>, tensor<f32>) -> tensor<*xi1>
+
+ return
+}
More information about the Mlir-commits
mailing list