[Mlir-commits] [mlir] [mlir][tensor] `tensor.generate`: do not verify dynamic sizes (PR #74568)

Matthias Springer llvmlistbot at llvm.org
Wed Dec 6 15:35:43 PST 2023


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/74568

>From 504272ebdf0e3751a6e2a83a30d7c3e0e362eb52 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Wed, 6 Dec 2023 16:15:29 +0900
Subject: [PATCH] [mlir][tensor] `tensor.generate`: do not verify dynamic sizes

Op verifiers should verify only local properties of an op. The dynamic sizes of a `tensor.generate` op should not be verified. Dynamic sizes that have a negative constant value should not prevent the `tensor.generate` op from verifying.

Also share some code between the `tensor.empty` and `tensor.generate` "dynamic dim -> static dim" canonicalization patterns.

Remove the `invalid-canonicalize.mlir` file and move the test case to `canonicalize.mlir`. Canonicalization no longer produces IR that does not verify (and leaves the op as is).
---
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp      | 133 +++++++-----------
 mlir/test/Dialect/Tensor/canonicalize.mlir    |  17 +++
 .../Dialect/Tensor/invalid-canonicalize.mlir  |  15 --
 mlir/test/Dialect/Tensor/invalid.mlir         |  13 --
 4 files changed, 70 insertions(+), 108 deletions(-)
 delete mode 100644 mlir/test/Dialect/Tensor/invalid-canonicalize.mlir

diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 55f813df78b85..3bb95a0141e04 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -168,6 +168,40 @@ static llvm::SmallBitVector getDroppedDims(ArrayRef<int64_t> reducedShape,
   return droppedDims;
 }
 
+/// Given a ranked tensor type and a range of values that defines its dynamic
+/// dimension sizes, turn all dynamic sizes that have a constant value into
+/// static dimension sizes.
+static RankedTensorType
+foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes,
+                            SmallVector<Value> &foldedDynamicSizes) {
+  SmallVector<int64_t> staticShape(type.getShape().begin(),
+                                   type.getShape().end());
+  assert(type.getNumDynamicDims() == dynamicSizes.size() &&
+         "incorrect number of dynamic sizes");
+
+  // Compute new static and dynamic sizes.
+  unsigned ctr = 0;
+  for (int64_t i = 0, e = type.getRank(); i < e; ++i) {
+    if (type.isDynamicDim(i)) {
+      Value dynamicSize = dynamicSizes[ctr++];
+      std::optional<int64_t> cst = getConstantIntValue(dynamicSize);
+      if (cst.has_value()) {
+        // Dynamic size must be non-negative.
+        if (cst.value() < 0) {
+          foldedDynamicSizes.push_back(dynamicSize);
+          continue;
+        }
+        staticShape[i] = *cst;
+      } else {
+        foldedDynamicSizes.push_back(dynamicSize);
+      }
+    }
+  }
+
+  return RankedTensorType::get(staticShape, type.getElementType(),
+                               type.getEncoding());
+}
+
 //===----------------------------------------------------------------------===//
 // BitcastOp
 //===----------------------------------------------------------------------===//
@@ -889,37 +923,16 @@ struct ReplaceEmptyTensorStaticShapeDims : OpRewritePattern<EmptyOp> {
 
   LogicalResult matchAndRewrite(EmptyOp op,
                                 PatternRewriter &rewriter) const override {
-    SmallVector<int64_t> staticShape(op.getType().getShape().begin(),
-                                     op.getType().getShape().end());
-    SmallVector<Value> dynamicSizes;
-
-    // Compute new static and dynamic sizes.
-    unsigned ctr = 0;
-    bool changedType = false;
-    for (int64_t i = 0; i < op.getType().getRank(); ++i) {
-      if (op.getType().isDynamicDim(i)) {
-        Value dynamicSize = op.getDynamicSizes()[ctr++];
-        std::optional<int64_t> cst = getConstantIntValue(dynamicSize);
-        if (cst.has_value()) {
-          // dynamic size must be non-negative.
-          if (cst.value() < 0)
-            return failure();
-          staticShape[i] = *cst;
-          changedType = true;
-        } else {
-          dynamicSizes.push_back(dynamicSize);
-        }
-      }
-    }
+    SmallVector<Value> foldedDynamicSizes;
+    RankedTensorType foldedTensorType = foldDynamicToStaticDimSizes(
+        op.getType(), op.getDynamicSizes(), foldedDynamicSizes);
 
     // Stop here if no dynamic size was promoted to static.
-    if (!changedType)
+    if (foldedTensorType == op.getType())
       return failure();
 
-    auto tensorType = RankedTensorType::get(
-        staticShape, op.getType().getElementType(), op.getType().getEncoding());
-    auto newOp =
-        rewriter.create<EmptyOp>(op.getLoc(), tensorType, dynamicSizes);
+    auto newOp = rewriter.create<EmptyOp>(op.getLoc(), foldedTensorType,
+                                          foldedDynamicSizes);
     rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
     return success();
   }
@@ -1347,28 +1360,6 @@ LogicalResult GenerateOp::reifyResultShapes(
   return success();
 }
 
-/// Extract operands and shape from a tensor with dynamic extents.
-static void operandsAndShape(TensorType resultType,
-                             Operation::operand_range dynamicExtents,
-                             SmallVectorImpl<Value> &newOperands,
-                             SmallVectorImpl<int64_t> &newShape) {
-  auto operandsIt = dynamicExtents.begin();
-  for (int64_t dim : resultType.getShape()) {
-    if (!ShapedType::isDynamic(dim)) {
-      newShape.push_back(dim);
-      continue;
-    }
-    APInt index;
-    if (!matchPattern(*operandsIt, m_ConstantInt(&index))) {
-      newShape.push_back(ShapedType::kDynamic);
-      newOperands.push_back(*operandsIt++);
-      continue;
-    }
-    newShape.push_back(index.getSExtValue());
-    operandsIt++;
-  }
-}
-
 LogicalResult GenerateOp::verify() {
   // Ensure that the tensor type has as many dynamic dimensions as are
   // specified by the operands.
@@ -1376,14 +1367,6 @@ LogicalResult GenerateOp::verify() {
   if (getNumOperands() != resultType.getNumDynamicDims())
     return emitError("must have as many index operands as dynamic extents "
                      "in the result type");
-  // Ensure operands are non-negative.
-  SmallVector<Value> newOperands;
-  SmallVector<int64_t> newShape;
-  operandsAndShape(resultType, getDynamicExtents(), newOperands, newShape);
-  for (int64_t newdim : newShape) {
-    if (newdim < 0 && !ShapedType::isDynamic(newdim))
-      return emitError("tensor dimensions must be non-negative");
-  }
   return success();
 }
 
@@ -1433,34 +1416,24 @@ namespace {
 struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
   using OpRewritePattern<GenerateOp>::OpRewritePattern;
 
-  LogicalResult matchAndRewrite(GenerateOp tensorFromElements,
+  LogicalResult matchAndRewrite(GenerateOp generateOp,
                                 PatternRewriter &rewriter) const final {
-    auto resultType =
-        llvm::cast<RankedTensorType>(tensorFromElements.getResult().getType());
+    SmallVector<Value> foldedDynamicSizes;
+    RankedTensorType foldedTensorType = foldDynamicToStaticDimSizes(
+        generateOp.getType(), generateOp.getDynamicExtents(),
+        foldedDynamicSizes);
 
-    if (resultType.hasStaticShape())
-      return failure();
-
-    Operation::operand_range dynamicExtents =
-        tensorFromElements.getDynamicExtents();
-    SmallVector<Value> newOperands;
-    SmallVector<int64_t> newShape;
-    operandsAndShape(resultType, dynamicExtents, newOperands, newShape);
-
-    if (!hasValidSizesOffsets(newShape))
-      return failure();
-
-    if (newOperands.size() == tensorFromElements.getDynamicExtents().size())
+    // Stop here if no dynamic size was promoted to static.
+    if (foldedTensorType == generateOp.getType())
       return failure();
 
-    auto loc = tensorFromElements.getLoc();
-    auto newOp = rewriter.create<GenerateOp>(
-        loc, RankedTensorType::get(newShape, resultType.getElementType()),
-        newOperands);
-    rewriter.inlineRegionBefore(tensorFromElements.getBody(), newOp.getBody(),
+    auto loc = generateOp.getLoc();
+    auto newOp =
+        rewriter.create<GenerateOp>(loc, foldedTensorType, foldedDynamicSizes);
+    rewriter.inlineRegionBefore(generateOp.getBody(), newOp.getBody(),
                                 newOp.getBody().begin());
-    rewriter.replaceOpWithNewOp<tensor::CastOp>(tensorFromElements, resultType,
-                                                newOp);
+    rewriter.replaceOpWithNewOp<tensor::CastOp>(generateOp,
+                                                generateOp.getType(), newOp);
     return success();
   }
 };
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 84c44a09aa3dd..7d7d221c1e8e9 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2018,3 +2018,20 @@ func.func @invalid_slice_ops(%t: tensor<?xf32>, %t2: tensor<?xf32>) -> tensor<?x
   %1 = tensor.insert_slice %0 into %t2[2][%c][1] : tensor<?xf32> into tensor<?xf32>
   return %1 : tensor<?xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @generate_negative_size_verifies(
+//       CHECK:   %[[c:.*]] = arith.constant -8 : index
+//       CHECK:   tensor.generate %[[c]]
+//       CHECK:   : tensor<?x8xi32>
+func.func @generate_negative_size_verifies() -> tensor<?x8xi32> {
+  %cst = arith.constant 0 : i32
+  %c0 = arith.constant 0 : index
+  %size = affine.max affine_map<(d0) -> (d0 mod 64 - 8)>(%c0)
+  %tensor = tensor.generate %size {
+  ^bb0(%arg0: index, %arg1: index):
+    tensor.yield %cst : i32
+  } : tensor<?x8xi32>
+  return %tensor : tensor<?x8xi32>
+}
diff --git a/mlir/test/Dialect/Tensor/invalid-canonicalize.mlir b/mlir/test/Dialect/Tensor/invalid-canonicalize.mlir
deleted file mode 100644
index decfd55eacc95..0000000000000
--- a/mlir/test/Dialect/Tensor/invalid-canonicalize.mlir
+++ /dev/null
@@ -1,15 +0,0 @@
-// RUN: mlir-opt <%s -split-input-file -verify-diagnostics -canonicalize
-
-// -----
-
-func.func @indirectly_generate_negative_size() -> tensor<?x8xi32> {
-  %cst = arith.constant 0 : i32
-  %c0 = arith.constant 0 : index
-  %size = affine.max affine_map<(d0) -> (d0 mod 64 - 8)>(%c0)
-  // expected-error at +1 {{tensor dimensions must be non-negative}}
-  %tensor = tensor.generate %size {
-  ^bb0(%arg0: index, %arg1: index):
-    tensor.yield %cst : i32
-  } : tensor<?x8xi32>
-  return %tensor : tensor<?x8xi32>
-}
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 9b6c2327879cf..bdada43e325c5 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -163,19 +163,6 @@ func.func @tensor.generate(%m : index, %n : index)
 
 // -----
 
-func.func @generate_negative_size() -> tensor<?x8xi32> {
-  %cst = arith.constant 0 : i32
-  %size = index.constant -128
-  // expected-error at +1 {{tensor dimensions must be non-negative}}
-  %tensor = tensor.generate %size {
-  ^bb0(%arg0: index, %arg1: index):
-    tensor.yield %cst : i32
-  } : tensor<?x8xi32>
-  return %tensor : tensor<?x8xi32>
-}
-
-// -----
-
 func.func @tensor.reshape_element_type_mismatch(
        %buf: tensor<*xf32>, %shape: tensor<1xi32>) {
   // expected-error @+1 {{element types of source and destination tensor types should be the same}}



More information about the Mlir-commits mailing list