[Mlir-commits] [mlir] f71de25 - [mlir][tensor] Add shape inference methods to tensor::PackOp.

Hanhan Wang llvmlistbot at llvm.org
Thu Feb 9 17:24:51 PST 2023


Author: Hanhan Wang
Date: 2023-02-09T17:24:42-08:00
New Revision: f71de259c373cf91abf33c99f375fb9c64c3a441

URL: https://github.com/llvm/llvm-project/commit/f71de259c373cf91abf33c99f375fb9c64c3a441
DIFF: https://github.com/llvm/llvm-project/commit/f71de259c373cf91abf33c99f375fb9c64c3a441.diff

LOG: [mlir][tensor] Add shape inference methods to tensor::PackOp.

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D143686

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index e702189e78476..9e1c8bc3abf81 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1772,6 +1772,14 @@ def Tensor_PackOp : Tensor_RelayoutOp<"pack", [
   ];
 
   let extraClassDeclaration = commonExtraClassDeclaration # [{
+    // Method to get the shape of the result as `SmallVector<OpFoldResult>`.
+    // This is a static method to allow getting the shape of the destination
+    // expected while creating a `pack` op.
+    static SmallVector<OpFoldResult> getResultShape(OpBuilder &builder,
+        Location loc, ArrayRef<OpFoldResult> sourceDims,
+        ArrayRef<OpFoldResult> innerTileDims, ArrayRef<int64_t> innerDimsPos,
+        ArrayRef<int64_t> outerDimsPerm = {});
+
     // Method to get the `ShapedType` of the result based on the inner tiles,
     // position of the inner tiles (innerDimsPos)  and interchange vector of
     // outer loops (outerDimsPerm).

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index d35895a167558..74b3f9338aa75 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3479,14 +3479,29 @@ LogicalResult PackOp::verify() {
   return success();
 }
 
-/// Get the expected packed type based on source type, tile factors, position of
-/// the inner tiles and permutation of the outer tiled loop.
-ShapedType PackOp::inferPackedType(ShapedType sourceType,
-                                   ArrayRef<int64_t> innerTileSizes,
-                                   ArrayRef<int64_t> innerDimsPos,
-                                   ArrayRef<int64_t> outerDimsPerm) {
-  SmallVector<int64_t> resultShape = llvm::to_vector(sourceType.getShape());
-  for (const auto &tiledDim : llvm::enumerate(innerDimsPos)) {
+/// Converts OpFoldResults to int64_t shape entries, unconditionally mapping all
+/// Value's to kDynamic, even if they are arith.constant values.
+static SmallVector<int64_t>
+asShapeWithAnyValueAsDynamic(ArrayRef<OpFoldResult> ofrs) {
+  SmallVector<int64_t> result;
+  for (auto o : ofrs) {
+    // Have to do this first, as getConstantIntValue special-cases constants.
+    if (o.dyn_cast<Value>())
+      result.push_back(ShapedType::kDynamic);
+    else
+      result.push_back(getConstantIntValue(o).value_or(ShapedType::kDynamic));
+  }
+  return result;
+}
+
+/// Helper for PackOp::{getResultShape,inferPackedType}. Returns the shape of
+/// the packed type. Having a shared helper helps implement these two methods in
+/// a way that ensures that they agree on which dimensions are dynamic.
+static SmallVector<int64_t> getPackOpResultTypeShape(
+    ArrayRef<int64_t> sourceShape, ArrayRef<int64_t> innerTileSizes,
+    ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm) {
+  SmallVector<int64_t> resultShape = llvm::to_vector(sourceShape);
+  for (auto tiledDim : llvm::enumerate(innerDimsPos)) {
     if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
       continue;
     if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
@@ -3497,11 +3512,60 @@ ShapedType PackOp::inferPackedType(ShapedType sourceType,
                                             innerTileSizes[tiledDim.index()]);
   }
 
+  // Swap tile loops if outer_dims_perm is available.
   if (!outerDimsPerm.empty())
     applyPermutationToVector(resultShape, outerDimsPerm);
 
   // Append the inner tile dimensions.
   resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
+  return resultShape;
+}
+
+SmallVector<OpFoldResult> PackOp::getResultShape(
+    OpBuilder &builder, Location loc, ArrayRef<OpFoldResult> sourceDims,
+    ArrayRef<OpFoldResult> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
+    ArrayRef<int64_t> outerDimsPerm) {
+  SmallVector<OpFoldResult> resultDims = llvm::to_vector(sourceDims);
+
+  AffineExpr s0, s1;
+  bindSymbols(builder.getContext(), s0, s1);
+  AffineExpr ceilDivExpr = s0.ceilDiv(s1);
+  for (auto tiledDim : llvm::enumerate(innerDimsPos)) {
+    resultDims[tiledDim.value()] = makeComposedFoldedAffineApply(
+        builder, loc, ceilDivExpr,
+        {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
+  }
+  if (!outerDimsPerm.empty())
+    applyPermutationToVector(resultDims, outerDimsPerm);
+  resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
+
+  SmallVector<int64_t> resultTypeShape =
+      getPackOpResultTypeShape(asShapeWithAnyValueAsDynamic(sourceDims),
+                               asShapeWithAnyValueAsDynamic(innerTileSizes),
+                               innerDimsPos, outerDimsPerm);
+
+  // Fix-up `resultDims` to ensure that they are Value's if and only if the
+  // result type shape says it's a dynamic dim. This is needed as callers may
+  // use dispatchIndexOpFoldResults on the result, and rely on exact number of
+  // dynamic dims returned by that.
+  for (unsigned i = 0; i < resultDims.size(); ++i) {
+    if (!ShapedType::isDynamic(resultTypeShape[i]))
+      continue;
+    resultDims[i] =
+        getValueOrCreateConstantIndexOp(builder, loc, resultDims[i]);
+  }
+
+  return resultDims;
+}
+
+/// Get the expected packed type based on source type, tile factors, position of
+/// the inner tiles and permutation of the outer tiled loop.
+ShapedType PackOp::inferPackedType(ShapedType sourceType,
+                                   ArrayRef<int64_t> innerTileSizes,
+                                   ArrayRef<int64_t> innerDimsPos,
+                                   ArrayRef<int64_t> outerDimsPerm) {
+  SmallVector<int64_t> resultShape = getPackOpResultTypeShape(
+      sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
   return RankedTensorType::get(resultShape, sourceType.getElementType());
 }
 


        


More information about the Mlir-commits mailing list