[Mlir-commits] [mlir] [mlir][tensor] Fix a crash in `ExtractOp::fold` (PR #115001)
Longsheng Mou
llvmlistbot at llvm.org
Tue Nov 5 06:38:39 PST 2024
https://github.com/CoTinker created https://github.com/llvm/llvm-project/pull/115001
This PR fixes a crash when the tensor of `tensor.extract` is a dense resource elements attribute.
Fixes #114728.
>From bb0824ee483a0d146f094256d434ec84efe44084 Mon Sep 17 00:00:00 2001
From: jinzhi <jinzhi6 at huawei.com>
Date: Tue, 5 Nov 2024 22:29:24 +0800
Subject: [PATCH] [mlir][tensor] Fix a crash in `ExtractOp::fold`
This PR fixes a crash when the tensor of `tensor.extract`
is a dense resource elements attribute.
---
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 11 ++++++++---
mlir/test/Dialect/Tensor/canonicalize.mlir | 14 ++++++++++++++
2 files changed, 22 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index c2d6bc610cd92a..6fdafe2b3d71af 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1128,12 +1128,17 @@ LogicalResult ExtractOp::verify() {
}
OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
- // If this is a splat elements attribute, simply return the value. All of
- // the elements of a splat attribute are the same.
- if (Attribute tensor = adaptor.getTensor())
+ if (Attribute tensor = adaptor.getTensor()) {
+ // If this is a splat elements attribute, simply return the value.
+ // All of the elements of a splat attribute are the same.
if (auto splatTensor = llvm::dyn_cast<SplatElementsAttr>(tensor))
return splatTensor.getSplatValue<Attribute>();
+ // If this is a dense resource elements attribute, return.
+ if (isa<DenseResourceElementsAttr>(tensor))
+ return {};
+ }
+
// Collect the constant indices into the tensor.
SmallVector<uint64_t, 8> indices;
for (Attribute indice : adaptor.getIndices()) {
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 236d2a3e60eb2c..c0cecd3583d217 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -173,6 +173,20 @@ func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex<f32>) {
// -----
+// Ensure extract dense resource elements not crash.
+
+// CHECK-LABEL: func @extract_dense_resource_nofold
+func.func @extract_dense_resource_nofold() -> i64 {
+ // CHECK: %[[EXT:.+]] = tensor.extract
+ // CHECK-NEXT: return %[[EXT]]
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant dense_resource<__elided__> : tensor<1xi64>
+ %extracted = tensor.extract %cst[%c0] : tensor<1xi64>
+ return %extracted : i64
+}
+
+// -----
+
// CHECK-LABEL: func @fold_insert
func.func @fold_insert(%arg0 : index) -> (tensor<4xf32>) {
// Fold an insert into a splat.
More information about the Mlir-commits
mailing list