[Mlir-commits] [mlir] Improve `tensor.pack` verifier to catch more UB cases (PR #77217)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Jan 6 17:58:07 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-tensor

Author: None (srcarroll)

<details>
<summary>Changes</summary>

Previously, the `tensor.pack` verifier detects undefined behavior only when tile sizes are static.  Now, dynamic tiles are considered and we only require that the input and either corresponding tile or output size are static to determine UB.

---
Full diff: https://github.com/llvm/llvm-project/pull/77217.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td (+4-3) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp (+4-2) 
- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+23-7) 
- (modified) mlir/test/Dialect/Tensor/invalid.mlir (+9-1) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index eb0c79c01bee1a..1c61ece2676a90 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1943,11 +1943,12 @@ def Tensor_PackOp : Tensor_RelayoutOp<"pack", [
 
     // Returns true if we have enough static information to catch undefined
     // behavior when the tile size does not divide perfectly the dimension of
-    // the input tensor. If a given dimension or a tile associated with it is
-    // dynamic, the dimension is not considered as we don't have enough static
-    // information to understand if the tile perfectly divides that dimension.
+    // the input tensor. Detecting UB requires that the input size and either
+    // corresponding tile or output size are static.
     static bool requirePaddingValue(ArrayRef<int64_t> inputShape,
                                     ArrayRef<int64_t> innerDimsPos,
+                                    ArrayRef<int64_t> outputShape,
+                                    ArrayRef<int64_t> outerDimsPerm,
                                     ArrayRef<OpFoldResult> innerTiles);
 
     static Value createDestinationTensor(OpBuilder &b, Location loc,
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 9d230e2c2e5749..c7fed41d234fd0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -582,8 +582,10 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
             return getConstantIntValue(tile).has_value();
           });
       if (areConstantTiles && operandType.hasStaticShape() &&
-          !tensor::PackOp::requirePaddingValue(operandType.getShape(), innerPos,
-                                               innerPackSizes)) {
+          !tensor::PackOp::requirePaddingValue(
+              operandType.getShape(), innerPos,
+              dest.getType().cast<ShapedType>().getShape(), {},
+              innerPackSizes)) {
         packOps.push_back(rewriter.create<tensor::PackOp>(
             loc, operand, dest, innerPos, innerPackSizes));
       } else {
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 816e6ba8fed94e..4318e55fd213e3 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3742,14 +3742,27 @@ SmallVector<int64_t> PackOp::getStaticTiles() {
 
 bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
                                  ArrayRef<int64_t> innerDimsPos,
+                                 ArrayRef<int64_t> outputShape,
+                                 ArrayRef<int64_t> outerDimsPerm,
                                  ArrayRef<OpFoldResult> innerTiles) {
+  SmallVector<int64_t> outputTileSizes(
+      outputShape.take_front(inputShape.size()));
+  if (!outerDimsPerm.empty()) {
+    assert(outerDimsPerm.size() == outputTileSizes.size() &&
+           "expected output and outer_dims_perm to have same size");
+    applyPermutationToVector(outputTileSizes,
+                             invertPermutationVector(outerDimsPerm));
+  }
   for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
     if (ShapedType::isDynamic(inputShape[pos]))
       continue;
     std::optional<int64_t> constantTile = getConstantIntValue(tileSize);
-    if (!constantTile)
-      continue;
-    if (inputShape[pos] % (*constantTile) != 0)
+
+    if (!constantTile) {
+      if (!ShapedType::isDynamic(outputTileSizes[pos]) &&
+          (inputShape[pos] % outputTileSizes[pos] != 0))
+        return true;
+    } else if (inputShape[pos] % (*constantTile) != 0)
       return true;
   }
   return false;
@@ -3772,9 +3785,11 @@ LogicalResult PackOp::verify() {
 
   if (!paddingValue &&
       requirePaddingValue(getSourceType().getShape(), getInnerDimsPos(),
+                          getDestType().getShape(), getOuterDimsPerm(),
                           getMixedTiles())) {
-    return emitOpError("invalid tile factor provided. Only full tiles are "
-                       "supported when padding_value is not set");
+    return emitOpError(
+        "invalid tile factor or output size provided. Only full tiles are "
+        "supported when padding_value is not set");
   }
   return success();
 }
@@ -3975,8 +3990,9 @@ static bool paddingIsNotNeeded(PackOp op) {
     return false;
   if (ShapedType::isDynamicShape(op.getStaticInnerTiles()))
     return false;
-  return !PackOp::requirePaddingValue(srcType.getShape(), op.getInnerDimsPos(),
-                                      op.getMixedTiles());
+  return !PackOp::requirePaddingValue(
+      srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(),
+      op.getOuterDimsPerm(), op.getMixedTiles());
 }
 
 LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 735e5146e9dbc8..0eb8672d4d41b3 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -597,13 +597,21 @@ func.func @empty_wrong_number_of_operands(%sz : index) {
 // -----
 
 func.func @pack_invalid_no_padding_no_full_tiles(%input: tensor<256x128xf32>, %output: tensor<8x8x16x33xf32>) -> tensor<8x8x16x33xf32> {
-  // expected-error at +1 {{invalid tile factor provided. Only full tiles are supported when padding_value is not set}}
+  // expected-error at +1 {{invalid tile factor or output size provided. Only full tiles are supported when padding_value is not set}}
   %0 = tensor.pack %input inner_dims_pos = [1, 0] inner_tiles = [16, 33] into %output : tensor<256x128xf32>  -> tensor<8x8x16x33xf32>
   return %0 : tensor<8x8x16x33xf32>
 }
 
 // -----
 
+func.func @pack_invalid_no_padding_no_full_tiles_dyn_tiles(%input: tensor<256x128xf32>, %output: tensor<10x8x?x?xf32>, %tile_size_0: index, %tile_size_1: index) -> tensor<10x8x?x?xf32> {
+  // expected-error at +1 {{invalid tile factor or output size provided. Only full tiles are supported when padding_value is not set}}
+  %0 = tensor.pack %input inner_dims_pos = [1, 0] inner_tiles = [%tile_size_0, %tile_size_1] into %output : tensor<256x128xf32>  -> tensor<10x8x?x?xf32>
+  return %0 : tensor<10x8x?x?xf32>
+} 
+
+// -----
+
 func.func @pad_and_pack_invalid_type(%input: tensor<13x15xf32>, %output: tensor<2x8x8x2xf32>, %pad: i32) -> tensor<2x8x8x2xf32> {
   // expected-error at +1 {{expected padding_value has 'f32' but got: 'i32'}}
   %0 = tensor.pack %input padding_value(%pad: i32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : tensor<13x15xf32> -> tensor<2x8x8x2xf32>

``````````

</details>


https://github.com/llvm/llvm-project/pull/77217


More information about the Mlir-commits mailing list