[Mlir-commits] [mlir] [MLIR][Linalg] pack, unpack to take memref inputs (PR #129036)
Hyunsung Lee
llvmlistbot at llvm.org
Sun Mar 30 04:38:55 PDT 2025
https://github.com/ita9naiwa updated https://github.com/llvm/llvm-project/pull/129036
>From 4d523adc3cf5eb581c43395e66aaa0012dbc179b Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Thu, 27 Feb 2025 19:54:30 +0900
Subject: [PATCH 01/23] draft
---
.../Dialect/Linalg/IR/LinalgRelayoutOps.td | 72 +++++++++++++++++--
.../Dialect/Linalg/IR/RelayoutOpInterface.td | 1 +
.../mlir/Dialect/Utils/ReshapeOpsUtils.h | 4 +-
mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp | 2 +-
4 files changed, 69 insertions(+), 10 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index 1e48a5e3a20ee..f8a4657c564ce 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -77,7 +77,20 @@ class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
/// with `inner_dims_pos` rather than the packed tensor.
SmallVector<int64_t> getTiledOuterDims();
}];
-
+ let extraClassDeclaration = commonExtraClassDeclaration # [{
+ ShapedType getInputType() {
+ return cast<ShapedType>(getInput().getType());
+ }
+ ShapedType getOutputType() {
+ return cast<ShapedType>(getOutput().getType());
+ }
+ int64_t getInputRank() {
+ return getInputType().getRank();
+ }
+ int64_t getOutputRank() {
+ return getOutputType().getRank();
+ }
+ }];
let hasVerifier = 1;
}
@@ -152,14 +165,14 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
// Note: Only tiled dimensions can be padded.
```
}];
- let arguments = (ins AnyRankedTensor:$source,
- AnyRankedTensor:$dest,
+ let arguments = (ins AnyShaped:$source,
+ AnyShaped:$dest,
Optional<AnyType>:$padding_value,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
DenseI64ArrayAttr:$inner_dims_pos,
Variadic<Index>:$inner_tiles,
DenseI64ArrayAttr:$static_inner_tiles);
- let results = (outs AnyRankedTensor:$result);
+ let results = (outs AnyShaped:$result);
let assemblyFormat = [{
$source
(`padding_value` `(` $padding_value^ `:` type($padding_value) `)`)?
@@ -179,6 +192,28 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
];
let extraClassDeclaration = commonExtraClassDeclaration # [{
+ Value getOutput() {
+ return getDpsInitOperand(0)->get();
+ }
+
+ // Return the input operand.
+ Value getInput() {
+ return getDpsInputOperand(0)->get();
+ }
+ ShapedType getInputType() {
+ return cast<ShapedType>(getInput().getType());
+ }
+ ShapedType getOutputType() {
+ return cast<ShapedType>(getDest().getType());
+ }
+ int64_t getInputRank() {
+ return getInputType().getRank();
+ }
+ int64_t getOutputRank() {
+ return getOutputType().getRank();
+ }
+
+ LogicalResult generateScalarImplementation(OpBuilder &builder, Location loc, ValueRange ivs);
// Method to get the shape of the result as `SmallVector<OpFoldResult>`.
// This is a static method to allow getting the shape of the destination
// expected while creating a `pack` op.
@@ -229,6 +264,7 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
/// 2. pads the other ones, and
/// 3. doesn't shuffle the dimensions
bool isLikePad();
+
}];
let hasCanonicalizeMethod = 1;
@@ -279,13 +315,13 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
: tensor<8x16x8x32xf32> -> tensor<128x256xf32>
```
}];
- let arguments = (ins AnyRankedTensor:$source,
- AnyRankedTensor:$dest,
+ let arguments = (ins AnyShaped:$source,
+ AnyShaped:$dest,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
DenseI64ArrayAttr:$inner_dims_pos,
Variadic<Index>:$inner_tiles,
DenseI64ArrayAttr:$static_inner_tiles);
- let results = (outs AnyRankedTensor:$result);
+ let results = (outs AnyShaped:$result);
let assemblyFormat = [{
$source
(`outer_dims_perm` `=` $outer_dims_perm^)?
@@ -303,6 +339,28 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
];
let extraClassDeclaration = commonExtraClassDeclaration # [{
+ Value getOutput() {
+ return getDpsInitOperand(0)->get();
+ }
+
+ // Return the input operand.
+ Value getInput() {
+ return getDpsInputOperand(0)->get();
+ }
+ ShapedType getInputType() {
+ return cast<ShapedType>(getInput().getType());
+ }
+ ShapedType getOutputType() {
+ return cast<ShapedType>(getDest().getType()); // getDest() 사용
+ }
+ int64_t getInputRank() {
+ return getInputType().getRank();
+ }
+ int64_t getOutputRank() {
+ return getOutputType().getRank();
+ }
+ LogicalResult generateScalarImplementation(OpBuilder &builder, Location loc, ValueRange ivs);
+
static Value createDestinationTensor(OpBuilder &b, Location loc,
Value source, ArrayRef<OpFoldResult> innerTileSizes,
ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm);
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td
index 2dec2fc4396f4..467d862d277eb 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td
@@ -10,6 +10,7 @@
#define LINALG_IR_RELAYOUTOPINTERFACE
include "mlir/Interfaces/DestinationStyleOpInterface.td"
+include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
include "mlir/IR/OpBase.td"
def LinalgRelayoutOpInterface : OpInterface<"RelayoutOpInterface"> {
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 3af89a6ab3799..a86bf74a7b6a1 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -451,7 +451,7 @@ getLinearizedDimensions(ArrayRef<ReassociationIndices> reassociationIndices);
/// %4 = tensor.extract_slice %0 [%3#0, %3#1, %3#2, 0] [1, 1, 1, 10] [1, 1, 1, 1] :
/// tensor<3x7x11x10xf32> to tensor<1x1x1x10xf32>
///
-/// %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] :
+/// %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] :
/// tensor<1x1x1x10xf32> into tensor<1x10xf32>
/// %6 = tensor.insert_slice %5 into %arg0 [%iv, 0] [1, 10] [1, 1] :
/// tensor<1x10xf32> into tensor<10x10xf32>
@@ -573,7 +573,7 @@ PackingMetadata computePackingMetadata(int64_t packedRank,
/// Removes the op and replaces the constant with a new constant of the result
/// shape. When an optional cst attribute is passed, it is reshaped only if the
/// splat value matches the value in the attribute.
-OpFoldResult reshapeConstantSource(DenseElementsAttr source, TensorType result,
+OpFoldResult reshapeConstantSource(DenseElementsAttr source, ShapedType result,
std::optional<Attribute> cst = std::nullopt);
} // namespace mlir
diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index 0336423c57b1d..4267732571801 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -485,7 +485,7 @@ PackingMetadata mlir::computePackingMetadata(int64_t packedRank,
}
OpFoldResult mlir::reshapeConstantSource(DenseElementsAttr source,
- TensorType result,
+ ShapedType result,
std::optional<Attribute> cst) {
if (source && source.isSplat() && result.hasStaticShape() &&
(!cst.has_value() || source.getSplatValue<Attribute>() == cst.value()))
>From 4f2dbf4848092942a7932387e39d3c1220d78923 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Fri, 28 Feb 2025 08:00:32 +0900
Subject: [PATCH 02/23] draft
---
.../Dialect/Linalg/IR/LinalgRelayoutOps.td | 44 -------------------
1 file changed, 44 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index f8a4657c564ce..6e2c6171132f5 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -192,28 +192,6 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
];
let extraClassDeclaration = commonExtraClassDeclaration # [{
- Value getOutput() {
- return getDpsInitOperand(0)->get();
- }
-
- // Return the input operand.
- Value getInput() {
- return getDpsInputOperand(0)->get();
- }
- ShapedType getInputType() {
- return cast<ShapedType>(getInput().getType());
- }
- ShapedType getOutputType() {
- return cast<ShapedType>(getDest().getType());
- }
- int64_t getInputRank() {
- return getInputType().getRank();
- }
- int64_t getOutputRank() {
- return getOutputType().getRank();
- }
-
- LogicalResult generateScalarImplementation(OpBuilder &builder, Location loc, ValueRange ivs);
// Method to get the shape of the result as `SmallVector<OpFoldResult>`.
// This is a static method to allow getting the shape of the destination
// expected while creating a `pack` op.
@@ -339,28 +317,6 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
];
let extraClassDeclaration = commonExtraClassDeclaration # [{
- Value getOutput() {
- return getDpsInitOperand(0)->get();
- }
-
- // Return the input operand.
- Value getInput() {
- return getDpsInputOperand(0)->get();
- }
- ShapedType getInputType() {
- return cast<ShapedType>(getInput().getType());
- }
- ShapedType getOutputType() {
- return cast<ShapedType>(getDest().getType()); // getDest() 사용
- }
- int64_t getInputRank() {
- return getInputType().getRank();
- }
- int64_t getOutputRank() {
- return getOutputType().getRank();
- }
- LogicalResult generateScalarImplementation(OpBuilder &builder, Location loc, ValueRange ivs);
-
static Value createDestinationTensor(OpBuilder &b, Location loc,
Value source, ArrayRef<OpFoldResult> innerTileSizes,
ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm);
>From 226230c9445084671531d755d5c3f5612bed7d67 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Fri, 28 Feb 2025 08:01:05 +0900
Subject: [PATCH 03/23] draft
---
.../mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td | 15 +--------------
1 file changed, 1 insertion(+), 14 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index 6e2c6171132f5..c68c395fc6337 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -77,20 +77,7 @@ class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
/// with `inner_dims_pos` rather than the packed tensor.
SmallVector<int64_t> getTiledOuterDims();
}];
- let extraClassDeclaration = commonExtraClassDeclaration # [{
- ShapedType getInputType() {
- return cast<ShapedType>(getInput().getType());
- }
- ShapedType getOutputType() {
- return cast<ShapedType>(getOutput().getType());
- }
- int64_t getInputRank() {
- return getInputType().getRank();
- }
- int64_t getOutputRank() {
- return getOutputType().getRank();
- }
- }];
+
let hasVerifier = 1;
}
>From 0c184dfc85cdb0d89d62aa8cafc4f752e1acc654 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Fri, 28 Feb 2025 09:44:08 +0900
Subject: [PATCH 04/23] init
---
.../Dialect/Linalg/IR/LinalgRelayoutOps.td | 10 +++---
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 14 ++++-----
.../Transforms/PackAndUnpackPatterns.cpp | 12 +++----
.../Dialect/Linalg/Transforms/Transforms.cpp | 31 +++++++++++++++----
.../Linalg/Transforms/Vectorization.cpp | 2 +-
mlir/lib/Tools/mlir-opt/launch.json | 13 ++++++++
6 files changed, 57 insertions(+), 25 deletions(-)
create mode 100644 mlir/lib/Tools/mlir-opt/launch.json
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index c68c395fc6337..785c7cc924159 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -43,10 +43,10 @@ class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
code commonExtraClassDeclaration = [{
size_t getSourceRank() { return getSourceType().getRank(); };
size_t getDestRank() { return getDestType().getRank(); };
- RankedTensorType getSourceType() {
- return ::llvm::cast<RankedTensorType>(getSource().getType()); };
- RankedTensorType getDestType() {
- return ::llvm::cast<RankedTensorType>(getDest().getType()); };
+ ShapedType getSourceType() {
+ return ::llvm::cast<ShapedType>(getSource().getType()); };
+ ShapedType getDestType() {
+ return ::llvm::cast<ShapedType>(getDest().getType()); };
MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
@@ -190,7 +190,7 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
// Method to get the `RankedTensorType` of the result based on the inner
// tiles, position of the inner tiles (innerDimsPos) and interchange vector
// of outer loops (outerDimsPerm).
- static RankedTensorType inferPackedType(RankedTensorType sourceType,
+ static RankedTensorType inferPackedType(ShapedType sourceType,
ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outerDimsPerm = {});
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 07b19e5cb1a89..f4f08d9d4acf7 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -803,7 +803,7 @@ struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
}
- RankedTensorType srcPadType = srcPadOp.getSourceType();
+ ShapedType srcPadType = srcPadOp.getSourceType();
SmallVector<OpFoldResult, 4> newSizes;
for (int i = 0, e = srcPadType.getRank(); i < e; ++i) {
if (srcPadType.isDynamicDim(i)) {
@@ -4433,7 +4433,7 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
return op->emitError("invalid zero tile factor");
// Verify inner_dims_pos and outer_dims_perm.
- RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
+ ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
? packOrUnPack.getSourceType()
: packOrUnPack.getDestType();
size_t unpackedRank = unpackedType.getRank();
@@ -4747,7 +4747,7 @@ SmallVector<OpFoldResult> PackOp::getResultShape(
/// Get the expected packed type based on source type, tile factors, position of
/// the inner tiles and permutation of the outer tiled loop.
-RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
+RankedTensorType PackOp::inferPackedType(ShapedType sourceType,
ArrayRef<int64_t> innerTileSizes,
ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outerDimsPerm) {
@@ -4943,7 +4943,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
}
Value dest = packOp.getDest();
- RankedTensorType originalResultType = packOp.getDestType();
+ ShapedType originalResultType = packOp.getDestType();
bool needUpdateDestType = (destShape != originalResultType.getShape());
if (needUpdateDestType) {
auto newDestType = packOp.getDestType().clone(destShape);
@@ -4953,7 +4953,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
rewriter.modifyOpInPlace(packOp, [&] {
packOp.getSourceMutable().assign(source);
packOp.getDestMutable().assign(dest);
- packOp.getResult().setType(cast<RankedTensorType>(dest.getType()));
+ packOp.getResult().setType(cast<ShapedType>(dest.getType()));
});
// Insert a cast if needed
if (needUpdateDestType) {
@@ -4970,7 +4970,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
template <typename PackOrUnpackOp>
static bool isLikePadUnPad(PackOrUnpackOp packOp,
- RankedTensorType packedTensorType) {
+ ShapedType packedTensorType) {
static_assert(std::is_same<PackOrUnpackOp, PackOp>::value ||
std::is_same<PackOrUnpackOp, UnPackOp>::value,
"Function meant for pack/unpack");
@@ -5274,7 +5274,7 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
}
bool UnPackOp::isLikeUnPad() {
- RankedTensorType packedTensorType = getSourceType();
+ ShapedType packedTensorType = getSourceType();
return isLikePadUnPad(*this, packedTensorType);
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
index 0984b6988b93b..599aa3b6668df 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
@@ -111,7 +111,7 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
if (packOp.getPaddingValue())
return rewriter.notifyMatchFailure(packOp, "expects no padding value");
- RankedTensorType sourceType = packOp.getSourceType();
+ ShapedType sourceType = packOp.getSourceType();
if (failed(isPackOnInnerMostDim(rewriter, packOp)) &&
failed(isPackOn1D(rewriter, packOp, sourceType.getShape(),
packOp.getStaticTiles())) &&
@@ -119,7 +119,7 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
return failure();
}
- RankedTensorType destType = packOp.getDestType();
+ ShapedType destType = packOp.getDestType();
auto reassociation =
getReassociationIndicesForReshape(sourceType, destType);
if (!reassociation)
@@ -157,8 +157,8 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
"expects outer_dims_perm is empty or an identity permutation");
}
- RankedTensorType sourceType = unpackOp.getSourceType();
- RankedTensorType destType = unpackOp.getDestType();
+ ShapedType sourceType = unpackOp.getSourceType();
+ ShapedType destType = unpackOp.getDestType();
if (!sourceType.hasStaticShape() || !destType.hasStaticShape())
return rewriter.notifyMatchFailure(unpackOp, "expects static shapes");
@@ -173,7 +173,7 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
LogicalResult matchAndRewrite(UnPackOp unpackOp,
PatternRewriter &rewriter) const override {
- RankedTensorType destType = unpackOp.getDestType();
+ ShapedType destType = unpackOp.getDestType();
if (failed(isUnpackOnInnerMostDim(rewriter, unpackOp)) &&
failed(isPackOn1D(rewriter, unpackOp, destType.getShape(),
unpackOp.getStaticTiles())) &&
@@ -181,7 +181,7 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
return failure();
}
- RankedTensorType sourceType = unpackOp.getSourceType();
+ ShapedType sourceType = unpackOp.getSourceType();
auto reassociation =
getReassociationIndicesForReshape(sourceType, destType);
if (!reassociation)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index dcd50cc44f81b..7ed211841c53f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -359,7 +359,7 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(unPackOp);
- RankedTensorType packedTensorType = unPackOp.getSourceType();
+ ShapedType packedTensorType = unPackOp.getSourceType();
int64_t packedRank = packedTensorType.getRank();
OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
@@ -396,10 +396,29 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm);
// 3. Transpose packedShape to stripMinedShape.
- RankedTensorType stripMinedTensorType =
- RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
- RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
- stripMinedTensorType, packingMetadata.reassociations);
+ ShapedType stripMinedType;
+ if (auto tensorType = packedTensorType.dyn_cast<TensorType>()) {
+ stripMinedType =
+ RankedTensorType::get(stripMinedShape, tensorType.getElementType());
+ } else if (auto memrefType = packedTensorType.dyn_cast<MemRefType>()) {
+ stripMinedType =
+ MemRefType::get(stripMinedShape, memrefType.getElementType());
+ }
+ ShapedType collapsedType;
+ if (stripMinedType.isa<TensorType>()) {
+ collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
+ stripMinedType.cast<RankedTensorType>(),
+ packingMetadata.reassociations);
+ } else if (stripMinedType.isa<MemRefType>()) {
+ auto memrefTy = stripMinedType.cast<MemRefType>();
+ auto tensorTy =
+ RankedTensorType::get(memrefTy.getShape(), memrefTy.getElementType());
+ auto collapsedTensorType = tensor::CollapseShapeOp::inferCollapsedType(
+ tensorTy, packingMetadata.reassociations);
+ // tensor collapsed type을 memref로 재구성 (같은 메모리 공간 유지)
+ collapsedType = MemRefType::get(collapsedTensorType.getShape(),
+ collapsedTensorType.getElementType());
+ }
// Get dynamic dims from input tensor based on packedToStripMinedShapePerm
// permutation.
@@ -407,7 +426,7 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
tensor::getMixedSizes(rewriter, loc, unPackOp.getSource());
applyPermutationToVector(dims, packedToStripMinedShapePerm);
auto emptyOp = rewriter.create<tensor::EmptyOp>(
- loc, dims, stripMinedTensorType.getElementType());
+ loc, dims, stripMinedType.getElementType());
auto transposeOp = rewriter.create<linalg::TransposeOp>(
loc, unPackOp.getSource(), emptyOp, packedToStripMinedShapePerm);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index ae04c2b6b2a5b..25ad5e38addbe 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1669,7 +1669,7 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(unpackOp);
- RankedTensorType unpackTensorType = unpackOp.getSourceType();
+ ShapedType unpackTensorType = unpackOp.getSourceType();
ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
diff --git a/mlir/lib/Tools/mlir-opt/launch.json b/mlir/lib/Tools/mlir-opt/launch.json
new file mode 100644
index 0000000000000..5a686d02e2dfb
--- /dev/null
+++ b/mlir/lib/Tools/mlir-opt/launch.json
@@ -0,0 +1,13 @@
+{
+ "version": "0.2.0",
+ "configurations": [
+ {
+ "name": "ma",
+ "type": "lldb",
+ "request": "launch",
+ "program": "/Users/ita/src/iree-build/tools/iree-opt --show-dialects",
+ "args": [],
+ "cwd": "${workspaceFolder}"
+ }
+ ]
+}
>From 19201c69e23578a69583bb98415f9c9583cb5c41 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Fri, 28 Feb 2025 14:50:46 +0900
Subject: [PATCH 05/23] lint
---
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp | 1 -
mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp | 12 ++++++------
mlir/lib/Tools/mlir-opt/launch.json | 13 -------------
3 files changed, 6 insertions(+), 20 deletions(-)
delete mode 100644 mlir/lib/Tools/mlir-opt/launch.json
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 7ed211841c53f..36e01ef46b30b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -415,7 +415,6 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
RankedTensorType::get(memrefTy.getShape(), memrefTy.getElementType());
auto collapsedTensorType = tensor::CollapseShapeOp::inferCollapsedType(
tensorTy, packingMetadata.reassociations);
- // tensor collapsed type을 memref로 재구성 (같은 메모리 공간 유지)
collapsedType = MemRefType::get(collapsedTensorType.getShape(),
collapsedTensorType.getElementType());
}
diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index 4267732571801..9a2bd3493f6af 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -315,11 +315,11 @@ SmallVector<Range> SliceFromCollapseHelper::getExtractSliceParams(
// have proven that these are not sliced. In this case we just take
// the full extent of each dimension in the reassociation list.
if (linearizedDimensions[it.index()]) {
- llvm::append_range(
- offsetsSizesAndStrides,
- llvm::map_range(it.value(), [&](int64_t idx) -> Range {
- return {zeroAttr, collapseShapeInputShape[idx], oneAttr};
- }));
+ llvm::append_range(offsetsSizesAndStrides,
+ llvm::map_range(it.value(), [&](int64_t idx) -> Range {
+ return {zeroAttr, collapseShapeInputShape[idx],
+ oneAttr};
+ }));
continue;
}
@@ -485,7 +485,7 @@ PackingMetadata mlir::computePackingMetadata(int64_t packedRank,
}
OpFoldResult mlir::reshapeConstantSource(DenseElementsAttr source,
- ShapedType result,
+ ShapedType result,
std::optional<Attribute> cst) {
if (source && source.isSplat() && result.hasStaticShape() &&
(!cst.has_value() || source.getSplatValue<Attribute>() == cst.value()))
diff --git a/mlir/lib/Tools/mlir-opt/launch.json b/mlir/lib/Tools/mlir-opt/launch.json
deleted file mode 100644
index 5a686d02e2dfb..0000000000000
--- a/mlir/lib/Tools/mlir-opt/launch.json
+++ /dev/null
@@ -1,13 +0,0 @@
-{
- "version": "0.2.0",
- "configurations": [
- {
- "name": "ma",
- "type": "lldb",
- "request": "launch",
- "program": "/Users/ita/src/iree-build/tools/iree-opt --show-dialects",
- "args": [],
- "cwd": "${workspaceFolder}"
- }
- ]
-}
>From b99b92030f2f664607f43554d2b7bc722c98c2c1 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Fri, 28 Feb 2025 15:26:13 +0900
Subject: [PATCH 06/23] lint
---
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 7 +++----
1 file changed, 3 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index f4f08d9d4acf7..eca8cea3e6323 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -4434,8 +4434,8 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
// Verify inner_dims_pos and outer_dims_perm.
ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
- ? packOrUnPack.getSourceType()
- : packOrUnPack.getDestType();
+ ? packOrUnPack.getSourceType()
+ : packOrUnPack.getDestType();
size_t unpackedRank = unpackedType.getRank();
ArrayRef<int64_t> innerDimsPos = packOrUnPack.getInnerDimsPos();
ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm();
@@ -4969,8 +4969,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
}
template <typename PackOrUnpackOp>
-static bool isLikePadUnPad(PackOrUnpackOp packOp,
- ShapedType packedTensorType) {
+static bool isLikePadUnPad(PackOrUnpackOp packOp, ShapedType packedTensorType) {
static_assert(std::is_same<PackOrUnpackOp, PackOp>::value ||
std::is_same<PackOrUnpackOp, UnPackOp>::value,
"Function meant for pack/unpack");
>From be6a1193579633d7b678a30a9a80e5dee89a51e1 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Fri, 28 Feb 2025 16:19:20 +0900
Subject: [PATCH 07/23] add
---
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 9 ++++++---
1 file changed, 6 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index eca8cea3e6323..a19039fbca67d 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -5001,9 +5001,12 @@ static bool isLikePadUnPad(PackOrUnpackOp packOp, ShapedType packedTensorType) {
}
bool PackOp::isLikePad() {
- auto packedTensorType =
- llvm::cast<RankedTensorType>((*this)->getResultTypes().front());
- return isLikePadUnPad(*this, packedTensorType);
+ if (auto packedTensorType =
+ llvm::dyn_cast<RankedTensorType>((*this)->getResultTypes().front()))
+ return isLikePadUnPad(*this, packedTensorType);
+ if (auto packedTensorType =
+ llvm::dyn_cast<MemRefType>((*this)->getResultTypes().front()))
+ return isLikePadUnPad(*this, packedTensorType);
}
OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
>From eee8805c351e7b8100d3e73d1e67c1c06e065962 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Sat, 1 Mar 2025 09:04:26 +0900
Subject: [PATCH 08/23] remove tensor casting
---
.../mlir/Dialect/MemRef/IR/MemRefOps.td | 5 +++
.../Dialect/Linalg/Transforms/Transforms.cpp | 10 ++----
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 32 ++++++++++++++++++-
3 files changed, 39 insertions(+), 8 deletions(-)
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 4c8a214049ea9..8bcc1882b454d 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1799,6 +1799,11 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
static MemRefType computeCollapsedType(
MemRefType srcType, ArrayRef<ReassociationIndices> reassociation);
+ static MemRefType
+ inferCollapsedType(MemRefType type, ArrayRef<AffineMap> reassociation);
+ static MemRefType
+ inferCollapsedType(MemRefType type,
+ SmallVector<ReassociationIndices> reassociation);
}];
let hasVerifier = 1;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 36e01ef46b30b..efa0453dda036 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h"
@@ -410,13 +411,8 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
stripMinedType.cast<RankedTensorType>(),
packingMetadata.reassociations);
} else if (stripMinedType.isa<MemRefType>()) {
- auto memrefTy = stripMinedType.cast<MemRefType>();
- auto tensorTy =
- RankedTensorType::get(memrefTy.getShape(), memrefTy.getElementType());
- auto collapsedTensorType = tensor::CollapseShapeOp::inferCollapsedType(
- tensorTy, packingMetadata.reassociations);
- collapsedType = MemRefType::get(collapsedTensorType.getShape(),
- collapsedTensorType.getElementType());
+ collapsedType = memref::CollapseShapeOp::inferCollapsedType(
+ stripMinedType.cast<MemRefType>(), packingMetadata.reassociations);
}
// Get dynamic dims from input tensor based on packedToStripMinedShapePerm
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 11597505e7888..ba12cc34d6457 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -9,6 +9,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
@@ -1124,7 +1125,7 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
}
} // else dim.getIndex is a block argument to reshape->getBlock and
// dominates reshape
- } // Check condition 2
+ } // Check condition 2
else if (dim->getBlock() != reshape->getBlock() &&
!dim.getIndex().getParentRegion()->isProperAncestor(
reshape->getParentRegion())) {
@@ -2525,6 +2526,35 @@ MemRefType CollapseShapeOp::computeCollapsedType(
srcType.getMemorySpace());
}
+MemRefType
+CollapseShapeOp::inferCollapsedType(MemRefType type,
+ ArrayRef<AffineMap> reassociation) {
+ auto shape = type.getShape();
+ SmallVector<int64_t, 4> newShape;
+ assert(isReassociationValid(reassociation) && "invalid reassociation");
+ unsigned currentDim = 0;
+ for (AffineMap m : reassociation) {
+ unsigned dim = m.getNumResults();
+ auto band = shape.slice(currentDim, dim);
+ int64_t size = 1;
+ if (llvm::is_contained(band, ShapedType::kDynamic))
+ size = ShapedType::kDynamic;
+ else
+ for (unsigned d = 0; d < dim; ++d)
+ size *= shape[currentDim + d];
+ newShape.push_back(size);
+ currentDim += dim;
+ }
+ return MemRefType::get(newShape, type.getElementType());
+}
+
+MemRefType CollapseShapeOp::inferCollapsedType(
+ MemRefType type, SmallVector<ReassociationIndices> reassociation) {
+ return inferCollapsedType(
+ type, getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
+ type.getContext(), reassociation)));
+}
+
void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
ArrayRef<ReassociationIndices> reassociation,
ArrayRef<NamedAttribute> attrs) {
>From c5b3c3955321ef0e9211226c8fea017bd4b591bf Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Sat, 1 Mar 2025 09:39:30 +0900
Subject: [PATCH 09/23] add test
---
.../lib/Dialect/Linalg/Transforms/Transforms.cpp | 5 ++---
mlir/test/Dialect/Linalg/loops.mlir | 16 ++++++++++++++++
2 files changed, 18 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index efa0453dda036..98dab332b2f40 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -408,11 +408,10 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
ShapedType collapsedType;
if (stripMinedType.isa<TensorType>()) {
collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
- stripMinedType.cast<RankedTensorType>(),
- packingMetadata.reassociations);
+ cast<RankedTensorType>(stripMinedType), packingMetadata.reassociations);
} else if (stripMinedType.isa<MemRefType>()) {
collapsedType = memref::CollapseShapeOp::inferCollapsedType(
- stripMinedType.cast<MemRefType>(), packingMetadata.reassociations);
+ cast<MemRefType>(stripMinedType), packingMetadata.reassociations);
}
// Get dynamic dims from input tensor based on packedToStripMinedShapePerm
diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index efe8010cffc91..767f593329f52 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -942,3 +942,19 @@ func.func @transpose(%input: memref<?xf32>,
// CHECKPARALLEL: }
// CHECKPARALLEL: return
// CHECKPARALLEL: }
+
+// Test that we can lower all the way to LLVM without crashing, don't check results here.
+func.func @pack_memref(%source: memref<128x256xf32>) -> memref<8x16x8x32xf32> {
+ %dest = memref.alloc() : memref<8x16x8x32xf32>
+ %packed = linalg.pack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [8, 32]
+ into %dest : memref<128x256xf32> -> memref<8x16x8x32xf32>
+ return %packed : memref<8x16x8x32xf32>
+}
+
+// Test that we can lower all the way to LLVM without crashing, don't check results here.
+func.func @unpack_memref(%source: memref<16x8x8x32xf32>) -> memref<128x256xf32> {
+ %dest = memref.alloc() : memref<128x256xf32>
+ %unpacked = linalg.unpack %source inner_dims_pos = [0, 1] inner_tiles = [8, 32]
+ into %dest : memref<16x8x8x32xf32> -> memref<128x256xf32>
+ return %unpacked : memref<128x256xf32>
+}
\ No newline at end of file
>From a5d01dffda768947463451af6cab1cf6e282114e Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <hyunsungl at nvidia.com>
Date: Sun, 16 Mar 2025 21:21:41 +0900
Subject: [PATCH 10/23] fix upon review
---
.../Dialect/Linalg/IR/RelayoutOpInterface.td | 1 -
.../mlir/Dialect/MemRef/IR/MemRefOps.td | 7 +--
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 14 +++--
.../Transforms/PackAndUnpackPatterns.cpp | 24 +++++---
.../Dialect/Linalg/Transforms/Transforms.cpp | 2 +-
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 56 +++++++++----------
mlir/test/Dialect/Linalg/loops.mlir | 16 ------
mlir/test/Dialect/Linalg/roundtrip.mlir | 18 ++++++
8 files changed, 71 insertions(+), 67 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td
index 467d862d277eb..2dec2fc4396f4 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td
@@ -10,7 +10,6 @@
#define LINALG_IR_RELAYOUTOPINTERFACE
include "mlir/Interfaces/DestinationStyleOpInterface.td"
-include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
include "mlir/IR/OpBase.td"
def LinalgRelayoutOpInterface : OpInterface<"RelayoutOpInterface"> {
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 87564066d309d..93449766aca4e 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1782,11 +1782,6 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
static MemRefType computeCollapsedType(
MemRefType srcType, ArrayRef<ReassociationIndices> reassociation);
- static MemRefType
- inferCollapsedType(MemRefType type, ArrayRef<AffineMap> reassociation);
- static MemRefType
- inferCollapsedType(MemRefType type,
- SmallVector<ReassociationIndices> reassociation);
}];
let hasVerifier = 1;
@@ -1806,7 +1801,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
let summary = "store operation";
let description = [{
The `store` op stores an element into a memref at the specified indices.
-
+
The number of indices must match the rank of the memref. The indices must
be in-bounds: `0 <= idx < dim_size`
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index a19039fbca67d..b4cbc7c6ad8e9 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -5001,12 +5001,8 @@ static bool isLikePadUnPad(PackOrUnpackOp packOp, ShapedType packedTensorType) {
}
bool PackOp::isLikePad() {
- if (auto packedTensorType =
- llvm::dyn_cast<RankedTensorType>((*this)->getResultTypes().front()))
- return isLikePadUnPad(*this, packedTensorType);
- if (auto packedTensorType =
- llvm::dyn_cast<MemRefType>((*this)->getResultTypes().front()))
- return isLikePadUnPad(*this, packedTensorType);
+ auto packedTensorType = llvm::dyn_cast<ShapedType>((*this)->getResultTypes().front());
+ return isLikePadUnPad(*this, packedTensorType);
}
OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
@@ -5042,6 +5038,9 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
if (!tensor::hasFoldableTensorCastOperand(op))
return failure();
+ if (!op.hasPureTensorSemantics())
+ return failure();
+
SmallVector<Type> newResultTypes(op->getResultTypes());
SmallVector<Value> newOperands =
tensor::getUpdatedOperandsAfterCastOpFolding(op, newResultTypes);
@@ -5310,6 +5309,9 @@ struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
if (!tensor::hasFoldableTensorCastOperand(op))
return failure();
+ if (!op.hasPureTensorSemantics())
+ return failure();
+
SmallVector<Type> newResultTypes(op->getResultTypes());
SmallVector<Value> newOperands =
tensor::getUpdatedOperandsAfterCastOpFolding(op, newResultTypes);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
index 599aa3b6668df..59e4b2ff634c2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
@@ -171,25 +171,27 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
return success();
}
- LogicalResult matchAndRewrite(UnPackOp unpackOp,
+ LogicalResult matchAndRewrite(UnPackOp unPackOp,
PatternRewriter &rewriter) const override {
- ShapedType destType = unpackOp.getDestType();
- if (failed(isUnpackOnInnerMostDim(rewriter, unpackOp)) &&
- failed(isPackOn1D(rewriter, unpackOp, destType.getShape(),
- unpackOp.getStaticTiles())) &&
- !unpackOp.isLikeUnPad()) {
+ if (!unPackOp.hasPureTensorSemantics())
+ return failure();
+ ShapedType destType = unPackOp.getDestType();
+ if (failed(isUnpackOnInnerMostDim(rewriter, unPackOp)) &&
+ failed(isPackOn1D(rewriter, unPackOp, destType.getShape(),
+ unPackOp.getStaticTiles())) &&
+ !unPackOp.isLikeUnPad()) {
return failure();
}
- ShapedType sourceType = unpackOp.getSourceType();
+ ShapedType sourceType = unPackOp.getSourceType();
auto reassociation =
getReassociationIndicesForReshape(sourceType, destType);
if (!reassociation)
return failure();
Value collapsed = insertCollapse(
- rewriter, unpackOp.getLoc(), unpackOp.getSource(), destType,
+ rewriter, unPackOp.getLoc(), unPackOp.getSource(), destType,
getReassociationIndicesAttribute(rewriter, *reassociation));
- rewriter.replaceOp(unpackOp, collapsed);
+ rewriter.replaceOp(unPackOp, collapsed);
return success();
}
};
@@ -426,6 +428,8 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
LogicalResult matchAndRewrite(UnPackOp unPackOp,
PatternRewriter &rewriter) const override {
+ if (!unPackOp.hasPureTensorSemantics())
+ return failure();
auto linalgOp = unPackOp.getSource().getDefiningOp<linalg::LinalgOp>();
if (!linalgOp)
return failure();
@@ -507,6 +511,8 @@ struct FoldEmptyTensorWithUnPackOp : public OpRewritePattern<UnPackOp> {
LogicalResult matchAndRewrite(UnPackOp unPackOp,
PatternRewriter &rewriter) const override {
+ if (!unPackOp.hasPureTensorSemantics())
+ return failure();
// Check for tensor.empty source.
auto emptyOp = unPackOp.getSource().getDefiningOp<tensor::EmptyOp>();
if (!emptyOp)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 98dab332b2f40..105831a3d9259 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -410,7 +410,7 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
cast<RankedTensorType>(stripMinedType), packingMetadata.reassociations);
} else if (stripMinedType.isa<MemRefType>()) {
- collapsedType = memref::CollapseShapeOp::inferCollapsedType(
+ collapsedType = memref::CollapseShapeOp::computeCollapsedType(
cast<MemRefType>(stripMinedType), packingMetadata.reassociations);
}
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index ba12cc34d6457..03c08756d110b 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2526,34 +2526,34 @@ MemRefType CollapseShapeOp::computeCollapsedType(
srcType.getMemorySpace());
}
-MemRefType
-CollapseShapeOp::inferCollapsedType(MemRefType type,
- ArrayRef<AffineMap> reassociation) {
- auto shape = type.getShape();
- SmallVector<int64_t, 4> newShape;
- assert(isReassociationValid(reassociation) && "invalid reassociation");
- unsigned currentDim = 0;
- for (AffineMap m : reassociation) {
- unsigned dim = m.getNumResults();
- auto band = shape.slice(currentDim, dim);
- int64_t size = 1;
- if (llvm::is_contained(band, ShapedType::kDynamic))
- size = ShapedType::kDynamic;
- else
- for (unsigned d = 0; d < dim; ++d)
- size *= shape[currentDim + d];
- newShape.push_back(size);
- currentDim += dim;
- }
- return MemRefType::get(newShape, type.getElementType());
-}
-
-MemRefType CollapseShapeOp::inferCollapsedType(
- MemRefType type, SmallVector<ReassociationIndices> reassociation) {
- return inferCollapsedType(
- type, getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
- type.getContext(), reassociation)));
-}
+// MemRefType
+// CollapseShapeOp::inferCollapsedType(MemRefType type,
+// ArrayRef<AffineMap> reassociation) {
+// auto shape = type.getShape();
+// SmallVector<int64_t, 4> newShape;
+// assert(isReassociationValid(reassociation) && "invalid reassociation");
+// unsigned currentDim = 0;
+// for (AffineMap m : reassociation) {
+// unsigned dim = m.getNumResults();
+// auto band = shape.slice(currentDim, dim);
+// int64_t size = 1;
+// if (llvm::is_contained(band, ShapedType::kDynamic))
+// size = ShapedType::kDynamic;
+// else
+// for (unsigned d = 0; d < dim; ++d)
+// size *= shape[currentDim + d];
+// newShape.push_back(size);
+// currentDim += dim;
+// }
+// return MemRefType::get(newShape, type.getElementType());
+// }
+
+// MemRefType CollapseShapeOp::inferCollapsedType(
+// MemRefType type, SmallVector<ReassociationIndices> reassociation) {
+// return inferCollapsedType(
+// type, getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
+// type.getContext(), reassociation)));
+// }
void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
ArrayRef<ReassociationIndices> reassociation,
diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index 767f593329f52..efe8010cffc91 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -942,19 +942,3 @@ func.func @transpose(%input: memref<?xf32>,
// CHECKPARALLEL: }
// CHECKPARALLEL: return
// CHECKPARALLEL: }
-
-// Test that we can lower all the way to LLVM without crashing, don't check results here.
-func.func @pack_memref(%source: memref<128x256xf32>) -> memref<8x16x8x32xf32> {
- %dest = memref.alloc() : memref<8x16x8x32xf32>
- %packed = linalg.pack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [8, 32]
- into %dest : memref<128x256xf32> -> memref<8x16x8x32xf32>
- return %packed : memref<8x16x8x32xf32>
-}
-
-// Test that we can lower all the way to LLVM without crashing, don't check results here.
-func.func @unpack_memref(%source: memref<16x8x8x32xf32>) -> memref<128x256xf32> {
- %dest = memref.alloc() : memref<128x256xf32>
- %unpacked = linalg.unpack %source inner_dims_pos = [0, 1] inner_tiles = [8, 32]
- into %dest : memref<16x8x8x32xf32> -> memref<128x256xf32>
- return %unpacked : memref<128x256xf32>
-}
\ No newline at end of file
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index dc556761b09e5..7f7aa12534a9b 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -706,3 +706,21 @@ func.func @conv2d_channel_first_q_promote(%img: tensor<100x3x224x224xi8>, %filt:
// CHECK-LABEL: func @conv2d_channel_first_q_promote(
// CHECK: %[[arg0:[a-zA-z0-9]*]]: tensor<100x3x224x224xi8>, %[[arg1:[a-zA-z0-9]*]]: tensor<64x3x5x5xi8>, %[[arg2:[a-zA-z0-9]*]]: i8, %[[arg3:[a-zA-z0-9]*]]: i8)
// CHECK: linalg.conv_2d_nchw_fchw_q {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]] : tensor<100x3x224x224xi8>, tensor<64x3x5x5xi8>, i8, i8) outs(%{{.*}} : tensor<100x64x220x220xi32>) -> tensor<100x64x220x220xi32>
+
+// -----
+// Test that we can lower all the way to LLVM without crashing, don't check results here.
+func.func @pack_memref(%source: memref<128x256xf32>) -> memref<8x16x8x32xf32> {
+ %dest = memref.alloc() : memref<8x16x8x32xf32>
+ %packed = linalg.pack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [8, 32]
+ into %dest : memref<128x256xf32> -> memref<8x16x8x32xf32>
+ return %packed : memref<8x16x8x32xf32>
+}
+
+// -----
+// Test that we can lower all the way to LLVM without crashing, don't check results here.
+func.func @unpack_memref(%source: memref<16x8x8x32xf32>) -> memref<128x256xf32> {
+ %dest = memref.alloc() : memref<128x256xf32>
+ %unpacked = linalg.unpack %source inner_dims_pos = [0, 1] inner_tiles = [8, 32]
+ into %dest : memref<16x8x8x32xf32> -> memref<128x256xf32>
+ return %unpacked : memref<128x256xf32>
+}
>From 2480616ebfbb968d83ab119bf7d6a84897f482e5 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <hyunsungl at nvidia.com>
Date: Sun, 23 Mar 2025 15:09:40 +0900
Subject: [PATCH 11/23] lint
---
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 3 ++-
mlir/test/Dialect/Linalg/roundtrip.mlir | 8 ++++----
2 files changed, 6 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index b4cbc7c6ad8e9..8d71cc0142556 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -5001,7 +5001,8 @@ static bool isLikePadUnPad(PackOrUnpackOp packOp, ShapedType packedTensorType) {
}
bool PackOp::isLikePad() {
- auto packedTensorType = llvm::dyn_cast<ShapedType>((*this)->getResultTypes().front());
+ auto packedTensorType =
+ llvm::dyn_cast<ShapedType>((*this)->getResultTypes().front());
return isLikePadUnPad(*this, packedTensorType);
}
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 7f7aa12534a9b..c2e9e3fbd5423 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -711,16 +711,16 @@ func.func @conv2d_channel_first_q_promote(%img: tensor<100x3x224x224xi8>, %filt:
// Test that we can lower all the way to LLVM without crashing, don't check results here.
func.func @pack_memref(%source: memref<128x256xf32>) -> memref<8x16x8x32xf32> {
%dest = memref.alloc() : memref<8x16x8x32xf32>
- %packed = linalg.pack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [8, 32]
+ linalg.pack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [8, 32]
into %dest : memref<128x256xf32> -> memref<8x16x8x32xf32>
- return %packed : memref<8x16x8x32xf32>
+ return %dest : memref<8x16x8x32xf32>
}
// -----
// Test that we can lower all the way to LLVM without crashing, don't check results here.
func.func @unpack_memref(%source: memref<16x8x8x32xf32>) -> memref<128x256xf32> {
%dest = memref.alloc() : memref<128x256xf32>
- %unpacked = linalg.unpack %source inner_dims_pos = [0, 1] inner_tiles = [8, 32]
+ linalg.unpack %source inner_dims_pos = [0, 1] inner_tiles = [8, 32]
into %dest : memref<16x8x8x32xf32> -> memref<128x256xf32>
- return %unpacked : memref<128x256xf32>
+ return %dest : memref<128x256xf32>
}
>From 7b92a4ee2af6c15035dbb5824f23f2524c7aa1a3 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <hyunsungl at nvidia.com>
Date: Mon, 24 Mar 2025 10:37:02 +0900
Subject: [PATCH 12/23] format fix
---
.../Dialect/Linalg/IR/LinalgRelayoutOps.td | 1 -
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 29 -------------------
2 files changed, 30 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index 785c7cc924159..63d36ec1fd3d6 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -229,7 +229,6 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
/// 2. pads the other ones, and
/// 3. doesn't shuffle the dimensions
bool isLikePad();
-
}];
let hasCanonicalizeMethod = 1;
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 70d44904788b1..dbd3f6d631a8a 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2526,35 +2526,6 @@ MemRefType CollapseShapeOp::computeCollapsedType(
srcType.getMemorySpace());
}
-// MemRefType
-// CollapseShapeOp::inferCollapsedType(MemRefType type,
-// ArrayRef<AffineMap> reassociation) {
-// auto shape = type.getShape();
-// SmallVector<int64_t, 4> newShape;
-// assert(isReassociationValid(reassociation) && "invalid reassociation");
-// unsigned currentDim = 0;
-// for (AffineMap m : reassociation) {
-// unsigned dim = m.getNumResults();
-// auto band = shape.slice(currentDim, dim);
-// int64_t size = 1;
-// if (llvm::is_contained(band, ShapedType::kDynamic))
-// size = ShapedType::kDynamic;
-// else
-// for (unsigned d = 0; d < dim; ++d)
-// size *= shape[currentDim + d];
-// newShape.push_back(size);
-// currentDim += dim;
-// }
-// return MemRefType::get(newShape, type.getElementType());
-// }
-
-// MemRefType CollapseShapeOp::inferCollapsedType(
-// MemRefType type, SmallVector<ReassociationIndices> reassociation) {
-// return inferCollapsedType(
-// type, getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
-// type.getContext(), reassociation)));
-// }
-
void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
ArrayRef<ReassociationIndices> reassociation,
ArrayRef<NamedAttribute> attrs) {
>From 6dc08ae1628ab2c5795f17af1a3b1ff682e5d861 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <hyunsungl at nvidia.com>
Date: Tue, 25 Mar 2025 14:58:06 +0900
Subject: [PATCH 13/23] revert changes
---
.../Dialect/Linalg/IR/LinalgRelayoutOps.td | 9 ++++++-
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 6 +++++
.../Dialect/Linalg/Transforms/Transforms.cpp | 27 +++++--------------
.../Linalg/Transforms/Vectorization.cpp | 2 +-
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 3 +--
mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp | 10 +++----
6 files changed, 28 insertions(+), 29 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index 63d36ec1fd3d6..03da3d38ef4c5 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -34,7 +34,7 @@ class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
Op<Linalg_Dialect, mnemonic, !listconcat(traits, [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
DestinationStyleOpInterface, LinalgRelayoutOpInterface,
- ConditionallySpeculatable, NoMemoryEffect,
+ ConditionallySpeculatable, DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
TypesMatchWith<"result type matches type of dest",
"dest", "result",
@@ -76,6 +76,13 @@ class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
/// have been tiled. Also, the order of the output dimensions is consistent
/// with `inner_dims_pos` rather than the packed tensor.
SmallVector<int64_t> getTiledOuterDims();
+
+ void $cppClass::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
+ }
+
}];
let hasVerifier = 1;
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 9766c6e56fb7c..1515d648bddca 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -4822,6 +4822,9 @@ bool areTilesAndTiledDimsAllConstant(OpTy op) {
}
Speculation::Speculatability PackOp::getSpeculatability() {
+ if (!hasPureTensorSemantics())
+ return Speculation::NotSpeculatable;
+
if (getPaddingValue())
return Speculation::Speculatable;
@@ -5122,6 +5125,9 @@ LogicalResult UnPackOp::verify() {
}
Speculation::Speculatability UnPackOp::getSpeculatability() {
+ if (!hasPureTensorSemantics())
+ return Speculation::NotSpeculatable;
+
// See PackOp::getSpeculatability.
if (!areTilesAndTiledDimsAllConstant(*this))
return Speculation::NotSpeculatable;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 105831a3d9259..085d6e44d854d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -17,7 +17,6 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
-#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h"
@@ -360,7 +359,7 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(unPackOp);
- ShapedType packedTensorType = unPackOp.getSourceType();
+ RankedTensorType packedTensorType = unPackOp.getSourceType();
int64_t packedRank = packedTensorType.getRank();
OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
@@ -397,22 +396,10 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm);
// 3. Transpose packedShape to stripMinedShape.
- ShapedType stripMinedType;
- if (auto tensorType = packedTensorType.dyn_cast<TensorType>()) {
- stripMinedType =
- RankedTensorType::get(stripMinedShape, tensorType.getElementType());
- } else if (auto memrefType = packedTensorType.dyn_cast<MemRefType>()) {
- stripMinedType =
- MemRefType::get(stripMinedShape, memrefType.getElementType());
- }
- ShapedType collapsedType;
- if (stripMinedType.isa<TensorType>()) {
- collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
- cast<RankedTensorType>(stripMinedType), packingMetadata.reassociations);
- } else if (stripMinedType.isa<MemRefType>()) {
- collapsedType = memref::CollapseShapeOp::computeCollapsedType(
- cast<MemRefType>(stripMinedType), packingMetadata.reassociations);
- }
+ RankedTensorType stripMinedTensorType =
+ RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
+ RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
+ stripMinedTensorType, packingMetadata.reassociations);
// Get dynamic dims from input tensor based on packedToStripMinedShapePerm
// permutation.
@@ -420,7 +407,7 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
tensor::getMixedSizes(rewriter, loc, unPackOp.getSource());
applyPermutationToVector(dims, packedToStripMinedShapePerm);
auto emptyOp = rewriter.create<tensor::EmptyOp>(
- loc, dims, stripMinedType.getElementType());
+ loc, dims, stripMinedTensorType.getElementType());
auto transposeOp = rewriter.create<linalg::TransposeOp>(
loc, unPackOp.getSource(), emptyOp, packedToStripMinedShapePerm);
@@ -1675,4 +1662,4 @@ void linalg::populateDecomposePackUnpackPatterns(RewritePatternSet &patterns) {
void linalg::populateDecomposePadPatterns(RewritePatternSet &patterns) {
patterns.add<DecomposePadOpPattern>(patterns.getContext());
-}
+}
\ No newline at end of file
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index dfb3f0c90595d..2dcd897330d1e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1669,7 +1669,7 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(unpackOp);
- ShapedType unpackTensorType = unpackOp.getSourceType();
+ RankedTensorType unpackTensorType = unpackOp.getSourceType();
ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index dbd3f6d631a8a..1a584a387f2a5 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -9,7 +9,6 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
@@ -1125,7 +1124,7 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
}
} // else dim.getIndex is a block argument to reshape->getBlock and
// dominates reshape
- } // Check condition 2
+ } // Check condition 2
else if (dim->getBlock() != reshape->getBlock() &&
!dim.getIndex().getParentRegion()->isProperAncestor(
reshape->getParentRegion())) {
diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index 9a2bd3493f6af..cd0cdd378c352 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -315,11 +315,11 @@ SmallVector<Range> SliceFromCollapseHelper::getExtractSliceParams(
// have proven that these are not sliced. In this case we just take
// the full extent of each dimension in the reassociation list.
if (linearizedDimensions[it.index()]) {
- llvm::append_range(offsetsSizesAndStrides,
- llvm::map_range(it.value(), [&](int64_t idx) -> Range {
- return {zeroAttr, collapseShapeInputShape[idx],
- oneAttr};
- }));
+ llvm::append_range(
+ offsetsSizesAndStrides,
+ llvm::map_range(it.value(), [&](int64_t idx) -> Range {
+ return {zeroAttr, collapseShapeInputShape[idx], oneAttr};
+ }));
continue;
}
>From cf7be5780250547577c8eca7c0c021f9590516a9 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <hyunsungl at nvidia.com>
Date: Tue, 25 Mar 2025 15:03:54 +0900
Subject: [PATCH 14/23] revert changes
---
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp | 2 +-
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 085d6e44d854d..dcd50cc44f81b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1662,4 +1662,4 @@ void linalg::populateDecomposePackUnpackPatterns(RewritePatternSet &patterns) {
void linalg::populateDecomposePadPatterns(RewritePatternSet &patterns) {
patterns.add<DecomposePadOpPattern>(patterns.getContext());
-}
\ No newline at end of file
+}
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 1a584a387f2a5..59434dccc117b 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1124,7 +1124,7 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
}
} // else dim.getIndex is a block argument to reshape->getBlock and
// dominates reshape
- } // Check condition 2
+ } // Check condition 2
else if (dim->getBlock() != reshape->getBlock() &&
!dim.getIndex().getParentRegion()->isProperAncestor(
reshape->getParentRegion())) {
>From 4e2f00de633fbde83d6cc967c442c75d809f0536 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <hyunsungl at nvidia.com>
Date: Tue, 25 Mar 2025 15:45:58 +0900
Subject: [PATCH 15/23] nit
---
mlir/test/Dialect/Linalg/roundtrip.mlir | 5 ++---
1 file changed, 2 insertions(+), 3 deletions(-)
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index c2e9e3fbd5423..d8e11d03bedd4 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -709,7 +709,7 @@ func.func @conv2d_channel_first_q_promote(%img: tensor<100x3x224x224xi8>, %filt:
// -----
// Test that we can lower all the way to LLVM without crashing, don't check results here.
-func.func @pack_memref(%source: memref<128x256xf32>) -> memref<8x16x8x32xf32> {
+func.func @pack_memref(%source: memref<128x256xf32>, memref<8x16x8x32xf32>) -> memref<8x16x8x32xf32> {
%dest = memref.alloc() : memref<8x16x8x32xf32>
linalg.pack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [8, 32]
into %dest : memref<128x256xf32> -> memref<8x16x8x32xf32>
@@ -718,8 +718,7 @@ func.func @pack_memref(%source: memref<128x256xf32>) -> memref<8x16x8x32xf32> {
// -----
// Test that we can lower all the way to LLVM without crashing, don't check results here.
-func.func @unpack_memref(%source: memref<16x8x8x32xf32>) -> memref<128x256xf32> {
- %dest = memref.alloc() : memref<128x256xf32>
+func.func @unpack_memref(%source: memref<16x8x8x32xf32>, %dest: memref<128x256xf32>) -> memref<128x256xf32> {
linalg.unpack %source inner_dims_pos = [0, 1] inner_tiles = [8, 32]
into %dest : memref<16x8x8x32xf32> -> memref<128x256xf32>
return %dest : memref<128x256xf32>
>From ee7a42a0c739bd4c56d0ce82318199ea01874491 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <hyunsungl at nvidia.com>
Date: Thu, 27 Mar 2025 14:09:08 +0900
Subject: [PATCH 16/23] fix upon review: Add getEffects for PackOp and UnPackOp
---
.../Dialect/Linalg/IR/LinalgRelayoutOps.td | 7 ---
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 54 +++++++++++++++++++
.../Dialect/Linalg/Transforms/Transforms.cpp | 3 +-
.../Linalg/Transforms/Vectorization.cpp | 3 +-
mlir/test/Dialect/Linalg/roundtrip.mlir | 3 +-
5 files changed, 59 insertions(+), 11 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index 03da3d38ef4c5..980e99872b9a6 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -76,13 +76,6 @@ class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
/// have been tiled. Also, the order of the output dimensions is consistent
/// with `inner_dims_pos` rather than the packed tensor.
SmallVector<int64_t> getTiledOuterDims();
-
- void $cppClass::getEffects(
- SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
- &effects) {
- getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
- }
-
}];
let hasVerifier = 1;
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 1515d648bddca..93ca2581f2a3d 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -4803,6 +4803,60 @@ PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc,
getPaddingValue(), metadata.outerDimsPerm);
}
+void PackOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ // No memory effects for pure tensor semantics
+ if (hasPureTensorSemantics())
+ return;
+
+ for (OpOperand &opOperand : getOperation()->getOpOperands()) {
+ if (!llvm::isa<MemRefType>(opOperand.get().getType()))
+ continue;
+
+ if (&opOperand == &getSourceMutable()) {
+ effects.emplace_back(MemoryEffects::Read::get(), &opOperand, /*stage=*/0,
+ /*effectOnFullRegion=*/true,
+ SideEffects::DefaultResource::get());
+ }
+ else if (&opOperand == &getDestMutable()) {
+ effects.emplace_back(MemoryEffects::Read::get(), &opOperand, /*stage=*/0,
+ /*effectOnFullRegion=*/true,
+ SideEffects::DefaultResource::get());
+ effects.emplace_back(MemoryEffects::Write::get(), &opOperand, /*stage=*/0,
+ /*effectOnFullRegion=*/true,
+ SideEffects::DefaultResource::get());
+ }
+ }
+}
+
+void UnPackOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ // No memory effects for pure tensor semantics
+ if (hasPureTensorSemantics())
+ return;
+
+ for (OpOperand &opOperand : getOperation()->getOpOperands()) {
+ if (!llvm::isa<MemRefType>(opOperand.get().getType()))
+ continue;
+
+ if (&opOperand == &getSourceMutable()) {
+ effects.emplace_back(MemoryEffects::Read::get(), &opOperand, /*stage=*/0,
+ /*effectOnFullRegion=*/true,
+ SideEffects::DefaultResource::get());
+ }
+ else if (&opOperand == &getDestMutable()) {
+ effects.emplace_back(MemoryEffects::Read::get(), &opOperand, /*stage=*/0,
+ /*effectOnFullRegion=*/true,
+ SideEffects::DefaultResource::get());
+ effects.emplace_back(MemoryEffects::Write::get(), &opOperand, /*stage=*/0,
+ /*effectOnFullRegion=*/true,
+ SideEffects::DefaultResource::get());
+ }
+ }
+}
+
/// Returns true if the tiles and the tiled dims are constant.
template <typename OpTy>
bool areTilesAndTiledDimsAllConstant(OpTy op) {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index dcd50cc44f81b..2ae6474cf3a2f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -359,7 +359,8 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(unPackOp);
- RankedTensorType packedTensorType = unPackOp.getSourceType();
+ // TODO: support non-ranked tensor types. ShapedType
+ RankedTensorType packedTensorType = dyn_cast<RankedTensorType>(unPackOp.getSourceType());
int64_t packedRank = packedTensorType.getRank();
OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 2dcd897330d1e..3b91b897bcfd4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1669,7 +1669,8 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(unpackOp);
- RankedTensorType unpackTensorType = unpackOp.getSourceType();
+ // TODO: support non-ranked tensor types. ShapedType
+ RankedTensorType unpackTensorType = dyn_cast<RankedTensorType>(unpackOp.getSourceType());
ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index d8e11d03bedd4..7ca20f684583a 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -709,8 +709,7 @@ func.func @conv2d_channel_first_q_promote(%img: tensor<100x3x224x224xi8>, %filt:
// -----
// Test that we can lower all the way to LLVM without crashing, don't check results here.
-func.func @pack_memref(%source: memref<128x256xf32>, memref<8x16x8x32xf32>) -> memref<8x16x8x32xf32> {
- %dest = memref.alloc() : memref<8x16x8x32xf32>
+func.func @pack_memref(%source: memref<128x256xf32>, %dest: memref<8x16x8x32xf32>) -> memref<8x16x8x32xf32> {
linalg.pack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [8, 32]
into %dest : memref<128x256xf32> -> memref<8x16x8x32xf32>
return %dest : memref<8x16x8x32xf32>
>From 5b95ee88d4bd1e4304c73383c3c03308598d0ae6 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <hyunsungl at nvidia.com>
Date: Thu, 27 Mar 2025 14:15:52 +0900
Subject: [PATCH 17/23] make clang-format happy
---
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 30 +++++++++----------
.../Dialect/Linalg/Transforms/Transforms.cpp | 3 +-
2 files changed, 16 insertions(+), 17 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 93ca2581f2a3d..7587178dd94d2 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -4816,16 +4816,15 @@ void PackOp::getEffects(
if (&opOperand == &getSourceMutable()) {
effects.emplace_back(MemoryEffects::Read::get(), &opOperand, /*stage=*/0,
- /*effectOnFullRegion=*/true,
- SideEffects::DefaultResource::get());
- }
- else if (&opOperand == &getDestMutable()) {
+ /*effectOnFullRegion=*/true,
+ SideEffects::DefaultResource::get());
+ } else if (&opOperand == &getDestMutable()) {
effects.emplace_back(MemoryEffects::Read::get(), &opOperand, /*stage=*/0,
- /*effectOnFullRegion=*/true,
- SideEffects::DefaultResource::get());
+ /*effectOnFullRegion=*/true,
+ SideEffects::DefaultResource::get());
effects.emplace_back(MemoryEffects::Write::get(), &opOperand, /*stage=*/0,
- /*effectOnFullRegion=*/true,
- SideEffects::DefaultResource::get());
+ /*effectOnFullRegion=*/true,
+ SideEffects::DefaultResource::get());
}
}
}
@@ -4843,16 +4842,15 @@ void UnPackOp::getEffects(
if (&opOperand == &getSourceMutable()) {
effects.emplace_back(MemoryEffects::Read::get(), &opOperand, /*stage=*/0,
- /*effectOnFullRegion=*/true,
- SideEffects::DefaultResource::get());
- }
- else if (&opOperand == &getDestMutable()) {
+ /*effectOnFullRegion=*/true,
+ SideEffects::DefaultResource::get());
+ } else if (&opOperand == &getDestMutable()) {
effects.emplace_back(MemoryEffects::Read::get(), &opOperand, /*stage=*/0,
- /*effectOnFullRegion=*/true,
- SideEffects::DefaultResource::get());
+ /*effectOnFullRegion=*/true,
+ SideEffects::DefaultResource::get());
effects.emplace_back(MemoryEffects::Write::get(), &opOperand, /*stage=*/0,
- /*effectOnFullRegion=*/true,
- SideEffects::DefaultResource::get());
+ /*effectOnFullRegion=*/true,
+ SideEffects::DefaultResource::get());
}
}
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 2ae6474cf3a2f..75afcb1fec332 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -360,7 +360,8 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
rewriter.setInsertionPoint(unPackOp);
// TODO: support non-ranked tensor types. ShapedType
- RankedTensorType packedTensorType = dyn_cast<RankedTensorType>(unPackOp.getSourceType());
+ RankedTensorType packedTensorType =
+ dyn_cast<RankedTensorType>(unPackOp.getSourceType());
int64_t packedRank = packedTensorType.getRank();
OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
>From 8b5ac5abd85b35ced34839b955247103341dd9a0 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <hyunsungl at nvidia.com>
Date: Thu, 27 Mar 2025 14:21:30 +0900
Subject: [PATCH 18/23] make clang-format happy
---
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 3b91b897bcfd4..f716ff97f7cf3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1670,7 +1670,8 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
rewriter.setInsertionPoint(unpackOp);
// TODO: support non-ranked tensor types. ShapedType
- RankedTensorType unpackTensorType = dyn_cast<RankedTensorType>(unpackOp.getSourceType());
+ RankedTensorType unpackTensorType =
+ dyn_cast<RankedTensorType>(unpackOp.getSourceType());
ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
>From c955d2137b454af779dedb12cd933da529140846 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <hyunsungl at nvidia.com>
Date: Fri, 28 Mar 2025 07:34:26 +0900
Subject: [PATCH 19/23] wrap getEffects function
---
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 43 +++++++++---------------
1 file changed, 15 insertions(+), 28 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 7587178dd94d2..63977d7165e36 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -4803,22 +4803,23 @@ PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc,
getPaddingValue(), metadata.outerDimsPerm);
}
-void PackOp::getEffects(
- SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
- &effects) {
+template <typename OpTy>
+static void getEffectsImpl(
+ OpTy op, SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
// No memory effects for pure tensor semantics
- if (hasPureTensorSemantics())
+ if (op.hasPureTensorSemantics())
return;
- for (OpOperand &opOperand : getOperation()->getOpOperands()) {
+ for (OpOperand &opOperand : op.getOperation()->getOpOperands()) {
if (!llvm::isa<MemRefType>(opOperand.get().getType()))
continue;
- if (&opOperand == &getSourceMutable()) {
+ if (&opOperand == &op.getSourceMutable()) {
effects.emplace_back(MemoryEffects::Read::get(), &opOperand, /*stage=*/0,
/*effectOnFullRegion=*/true,
SideEffects::DefaultResource::get());
- } else if (&opOperand == &getDestMutable()) {
+ } else if (&opOperand == &op.getDestMutable()) {
effects.emplace_back(MemoryEffects::Read::get(), &opOperand, /*stage=*/0,
/*effectOnFullRegion=*/true,
SideEffects::DefaultResource::get());
@@ -4829,30 +4830,16 @@ void PackOp::getEffects(
}
}
-void UnPackOp::getEffects(
+void PackOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
- // No memory effects for pure tensor semantics
- if (hasPureTensorSemantics())
- return;
-
- for (OpOperand &opOperand : getOperation()->getOpOperands()) {
- if (!llvm::isa<MemRefType>(opOperand.get().getType()))
- continue;
+ getEffectsImpl(*this, effects);
+}
- if (&opOperand == &getSourceMutable()) {
- effects.emplace_back(MemoryEffects::Read::get(), &opOperand, /*stage=*/0,
- /*effectOnFullRegion=*/true,
- SideEffects::DefaultResource::get());
- } else if (&opOperand == &getDestMutable()) {
- effects.emplace_back(MemoryEffects::Read::get(), &opOperand, /*stage=*/0,
- /*effectOnFullRegion=*/true,
- SideEffects::DefaultResource::get());
- effects.emplace_back(MemoryEffects::Write::get(), &opOperand, /*stage=*/0,
- /*effectOnFullRegion=*/true,
- SideEffects::DefaultResource::get());
- }
- }
+void UnPackOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ getEffectsImpl(*this, effects);
}
/// Returns true if the tiles and the tiled dims are constant.
>From 276069d36b4bb88b628d2b29f20f6c85e76aa931 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <hyunsungl at nvidia.com>
Date: Sun, 30 Mar 2025 08:47:51 +0900
Subject: [PATCH 20/23] fix upon review
---
.../Dialect/Linalg/IR/LinalgRelayoutOps.td | 9 +-
.../mlir/Dialect/Utils/ReshapeOpsUtils.h | 2 +-
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 101 +++++++++++++-----
.../Transforms/DataLayoutPropagation.cpp | 4 +-
mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp | 12 +--
mlir/test/Dialect/Linalg/roundtrip.mlir | 10 +-
6 files changed, 96 insertions(+), 42 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index 980e99872b9a6..bd9caa3f6b1a7 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -190,7 +190,14 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
// Method to get the `RankedTensorType` of the result based on the inner
// tiles, position of the inner tiles (innerDimsPos) and interchange vector
// of outer loops (outerDimsPerm).
- static RankedTensorType inferPackedType(ShapedType sourceType,
+ static RankedTensorType inferPackedTensorType(RankedTensorType sourceType,
+ ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
+ ArrayRef<int64_t> outerDimsPerm = {});
+
+ // Method to get the `MemRefType` of the result based on the inner
+ // tiles, position of the inner tiles (innerDimsPos) and interchange vector
+ // of outer loops (outerDimsPerm).
+ static MemRefType inferPackedMemRefType(MemRefType sourceType,
ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outerDimsPerm = {});
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index a86bf74a7b6a1..99c80a2196567 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -573,7 +573,7 @@ PackingMetadata computePackingMetadata(int64_t packedRank,
/// Removes the op and replaces the constant with a new constant of the result
/// shape. When an optional cst attribute is passed, it is reshaped only if the
/// splat value matches the value in the attribute.
-OpFoldResult reshapeConstantSource(DenseElementsAttr source, ShapedType result,
+OpFoldResult reshapeConstantSource(DenseElementsAttr source, TensorType result,
std::optional<Attribute> cst = std::nullopt);
} // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index afff911168324..0af14b12da040 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -9,8 +9,8 @@
// This file implements the Linalg operations.
//
//===----------------------------------------------------------------------===//
-
#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include <iostream>
#include "mlir/AsmParser/AsmParser.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
@@ -29,6 +29,7 @@
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Matchers.h"
@@ -45,6 +46,7 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Error.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/MathExtras.h"
@@ -4426,15 +4428,30 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
tiles, [](OpFoldResult tile) { return isConstantIntValue(tile, 0); });
};
+ // Verify that the source and destination are ranked types.
+ if (!packOrUnPack.getSourceType().hasRank() ||
+ !packOrUnPack.getDestType().hasRank()) {
+ return op->emitError(
+ "expected both source and destination to be shaped types");
+ }
+
// Verify tiles. Do not allow zero tiles.
SmallVector<OpFoldResult> mixedTiles = packOrUnPack.getMixedTiles();
if (hasZeros(mixedTiles))
return op->emitError("invalid zero tile factor");
+ // Verify that the Operation does not have mixed tensor/buffer semantics.
+ if (!packOrUnPack.hasPureBufferSemantics() &&
+ !packOrUnPack.hasPureTensorSemantics()) {
+ return op->emitError("mixing tensor and buffer semantics is not allowed");
+ }
+ bool hasTensorSemantics = packOrUnPack.hasPureTensorSemantics();
+
// Verify inner_dims_pos and outer_dims_perm.
ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
? packOrUnPack.getSourceType()
: packOrUnPack.getDestType();
+
size_t unpackedRank = unpackedType.getRank();
ArrayRef<int64_t> innerDimsPos = packOrUnPack.getInnerDimsPos();
ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm();
@@ -4471,12 +4488,17 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
// Verify result shape is greater than the minimum expected
// by the pack operation, and that the output shape
// represents full tiles.
- RankedTensorType expectedPackedType = PackOp::inferPackedType(
- unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
- if (!areAllInBound(expectedPackedType.getShape(), packedType.getShape())) {
- return op->emitError("the shape of output is not large enough to hold the "
- "packed data. Expected at least ")
- << expectedPackedType << ", got " << packedType;
+ if (hasTensorSemantics) {
+ RankedTensorType expectedPackedType = PackOp::inferPackedTensorType(
+ cast<RankedTensorType>(unpackedType), packOrUnPack.getStaticTiles(),
+ innerDimsPos, outerDimPerm);
+ if (!areAllInBound(expectedPackedType.getShape(), packedType.getShape())) {
+ return op->emitError(
+ "the shape of output is not large enough to hold the "
+ "packed data. Expected at least ")
+ << expectedPackedType << ", got " << packedType;
+ }
+ } else {
}
if (!llvm::all_of(
llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
@@ -4680,9 +4702,9 @@ asShapeWithAnyValueAsDynamic(ArrayRef<OpFoldResult> ofrs) {
return result;
}
-/// Helper for PackOp::{getResultShape,inferPackedType}. Returns the shape of
-/// the packed type. Having a shared helper helps implement these two methods in
-/// a way that ensures that they agree on which dimensions are dynamic.
+/// Helper for PackOp::{getResultShape,inferPackedTensorType}. Returns the shape
+/// of the packed type. Having a shared helper helps implement these two methods
+/// in a way that ensures that they agree on which dimensions are dynamic.
static SmallVector<int64_t> getPackOpResultTypeShape(
ArrayRef<int64_t> sourceShape, ArrayRef<int64_t> innerTileSizes,
ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm) {
@@ -4746,13 +4768,21 @@ SmallVector<OpFoldResult> PackOp::getResultShape(
/// Get the expected packed type based on source type, tile factors, position of
/// the inner tiles and permutation of the outer tiled loop.
-RankedTensorType PackOp::inferPackedType(ShapedType sourceType,
+RankedTensorType PackOp::inferPackedTensorType(
+ RankedTensorType sourceType, ArrayRef<int64_t> innerTileSizes,
+ ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm) {
+ SmallVector<int64_t> resultShape = getPackOpResultTypeShape(
+ sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
+ return RankedTensorType::get(resultShape, sourceType.getElementType());
+}
+
+MemRefType PackOp::inferPackedMemRefType(MemRefType sourceType,
ArrayRef<int64_t> innerTileSizes,
ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outerDimsPerm) {
SmallVector<int64_t> resultShape = getPackOpResultTypeShape(
sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
- return RankedTensorType::get(resultShape, sourceType.getElementType());
+ return MemRefType::get(resultShape, sourceType.getElementType());
}
Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
@@ -4802,7 +4832,7 @@ PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc,
}
template <typename OpTy>
-static void getEffectsImpl(
+static void getPackUnPackEffectsImpl(
OpTy op, SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
// No memory effects for pure tensor semantics
@@ -4831,13 +4861,13 @@ static void getEffectsImpl(
void PackOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
- getEffectsImpl(*this, effects);
+ getPackUnPackEffectsImpl(*this, effects);
}
void UnPackOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
- getEffectsImpl(*this, effects);
+ getPackUnPackEffectsImpl(*this, effects);
}
/// Returns true if the tiles and the tiled dims are constant.
@@ -4972,35 +5002,49 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
return success();
}
- // Insert tensor.cast ops if static shape inference is available..
+ // Insert either tensor.cast or memref.cast ops
+ // if static shape inference is available..
+ bool hasTensorSemantics = packOp.hasPureTensorSemantics();
+
SmallVector<int64_t> srcShape, destShape;
if (inferStaticShape(packOp, srcShape, destShape)) {
Location loc = packOp.getLoc();
Value source = packOp.getSource();
if (srcShape != packOp.getSourceType().getShape()) {
auto newSrcType = packOp.getSourceType().clone(srcShape);
- source =
- rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
+ if (hasTensorSemantics)
+ source = rewriter.create<tensor::CastOp>(loc, newSrcType,
+ packOp.getSource());
+ else
+ source = rewriter.create<memref::CastOp>(loc, newSrcType,
+ packOp.getSource());
}
Value dest = packOp.getDest();
ShapedType originalResultType = packOp.getDestType();
bool needUpdateDestType = (destShape != originalResultType.getShape());
if (needUpdateDestType) {
auto newDestType = packOp.getDestType().clone(destShape);
- dest =
- rewriter.create<tensor::CastOp>(loc, newDestType, packOp.getDest());
+ if (hasTensorSemantics)
+ dest =
+ rewriter.create<tensor::CastOp>(loc, newDestType, packOp.getDest());
}
rewriter.modifyOpInPlace(packOp, [&] {
packOp.getSourceMutable().assign(source);
packOp.getDestMutable().assign(dest);
- packOp.getResult().setType(cast<ShapedType>(dest.getType()));
+ packOp.getResult().setType(cast<RankedTensorType>(dest.getType()));
});
// Insert a cast if needed
if (needUpdateDestType) {
rewriter.setInsertionPointAfter(packOp);
- auto castOp =
- rewriter.create<tensor::CastOp>(loc, originalResultType, packOp);
- rewriter.replaceAllUsesExcept(packOp, castOp, castOp);
+ if (hasTensorSemantics) {
+ auto castOp =
+ rewriter.create<tensor::CastOp>(loc, originalResultType, packOp);
+ rewriter.replaceAllUsesExcept(packOp, castOp, castOp);
+ } else {
+ auto castOp =
+ rewriter.create<memref::CastOp>(loc, originalResultType, packOp);
+ rewriter.replaceAllUsesExcept(packOp, castOp, castOp);
+ }
}
return success();
}
@@ -5047,12 +5091,15 @@ bool PackOp::isLikePad() {
}
OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
+ if (!hasPureTensorSemantics())
+ return {};
+
std::optional<Attribute> paddingValue;
if (auto pad = adaptor.getPaddingValue())
paddingValue = pad;
if (OpFoldResult reshapedSource = reshapeConstantSource(
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
- getDestType(), paddingValue))
+ cast<TensorType>(getDestType()), paddingValue))
return reshapedSource;
return {};
}
@@ -5324,9 +5371,11 @@ bool UnPackOp::isLikeUnPad() {
}
OpFoldResult UnPackOp::fold(FoldAdaptor adaptor) {
+ if (!hasPureTensorSemantics())
+ return {};
if (OpFoldResult reshapedSource = reshapeConstantSource(
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
- getResult().getType()))
+ cast<TensorType>(getResult().getType())))
return reshapedSource;
return {};
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 9f5000b70b6f6..22bd5a8b38862 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -808,7 +808,7 @@ bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp,
// If reassociation is not possible, then reordering cannot happen.
// This can be caused by pack padding affecting previously expanded
// dimensions or packing extending dimensions.
- RankedTensorType newPackType = linalg::PackOp::inferPackedType(
+ RankedTensorType newPackType = linalg::PackOp::inferPackedTensorType(
expandOp.getSrcType(), packOp.getStaticInnerTiles(),
projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{});
auto reassocExpand =
@@ -943,7 +943,7 @@ static LogicalResult pushDownUnPackOpThroughExpandShape(
nextPos += 1;
}
- RankedTensorType newExpandType = linalg::PackOp::inferPackedType(
+ RankedTensorType newExpandType = linalg::PackOp::inferPackedTensorType(
expandTy, innerTileSizes, projectedInnerDimsPos, newOuterDimsPerm);
auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
expandOp.getLoc(), newExpandType, unPackOp.getSource(),
diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index cd0cdd378c352..86a1fb12f2b26 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -315,11 +315,11 @@ SmallVector<Range> SliceFromCollapseHelper::getExtractSliceParams(
// have proven that these are not sliced. In this case we just take
// the full extent of each dimension in the reassociation list.
if (linearizedDimensions[it.index()]) {
- llvm::append_range(
- offsetsSizesAndStrides,
- llvm::map_range(it.value(), [&](int64_t idx) -> Range {
- return {zeroAttr, collapseShapeInputShape[idx], oneAttr};
- }));
+ llvm::append_range(offsetsSizesAndStrides,
+ llvm::map_range(it.value(), [&](int64_t idx) -> Range {
+ return {zeroAttr, collapseShapeInputShape[idx],
+ oneAttr};
+ }));
continue;
}
@@ -485,7 +485,7 @@ PackingMetadata mlir::computePackingMetadata(int64_t packedRank,
}
OpFoldResult mlir::reshapeConstantSource(DenseElementsAttr source,
- ShapedType result,
+ TensorType result,
std::optional<Attribute> cst) {
if (source && source.isSplat() && result.hasStaticShape() &&
(!cst.has_value() || source.getSplatValue<Attribute>() == cst.value()))
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 7ca20f684583a..550d717570e69 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -708,17 +708,15 @@ func.func @conv2d_channel_first_q_promote(%img: tensor<100x3x224x224xi8>, %filt:
// CHECK: linalg.conv_2d_nchw_fchw_q {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]] : tensor<100x3x224x224xi8>, tensor<64x3x5x5xi8>, i8, i8) outs(%{{.*}} : tensor<100x64x220x220xi32>) -> tensor<100x64x220x220xi32>
// -----
-// Test that we can lower all the way to LLVM without crashing, don't check results here.
-func.func @pack_memref(%source: memref<128x256xf32>, %dest: memref<8x16x8x32xf32>) -> memref<8x16x8x32xf32> {
+func.func @pack_memref(%source: memref<128x256xf32>, %dest: memref<8x16x8x32xf32>) {
linalg.pack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [8, 32]
into %dest : memref<128x256xf32> -> memref<8x16x8x32xf32>
- return %dest : memref<8x16x8x32xf32>
+ return
}
// -----
-// Test that we can lower all the way to LLVM without crashing, don't check results here.
-func.func @unpack_memref(%source: memref<16x8x8x32xf32>, %dest: memref<128x256xf32>) -> memref<128x256xf32> {
+func.func @unpack_memref(%source: memref<16x8x8x32xf32>, %dest: memref<128x256xf32>) {
linalg.unpack %source inner_dims_pos = [0, 1] inner_tiles = [8, 32]
into %dest : memref<16x8x8x32xf32> -> memref<128x256xf32>
- return %dest : memref<128x256xf32>
+ return
}
>From 790e974e544fd8552cc668a621795f661b292247 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <hyunsungl at nvidia.com>
Date: Sun, 30 Mar 2025 18:01:35 +0900
Subject: [PATCH 21/23] bail out transforms using PackOp, UnPackOp
---
.../Linalg/Transforms/BlockPackMatmul.cpp | 5 ++
.../Transforms/DataLayoutPropagation.cpp | 52 +++++++++++++++++++
.../Linalg/Transforms/Vectorization.cpp | 25 +++++++++
3 files changed, 82 insertions(+)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
index 81842e4bea631..0b3d86d51ca0a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
@@ -91,6 +91,11 @@ transposePackedMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
linalg::PackOp packOp, AffineMap operandMap,
ArrayRef<unsigned> blocksStartDimPos,
bool transposeOuterBlocks, bool transposeInnerBlocks) {
+ // TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
+ if (!packOp.hasPureTensorSemantics()) {
+ return failure();
+ }
+
assert(operandMap.getNumDims() >= 4 &&
"expected at least 4D prepacked matmul");
assert(blocksStartDimPos.size() >= 2 &&
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 22bd5a8b38862..ced3719ff8c3e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -63,6 +63,12 @@ getPackingInfoFromOperand(OpOperand *opOperand, linalg::GenericOp genericOp,
OpTy packOrUnPackOp) {
static_assert(llvm::is_one_of<OpTy, linalg::PackOp, linalg::UnPackOp>::value,
"applies to only pack or unpack operations");
+ // TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
+ if (isa<linalg::LinalgOp>(packOrUnPackOp)) {
+ if (!packOrUnPackOp.hasPureTensorSemantics()) {
+ return failure();
+ }
+ }
LLVM_DEBUG(
{ llvm::dbgs() << "--- Construct PackInfo From an operand ---\n"; });
@@ -373,6 +379,11 @@ static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
static FailureOr<GenericOp>
bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp,
const ControlPropagationFn &controlFn) {
+ // TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
+ if (!packOp.hasPureTensorSemantics()) {
+ return failure();
+ }
+
auto genericOp = packOp.getSource().getDefiningOp<GenericOp>();
if (!genericOp)
return failure();
@@ -461,6 +472,11 @@ struct BubbleUpPackOpThroughGenericOpPattern
LogicalResult matchAndRewrite(linalg::PackOp packOp,
PatternRewriter &rewriter) const override {
+ // TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
+ if (!packOp.hasPureTensorSemantics()) {
+ return failure();
+ }
+
auto genericOp =
bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn);
if (failed(genericOp))
@@ -483,6 +499,11 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<linalg::PackOp> {
LogicalResult matchAndRewrite(linalg::PackOp packOp,
PatternRewriter &rewriter) const override {
+ // TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
+ if (!packOp.hasPureTensorSemantics()) {
+ return failure();
+ }
+
auto padOp = packOp.getSource().getDefiningOp<tensor::PadOp>();
if (!padOp)
return failure();
@@ -651,6 +672,11 @@ static LogicalResult
bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
linalg::PackOp packOp,
PatternRewriter &rewriter) {
+ // TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
+ if (!packOp.hasPureTensorSemantics()) {
+ return failure();
+ }
+
SmallVector<int64_t> innerTileSizes = packOp.getStaticTiles();
ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
@@ -757,6 +783,11 @@ static LogicalResult
bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp,
linalg::PackOp packOp,
PatternRewriter &rewriter) {
+ // TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
+ if (!packOp.hasPureTensorSemantics()) {
+ return failure();
+ }
+
// Outer dimensions permutation is not supported currently.
// TODO: Handle outer_dims_perm variants.
ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
@@ -840,6 +871,11 @@ class BubbleUpPackOpThroughReshapeOp final
LogicalResult matchAndRewrite(linalg::PackOp packOp,
PatternRewriter &rewriter) const override {
+ // TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
+ if (!packOp.hasPureTensorSemantics()) {
+ return failure();
+ }
+
Operation *srcOp = packOp.getSource().getDefiningOp();
// Currently only support when the pack op is the only user.
if (!srcOp || !(srcOp->getNumResults() == 1) ||
@@ -893,6 +929,11 @@ class BubbleUpPackOpThroughReshapeOp final
static LogicalResult pushDownUnPackOpThroughExpandShape(
linalg::UnPackOp unPackOp, tensor::ExpandShapeOp expandOp,
PatternRewriter &rewriter, ControlPropagationFn controlFn) {
+ // TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
+ if (!unPackOp.hasPureTensorSemantics()) {
+ return failure();
+ }
+
// User controlled propagation function.
if (!controlFn(&expandOp.getSrcMutable()))
return failure();
@@ -970,6 +1011,11 @@ class PushDownUnPackOpThroughReshapeOp final
LogicalResult matchAndRewrite(linalg::UnPackOp unPackOp,
PatternRewriter &rewriter) const override {
+ // TODO(issues/129004): Support MemRef UnPackOp. Temporarily return failure.
+ if (!unPackOp.hasPureTensorSemantics()) {
+ return failure();
+ }
+
Value result = unPackOp.getResult();
// Currently only support unpack op with the single user.
if (!result.hasOneUse()) {
@@ -1146,11 +1192,17 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
LogicalResult matchAndRewrite(tensor::PadOp padOp,
PatternRewriter &rewriter) const override {
+
linalg::UnPackOp unpackOp =
padOp.getSource().getDefiningOp<linalg::UnPackOp>();
+
if (!unpackOp)
return failure();
+ // TODO(issues/129004): Support MemRef PadOp. Temporarily return failure.
+ if (!unpackOp.hasPureTensorSemantics())
+ return failure();
+
if (!controlFn(&padOp.getSourceMutable()))
return failure();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index f716ff97f7cf3..aba729ec3f5cd 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1588,6 +1588,11 @@ static LogicalResult
vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
ArrayRef<int64_t> inputVectorSizes,
SmallVectorImpl<Value> &newResults) {
+ // TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
+ if (!packOp.hasPureTensorSemantics()) {
+ return failure();
+ }
+
// TODO: Introduce a parent class that will handle the insertion point update.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(packOp);
@@ -1664,6 +1669,10 @@ static LogicalResult
vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
ArrayRef<int64_t> inputVectorSizes,
SmallVectorImpl<Value> &newResults) {
+ // TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
+ if (!unpackOp.hasPureTensorSemantics()) {
+ return failure();
+ }
// TODO: Introduce a parent class that will handle the insertion point update.
OpBuilder::InsertionGuard g(rewriter);
@@ -1891,6 +1900,10 @@ vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
static LogicalResult
vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp,
ArrayRef<int64_t> inputVectorSizes) {
+ // TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
+ if (!unpackOp.hasPureTensorSemantics()) {
+ return failure();
+ }
if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) {
return !getConstantIntValue(res).has_value();
@@ -2136,6 +2149,11 @@ static LogicalResult vectorizeLinalgOpPrecondition(
static LogicalResult
vectorizePackOpPrecondition(linalg::PackOp packOp,
ArrayRef<int64_t> inputVectorSizes) {
+ // TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
+ if (!packOp.hasPureTensorSemantics()) {
+ return failure();
+ }
+
auto padValue = packOp.getPaddingValue();
Attribute cstAttr;
if (padValue && !matchPattern(padValue, m_Constant(&cstAttr))) {
@@ -2358,6 +2376,13 @@ static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
}
bool mlir::linalg::hasVectorizationImpl(Operation *op) {
+ // TODO(issues/129004): Support MemRef PackOp. Temporarily return false.
+ // Actually do we need this?
+ if (isa<linalg::PackOp, linalg::UnPackOp>(op)) {
+ if (!cast<LinalgOp>(op).hasPureTensorSemantics()) {
+ return false;
+ }
+ }
return isa<linalg::LinalgOp, tensor::PadOp, linalg::PackOp, linalg::UnPackOp,
tensor::InsertSliceOp>(op);
}
>From 820e40b994b9b26b92c7f184b2b9a01c1328d489 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <hyunsungl at nvidia.com>
Date: Sun, 30 Mar 2025 19:23:21 +0900
Subject: [PATCH 22/23] fix build error
---
mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index ced3719ff8c3e..199011ac901ce 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -64,8 +64,9 @@ getPackingInfoFromOperand(OpOperand *opOperand, linalg::GenericOp genericOp,
static_assert(llvm::is_one_of<OpTy, linalg::PackOp, linalg::UnPackOp>::value,
"applies to only pack or unpack operations");
// TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
- if (isa<linalg::LinalgOp>(packOrUnPackOp)) {
- if (!packOrUnPackOp.hasPureTensorSemantics()) {
+ if (auto linalgOp =
+ dyn_cast<linalg::LinalgOp>(packOrUnPackOp.getOperation())) {
+ if (!linalgOp.hasPureTensorSemantics()) {
return failure();
}
}
>From 43a64b912adaa2eed85d9715c13c3057c2c4b53e Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <hyunsungl at nvidia.com>
Date: Sun, 30 Mar 2025 20:38:34 +0900
Subject: [PATCH 23/23] fix build error
---
.../Dialect/Linalg/Transforms/Transforms.cpp | 25 +++++++++++++++++++
.../Linalg/Transforms/Vectorization.cpp | 7 ------
2 files changed, 25 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 75afcb1fec332..63c0e4d126c9a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -219,6 +219,11 @@ struct PackedOperandsDimList {
FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
linalg::PackOp packOp,
bool lowerPadLikeWithInsertSlice) {
+ // TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
+ if (!packOp.hasPureTensorSemantics()) {
+ return failure();
+ }
+
// 1. Filter out NYI cases.
auto packedTensorType =
cast<RankedTensorType>(packOp->getResultTypes().front());
@@ -355,6 +360,11 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
FailureOr<LowerUnPackOpResult>
linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
bool lowerUnpadLikeWithExtractSlice) {
+ // TODO(issues/129004): Support MemRef UnPackOp. Temporarily return failure.
+ if (!unPackOp.hasPureTensorSemantics()) {
+ return failure();
+ }
+
Location loc = unPackOp->getLoc();
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(unPackOp);
@@ -1032,6 +1042,11 @@ static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
return input;
}
+ // TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
+ if (!packOp.hasPureTensorSemantics()) {
+ return packOp.getSource();
+ }
+
assert(llvm::all_of(packOp.getAllOuterDims(),
[](int64_t val) { return val == 1; }) &&
"some outer dims are != 1");
@@ -1144,6 +1159,11 @@ getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape,
LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
linalg::PackOp packOp, PatternRewriter &rewriter) const {
+ // TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
+ if (!packOp.hasPureTensorSemantics()) {
+ return failure();
+ }
+
// TODO: support the case that outer dimensions are not all 1s. A
// tensor.expand_shape will be generated in this case.
if (llvm::any_of(packOp.getAllOuterDims(),
@@ -1245,6 +1265,11 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
linalg::UnPackOp unpackOp, PatternRewriter &rewriter) const {
+ // TODO(issues/129004): Support MemRef UnPackOp. Temporarily return failure.
+ if (!unpackOp.hasPureTensorSemantics()) {
+ return failure();
+ }
+
int64_t srcRank = unpackOp.getSourceRank();
int64_t destRank = unpackOp.getDestRank();
ArrayRef<int64_t> srcShape = unpackOp.getSourceType().getShape();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index aba729ec3f5cd..8936f9d9e389e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2376,13 +2376,6 @@ static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
}
bool mlir::linalg::hasVectorizationImpl(Operation *op) {
- // TODO(issues/129004): Support MemRef PackOp. Temporarily return false.
- // Actually do we need this?
- if (isa<linalg::PackOp, linalg::UnPackOp>(op)) {
- if (!cast<LinalgOp>(op).hasPureTensorSemantics()) {
- return false;
- }
- }
return isa<linalg::LinalgOp, tensor::PadOp, linalg::PackOp, linalg::UnPackOp,
tensor::InsertSliceOp>(op);
}
More information about the Mlir-commits
mailing list