[Mlir-commits] [mlir] [MLIR][Tensor] Fix out-of-bounds FoldEmptyTensorWithDimOp crash #111270 (PR #112196)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Oct 17 23:06:40 PDT 2024


https://github.com/brod4910 updated https://github.com/llvm/llvm-project/pull/112196

>From af420ff571de25d7031ec464c5f6f92da0174d16 Mon Sep 17 00:00:00 2001
From: brod4910 <brod4910 at gmail.com>
Date: Mon, 14 Oct 2024 07:08:51 -0600
Subject: [PATCH 1/2] Fix out-of-bounds FoldEmptyTensorWithDimOp crash #111270

---
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 5 ++++-
 1 file changed, 4 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 4d6c5965c4fcc3..b02cd3f3fa1973 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -979,7 +979,10 @@ struct FoldEmptyTensorWithDimOp : public OpRewritePattern<DimOp> {
     auto emptyTensorOp = dimOp.getSource().getDefiningOp<EmptyOp>();
     if (!emptyTensorOp || !maybeConstantIndex)
       return failure();
-    if (!emptyTensorOp.getType().isDynamicDim(*maybeConstantIndex))
+    auto emptyTensorType = emptyTensorOp.getType();
+    if (*maybeConstantIndex < 0 ||
+        *maybeConstantIndex >= emptyTensorType.getRank() ||
+        !emptyTensorType.isDynamicDim(*maybeConstantIndex))
       return failure();
     rewriter.replaceOp(dimOp,
                        emptyTensorOp.getDynamicSize(*maybeConstantIndex));

>From dac04c5dba74041067569107fa327bc4f3d0b3e1 Mon Sep 17 00:00:00 2001
From: brod4910 <brod4910 at gmail.com>
Date: Fri, 18 Oct 2024 00:06:28 -0600
Subject: [PATCH 2/2] add tests for folding and no folding dim op

---
 .../Dialect/Linalg/drop-unit-extent-dims.mlir | 39 +++++++++++++++++++
 1 file changed, 39 insertions(+)

diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index 9a00b19aae400f..4f2d272cf725cb 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -1130,3 +1130,42 @@ module {
     return %1 : tensor<?x1x61x1xf32>
   }
 }
+
+// -----
+
+func.func @no_fold_empty_tensor_dim_out_of_bounds(%arg0: tensor<1x?x10xf32>) -> tensor<1x?xf32> {
+  %cst = arith.constant 1.000000e+00 : f32
+  %cst7 = arith.constant 7 : index
+  %dim = tensor.dim %arg0, %cst7 : tensor<1x?x10xf32>
+  %0 = tensor.empty(%dim) : tensor<1x?xf32>
+  %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1x?xf32>) -> tensor<1x?xf32>
+  return %1 : tensor<1x?xf32>
+}
+// CHECK-LABEL: func.func @no_fold_empty_tensor_dim_out_of_bounds
+//  CHECK-SAME:                 %[[ARG0:.*]]: tensor<1x?x10xf32>) -> tensor<1x?xf32> {
+//       CHECK:   %[[CST:.*]] = arith.constant 1.000000e+00 : f32
+//       CHECK:   %[[C7:.*]] = arith.constant 7
+//       CHECK:   %[[DIM:.*]] = tensor.dim %[[ARG0]], %[[C7]] : tensor<1x?x10xf32>
+//       CHECK:   %[[VAL_0:.*]] = tensor.empty(%[[DIM]]) : tensor<1x?xf32>
+//       CHECK:   %[[VAL_1:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[VAL_0]] : tensor<1x?xf32>) -> tensor<1x?xf32>
+//       CHECK:   return %[[VAL_1]] : tensor<1x?xf32>
+//       CHECK: }
+
+// -----
+
+func.func @fold_empty_tensor_dim_op(%arg0: tensor<1x?x10xf32>) -> tensor<1x?xf32> {
+  %cst = arith.constant 1.000000e+00 : f32
+  %cst2 = index.constant 2
+  %dim10 = tensor.dim %arg0, %cst2 : tensor<1x?x10xf32>
+  %0 = tensor.empty(%dim10) : tensor<1x?xf32>
+  %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1x?xf32>) -> tensor<1x?xf32>
+  return %1 : tensor<1x?xf32>
+}
+// CHECK-LABEL: func.func @fold_empty_tensor_dim_op
+//  CHECK-SAME:                 %[[ARG0:.*]]: tensor<1x?x10xf32>) -> tensor<1x?xf32> {
+//       CHECK:   %[[CST:.*]] = arith.constant 1.000000e+00 : f32
+//       CHECK:   %[[VAL_0:.*]] = tensor.empty() : tensor<1x10xf32>
+//       CHECK:   %[[VAL_1:.*]] = tensor.cast %[[VAL_0]] : tensor<1x10xf32> to tensor<1x?xf32>
+//       CHECK:   %[[VAL_2:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[VAL_1]] : tensor<1x?xf32>) -> tensor<1x?xf32>
+//       CHECK:   return %[[VAL_2]] : tensor<1x?xf32>
+//       CHECK: }
\ No newline at end of file



More information about the Mlir-commits mailing list