[Mlir-commits] [mlir] [mlir][tensor] Check the EmptyOp's dynamicSize to be non-negative (PR #65577)
Kohei Yamaguchi
llvmlistbot at llvm.org
Thu Sep 7 00:48:06 PDT 2023
https://github.com/sott0n created https://github.com/llvm/llvm-project/pull/65577:
This patch addresses a crash that occurs when negative dynamic sizes are provided in tensor.emptyOp by adding a check to ensure that dynamic sizes are non-negative.
Fixes #64064
>From 5c81c38650499b749223c4e54d32aac87378f48b Mon Sep 17 00:00:00 2001
From: Kohei Yamaguchi <fix7211 at gmail.com>
Date: Thu, 7 Sep 2023 16:18:23 +0000
Subject: [PATCH] [mlir][tensor] Check the EmptyOp's dynamicSize to be
non-negative
---
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 15 +++++++++++++++
mlir/test/Dialect/Tensor/canonicalize.mlir | 13 ++++++++++++-
mlir/test/Dialect/Tensor/invalid.mlir | 9 +++++++++
3 files changed, 36 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 42d89cd5a76208..7c3b641cea757c 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -621,6 +621,18 @@ LogicalResult EmptyOp::verify() {
return emitOpError("incorrect number of dynamic sizes, has ")
<< getDynamicSizes().size() << ", expected "
<< getType().getNumDynamicDims();
+
+ if (getDynamicSizes().size() > 0) {
+ if (llvm::any_of(getDynamicSizes(), [](Value operand) {
+ APInt constSizeArg;
+ if (!matchPattern(operand, m_ConstantInt(&constSizeArg))) {
+ return false;
+ }
+ return constSizeArg.isNegative();
+ }))
+ return emitOpError("dynamic size must be non-negative");
+ }
+
return success();
}
@@ -691,6 +703,9 @@ struct ReplaceEmptyTensorStaticShapeDims : OpRewritePattern<EmptyOp> {
Value dynamicSize = op.getDynamicSizes()[ctr++];
std::optional<int64_t> cst = getConstantIntValue(dynamicSize);
if (cst.has_value()) {
+ // dynamic size must be non-negative.
+ if (cst.value() < 0)
+ return failure();
staticShape[i] = *cst;
changedType = true;
} else {
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 4a757500920d50..70c274b6e0b63c 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -split-input-file -canonicalize="test-convergence" | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -canonicalize="test-convergence" -verify-diagnostics | FileCheck %s
// CHECK-LABEL: @tensor_bitcast_chain_ok
// CHECK-SAME: %[[IN:.*]]: tensor<2xi32>
@@ -1848,3 +1848,14 @@ func.func @pack_unpack_dynamic_with_padding(%t: tensor<?x?x?x?xf32>, %dim1: inde
%packed = tensor.pack %unpacked padding_value(%pad: f32) inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<?x?xf32> -> tensor<?x?x?x?xf32>
return %packed : tensor<?x?x?x?xf32>
}
+
+// -----
+
+func.func @invalid_empty_negative_size() -> (tensor<4x5x?xf32>) {
+ %c1 = arith.constant 1 : index
+ %cn2 = arith.constant 2 : index
+ %0 = index.sub %c1, %cn2
+ // expected-error at +1 {{dynamic size must be non-negative}}
+ %1 = tensor.empty(%0) : tensor<4x5x?xf32>
+ return %1 : tensor<4x5x?xf32>
+}
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 389e7e675c0eed..341be52b8e2b72 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -553,6 +553,15 @@ func.func @empty_wrong_number_of_operands(%sz : index) {
// -----
+func.func @empty_negative_size(%sz : index) {
+ %0 = arith.constant -1 : index
+ // expected-error at +1 {{dynamic size must be non-negative}}
+ %out = tensor.empty(%sz, %0) : tensor<2x?x?x5xf32>
+ return
+}
+
+// -----
+
func.func @pack_invalid_no_padding_no_full_tiles(%input: tensor<256x128xf32>, %output: tensor<8x8x16x33xf32>) -> tensor<8x8x16x33xf32> {
// expected-error at +1 {{invalid tile factor provided. Only full tiles are supported when padding_value is not set}}
%0 = tensor.pack %input inner_dims_pos = [1, 0] inner_tiles = [16, 33] into %output : tensor<256x128xf32> -> tensor<8x8x16x33xf32>
More information about the Mlir-commits
mailing list