[Mlir-commits] [mlir] 64bd5bb - [mlir] Avoid tensor canonicalizer crash on negative dimensions
Alex Zinenko
llvmlistbot at llvm.org
Wed May 31 05:14:37 PDT 2023
Author: rikhuijzer
Date: 2023-05-31T14:14:30+02:00
New Revision: 64bd5bbb9bbb72de5f59755c74dae4b4881d93d5
URL: https://github.com/llvm/llvm-project/commit/64bd5bbb9bbb72de5f59755c74dae4b4881d93d5
DIFF: https://github.com/llvm/llvm-project/commit/64bd5bbb9bbb72de5f59755c74dae4b4881d93d5.diff
LOG: [mlir] Avoid tensor canonicalizer crash on negative dimensions
Fixes #59703.
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D151611
Added:
Modified:
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Tensor/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 1adb9c7f262fe..283f1be6aa793 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1111,14 +1111,43 @@ LogicalResult GenerateOp::reifyResultShapes(
return success();
}
+/// Extract operands and shape from a tensor with dynamic extents.
+static void operandsAndShape(TensorType resultType,
+ Operation::operand_range dynamicExtents,
+ SmallVectorImpl<Value> &newOperands,
+ SmallVectorImpl<int64_t> &newShape) {
+ auto operandsIt = dynamicExtents.begin();
+ for (int64_t dim : resultType.getShape()) {
+ if (!ShapedType::isDynamic(dim)) {
+ newShape.push_back(dim);
+ continue;
+ }
+ APInt index;
+ if (!matchPattern(*operandsIt, m_ConstantInt(&index))) {
+ newShape.push_back(ShapedType::kDynamic);
+ newOperands.push_back(*operandsIt++);
+ continue;
+ }
+ newShape.push_back(index.getSExtValue());
+ operandsIt++;
+ }
+}
+
LogicalResult GenerateOp::verify() {
// Ensure that the tensor type has as many dynamic dimensions as are
// specified by the operands.
- RankedTensorType resultTy = llvm::cast<RankedTensorType>(getType());
- if (getNumOperands() != resultTy.getNumDynamicDims())
+ RankedTensorType resultType = llvm::cast<RankedTensorType>(getType());
+ if (getNumOperands() != resultType.getNumDynamicDims())
return emitError("must have as many index operands as dynamic extents "
"in the result type");
-
+ // Ensure operands are non-negative.
+ SmallVector<Value> newOperands;
+ SmallVector<int64_t> newShape;
+ operandsAndShape(resultType, getDynamicExtents(), newOperands, newShape);
+ for (int64_t newdim : newShape) {
+ if (newdim < 0 && !ShapedType::isDynamic(newdim))
+ return emitError("tensor dimensions must be non-negative");
+ }
return success();
}
@@ -1176,24 +1205,11 @@ struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
if (resultType.hasStaticShape())
return failure();
- SmallVector<Value, 4> newOperands;
- SmallVector<int64_t, 4> newShape;
- auto operandsIt = tensorFromElements.getDynamicExtents().begin();
-
- for (int64_t dim : resultType.getShape()) {
- if (!ShapedType::isDynamic(dim)) {
- newShape.push_back(dim);
- continue;
- }
- APInt index;
- if (!matchPattern(*operandsIt, m_ConstantInt(&index))) {
- newShape.push_back(ShapedType::kDynamic);
- newOperands.push_back(*operandsIt++);
- continue;
- }
- newShape.push_back(index.getSExtValue());
- operandsIt++;
- }
+ Operation::operand_range dynamicExtents =
+ tensorFromElements.getDynamicExtents();
+ SmallVector<Value> newOperands;
+ SmallVector<int64_t> newShape;
+ operandsAndShape(resultType, dynamicExtents, newOperands, newShape);
if (newOperands.size() == tensorFromElements.getDynamicExtents().size())
return failure();
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 61f03f19de33b..389e7e675c0ee 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -112,6 +112,20 @@ func.func @tensor.generate(%m : index, %n : index)
} : tensor<?x3x?xf32>
return %tnsr : tensor<?x3x?xf32>
}
+
+// -----
+
+func.func @generate_negative_size() -> tensor<?x8xi32> {
+ %cst = arith.constant 0 : i32
+ %size = index.constant -128
+ // expected-error at +1 {{tensor dimensions must be non-negative}}
+ %tensor = tensor.generate %size {
+ ^bb0(%arg0: index, %arg1: index):
+ tensor.yield %cst : i32
+ } : tensor<?x8xi32>
+ return %tensor : tensor<?x8xi32>
+}
+
// -----
func.func @tensor.reshape_element_type_mismatch(
More information about the Mlir-commits
mailing list