[Mlir-commits] [mlir] [MLIR] Refactor mask compression logic when emulating `vector.maskedload` ops (PR #116520)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Nov 25 18:16:29 PST 2024


https://github.com/lialan updated https://github.com/llvm/llvm-project/pull/116520

>From 98cee5cd5ba639247fb165c28a43e8102d2c0677 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Fri, 15 Nov 2024 21:10:12 -0500
Subject: [PATCH 1/2] [MLIR] Refactor mask compression logic when emulating
 `vector.maskedload` ops

This patch simplifies and extends the logic used when compressing masks emitted by `vector.constant_mask` to support extracting 1-D vectors from multi-dimensional vector loads. It streamlines mask computation, making it applicable for multi-dimensional mask generation, improving the overall handling of masked load operations.
---
 .../Transforms/VectorEmulateNarrowType.cpp    | 43 ++++-------------
 .../vector-emulate-narrow-type-unaligned.mlir | 48 ++++++++++++++-----
 2 files changed, 45 insertions(+), 46 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 87c30a733c363e..51b41669513c47 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -128,34 +128,16 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
                 return rewriter.create<vector::CreateMaskOp>(loc, newMaskType,
                                                              newMaskOperands);
               })
-          .Case<vector::ConstantMaskOp>([&](auto constantMaskOp)
-                                            -> std::optional<Operation *> {
-            ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
-            size_t numMaskOperands = maskDimSizes.size();
-            int64_t origIndex = maskDimSizes[numMaskOperands - 1];
-            int64_t startIndex = numFrontPadElems / numSrcElemsPerDest;
-            int64_t maskIndex = llvm::divideCeil(numFrontPadElems + origIndex,
-                                                 numSrcElemsPerDest);
-
-            // TODO: we only want the mask between [startIndex, maskIndex]
-            // to be true, the rest are false.
-            if (numFrontPadElems != 0 && maskDimSizes.size() > 1)
-              return std::nullopt;
-
-            SmallVector<int64_t> newMaskDimSizes(maskDimSizes.drop_back());
-            newMaskDimSizes.push_back(maskIndex);
-
-            if (numFrontPadElems == 0)
-              return rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
-                                                             newMaskDimSizes);
-
-            SmallVector<bool> newMaskValues;
-            for (int64_t i = 0; i < numDestElems; ++i)
-              newMaskValues.push_back(i >= startIndex && i < maskIndex);
-            auto newMask = DenseElementsAttr::get(newMaskType, newMaskValues);
-            return rewriter.create<arith::ConstantOp>(loc, newMaskType,
-                                                      newMask);
-          })
+          .Case<vector::ConstantMaskOp>(
+              [&](auto constantMaskOp) -> std::optional<Operation *> {
+                SmallVector<int64_t> maskDimSizes(
+                    constantMaskOp.getMaskDimSizes());
+                int64_t &maskIndex = maskDimSizes.back();
+                maskIndex = llvm::divideCeil(numFrontPadElems + maskIndex,
+                                             numSrcElemsPerDest);
+                return rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
+                                                               maskDimSizes);
+              })
           .Case<arith::ConstantOp>([&](auto constantOp)
                                        -> std::optional<Operation *> {
             // TODO: Support multiple dimensions.
@@ -605,11 +587,6 @@ struct ConvertVectorMaskedLoad final
   matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
 
-    // See #115653
-    if (op.getVectorType().getRank() != 1)
-      return rewriter.notifyMatchFailure(op,
-                                         "only 1-D vectors are supported ATM");
-
     auto loc = op.getLoc();
     auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
     Type oldElementType = op.getType().getElementType();
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
index 721c8a8d5d2034..263548245ce838 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
@@ -42,22 +42,20 @@ func.func @vector_transfer_read_i2() -> vector<3xi2> {
 
 // -----
 
-func.func @vector_constant_mask_maskedload_i2(%passthru: vector<5xi2>) -> vector<3x5xi2> {
+func.func @vector_constant_mask_maskedload_i2(%passthru: vector<5xi2>) -> vector<5xi2> {
   %0 = memref.alloc() : memref<3x5xi2>
-  %cst = arith.constant dense<0> : vector<3x5xi2>
   %mask = vector.constant_mask [3] : vector<5xi1>
   %c0 = arith.constant 0 : index
   %c2 = arith.constant 2 : index
   %1 = vector.maskedload %0[%c2, %c0], %mask, %passthru :
     memref<3x5xi2>, vector<5xi1>, vector<5xi2> into vector<5xi2>
-  %2 = vector.insert %1, %cst [0] : vector<5xi2> into vector<3x5xi2>
-  return %2 : vector<3x5xi2>
+  return %1 : vector<5xi2>
 }
-
 // CHECK-LABEL: func @vector_constant_mask_maskedload_i2(
-// CHECK-SAME: %[[ARG0:.+]]: vector<5xi2>) -> vector<3x5xi2>
+// CHECK-SAME: %[[ARG0:.+]]: vector<5xi2>) -> vector<5xi2>
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<4xi8>
 // CHECK: %[[ORIGINMASK:.+]] = vector.constant_mask [3] : vector<5xi1>
-// CHECK: %[[NEWMASK:.+]] = arith.constant dense<true> : vector<2xi1>
+// CHECK: %[[NEWMASK:.+]] = vector.constant_mask [2] : vector<2xi1>
 // CHECK: %[[VESSEL:.+]] = arith.constant dense<0> : vector<8xi2>
 // CHECK: %[[INSERT1:.+]] = vector.insert_strided_slice %[[ARG0]], %[[VESSEL]]
 // CHECK-SAME: {offsets = [2], strides = [1]} : vector<5xi2> into vector<8xi2>
@@ -123,6 +121,31 @@ func.func @check_unaligned_create_mask_static_i2(%passthru: vector<7xi2>) -> vec
 
 // -----
 
+// This test is similar to @vector_constant_mask_maskedload_i2, but the mask is multi-dimensional.
+func.func @vector_constant_mask_maskedload_i2_multidim(%passthru: vector<5xi2>) -> vector<5xi2> {
+  %0 = memref.alloc() : memref<4x3x5xi2>
+  %mask = vector.constant_mask [2, 2] : vector<3x5xi1>
+  %ext_mask = vector.extract %mask[1] : vector<5xi1> from vector<3x5xi1>
+  %c0 = arith.constant 0 : index
+  %c2 = arith.constant 2 : index
+  %1 = vector.maskedload %0[%c2, %c0, %c0], %ext_mask, %passthru :
+    memref<4x3x5xi2>, vector<5xi1>, vector<5xi2> into vector<5xi2>
+  return %1 : vector<5xi2>
+}
+
+// CHECK-LABEL: func @vector_constant_mask_maskedload_i2_multidim(
+// CHECK-SAME:   %[[PASSTHRU:[a-zA-Z0-9]+]]
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<15xi8>
+// CHECK: %[[ORIG_MASK:.+]] = vector.constant_mask [2, 2] : vector<3x5xi1>
+// CHECK: %[[EXT_ORIG_MASK:.+]] = vector.extract %[[ORIG_MASK]][1]
+
+// Compressing the mask used for emulated masked load.
+// The innermost dimension is compressed to 2 elements from 5.
+// CHECK: %[[NEW_MASK:.+]] = vector.constant_mask [2, 1] : vector<3x2xi1>
+// CHECK: %[[EXT_NEW_MASK:.+]] = vector.extract %[[NEW_MASK]][1]
+
+// -----
+
 func.func @vector_load_i2_dynamic_indexing(%idx1: index, %idx2: index) -> vector<3xi2> {
   %0 = memref.alloc() : memref<3x3xi2>
   %cst = arith.constant dense<0> : vector<3x3xi2>
@@ -252,7 +275,7 @@ func.func @vector_maskedload_i2_dynamic_indexing_mixed(%passthru: vector<3xi2>,
 // CHECK: %[[MASK:.+]] = vector.constant_mask [3] : vector<3xi1>
 // CHECK: %[[LINEAR1:.+]] = affine.apply #map()[%[[IDX]]]
 // CHECK: %[[LINEAR2:.+]] = affine.apply #map1()[%[[IDX]]]
-// CHECK: %[[ONE:.+]] = arith.constant dense<true> : vector<2xi1>
+// CHECK: %[[ONE:.+]] = vector.constant_mask [2] : vector<2xi1>
 // CHECK: %[[ZERO:.+]] = arith.constant dense<0> : vector<8xi2>
 
 // Extract passthru vector, and insert into zero vector, this is for constructing a new passthru
@@ -301,7 +324,7 @@ func.func @vector_maskedload_i2_dynamic_indexing_mixed(%passthru: vector<3xi2>,
 
 // -----
 
-func.func @vector_maskedload_i4_constant_mask_unaligned(%passthru: vector<5xi2>) -> vector<5xi2> {
+func.func @vector_maskedload_i2_constant_mask_unaligned(%passthru: vector<5xi2>) -> vector<5xi2> {
   %0 = memref.alloc() : memref<3x5xi2>
   %mask = arith.constant dense<[false, true, true, true, false]> : vector<5xi1>
   %c0 = arith.constant 0 : index
@@ -311,24 +334,23 @@ func.func @vector_maskedload_i4_constant_mask_unaligned(%passthru: vector<5xi2>)
   return %1 : vector<5xi2>
 }
 
-// CHECK: func @vector_maskedload_i4_constant_mask_unaligned(
+// CHECK: func @vector_maskedload_i2_constant_mask_unaligned(
 // CHECK-SAME: %[[PTH:.+]]: vector<5xi2>) -> vector<5xi2>
 // CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<4xi8>
 // CHECK: %[[MASK:.+]] = arith.constant dense<[false, true, true, true, false]> : vector<5xi1>
 
+// Emulated masked load from alloc:
 // CHECK: %[[COMPRESSED_MASK:.+]] = arith.constant dense<true> : vector<2xi1>
 // CHECK: %[[EMPTY:.+]] = arith.constant dense<0> : vector<8xi2>
 // CHECK: %[[PTH_PADDED:.+]] = vector.insert_strided_slice %[[PTH]], %[[EMPTY]]
 // CHECK-SAME: {offsets = [1], strides = [1]} : vector<5xi2> into vector<8xi2>
-
-// Emulated masked load from alloc:
 // CHECK: %[[PTH_PADDED_UPCAST:.+]] = vector.bitcast %[[PTH_PADDED]] : vector<8xi2> to vector<2xi8>
 // CHECK: %[[C1:.+]] = arith.constant 1 : index
 // CHECK: %[[MASKLOAD:.+]] = vector.maskedload %[[ALLOC]][%[[C1]]], %[[COMPRESSED_MASK]], %[[PTH_PADDED_UPCAST]]
 // CHECK: %[[MASKLOAD_DOWNCAST:.+]] = vector.bitcast %[[MASKLOAD]] : vector<2xi8> to vector<8xi2>
 
 // Select from emulated loaded vector and passthru vector:
-// TODO: fold this part if possible.
+// TODO: fold insert_strided_slice into source if possible.
 // CHECK: %[[EMPTY_MASK:.+]] = arith.constant dense<false> : vector<8xi1>
 // CHECK: %[[MASK_PADDED:.+]] = vector.insert_strided_slice %[[MASK]], %[[EMPTY_MASK]]
 // CHECK-SAME: {offsets = [1], strides = [1]} : vector<5xi1> into vector<8xi1>

>From 7dace4e2e260401b86a96e80332d454ddddd7444 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Tue, 26 Nov 2024 10:16:00 +0800
Subject: [PATCH 2/2] Updates according to comments

---
 .../Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp | 5 +++++
 .../Vector/vector-emulate-narrow-type-unaligned.mlir      | 8 +++-----
 2 files changed, 8 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 51b41669513c47..181c394edc1d20 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -130,6 +130,7 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
               })
           .Case<vector::ConstantMaskOp>(
               [&](auto constantMaskOp) -> std::optional<Operation *> {
+                // Take the shape of mask, compress its trailing dimension:
                 SmallVector<int64_t> maskDimSizes(
                     constantMaskOp.getMaskDimSizes());
                 int64_t &maskIndex = maskDimSizes.back();
@@ -586,6 +587,10 @@ struct ConvertVectorMaskedLoad final
   LogicalResult
   matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    // See #115653
+    if (op.getVectorType().getRank() != 1)
+      return rewriter.notifyMatchFailure(op,
+                                         "only 1-D vectors are supported ATM");
 
     auto loc = op.getLoc();
     auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
index 263548245ce838..4332e80feed421 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
@@ -134,15 +134,13 @@ func.func @vector_constant_mask_maskedload_i2_multidim(%passthru: vector<5xi2>)
 }
 
 // CHECK-LABEL: func @vector_constant_mask_maskedload_i2_multidim(
-// CHECK-SAME:   %[[PASSTHRU:[a-zA-Z0-9]+]]
-// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<15xi8>
 // CHECK: %[[ORIG_MASK:.+]] = vector.constant_mask [2, 2] : vector<3x5xi1>
-// CHECK: %[[EXT_ORIG_MASK:.+]] = vector.extract %[[ORIG_MASK]][1]
+// CHECK: vector.extract %[[ORIG_MASK]][1]
 
 // Compressing the mask used for emulated masked load.
 // The innermost dimension is compressed to 2 elements from 5.
-// CHECK: %[[NEW_MASK:.+]] = vector.constant_mask [2, 1] : vector<3x2xi1>
-// CHECK: %[[EXT_NEW_MASK:.+]] = vector.extract %[[NEW_MASK]][1]
+// CHECK: %[[NEW_COMPRESSED_MASK:.+]] = vector.constant_mask [2, 1] : vector<3x2xi1>
+// CHECK: vector.extract %[[NEW_COMPRESSED_MASK]][1]
 
 // -----
 



More information about the Mlir-commits mailing list