[Mlir-commits] [mlir] 4b066c7 - [mlir][linalg] Extend linalg.pack and linalg.unpack to accept memref (#167675)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jan 19 07:42:34 PST 2026
Author: Ryutaro Okada
Date: 2026-01-19T16:42:27+01:00
New Revision: 4b066c7fff3455dc547fabb676583391febe41e9
URL: https://github.com/llvm/llvm-project/commit/4b066c7fff3455dc547fabb676583391febe41e9
DIFF: https://github.com/llvm/llvm-project/commit/4b066c7fff3455dc547fabb676583391febe41e9.diff
LOG: [mlir][linalg] Extend linalg.pack and linalg.unpack to accept memref (#167675)
Extend linalg.pack and linalg.unpack to accept memref operands in
addition to tensors. As part of this change, we now disable all
transformations when these ops have memref semantics.
Closes https://github.com/llvm/llvm-project/issues/129004
---------
Signed-off-by: Ryutaro Okada <1015ryu88 at gmail.com>
Co-authored-by: Hyunsung Lee <ita9naiwa at gmail.com>
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/python/mlir/dialects/linalg/__init__.py
mlir/test/Dialect/Linalg/canonicalize.mlir
mlir/test/Dialect/Linalg/roundtrip.mlir
mlir/test/python/dialects/linalg/ops.py
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index f80c302ba7c51..95383e6262f71 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -7,11 +7,7 @@
//===----------------------------------------------------------------------===//
//
// This file defines Pack + Unpack Ops that have been moved from the Tensor
-// dialect. As such, these are defined as memory-effect-free and only accept
-// "tensors" as inputs.
-//
-// TODO: Once a good motivating example is identified, relax these
-// restrictions.
+// dialect.
//
//===----------------------------------------------------------------------===//
@@ -30,24 +26,27 @@ include "mlir/IR/OpAsmInterface.td"
// RelayoutOp
//===----------------------------------------------------------------------===//
-class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
- Op<Linalg_Dialect, mnemonic, !listconcat(traits, [
- DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
- DestinationStyleOpInterface, LinalgRelayoutOpInterface,
- ConditionallySpeculatable, NoMemoryEffect,
- DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
+class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []>
+ : Op<Linalg_Dialect, mnemonic,
+ !listconcat(
+ traits, [DeclareOpInterfaceMethods<
+ OpAsmOpInterface, ["getAsmResultNames"]>,
+ DestinationStyleOpInterface, LinalgRelayoutOpInterface,
+ ConditionallySpeculatable,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ DeclareOpInterfaceMethods<
+ ReifyRankedShapedTypeOpInterface, [
"reifyResultShapes"]>,
- TypesMatchWith<"result type matches type of dest",
- "dest", "result",
- "$_self">])> {
+ OptionalTypesMatchWith<"result type matches type of dest",
+ "dest", "result", "$_self">])> {
code commonExtraClassDeclaration = [{
size_t getSourceRank() { return getSourceType().getRank(); };
size_t getDestRank() { return getDestType().getRank(); };
- RankedTensorType getSourceType() {
- return ::llvm::cast<RankedTensorType>(getSource().getType()); };
- RankedTensorType getDestType() {
- return ::llvm::cast<RankedTensorType>(getDest().getType()); };
+ ShapedType getSourceType() {
+ return ::llvm::cast<ShapedType>(getSource().getType()); };
+ ShapedType getDestType() {
+ return ::llvm::cast<ShapedType>(getDest().getType()); };
MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
@@ -195,23 +194,14 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
// expect tensor<2x8xf32> because CeilDiv(9, 8) = 2
```
}];
- let arguments = (ins AnyRankedTensor:$source,
- AnyRankedTensor:$dest,
- Optional<AnyType>:$padding_value,
- DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
- DenseI64ArrayAttr:$inner_dims_pos,
- Variadic<Index>:$inner_tiles,
- DenseI64ArrayAttr:$static_inner_tiles);
- let results = (outs AnyRankedTensor:$result);
- let assemblyFormat = [{
- $source
- (`padding_value` `(` $padding_value^ `:` type($padding_value) `)`)?
- (`outer_dims_perm` `=` $outer_dims_perm^)?
- `inner_dims_pos` `=` $inner_dims_pos
- `inner_tiles` `=`
- custom<DynamicIndexList>($inner_tiles, $static_inner_tiles)
- `into` $dest attr-dict `:` type($source) `->` type($dest)
- }];
+ let arguments = (ins TensorOrMemRef<[AnyType]>:$source,
+ TensorOrMemRef<[AnyType]>:$dest,
+ Optional<AnyType>:$padding_value,
+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
+ DenseI64ArrayAttr:$inner_dims_pos,
+ Variadic<Index>:$inner_tiles,
+ DenseI64ArrayAttr:$static_inner_tiles);
+ let results = (outs Optional<AnyRankedTensor>:$result);
let builders = [
OpBuilder<(ins "Value":$source, "Value":$dest,
@@ -233,7 +223,21 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
// Method to get the `RankedTensorType` of the result based on the inner
// tiles, position of the inner tiles (innerDimsPos) and interchange vector
// of outer loops (outerDimsPerm).
- static RankedTensorType inferPackedType(RankedTensorType sourceType,
+ static RankedTensorType inferPackedTensorType(RankedTensorType sourceType,
+ ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
+ ArrayRef<int64_t> outerDimsPerm = {});
+
+ // Method to get the `MemRefType` of the result based on the inner
+ // tiles, position of the inner tiles (innerDimsPos) and interchange vector
+ // of outer loops (outerDimsPerm).
+ static MemRefType inferPackedMemRefType(MemRefType sourceType,
+ ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
+ ArrayRef<int64_t> outerDimsPerm = {});
+
+ // Returns the shape of the packed type. It is a shared helper that helps
+ // type inference methods in a way that ensures that they agree on which
+ // dimensions are dynamic.
+ static SmallVector<int64_t> inferPackedShape(ArrayRef<int64_t> inputShape,
ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outerDimsPerm = {});
@@ -285,6 +289,8 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
let hasCanonicalizeMethod = 1;
let hasFolder = 1;
+
+ let hasCustomAssemblyFormat = 1;
}
//===----------------------------------------------------------------------===//
@@ -352,21 +358,12 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
// Outer Dims: 9x3x8 Inner Dims: 4x2
```
}];
- let arguments = (ins AnyRankedTensor:$source,
- AnyRankedTensor:$dest,
- DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
- DenseI64ArrayAttr:$inner_dims_pos,
- Variadic<Index>:$inner_tiles,
- DenseI64ArrayAttr:$static_inner_tiles);
- let results = (outs AnyRankedTensor:$result);
- let assemblyFormat = [{
- $source
- (`outer_dims_perm` `=` $outer_dims_perm^)?
- `inner_dims_pos` `=` $inner_dims_pos
- `inner_tiles` `=`
- custom<DynamicIndexList>($inner_tiles, $static_inner_tiles)
- `into` $dest attr-dict `:` type($source) `->` type($dest)
- }];
+ let arguments = (ins TensorOrMemRef<[AnyType]>:$source,
+ TensorOrMemRef<[AnyType]>:$dest,
+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
+ DenseI64ArrayAttr:$inner_dims_pos, Variadic<Index>:$inner_tiles,
+ DenseI64ArrayAttr:$static_inner_tiles);
+ let results = (outs Optional<AnyRankedTensor>:$result);
let builders = [
OpBuilder<(ins "Value":$source, "Value":$dest,
@@ -409,6 +406,8 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
let hasCanonicalizeMethod = 1;
let hasFolder = 1;
+
+ let hasCustomAssemblyFormat = 1;
}
#endif // LINALG_RELEAYOUT_OPS
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index e9f617a785d22..40cabe20d1a4b 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -4970,12 +4970,12 @@ Speculation::Speculatability ElementwiseOp::getSpeculatability() {
template <typename OpTy, typename>
SmallVector<int64_t>
getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack) {
- RankedTensorType packedType = (std::is_same<OpTy, PackOp>::value)
- ? packOrUnPack.getDestType()
- : packOrUnPack.getSourceType();
- RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
- ? packOrUnPack.getSourceType()
- : packOrUnPack.getDestType();
+ ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
+ ? packOrUnPack.getDestType()
+ : packOrUnPack.getSourceType();
+ ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
+ ? packOrUnPack.getSourceType()
+ : packOrUnPack.getDestType();
SmallVector<int64_t> result(
packedType.getShape().take_front(unpackedType.getRank()));
if (!packOrUnPack.getOuterDimsPerm().empty()) {
@@ -5109,15 +5109,30 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
return llvm::any_of(tiles, isZeroInteger);
};
+ // Verify that the source and destination are ranked types.
+ if (!packOrUnPack.getSourceType().hasRank() ||
+ !packOrUnPack.getDestType().hasRank())
+ return op->emitError("expected both source and destination to have rank");
+
+ // Verify that the Operation does not have mixed tensor/buffer semantics.
+ if (!packOrUnPack.hasPureBufferSemantics() &&
+ !packOrUnPack.hasPureTensorSemantics())
+ return op->emitError("mixing tensor and buffer semantics is not allowed");
+ const unsigned numResults = packOrUnPack.getNumResults();
+ if (packOrUnPack.hasPureTensorSemantics() && numResults != 1)
+ return op->emitError("expected 1 result, got ") << numResults;
+ if (packOrUnPack.hasPureBufferSemantics() && numResults != 0)
+ return op->emitError("expected 0 results, got ") << numResults;
+
// Verify tiles. Do not allow zero tiles.
SmallVector<OpFoldResult> mixedTiles = packOrUnPack.getMixedTiles();
if (hasZeros(mixedTiles))
return op->emitError("invalid zero tile factor");
// Verify inner_dims_pos and outer_dims_perm.
- RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
- ? packOrUnPack.getSourceType()
- : packOrUnPack.getDestType();
+ ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
+ ? packOrUnPack.getSourceType()
+ : packOrUnPack.getDestType();
size_t unpackedRank = unpackedType.getRank();
ArrayRef<int64_t> innerDimsPos = packOrUnPack.getInnerDimsPos();
ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm();
@@ -5154,8 +5169,9 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
// Verify result shape is greater than the minimum expected
// by the pack operation, and that the output shape
// represents full tiles.
- RankedTensorType expectedPackedType = PackOp::inferPackedType(
- unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
+ SmallVector<int64_t> expectedPackedShape = PackOp::inferPackedShape(
+ unpackedType.getShape(), packOrUnPack.getStaticTiles(),
+ packOrUnPack.getInnerDimsPos(), packOrUnPack.getOuterDimsPerm());
if (!llvm::all_of(
llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
mixedTiles),
@@ -5172,11 +5188,20 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
return op->emitError("mismatch in inner tile sizes specified and shaped of "
"tiled dimension in the packed type");
}
- if (failed(verifyCompatibleShape(expectedPackedType.getShape(),
- packedType.getShape()))) {
+ if (failed(
+ verifyCompatibleShape(expectedPackedShape, packedType.getShape()))) {
+ auto elementType = unpackedType.getElementType();
+ Type expectedType, actualType;
+ if (packOrUnPack.hasPureTensorSemantics()) {
+ expectedType = RankedTensorType::get(expectedPackedShape, elementType);
+ actualType = RankedTensorType::get(packedType.getShape(), elementType);
+ } else {
+ expectedType = MemRefType::get(expectedPackedShape, elementType);
+ actualType = MemRefType::get(packedType.getShape(), elementType);
+ }
return op->emitError("expected ")
- << expectedPackedType << " for the packed domain value, got "
- << packedType;
+ << expectedType << " for the packed domain value, got "
+ << actualType;
}
return success();
}
@@ -5237,7 +5262,154 @@ commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp,
//===----------------------------------------------------------------------===//
void PackOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
- setNameFn(getResult(), "pack");
+ if (!getResults().empty())
+ setNameFn(getResult(), "pack");
+}
+
+ParseResult PackOp::parse(OpAsmParser &parser, OperationState &result) {
+ OpAsmParser::UnresolvedOperand source, dest;
+ SmallVector<OpAsmParser::UnresolvedOperand> dynamicTiles;
+ SmallVector<OpAsmParser::UnresolvedOperand> paddingValue;
+ SmallVector<Type> paddingValueType;
+ SmallVector<int64_t> staticTiles;
+ DenseI64ArrayAttr innerDimsPos, outerDimsPerm;
+ Type sourceType, destType, resultType;
+
+ if (parser.parseOperand(source))
+ return failure();
+
+ if (succeeded(parser.parseOptionalKeyword("padding_value"))) {
+ if (parser.parseLParen() ||
+ parser.parseOperandList(paddingValue, /*requiredOperandCount=*/1) ||
+ parser.parseColon() || parser.parseTypeList(paddingValueType) ||
+ parser.parseRParen())
+ return failure();
+ }
+
+ if (succeeded(parser.parseOptionalKeyword("outer_dims_perm"))) {
+ if (parser.parseEqual())
+ return failure();
+
+ SmallVector<int64_t> outerDimsPermVec;
+ if (parser.parseCommaSeparatedList(AsmParser::Delimiter::Square, [&]() {
+ int64_t value;
+ if (parser.parseInteger(value))
+ return failure();
+ outerDimsPermVec.push_back(value);
+ return success();
+ }))
+ return failure();
+ outerDimsPerm = parser.getBuilder().getDenseI64ArrayAttr(outerDimsPermVec);
+ }
+
+ if (parser.parseKeyword("inner_dims_pos") || parser.parseEqual())
+ return failure();
+
+ SmallVector<int64_t> innerDimsPosVec;
+ if (parser.parseCommaSeparatedList(AsmParser::Delimiter::Square, [&]() {
+ int64_t value;
+ if (parser.parseInteger(value))
+ return failure();
+ innerDimsPosVec.push_back(value);
+ return success();
+ }))
+ return failure();
+ innerDimsPos = parser.getBuilder().getDenseI64ArrayAttr(innerDimsPosVec);
+
+ if (parser.parseKeyword("inner_tiles") || parser.parseEqual())
+ return failure();
+
+ DenseI64ArrayAttr staticTilesAttr;
+ if (parseDynamicIndexList(parser, dynamicTiles, staticTilesAttr))
+ return failure();
+ for (auto val : staticTilesAttr.asArrayRef())
+ staticTiles.push_back(val);
+
+ if (parser.parseKeyword("into") || parser.parseOperand(dest))
+ return failure();
+
+ if (parser.parseOptionalAttrDict(result.attributes))
+ return failure();
+
+ if (parser.parseColon() || parser.parseType(sourceType))
+ return failure();
+
+ bool hasArrow = succeeded(parser.parseOptionalArrow());
+ if (hasArrow) {
+ if (parser.parseType(destType))
+ return failure();
+ }
+
+ bool isMemRef = llvm::isa<MemRefType>(sourceType);
+ if (!hasArrow) {
+ return parser.emitError(parser.getCurrentLocation(),
+ "pack/unpack requires '->' and destination type");
+ }
+
+ if (!isMemRef)
+ resultType = destType;
+
+ if (parser.resolveOperand(source, sourceType, result.operands) ||
+ parser.resolveOperand(dest, destType, result.operands))
+ return failure();
+
+ if (!paddingValue.empty() &&
+ parser.resolveOperands(paddingValue, paddingValueType[0],
+ result.operands))
+ return failure();
+
+ if (!dynamicTiles.empty() &&
+ parser.resolveOperands(dynamicTiles, parser.getBuilder().getIndexType(),
+ result.operands))
+ return failure();
+
+ result.addAttribute("static_inner_tiles",
+ parser.getBuilder().getDenseI64ArrayAttr(staticTiles));
+ result.addAttribute("inner_dims_pos", innerDimsPos);
+ if (outerDimsPerm)
+ result.addAttribute("outer_dims_perm", outerDimsPerm);
+
+ SmallVector<int32_t> segmentSizes = {
+ 1, 1, static_cast<int32_t>(paddingValue.size()),
+ static_cast<int32_t>(dynamicTiles.size())};
+ result.addAttribute("operandSegmentSizes",
+ parser.getBuilder().getDenseI32ArrayAttr(segmentSizes));
+
+ if (!isMemRef)
+ result.addTypes(resultType);
+
+ return success();
+}
+
+void PackOp::print(OpAsmPrinter &p) {
+ p << " " << getSource();
+
+ if (getPaddingValue()) {
+ p << " padding_value(" << getPaddingValue() << " : "
+ << getPaddingValue().getType() << ")";
+ }
+
+ if (!getOuterDimsPerm().empty()) {
+ p << " outer_dims_perm = [";
+ llvm::interleaveComma(getOuterDimsPerm(), p);
+ p << "]";
+ }
+
+ p << " inner_dims_pos = [";
+ llvm::interleaveComma(getInnerDimsPos(), p);
+ p << "]";
+
+ p << " inner_tiles = ";
+ printDynamicIndexList(p, *this, getInnerTiles(), getStaticInnerTilesAttr());
+
+ p << " into " << getDest();
+
+ p.printOptionalAttrDict((*this)->getAttrs(),
+ {"static_inner_tiles", "inner_dims_pos",
+ "outer_dims_perm", "operandSegmentSizes"});
+
+ p << " : " << getSource().getType();
+ p << " -> " << getDest().getType();
}
void PackOp::build(OpBuilder &builder, OperationState &state, Value source,
@@ -5262,6 +5434,8 @@ void PackOp::build(OpBuilder &builder, OperationState &state, Value source,
LogicalResult
PackOp::reifyResultShapes(OpBuilder &builder,
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+ if (!hasPureTensorSemantics())
+ return failure();
return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
}
@@ -5397,13 +5571,11 @@ asShapeWithAnyValueAsDynamic(ArrayRef<OpFoldResult> ofrs) {
return result;
}
-/// Helper for PackOp::{getResultShape,inferPackedType}. Returns the shape of
-/// the packed type. Having a shared helper helps implement these two methods in
-/// a way that ensures that they agree on which dimensions are dynamic.
-static SmallVector<int64_t> getPackOpResultTypeShape(
- ArrayRef<int64_t> sourceShape, ArrayRef<int64_t> innerTileSizes,
- ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm) {
- SmallVector<int64_t> resultShape = llvm::to_vector(sourceShape);
+SmallVector<int64_t> PackOp::inferPackedShape(ArrayRef<int64_t> inputShape,
+ ArrayRef<int64_t> innerTileSizes,
+ ArrayRef<int64_t> innerDimsPos,
+ ArrayRef<int64_t> outerDimsPerm) {
+ SmallVector<int64_t> resultShape = llvm::to_vector(inputShape);
for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
continue;
@@ -5443,9 +5615,9 @@ SmallVector<OpFoldResult> PackOp::getResultShape(
resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
SmallVector<int64_t> resultTypeShape =
- getPackOpResultTypeShape(asShapeWithAnyValueAsDynamic(sourceDims),
- asShapeWithAnyValueAsDynamic(innerTileSizes),
- innerDimsPos, outerDimsPerm);
+ inferPackedShape(asShapeWithAnyValueAsDynamic(sourceDims),
+ asShapeWithAnyValueAsDynamic(innerTileSizes),
+ innerDimsPos, outerDimsPerm);
// Fix-up `resultDims` to ensure that they are Value's if and only if the
// result type shape says it's a dynamic dim. This is needed as callers may
@@ -5461,15 +5633,21 @@ SmallVector<OpFoldResult> PackOp::getResultShape(
return resultDims;
}
-/// Get the expected packed type based on source type, tile factors, position of
-/// the inner tiles and permutation of the outer tiled loop.
-RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
+RankedTensorType PackOp::inferPackedTensorType(
+ RankedTensorType sourceType, ArrayRef<int64_t> innerTileSizes,
+ ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm) {
+ SmallVector<int64_t> resultShape = inferPackedShape(
+ sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
+ return RankedTensorType::get(resultShape, sourceType.getElementType());
+}
+
+MemRefType PackOp::inferPackedMemRefType(MemRefType sourceType,
ArrayRef<int64_t> innerTileSizes,
ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outerDimsPerm) {
- SmallVector<int64_t> resultShape = getPackOpResultTypeShape(
+ SmallVector<int64_t> resultShape = inferPackedShape(
sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
- return RankedTensorType::get(resultShape, sourceType.getElementType());
+ return MemRefType::get(resultShape, sourceType.getElementType());
}
Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
@@ -5518,6 +5696,45 @@ PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc,
getPaddingValue(), metadata.outerDimsPerm);
}
+template <typename OpTy>
+static void getPackUnPackEffectsImpl(
+ OpTy op, SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ // No memory effects for pure tensor semantics
+ if (op.hasPureTensorSemantics())
+ return;
+
+ for (OpOperand &opOperand : op.getOperation()->getOpOperands()) {
+ if (!llvm::isa<MemRefType>(opOperand.get().getType()))
+ continue;
+
+ if (&opOperand == &op.getSourceMutable()) {
+ effects.emplace_back(MemoryEffects::Read::get(), &opOperand, /*stage=*/0,
+ /*effectOnFullRegion=*/true,
+ SideEffects::DefaultResource::get());
+ } else if (&opOperand == &op.getDestMutable()) {
+ effects.emplace_back(MemoryEffects::Read::get(), &opOperand, /*stage=*/0,
+ /*effectOnFullRegion=*/true,
+ SideEffects::DefaultResource::get());
+ effects.emplace_back(MemoryEffects::Write::get(), &opOperand, /*stage=*/0,
+ /*effectOnFullRegion=*/true,
+ SideEffects::DefaultResource::get());
+ }
+ }
+}
+
+void PackOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ getPackUnPackEffectsImpl(*this, effects);
+}
+
+void UnPackOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ getPackUnPackEffectsImpl(*this, effects);
+}
+
/// Returns true if the tiles and the tiled dims are constant.
template <typename OpTy>
static bool areTilesAndTiledDimsAllConstant(OpTy op) {
@@ -5537,6 +5754,8 @@ static bool areTilesAndTiledDimsAllConstant(OpTy op) {
}
Speculation::Speculatability PackOp::getSpeculatability() {
+ if (!hasPureTensorSemantics())
+ return Speculation::NotSpeculatable;
if (getPaddingValue())
return Speculation::Speculatable;
@@ -5627,6 +5846,10 @@ static bool inferStaticShape(PackOp packOp, SmallVectorImpl<int64_t> &srcShape,
}
LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
+ // TODO: Support Memref PackOp. Temporarily return failure.
+ if (!packOp.hasPureTensorSemantics())
+ return failure();
+
// Fold an pack(unpack(x)) to x.
if (auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
if (unPackOp.getSourceType() == packOp.getDestType() &&
@@ -5657,7 +5880,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
tensor::CastOp::create(rewriter, loc, newSrcType, packOp.getSource());
}
Value dest = packOp.getDest();
- RankedTensorType originalResultType = packOp.getDestType();
+ ShapedType originalResultType = packOp.getDestType();
bool needUpdateDestType = (destShape != originalResultType.getShape());
if (needUpdateDestType) {
auto newDestType = packOp.getDestType().clone(destShape);
@@ -5672,9 +5895,9 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
// Insert a cast if needed
if (needUpdateDestType) {
rewriter.setInsertionPointAfter(packOp);
- auto castOp =
- tensor::CastOp::create(rewriter, loc, originalResultType, packOp);
- rewriter.replaceAllUsesExcept(packOp, castOp, castOp);
+ auto castOp = tensor::CastOp::create(rewriter, loc, originalResultType,
+ packOp.getResult());
+ rewriter.replaceAllUsesExcept(packOp.getResult(), castOp, castOp);
}
return success();
}
@@ -5683,8 +5906,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
}
template <typename PackOrUnpackOp>
-static bool isLikePadUnPad(PackOrUnpackOp packOp,
- RankedTensorType packedTensorType) {
+static bool isLikePadUnPad(PackOrUnpackOp packOp, ShapedType packedTensorType) {
static_assert(std::is_same<PackOrUnpackOp, PackOp>::value ||
std::is_same<PackOrUnpackOp, UnPackOp>::value,
"Function meant for pack/unpack");
@@ -5717,19 +5939,25 @@ static bool isLikePadUnPad(PackOrUnpackOp packOp,
bool PackOp::isLikePad() {
auto packedTensorType =
- llvm::cast<RankedTensorType>((*this)->getResultTypes().front());
+ llvm::cast<ShapedType>((*this)->getResultTypes().front());
return isLikePadUnPad(*this, packedTensorType);
}
-OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
+::mlir::LogicalResult
+PackOp::fold(FoldAdaptor adaptor,
+ ::llvm::SmallVectorImpl<OpFoldResult> &results) {
+ if (!hasPureTensorSemantics())
+ return failure();
std::optional<Attribute> paddingValue;
if (auto pad = adaptor.getPaddingValue())
paddingValue = pad;
if (OpFoldResult reshapedSource = reshapeConstantSource(
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
- getDestType(), paddingValue))
- return reshapedSource;
- return {};
+ cast<TensorType>(getDestType()), paddingValue)) {
+ results.push_back(reshapedSource);
+ return success();
+ }
+ return failure();
}
/// Folds a tensor.cast op into a consuming PackOp op if the
@@ -5751,6 +5979,10 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
LogicalResult matchAndRewrite(PackOp op,
PatternRewriter &rewriter) const override {
+ // TODO: Support Memref PackOp. Temporarily return failure.
+ if (!op.hasPureTensorSemantics())
+ return failure();
+
if (!tensor::hasFoldableTensorCastOperand(op))
return failure();
@@ -5793,12 +6025,141 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
void UnPackOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
- setNameFn(getResult(), "unpack");
+ if (!getResults().empty())
+ setNameFn(getResult(), "unpack");
+}
+
+// Custom parser for UnPackOp that handles the memref/tensor case distinction
+ParseResult UnPackOp::parse(OpAsmParser &parser, OperationState &result) {
+ OpAsmParser::UnresolvedOperand source, dest;
+ SmallVector<OpAsmParser::UnresolvedOperand> dynamicTiles;
+ SmallVector<int64_t> staticTiles;
+ DenseI64ArrayAttr innerDimsPos, outerDimsPerm;
+ Type sourceType, destType, resultType;
+
+ if (parser.parseOperand(source))
+ return failure();
+
+ if (succeeded(parser.parseOptionalKeyword("outer_dims_perm"))) {
+ if (parser.parseEqual())
+ return failure();
+
+ SmallVector<int64_t> outerDimsPermVec;
+ if (parser.parseCommaSeparatedList(AsmParser::Delimiter::Square, [&]() {
+ int64_t value;
+ if (parser.parseInteger(value))
+ return failure();
+ outerDimsPermVec.push_back(value);
+ return success();
+ }))
+ return failure();
+ outerDimsPerm = parser.getBuilder().getDenseI64ArrayAttr(outerDimsPermVec);
+ }
+
+ if (parser.parseKeyword("inner_dims_pos") || parser.parseEqual())
+ return failure();
+
+ SmallVector<int64_t> innerDimsPosVec;
+ if (parser.parseCommaSeparatedList(AsmParser::Delimiter::Square, [&]() {
+ int64_t value;
+ if (parser.parseInteger(value))
+ return failure();
+ innerDimsPosVec.push_back(value);
+ return success();
+ }))
+ return failure();
+ innerDimsPos = parser.getBuilder().getDenseI64ArrayAttr(innerDimsPosVec);
+
+ if (parser.parseKeyword("inner_tiles") || parser.parseEqual())
+ return failure();
+
+ DenseI64ArrayAttr staticTilesAttr;
+ if (parseDynamicIndexList(parser, dynamicTiles, staticTilesAttr))
+ return failure();
+ for (auto val : staticTilesAttr.asArrayRef())
+ staticTiles.push_back(val);
+
+ if (parser.parseKeyword("into") || parser.parseOperand(dest))
+ return failure();
+
+ if (parser.parseOptionalAttrDict(result.attributes))
+ return failure();
+
+ if (parser.parseColon() || parser.parseType(sourceType))
+ return failure();
+
+ bool hasArrow = succeeded(parser.parseOptionalArrow());
+ if (hasArrow) {
+ if (parser.parseType(destType))
+ return failure();
+ }
+
+ bool isMemRef = llvm::isa<MemRefType>(sourceType);
+ if (!hasArrow) {
+ return parser.emitError(parser.getCurrentLocation(),
+ "pack/unpack requires '->' and destination type");
+ }
+
+ if (!isMemRef)
+ resultType = destType;
+
+ if (parser.resolveOperand(source, sourceType, result.operands) ||
+ parser.resolveOperand(dest, destType, result.operands))
+ return failure();
+
+ if (!dynamicTiles.empty() &&
+ parser.resolveOperands(dynamicTiles, parser.getBuilder().getIndexType(),
+ result.operands))
+ return failure();
+
+ result.addAttribute("static_inner_tiles",
+ parser.getBuilder().getDenseI64ArrayAttr(staticTiles));
+ result.addAttribute("inner_dims_pos", innerDimsPos);
+ if (outerDimsPerm)
+ result.addAttribute("outer_dims_perm", outerDimsPerm);
+
+ SmallVector<int32_t> segmentSizes = {
+ 1, 1, 0, static_cast<int32_t>(dynamicTiles.size())};
+ result.addAttribute("operandSegmentSizes",
+ parser.getBuilder().getDenseI32ArrayAttr(segmentSizes));
+
+ if (!isMemRef)
+ result.addTypes(resultType);
+
+ return success();
+}
+
+void UnPackOp::print(OpAsmPrinter &p) {
+ p << " " << getSource();
+
+ if (!getOuterDimsPerm().empty()) {
+ p << " outer_dims_perm = [";
+ llvm::interleaveComma(getOuterDimsPerm(), p);
+ p << "]";
+ }
+
+ p << " inner_dims_pos = [";
+ llvm::interleaveComma(getInnerDimsPos(), p);
+ p << "]";
+
+ p << " inner_tiles = ";
+ printDynamicIndexList(p, *this, getInnerTiles(), getStaticInnerTilesAttr());
+
+ p << " into " << getDest();
+
+ p.printOptionalAttrDict((*this)->getAttrs(),
+ {"static_inner_tiles", "inner_dims_pos",
+ "outer_dims_perm", "operandSegmentSizes"});
+
+ p << " : " << getSource().getType();
+ p << " -> " << getDest().getType();
}
LogicalResult
UnPackOp::reifyResultShapes(OpBuilder &builder,
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+ if (!hasPureTensorSemantics())
+ return failure();
return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
}
@@ -5843,6 +6204,8 @@ LogicalResult UnPackOp::verify() {
}
Speculation::Speculatability UnPackOp::getSpeculatability() {
+ if (!hasPureTensorSemantics())
+ return Speculation::NotSpeculatable;
// See PackOp::getSpeculatability.
if (!areTilesAndTiledDimsAllConstant(*this))
return Speculation::NotSpeculatable;
@@ -5949,6 +6312,10 @@ static bool inferStaticShape(UnPackOp op, SmallVectorImpl<int64_t> &srcShape,
LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
PatternRewriter &rewriter) {
+ // TODO: Support Memref UnPackOp. Temporarily return failure.
+ if (!unPackOp.hasPureTensorSemantics())
+ return failure();
+
/// unpack(pack(x)) -> x
if (PackOp packOp = unPackOp.getSource().getDefiningOp<PackOp>()) {
if (packOp.getSourceType() != unPackOp.getDestType())
@@ -6005,11 +6372,11 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
dest = tensor::CastOp::create(rewriter, loc, newDestType,
unPackOp.getDest());
}
- Value newOp = UnPackOp::create(
+ UnPackOp newOp = UnPackOp::create(
rewriter, loc, source, dest, unPackOp.getInnerDimsPos(),
unPackOp.getMixedTiles(), unPackOp.getOuterDimsPerm());
rewriter.replaceOpWithNewOp<tensor::CastOp>(
- unPackOp, unPackOp.getResult().getType(), newOp);
+ unPackOp, unPackOp.getResult().getType(), newOp.getResult());
return success();
}
@@ -6043,16 +6410,24 @@ bool UnPackOp::canFoldSliceOp(tensor::ExtractSliceOp sliceOp) {
}
bool UnPackOp::isLikeUnPad() {
- RankedTensorType packedTensorType = getSourceType();
+ ShapedType packedTensorType = getSourceType();
return isLikePadUnPad(*this, packedTensorType);
}
-OpFoldResult UnPackOp::fold(FoldAdaptor adaptor) {
+::mlir::LogicalResult
+UnPackOp::fold(FoldAdaptor adaptor,
+ ::llvm::SmallVectorImpl<OpFoldResult> &results) {
+ // TODO: Support Memref UnPackOp. Temporarily return failure.
+ if (!hasPureTensorSemantics())
+ return failure();
+
if (OpFoldResult reshapedSource = reshapeConstantSource(
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
- getResult().getType()))
- return reshapedSource;
- return {};
+ cast<TensorType>(getResult().getType()))) {
+ results.push_back(reshapedSource);
+ return success();
+ }
+ return failure();
}
/// Folds a tensor.cast op into a consuming UnPackOp op if the
@@ -6074,6 +6449,10 @@ struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
LogicalResult matchAndRewrite(UnPackOp op,
PatternRewriter &rewriter) const override {
+ // TODO: Support Memref UnPackOp. Temporarily return failure.
+ if (!op.hasPureTensorSemantics())
+ return failure();
+
if (!tensor::hasFoldableTensorCastOperand(op))
return failure();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
index 6912da3ffbc83..6ea1eb50b13ce 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
@@ -90,6 +90,10 @@ transposePackedMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
linalg::PackOp packOp, AffineMap operandMap,
ArrayRef<unsigned> blocksStartDimPos,
bool transposeOuterBlocks, bool transposeInnerBlocks) {
+ // TODO: Support Memref PackOp. Temporarily return failure.
+ if (!packOp.hasPureTensorSemantics())
+ return failure();
+
assert(operandMap.getNumDims() >= 4 &&
"expected at least 4D prepacked matmul");
assert(blocksStartDimPos.size() >= 2 &&
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 419f6a0d3c010..d36ca43a6cbb3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -282,8 +282,8 @@ static bool getPackedOperandDetails(
});
bool requirePadding = linalg::PackOp::requirePaddingValueStrict(
inputType.getShape(), innerDimsPos,
- linalg::PackOp::inferPackedType(inputType, maybeIntInnerTileSizes,
- innerDimsPos, outerDimsPerm)
+ linalg::PackOp::inferPackedTensorType(inputType, maybeIntInnerTileSizes,
+ innerDimsPos, outerDimsPerm)
.getShape(),
outerDimsPerm, innerTileSizes);
currOperandDetails.innerDimsPos = innerDimsPos;
@@ -341,10 +341,11 @@ static std::tuple<Value, AffineMap> getOrCreatePackedViewOfOperand(
b, loc, opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm);
auto poison = ub::PoisonOp::create(
b, loc, getElementTypeOrSelf(opOperand->get().getType()));
- Value packedOperand =
+ PackOp packedOperand =
linalg::PackOp::create(b, loc, opOperand->get(), empty, innerDimsPos,
innerTileSizes, poison, outerDimsPerm);
- return std::make_tuple(packedOperand, currOperandDetails.indexingMap);
+ return std::make_tuple(packedOperand.getResult(),
+ currOperandDetails.indexingMap);
}
/// This function is a helper subroutine to pack a genericOp and return it. It
@@ -571,6 +572,9 @@ struct BubbleUpPackOpThroughGenericOpPattern
LogicalResult matchAndRewrite(linalg::PackOp packOp,
PatternRewriter &rewriter) const override {
+ if (!packOp.hasPureTensorSemantics())
+ return failure();
+
auto genericOp = bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn,
poisonPaddingOk);
if (failed(genericOp))
@@ -594,6 +598,9 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<linalg::PackOp> {
LogicalResult matchAndRewrite(linalg::PackOp packOp,
PatternRewriter &rewriter) const override {
+ if (!packOp.hasPureTensorSemantics())
+ return failure();
+
auto padOp = packOp.getSource().getDefiningOp<tensor::PadOp>();
if (!padOp)
return failure();
@@ -653,19 +660,19 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<linalg::PackOp> {
lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
highPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
- auto newPadOp =
- tensor::PadOp::create(rewriter, loc, /*result=*/Type(), sourcePack,
- lowPad, highPad, paddingVal, padOp.getNofold());
+ auto newPadOp = tensor::PadOp::create(
+ rewriter, loc, /*result=*/Type(), sourcePack.getResult(), lowPad,
+ highPad, paddingVal, padOp.getNofold());
// If the pad has more than one user, create an unpack on the new pad to
// replace the other uses.
if (!padOp->hasOneUse()) {
auto unpackEmpty = linalg::UnPackOp::createDestinationTensor(
rewriter, loc, newPadOp, mixedTiles, innerDimsPos, outerDimsPerm);
- Value unpackedPad =
+ UnPackOp unpackedPad =
linalg::UnPackOp::create(rewriter, loc, newPadOp, unpackEmpty,
innerDimsPos, mixedTiles, outerDimsPerm);
- rewriter.replaceAllUsesExcept(padOp, unpackedPad, sourcePack);
+ rewriter.replaceAllUsesExcept(padOp, unpackedPad.getResult(), sourcePack);
}
// Replace the pack with the new pad.
@@ -763,6 +770,9 @@ static LogicalResult
bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
linalg::PackOp packOp,
PatternRewriter &rewriter) {
+ if (!packOp.hasPureTensorSemantics())
+ return failure();
+
SmallVector<int64_t> innerTileSizes = packOp.getStaticTiles();
ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
@@ -812,8 +822,8 @@ bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
}
auto newCollapseOp = tensor::CollapseShapeOp::create(
- rewriter, collapseOp.getLoc(), packOp.getType(), newPackOp,
- newReassocIndices);
+ rewriter, collapseOp.getLoc(), packOp.getResult().getType(),
+ newPackOp.getResult(), newReassocIndices);
rewriter.replaceOp(packOp, newCollapseOp);
return success();
@@ -868,6 +878,9 @@ static LogicalResult
bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp,
linalg::PackOp packOp,
PatternRewriter &rewriter) {
+ if (!packOp.hasPureTensorSemantics())
+ return failure();
+
// Outer dimensions permutation is not supported currently.
// TODO: Handle outer_dims_perm variants.
ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
@@ -918,7 +931,7 @@ bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp,
// If reassociation is not possible, then reordering cannot happen.
// This can be caused by pack padding affecting previously expanded
// dimensions or packing extending dimensions.
- RankedTensorType newPackType = linalg::PackOp::inferPackedType(
+ RankedTensorType newPackType = linalg::PackOp::inferPackedTensorType(
expandOp.getSrcType(), packOp.getStaticInnerTiles(),
projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{});
auto reassocExpand =
@@ -930,14 +943,14 @@ bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp,
Value destTensor = linalg::PackOp::createDestinationTensor(
rewriter, packOp.getLoc(), expandOp.getSrc(), packOp.getMixedTiles(),
projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{});
- Value packedVal = linalg::PackOp::create(
+ PackOp packedVal = linalg::PackOp::create(
rewriter, packOp.getLoc(), expandOp.getSrc(), destTensor,
projectedInnerDimsPos, packOp.getMixedTiles(), packOp.getPaddingValue(),
/*outerDimsPerm=*/SmallVector<int64_t>{});
- Value newExpandOp = tensor::ExpandShapeOp::create(rewriter, packOp.getLoc(),
- packOp.getDestType(),
- packedVal, *reassocExpand);
+ Value newExpandOp = tensor::ExpandShapeOp::create(
+ rewriter, packOp.getLoc(), packOp.getDestType(), packedVal.getResult(),
+ *reassocExpand);
rewriter.replaceOp(packOp, newExpandOp);
return success();
@@ -951,6 +964,9 @@ class BubbleUpPackOpThroughReshapeOp final
LogicalResult matchAndRewrite(linalg::PackOp packOp,
PatternRewriter &rewriter) const override {
+ if (!packOp.hasPureTensorSemantics())
+ return failure();
+
Operation *srcOp = packOp.getSource().getDefiningOp();
// Currently only support when the pack op is the only user.
if (!srcOp || !(srcOp->getNumResults() == 1) ||
@@ -1001,6 +1017,9 @@ class BubbleUpPackOpThroughReshapeOp final
static LogicalResult pushDownUnPackOpThroughExpandShape(
linalg::UnPackOp unPackOp, tensor::ExpandShapeOp expandOp,
PatternRewriter &rewriter, ControlPropagationFn controlFn) {
+ if (!unPackOp.hasPureTensorSemantics())
+ return failure();
+
// User controlled propagation function.
if (!controlFn(&expandOp.getSrcMutable()))
return failure();
@@ -1048,7 +1067,7 @@ static LogicalResult pushDownUnPackOpThroughExpandShape(
nextPos += 1;
}
- RankedTensorType newExpandType = linalg::PackOp::inferPackedType(
+ RankedTensorType newExpandType = linalg::PackOp::inferPackedTensorType(
expandTy, innerTileSizes, projectedInnerDimsPos, newOuterDimsPerm);
auto newExpandOp =
tensor::ExpandShapeOp::create(rewriter, expandOp.getLoc(), newExpandType,
@@ -1075,6 +1094,9 @@ class PushDownUnPackOpThroughReshapeOp final
LogicalResult matchAndRewrite(linalg::UnPackOp unPackOp,
PatternRewriter &rewriter) const override {
+ if (!unPackOp.hasPureTensorSemantics())
+ return failure();
+
Value result = unPackOp.getResult();
// Currently only support unpack op with the single user.
if (!result.hasOneUse()) {
@@ -1274,6 +1296,9 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
if (!unpackOp)
return failure();
+ if (!unpackOp.hasPureTensorSemantics())
+ return failure();
+
if (!controlFn(&padOp.getSourceMutable()))
return failure();
@@ -1313,7 +1338,7 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
tensor::EmptyOp::create(rewriter, loc, padOp.getResultType().getShape(),
padOp.getResultType().getElementType());
- Value replacement = linalg::UnPackOp::create(
+ UnPackOp replacement = linalg::UnPackOp::create(
rewriter, loc, newPadOp.getResult(), outputUnPack, innerDimsPos,
unpackOp.getMixedTiles(), outerDimsPerm);
rewriter.replaceOp(padOp, replacement);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
index 1d4c11e418006..993eae62535c3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
@@ -11,6 +11,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/PatternMatch.h"
namespace mlir {
@@ -110,8 +111,11 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
PatternRewriter &rewriter) const override {
if (packOp.getPaddingValue())
return rewriter.notifyMatchFailure(packOp, "expects no padding value");
+ // TODO: Support Memref PackOp. Temporarily return failure.
+ if (!packOp.hasPureTensorSemantics())
+ return failure();
- RankedTensorType sourceType = packOp.getSourceType();
+ ShapedType sourceType = packOp.getSourceType();
if (failed(isPackOnInnerMostDim(rewriter, packOp)) &&
failed(isPackOn1D(rewriter, packOp, sourceType.getShape(),
packOp.getStaticTiles())) &&
@@ -119,7 +123,7 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
return failure();
}
- RankedTensorType destType = packOp.getDestType();
+ ShapedType destType = packOp.getDestType();
auto reassociation =
getReassociationIndicesForReshape(sourceType, destType);
if (!reassociation)
@@ -157,8 +161,8 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
"expects outer_dims_perm is empty or an identity permutation");
}
- RankedTensorType sourceType = unpackOp.getSourceType();
- RankedTensorType destType = unpackOp.getDestType();
+ ShapedType sourceType = unpackOp.getSourceType();
+ ShapedType destType = unpackOp.getDestType();
if (!sourceType.hasStaticShape() || !destType.hasStaticShape())
return rewriter.notifyMatchFailure(unpackOp, "expects static shapes");
@@ -173,7 +177,11 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
LogicalResult matchAndRewrite(UnPackOp unpackOp,
PatternRewriter &rewriter) const override {
- RankedTensorType destType = unpackOp.getDestType();
+ // TODO: Support Memref UnPackOp. Temporarily return failure.
+ if (!unpackOp.hasPureTensorSemantics())
+ return failure();
+
+ ShapedType destType = unpackOp.getDestType();
if (failed(isUnpackOnInnerMostDim(rewriter, unpackOp)) &&
failed(isPackOn1D(rewriter, unpackOp, destType.getShape(),
unpackOp.getStaticTiles())) &&
@@ -181,7 +189,7 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
return failure();
}
- RankedTensorType sourceType = unpackOp.getSourceType();
+ ShapedType sourceType = unpackOp.getSourceType();
auto reassociation =
getReassociationIndicesForReshape(sourceType, destType);
if (!reassociation)
@@ -225,7 +233,7 @@ struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
// sizes - that is because it would be impossible to compute the padding
// size and hence to establish whether "artificial" padding would be
// created.
- RankedTensorType unpackedType = packOp.getSourceType();
+ ShapedType unpackedType = packOp.getSourceType();
SmallVector<int64_t> outerShapeWithoutTranspose =
getPackedOuterShapeWithoutTransposition(packOp);
for (auto [pos, tileSize, high] :
@@ -274,6 +282,10 @@ struct FoldUnpackWithExtractSliceOp
if (!unpackOp)
return failure();
+ // TODO: Support Memref UnPackOp. Temporarily return failure.
+ if (!unpackOp.hasPureTensorSemantics())
+ return failure();
+
// User controlled folding function.
if (controlFn && !controlFn(&sliceOp.getSourceMutable()))
return failure();
@@ -336,6 +348,10 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
if (!packOp)
return failure();
+ // TODO: Support Memref PackOp. Temporarily return failure.
+ if (!packOp.hasPureTensorSemantics())
+ return failure();
+
// User controlled folding function.
if (controlFn && !controlFn(&linalgOp->getOpOperand(0)))
return failure();
@@ -395,6 +411,10 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
LogicalResult matchAndRewrite(PackOp packOp,
PatternRewriter &rewriter) const override {
+ // TODO: Support Memref PackOp. Temporarily return failure.
+ if (!packOp.hasPureTensorSemantics())
+ return failure();
+
auto linalgOp = packOp.getSource().getDefiningOp<linalg::LinalgOp>();
if (!linalgOp)
return failure();
@@ -456,6 +476,10 @@ struct FoldProducerUnPackWithConsumerLinalgTransposeOp
if (!unPackOp)
return failure();
+ // TODO: Support Memref UnPackOp. Temporarily return failure.
+ if (!unPackOp.hasPureTensorSemantics())
+ return failure();
+
// User controlled folding function.
if (controlFn && !controlFn(&linalgOp->getOpOperand(0)))
return failure();
@@ -504,6 +528,10 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
LogicalResult matchAndRewrite(UnPackOp unPackOp,
PatternRewriter &rewriter) const override {
+ // TODO: Support Memref UnPackOp. Temporarily return failure.
+ if (!unPackOp.hasPureTensorSemantics())
+ return failure();
+
auto linalgOp = unPackOp.getSource().getDefiningOp<linalg::LinalgOp>();
if (!linalgOp)
return failure();
@@ -568,6 +596,10 @@ struct FoldEmptyTensorWithPackOp : public OpRewritePattern<PackOp> {
LogicalResult matchAndRewrite(PackOp packOp,
PatternRewriter &rewriter) const override {
+ // TODO: Support Memref PackOp. Temporarily return failure.
+ if (!packOp.hasPureTensorSemantics())
+ return failure();
+
// Check for tensor.empty source.
auto emptyOp = packOp.getSource().getDefiningOp<tensor::EmptyOp>();
if (!emptyOp)
@@ -592,6 +624,10 @@ struct FoldEmptyTensorWithUnPackOp : public OpRewritePattern<UnPackOp> {
LogicalResult matchAndRewrite(UnPackOp unPackOp,
PatternRewriter &rewriter) const override {
+ // TODO: Support Memref UnPackOp. Temporarily return failure.
+ if (!unPackOp.hasPureTensorSemantics())
+ return failure();
+
// Check for tensor.empty source.
auto emptyOp = unPackOp.getSource().getDefiningOp<tensor::EmptyOp>();
if (!emptyOp)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 50a84ace09258..5d39c4731dd1b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -766,6 +766,10 @@ struct PackOpTiling
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) const {
auto packOp = cast<PackOp>(op);
+ // TODO: Support Memref PackOp. Temporarily return failure.
+ if (!packOp.hasPureTensorSemantics())
+ return failure();
+
Location loc = packOp.getLoc();
// The tiling is applied on interchanged dimensions. We have to undo the
@@ -1010,6 +1014,10 @@ struct PackOpTiling
ArrayRef<OpFoldResult> sizes(allSizes[0]);
auto packOp = cast<PackOp>(op);
+ // TODO: Support Memref UnPackOp. Temporarily return failure.
+ if (!packOp.hasPureTensorSemantics())
+ return failure();
+
Location loc = packOp.getLoc();
int64_t inputRank = packOp.getSourceRank();
@@ -1189,6 +1197,10 @@ struct UnPackOpTiling
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) const {
auto unpackOp = cast<UnPackOp>(op);
+ // TODO: Support Memref UnPackOp. Temporarily return failure.
+ if (!unpackOp.hasPureTensorSemantics())
+ return failure();
+
int64_t srcRank = unpackOp.getSourceRank();
int64_t destRank = unpackOp.getDestRank();
int64_t numInnerTiles = srcRank - destRank;
@@ -1359,6 +1371,10 @@ struct UnPackOpTiling
return failure();
}
auto unPackOp = cast<UnPackOp>(op);
+ // TODO: Support Memref UnPackOp. Temporarily return failure.
+ if (!unPackOp.hasPureTensorSemantics())
+ return failure();
+
ArrayRef<OpFoldResult> offsets(allOffsets[0]);
ArrayRef<OpFoldResult> sizes(allSizes[0]);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index fc7cdad0ee33d..48ebd1644bbef 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -26,6 +26,8 @@
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
@@ -218,6 +220,10 @@ struct PackedOperandsDimList {
FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
linalg::PackOp packOp,
bool lowerPadLikeWithInsertSlice) {
+ // TODO: Support Memref PackOp. Temporarily return failure.
+ if (!packOp.hasPureTensorSemantics())
+ return failure();
+
// 1. Filter out NYI cases.
auto packedTensorType =
cast<RankedTensorType>(packOp->getResultTypes().front());
@@ -345,11 +351,15 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
FailureOr<LowerUnPackOpResult>
linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
bool lowerUnpadLikeWithExtractSlice) {
+ // TODO: Support Memref UnPackOp. Temporarily return failure.
+ if (!unPackOp.hasPureTensorSemantics())
+ return failure();
+
Location loc = unPackOp->getLoc();
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(unPackOp);
- RankedTensorType packedTensorType = unPackOp.getSourceType();
+ auto packedTensorType = cast<RankedTensorType>(unPackOp.getSourceType());
int64_t packedRank = packedTensorType.getRank();
OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
@@ -549,7 +559,7 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
packOps.push_back(linalg::PackOp::create(
rewriter, loc, operand, dest, innerPos, innerPackSizes, zero));
}
- inputsAndInits.push_back(packOps.back());
+ inputsAndInits.push_back(packOps.back().getResult());
}
}
@@ -576,7 +586,7 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
unPackOps.push_back(linalg::UnPackOp::create(
rewriter, packedLinalgOp->getLoc(), result, maybePackedInit.getSource(),
maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles()));
- results.push_back(unPackOps.back());
+ results.push_back(unPackOps.back().getResult());
}
// Step 5. Replace `linalgOp`.
@@ -666,7 +676,7 @@ linalg::packTranspose(RewriterBase &rewriter, linalg::PackOp packOp,
linalg::PackOp transposedPackOp =
packOp.createTransposedClone(rewriter, loc, innerPerm, outerPerm);
- if (!packOp.getResult().hasOneUse())
+ if (packOp.hasPureBufferSemantics() || !packOp.getResult().hasOneUse())
return rewriter.notifyMatchFailure(linalgOp, "expect single pack use");
OpOperand &packUse = *packOp->getUses().begin();
@@ -727,7 +737,10 @@ linalg::packTranspose(RewriterBase &rewriter, linalg::PackOp packOp,
}
// Step 4. Finally, replace packOp now that we don't need it anymore.
- rewriter.replaceOp(packOp, transposedPackOp->getResults());
+ if (packOp.hasPureTensorSemantics())
+ rewriter.replaceOp(packOp, transposedPackOp->getResults());
+ else
+ rewriter.eraseOp(packOp);
return PackTransposeResult{transposedPackOp, transposedLinalgOp,
transposedUnPackOp};
@@ -1019,6 +1032,10 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
linalg::PackOp packOp) {
Value input = packOp.getSource();
+ // TODO: Support Memref PackOp. Temporarily return just Op Source.
+ if (!packOp.hasPureTensorSemantics())
+ return input;
+
if (!packOp.getPaddingValue()) {
return input;
}
@@ -1135,6 +1152,10 @@ getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape,
LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
linalg::PackOp packOp, PatternRewriter &rewriter) const {
+ // TODO: Support Memref PackOp. Temporarily return failure.
+ if (!packOp.hasPureTensorSemantics())
+ return failure();
+
if (llvm::any_of(packOp.getTiledOuterDims(),
[](int64_t dim) { return dim != 1; })) {
return rewriter.notifyMatchFailure(
@@ -1156,8 +1177,8 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
// Check whether this dim has been permuted. Permuting unit dims is fine
// as that's effectively a no-op.
- if (dim < prev && (packOp.getType().getShape()[prev] != 1 ||
- packOp.getType().getShape()[dim] != 1))
+ if (dim < prev && (packOp.getResult().getType().getShape()[prev] != 1 ||
+ packOp.getResult().getType().getShape()[dim] != 1))
return false;
prev = dim;
@@ -1274,6 +1295,9 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
linalg::UnPackOp unpackOp, PatternRewriter &rewriter) const {
+ if (!unpackOp.hasPureTensorSemantics())
+ return failure();
+
int64_t destRank = unpackOp.getDestRank();
ArrayRef<int64_t> srcShape = unpackOp.getSourceType().getShape();
ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index c7d5dff74c5a9..6f73f4f57e50d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1877,8 +1877,9 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
SmallVector<int64_t> preTransposeWriteVecSizses(writeVectorSizes);
auto destInvPermutation = getPackInverseDestPerm(packOp, packMetadata);
applyPermutationToVector(preTransposeWriteVecSizses, destInvPermutation);
- auto preTransposeWriteVecType = VectorType::get(
- preTransposeWriteVecSizses, packOp.getType().getElementType());
+ auto preTransposeWriteVecType =
+ VectorType::get(preTransposeWriteVecSizses,
+ packOp.getResult().getType().getElementType());
// Compute vector type for the _read_ opeartion. This is simply
// pre-transpose-write-vector-type with the dimensions collapsed
@@ -1954,7 +1955,7 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(unpackOp);
- RankedTensorType unpackTensorType = unpackOp.getSourceType();
+ ShapedType unpackTensorType = unpackOp.getSourceType();
ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
bool useInBoundsInsteadOfMasking = false;
@@ -2117,6 +2118,10 @@ vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
static LogicalResult
vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp,
ArrayRef<int64_t> inputVectorSizes) {
+ // TODO: Support Memref UnPackOp. Temporarily return failure.
+ if (!unpackOp.hasPureTensorSemantics())
+ return failure();
+
// If there are no input vector sizes and all shapes are static, there is
// nothing left to check.
if (inputVectorSizes.empty() && unpackOp.getDestType().hasStaticShape() &&
@@ -2454,6 +2459,10 @@ static LogicalResult vectorizeLinalgOpPrecondition(
static LogicalResult
vectorizePackOpPrecondition(linalg::PackOp packOp,
ArrayRef<int64_t> inputVectorSizes) {
+ // TODO: Support Memref PackOp. Temporarily return failure.
+ if (!packOp.hasPureTensorSemantics())
+ return failure();
+
auto padValue = packOp.getPaddingValue();
Attribute cstAttr;
// TODO: Relax this condiiton
diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index 0a97bc03f584b..1f6d1a68fbbb8 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -308,9 +308,12 @@ def pack(
_inner_tiles,
static_inner_tiles,
) = _dispatch_mixed_values(inner_tiles)
+ dest = _get_op_result_or_value(dest)
+ result_type = dest.type if isinstance(dest.type, RankedTensorType) else None
return _get_op_result_or_op_results(
PackOp(
+ result=result_type,
source=source,
dest=dest,
inner_dims_pos=inner_dims_pos,
@@ -340,9 +343,11 @@ def unpack(
_inner_tiles,
static_inner_tiles,
) = _dispatch_mixed_values(inner_tiles)
-
+ dest = _get_op_result_or_value(dest)
+ result_type = dest.type if isinstance(dest.type, RankedTensorType) else None
return _get_op_result_or_op_results(
UnPackOp(
+ result=result_type,
source=source,
dest=dest,
inner_dims_pos=inner_dims_pos,
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index f4020ede4854e..08eb40d4cb442 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -2058,3 +2058,59 @@ func.func @no_fold_extract_slice_into_unpack_non_zero_offset(
// CHECK-SAME: into %[[DEST]]
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[UNPACK]]
// CHECK: return %[[SLICE]]
+
+// -----
+
+// CHECK-LABEL: func.func @fold_cast_unpack_dynamic_tile_size(
+// CHECK-SAME: %[[SRC:.*]]: tensor<1x1x8x1xi32>,
+// CHECK-SAME: %[[DEST:.*]]: tensor<7x?xi32>) -> tensor<7x?xi32> {
+// CHECK: %[[RES:.*]] = linalg.unpack %[[SRC]] inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %[[DEST]] {test_attr} : tensor<1x1x8x1xi32> -> tensor<7x?xi32>
+func.func @fold_cast_unpack_dynamic_tile_size(
+ %src: tensor<1x1x8x1xi32>,
+ %res: tensor<7x?xi32>) -> tensor<7x?xi32> {
+
+ %cast = tensor.cast %src : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
+ %c8 = arith.constant 8 : index
+ %unpack = linalg.unpack %cast
+ inner_dims_pos = [0, 1]
+ inner_tiles = [%c8, 1]
+ into %res {test_attr} : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
+ return %unpack : tensor<7x?xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @fold_pack_unpack_tensor
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x3xf32>) -> tensor<2x3xf32>
+// CHECK: return %[[ARG0]] : tensor<2x3xf32>
+func.func @fold_pack_unpack_tensor(%x: tensor<2x3xf32>) -> tensor<2x3xf32> {
+ %unpacked = linalg.unpack %x outer_dims_perm = [] inner_dims_pos = [] inner_tiles = []
+ into %x : tensor<2x3xf32> -> tensor<2x3xf32>
+ %packed = linalg.pack %unpacked outer_dims_perm = [] inner_dims_pos = [] inner_tiles = []
+ into %x : tensor<2x3xf32> -> tensor<2x3xf32>
+ return %packed : tensor<2x3xf32>
+}
+
+// -----
+
+// Test that pack/unpack canonicalization is disabled for memref versions.
+// CHECK-LABEL: func.func @pack_unpack_memref_no_canonicalization
+// CHECK: linalg.pack
+// CHECK: linalg.unpack
+func.func @pack_unpack_memref_no_canonicalization(%source: memref<128x256xf32>, %packed: memref<16x8x8x32xf32>, %dest: memref<128x256xf32>) {
+ linalg.pack %source inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %packed : memref<128x256xf32> -> memref<16x8x8x32xf32>
+ linalg.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %dest : memref<16x8x8x32xf32> -> memref<128x256xf32>
+ return
+}
+
+// -----
+
+// Test that unpack/pack canonicalization is disabled for memref versions.
+// CHECK-LABEL: func.func @unpack_pack_memref_no_canonicalization
+// CHECK: linalg.unpack
+// CHECK: linalg.pack
+func.func @unpack_pack_memref_no_canonicalization(%packed: memref<16x8x8x32xf32>, %unpacked: memref<128x256xf32>, %dest: memref<16x8x8x32xf32>) {
+ linalg.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %unpacked : memref<16x8x8x32xf32> -> memref<128x256xf32>
+ linalg.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %dest : memref<128x256xf32> -> memref<16x8x8x32xf32>
+ return
+}
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 74928920c695a..bfb92c3289a49 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -755,3 +755,77 @@ func.func @conv2d_channel_first_q_promote(%img: tensor<100x3x224x224xi8>, %filt:
// CHECK-LABEL: func @conv2d_channel_first_q_promote(
// CHECK: %[[arg0:[a-zA-z0-9]*]]: tensor<100x3x224x224xi8>, %[[arg1:[a-zA-z0-9]*]]: tensor<64x3x5x5xi8>, %[[arg2:[a-zA-z0-9]*]]: i8, %[[arg3:[a-zA-z0-9]*]]: i8)
// CHECK: linalg.conv_2d_nchw_fchw_q {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]] : tensor<100x3x224x224xi8>, tensor<64x3x5x5xi8>, i8, i8) outs(%{{.*}} : tensor<100x64x220x220xi32>) -> tensor<100x64x220x220xi32>
+
+// -----
+
+func.func @pack_memref(%source: memref<128x256xf32>, %dest: memref<8x16x8x32xf32>) {
+ linalg.pack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [8, 32]
+ into %dest : memref<128x256xf32> -> memref<8x16x8x32xf32>
+ return
+}
+
+// CHECK-LABEL: func @pack_memref(
+// CHECK: %[[source:[a-zA-z0-9]*]]: memref<128x256xf32>, %[[dest:[a-zA-z0-9]*]]: memref<8x16x8x32xf32>) {
+// CHECK: linalg.pack %[[source]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %[[dest]] : memref<128x256xf32> -> memref<8x16x8x32xf32>
+// CHECK: return
+// CHECK: }
+// -----
+
+func.func @unpack_memref(%source: memref<16x8x8x32xf32>, %dest: memref<128x256xf32>) {
+ linalg.unpack %source inner_dims_pos = [0, 1] inner_tiles = [8, 32]
+ into %dest : memref<16x8x8x32xf32> -> memref<128x256xf32>
+ return
+}
+
+// CHECK-LABEL: func @unpack_memref(
+// CHECK: %[[source:[a-zA-z0-9]*]]: memref<16x8x8x32xf32>, %[[dest:[a-zA-z0-9]*]]: memref<128x256xf32>) {
+// CHECK: linalg.unpack %[[source]] inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %[[dest]] : memref<16x8x8x32xf32> -> memref<128x256xf32>
+// CHECK: return
+
+// -----
+
+// CHECK-LABEL: func @test_pack_memref
+func.func @test_pack_memref(%arg0: memref<128x256xf32>, %arg1: memref<16x8x8x32xf32>) {
+ // CHECK-NOT: %{{.*}} = linalg.pack
+ // CHECK: linalg.pack %{{.*}} inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %{{.*}} : memref<128x256xf32> -> memref<16x8x8x32xf32>
+ linalg.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %arg1 : memref<128x256xf32> -> memref<16x8x8x32xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @test_unpack_memref
+func.func @test_unpack_memref(%arg0: memref<16x8x8x32xf32>, %arg1: memref<128x256xf32>) {
+ // CHECK: linalg.unpack %{{.*}} inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %{{.*}} : memref<16x8x8x32xf32>
+ linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %arg1 : memref<16x8x8x32xf32> -> memref<128x256xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @test_pack_memref_with_padding
+func.func @test_pack_memref_with_padding(%arg0: memref<127x255xf32>, %arg1: memref<16x8x8x32xf32>, %pad: f32) {
+ // CHECK: linalg.pack %{{.*}} padding_value(%{{.*}} : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %{{.*}} : memref<127x255xf32>
+ linalg.pack %arg0 padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %arg1 : memref<127x255xf32> -> memref<16x8x8x32xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @test_pack_tensor
+func.func @test_pack_tensor(%arg0: tensor<128x256xf32>, %arg1: tensor<16x8x8x32xf32>) -> tensor<16x8x8x32xf32> {
+ // CHECK: %[[RESULT:.*]] = linalg.pack %{{.*}} inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %{{.*}} : tensor<128x256xf32> -> tensor<16x8x8x32xf32>
+ %0 = linalg.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %arg1 : tensor<128x256xf32> -> tensor<16x8x8x32xf32>
+ // CHECK: return %[[RESULT]] : tensor<16x8x8x32xf32>
+ return %0 : tensor<16x8x8x32xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @test_unpack_tensor
+func.func @test_unpack_tensor(%arg0: tensor<16x8x8x32xf32>, %arg1: tensor<128x256xf32>) -> tensor<128x256xf32> {
+ // CHECK: %[[RESULT:.*]] = linalg.unpack %{{.*}} inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %{{.*}} : tensor<16x8x8x32xf32> -> tensor<128x256xf32>
+ %0 = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %arg1 : tensor<16x8x8x32xf32> -> tensor<128x256xf32>
+ // CHECK: return %[[RESULT]] : tensor<128x256xf32>
+ return %0 : tensor<128x256xf32>
+}
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index 92591cd59fb40..68b32098b7782 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -664,6 +664,20 @@ def tensor_pack(src, dst):
return unpacked
+ @func.FuncOp.from_py_func(
+ MemRefType.get((128, 128), f32),
+ MemRefType.get((16, 16, 8, 8), f32),
+ )
+ def memref_pack(src, dst):
+ linalg.pack(src, dst, inner_dims_pos=[1, 0], inner_tiles=[8, 8])
+
+ linalg.unpack(
+ dst,
+ src,
+ inner_dims_pos=[0, 1],
+ inner_tiles=[8, 8],
+ )
+
# CHECK-LABEL: func.func @tensor_pack(
# CHECK-SAME: %[[VAL_0:.*]]: tensor<128x128xf32>, %[[VAL_1:.*]]: tensor<16x16x8x8xf32>) -> tensor<128x128xf32> {
# CHECK: %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f32
@@ -671,6 +685,12 @@ def tensor_pack(src, dst):
# CHECK: %[[VAL_4:.*]] = linalg.unpack %[[VAL_3]] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %[[VAL_0]] : tensor<16x16x8x8xf32> -> tensor<128x128xf32>
# CHECK: return %[[VAL_4]] : tensor<128x128xf32>
# CHECK: }
+ # CHECK-LABEL: func.func @memref_pack(
+ # CHECK-SAME: %[[VAL_0:.*]]: memref<128x128xf32>, %[[VAL_1:.*]]: memref<16x16x8x8xf32>) {
+ # CHECK: linalg.pack %[[VAL_0]] inner_dims_pos = [1, 0] inner_tiles = [8, 8] into %[[VAL_1]] : memref<128x128xf32> -> memref<16x16x8x8xf32>
+ # CHECK: linalg.unpack %[[VAL_1]] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %[[VAL_0]] : memref<16x16x8x8xf32> -> memref<128x128xf32>
+ # CHECK: return
+ # CHECK: }
print(module)
More information about the Mlir-commits
mailing list