[Mlir-commits] [mlir] [mlir][tensor] `tensor.generate`: do not verify dynamic sizes (PR #74568)
Matthias Springer
llvmlistbot at llvm.org
Wed Dec 6 15:35:43 PST 2023
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/74568
>From 504272ebdf0e3751a6e2a83a30d7c3e0e362eb52 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Wed, 6 Dec 2023 16:15:29 +0900
Subject: [PATCH] [mlir][tensor] `tensor.generate`: do not verify dynamic sizes
Op verifiers should verify only local properties of an op. The dynamic sizes of a `tensor.generate` op should not be verified. Dynamic sizes that have a negative constant value should not prevent the `tensor.generate` op from verifying.
Also share some code between the `tensor.empty` and `tensor.generate` "dynamic dim -> static dim" canonicalization patterns.
Remove the `invalid-canonicalize.mlir` file and move the test case to `canonicalize.mlir`. Canonicalization no longer produces IR that does not verify (and leaves the op as is).
---
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 133 +++++++-----------
mlir/test/Dialect/Tensor/canonicalize.mlir | 17 +++
.../Dialect/Tensor/invalid-canonicalize.mlir | 15 --
mlir/test/Dialect/Tensor/invalid.mlir | 13 --
4 files changed, 70 insertions(+), 108 deletions(-)
delete mode 100644 mlir/test/Dialect/Tensor/invalid-canonicalize.mlir
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 55f813df78b85..3bb95a0141e04 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -168,6 +168,40 @@ static llvm::SmallBitVector getDroppedDims(ArrayRef<int64_t> reducedShape,
return droppedDims;
}
+/// Given a ranked tensor type and a range of values that defines its dynamic
+/// dimension sizes, turn all dynamic sizes that have a constant value into
+/// static dimension sizes.
+static RankedTensorType
+foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes,
+ SmallVector<Value> &foldedDynamicSizes) {
+ SmallVector<int64_t> staticShape(type.getShape().begin(),
+ type.getShape().end());
+ assert(type.getNumDynamicDims() == dynamicSizes.size() &&
+ "incorrect number of dynamic sizes");
+
+ // Compute new static and dynamic sizes.
+ unsigned ctr = 0;
+ for (int64_t i = 0, e = type.getRank(); i < e; ++i) {
+ if (type.isDynamicDim(i)) {
+ Value dynamicSize = dynamicSizes[ctr++];
+ std::optional<int64_t> cst = getConstantIntValue(dynamicSize);
+ if (cst.has_value()) {
+ // Dynamic size must be non-negative.
+ if (cst.value() < 0) {
+ foldedDynamicSizes.push_back(dynamicSize);
+ continue;
+ }
+ staticShape[i] = *cst;
+ } else {
+ foldedDynamicSizes.push_back(dynamicSize);
+ }
+ }
+ }
+
+ return RankedTensorType::get(staticShape, type.getElementType(),
+ type.getEncoding());
+}
+
//===----------------------------------------------------------------------===//
// BitcastOp
//===----------------------------------------------------------------------===//
@@ -889,37 +923,16 @@ struct ReplaceEmptyTensorStaticShapeDims : OpRewritePattern<EmptyOp> {
LogicalResult matchAndRewrite(EmptyOp op,
PatternRewriter &rewriter) const override {
- SmallVector<int64_t> staticShape(op.getType().getShape().begin(),
- op.getType().getShape().end());
- SmallVector<Value> dynamicSizes;
-
- // Compute new static and dynamic sizes.
- unsigned ctr = 0;
- bool changedType = false;
- for (int64_t i = 0; i < op.getType().getRank(); ++i) {
- if (op.getType().isDynamicDim(i)) {
- Value dynamicSize = op.getDynamicSizes()[ctr++];
- std::optional<int64_t> cst = getConstantIntValue(dynamicSize);
- if (cst.has_value()) {
- // dynamic size must be non-negative.
- if (cst.value() < 0)
- return failure();
- staticShape[i] = *cst;
- changedType = true;
- } else {
- dynamicSizes.push_back(dynamicSize);
- }
- }
- }
+ SmallVector<Value> foldedDynamicSizes;
+ RankedTensorType foldedTensorType = foldDynamicToStaticDimSizes(
+ op.getType(), op.getDynamicSizes(), foldedDynamicSizes);
// Stop here if no dynamic size was promoted to static.
- if (!changedType)
+ if (foldedTensorType == op.getType())
return failure();
- auto tensorType = RankedTensorType::get(
- staticShape, op.getType().getElementType(), op.getType().getEncoding());
- auto newOp =
- rewriter.create<EmptyOp>(op.getLoc(), tensorType, dynamicSizes);
+ auto newOp = rewriter.create<EmptyOp>(op.getLoc(), foldedTensorType,
+ foldedDynamicSizes);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
return success();
}
@@ -1347,28 +1360,6 @@ LogicalResult GenerateOp::reifyResultShapes(
return success();
}
-/// Extract operands and shape from a tensor with dynamic extents.
-static void operandsAndShape(TensorType resultType,
- Operation::operand_range dynamicExtents,
- SmallVectorImpl<Value> &newOperands,
- SmallVectorImpl<int64_t> &newShape) {
- auto operandsIt = dynamicExtents.begin();
- for (int64_t dim : resultType.getShape()) {
- if (!ShapedType::isDynamic(dim)) {
- newShape.push_back(dim);
- continue;
- }
- APInt index;
- if (!matchPattern(*operandsIt, m_ConstantInt(&index))) {
- newShape.push_back(ShapedType::kDynamic);
- newOperands.push_back(*operandsIt++);
- continue;
- }
- newShape.push_back(index.getSExtValue());
- operandsIt++;
- }
-}
-
LogicalResult GenerateOp::verify() {
// Ensure that the tensor type has as many dynamic dimensions as are
// specified by the operands.
@@ -1376,14 +1367,6 @@ LogicalResult GenerateOp::verify() {
if (getNumOperands() != resultType.getNumDynamicDims())
return emitError("must have as many index operands as dynamic extents "
"in the result type");
- // Ensure operands are non-negative.
- SmallVector<Value> newOperands;
- SmallVector<int64_t> newShape;
- operandsAndShape(resultType, getDynamicExtents(), newOperands, newShape);
- for (int64_t newdim : newShape) {
- if (newdim < 0 && !ShapedType::isDynamic(newdim))
- return emitError("tensor dimensions must be non-negative");
- }
return success();
}
@@ -1433,34 +1416,24 @@ namespace {
struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
using OpRewritePattern<GenerateOp>::OpRewritePattern;
- LogicalResult matchAndRewrite(GenerateOp tensorFromElements,
+ LogicalResult matchAndRewrite(GenerateOp generateOp,
PatternRewriter &rewriter) const final {
- auto resultType =
- llvm::cast<RankedTensorType>(tensorFromElements.getResult().getType());
+ SmallVector<Value> foldedDynamicSizes;
+ RankedTensorType foldedTensorType = foldDynamicToStaticDimSizes(
+ generateOp.getType(), generateOp.getDynamicExtents(),
+ foldedDynamicSizes);
- if (resultType.hasStaticShape())
- return failure();
-
- Operation::operand_range dynamicExtents =
- tensorFromElements.getDynamicExtents();
- SmallVector<Value> newOperands;
- SmallVector<int64_t> newShape;
- operandsAndShape(resultType, dynamicExtents, newOperands, newShape);
-
- if (!hasValidSizesOffsets(newShape))
- return failure();
-
- if (newOperands.size() == tensorFromElements.getDynamicExtents().size())
+ // Stop here if no dynamic size was promoted to static.
+ if (foldedTensorType == generateOp.getType())
return failure();
- auto loc = tensorFromElements.getLoc();
- auto newOp = rewriter.create<GenerateOp>(
- loc, RankedTensorType::get(newShape, resultType.getElementType()),
- newOperands);
- rewriter.inlineRegionBefore(tensorFromElements.getBody(), newOp.getBody(),
+ auto loc = generateOp.getLoc();
+ auto newOp =
+ rewriter.create<GenerateOp>(loc, foldedTensorType, foldedDynamicSizes);
+ rewriter.inlineRegionBefore(generateOp.getBody(), newOp.getBody(),
newOp.getBody().begin());
- rewriter.replaceOpWithNewOp<tensor::CastOp>(tensorFromElements, resultType,
- newOp);
+ rewriter.replaceOpWithNewOp<tensor::CastOp>(generateOp,
+ generateOp.getType(), newOp);
return success();
}
};
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 84c44a09aa3dd..7d7d221c1e8e9 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2018,3 +2018,20 @@ func.func @invalid_slice_ops(%t: tensor<?xf32>, %t2: tensor<?xf32>) -> tensor<?x
%1 = tensor.insert_slice %0 into %t2[2][%c][1] : tensor<?xf32> into tensor<?xf32>
return %1 : tensor<?xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @generate_negative_size_verifies(
+// CHECK: %[[c:.*]] = arith.constant -8 : index
+// CHECK: tensor.generate %[[c]]
+// CHECK: : tensor<?x8xi32>
+func.func @generate_negative_size_verifies() -> tensor<?x8xi32> {
+ %cst = arith.constant 0 : i32
+ %c0 = arith.constant 0 : index
+ %size = affine.max affine_map<(d0) -> (d0 mod 64 - 8)>(%c0)
+ %tensor = tensor.generate %size {
+ ^bb0(%arg0: index, %arg1: index):
+ tensor.yield %cst : i32
+ } : tensor<?x8xi32>
+ return %tensor : tensor<?x8xi32>
+}
diff --git a/mlir/test/Dialect/Tensor/invalid-canonicalize.mlir b/mlir/test/Dialect/Tensor/invalid-canonicalize.mlir
deleted file mode 100644
index decfd55eacc95..0000000000000
--- a/mlir/test/Dialect/Tensor/invalid-canonicalize.mlir
+++ /dev/null
@@ -1,15 +0,0 @@
-// RUN: mlir-opt <%s -split-input-file -verify-diagnostics -canonicalize
-
-// -----
-
-func.func @indirectly_generate_negative_size() -> tensor<?x8xi32> {
- %cst = arith.constant 0 : i32
- %c0 = arith.constant 0 : index
- %size = affine.max affine_map<(d0) -> (d0 mod 64 - 8)>(%c0)
- // expected-error at +1 {{tensor dimensions must be non-negative}}
- %tensor = tensor.generate %size {
- ^bb0(%arg0: index, %arg1: index):
- tensor.yield %cst : i32
- } : tensor<?x8xi32>
- return %tensor : tensor<?x8xi32>
-}
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 9b6c2327879cf..bdada43e325c5 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -163,19 +163,6 @@ func.func @tensor.generate(%m : index, %n : index)
// -----
-func.func @generate_negative_size() -> tensor<?x8xi32> {
- %cst = arith.constant 0 : i32
- %size = index.constant -128
- // expected-error at +1 {{tensor dimensions must be non-negative}}
- %tensor = tensor.generate %size {
- ^bb0(%arg0: index, %arg1: index):
- tensor.yield %cst : i32
- } : tensor<?x8xi32>
- return %tensor : tensor<?x8xi32>
-}
-
-// -----
-
func.func @tensor.reshape_element_type_mismatch(
%buf: tensor<*xf32>, %shape: tensor<1xi32>) {
// expected-error @+1 {{element types of source and destination tensor types should be the same}}
More information about the Mlir-commits
mailing list