[Mlir-commits] [mlir] 9466c4e - [MLIR][tensor] Improve `tensor.pack` verifier to catch more cases with unconditional runtime errors (#77217)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Feb 19 10:27:28 PST 2024


Author: srcarroll
Date: 2024-02-19T12:27:24-06:00
New Revision: 9466c4e629ecff3060b7fef3cb189179e25c4f5f

URL: https://github.com/llvm/llvm-project/commit/9466c4e629ecff3060b7fef3cb189179e25c4f5f
DIFF: https://github.com/llvm/llvm-project/commit/9466c4e629ecff3060b7fef3cb189179e25c4f5f.diff

LOG: [MLIR][tensor] Improve `tensor.pack` verifier to catch more cases with unconditional runtime errors (#77217)

Previously, the `tensor.pack` verifier detects unconditional runtime
errors only when tile sizes are static. Now, dynamic tiles are
considered and we only require that the input and either corresponding
tile or output size are static to determine if it will unconditionally
produce errors at runtime.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/test/Dialect/Tensor/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index eb0c79c01bee1a..1c61ece2676a90 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1943,11 +1943,12 @@ def Tensor_PackOp : Tensor_RelayoutOp<"pack", [
 
     // Returns true if we have enough static information to catch undefined
     // behavior when the tile size does not divide perfectly the dimension of
-    // the input tensor. If a given dimension or a tile associated with it is
-    // dynamic, the dimension is not considered as we don't have enough static
-    // information to understand if the tile perfectly divides that dimension.
+    // the input tensor. Detecting UB requires that the input size and either
+    // corresponding tile or output size are static.
     static bool requirePaddingValue(ArrayRef<int64_t> inputShape,
                                     ArrayRef<int64_t> innerDimsPos,
+                                    ArrayRef<int64_t> outputShape,
+                                    ArrayRef<int64_t> outerDimsPerm,
                                     ArrayRef<OpFoldResult> innerTiles);
 
     static Value createDestinationTensor(OpBuilder &b, Location loc,

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 596b7c50c1e4e4..01b393644679c5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -563,8 +563,10 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
             return getConstantIntValue(tile).has_value();
           });
       if (areConstantTiles && operandType.hasStaticShape() &&
-          !tensor::PackOp::requirePaddingValue(operandType.getShape(), innerPos,
-                                               innerPackSizes)) {
+          !tensor::PackOp::requirePaddingValue(
+              operandType.getShape(), innerPos,
+              dest.getType().cast<ShapedType>().getShape(), {},
+              innerPackSizes)) {
         packOps.push_back(rewriter.create<tensor::PackOp>(
             loc, operand, dest, innerPos, innerPackSizes));
       } else {

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 303b38c70d747f..62bf6b48f18b4a 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3746,15 +3746,29 @@ SmallVector<int64_t> PackOp::getStaticTiles() {
 
 bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
                                  ArrayRef<int64_t> innerDimsPos,
+                                 ArrayRef<int64_t> outputShape,
+                                 ArrayRef<int64_t> outerDimsPerm,
                                  ArrayRef<OpFoldResult> innerTiles) {
+  SmallVector<int64_t> outputTileSizes(
+      outputShape.take_front(inputShape.size()));
+  if (!outerDimsPerm.empty()) {
+    assert(outerDimsPerm.size() == outputTileSizes.size() &&
+           "expected output and outer_dims_perm to have same size");
+    applyPermutationToVector(outputTileSizes,
+                             invertPermutationVector(outerDimsPerm));
+  }
   for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
     if (ShapedType::isDynamic(inputShape[pos]))
       continue;
     std::optional<int64_t> constantTile = getConstantIntValue(tileSize);
-    if (!constantTile)
-      continue;
-    if (inputShape[pos] % (*constantTile) != 0)
+
+    if (!constantTile) {
+      if (!ShapedType::isDynamic(outputTileSizes[pos]) &&
+          (inputShape[pos] % outputTileSizes[pos] != 0))
+        return true;
+    } else if (inputShape[pos] % (*constantTile) != 0) {
       return true;
+    }
   }
   return false;
 }
@@ -3776,9 +3790,11 @@ LogicalResult PackOp::verify() {
 
   if (!paddingValue &&
       requirePaddingValue(getSourceType().getShape(), getInnerDimsPos(),
+                          getDestType().getShape(), getOuterDimsPerm(),
                           getMixedTiles())) {
-    return emitOpError("invalid tile factor provided. Only full tiles are "
-                       "supported when padding_value is not set");
+    return emitOpError(
+        "invalid tile factor or output size provided. Only full tiles are "
+        "supported when padding_value is not set");
   }
   return success();
 }
@@ -3979,8 +3995,9 @@ static bool paddingIsNotNeeded(PackOp op) {
     return false;
   if (ShapedType::isDynamicShape(op.getStaticInnerTiles()))
     return false;
-  return !PackOp::requirePaddingValue(srcType.getShape(), op.getInnerDimsPos(),
-                                      op.getMixedTiles());
+  return !PackOp::requirePaddingValue(
+      srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(),
+      op.getOuterDimsPerm(), op.getMixedTiles());
 }
 
 /// Returns true if the `srcShape` or `destShape` is 
diff erent from the one in

diff  --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 735e5146e9dbc8..4c534fe936e3de 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -597,13 +597,29 @@ func.func @empty_wrong_number_of_operands(%sz : index) {
 // -----
 
 func.func @pack_invalid_no_padding_no_full_tiles(%input: tensor<256x128xf32>, %output: tensor<8x8x16x33xf32>) -> tensor<8x8x16x33xf32> {
-  // expected-error at +1 {{invalid tile factor provided. Only full tiles are supported when padding_value is not set}}
+  // expected-error at +1 {{invalid tile factor or output size provided. Only full tiles are supported when padding_value is not set}}
   %0 = tensor.pack %input inner_dims_pos = [1, 0] inner_tiles = [16, 33] into %output : tensor<256x128xf32>  -> tensor<8x8x16x33xf32>
   return %0 : tensor<8x8x16x33xf32>
 }
 
 // -----
 
+func.func @pack_invalid_no_padding_no_full_tiles_dyn_tiles(%input: tensor<256x128xf32>, %output: tensor<10x8x?x?xf32>, %tile_size_0: index, %tile_size_1: index) -> tensor<10x8x?x?xf32> {
+  // expected-error at +1 {{invalid tile factor or output size provided. Only full tiles are supported when padding_value is not set}}
+  %0 = tensor.pack %input inner_dims_pos = [1, 0] inner_tiles = [%tile_size_0, %tile_size_1] into %output : tensor<256x128xf32>  -> tensor<10x8x?x?xf32>
+  return %0 : tensor<10x8x?x?xf32>
+} 
+
+// -----
+
+func.func @pack_invalid_no_padding_no_full_tiles_dyn_tiles_outperm(%input: tensor<256x128xf32>, %output: tensor<8x10x?x?xf32>, %tile_size_0: index, %tile_size_1: index) -> tensor<8x10x?x?xf32> {
+  // expected-error at +1 {{invalid tile factor or output size provided. Only full tiles are supported when padding_value is not set}}
+  %0 = tensor.pack %input outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [%tile_size_0, %tile_size_1] into %output : tensor<256x128xf32>  -> tensor<8x10x?x?xf32>
+  return %0 : tensor<8x10x?x?xf32>
+} 
+
+// -----
+
 func.func @pad_and_pack_invalid_type(%input: tensor<13x15xf32>, %output: tensor<2x8x8x2xf32>, %pad: i32) -> tensor<2x8x8x2xf32> {
   // expected-error at +1 {{expected padding_value has 'f32' but got: 'i32'}}
   %0 = tensor.pack %input padding_value(%pad: i32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : tensor<13x15xf32> -> tensor<2x8x8x2xf32>


        


More information about the Mlir-commits mailing list