[Mlir-commits] [mlir] [mlir][TensorToSPIRV] Add check for `tensor.extract` in TensorToSPIRV (PR #107110)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Sep 3 06:45:49 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-spirv
Author: Longsheng Mou (CoTinker)
<details>
<summary>Changes</summary>
This patch add a check for `tensor.extract` in TensorToSPIRV, which is only support integers and floats type. Fix #<!-- -->74466.
---
Full diff: https://github.com/llvm/llvm-project/pull/107110.diff
2 Files Affected:
- (modified) mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp (+3)
- (modified) mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir (+10)
``````````diff
diff --git a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp
index 0fb58623bdafbe..03bd79c843158d 100644
--- a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp
+++ b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp
@@ -45,6 +45,9 @@ class TensorExtractPattern final
ConversionPatternRewriter &rewriter) const override {
auto tensorType = cast<RankedTensorType>(extractOp.getTensor().getType());
+ if (!tensorType.getElementType().isIntOrFloat())
+ return rewriter.notifyMatchFailure(extractOp,
+ "only integers and floats supported");
if (!tensorType.hasStaticShape())
return rewriter.notifyMatchFailure(extractOp, "non-static tensor");
diff --git a/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir b/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir
index 32d0fbea65b164..b1c2d5c2712eaa 100644
--- a/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir
@@ -29,6 +29,16 @@ func.func @tensor_extract_constant(%a : index, %b: index, %c: index) -> i32 {
// -----
+// CHECK-LABEL: tensor_extract_unsupported_type
+func.func @tensor_extract_unsupported_type(%a : index) {
+ %cst = arith.constant dense<[1, 2]> : tensor<2xindex>
+ // CHECK: tensor.extract
+ %extract = tensor.extract %cst[%a] : tensor<2xindex>
+ return
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// Type conversion
//===----------------------------------------------------------------------===//
``````````
</details>
https://github.com/llvm/llvm-project/pull/107110
More information about the Mlir-commits
mailing list