[Mlir-commits] [mlir] [mlir][TensorToSPIRV] Add check for `tensor.extract` in TensorToSPIRV (PR #107110)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Sep 3 06:45:49 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-spirv

Author: Longsheng Mou (CoTinker)

<details>
<summary>Changes</summary>

This patch add a check for `tensor.extract` in TensorToSPIRV, which is only support integers and floats type. Fix #<!-- -->74466.

---
Full diff: https://github.com/llvm/llvm-project/pull/107110.diff


2 Files Affected:

- (modified) mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp (+3) 
- (modified) mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir (+10) 


``````````diff
diff --git a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp
index 0fb58623bdafbe..03bd79c843158d 100644
--- a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp
+++ b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp
@@ -45,6 +45,9 @@ class TensorExtractPattern final
                   ConversionPatternRewriter &rewriter) const override {
     auto tensorType = cast<RankedTensorType>(extractOp.getTensor().getType());
 
+    if (!tensorType.getElementType().isIntOrFloat())
+      return rewriter.notifyMatchFailure(extractOp,
+                                         "only integers and floats supported");
     if (!tensorType.hasStaticShape())
       return rewriter.notifyMatchFailure(extractOp, "non-static tensor");
 
diff --git a/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir b/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir
index 32d0fbea65b164..b1c2d5c2712eaa 100644
--- a/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir
@@ -29,6 +29,16 @@ func.func @tensor_extract_constant(%a : index, %b: index, %c: index) -> i32 {
 
 // -----
 
+// CHECK-LABEL: tensor_extract_unsupported_type
+func.func @tensor_extract_unsupported_type(%a : index) {
+  %cst = arith.constant dense<[1, 2]> : tensor<2xindex>
+  // CHECK: tensor.extract
+  %extract = tensor.extract %cst[%a] : tensor<2xindex>
+  return
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // Type conversion
 //===----------------------------------------------------------------------===//

``````````

</details>


https://github.com/llvm/llvm-project/pull/107110


More information about the Mlir-commits mailing list