[Mlir-commits] [mlir] dc55d31 - [mlir][tensor] Fix a crash in `ExtractOp::fold` (#115001)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Nov 6 00:09:46 PST 2024


Author: Longsheng Mou
Date: 2024-11-06T16:09:42+08:00
New Revision: dc55d31f4cf5c97b56f6b7e1c24b70674cc15a01

URL: https://github.com/llvm/llvm-project/commit/dc55d31f4cf5c97b56f6b7e1c24b70674cc15a01
DIFF: https://github.com/llvm/llvm-project/commit/dc55d31f4cf5c97b56f6b7e1c24b70674cc15a01.diff

LOG: [mlir][tensor] Fix a crash in `ExtractOp::fold` (#115001)

This PR fixes a crash when the tensor of `tensor.extract` is a dense
resource elements attribute.
Fixes #114728.

Co-authored-by: jinzhi <jinzhi6 at huawei.com>

Added: 
    

Modified: 
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/test/Dialect/Tensor/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 4249a88627588a..20480c6437c424 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1128,12 +1128,17 @@ LogicalResult ExtractOp::verify() {
 }
 
 OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
-  // If this is a splat elements attribute, simply return the value. All of
-  // the elements of a splat attribute are the same.
-  if (Attribute tensor = adaptor.getTensor())
+  if (Attribute tensor = adaptor.getTensor()) {
+    // If this is a splat elements attribute, simply return the value.
+    // All of the elements of a splat attribute are the same.
     if (auto splatTensor = llvm::dyn_cast<SplatElementsAttr>(tensor))
       return splatTensor.getSplatValue<Attribute>();
 
+    // If this is a dense resource elements attribute, return.
+    if (isa<DenseResourceElementsAttr>(tensor))
+      return {};
+  }
+
   // Collect the constant indices into the tensor.
   SmallVector<uint64_t, 8> indices;
   for (Attribute indice : adaptor.getIndices()) {

diff  --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 2186aab9a527cc..2c826d7ae008d5 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -173,6 +173,20 @@ func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex<f32>) {
 
 // -----
 
+// Ensure extract dense resource elements not crash.
+
+// CHECK-LABEL: func @extract_dense_resource_nofold
+func.func @extract_dense_resource_nofold() -> i64 {
+  // CHECK:      %[[EXT:.+]] = tensor.extract
+  // CHECK-NEXT:   return %[[EXT]]
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant dense_resource<__elided__> : tensor<1xi64>
+  %extracted = tensor.extract %cst[%c0] : tensor<1xi64>
+  return %extracted : i64
+}
+
+// -----
+
 // CHECK-LABEL: func @fold_insert
 func.func @fold_insert(%arg0 : index) -> (tensor<4xf32>) {
   // Fold an insert into a splat.


        


More information about the Mlir-commits mailing list