[Mlir-commits] [mlir] 9aa505a - Introduce `tensor.pack` and `tensor.unpack` operations
Lorenzo Chelini
llvmlistbot at llvm.org
Tue Nov 22 00:12:06 PST 2022
Author: Lorenzo Chelini
Date: 2022-11-22T09:11:59+01:00
New Revision: 9aa505a28d827f13ac9c6268f5834592b1a150e3
URL: https://github.com/llvm/llvm-project/commit/9aa505a28d827f13ac9c6268f5834592b1a150e3
DIFF: https://github.com/llvm/llvm-project/commit/9aa505a28d827f13ac9c6268f5834592b1a150e3.diff
LOG: Introduce `tensor.pack` and `tensor.unpack` operations
Pack and Unpack return new tensors within which the individual elements
are reshuffled according to the packing specification. This has the
consequence of modifying the canonical order in which a given operator
(i.e., Matmul) accesses the individual elements. After bufferization,
this typically translates to increased access locality and cache
behavior improvement, e.g., eliminating cache line splitting.
Co-authored-by: Mahesh Ravishankar <ravishankarm at google.com>
Co-authored-by: Han-Chung Wang <hanchung at google.com>
RFC: https://discourse.llvm.org/t/rfc-tensor-pack-and-tensor-unpack/66408/1
Reviewed By: nicolasvasilache, rengolin, hanchung
Differential Revision: https://reviews.llvm.org/D138119
Added:
Modified:
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Tensor/invalid.mlir
mlir/test/Dialect/Tensor/ops.mlir
mlir/test/Transforms/loop-invariant-code-motion.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 352002b20a738..661a8f8a6e850 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1663,6 +1663,170 @@ def Tensor_SplatOp : Tensor_Op<"splat", [
let hasFolder = 1;
}
+//===----------------------------------------------------------------------===//
+// PackOp
+//===----------------------------------------------------------------------===//
+
+class Tensor_RelayoutOp<string mnemonic, list<Trait> traits = []> :
+ Tensor_Op<mnemonic, !listconcat(traits, [
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+ DestinationStyleOpInterface,
+ ConditionallySpeculatable, NoMemoryEffect,
+ DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+ TypesMatchWith<"result type matches type of dest",
+ "dest", "result",
+ "$_self">])> {
+
+ code commonExtraClassDeclaration = [{
+ int64_t getSourceRank() { return getSource().getType().getRank(); };
+ int64_t getDestRank() { return getDest().getType().getRank(); };
+ RankedTensorType getSourceType() {
+ return getSource().getType().cast<RankedTensorType>(); };
+ RankedTensorType getDestType() {
+ return getDest().getType().cast<RankedTensorType>(); };
+
+ /// Return position for init operand. Init operand is `dest`.
+ std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
+ return {1, 2}; // `dest` operand
+ }
+
+ /// Interface method for ConditionallySpeculatable.
+ Speculation::Speculatability getSpeculatability();
+
+ /// Return a mapping from positions `inner_dims_pos` to their
+ /// tile factors.
+ DenseMap<int64_t, OpFoldResult> getDimAndTileMapping();
+
+ /// Return the tile sizes as OpFoldResult.
+ SmallVector<OpFoldResult> getMixedTiles();
+
+ /// Return the tile sizes as `int64_t`. If a tile size is dynamic
+ /// a sentinel `kDynamic` is introduced at that position in
+ /// the returned vector.
+ SmallVector<int64_t> getStaticTiles();
+ }];
+
+ let hasVerifier = 1;
+}
+
+def Tensor_PackOp : Tensor_RelayoutOp<"pack", [
+ AttrSizedOperandSegments]> {
+ let summary = "tensor pack operation";
+ let description = [{
+ The pack operation converts an input tensor to a higher-dimensional tensor
+ with a tiled and packed layout. The mandatory `inner_dims_pos` attribute
+ specifies a permutation for the original dimensions, while `inner_tiles` is the
+ tiling factor for each dimension. The optional attribute `outer_dims_perm`
+ specifies the order for the tiled data dimension, while the attribute
+ `padding_value` specifies a padding value at the boundary on non-perfectly
+ divisible dimensions. Padding is optional:
+ - If absent, it is UB if the tile does not perfectly divide the dimension.
+ - If present, it will pad along high dimensions (high-padding) to make the
+ tile complete.
+
+ Example NC_to_NCnc:
+
+ ```mlir
+ tensor.pack %source inner_dims_pos = [0, 1]
+ inner_tiles = [8, 32] into %dest : tensor<128x256xf32> -> tensor<16x8x8x32xf32>
+ ```
+ Example CK to KCck
+
+ ```mlir
+ tensor.pack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1]
+ inner_tiles = [8, 32] into %dest : tensor<128x256xf32> -> tensor<8x16x8x32xf32>
+ ```
+
+ In all cases, dimension at position 0 in the input tensor (128) is tiled
+ with a factor of 8, while dimension at position 1 (256) is tiled with a factor
+ of 32. In the second example, the outer data dimensions are interchanged
+ according to `outer_dims_perm`.
+
+ Example NC_to_NCnc with padding:
+
+ ```mlir
+ tensor.pack %arg padding_value(%pad : f32) inner_dims_pos = [0, 1]
+ inner_tiles = [8, 2] into %arg1 : tensor<13x15xf32> -> tensor<2x8x8x2xf32>
+ ```
+
+ }];
+ let arguments = (ins AnyRankedTensor:$source,
+ AnyRankedTensor:$dest,
+ Optional<AnyType>:$padding_value,
+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
+ DenseI64ArrayAttr:$inner_dims_pos,
+ Variadic<Index>:$inner_tiles,
+ I64ArrayAttr:$static_inner_tiles);
+ let results = (outs AnyRankedTensor:$result);
+ let assemblyFormat = [{
+ $source
+ (`padding_value` `(` $padding_value^ `:` type($padding_value) `)`)?
+ (`outer_dims_perm` `=` $outer_dims_perm^)?
+ `inner_dims_pos` `=` $inner_dims_pos
+ `inner_tiles` `=`
+ custom<DynamicIndexList>($inner_tiles, $static_inner_tiles,
+ "ShapedType::kDynamic")
+ `into` $dest attr-dict `:` type($source) `->` type($dest)
+ }];
+
+ let extraClassDeclaration = commonExtraClassDeclaration # [{
+ // Method to get the `ShapedType` of the result based on the inner tiles,
+ // position of the inner tiles (innerDimsPos) and interchange vector of
+ // outer loops (outerDimsPerm).
+ static ShapedType inferPackedType(ShapedType sourceType,
+ ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
+ ArrayRef<int64_t> outerDimsPerm = {});
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// UnPackOp
+//===----------------------------------------------------------------------===//
+
+def Tensor_UnPackOp : Tensor_RelayoutOp<"unpack"> {
+ let summary = "tensor unpack operation";
+ let description = [{
+ The unpack operation converts a tensor with a tiled and packed layout to a
+ lower-dimensional tensor. Similar to `pack`, the mandatory attributes
+ `inner_dims_pos` specifies a permutation for the inner data dimensions, while
+ `inner_tiles` is the tiling factor. The attribute `outer_dims_perm` has the
+ exact behavior as the one described in `pack`. In `unpack`, it is UB if the
+ tile does not perfectly divide the dimension.
+
+ Example NCnc_to_NC:
+
+ ```mlir
+ tensor.unpack %source inner_dims_pos = [0, 1]
+ inner_tiles = [8, 32] into %dest : tensor<16x8x8x32xf32> -> tensor<128x256xf32>
+ ```
+
+ Example CK to KCck:
+
+ ```mlir
+ tensor.unapck %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1]
+ inner_tiles = [8, 32] into %dest : tensor<8x16x8x32xf32> -> tensor<128x256xf32>
+ ```
+ }];
+ let arguments = (ins AnyRankedTensor:$source,
+ AnyRankedTensor:$dest,
+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
+ DenseI64ArrayAttr:$inner_dims_pos,
+ Variadic<Index>:$inner_tiles,
+ I64ArrayAttr:$static_inner_tiles);
+ let results = (outs AnyRankedTensor:$result);
+ let assemblyFormat = [{
+ $source
+ (`outer_dims_perm` `=` $outer_dims_perm^)?
+ `inner_dims_pos` `=` $inner_dims_pos
+ `inner_tiles` `=`
+ custom<DynamicIndexList>($inner_tiles, $static_inner_tiles,
+ "ShapedType::kDynamic")
+ `into` $dest attr-dict `:` type($source) `->` type($dest)
+ }];
+
+ let extraClassDeclaration = commonExtraClassDeclaration;
+}
+
//===----------------------------------------------------------------------===//
// YieldOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 019cffe34f900..95dbd77c4abc2 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -18,6 +18,7 @@
#include "mlir/IR/Matchers.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
+#include "mlir/Support/MathExtras.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
@@ -2944,6 +2945,369 @@ OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
return SplatElementsAttr::get(getType(), {constOperand});
}
+//===----------------------------------------------------------------------===//
+// PackOp/UnPackOp Common
+//===----------------------------------------------------------------------===//
+
+template <typename OpTy>
+static LogicalResult
+reifyResultShapesImpl(OpTy op, OpBuilder &builder,
+ ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+ static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
+ "applies to only pack or unpack operations");
+ int64_t destRank = op.getDestRank();
+ reifiedReturnShapes.resize(1, SmallVector<Value>(destRank));
+ for (auto dim : llvm::seq<int64_t>(0, destRank)) {
+ reifiedReturnShapes[0][dim] =
+ builder.createOrFold<tensor::DimOp>(op.getLoc(), op.getDest(), dim);
+ }
+ return success();
+}
+
+template <typename OpTy>
+static DenseMap<int64_t, OpFoldResult> getDimAndTileMappingImpl(OpTy op) {
+ static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
+ "applies to only pack or unpack operations");
+ DenseMap<int64_t, OpFoldResult> dimAndTileMapping;
+ ArrayRef<int64_t> dimsToTile = op.getInnerDimsPos();
+ SmallVector<OpFoldResult> tiles = op.getMixedTiles();
+ assert(tiles.size() == dimsToTile.size() &&
+ "tiles must match indices of dimension to block");
+ // bind the dimension `i` with the tile factor.
+ for (auto i : llvm::seq<int64_t>(0, dimsToTile.size()))
+ dimAndTileMapping[dimsToTile[i]] = tiles[i];
+ return dimAndTileMapping;
+}
+
+template <typename OpTy>
+static SmallVector<OpFoldResult> getMixedTilesImpl(OpTy op) {
+ static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
+ "applies to only pack or unpack operations");
+ SmallVector<OpFoldResult> mixedInnerTiles;
+ unsigned dynamicValIndex = 0;
+ for (Attribute attr : op.getStaticInnerTiles()) {
+ auto tileAttr = attr.cast<IntegerAttr>();
+ if (!ShapedType::isDynamic(tileAttr.getInt()))
+ mixedInnerTiles.push_back(tileAttr);
+ else
+ mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
+ }
+ return mixedInnerTiles;
+}
+
+template <typename OpTy>
+static SmallVector<int64_t> getStaticTilesImpl(OpTy op) {
+ static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
+ "applies to only pack or unpack operations");
+ SmallVector<Value> dynamicTiles;
+ SmallVector<int64_t> staticTiles;
+ dispatchIndexOpFoldResults(op.getMixedTiles(), dynamicTiles, staticTiles,
+ ShapedType::kDynamic);
+ return staticTiles;
+}
+
+/// Returns true if `dimsPos` is invalid. It is invalid when:
+/// a) It contains duplicate.
+/// b) At least one dimension is out of bound (`dimPos` is >= 0 and < rank).
+/// c) The number of elements in `dimsPos` is > than `rank`.
+static bool isInvalidPackingPosSpecification(ArrayRef<int64_t> dimsPos,
+ size_t rank) {
+ size_t dimsPosSize = dimsPos.size();
+ if (dimsPosSize > rank)
+ return true;
+ DenseSet<int64_t> uniqued;
+ for (int64_t dim : dimsPos)
+ uniqued.insert(dim);
+ if (dimsPosSize != uniqued.size())
+ return true;
+ return llvm::any_of(dimsPos, [rank](int64_t dimPos) {
+ return dimPos < 0 || dimPos >= static_cast<int64_t>(rank);
+ });
+}
+
+/// Returns true if the dimension of `sourceShape` is smaller than the dimension
+/// of the `limitShape`.
+static bool areAllInBound(ArrayRef<int64_t> sourceShape,
+ ArrayRef<int64_t> limitShape) {
+ assert(
+ sourceShape.size() == limitShape.size() &&
+ "expected source shape rank, and limit of the shape to have same rank");
+ return llvm::all_of(
+ llvm::zip(sourceShape, limitShape), [](std::tuple<int64_t, int64_t> it) {
+ int64_t sourceExtent = std::get<0>(it);
+ int64_t limit = std::get<1>(it);
+ return ShapedType::isDynamic(sourceExtent) ||
+ ShapedType::isDynamic(limit) || sourceExtent <= limit;
+ });
+}
+
+template <typename OpTy>
+static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
+ static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
+ "applies to only pack or unpack operations");
+ Operation *op = packOrUnPack.getOperation();
+
+ // Return true if we have a zero-value tile.
+ auto hasZeros = [&](ArrayRef<OpFoldResult> tiles) {
+ return llvm::any_of(
+ tiles, [](OpFoldResult tile) { return isConstantIntValue(tile, 0); });
+ };
+
+ // Verify tiles. Do not allow zero tiles.
+ SmallVector<OpFoldResult> mixedTiles = packOrUnPack.getMixedTiles();
+ if (hasZeros(mixedTiles))
+ return op->emitError("invalid zero tile factor");
+
+ // Verify inner_dims_pos and outer_dims_perm.
+ ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
+ ? packOrUnPack.getSourceType()
+ : packOrUnPack.getDestType();
+ size_t unpackedRank = unpackedType.getRank();
+ ArrayRef<int64_t> innerDimsPos = packOrUnPack.getInnerDimsPos();
+ ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm();
+ if (isInvalidPackingPosSpecification(innerDimsPos, unpackedRank))
+ return op->emitError("invalid inner_dims_pos vector");
+ if (isInvalidPackingPosSpecification(outerDimPerm, unpackedRank))
+ return op->emitError("invalid outer_dims_perm vector");
+
+ // Tiling factors must be less than or equal to the input rank for pack (or
+ // output rank for unpack), and must match the number of `inner_dims_pos`.
+ if (mixedTiles.size() > unpackedRank) {
+ return op->emitError("tiling factors must be less than or equal to the "
+ "input rank for pack or output rank for unpack");
+ }
+ if (mixedTiles.size() != innerDimsPos.size()) {
+ return op->emitError(
+ "tiling factors must equal the number of dimensions to tile");
+ }
+
+ ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
+ ? packOrUnPack.getDestType()
+ : packOrUnPack.getSourceType();
+ size_t packedRank = packedType.getRank();
+ // Require output rank to match input rank + number of blocking factors.
+ if (unpackedRank + mixedTiles.size() != packedRank) {
+ return op->emitError(
+ "packed rank must equal unpacked rank + tiling factors");
+ }
+
+ // Verify result shape is greater than the minimum expected
+ // by the pack operation, and that the output shape
+ // represents full tiles.
+ ShapedType expectedPackedType = PackOp::inferPackedType(
+ unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
+ if (!areAllInBound(expectedPackedType.getShape(), packedType.getShape())) {
+ return op->emitError("the shape of output is not large enough to hold the "
+ "packed data. Expected at least ")
+ << expectedPackedType << ", got " << packedType;
+ }
+ if (!llvm::all_of(
+ llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
+ mixedTiles),
+ [](std::tuple<int64_t, OpFoldResult> it) {
+ Optional<int64_t> constTileSize =
+ getConstantIntValue(std::get<1>(it));
+ int64_t shape = std::get<0>(it);
+ if (!constTileSize) {
+ // If specified tile size is dynamic, output shape should
+ // be dynamic too.
+ return ShapedType::isDynamic(shape);
+ } else {
+ if (ShapedType::isDynamic(shape)) {
+ // For the shape being dynamic when tile size is
+ // specified, return true. In canonical form a constant
+ // tile size should lead to constant shape of the tiled
+ // dimension, but not needed for verification.
+ return true;
+ }
+ return shape == constTileSize.value();
+ }
+ })) {
+ return op->emitError("mismatch in inner tile sizes specified and shaped of "
+ "tiled dimension in the packed type");
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// PackOp
+//===----------------------------------------------------------------------===//
+
+void PackOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult(), "pack");
+}
+
+LogicalResult
+PackOp::reifyResultShapes(OpBuilder &builder,
+ ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+ return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
+}
+
+DenseMap<int64_t, OpFoldResult> PackOp::getDimAndTileMapping() {
+ return getDimAndTileMappingImpl(*this);
+}
+
+SmallVector<OpFoldResult> PackOp::getMixedTiles() {
+ return getMixedTilesImpl(*this);
+}
+
+SmallVector<int64_t> PackOp::getStaticTiles() {
+ return getStaticTilesImpl(*this);
+}
+
+/// Check if we have enough static information to catch undefined behavior when
+/// the tile size does not divide perfectly the dimension of the input tensor.
+static bool
+areNotFullTiles(ArrayRef<int64_t> inputShape,
+ DenseMap<int64_t, OpFoldResult> const &dimAndTileMapping) {
+ int64_t rank = inputShape.size();
+ for (int64_t dim = 0; dim < rank; dim++) {
+ if (ShapedType::isDynamic(inputShape[dim]))
+ continue;
+ auto it = dimAndTileMapping.find(dim);
+ if (it == dimAndTileMapping.end())
+ continue;
+ Optional<int64_t> constantTile = getConstantIntValue(it->second);
+ if (!constantTile)
+ continue;
+ if (inputShape[dim] % (*constantTile) != 0)
+ return true;
+ }
+ return false;
+}
+
+LogicalResult PackOp::verify() {
+ if (failed(commonVerifierPackAndUnPackOp(*this)))
+ return failure();
+
+ // Verify padding value, and bail out if the tile does not divide the
+ // dimension fully. In the case of dynamic tile factors or dimensions, having
+ // a partial tile is undefined behavior.
+ auto paddingValue = getPaddingValue();
+ if (paddingValue &&
+ paddingValue.getType() != getSourceType().getElementType()) {
+ return emitOpError("expected padding_value has ")
+ << getSourceType().getElementType()
+ << " but got: " << paddingValue.getType();
+ }
+
+ auto dimAndTileMapping = getDimAndTileMapping();
+ if (!paddingValue &&
+ areNotFullTiles(getSourceType().getShape(), dimAndTileMapping)) {
+ return emitOpError("invalid tile factor provided. Only full tiles are "
+ "supported when padding_value is not set");
+ }
+ return success();
+}
+
+/// Returns a vector that interchanges `elements` starting at offset `offset`
+/// based on the indexes in `interchangeVector`.
+template <typename T>
+SmallVector<T> interchange(ArrayRef<T> elements,
+ ArrayRef<int64_t> interchangeVector,
+ int offset = 0) {
+ SmallVector<T> vec = llvm::to_vector(elements);
+ for (auto en : llvm::enumerate(interchangeVector))
+ vec[en.index() + offset] = elements[en.value() + offset];
+
+ return vec;
+}
+
+/// Get the expected packed type based on source type, tile factors, position of
+/// the inner tiles and permutation of the outer tiled loop.
+ShapedType PackOp::inferPackedType(ShapedType sourceType,
+ ArrayRef<int64_t> innerTileSizes,
+ ArrayRef<int64_t> innerDimsPos,
+ ArrayRef<int64_t> outerDimsPerm) {
+ SmallVector<int64_t> resultShape = llvm::to_vector(sourceType.getShape());
+ for (auto tiledDim : llvm::enumerate(innerDimsPos)) {
+ if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
+ continue;
+ if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
+ resultShape[tiledDim.value()] = ShapedType::kDynamic;
+ continue;
+ }
+ resultShape[tiledDim.value()] = ceilDiv(resultShape[tiledDim.value()],
+ innerTileSizes[tiledDim.index()]);
+ }
+
+ resultShape = interchange<int64_t>(resultShape, outerDimsPerm);
+
+ // Append the inner tile dimensions.
+ resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
+ return RankedTensorType::get(resultShape, sourceType.getElementType());
+}
+
+/// Returns true if the tiles and the tiled dims are constant.
+template <typename OpTy>
+bool areTilesAndTiledDimsAllConstant(OpTy op) {
+ static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
+ "applies to only pack or unpack operations");
+ ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
+ ? op.getDestType()
+ : op.getSourceType();
+ SmallVector<OpFoldResult> mixedTiles = op.getMixedTiles();
+ for (auto [dimDest, tile] : llvm::zip(
+ packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) {
+ Optional<int64_t> constTileSize = getConstantIntValue(tile);
+ if (!constTileSize || ShapedType::isDynamic(dimDest))
+ return false;
+ }
+ return true;
+}
+
+Speculation::Speculatability PackOp::getSpeculatability() {
+ if (auto paddingValue = getPaddingValue())
+ return Speculation::Speculatable;
+
+ // The verifier rejects already operations if we can statically prove that the
+ // sizes of the tiles do not divide perfectly the dimension; thus, check only
+ // to have constant tiles and tiled inner dimensions.
+ if (!areTilesAndTiledDimsAllConstant(*this))
+ return Speculation::NotSpeculatable;
+
+ return Speculation::Speculatable;
+}
+
+//===----------------------------------------------------------------------===//
+// UnPackOp
+//===----------------------------------------------------------------------===//
+
+void UnPackOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult(), "unpack");
+}
+
+LogicalResult
+UnPackOp::reifyResultShapes(OpBuilder &builder,
+ ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+ return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
+}
+
+DenseMap<int64_t, OpFoldResult> UnPackOp::getDimAndTileMapping() {
+ return getDimAndTileMappingImpl(*this);
+}
+
+SmallVector<OpFoldResult> UnPackOp::getMixedTiles() {
+ return getMixedTilesImpl(*this);
+}
+
+SmallVector<int64_t> UnPackOp::getStaticTiles() {
+ return getStaticTilesImpl(*this);
+}
+
+LogicalResult UnPackOp::verify() {
+ return commonVerifierPackAndUnPackOp(*this);
+}
+
+Speculation::Speculatability UnPackOp::getSpeculatability() {
+ // See PackOp::getSpeculatability.
+ if (!areTilesAndTiledDimsAllConstant(*this))
+ return Speculation::NotSpeculatable;
+
+ return Speculation::Speculatable;
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index d1d1c350de26e..b085053296ca4 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -522,3 +522,92 @@ func.func @empty_wrong_number_of_operands(%sz : index) {
%out = tensor.empty(%sz) : tensor<2x?x?x5xf32>
return
}
+
+// -----
+
+func.func @pack_invalid_no_padding_no_full_tiles(%input: tensor<256x128xf32>, %output: tensor<8x8x16x33xf32>) -> tensor<8x8x16x33xf32> {
+ // expected-error at +1 {{invalid tile factor provided. Only full tiles are supported when padding_value is not set}}
+ %0 = tensor.pack %input inner_dims_pos = [1, 0] inner_tiles = [16, 33] into %output : tensor<256x128xf32> -> tensor<8x8x16x33xf32>
+ return %0 : tensor<8x8x16x33xf32>
+}
+
+// -----
+
+func.func @pad_and_pack_invalid_type(%input: tensor<13x15xf32>, %output: tensor<2x8x8x2xf32>, %pad: i32) -> tensor<2x8x8x2xf32> {
+ // expected-error at +1 {{expected padding_value has 'f32' but got: 'i32'}}
+ %0 = tensor.pack %input padding_value(%pad: i32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : tensor<13x15xf32> -> tensor<2x8x8x2xf32>
+ return %0 : tensor<2x8x8x2xf32>
+}
+
+// -----
+
+func.func @pack_invalid_inner_dims_pos_vector(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> {
+ // expected-error at +1 {{invalid inner_dims_pos vector}}
+ %0 = tensor.pack %input inner_dims_pos = [2, 0] inner_tiles = [2, 2] into %output : tensor<256x128xf32> -> tensor<8x8x32x16xf32>
+ return %0 : tensor<8x8x32x16xf32>
+}
+
+// -----
+
+func.func @pack_invalid_duplicate_element_in_inner_dims(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> {
+ // expected-error at +1 {{invalid inner_dims_pos vector}}
+ %0 = tensor.pack %input inner_dims_pos = [1, 1] inner_tiles = [2, 2] into %output : tensor<256x128xf32> -> tensor<8x8x32x16xf32>
+ return %0 : tensor<8x8x32x16xf32>
+}
+
+// -----
+
+func.func @pack_invalid_duplicate_element_in_outer_perm(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> {
+ // expected-error at +1 {{invalid outer_dims_perm vector}}
+ %0 = tensor.pack %input outer_dims_perm = [1, 1] inner_dims_pos = [0, 1] inner_tiles = [2, 2] into %output : tensor<256x128xf32> -> tensor<8x8x32x16xf32>
+ return %0 : tensor<8x8x32x16xf32>
+}
+
+// -----
+
+func.func @unpack_invalid_out_of_bound_outer_perm(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> {
+ // expected-error at +1 {{invalid outer_dims_perm vector}}
+ %0 = tensor.unpack %output outer_dims_perm = [2, 1] inner_dims_pos = [0, 1] inner_tiles = [2, 2] into %input : tensor<8x8x32x16xf32> -> tensor<256x128xf32>
+ return %0 : tensor<256x128xf32>
+}
+
+// -----
+
+func.func @pack_invalid(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> {
+ // expected-error at +1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<8x8x16x32xf32>', got 'tensor<8x8x32x16xf32>'}}
+ %0 = tensor.pack %input inner_dims_pos = [1, 0] inner_tiles = [16, 32] into %output : tensor<256x128xf32> -> tensor<8x8x32x16xf32>
+ return %0 : tensor<8x8x32x16xf32>
+}
+
+// -----
+
+func.func @unpack_invalid(%output: tensor<256x128xf32>, %input: tensor<8x8x32x16xf32>) -> tensor<256x128xf32> {
+ // expected-error at +1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<8x32x4x32xf32>', got 'tensor<8x8x32x16xf32>'}}
+ %0 = tensor.unpack %input inner_dims_pos = [1, 0] inner_tiles = [4, 32] into %output : tensor<8x8x32x16xf32> -> tensor<256x128xf32>
+ return %0 : tensor<256x128xf32>
+}
+
+// -----
+
+func.func @pack_invalid(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> {
+ // expected-error at +1 {{invalid zero tile factor}}
+ %0 = tensor.pack %input inner_dims_pos = [1, 0] inner_tiles = [0, 2] into %output : tensor<256x128xf32> -> tensor<8x8x32x16xf32>
+ return %0 : tensor<8x8x32x16xf32>
+}
+
+// -----
+func.func @pack_mismatch_inner_tile_size_and_output_shape(
+ %input : tensor<?x?xf32>, %output : tensor<?x?x8x8xf32>) -> tensor<?x?x8x8xf32> {
+ // expected-error at +1 {{mismatch in inner tile sizes specified and shaped of tiled dimension in the packed type}}
+ %0 = tensor.pack %input inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %output : tensor<?x?xf32> -> tensor<?x?x8x8xf32>
+ return %0 : tensor<?x?x8x8xf32>
+}
+
+// -----
+
+func.func @unpack_mismatch_inner_tile_size_and_output_shape(
+ %input : tensor<?x?x8x8xf32>, %output : tensor<?x?xf32>) -> tensor<?x?xf32> {
+ // expected-error at +1 {{mismatch in inner tile sizes specified and shaped of tiled dimension in the packed type}}
+ %0 = tensor.unpack %input inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %output : tensor<?x?x8x8xf32> -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir
index a561726d96c9d..3bb62354102e0 100644
--- a/mlir/test/Dialect/Tensor/ops.mlir
+++ b/mlir/test/Dialect/Tensor/ops.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+// RUN: mlir-opt --split-input-file %s | mlir-opt | FileCheck %s
// CHECK-LABEL: func @cast(
func.func @cast(%arg0: tensor<*xf32>, %arg1 : tensor<4x4xf32>, %arg2: tensor<?x?xf32>) {
@@ -13,6 +13,8 @@ func.func @cast(%arg0: tensor<*xf32>, %arg1 : tensor<4x4xf32>, %arg2: tensor<?x?
return
}
+// -----
+
// CHECK-LABEL: func @empty(
// CHECK-SAME: %[[sz:.*]]: index
func.func @empty(%sz: index) -> tensor<5x?x6xf32> {
@@ -21,6 +23,8 @@ func.func @empty(%sz: index) -> tensor<5x?x6xf32> {
return %0 : tensor<5x?x6xf32>
}
+// -----
+
// CHECK-LABEL: func @empty_with_encoding(
// CHECK-SAME: %[[sz:.*]]: index
func.func @empty_with_encoding(%sz: index) -> tensor<5x?x6xf32, "foo"> {
@@ -29,6 +33,8 @@ func.func @empty_with_encoding(%sz: index) -> tensor<5x?x6xf32, "foo"> {
return %0 : tensor<5x?x6xf32, "foo">
}
+// -----
+
// CHECK-LABEL: func @extract(
// CHECK-SAME: %[[TENSOR:.*]]: tensor<?x?x?xf32>,
// CHECK-SAME: %[[INDEX:.*]]: index) {
@@ -38,6 +44,8 @@ func.func @extract(%arg0: tensor<?x?x?xf32>, %arg1: index) {
return
}
+// -----
+
// CHECK-LABEL: func @insert(
// CHECK-SAME: %[[SCALAR:.*]]: f32
// CHECK-SAME: %[[INDEX:.*]]: index
@@ -48,6 +56,8 @@ func.func @insert(%arg0: f32, %arg1: index, %arg2: tensor<?x?x?xf32>) {
return
}
+// -----
+
// CHECK-LABEL: func @tensor.from_elements() {
func.func @tensor.from_elements() {
%c0 = "arith.constant"() {value = 0: index} : () -> index
@@ -74,6 +84,8 @@ func.func @tensor.from_elements() {
return
}
+// -----
+
// CHECK-LABEL: @tensor.generate
func.func @tensor.generate(%m : index, %n : index)
-> tensor<?x3x?xf32> {
@@ -85,6 +97,8 @@ func.func @tensor.generate(%m : index, %n : index)
return %tnsr : tensor<?x3x?xf32>
}
+// -----
+
// CHECK-LABEL: func @tensor_reshape
func.func @tensor_reshape(%unranked: tensor<*xf32>, %shape1: tensor<1xi32>,
%shape2: tensor<2xi32>, %shape3: tensor<?xi32>) -> tensor<*xf32> {
@@ -97,6 +111,8 @@ func.func @tensor_reshape(%unranked: tensor<*xf32>, %shape1: tensor<1xi32>,
return %new_unranked : tensor<*xf32>
}
+// -----
+
// CHECK-LABEL: func @slice({{.*}}) {
func.func @slice(%t: tensor<8x16x4xf32>, %idx : index) {
%c0 = arith.constant 0 : index
@@ -120,6 +136,8 @@ func.func @slice(%t: tensor<8x16x4xf32>, %idx : index) {
return
}
+// -----
+
// CHECK-LABEL: func @insert_slice({{.*}}) {
func.func @insert_slice(
%t: tensor<8x16x4xf32>,
@@ -154,6 +172,8 @@ func.func @insert_slice(
return
}
+// -----
+
func.func @tensor_reshape_zero_dim(%arg0 : tensor<1x1xf32>, %arg1 : tensor<f32>)
-> (tensor<f32>, tensor<1x1xf32>) {
%0 = tensor.collapse_shape %arg0 [] : tensor<1x1xf32> into tensor<f32>
@@ -164,6 +184,8 @@ func.func @tensor_reshape_zero_dim(%arg0 : tensor<1x1xf32>, %arg1 : tensor<f32>)
// CHECK: tensor.collapse_shape %{{.*}} [] : tensor<1x1xf32> into tensor<f32>
// CHECK: tensor.expand_shape %{{.*}} [] : tensor<f32> into tensor<1x1xf32>
+// -----
+
func.func @legal_collapsing_reshape_dynamic_tensor
(%arg0: tensor<?x?x?x4x?xf32>) -> tensor<?x?x?xf32>
{
@@ -175,6 +197,8 @@ func.func @legal_collapsing_reshape_dynamic_tensor
// CHECK: tensor.collapse_shape
// CHECK-SAME: [0], [1], [2, 3, 4]
+// -----
+
func.func @rank(%t : tensor<4x4x?xf32>) {
// CHECK: %{{.*}} = tensor.rank %{{.*}} : tensor<4x4x?xf32>
%0 = "tensor.rank"(%t) : (tensor<4x4x?xf32>) -> index
@@ -184,6 +208,8 @@ func.func @rank(%t : tensor<4x4x?xf32>) {
return
}
+// -----
+
func.func @pad_dynamic(%arg0: tensor<1x2x2x?xf32>, %low: index, %high: index,
%pad_value: f32) -> tensor<6x?x?x?xf32> {
%0 = tensor.pad %arg0 low[2, %low, 3, 3] high[3, 3, %high, 2] {
@@ -201,6 +227,8 @@ func.func @pad_dynamic(%arg0: tensor<1x2x2x?xf32>, %low: index, %high: index,
// CHECK-SAME: high[3, 3, %[[HIGH]], 2]
// CHECK: : tensor<1x2x2x?xf32> to tensor<6x?x?x?xf32>
+// -----
+
func.func @pad_static(%arg0: tensor<3x4xf32>, %pad_value: f32) -> tensor<6x9xf32> {
%0 = tensor.pad %arg0 low[1, 2] high[2, 3] {
^bb0(%arg1 : index, %arg2 : index):
@@ -213,6 +241,8 @@ func.func @pad_static(%arg0: tensor<3x4xf32>, %pad_value: f32) -> tensor<6x9xf32
// CHECK: tensor.pad %[[ARG0]] low[1, 2] high[2, 3]
// CHECK: : tensor<3x4xf32> to tensor<6x9xf32>
+// -----
+
func.func @pad_asymmetrical(%arg0: tensor<2x3xf32>, %ub0: index, %ub1: index,
%pad_value: f32) -> tensor<?x?xf32> {
%0 = tensor.pad %arg0 low[0, 0] high[%ub0, %ub1] {
@@ -230,6 +260,8 @@ func.func @pad_asymmetrical(%arg0: tensor<2x3xf32>, %ub0: index, %ub1: index,
// CHECK-SAME: high[%[[UB0]], %[[UB1]]]
// CHECK: : tensor<2x3xf32> to tensor<?x?xf32>
+// -----
+
func.func @pad_to_static_size(%arg0: tensor<?x?xf32>, %ub0: index, %ub1: index,
%pad_value: f32) -> tensor<2x3xf32> {
%0 = tensor.pad %arg0 low[0, 0] high[%ub0, %ub1] {
@@ -247,6 +279,8 @@ func.func @pad_to_static_size(%arg0: tensor<?x?xf32>, %ub0: index, %ub1: index,
// CHECK-SAME: high[%[[UB0]], %[[UB1]]]
// CHECK: : tensor<?x?xf32> to tensor<2x3xf32>
+// -----
+
// CHECK-LABEL: func @test_splat_op
// CHECK-SAME: [[S:%arg[0-9]+]]: f32
func.func @test_splat_op(%s : f32) {
@@ -258,6 +292,8 @@ func.func @test_splat_op(%s : f32) {
return
}
+// -----
+
// CHECK-LABEL: func.func @gather_scatter(
// CHECK-SAME: %[[ARG0:.*]]: tensor<4x5x6xf32>,
// CHECK-SAME: %[[ARG1:.*]]: tensor<1x3x2xindex>,
@@ -281,3 +317,106 @@ func.func @gather_scatter(
(tensor<1x3x4xf32>, tensor<4x5x6xf32>, tensor<1x3x2xi32>) -> tensor<4x5x6xf32>
return
}
+
+// -----
+
+func.func @pack_nc_to_ncnc(%source: tensor<128x256xf32>, %dest: tensor<4x16x32x16xf32>) -> tensor<128x256xf32> {
+ %0 = tensor.pack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<128x256xf32> -> tensor<4x16x32x16xf32>
+ %1 = tensor.empty() : tensor<128x256xf32>
+ %2 = tensor.unpack %0 inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %1 : tensor<4x16x32x16xf32> -> tensor<128x256xf32>
+ return %2 : tensor<128x256xf32>
+}
+
+// CHECK-LABEL: func.func @pack_nc_to_ncnc(
+// CHECK-SAME: %[[SOURCE:.*]]: tensor<128x256xf32>,
+// CHECK-SAME: %[[DEST:.*]]: tensor<4x16x32x16xf32>)
+// CHECK: %[[PACKED:.*]] = tensor.pack %[[SOURCE]] inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %[[DEST]] : tensor<128x256xf32> -> tensor<4x16x32x16xf32>
+// CHECK: %[[BUFF:.*]] = tensor.empty() : tensor<128x256xf32>
+// CHECK: %{{.*}} = tensor.unpack %[[PACKED]] inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %[[BUFF]] : tensor<4x16x32x16xf32> -> tensor<128x256xf32>
+
+// -----
+
+func.func @pack_nc_to_ncnc_with_padding(%source: tensor<13x15xf32>, %dest: tensor<2x8x8x2xf32>, %padding: f32) -> tensor<13x15xf32> {
+ %0 = tensor.pack %source padding_value(%padding : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %dest : tensor<13x15xf32> -> tensor<2x8x8x2xf32>
+ %1 = tensor.empty() : tensor<13x15xf32>
+ %2 = tensor.unpack %0 inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %1 : tensor<2x8x8x2xf32> -> tensor<13x15xf32>
+ return %2 : tensor<13x15xf32>
+}
+
+// CHECK-LABEL: func.func @pack_nc_to_ncnc_with_padding(
+// CHECK-SAME: %[[SOURCE:.*]]: tensor<13x15xf32>,
+// CHECK-SAME: %[[DEST:.*]]: tensor<2x8x8x2xf32>,
+// CHECK-SAME: %[[PADDING:.*]]: f32)
+// CHECK: %[[PACKED:.*]] = tensor.pack %[[SOURCE]] padding_value(%[[PADDING]] : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %[[DEST]] : tensor<13x15xf32> -> tensor<2x8x8x2xf32>
+// CHECK: %[[BUFF:.*]] = tensor.empty() : tensor<13x15xf32>
+// CHECK: %{{.*}} = tensor.unpack %[[PACKED]] inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %[[BUFF]] : tensor<2x8x8x2xf32> -> tensor<13x15xf32>
+
+// -----
+
+func.func @pack_ck_to_kcck(%source: tensor<128x256xf32>, %dest: tensor<16x4x32x16xf32>) -> tensor<128x256xf32> {
+ %0 = tensor.pack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<128x256xf32> -> tensor<16x4x32x16xf32>
+ %1 = tensor.empty() : tensor<128x256xf32>
+ %2 = tensor.unpack %0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %1 : tensor<16x4x32x16xf32> -> tensor<128x256xf32>
+ return %2 : tensor<128x256xf32>
+}
+
+// CHECK-LABEL: func.func @pack_ck_to_kcck(
+// CHECK-SAME: %[[SOURCE:.*]]: tensor<128x256xf32>,
+// CHECK-SAME: %[[DEST:.*]]: tensor<16x4x32x16xf32>)
+// CHECK: %[[PACKED:.*]] = tensor.pack %[[SOURCE]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %[[DEST]] : tensor<128x256xf32> -> tensor<16x4x32x16xf32>
+// CHECK: %[[BUFF:.*]] = tensor.empty() : tensor<128x256xf32>
+// CHECK: %{{.*}} = tensor.unpack %[[PACKED]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %[[BUFF]] : tensor<16x4x32x16xf32> -> tensor<128x256xf32>
+
+// -----
+
+func.func @pad_and_pack_fully_dynamic(%source: tensor<?x?xf32>, %dest: tensor<?x?x?x?xf32>, %pad: f32, %tile_n : index, %tile_m : index) -> tensor<?x?x?x?xf32> {
+ %0 = tensor.pack %source padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [%tile_n, %tile_m] into %dest : tensor<?x?xf32> -> tensor<?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?xf32>
+}
+
+// CHECK-LABEL: func.func @pad_and_pack_fully_dynamic(
+// CHECK-SAME: %[[SOURCE:.*]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[DEST:.*]]: tensor<?x?x?x?xf32>,
+// CHECK-SAME: %[[PAD:.*]]: f32,
+// CHECK-SAME: %[[TILE_N:.*]]: index,
+// CHECK-SAME: %[[TILE_M:.*]]: index)
+// CHECK: %{{.*}} = tensor.pack %[[SOURCE]] padding_value(%[[PAD]] : f32) inner_dims_pos = [0, 1] inner_tiles = [%[[TILE_N]], %[[TILE_M]]] into %[[DEST]] : tensor<?x?xf32> -> tensor<?x?x?x?xf32>
+
+// -----
+
+func.func @pad_and_pack_partially_dynamic(%source: tensor<?x?xf32>, %dest: tensor<?x?x8x2xf32>, %pad: f32) -> tensor<?x?x8x2xf32> {
+ %0 = tensor.pack %source padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
+ return %0 : tensor<?x?x8x2xf32>
+}
+
+// CHECK-LABEL: func.func @pad_and_pack_partially_dynamic(
+// CHECK-SAME: %[[SOURCE:.*]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[DEST:.*]]: tensor<?x?x8x2xf32>,
+// CHECK-SAME: %[[PAD:.*]]: f32)
+// CHECK: %{{.*}} = tensor.pack %[[SOURCE]] padding_value(%[[PAD]] : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %[[DEST]] : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
+
+// -----
+
+func.func @unpack_fully_dynamic(%source: tensor<?x?x?x?xf32>, %dest: tensor<?x?xf32>, %tile_n : index, %tile_m : index) -> tensor<?x?xf32> {
+ %0 = tensor.unpack %source inner_dims_pos = [0, 1] inner_tiles = [%tile_n, %tile_m] into %dest : tensor<?x?x?x?xf32> -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: func.func @unpack_fully_dynamic(
+// CHECK-SAME: %[[SOURCE:.*]]: tensor<?x?x?x?xf32>,
+// CHECK-SAME: %[[DEST:.*]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[TILE_N:.*]]: index,
+// CHECK-SAME: %[[TILE_M:.*]]: index)
+// CHECK: %{{.*}} = tensor.unpack %[[SOURCE]] inner_dims_pos = [0, 1] inner_tiles = [%[[TILE_N]], %[[TILE_M]]] into %[[DEST]] : tensor<?x?x?x?xf32> -> tensor<?x?xf32>
+
+// -----
+
+func.func @unpack_partially_dynamic(%source: tensor<?x?x8x2xf32>, %dest: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = tensor.unpack %source inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %dest : tensor<?x?x8x2xf32> -> tensor<?x?xf32>
+ return %0: tensor<?x?xf32>
+}
+
+// CHECK-LABEL: func.func @unpack_partially_dynamic(
+// CHECK-SAME: %[[SOURCE:.*]]: tensor<?x?x8x2xf32>,
+// CHECK-SAME: %[[DEST:.*]]: tensor<?x?xf32>)
+// CHECK: %{{.*}} = tensor.unpack %[[SOURCE]] inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %[[DEST]] : tensor<?x?x8x2xf32> -> tensor<?x?xf32>
diff --git a/mlir/test/Transforms/loop-invariant-code-motion.mlir b/mlir/test/Transforms/loop-invariant-code-motion.mlir
index 4d2b5b778a71d..090962db4c1ec 100644
--- a/mlir/test/Transforms/loop-invariant-code-motion.mlir
+++ b/mlir/test/Transforms/loop-invariant-code-motion.mlir
@@ -874,3 +874,58 @@ func.func @speculate_ceildivsi_const(
return
}
+
+// -----
+
+func.func @speculate_static_pack_and_unpack(%source: tensor<128x256xf32>,
+ %dest: tensor<4x16x32x16xf32>, %lb: index, %ub: index, %step: index) {
+
+ // CHECK: tensor.pack
+ // CHECK-NEXT: scf.for
+ scf.for %i = %lb to %ub step %step {
+ %packed = tensor.pack %source
+ inner_dims_pos = [0, 1]
+ inner_tiles = [32, 16] into %dest : tensor<128x256xf32> -> tensor<4x16x32x16xf32>
+ }
+
+ // CHECK: tensor.unpack
+ // CHECK-NEXT: scf.for
+ scf.for %i = %lb to %ub step %step {
+ %unpacked = tensor.unpack %dest
+ inner_dims_pos = [0, 1]
+ inner_tiles = [32, 16] into %source : tensor<4x16x32x16xf32> -> tensor<128x256xf32>
+ }
+ return
+}
+
+// -----
+
+func.func @speculate_dynamic_pack_and_unpack(%source: tensor<?x?xf32>,
+ %dest: tensor<?x?x?x?xf32>, %lb: index, %ub: index, %step: index,
+ %tile_m: index, %tile_n: index, %pad: f32) {
+
+ // CHECK: scf.for
+ // CHECK-NEXT: tensor.pack
+ scf.for %i = %lb to %ub step %step {
+ %packed = tensor.pack %source
+ inner_dims_pos = [0, 1]
+ inner_tiles = [%tile_n, %tile_m] into %dest : tensor<?x?xf32> -> tensor<?x?x?x?xf32>
+ }
+
+ // CHECK: scf.for
+ // CHECK-NEXT: tensor.unpack
+ scf.for %i = %lb to %ub step %step {
+ %unpacked = tensor.unpack %dest
+ inner_dims_pos = [0, 1]
+ inner_tiles = [%tile_n, %tile_m] into %source : tensor<?x?x?x?xf32> -> tensor<?x?xf32>
+ }
+
+ // CHECK: tensor.pack
+ // CHECK-NEXT: scf.for
+ scf.for %i = %lb to %ub step %step {
+ %packed = tensor.pack %source padding_value(%pad : f32)
+ inner_dims_pos = [0, 1]
+ inner_tiles = [%tile_n, %tile_m] into %dest : tensor<?x?xf32> -> tensor<?x?x?x?xf32>
+ }
+ return
+}
More information about the Mlir-commits
mailing list