[Mlir-commits] [mlir] [mlir][TensorToSPIRV] Add check for `tensor.extract` in TensorToSPIRV (PR #107110)
Longsheng Mou
llvmlistbot at llvm.org
Tue Sep 3 07:51:07 PDT 2024
https://github.com/CoTinker updated https://github.com/llvm/llvm-project/pull/107110
>From 71c240b816932959d0c19c967bce3d17f77c0db7 Mon Sep 17 00:00:00 2001
From: Longsheng Mou <moulongsheng at huawei.com>
Date: Tue, 3 Sep 2024 21:39:54 +0800
Subject: [PATCH] [mlir][TensorToSPIRV] Add type check for `tensor.extract` in
TensorToSPIRV
This patch add a type check for `tensor.extract` in TensorToSPIRV.
Only convert `tensor.extract` with supported element type.
---
.../Conversion/TensorToSPIRV/TensorToSPIRV.cpp | 2 ++
.../TensorToSPIRV/tensor-ops-to-spirv.mlir | 18 ++++++++++++++++++
2 files changed, 20 insertions(+)
diff --git a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp
index 0fb58623bdafbe..468fffdd2df91b 100644
--- a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp
+++ b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp
@@ -45,6 +45,8 @@ class TensorExtractPattern final
ConversionPatternRewriter &rewriter) const override {
auto tensorType = cast<RankedTensorType>(extractOp.getTensor().getType());
+ if (!isa<spirv::ScalarType>(tensorType.getElementType()))
+ return rewriter.notifyMatchFailure(extractOp, "unsupported type");
if (!tensorType.hasStaticShape())
return rewriter.notifyMatchFailure(extractOp, "non-static tensor");
diff --git a/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir b/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir
index 32d0fbea65b164..b77d75d27315ee 100644
--- a/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir
@@ -29,6 +29,24 @@ func.func @tensor_extract_constant(%a : index, %b: index, %c: index) -> i32 {
// -----
+// CHECK-LABEL: test_spirv_unsupported_type_index
+func.func @test_spirv_unsupported_type_index(%a : index) {
+ %cst = arith.constant dense<[1, 2]> : tensor<2xindex>
+ // CHECK: tensor.extract
+ %extract = tensor.extract %cst[%a] : tensor<2xindex>
+ return
+}
+
+// CHECK-LABEL: test_spirv_unsupported_type_i128
+func.func @test_spirv_unsupported_type_i128(%a : i128) {
+ %cst = arith.constant dense<[1, 2]> : tensor<2xi128>
+ // CHECK: tensor.extract
+ %extract = tensor.extract %cst[%a] : tensor<2xi128>
+ return
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// Type conversion
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list