[Mlir-commits] [mlir] [MLIR] Fix VectorEmulateNarrowType constant op mask bug (PR #116064)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Nov 14 07:42:20 PST 2024


https://github.com/lialan updated https://github.com/llvm/llvm-project/pull/116064

>From af7290e4bdf0e461494c40e7bd33a7323ff993e0 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Tue, 12 Nov 2024 22:43:58 -0500
Subject: [PATCH] [MLIR] Fix VectorEmulateNarrowType constant op mask bug

This commit adds support for handling mask constants generated by the
`arith.constant` op in the `VectorEmulateNarrowType` pattern. Previously, this
pattern would not match due to the lack of mask constant handling in
`getCompressedMaskOp`.

The changes include:

1. Updating `getCompressedMaskOp` to recognize and handle `arith.constant` ops as
   mask value sources.

2. Handling cases where the mask is not aligned with the emulated load width.
   The compressed mask is adjusted to account for the offset.

Limitations:
- The arith.constant op can only have 1-dimensional constant values.

Resolves: #115742

Signed-off-by: Alan Li <me at alanli.org>
---
 .../Transforms/VectorEmulateNarrowType.cpp    | 161 ++++++++++++------
 .../vector-emulate-narrow-type-unaligned.mlir |  38 +++++
 .../Vector/vector-emulate-narrow-type.mlir    |  25 +++
 3 files changed, 170 insertions(+), 54 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index e5f2a847994aee..91c41926d6f759 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -30,9 +30,11 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/IR/Value.h"
+#include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/Support/Debug.h"
+#include "llvm/Support/LogicalResult.h"
 #include "llvm/Support/MathExtras.h"
 #include "llvm/Support/raw_ostream.h"
 #include <cstdint>
@@ -75,7 +77,8 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
                                                   int numSrcElemsPerDest,
                                                   int numFrontPadElems = 0) {
 
-  assert(numFrontPadElems < numSrcElemsPerDest && "intraDataOffset must be less than scale");
+  assert(numFrontPadElems < numSrcElemsPerDest &&
+         "numFrontPadElems must be less than numSrcElemsPerDest");
 
   auto numElements = (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1) /
                      numSrcElemsPerDest;
@@ -83,75 +86,125 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
   Operation *maskOp = mask.getDefiningOp();
   SmallVector<vector::ExtractOp, 2> extractOps;
   // Finding the mask creation operation.
-  while (maskOp && !isa<vector::CreateMaskOp, vector::ConstantMaskOp>(maskOp)) {
+  while (maskOp &&
+         !isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
+             maskOp)) {
     if (auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
       maskOp = extractOp.getVector().getDefiningOp();
       extractOps.push_back(extractOp);
     }
   }
-  auto createMaskOp = dyn_cast_or_null<vector::CreateMaskOp>(maskOp);
-  auto constantMaskOp = dyn_cast_or_null<vector::ConstantMaskOp>(maskOp);
-  if (!createMaskOp && !constantMaskOp)
+
+  if (!isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
+          maskOp))
     return failure();
 
   // Computing the "compressed" mask. All the emulation logic (i.e. computing
   // new mask index) only happens on the last dimension of the vectors.
-  Operation *newMask = nullptr;
-  SmallVector<int64_t> shape(
+  SmallVector<int64_t> maskShape(
       cast<VectorType>(maskOp->getResultTypes()[0]).getShape());
-  shape.back() = numElements;
-  auto newMaskType = VectorType::get(shape, rewriter.getI1Type());
-  if (createMaskOp) {
-    OperandRange maskOperands = createMaskOp.getOperands();
-    size_t numMaskOperands = maskOperands.size();
-    AffineExpr s0;
-    bindSymbols(rewriter.getContext(), s0);
-    s0 = s0 + numSrcElemsPerDest - 1;
-    s0 = s0.floorDiv(numSrcElemsPerDest);
-    OpFoldResult origIndex =
-        getAsOpFoldResult(maskOperands[numMaskOperands - 1]);
-    OpFoldResult maskIndex =
-        affine::makeComposedFoldedAffineApply(rewriter, loc, s0, origIndex);
-    SmallVector<Value> newMaskOperands(maskOperands.drop_back());
-    newMaskOperands.push_back(
-        getValueOrCreateConstantIndexOp(rewriter, loc, maskIndex));
-    newMask = rewriter.create<vector::CreateMaskOp>(loc, newMaskType,
-                                                    newMaskOperands);
-  } else if (constantMaskOp) {
-    ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
-    size_t numMaskOperands = maskDimSizes.size();
-    int64_t origIndex = maskDimSizes[numMaskOperands - 1];
-    int64_t startIndex = numFrontPadElems / numSrcElemsPerDest;
-    int64_t maskIndex =
-        llvm::divideCeil(numFrontPadElems + origIndex, numSrcElemsPerDest);
-
-    // TODO: we only want the mask between [startIndex, maskIndex] to be true,
-    // the rest are false.
-    if (numFrontPadElems != 0 && maskDimSizes.size() > 1)
-      return failure();
-
-    SmallVector<int64_t> newMaskDimSizes(maskDimSizes.drop_back());
-    newMaskDimSizes.push_back(maskIndex);
-
-    if (numFrontPadElems == 0) {
-      newMask = rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
-                                                        newMaskDimSizes);
-    } else {
-      SmallVector<bool> newMaskValues;
-      for (int64_t i = 0; i < numElements; ++i)
-        newMaskValues.push_back(i >= startIndex && i < maskIndex);
-      auto denseAttr = DenseElementsAttr::get(newMaskType, newMaskValues);
-      newMask = rewriter.create<arith::ConstantOp>(loc, newMaskType, denseAttr);
-    }
-  }
+  maskShape.back() = numElements;
+  auto newMaskType = VectorType::get(maskShape, rewriter.getI1Type());
+  std::optional<Operation *> newMask =
+      TypeSwitch<Operation *, std::optional<Operation *>>(maskOp)
+          .Case<vector::CreateMaskOp>(
+              [&](auto createMaskOp) -> std::optional<Operation *> {
+                OperandRange maskOperands = createMaskOp.getOperands();
+                size_t numMaskOperands = maskOperands.size();
+                AffineExpr s0;
+                bindSymbols(rewriter.getContext(), s0);
+                s0 = s0 + numSrcElemsPerDest - 1;
+                s0 = s0.floorDiv(numSrcElemsPerDest);
+                OpFoldResult origIndex =
+                    getAsOpFoldResult(maskOperands[numMaskOperands - 1]);
+                OpFoldResult maskIndex = affine::makeComposedFoldedAffineApply(
+                    rewriter, loc, s0, origIndex);
+                SmallVector<Value> newMaskOperands(maskOperands.drop_back());
+                newMaskOperands.push_back(
+                    getValueOrCreateConstantIndexOp(rewriter, loc, maskIndex));
+                return rewriter.create<vector::CreateMaskOp>(loc, newMaskType,
+                                                             newMaskOperands);
+              })
+          .Case<vector::ConstantMaskOp>([&](auto constantMaskOp)
+                                            -> std::optional<Operation *> {
+            ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
+            size_t numMaskOperands = maskDimSizes.size();
+            int64_t origIndex = maskDimSizes[numMaskOperands - 1];
+            int64_t startIndex = numFrontPadElems / numSrcElemsPerDest;
+            int64_t maskIndex = llvm::divideCeil(numFrontPadElems + origIndex,
+                                                 numSrcElemsPerDest);
+
+            // TODO: we only want the mask between [startIndex, maskIndex]
+            // to be true, the rest are false.
+            if (numFrontPadElems != 0 && maskDimSizes.size() > 1)
+              return std::nullopt;
+
+            SmallVector<int64_t> newMaskDimSizes(maskDimSizes.drop_back());
+            newMaskDimSizes.push_back(maskIndex);
+
+            if (numFrontPadElems == 0)
+              return rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
+                                                             newMaskDimSizes);
+
+            SmallVector<bool> newMaskValues;
+            for (int64_t i = 0; i < numElements; ++i)
+              newMaskValues.push_back(i >= startIndex && i < maskIndex);
+            auto denseAttr = DenseElementsAttr::get(newMaskType, newMaskValues);
+            return rewriter.create<arith::ConstantOp>(loc, newMaskType,
+                                                      denseAttr);
+          })
+          .Case<arith::ConstantOp>([&](auto constantOp)
+                                       -> std::optional<Operation *> {
+            // TODO: Support multiple dimensions.
+            if (maskShape.size() != 1)
+              return std::nullopt;
+            // Rearrange the original mask values to cover the whole potential
+            // loading region. For example, in the case of using byte-size for
+            // emulation, given the following mask:
+            //
+            // %mask = [false, true, false, true, false, false]
+            //
+            // With front offset of 1, the mask will be padded 0s in the front
+            // and back so that:
+            // 1. It is aligned with the effective loading bits
+            // 2. Its length is multiple of `numSrcElemPerDest` (and the total
+            // coverage size is mulitiple of bytes). The new mask will be like
+            // this before compressing:
+            //
+            // %new_mask = [false, false, true, false, true, false, false,
+            // false]
+            auto denseAttr =
+                dyn_cast<DenseIntElementsAttr>(constantOp.getValue());
+            if (!denseAttr)
+              return std::nullopt;
+            SmallVector<bool> maskValues(numFrontPadElems, false);
+            maskValues.append(denseAttr.template value_begin<bool>(),
+                              denseAttr.template value_end<bool>());
+            maskValues.resize(numElements * numSrcElemsPerDest, false);
+
+            // Compressing by combining every `numSrcElemsPerDest` elements:
+            SmallVector<bool> compressedMaskValues;
+            for (size_t i = 0; i < maskValues.size(); i += numSrcElemsPerDest) {
+              bool combinedValue = false;
+              for (int j = 0; j < numSrcElemsPerDest; ++j) {
+                combinedValue |= maskValues[i + j];
+              }
+              compressedMaskValues.push_back(combinedValue);
+            }
+            return rewriter.create<arith::ConstantOp>(
+                loc, DenseElementsAttr::get(newMaskType, compressedMaskValues));
+          });
+
+  if (!newMask)
+    return failure();
 
   while (!extractOps.empty()) {
     newMask = rewriter.create<vector::ExtractOp>(
-        loc, newMask->getResults()[0], extractOps.back().getMixedPosition());
+        loc, (*newMask)->getResults()[0], extractOps.back().getMixedPosition());
     extractOps.pop_back();
   }
 
-  return newMask;
+  return *newMask;
 }
 
 /// Extracts 1-D subvector from a 1-D vector. It is a wrapper function for
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
index 7ed75ff7f1579c..c61343a328d791 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
@@ -249,3 +249,41 @@ func.func @vector_maskedload_i2_dynamic_indexing_mixed(%passthru: vector<3xi2>,
 // CHECK: %[[IN8:.+]] = vector.insert %[[EX8]], %[[IN7]] [1] : i2 into vector<3xi2>
 // CHECK: %[[EX9:.+]] = vector.extract %[[SELECT]][%[[INCIDX2]]] : i2 from vector<8xi2>
 // CHECK: %[[IN9:.+]] = vector.insert %[[EX9]], %[[IN8]] [2] : i2 into vector<3xi2>
+
+// -----
+
+func.func @vector_maskedload_i4_constant_mask_unaligned(%passthru: vector<5xi2>) -> vector<5xi2> {
+  %0 = memref.alloc() : memref<3x5xi2>
+  %mask = arith.constant dense<[false, true, true, true, false]> : vector<5xi1>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %1 = vector.maskedload %0[%c1, %c0], %mask, %passthru :
+    memref<3x5xi2>, vector<5xi1>, vector<5xi2> into vector<5xi2>
+  return %1 : vector<5xi2>
+}
+
+// CHECK: func @vector_maskedload_i4_constant_mask_unaligned(
+// CHECK-SAME: %[[PTH:.+]]: vector<5xi2>) -> vector<5xi2>
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<4xi8>
+// CHECK: %[[MASK:.+]] = arith.constant dense<[false, true, true, true, false]> : vector<5xi1>
+
+// CHECK: %[[CST0:.+]] = arith.constant dense<true> : vector<2xi1>
+// CHECK: %[[CST1:.+]] = arith.constant dense<0> : vector<8xi2>
+// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[PTH]], %[[CST1]]
+// CHECK-SAME: {offsets = [1], strides = [1]} : vector<5xi2> into vector<8xi2>
+
+// Emulated masked load from alloc:
+// CHECK: %[[BCAST:.+]] = vector.bitcast %[[INSERT]] : vector<8xi2> to vector<2xi8>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[MASKLOAD:.+]] = vector.maskedload %[[ALLOC]][%[[C1]]], %[[CST0]], %[[BCAST]]
+// CHECK: %[[BCAST2:.+]] = vector.bitcast %[[MASKLOAD]] : vector<2xi8> to vector<8xi2>
+
+// Select from emulated loaded vector and passthru vector:
+// TODO: fold this part if possible.
+// CHECK: %[[CST2:.+]] = arith.constant dense<false> : vector<8xi1>
+// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[MASK]], %[[CST2]]
+// CHECK-SAME: {offsets = [1], strides = [1]} : vector<5xi1> into vector<8xi1>
+// CHECK: %[[SELECT:.+]] = arith.select %[[INSERT2]], %[[BCAST2]], %[[INSERT]] : vector<8xi1>, vector<8xi2>
+// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[SELECT]]
+// CHECK-SAME: {offsets = [1], sizes = [5], strides = [1]} : vector<8xi2> to vector<5xi2>
+// CHECK: return %[[EXTRACT]] : vector<5xi2>
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
index 034bd47f6163e6..a9cf7c7220fd08 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
@@ -624,3 +624,28 @@ func.func @vector_maskedstore_i4_constant_mask(
 // CHECK32:           %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[BITCAST]] : vector<8xi1>, vector<8xi4>
 // CHECK32:           %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<1xi32>
 // CHECK32:           vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]] : memref<3xi32>, vector<1xi1>, vector<1xi32>
+
+// -----
+
+func.func @vector_maskedload_i4_arith_constant(%passthru: vector<8xi4>) -> vector<8xi4> {
+  %0 = memref.alloc() : memref<3x8xi4>
+  %cst = arith.constant dense<0> : vector<8xi4>
+  %mask = arith.constant dense<[false, true, true, true, true, false, false, false]> : vector<8xi1>
+  %c0 = arith.constant 0 : index
+  %1 = vector.maskedload %0[%c0, %c0], %mask, %passthru :
+    memref<3x8xi4>, vector<8xi1>, vector<8xi4> into vector<8xi4>
+  return %1 : vector<8xi4>
+}
+
+// CHECK: func @vector_maskedload_i4_arith_constant(
+// CHECK-SAME:   %[[PASSTHRU:[a-zA-Z0-9]+]]
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<12xi8>
+// CHECK: %[[MASK:.+]] = arith.constant dense<[false, true, true, true, true, false, false, false]> : vector<8xi1>
+// CHECK: %[[COMP_MASK:.+]] = arith.constant dense<[true, true, true, false]> : vector<4xi1>
+// CHECK: %[[PTHU_UPCAST:.+]] = vector.bitcast %[[PASSTHRU]] : vector<8xi4> to vector<4xi8>
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[C0]]], %[[COMP_MASK]], %[[PTHU_UPCAST]]
+// CHECK-SAME : memref<12xi8>, vector<4xi1>, vector<4xi8> into vector<4xi8>
+// CHECK: %[[LOAD_DOWNCAST:.+]] = vector.bitcast %[[LOAD]] : vector<4xi8> to vector<8xi4>
+// CHECK: %[[SELECT:.+]] = arith.select %[[MASK]], %[[LOAD_DOWNCAST]], %[[PASSTHRU]] : vector<8xi1>, vector<8xi4>
+// CHECK: return %[[SELECT]] : vector<8xi4>



More information about the Mlir-commits mailing list