[Mlir-commits] [mlir] 40556d0 - [MLIR][Tensor] Fix out-of-bounds FoldEmptyTensorWithDimOp crash (#112196)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Nov 6 02:56:55 PST 2024
Author: brod4910
Date: 2024-11-06T11:56:51+01:00
New Revision: 40556d08491f530e03746fb188b38e7f9cb272c7
URL: https://github.com/llvm/llvm-project/commit/40556d08491f530e03746fb188b38e7f9cb272c7
DIFF: https://github.com/llvm/llvm-project/commit/40556d08491f530e03746fb188b38e7f9cb272c7.diff
LOG: [MLIR][Tensor] Fix out-of-bounds FoldEmptyTensorWithDimOp crash (#112196)
Fixes #111270
Added:
Modified:
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 20480c6437c424..8e0d0104397468 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -980,7 +980,10 @@ struct FoldEmptyTensorWithDimOp : public OpRewritePattern<DimOp> {
auto emptyTensorOp = dimOp.getSource().getDefiningOp<EmptyOp>();
if (!emptyTensorOp || !maybeConstantIndex)
return failure();
- if (!emptyTensorOp.getType().isDynamicDim(*maybeConstantIndex))
+ auto emptyTensorType = emptyTensorOp.getType();
+ if (*maybeConstantIndex < 0 ||
+ *maybeConstantIndex >= emptyTensorType.getRank() ||
+ !emptyTensorType.isDynamicDim(*maybeConstantIndex))
return failure();
rewriter.replaceOp(dimOp,
emptyTensorOp.getDynamicSize(*maybeConstantIndex));
diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index 9a00b19aae400f..3256daa8e0b591 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -1130,3 +1130,42 @@ module {
return %1 : tensor<?x1x61x1xf32>
}
}
+
+// -----
+
+func.func @no_fold_empty_tensor_dim_out_of_bounds(%arg0: tensor<1x?x10xf32>) -> tensor<1x?xf32> {
+ %cst = arith.constant 1.000000e+00 : f32
+ %cst7 = arith.constant 7 : index
+ %dim = tensor.dim %arg0, %cst7 : tensor<1x?x10xf32>
+ %0 = tensor.empty(%dim) : tensor<1x?xf32>
+ %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1x?xf32>) -> tensor<1x?xf32>
+ return %1 : tensor<1x?xf32>
+}
+// CHECK-LABEL: func.func @no_fold_empty_tensor_dim_out_of_bounds
+// CHECK-SAME: %[[ARG0:.*]]: tensor<1x?x10xf32>) -> tensor<1x?xf32> {
+// CHECK: %[[CST:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK: %[[C7:.*]] = arith.constant 7
+// CHECK: %[[DIM:.*]] = tensor.dim %[[ARG0]], %[[C7]] : tensor<1x?x10xf32>
+// CHECK: %[[VAL_0:.*]] = tensor.empty(%[[DIM]]) : tensor<1x?xf32>
+// CHECK: %[[VAL_1:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[VAL_0]] : tensor<1x?xf32>) -> tensor<1x?xf32>
+// CHECK: return %[[VAL_1]] : tensor<1x?xf32>
+// CHECK: }
+
+// -----
+
+func.func @fold_empty_tensor_dim_op(%arg0: tensor<1x?x10xf32>) -> tensor<1x?xf32> {
+ %cst = arith.constant 1.000000e+00 : f32
+ %cst2 = index.constant 2
+ %dim10 = tensor.dim %arg0, %cst2 : tensor<1x?x10xf32>
+ %0 = tensor.empty(%dim10) : tensor<1x?xf32>
+ %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1x?xf32>) -> tensor<1x?xf32>
+ return %1 : tensor<1x?xf32>
+}
+// CHECK-LABEL: func.func @fold_empty_tensor_dim_op
+// CHECK-SAME: %[[ARG0:.*]]: tensor<1x?x10xf32>) -> tensor<1x?xf32> {
+// CHECK: %[[CST:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<1x10xf32>
+// CHECK: %[[VAL_1:.*]] = tensor.cast %[[VAL_0]] : tensor<1x10xf32> to tensor<1x?xf32>
+// CHECK: %[[VAL_2:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[VAL_1]] : tensor<1x?xf32>) -> tensor<1x?xf32>
+// CHECK: return %[[VAL_2]] : tensor<1x?xf32>
+// CHECK: }
More information about the Mlir-commits
mailing list