[Mlir-commits] [mlir] [mlir][tensor] Add new helper hooks for RelayoutOp (PR #109642)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Sep 23 02:42:36 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-tensor

Author: Andrzej WarzyƄski (banach-space)

<details>
<summary>Changes</summary>

Implements two helper hooks for PackOp and UnPackOP, `getAllOuterDims`
and `getTiledOuterDims`, and adds them to RelayoutOp (that both PackOp
an UnPackOp inherit from).

This improves code re-use and also clarifies the meaning of "outer dims"
and "tiled outer dims".


---
Full diff: https://github.com/llvm/llvm-project/pull/109642.diff


3 Files Affected:

- (modified) mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td (+18-1) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp (+10-13) 
- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+22) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index cafc3d91fd1e9d..9fee75c6a2ca3d 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1814,7 +1814,7 @@ def Tensor_SplatOp : Tensor_Op<"splat", [
 }
 
 //===----------------------------------------------------------------------===//
-// PackOp
+// RelayoutOp
 //===----------------------------------------------------------------------===//
 
 class Tensor_RelayoutOp<string mnemonic, list<Trait> traits = []> :
@@ -1851,11 +1851,28 @@ class Tensor_RelayoutOp<string mnemonic, list<Trait> traits = []> :
     /// a sentinel `kDynamic` is introduced at that position in
     /// the returned vector.
     SmallVector<int64_t> getStaticTiles();
+
+    /// Retrieve all outer dims for this Pack/UnPack Op, i.e. all the leading
+    /// dims excluding the trailing dims corresponding to `innerTiles`. Note
+    /// that this will include both tiled and non-tiled dimensions.
+    ArrayRef<int64_t> getAllOuterDims() {
+      ShapedType inputType = getSourceType();
+      int64_t inputRank = inputType.getRank();
+      return getDestType().getShape().take_front(inputRank);
+    }
+
+    /// Similar to `getAllOuterDims`, but only retrieve the outer dims  that
+    /// have been tiled.
+    SmallVector<int64_t> getTiledOuterDims();
   }];
 
   let hasVerifier = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// PackOp
+//===----------------------------------------------------------------------===//
+
 def Tensor_PackOp : Tensor_RelayoutOp<"pack", [
     AttrSizedOperandSegments]> {
   let summary = "tensor pack operation";
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 77f0ea9d2236ea..e0dea8e78d55c1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1030,11 +1030,13 @@ static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
     return input;
   }
 
+  assert(llvm::all_of(packOp.getAllOuterDims(),
+                      [](int64_t val) { return val == 1; }) &&
+         "some outer dims are != 1");
+
   Location loc = packOp.getLoc();
   ShapedType inputType = packOp.getSourceType();
   int64_t inputRank = inputType.getRank();
-  assert(llvm::all_of(packOp.getDestType().getShape().take_front(inputRank),
-                      [](int64_t val) { return val == 1; }));
 
   SmallVector<int64_t> paddedShape;
   DenseMap<int64_t, OpFoldResult> tileAndPosMapping =
@@ -1126,12 +1128,8 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
 
   // TODO: support the case that outer dimensions are not all 1s. A
   // tensor.expand_shape will be generated in this case.
-  auto innerDimsPos = packOp.getInnerDimsPos();
-  int64_t srcRank = packOp.getSourceRank();
-  auto destShape = packOp.getDestType().getShape();
-  if (llvm::any_of(innerDimsPos, [destShape](int64_t index) {
-        return destShape[index] != 1;
-      })) {
+  if (llvm::any_of(packOp.getTiledOuterDims(),
+                   [](int64_t dim) { return dim != 1; })) {
     return rewriter.notifyMatchFailure(
         packOp, "require the tiled outer dimensions of the result are all 1s");
   }
@@ -1145,6 +1143,7 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
       packOp.getDimAndTileMapping();
   Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
   Attribute oneIdxAttr = rewriter.getIndexAttr(1);
+  int64_t srcRank = packOp.getSourceRank();
   SmallVector<OpFoldResult> readOffsets(srcRank, zeroIdxAttr);
   SmallVector<OpFoldResult> readStrides(srcRank, oneIdxAttr);
   SmallVector<OpFoldResult> readSizes;
@@ -1173,9 +1172,8 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
       loc, readType, input, readOffsets, readSizes, readStrides);
 
   // 2. Transpose the tile to match the inner tile order.
-
   SmallVector<int64_t> perm = getPackUnpackRankReducedPerm(
-      inputShape, innerDimsPos, packOp.getOuterDimsPerm());
+      inputShape, packOp.getInnerDimsPos(), packOp.getOuterDimsPerm());
 
   LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n";
              llvm::interleaveComma(perm, DBGS() << "perm: "); DBGSNL(););
@@ -1208,9 +1206,8 @@ LogicalResult GeneralizeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
   int64_t destRank = unpackOp.getDestRank();
   ArrayRef<int64_t> srcShape = unpackOp.getSourceType().getShape();
   ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos();
-  if (llvm::any_of(innerDimsPos, [srcShape](int64_t index) {
-        return srcShape[index] != 1;
-      })) {
+  if (llvm::any_of(unpackOp.getTiledOuterDims(),
+                   [](int64_t dim) { return dim != 1; })) {
     return rewriter.notifyMatchFailure(
         unpackOp,
         "require the tiled outer dimensions of the result are all 1s");
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 47f540e092e990..bc7deb1614d18d 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3987,6 +3987,17 @@ SmallVector<int64_t> PackOp::getStaticTiles() {
   return getStaticTilesImpl(*this);
 }
 
+SmallVector<int64_t> PackOp::getTiledOuterDims() {
+  auto innerDimsPos = getInnerDimsPos();
+  auto destShape = getDestType().getShape();
+  SmallVector<int64_t> res;
+
+  for (auto index : innerDimsPos)
+    res.push_back(destShape[index]);
+
+  return res;
+}
+
 bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
                                  ArrayRef<int64_t> innerDimsPos,
                                  ArrayRef<int64_t> outputShape,
@@ -4411,6 +4422,17 @@ SmallVector<int64_t> UnPackOp::getStaticTiles() {
   return getStaticTilesImpl(*this);
 }
 
+SmallVector<int64_t> UnPackOp::getTiledOuterDims() {
+  auto innerDimsPos = getInnerDimsPos();
+  auto destShape = getSourceType().getShape();
+  SmallVector<int64_t> res;
+
+  for (auto index : innerDimsPos)
+    res.push_back(destShape[index]);
+
+  return res;
+}
+
 LogicalResult UnPackOp::verify() {
   return commonVerifierPackAndUnPackOp(*this);
 }

``````````

</details>


https://github.com/llvm/llvm-project/pull/109642


More information about the Mlir-commits mailing list