[Mlir-commits] [mlir] 2ad297d - [mlir][spirv] Handle zero-element tensors in spirv type conversion

Jakub Kuderski llvmlistbot at llvm.org
Wed Aug 2 14:37:16 PDT 2023


Author: Jakub Kuderski
Date: 2023-08-02T17:36:54-04:00
New Revision: 2ad297db2c0ea168822b4958dbe3f3c1d3198d79

URL: https://github.com/llvm/llvm-project/commit/2ad297db2c0ea168822b4958dbe3f3c1d3198d79
DIFF: https://github.com/llvm/llvm-project/commit/2ad297db2c0ea168822b4958dbe3f3c1d3198d79.diff

LOG: [mlir][spirv] Handle zero-element tensors in spirv type conversion

Return gracefully instead of crashing. Add missing type conversion
tests.

Fixes: https://github.com/llvm/llvm-project/issues/61044

Reviewed By: qedawkins

Differential Revision: https://reviews.llvm.org/D156942

Added: 
    

Modified: 
    mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
    mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index c8d7aef8964201..51ae2c10087e50 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -390,8 +390,14 @@ static Type convertTensorType(const spirv::TargetEnv &targetEnv,
     return nullptr;
   }
 
-  auto arrayElemCount = *tensorSize / *scalarSize;
-  auto arrayElemType = convertScalarType(targetEnv, options, scalarType);
+  int64_t arrayElemCount = *tensorSize / *scalarSize;
+  if (arrayElemCount == 0) {
+    LLVM_DEBUG(llvm::dbgs()
+               << type << " illegal: cannot handle zero-element tensors\n");
+    return nullptr;
+  }
+
+  Type arrayElemType = convertScalarType(targetEnv, options, scalarType);
   if (!arrayElemType)
     return nullptr;
   std::optional<int64_t> arrayElemSize =

diff  --git a/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir b/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir
index 904560d0572ba1..19de613bf5b073 100644
--- a/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt -split-input-file -convert-tensor-to-spirv -verify-diagnostics %s | FileCheck %s
+// RUN: mlir-opt --split-input-file --convert-tensor-to-spirv \
+// RUN:   --verify-diagnostics %s | FileCheck %s
 
 //===----------------------------------------------------------------------===//
 // tensor.extract
@@ -27,3 +28,38 @@ func.func @tensor_extract_constant(%a : index, %b: index, %c: index) -> i32 {
   // CHECK: spirv.ReturnValue %[[VAL]]
   return %extract : i32
 }
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Type conversion
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func @tensor_0d
+// CHECK-NEXT:    spirv.Constant 1 : i32
+func.func @tensor_0d() -> () {
+  %x = arith.constant dense<1> : tensor<i32>
+  return
+}
+
+// CHECK-LABEL: func @tensor_1d
+// CHECK-NEXT:    spirv.Constant dense<[1, 2, 3]> : tensor<3xi32> : !spirv.array<3 x i32>
+func.func @tensor_1d() -> () {
+  %x = arith.constant dense<[1, 2, 3]> : tensor<3xi32>
+  return
+}
+
+// CHECK-LABEL: func @tensor_2d
+// CHECK-NEXT:    spirv.Constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32> : !spirv.array<6 x i32>
+func.func @tensor_2d() -> () {
+  %x = arith.constant dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
+  return
+}
+
+// We do not handle zero-element tensors yet. Just make we do not crash on them.
+// CHECK-LABEL: func @tensor_2d_empty
+// CHECK-NEXT:    arith.constant dense<>
+func.func @tensor_2d_empty() -> () {
+  %x = arith.constant dense<> : tensor<2x0xi32>
+  return
+}


        


More information about the Mlir-commits mailing list