[Mlir-commits] [mlir] [mlir][tensor] Add new helper hooks for RelayoutOp (PR #109642)

Andrzej WarzyƄski llvmlistbot at llvm.org
Mon Sep 23 10:39:58 PDT 2024


https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/109642

>From 469274f99c2128d2fd606b7ab9642c79714c56eb Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Sat, 21 Sep 2024 16:14:25 +0100
Subject: [PATCH 1/3] [mlir][tensor] Add new helper hooks to RelayoutOp

Implements two helper hooks for PackOp and UnPackOP, `getAllOuterDims`
and `getTiledOuterDims`, and adds them to RelayoutOp (that both PackOp
an UnPackOp inherit from).

This improves code re-use and also clarifies the meaning of "outer dims"
and "tiled outer dims".
---
 .../mlir/Dialect/Tensor/IR/TensorOps.td       | 19 ++++++++++++++-
 .../Dialect/Linalg/Transforms/Transforms.cpp  | 23 ++++++++-----------
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp      | 22 ++++++++++++++++++
 3 files changed, 50 insertions(+), 14 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index cafc3d91fd1e9d..9fee75c6a2ca3d 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1814,7 +1814,7 @@ def Tensor_SplatOp : Tensor_Op<"splat", [
 }
 
 //===----------------------------------------------------------------------===//
-// PackOp
+// RelayoutOp
 //===----------------------------------------------------------------------===//
 
 class Tensor_RelayoutOp<string mnemonic, list<Trait> traits = []> :
@@ -1851,11 +1851,28 @@ class Tensor_RelayoutOp<string mnemonic, list<Trait> traits = []> :
     /// a sentinel `kDynamic` is introduced at that position in
     /// the returned vector.
     SmallVector<int64_t> getStaticTiles();
+
+    /// Retrieve all outer dims for this Pack/UnPack Op, i.e. all the leading
+    /// dims excluding the trailing dims corresponding to `innerTiles`. Note
+    /// that this will include both tiled and non-tiled dimensions.
+    ArrayRef<int64_t> getAllOuterDims() {
+      ShapedType inputType = getSourceType();
+      int64_t inputRank = inputType.getRank();
+      return getDestType().getShape().take_front(inputRank);
+    }
+
+    /// Similar to `getAllOuterDims`, but only retrieve the outer dims  that
+    /// have been tiled.
+    SmallVector<int64_t> getTiledOuterDims();
   }];
 
   let hasVerifier = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// PackOp
+//===----------------------------------------------------------------------===//
+
 def Tensor_PackOp : Tensor_RelayoutOp<"pack", [
     AttrSizedOperandSegments]> {
   let summary = "tensor pack operation";
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 77f0ea9d2236ea..e0dea8e78d55c1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1030,11 +1030,13 @@ static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
     return input;
   }
 
+  assert(llvm::all_of(packOp.getAllOuterDims(),
+                      [](int64_t val) { return val == 1; }) &&
+         "some outer dims are != 1");
+
   Location loc = packOp.getLoc();
   ShapedType inputType = packOp.getSourceType();
   int64_t inputRank = inputType.getRank();
-  assert(llvm::all_of(packOp.getDestType().getShape().take_front(inputRank),
-                      [](int64_t val) { return val == 1; }));
 
   SmallVector<int64_t> paddedShape;
   DenseMap<int64_t, OpFoldResult> tileAndPosMapping =
@@ -1126,12 +1128,8 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
 
   // TODO: support the case that outer dimensions are not all 1s. A
   // tensor.expand_shape will be generated in this case.
-  auto innerDimsPos = packOp.getInnerDimsPos();
-  int64_t srcRank = packOp.getSourceRank();
-  auto destShape = packOp.getDestType().getShape();
-  if (llvm::any_of(innerDimsPos, [destShape](int64_t index) {
-        return destShape[index] != 1;
-      })) {
+  if (llvm::any_of(packOp.getTiledOuterDims(),
+                   [](int64_t dim) { return dim != 1; })) {
     return rewriter.notifyMatchFailure(
         packOp, "require the tiled outer dimensions of the result are all 1s");
   }
@@ -1145,6 +1143,7 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
       packOp.getDimAndTileMapping();
   Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
   Attribute oneIdxAttr = rewriter.getIndexAttr(1);
+  int64_t srcRank = packOp.getSourceRank();
   SmallVector<OpFoldResult> readOffsets(srcRank, zeroIdxAttr);
   SmallVector<OpFoldResult> readStrides(srcRank, oneIdxAttr);
   SmallVector<OpFoldResult> readSizes;
@@ -1173,9 +1172,8 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
       loc, readType, input, readOffsets, readSizes, readStrides);
 
   // 2. Transpose the tile to match the inner tile order.
-
   SmallVector<int64_t> perm = getPackUnpackRankReducedPerm(
-      inputShape, innerDimsPos, packOp.getOuterDimsPerm());
+      inputShape, packOp.getInnerDimsPos(), packOp.getOuterDimsPerm());
 
   LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n";
              llvm::interleaveComma(perm, DBGS() << "perm: "); DBGSNL(););
@@ -1208,9 +1206,8 @@ LogicalResult GeneralizeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
   int64_t destRank = unpackOp.getDestRank();
   ArrayRef<int64_t> srcShape = unpackOp.getSourceType().getShape();
   ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos();
-  if (llvm::any_of(innerDimsPos, [srcShape](int64_t index) {
-        return srcShape[index] != 1;
-      })) {
+  if (llvm::any_of(unpackOp.getTiledOuterDims(),
+                   [](int64_t dim) { return dim != 1; })) {
     return rewriter.notifyMatchFailure(
         unpackOp,
         "require the tiled outer dimensions of the result are all 1s");
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 47f540e092e990..bc7deb1614d18d 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3987,6 +3987,17 @@ SmallVector<int64_t> PackOp::getStaticTiles() {
   return getStaticTilesImpl(*this);
 }
 
+SmallVector<int64_t> PackOp::getTiledOuterDims() {
+  auto innerDimsPos = getInnerDimsPos();
+  auto destShape = getDestType().getShape();
+  SmallVector<int64_t> res;
+
+  for (auto index : innerDimsPos)
+    res.push_back(destShape[index]);
+
+  return res;
+}
+
 bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
                                  ArrayRef<int64_t> innerDimsPos,
                                  ArrayRef<int64_t> outputShape,
@@ -4411,6 +4422,17 @@ SmallVector<int64_t> UnPackOp::getStaticTiles() {
   return getStaticTilesImpl(*this);
 }
 
+SmallVector<int64_t> UnPackOp::getTiledOuterDims() {
+  auto innerDimsPos = getInnerDimsPos();
+  auto destShape = getSourceType().getShape();
+  SmallVector<int64_t> res;
+
+  for (auto index : innerDimsPos)
+    res.push_back(destShape[index]);
+
+  return res;
+}
+
 LogicalResult UnPackOp::verify() {
   return commonVerifierPackAndUnPackOp(*this);
 }

>From d6bc07fa1f7c67f4873df0986aedc28fd3f26f1a Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Mon, 23 Sep 2024 14:58:20 +0100
Subject: [PATCH 2/3] fixup! [mlir][tensor] Add new helper hooks to RelayoutOp

Remove empty space
---
 mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 9fee75c6a2ca3d..7b57b503ea56f0 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1861,7 +1861,7 @@ class Tensor_RelayoutOp<string mnemonic, list<Trait> traits = []> :
       return getDestType().getShape().take_front(inputRank);
     }
 
-    /// Similar to `getAllOuterDims`, but only retrieve the outer dims  that
+    /// Similar to `getAllOuterDims`, but only retrieve the outer dims that
     /// have been tiled.
     SmallVector<int64_t> getTiledOuterDims();
   }];

>From 534e096b789c78410ad1e15f777e06d4e6a59d4c Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Mon, 23 Sep 2024 18:21:08 +0100
Subject: [PATCH 3/3] fixup! fixup! [mlir][tensor] Add new helper hooks to
 RelayoutOp

Add comments, specialize getAllOuterDims
---
 mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td | 13 ++++++-------
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp         | 12 ++++++++++++
 2 files changed, 18 insertions(+), 7 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 7b57b503ea56f0..3170115883e2be 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1854,15 +1854,14 @@ class Tensor_RelayoutOp<string mnemonic, list<Trait> traits = []> :
 
     /// Retrieve all outer dims for this Pack/UnPack Op, i.e. all the leading
     /// dims excluding the trailing dims corresponding to `innerTiles`. Note
-    /// that this will include both tiled and non-tiled dimensions.
-    ArrayRef<int64_t> getAllOuterDims() {
-      ShapedType inputType = getSourceType();
-      int64_t inputRank = inputType.getRank();
-      return getDestType().getShape().take_front(inputRank);
-    }
+    /// that this will include both tiled and non-tiled dimensions. The order
+    /// of the output dimensions is consistent with the shape of the packed
+    /// tensor.
+    ArrayRef<int64_t> getAllOuterDims();
 
     /// Similar to `getAllOuterDims`, but only retrieve the outer dims that
-    /// have been tiled.
+    /// have been tiled. Also, the order of the output dimensions is consistent
+    /// with `inner_dims_pos` rather than the packed tensor.
     SmallVector<int64_t> getTiledOuterDims();
   }];
 
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index bc7deb1614d18d..d0ddde96b0b231 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3987,6 +3987,12 @@ SmallVector<int64_t> PackOp::getStaticTiles() {
   return getStaticTilesImpl(*this);
 }
 
+ArrayRef<int64_t> PackOp::getAllOuterDims() {
+  ShapedType inputType = getSourceType();
+  int64_t inputRank = inputType.getRank();
+  return getDestType().getShape().take_front(inputRank);
+}
+
 SmallVector<int64_t> PackOp::getTiledOuterDims() {
   auto innerDimsPos = getInnerDimsPos();
   auto destShape = getDestType().getShape();
@@ -4422,6 +4428,12 @@ SmallVector<int64_t> UnPackOp::getStaticTiles() {
   return getStaticTilesImpl(*this);
 }
 
+ArrayRef<int64_t> UnPackOp::getAllOuterDims() {
+  ShapedType destType = getDestType();
+  int64_t destRank = destType.getRank();
+  return getSourceType().getShape().take_front(destRank);
+}
+
 SmallVector<int64_t> UnPackOp::getTiledOuterDims() {
   auto innerDimsPos = getInnerDimsPos();
   auto destShape = getSourceType().getShape();



More information about the Mlir-commits mailing list