[Mlir-commits] [mlir] 2ad297d - [mlir][spirv] Handle zero-element tensors in spirv type conversion
Jakub Kuderski
llvmlistbot at llvm.org
Wed Aug 2 14:37:16 PDT 2023
Author: Jakub Kuderski
Date: 2023-08-02T17:36:54-04:00
New Revision: 2ad297db2c0ea168822b4958dbe3f3c1d3198d79
URL: https://github.com/llvm/llvm-project/commit/2ad297db2c0ea168822b4958dbe3f3c1d3198d79
DIFF: https://github.com/llvm/llvm-project/commit/2ad297db2c0ea168822b4958dbe3f3c1d3198d79.diff
LOG: [mlir][spirv] Handle zero-element tensors in spirv type conversion
Return gracefully instead of crashing. Add missing type conversion
tests.
Fixes: https://github.com/llvm/llvm-project/issues/61044
Reviewed By: qedawkins
Differential Revision: https://reviews.llvm.org/D156942
Added:
Modified:
mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index c8d7aef8964201..51ae2c10087e50 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -390,8 +390,14 @@ static Type convertTensorType(const spirv::TargetEnv &targetEnv,
return nullptr;
}
- auto arrayElemCount = *tensorSize / *scalarSize;
- auto arrayElemType = convertScalarType(targetEnv, options, scalarType);
+ int64_t arrayElemCount = *tensorSize / *scalarSize;
+ if (arrayElemCount == 0) {
+ LLVM_DEBUG(llvm::dbgs()
+ << type << " illegal: cannot handle zero-element tensors\n");
+ return nullptr;
+ }
+
+ Type arrayElemType = convertScalarType(targetEnv, options, scalarType);
if (!arrayElemType)
return nullptr;
std::optional<int64_t> arrayElemSize =
diff --git a/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir b/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir
index 904560d0572ba1..19de613bf5b073 100644
--- a/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt -split-input-file -convert-tensor-to-spirv -verify-diagnostics %s | FileCheck %s
+// RUN: mlir-opt --split-input-file --convert-tensor-to-spirv \
+// RUN: --verify-diagnostics %s | FileCheck %s
//===----------------------------------------------------------------------===//
// tensor.extract
@@ -27,3 +28,38 @@ func.func @tensor_extract_constant(%a : index, %b: index, %c: index) -> i32 {
// CHECK: spirv.ReturnValue %[[VAL]]
return %extract : i32
}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Type conversion
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func @tensor_0d
+// CHECK-NEXT: spirv.Constant 1 : i32
+func.func @tensor_0d() -> () {
+ %x = arith.constant dense<1> : tensor<i32>
+ return
+}
+
+// CHECK-LABEL: func @tensor_1d
+// CHECK-NEXT: spirv.Constant dense<[1, 2, 3]> : tensor<3xi32> : !spirv.array<3 x i32>
+func.func @tensor_1d() -> () {
+ %x = arith.constant dense<[1, 2, 3]> : tensor<3xi32>
+ return
+}
+
+// CHECK-LABEL: func @tensor_2d
+// CHECK-NEXT: spirv.Constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32> : !spirv.array<6 x i32>
+func.func @tensor_2d() -> () {
+ %x = arith.constant dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
+ return
+}
+
+// We do not handle zero-element tensors yet. Just make we do not crash on them.
+// CHECK-LABEL: func @tensor_2d_empty
+// CHECK-NEXT: arith.constant dense<>
+func.func @tensor_2d_empty() -> () {
+ %x = arith.constant dense<> : tensor<2x0xi32>
+ return
+}
More information about the Mlir-commits
mailing list