[Mlir-commits] [mlir] [mlir][vector][nfc] Improve comments in `getCompressedMaskOp` (PR #115663)

Andrzej WarzyƄski llvmlistbot at llvm.org
Wed Nov 13 08:40:23 PST 2024


https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/115663

>From d2251a1861dd5ae908b8720d9e29a6ffe9b2738a Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Sun, 10 Nov 2024 15:51:15 +0000
Subject: [PATCH 1/3] [mlir][vector][nfc] Improve comments in
 `getCompressedMaskOp`

This PR updates and expands the high-level comment for
`getCompressedMaskOp` and renames input variables with more descriptive
names.

The current variable names are somewhat unclear (e.g., `scale`) or
derived from `memref` terminology (e.g., `intraDataOffset` from
`LinearizedMemRefInfo`). The updated names in this PR aim to better
align with the context and usage in the vector domain.
---
 .../Transforms/VectorEmulateNarrowType.cpp    | 45 ++++++++++++-------
 1 file changed, 29 insertions(+), 16 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index eb4ce24548e603..6cc1b5a7ab2438 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -45,27 +45,39 @@ using namespace mlir;
 #define DBGSNL() (llvm::dbgs() << "\n")
 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
 
-/// Returns a compressed mask. The mask value is set only if any mask is present
-/// in the scale range. E.g., if `scale` equals to 2, and `intraDataOffset`
-/// equals to 1 (intraDataOffset strictly smaller than scale), the following
-/// mask:
+/// Returns a compressed mask. For example, when emulating `i8` with `i32` and
+/// when the number of source elements spans two `i32` elements, this method
+/// will compress `vector<8xi1>` into `vector<2xi1>`.
+///
+/// The compressed/output mask value is set iff any mask in the corresponding
+/// `numSrcElemsPerDest` range of uncompressed/input masks is set. E.g., if
+/// `numSrcElemsPerDest` equals to 2, and `numFrontPadElems` equals to 1, the
+/// following mask:
 ///
 ///   %mask = [1, 1, 0, 0, 0, 0]
 ///
-/// will first be padded in the front with number of `intraDataOffset` zeros,
+/// will first be padded in the front with number of `numFrontPadElems` zeros,
 /// and pad zeros in the back to make the number of elements a multiple of
-/// `scale` (just to make it easier to compute). The new mask will be:
+/// `numSrcElemesPerDest` (just to make it easier to compute). The new mask will
+/// be:
 ///   %mask = [0, 1, 1, 0, 0, 0, 0, 0]
 ///
 /// then it will return the following new compressed mask:
 ///
 ///   %mask = [1, 1, 0, 0]
+///
+/// `numFrontPadElems` is assumed to be strictly smaller than
+/// `numSrcElemsPerDest`.
 static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
                                                   Location loc, Value mask,
-                                                  int origElements, int scale,
-                                                  int intraDataOffset = 0) {
-  assert(intraDataOffset < scale && "intraDataOffset must be less than scale");
-  auto numElements = llvm::divideCeil(intraDataOffset + origElements, scale);
+                                                  int numSrcElems,
+                                                  int numSrcElemsPerDest,
+                                                  int numFrontPadElems = 0) {
+
+  assert(numFrontPadElems < numSrcElemsPerDest && "intraDataOffset must be less than scale");
+
+  auto numElements = (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1) /
+                     numSrcElemsPerDest;
 
   Operation *maskOp = mask.getDefiningOp();
   SmallVector<vector::ExtractOp, 2> extractOps;
@@ -93,8 +105,8 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
     size_t numMaskOperands = maskOperands.size();
     AffineExpr s0;
     bindSymbols(rewriter.getContext(), s0);
-    s0 = s0 + scale - 1;
-    s0 = s0.floorDiv(scale);
+    s0 = s0 + numSrcElemsPerDest - 1;
+    s0 = s0.floorDiv(numSrcElemsPerDest);
     OpFoldResult origIndex =
         getAsOpFoldResult(maskOperands[numMaskOperands - 1]);
     OpFoldResult maskIndex =
@@ -108,18 +120,19 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
     ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
     size_t numMaskOperands = maskDimSizes.size();
     int64_t origIndex = maskDimSizes[numMaskOperands - 1];
-    int64_t startIndex = intraDataOffset / scale;
-    int64_t maskIndex = llvm::divideCeil(intraDataOffset + origIndex, scale);
+    int64_t startIndex = numFrontPadElems / numSrcElemsPerDest;
+    int64_t maskIndex =
+        llvm::divideCeil(numFrontPadElems + origIndex, numSrcElemsPerDest);
 
     // TODO: we only want the mask between [startIndex, maskIndex] to be true,
     // the rest are false.
-    if (intraDataOffset != 0 && maskDimSizes.size() > 1)
+    if (numFrontPadElems != 0 && maskDimSizes.size() > 1)
       return failure();
 
     SmallVector<int64_t> newMaskDimSizes(maskDimSizes.drop_back());
     newMaskDimSizes.push_back(maskIndex);
 
-    if (intraDataOffset == 0) {
+    if (numFrontPadElems == 0) {
       newMask = rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
                                                         newMaskDimSizes);
     } else {

>From 1fc1fb738a03d3ae19a4f17ca061fb414322a107 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 12 Nov 2024 15:56:08 +0000
Subject: [PATCH 2/3] fixup! [mlir][vector][nfc] Improve comments in
 `getCompressedMaskOp`

Incorporate PR suggestions
---
 .../Vector/Transforms/VectorEmulateNarrowType.cpp      | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 6cc1b5a7ab2438..ae54367f60b469 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -45,9 +45,10 @@ using namespace mlir;
 #define DBGSNL() (llvm::dbgs() << "\n")
 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
 
-/// Returns a compressed mask. For example, when emulating `i8` with `i32` and
-/// when the number of source elements spans two `i32` elements, this method
-/// will compress `vector<8xi1>` into `vector<2xi1>`.
+/// Returns a compressed mask for the emulated vector. For example, when
+/// emulating an eight-element `i8` vector with `i32` (i.e. when the source
+/// elements span two dest elements), this method compresses `vector<8xi1>`
+/// into `vector<2xi1>`.
 ///
 /// The compressed/output mask value is set iff any mask in the corresponding
 /// `numSrcElemsPerDest` range of uncompressed/input masks is set. E.g., if
@@ -66,8 +67,7 @@ using namespace mlir;
 ///
 ///   %mask = [1, 1, 0, 0]
 ///
-/// `numFrontPadElems` is assumed to be strictly smaller than
-/// `numSrcElemsPerDest`.
+/// NOTE: `numFrontPadElems` must be strictly smaller than `numSrcElemsPerDest`.
 static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
                                                   Location loc, Value mask,
                                                   int numSrcElems,

>From afc03145f6c7354d9f42c8364ad6efcb34352236 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Wed, 13 Nov 2024 16:40:04 +0000
Subject: [PATCH 3/3] fixup! fixup! [mlir][vector][nfc] Improve comments in
 `getCompressedMaskOp`

Final tweaks
---
 .../Vector/Transforms/VectorEmulateNarrowType.cpp     | 11 ++++++-----
 1 file changed, 6 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index ae54367f60b469..e5f2a847994aee 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -57,17 +57,18 @@ using namespace mlir;
 ///
 ///   %mask = [1, 1, 0, 0, 0, 0]
 ///
-/// will first be padded in the front with number of `numFrontPadElems` zeros,
-/// and pad zeros in the back to make the number of elements a multiple of
-/// `numSrcElemesPerDest` (just to make it easier to compute). The new mask will
-/// be:
+/// will first be padded in the front with `numFrontPadElems` zeros, and zeros
+/// will be added in the back to make the number of elements a multiple of
+/// `numSrcElemsPerDest` (for easier computation). The resulting mask will be:
+///
 ///   %mask = [0, 1, 1, 0, 0, 0, 0, 0]
 ///
 /// then it will return the following new compressed mask:
 ///
 ///   %mask = [1, 1, 0, 0]
 ///
-/// NOTE: `numFrontPadElems` must be strictly smaller than `numSrcElemsPerDest`.
+/// NOTE: `numFrontPadElems` is assumed to be strictly smaller than
+/// `numSrcElemsPerDest`.
 static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
                                                   Location loc, Value mask,
                                                   int numSrcElems,



More information about the Mlir-commits mailing list