[Mlir-commits] [mlir] [mlir][tensor] Add new helper hooks for RelayoutOp (PR #109642)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Sep 23 02:42:36 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-tensor
Author: Andrzej WarzyĆski (banach-space)
<details>
<summary>Changes</summary>
Implements two helper hooks for PackOp and UnPackOP, `getAllOuterDims`
and `getTiledOuterDims`, and adds them to RelayoutOp (that both PackOp
an UnPackOp inherit from).
This improves code re-use and also clarifies the meaning of "outer dims"
and "tiled outer dims".
---
Full diff: https://github.com/llvm/llvm-project/pull/109642.diff
3 Files Affected:
- (modified) mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td (+18-1)
- (modified) mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp (+10-13)
- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+22)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index cafc3d91fd1e9d..9fee75c6a2ca3d 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1814,7 +1814,7 @@ def Tensor_SplatOp : Tensor_Op<"splat", [
}
//===----------------------------------------------------------------------===//
-// PackOp
+// RelayoutOp
//===----------------------------------------------------------------------===//
class Tensor_RelayoutOp<string mnemonic, list<Trait> traits = []> :
@@ -1851,11 +1851,28 @@ class Tensor_RelayoutOp<string mnemonic, list<Trait> traits = []> :
/// a sentinel `kDynamic` is introduced at that position in
/// the returned vector.
SmallVector<int64_t> getStaticTiles();
+
+ /// Retrieve all outer dims for this Pack/UnPack Op, i.e. all the leading
+ /// dims excluding the trailing dims corresponding to `innerTiles`. Note
+ /// that this will include both tiled and non-tiled dimensions.
+ ArrayRef<int64_t> getAllOuterDims() {
+ ShapedType inputType = getSourceType();
+ int64_t inputRank = inputType.getRank();
+ return getDestType().getShape().take_front(inputRank);
+ }
+
+ /// Similar to `getAllOuterDims`, but only retrieve the outer dims that
+ /// have been tiled.
+ SmallVector<int64_t> getTiledOuterDims();
}];
let hasVerifier = 1;
}
+//===----------------------------------------------------------------------===//
+// PackOp
+//===----------------------------------------------------------------------===//
+
def Tensor_PackOp : Tensor_RelayoutOp<"pack", [
AttrSizedOperandSegments]> {
let summary = "tensor pack operation";
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 77f0ea9d2236ea..e0dea8e78d55c1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1030,11 +1030,13 @@ static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
return input;
}
+ assert(llvm::all_of(packOp.getAllOuterDims(),
+ [](int64_t val) { return val == 1; }) &&
+ "some outer dims are != 1");
+
Location loc = packOp.getLoc();
ShapedType inputType = packOp.getSourceType();
int64_t inputRank = inputType.getRank();
- assert(llvm::all_of(packOp.getDestType().getShape().take_front(inputRank),
- [](int64_t val) { return val == 1; }));
SmallVector<int64_t> paddedShape;
DenseMap<int64_t, OpFoldResult> tileAndPosMapping =
@@ -1126,12 +1128,8 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
// TODO: support the case that outer dimensions are not all 1s. A
// tensor.expand_shape will be generated in this case.
- auto innerDimsPos = packOp.getInnerDimsPos();
- int64_t srcRank = packOp.getSourceRank();
- auto destShape = packOp.getDestType().getShape();
- if (llvm::any_of(innerDimsPos, [destShape](int64_t index) {
- return destShape[index] != 1;
- })) {
+ if (llvm::any_of(packOp.getTiledOuterDims(),
+ [](int64_t dim) { return dim != 1; })) {
return rewriter.notifyMatchFailure(
packOp, "require the tiled outer dimensions of the result are all 1s");
}
@@ -1145,6 +1143,7 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
packOp.getDimAndTileMapping();
Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
Attribute oneIdxAttr = rewriter.getIndexAttr(1);
+ int64_t srcRank = packOp.getSourceRank();
SmallVector<OpFoldResult> readOffsets(srcRank, zeroIdxAttr);
SmallVector<OpFoldResult> readStrides(srcRank, oneIdxAttr);
SmallVector<OpFoldResult> readSizes;
@@ -1173,9 +1172,8 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
loc, readType, input, readOffsets, readSizes, readStrides);
// 2. Transpose the tile to match the inner tile order.
-
SmallVector<int64_t> perm = getPackUnpackRankReducedPerm(
- inputShape, innerDimsPos, packOp.getOuterDimsPerm());
+ inputShape, packOp.getInnerDimsPos(), packOp.getOuterDimsPerm());
LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n";
llvm::interleaveComma(perm, DBGS() << "perm: "); DBGSNL(););
@@ -1208,9 +1206,8 @@ LogicalResult GeneralizeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
int64_t destRank = unpackOp.getDestRank();
ArrayRef<int64_t> srcShape = unpackOp.getSourceType().getShape();
ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos();
- if (llvm::any_of(innerDimsPos, [srcShape](int64_t index) {
- return srcShape[index] != 1;
- })) {
+ if (llvm::any_of(unpackOp.getTiledOuterDims(),
+ [](int64_t dim) { return dim != 1; })) {
return rewriter.notifyMatchFailure(
unpackOp,
"require the tiled outer dimensions of the result are all 1s");
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 47f540e092e990..bc7deb1614d18d 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3987,6 +3987,17 @@ SmallVector<int64_t> PackOp::getStaticTiles() {
return getStaticTilesImpl(*this);
}
+SmallVector<int64_t> PackOp::getTiledOuterDims() {
+ auto innerDimsPos = getInnerDimsPos();
+ auto destShape = getDestType().getShape();
+ SmallVector<int64_t> res;
+
+ for (auto index : innerDimsPos)
+ res.push_back(destShape[index]);
+
+ return res;
+}
+
bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outputShape,
@@ -4411,6 +4422,17 @@ SmallVector<int64_t> UnPackOp::getStaticTiles() {
return getStaticTilesImpl(*this);
}
+SmallVector<int64_t> UnPackOp::getTiledOuterDims() {
+ auto innerDimsPos = getInnerDimsPos();
+ auto destShape = getSourceType().getShape();
+ SmallVector<int64_t> res;
+
+ for (auto index : innerDimsPos)
+ res.push_back(destShape[index]);
+
+ return res;
+}
+
LogicalResult UnPackOp::verify() {
return commonVerifierPackAndUnPackOp(*this);
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/109642
More information about the Mlir-commits
mailing list