[Mlir-commits] [mlir] [MLIR] `vector.constant_mask` to support unaligned cases (PR #116520)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Nov 16 18:26:27 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: lialan (lialan)

<details>
<summary>Changes</summary>

In the case of unaligned indexing and `vector.constant_mask`, notice that `numFrontPadElems` is always strictly smaller than `numSrcElemsPerDest`, which means that with a non-zero `numFrontPadElems`, the compressed `constant_mask` op will not have any preceding zeros elements in the innermost dimemsion but the values and size relevant might change due to the extra step of shifting and aligning elements.

This patch enables multi-dimensional support by simply observing the abovementioned property and eliminate the constraints.

---
Full diff: https://github.com/llvm/llvm-project/pull/116520.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp (+13-28) 
- (modified) mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir (+47-6) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index dc8bab325184b8..26a5f566f34948 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -125,34 +125,19 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
                 return rewriter.create<vector::CreateMaskOp>(loc, newMaskType,
                                                              newMaskOperands);
               })
-          .Case<vector::ConstantMaskOp>([&](auto constantMaskOp)
-                                            -> std::optional<Operation *> {
-            ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
-            size_t numMaskOperands = maskDimSizes.size();
-            int64_t origIndex = maskDimSizes[numMaskOperands - 1];
-            int64_t startIndex = numFrontPadElems / numSrcElemsPerDest;
-            int64_t maskIndex = llvm::divideCeil(numFrontPadElems + origIndex,
-                                                 numSrcElemsPerDest);
-
-            // TODO: we only want the mask between [startIndex, maskIndex]
-            // to be true, the rest are false.
-            if (numFrontPadElems != 0 && maskDimSizes.size() > 1)
-              return std::nullopt;
-
-            SmallVector<int64_t> newMaskDimSizes(maskDimSizes.drop_back());
-            newMaskDimSizes.push_back(maskIndex);
-
-            if (numFrontPadElems == 0)
-              return rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
-                                                             newMaskDimSizes);
-
-            SmallVector<bool> newMaskValues;
-            for (int64_t i = 0; i < numDestElems; ++i)
-              newMaskValues.push_back(i >= startIndex && i < maskIndex);
-            auto newMask = DenseElementsAttr::get(newMaskType, newMaskValues);
-            return rewriter.create<arith::ConstantOp>(loc, newMaskType,
-                                                      newMask);
-          })
+          .Case<vector::ConstantMaskOp>(
+              [&](auto constantMaskOp) -> std::optional<Operation *> {
+                ArrayRef<int64_t> maskDimSizes =
+                    constantMaskOp.getMaskDimSizes();
+                size_t numMaskOperands = maskDimSizes.size();
+                int64_t origIndex = maskDimSizes[numMaskOperands - 1];
+                int64_t maskIndex = llvm::divideCeil(
+                    numFrontPadElems + origIndex, numSrcElemsPerDest);
+                SmallVector<int64_t> newMaskDimSizes(maskDimSizes.drop_back());
+                newMaskDimSizes.push_back(maskIndex);
+                return rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
+                                                               newMaskDimSizes);
+              })
           .Case<arith::ConstantOp>([&](auto constantOp)
                                        -> std::optional<Operation *> {
             // TODO: Support multiple dimensions.
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
index b1a0d4f924f3cf..73ce7ac9be2437 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
@@ -57,7 +57,7 @@ func.func @vector_cst_maskedload_i2(%passthru: vector<5xi2>) -> vector<3x5xi2> {
 // CHECK-LABEL: func @vector_cst_maskedload_i2(
 // CHECK-SAME: %[[ARG0:.+]]: vector<5xi2>) -> vector<3x5xi2>
 // CHECK: %[[ORIGINMASK:.+]] = vector.constant_mask [3] : vector<5xi1>
-// CHECK: %[[NEWMASK:.+]] = arith.constant dense<true> : vector<2xi1>
+// CHECK: %[[NEWMASK:.+]] = vector.constant_mask [2] : vector<2xi1>
 // CHECK: %[[VESSEL:.+]] = arith.constant dense<0> : vector<8xi2>
 // CHECK: %[[INSERT1:.+]] = vector.insert_strided_slice %[[ARG0]], %[[VESSEL]]
 // CHECK-SAME: {offsets = [2], strides = [1]} : vector<5xi2> into vector<8xi2>
@@ -74,6 +74,48 @@ func.func @vector_cst_maskedload_i2(%passthru: vector<5xi2>) -> vector<3x5xi2> {
 
 // -----
 
+func.func @vector_constant_mask_maskedload_i2_multidim(%passthru: vector<5xi2>) -> vector<5xi2> {
+  %0 = memref.alloc() : memref<4x3x5xi2>
+  %cst = arith.constant dense<0> : vector<3x5xi2>
+  %mask = vector.constant_mask [2, 2] : vector<3x5xi1>
+  %ext_mask = vector.extract %mask[1] : vector<5xi1> from vector<3x5xi1>
+  %c0 = arith.constant 0 : index
+  %c2 = arith.constant 2 : index
+  %1 = vector.maskedload %0[%c2, %c0, %c0], %ext_mask, %passthru :
+    memref<4x3x5xi2>, vector<5xi1>, vector<5xi2> into vector<5xi2>
+  return %1 : vector<5xi2>
+}
+
+// CHECK-LABEL: func @vector_constant_mask_maskedload_i2_multidim(
+// CHECK-SAME:   %[[PASSTHRU:[a-zA-Z0-9]+]]
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<15xi8>
+// CHECK: %[[ORIG_MASK:.+]] = vector.constant_mask [2, 2] : vector<3x5xi1>
+// CHECK: %[[EXT_ORIG_MASK:.+]] = vector.extract %[[ORIG_MASK]][1]
+
+// compressed mask, used for emulated masked load
+// CHECK: %[[NEW_MASK:.+]] = vector.constant_mask [2, 1] : vector<3x2xi1>
+// CHECK: %[[EXT_NEW_MASK:.+]] = vector.extract %[[NEW_MASK]][1]
+
+// Create a padded and shifted passthru vector
+// CHECK: %[[EMPTY:.+]] = arith.constant dense<0> : vector<8xi2>
+// CHECK: %[[PADDED_PTH:.+]] = vector.insert_strided_slice %[[PASSTHRU]], %[[EMPTY]]
+// CHECK-SAME: {offsets = [2], strides = [1]} 
+
+// CHECK: %[[PTH_DOWNCAST:.+]] = vector.bitcast %[[PADDED_PTH]] : vector<8xi2> to vector<2xi8>
+// CHECK: %[[C7:.+]] = arith.constant 7 : index
+// CHECK: %[[MASKLOAD:.+]] = vector.maskedload %[[ALLOC]][%[[C7]]], %[[EXT_NEW_MASK]], %[[PTH_DOWNCAST]]
+// CHECK: %[[DOWNCAST_LOAD:.+]] = vector.bitcast %[[MASKLOAD]]
+
+// pad and shift the original mask to match the size and location of the loaded value.
+// CHECK: %[[EMPTY_MASK:.+]] = arith.constant dense<false> : vector<8xi1>
+// CHECK: %[[PADDED_MASK:.+]] = vector.insert_strided_slice %[[EXT_ORIG_MASK]], %[[EMPTY_MASK]]
+// CHECK-SAME: {offsets = [2], strides = [1]} 
+// CHECK: %[[SELECT:.+]] = arith.select %[[PADDED_MASK]], %[[DOWNCAST_LOAD]], %[[PADDED_PTH]]
+// CHECK: %[[RESULT:.+]] = vector.extract_strided_slice %[[SELECT]]
+// CHECK-SAME: {offsets = [2], sizes = [5], strides = [1]}
+
+// -----
+
 func.func @vector_load_i2_dynamic_indexing(%idx1: index, %idx2: index) -> vector<3xi2> {
   %0 = memref.alloc() : memref<3x3xi2>
   %cst = arith.constant dense<0> : vector<3x3xi2>
@@ -203,7 +245,7 @@ func.func @vector_maskedload_i2_dynamic_indexing_mixed(%passthru: vector<3xi2>,
 // CHECK: %[[MASK:.+]] = vector.constant_mask [3] : vector<3xi1>
 // CHECK: %[[LINEAR1:.+]] = affine.apply #map()[%[[IDX]]]
 // CHECK: %[[LINEAR2:.+]] = affine.apply #map1()[%[[IDX]]]
-// CHECK: %[[ONE:.+]] = arith.constant dense<true> : vector<2xi1>
+// CHECK: %[[ONE:.+]] = vector.constant_mask [2] : vector<2xi1>
 // CHECK: %[[ZERO:.+]] = arith.constant dense<0> : vector<8xi2>
 
 // Extract passthru vector, and insert into zero vector, this is for constructing a new passthru
@@ -268,18 +310,17 @@ func.func @vector_maskedload_i4_constant_mask_unaligned(%passthru: vector<5xi2>)
 // CHECK: %[[MASK:.+]] = arith.constant dense<[false, true, true, true, false]> : vector<5xi1>
 
 // CHECK: %[[COMPRESSED_MASK:.+]] = arith.constant dense<true> : vector<2xi1>
+
+// Emulated masked load from alloc:
 // CHECK: %[[EMPTY:.+]] = arith.constant dense<0> : vector<8xi2>
 // CHECK: %[[PTH_PADDED:.+]] = vector.insert_strided_slice %[[PTH]], %[[EMPTY]]
 // CHECK-SAME: {offsets = [1], strides = [1]} : vector<5xi2> into vector<8xi2>
-
-// Emulated masked load from alloc:
 // CHECK: %[[PTH_PADDED_UPCAST:.+]] = vector.bitcast %[[PTH_PADDED]] : vector<8xi2> to vector<2xi8>
 // CHECK: %[[C1:.+]] = arith.constant 1 : index
 // CHECK: %[[MASKLOAD:.+]] = vector.maskedload %[[ALLOC]][%[[C1]]], %[[COMPRESSED_MASK]], %[[PTH_PADDED_UPCAST]]
 // CHECK: %[[MASKLOAD_DOWNCAST:.+]] = vector.bitcast %[[MASKLOAD]] : vector<2xi8> to vector<8xi2>
 
-// Select from emulated loaded vector and passthru vector:
-// TODO: fold this part if possible.
+// Select from emulated loaded vector and passthru vector: (TODO: fold this part if possible)
 // CHECK: %[[EMPTY_MASK:.+]] = arith.constant dense<false> : vector<8xi1>
 // CHECK: %[[MASK_PADDED:.+]] = vector.insert_strided_slice %[[MASK]], %[[EMPTY_MASK]]
 // CHECK-SAME: {offsets = [1], strides = [1]} : vector<5xi1> into vector<8xi1>

``````````

</details>


https://github.com/llvm/llvm-project/pull/116520


More information about the Mlir-commits mailing list