[Mlir-commits] [mlir] f4b9839 - [mlir][TensorToSPIRV] Add type check for `tensor.extract` in TensorToSPIRV (#107110)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Sep 3 19:21:31 PDT 2024


Author: Longsheng Mou
Date: 2024-09-04T10:21:27+08:00
New Revision: f4b9839d6f7c9ec2967a42f2d5546a2a2ae77ca4

URL: https://github.com/llvm/llvm-project/commit/f4b9839d6f7c9ec2967a42f2d5546a2a2ae77ca4
DIFF: https://github.com/llvm/llvm-project/commit/f4b9839d6f7c9ec2967a42f2d5546a2a2ae77ca4.diff

LOG: [mlir][TensorToSPIRV] Add type check for `tensor.extract` in TensorToSPIRV (#107110)

This patch add a type check for `tensor.extract` in TensorToSPIRV.
Only convert `tensor.extract` with supported element type. Fix #74466.

Added: 
    

Modified: 
    mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp
    mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp
index 0fb58623bdafbe..468fffdd2df91b 100644
--- a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp
+++ b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp
@@ -45,6 +45,8 @@ class TensorExtractPattern final
                   ConversionPatternRewriter &rewriter) const override {
     auto tensorType = cast<RankedTensorType>(extractOp.getTensor().getType());
 
+    if (!isa<spirv::ScalarType>(tensorType.getElementType()))
+      return rewriter.notifyMatchFailure(extractOp, "unsupported type");
     if (!tensorType.hasStaticShape())
       return rewriter.notifyMatchFailure(extractOp, "non-static tensor");
 

diff  --git a/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir b/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir
index 32d0fbea65b164..b69c2d0408d176 100644
--- a/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir
@@ -29,6 +29,24 @@ func.func @tensor_extract_constant(%a : index, %b: index, %c: index) -> i32 {
 
 // -----
 
+// CHECK-LABEL: test_spirv_unsupported_type_index
+func.func @test_spirv_unsupported_type_index(%a : index) {
+  %cst = arith.constant dense<[1, 2]> : tensor<2xindex>
+  // CHECK: tensor.extract
+  %extract = tensor.extract %cst[%a] : tensor<2xindex>
+  return
+}
+
+// CHECK-LABEL: test_spirv_unsupported_type_i128
+func.func @test_spirv_unsupported_type_i128(%a : index) {
+  %cst = arith.constant dense<[1, 2]> : tensor<2xi128>
+  // CHECK: tensor.extract
+  %extract = tensor.extract %cst[%a] : tensor<2xi128>
+  return
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // Type conversion
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list