[Mlir-commits] [mlir] [mlir][tensor] Add new helper hooks for RelayoutOp (PR #109642)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Mon Sep 23 10:39:58 PDT 2024
https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/109642
>From 469274f99c2128d2fd606b7ab9642c79714c56eb Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Sat, 21 Sep 2024 16:14:25 +0100
Subject: [PATCH 1/3] [mlir][tensor] Add new helper hooks to RelayoutOp
Implements two helper hooks for PackOp and UnPackOP, `getAllOuterDims`
and `getTiledOuterDims`, and adds them to RelayoutOp (that both PackOp
an UnPackOp inherit from).
This improves code re-use and also clarifies the meaning of "outer dims"
and "tiled outer dims".
---
.../mlir/Dialect/Tensor/IR/TensorOps.td | 19 ++++++++++++++-
.../Dialect/Linalg/Transforms/Transforms.cpp | 23 ++++++++-----------
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 22 ++++++++++++++++++
3 files changed, 50 insertions(+), 14 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index cafc3d91fd1e9d..9fee75c6a2ca3d 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1814,7 +1814,7 @@ def Tensor_SplatOp : Tensor_Op<"splat", [
}
//===----------------------------------------------------------------------===//
-// PackOp
+// RelayoutOp
//===----------------------------------------------------------------------===//
class Tensor_RelayoutOp<string mnemonic, list<Trait> traits = []> :
@@ -1851,11 +1851,28 @@ class Tensor_RelayoutOp<string mnemonic, list<Trait> traits = []> :
/// a sentinel `kDynamic` is introduced at that position in
/// the returned vector.
SmallVector<int64_t> getStaticTiles();
+
+ /// Retrieve all outer dims for this Pack/UnPack Op, i.e. all the leading
+ /// dims excluding the trailing dims corresponding to `innerTiles`. Note
+ /// that this will include both tiled and non-tiled dimensions.
+ ArrayRef<int64_t> getAllOuterDims() {
+ ShapedType inputType = getSourceType();
+ int64_t inputRank = inputType.getRank();
+ return getDestType().getShape().take_front(inputRank);
+ }
+
+ /// Similar to `getAllOuterDims`, but only retrieve the outer dims that
+ /// have been tiled.
+ SmallVector<int64_t> getTiledOuterDims();
}];
let hasVerifier = 1;
}
+//===----------------------------------------------------------------------===//
+// PackOp
+//===----------------------------------------------------------------------===//
+
def Tensor_PackOp : Tensor_RelayoutOp<"pack", [
AttrSizedOperandSegments]> {
let summary = "tensor pack operation";
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 77f0ea9d2236ea..e0dea8e78d55c1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1030,11 +1030,13 @@ static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
return input;
}
+ assert(llvm::all_of(packOp.getAllOuterDims(),
+ [](int64_t val) { return val == 1; }) &&
+ "some outer dims are != 1");
+
Location loc = packOp.getLoc();
ShapedType inputType = packOp.getSourceType();
int64_t inputRank = inputType.getRank();
- assert(llvm::all_of(packOp.getDestType().getShape().take_front(inputRank),
- [](int64_t val) { return val == 1; }));
SmallVector<int64_t> paddedShape;
DenseMap<int64_t, OpFoldResult> tileAndPosMapping =
@@ -1126,12 +1128,8 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
// TODO: support the case that outer dimensions are not all 1s. A
// tensor.expand_shape will be generated in this case.
- auto innerDimsPos = packOp.getInnerDimsPos();
- int64_t srcRank = packOp.getSourceRank();
- auto destShape = packOp.getDestType().getShape();
- if (llvm::any_of(innerDimsPos, [destShape](int64_t index) {
- return destShape[index] != 1;
- })) {
+ if (llvm::any_of(packOp.getTiledOuterDims(),
+ [](int64_t dim) { return dim != 1; })) {
return rewriter.notifyMatchFailure(
packOp, "require the tiled outer dimensions of the result are all 1s");
}
@@ -1145,6 +1143,7 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
packOp.getDimAndTileMapping();
Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
Attribute oneIdxAttr = rewriter.getIndexAttr(1);
+ int64_t srcRank = packOp.getSourceRank();
SmallVector<OpFoldResult> readOffsets(srcRank, zeroIdxAttr);
SmallVector<OpFoldResult> readStrides(srcRank, oneIdxAttr);
SmallVector<OpFoldResult> readSizes;
@@ -1173,9 +1172,8 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
loc, readType, input, readOffsets, readSizes, readStrides);
// 2. Transpose the tile to match the inner tile order.
-
SmallVector<int64_t> perm = getPackUnpackRankReducedPerm(
- inputShape, innerDimsPos, packOp.getOuterDimsPerm());
+ inputShape, packOp.getInnerDimsPos(), packOp.getOuterDimsPerm());
LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n";
llvm::interleaveComma(perm, DBGS() << "perm: "); DBGSNL(););
@@ -1208,9 +1206,8 @@ LogicalResult GeneralizeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
int64_t destRank = unpackOp.getDestRank();
ArrayRef<int64_t> srcShape = unpackOp.getSourceType().getShape();
ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos();
- if (llvm::any_of(innerDimsPos, [srcShape](int64_t index) {
- return srcShape[index] != 1;
- })) {
+ if (llvm::any_of(unpackOp.getTiledOuterDims(),
+ [](int64_t dim) { return dim != 1; })) {
return rewriter.notifyMatchFailure(
unpackOp,
"require the tiled outer dimensions of the result are all 1s");
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 47f540e092e990..bc7deb1614d18d 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3987,6 +3987,17 @@ SmallVector<int64_t> PackOp::getStaticTiles() {
return getStaticTilesImpl(*this);
}
+SmallVector<int64_t> PackOp::getTiledOuterDims() {
+ auto innerDimsPos = getInnerDimsPos();
+ auto destShape = getDestType().getShape();
+ SmallVector<int64_t> res;
+
+ for (auto index : innerDimsPos)
+ res.push_back(destShape[index]);
+
+ return res;
+}
+
bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outputShape,
@@ -4411,6 +4422,17 @@ SmallVector<int64_t> UnPackOp::getStaticTiles() {
return getStaticTilesImpl(*this);
}
+SmallVector<int64_t> UnPackOp::getTiledOuterDims() {
+ auto innerDimsPos = getInnerDimsPos();
+ auto destShape = getSourceType().getShape();
+ SmallVector<int64_t> res;
+
+ for (auto index : innerDimsPos)
+ res.push_back(destShape[index]);
+
+ return res;
+}
+
LogicalResult UnPackOp::verify() {
return commonVerifierPackAndUnPackOp(*this);
}
>From d6bc07fa1f7c67f4873df0986aedc28fd3f26f1a Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Mon, 23 Sep 2024 14:58:20 +0100
Subject: [PATCH 2/3] fixup! [mlir][tensor] Add new helper hooks to RelayoutOp
Remove empty space
---
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 9fee75c6a2ca3d..7b57b503ea56f0 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1861,7 +1861,7 @@ class Tensor_RelayoutOp<string mnemonic, list<Trait> traits = []> :
return getDestType().getShape().take_front(inputRank);
}
- /// Similar to `getAllOuterDims`, but only retrieve the outer dims that
+ /// Similar to `getAllOuterDims`, but only retrieve the outer dims that
/// have been tiled.
SmallVector<int64_t> getTiledOuterDims();
}];
>From 534e096b789c78410ad1e15f777e06d4e6a59d4c Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Mon, 23 Sep 2024 18:21:08 +0100
Subject: [PATCH 3/3] fixup! fixup! [mlir][tensor] Add new helper hooks to
RelayoutOp
Add comments, specialize getAllOuterDims
---
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td | 13 ++++++-------
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 12 ++++++++++++
2 files changed, 18 insertions(+), 7 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 7b57b503ea56f0..3170115883e2be 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1854,15 +1854,14 @@ class Tensor_RelayoutOp<string mnemonic, list<Trait> traits = []> :
/// Retrieve all outer dims for this Pack/UnPack Op, i.e. all the leading
/// dims excluding the trailing dims corresponding to `innerTiles`. Note
- /// that this will include both tiled and non-tiled dimensions.
- ArrayRef<int64_t> getAllOuterDims() {
- ShapedType inputType = getSourceType();
- int64_t inputRank = inputType.getRank();
- return getDestType().getShape().take_front(inputRank);
- }
+ /// that this will include both tiled and non-tiled dimensions. The order
+ /// of the output dimensions is consistent with the shape of the packed
+ /// tensor.
+ ArrayRef<int64_t> getAllOuterDims();
/// Similar to `getAllOuterDims`, but only retrieve the outer dims that
- /// have been tiled.
+ /// have been tiled. Also, the order of the output dimensions is consistent
+ /// with `inner_dims_pos` rather than the packed tensor.
SmallVector<int64_t> getTiledOuterDims();
}];
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index bc7deb1614d18d..d0ddde96b0b231 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3987,6 +3987,12 @@ SmallVector<int64_t> PackOp::getStaticTiles() {
return getStaticTilesImpl(*this);
}
+ArrayRef<int64_t> PackOp::getAllOuterDims() {
+ ShapedType inputType = getSourceType();
+ int64_t inputRank = inputType.getRank();
+ return getDestType().getShape().take_front(inputRank);
+}
+
SmallVector<int64_t> PackOp::getTiledOuterDims() {
auto innerDimsPos = getInnerDimsPos();
auto destShape = getDestType().getShape();
@@ -4422,6 +4428,12 @@ SmallVector<int64_t> UnPackOp::getStaticTiles() {
return getStaticTilesImpl(*this);
}
+ArrayRef<int64_t> UnPackOp::getAllOuterDims() {
+ ShapedType destType = getDestType();
+ int64_t destRank = destType.getRank();
+ return getSourceType().getShape().take_front(destRank);
+}
+
SmallVector<int64_t> UnPackOp::getTiledOuterDims() {
auto innerDimsPos = getInnerDimsPos();
auto destShape = getSourceType().getShape();
More information about the Mlir-commits
mailing list