[Mlir-commits] [mlir] f71de25 - [mlir][tensor] Add shape inference methods to tensor::PackOp.
Hanhan Wang
llvmlistbot at llvm.org
Thu Feb 9 17:24:51 PST 2023
Author: Hanhan Wang
Date: 2023-02-09T17:24:42-08:00
New Revision: f71de259c373cf91abf33c99f375fb9c64c3a441
URL: https://github.com/llvm/llvm-project/commit/f71de259c373cf91abf33c99f375fb9c64c3a441
DIFF: https://github.com/llvm/llvm-project/commit/f71de259c373cf91abf33c99f375fb9c64c3a441.diff
LOG: [mlir][tensor] Add shape inference methods to tensor::PackOp.
Reviewed By: mravishankar
Differential Revision: https://reviews.llvm.org/D143686
Added:
Modified:
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index e702189e78476..9e1c8bc3abf81 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1772,6 +1772,14 @@ def Tensor_PackOp : Tensor_RelayoutOp<"pack", [
];
let extraClassDeclaration = commonExtraClassDeclaration # [{
+ // Method to get the shape of the result as `SmallVector<OpFoldResult>`.
+ // This is a static method to allow getting the shape of the destination
+ // expected while creating a `pack` op.
+ static SmallVector<OpFoldResult> getResultShape(OpBuilder &builder,
+ Location loc, ArrayRef<OpFoldResult> sourceDims,
+ ArrayRef<OpFoldResult> innerTileDims, ArrayRef<int64_t> innerDimsPos,
+ ArrayRef<int64_t> outerDimsPerm = {});
+
// Method to get the `ShapedType` of the result based on the inner tiles,
// position of the inner tiles (innerDimsPos) and interchange vector of
// outer loops (outerDimsPerm).
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index d35895a167558..74b3f9338aa75 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3479,14 +3479,29 @@ LogicalResult PackOp::verify() {
return success();
}
-/// Get the expected packed type based on source type, tile factors, position of
-/// the inner tiles and permutation of the outer tiled loop.
-ShapedType PackOp::inferPackedType(ShapedType sourceType,
- ArrayRef<int64_t> innerTileSizes,
- ArrayRef<int64_t> innerDimsPos,
- ArrayRef<int64_t> outerDimsPerm) {
- SmallVector<int64_t> resultShape = llvm::to_vector(sourceType.getShape());
- for (const auto &tiledDim : llvm::enumerate(innerDimsPos)) {
+/// Converts OpFoldResults to int64_t shape entries, unconditionally mapping all
+/// Value's to kDynamic, even if they are arith.constant values.
+static SmallVector<int64_t>
+asShapeWithAnyValueAsDynamic(ArrayRef<OpFoldResult> ofrs) {
+ SmallVector<int64_t> result;
+ for (auto o : ofrs) {
+ // Have to do this first, as getConstantIntValue special-cases constants.
+ if (o.dyn_cast<Value>())
+ result.push_back(ShapedType::kDynamic);
+ else
+ result.push_back(getConstantIntValue(o).value_or(ShapedType::kDynamic));
+ }
+ return result;
+}
+
+/// Helper for PackOp::{getResultShape,inferPackedType}. Returns the shape of
+/// the packed type. Having a shared helper helps implement these two methods in
+/// a way that ensures that they agree on which dimensions are dynamic.
+static SmallVector<int64_t> getPackOpResultTypeShape(
+ ArrayRef<int64_t> sourceShape, ArrayRef<int64_t> innerTileSizes,
+ ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm) {
+ SmallVector<int64_t> resultShape = llvm::to_vector(sourceShape);
+ for (auto tiledDim : llvm::enumerate(innerDimsPos)) {
if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
continue;
if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
@@ -3497,11 +3512,60 @@ ShapedType PackOp::inferPackedType(ShapedType sourceType,
innerTileSizes[tiledDim.index()]);
}
+ // Swap tile loops if outer_dims_perm is available.
if (!outerDimsPerm.empty())
applyPermutationToVector(resultShape, outerDimsPerm);
// Append the inner tile dimensions.
resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
+ return resultShape;
+}
+
+SmallVector<OpFoldResult> PackOp::getResultShape(
+ OpBuilder &builder, Location loc, ArrayRef<OpFoldResult> sourceDims,
+ ArrayRef<OpFoldResult> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
+ ArrayRef<int64_t> outerDimsPerm) {
+ SmallVector<OpFoldResult> resultDims = llvm::to_vector(sourceDims);
+
+ AffineExpr s0, s1;
+ bindSymbols(builder.getContext(), s0, s1);
+ AffineExpr ceilDivExpr = s0.ceilDiv(s1);
+ for (auto tiledDim : llvm::enumerate(innerDimsPos)) {
+ resultDims[tiledDim.value()] = makeComposedFoldedAffineApply(
+ builder, loc, ceilDivExpr,
+ {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
+ }
+ if (!outerDimsPerm.empty())
+ applyPermutationToVector(resultDims, outerDimsPerm);
+ resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
+
+ SmallVector<int64_t> resultTypeShape =
+ getPackOpResultTypeShape(asShapeWithAnyValueAsDynamic(sourceDims),
+ asShapeWithAnyValueAsDynamic(innerTileSizes),
+ innerDimsPos, outerDimsPerm);
+
+ // Fix-up `resultDims` to ensure that they are Value's if and only if the
+ // result type shape says it's a dynamic dim. This is needed as callers may
+ // use dispatchIndexOpFoldResults on the result, and rely on exact number of
+ // dynamic dims returned by that.
+ for (unsigned i = 0; i < resultDims.size(); ++i) {
+ if (!ShapedType::isDynamic(resultTypeShape[i]))
+ continue;
+ resultDims[i] =
+ getValueOrCreateConstantIndexOp(builder, loc, resultDims[i]);
+ }
+
+ return resultDims;
+}
+
+/// Get the expected packed type based on source type, tile factors, position of
+/// the inner tiles and permutation of the outer tiled loop.
+ShapedType PackOp::inferPackedType(ShapedType sourceType,
+ ArrayRef<int64_t> innerTileSizes,
+ ArrayRef<int64_t> innerDimsPos,
+ ArrayRef<int64_t> outerDimsPerm) {
+ SmallVector<int64_t> resultShape = getPackOpResultTypeShape(
+ sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
return RankedTensorType::get(resultShape, sourceType.getElementType());
}
More information about the Mlir-commits
mailing list