[Mlir-commits] [mlir] be76f6b - [mlir][tensor] Expose padding requirement of pack ops to a static method

Hanhan Wang llvmlistbot at llvm.org
Thu Mar 9 10:03:54 PST 2023


Author: Hanhan Wang
Date: 2023-03-09T10:03:46-08:00
New Revision: be76f6bef835c473a099482dd6c531fe7d7ededb

URL: https://github.com/llvm/llvm-project/commit/be76f6bef835c473a099482dd6c531fe7d7ededb
DIFF: https://github.com/llvm/llvm-project/commit/be76f6bef835c473a099482dd6c531fe7d7ededb.diff

LOG: [mlir][tensor] Expose padding requirement of pack ops to a static method

It also simplifies the implementation of the method. The map is not needed in the check.

Reviewed By: chelini

Differential Revision: https://reviews.llvm.org/D145522

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 80c0ba5e754a9..09b7775dcaae4 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1788,6 +1788,13 @@ def Tensor_PackOp : Tensor_RelayoutOp<"pack", [
         ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
         ArrayRef<int64_t> outerDimsPerm = {});
 
+    // Returns true if we have enough static information to catch undefined
+    // behavior when the tile size does not divide perfectly the dimension of
+    // the input tensor.
+    static bool requirePaddingValue(ArrayRef<int64_t> inputShape,
+                                    ArrayRef<int64_t> innerDimsPos,
+                                    ArrayRef<OpFoldResult> innerTiles);
+
     static Value createDestinationTensor(OpBuilder &b, Location loc,
         Value source, ArrayRef<OpFoldResult> innerTileSizes,
         ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm);

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 22531338d30c6..baf213006a6dc 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3422,22 +3422,16 @@ SmallVector<int64_t> PackOp::getStaticTiles() {
   return getStaticTilesImpl(*this);
 }
 
-/// Check if we have enough static information to catch undefined behavior when
-/// the tile size does not divide perfectly the dimension of the input tensor.
-static bool
-areNotFullTiles(ArrayRef<int64_t> inputShape,
-                DenseMap<int64_t, OpFoldResult> const &dimAndTileMapping) {
-  int64_t rank = inputShape.size();
-  for (int64_t dim = 0; dim < rank; dim++) {
-    if (ShapedType::isDynamic(inputShape[dim]))
+bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
+                                 ArrayRef<int64_t> innerDimsPos,
+                                 ArrayRef<OpFoldResult> innerTiles) {
+  for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
+    if (ShapedType::isDynamic(inputShape[pos]))
       continue;
-    auto it = dimAndTileMapping.find(dim);
-    if (it == dimAndTileMapping.end())
-      continue;
-    std::optional<int64_t> constantTile = getConstantIntValue(it->second);
+    std::optional<int64_t> constantTile = getConstantIntValue(tileSize);
     if (!constantTile)
       continue;
-    if (inputShape[dim] % (*constantTile) != 0)
+    if (inputShape[pos] % (*constantTile) != 0)
       return true;
   }
   return false;
@@ -3458,9 +3452,9 @@ LogicalResult PackOp::verify() {
            << " but got: " << paddingValue.getType();
   }
 
-  auto dimAndTileMapping = getDimAndTileMapping();
   if (!paddingValue &&
-      areNotFullTiles(getSourceType().getShape(), dimAndTileMapping)) {
+      requirePaddingValue(getSourceType().getShape(), getInnerDimsPos(),
+                          getMixedTiles())) {
     return emitOpError("invalid tile factor provided. Only full tiles are "
                        "supported when padding_value is not set");
   }


        


More information about the Mlir-commits mailing list