[Mlir-commits] [mlir] [MLIR] `vector.constant_mask` to support unaligned cases (PR #116520)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Nov 20 18:06:02 PST 2024
https://github.com/lialan updated https://github.com/llvm/llvm-project/pull/116520
>From f9cc3562abefd15597f900e9900f1d36dd81d80b Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Fri, 15 Nov 2024 21:10:12 -0500
Subject: [PATCH] [MLIR] `vector.constant_mask` to support unaligned cases
In the case of unaligned indexing and `vector.constant_mask`, notice
that `numFrontPadElems` is always strictly smaller than `numSrcElemsPerDest`,
which means that with a non-zero `numFrontPadElems`, the compressed `constant_mask`
op will not have any preceding zeros elements in the innermost dimemsion but the
values and size relevant might change due to the extra step of shifting and
aligning elements.
This patch enables multi-dimensional support by simply observing
the abovementioned property and eliminate the constraints.
---
.../Transforms/VectorEmulateNarrowType.cpp | 38 ++++---------
.../vector-emulate-narrow-type-unaligned.mlir | 54 ++++++++++++++++---
2 files changed, 57 insertions(+), 35 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 87c30a733c363e..ef142759c96fc0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -128,34 +128,16 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
return rewriter.create<vector::CreateMaskOp>(loc, newMaskType,
newMaskOperands);
})
- .Case<vector::ConstantMaskOp>([&](auto constantMaskOp)
- -> std::optional<Operation *> {
- ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
- size_t numMaskOperands = maskDimSizes.size();
- int64_t origIndex = maskDimSizes[numMaskOperands - 1];
- int64_t startIndex = numFrontPadElems / numSrcElemsPerDest;
- int64_t maskIndex = llvm::divideCeil(numFrontPadElems + origIndex,
- numSrcElemsPerDest);
-
- // TODO: we only want the mask between [startIndex, maskIndex]
- // to be true, the rest are false.
- if (numFrontPadElems != 0 && maskDimSizes.size() > 1)
- return std::nullopt;
-
- SmallVector<int64_t> newMaskDimSizes(maskDimSizes.drop_back());
- newMaskDimSizes.push_back(maskIndex);
-
- if (numFrontPadElems == 0)
- return rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
- newMaskDimSizes);
-
- SmallVector<bool> newMaskValues;
- for (int64_t i = 0; i < numDestElems; ++i)
- newMaskValues.push_back(i >= startIndex && i < maskIndex);
- auto newMask = DenseElementsAttr::get(newMaskType, newMaskValues);
- return rewriter.create<arith::ConstantOp>(loc, newMaskType,
- newMask);
- })
+ .Case<vector::ConstantMaskOp>(
+ [&](auto constantMaskOp) -> std::optional<Operation *> {
+ SmallVector<int64_t> maskDimSizes(
+ constantMaskOp.getMaskDimSizes());
+ int64_t &maskIndex = maskDimSizes.back();
+ maskIndex = llvm::divideCeil(numFrontPadElems + maskIndex,
+ numSrcElemsPerDest);
+ return rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
+ maskDimSizes);
+ })
.Case<arith::ConstantOp>([&](auto constantOp)
-> std::optional<Operation *> {
// TODO: Support multiple dimensions.
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
index 721c8a8d5d2034..ec1f3fbc4d688d 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
@@ -57,7 +57,7 @@ func.func @vector_constant_mask_maskedload_i2(%passthru: vector<5xi2>) -> vector
// CHECK-LABEL: func @vector_constant_mask_maskedload_i2(
// CHECK-SAME: %[[ARG0:.+]]: vector<5xi2>) -> vector<3x5xi2>
// CHECK: %[[ORIGINMASK:.+]] = vector.constant_mask [3] : vector<5xi1>
-// CHECK: %[[NEWMASK:.+]] = arith.constant dense<true> : vector<2xi1>
+// CHECK: %[[NEWMASK:.+]] = vector.constant_mask [2] : vector<2xi1>
// CHECK: %[[VESSEL:.+]] = arith.constant dense<0> : vector<8xi2>
// CHECK: %[[INSERT1:.+]] = vector.insert_strided_slice %[[ARG0]], %[[VESSEL]]
// CHECK-SAME: {offsets = [2], strides = [1]} : vector<5xi2> into vector<8xi2>
@@ -123,6 +123,47 @@ func.func @check_unaligned_create_mask_static_i2(%passthru: vector<7xi2>) -> vec
// -----
+func.func @vector_constant_mask_maskedload_i2_multidim(%passthru: vector<5xi2>) -> vector<5xi2> {
+ %0 = memref.alloc() : memref<4x3x5xi2>
+ %mask = vector.constant_mask [2, 2] : vector<3x5xi1>
+ %ext_mask = vector.extract %mask[1] : vector<5xi1> from vector<3x5xi1>
+ %c0 = arith.constant 0 : index
+ %c2 = arith.constant 2 : index
+ %1 = vector.maskedload %0[%c2, %c0, %c0], %ext_mask, %passthru :
+ memref<4x3x5xi2>, vector<5xi1>, vector<5xi2> into vector<5xi2>
+ return %1 : vector<5xi2>
+}
+
+// CHECK-LABEL: func @vector_constant_mask_maskedload_i2_multidim(
+// CHECK-SAME: %[[PASSTHRU:[a-zA-Z0-9]+]]
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<15xi8>
+// CHECK: %[[ORIG_MASK:.+]] = vector.constant_mask [2, 2] : vector<3x5xi1>
+// CHECK: %[[EXT_ORIG_MASK:.+]] = vector.extract %[[ORIG_MASK]][1]
+
+// compressed mask, used for emulated masked load
+// CHECK: %[[NEW_MASK:.+]] = vector.constant_mask [2, 1] : vector<3x2xi1>
+// CHECK: %[[EXT_NEW_MASK:.+]] = vector.extract %[[NEW_MASK]][1]
+
+// Create a padded and shifted passthru vector
+// CHECK: %[[EMPTY:.+]] = arith.constant dense<0> : vector<8xi2>
+// CHECK: %[[PADDED_PTH:.+]] = vector.insert_strided_slice %[[PASSTHRU]], %[[EMPTY]]
+// CHECK-SAME: {offsets = [2], strides = [1]}
+
+// CHECK: %[[PTH_DOWNCAST:.+]] = vector.bitcast %[[PADDED_PTH]] : vector<8xi2> to vector<2xi8>
+// CHECK: %[[C7:.+]] = arith.constant 7 : index
+// CHECK: %[[MASKLOAD:.+]] = vector.maskedload %[[ALLOC]][%[[C7]]], %[[EXT_NEW_MASK]], %[[PTH_DOWNCAST]]
+// CHECK: %[[DOWNCAST_LOAD:.+]] = vector.bitcast %[[MASKLOAD]]
+
+// pad and shift the original mask to match the size and location of the loaded value.
+// CHECK: %[[EMPTY_MASK:.+]] = arith.constant dense<false> : vector<8xi1>
+// CHECK: %[[PADDED_MASK:.+]] = vector.insert_strided_slice %[[EXT_ORIG_MASK]], %[[EMPTY_MASK]]
+// CHECK-SAME: {offsets = [2], strides = [1]}
+// CHECK: %[[SELECT:.+]] = arith.select %[[PADDED_MASK]], %[[DOWNCAST_LOAD]], %[[PADDED_PTH]]
+// CHECK: %[[RESULT:.+]] = vector.extract_strided_slice %[[SELECT]]
+// CHECK-SAME: {offsets = [2], sizes = [5], strides = [1]}
+
+// -----
+
func.func @vector_load_i2_dynamic_indexing(%idx1: index, %idx2: index) -> vector<3xi2> {
%0 = memref.alloc() : memref<3x3xi2>
%cst = arith.constant dense<0> : vector<3x3xi2>
@@ -252,7 +293,7 @@ func.func @vector_maskedload_i2_dynamic_indexing_mixed(%passthru: vector<3xi2>,
// CHECK: %[[MASK:.+]] = vector.constant_mask [3] : vector<3xi1>
// CHECK: %[[LINEAR1:.+]] = affine.apply #map()[%[[IDX]]]
// CHECK: %[[LINEAR2:.+]] = affine.apply #map1()[%[[IDX]]]
-// CHECK: %[[ONE:.+]] = arith.constant dense<true> : vector<2xi1>
+// CHECK: %[[ONE:.+]] = vector.constant_mask [2] : vector<2xi1>
// CHECK: %[[ZERO:.+]] = arith.constant dense<0> : vector<8xi2>
// Extract passthru vector, and insert into zero vector, this is for constructing a new passthru
@@ -301,7 +342,7 @@ func.func @vector_maskedload_i2_dynamic_indexing_mixed(%passthru: vector<3xi2>,
// -----
-func.func @vector_maskedload_i4_constant_mask_unaligned(%passthru: vector<5xi2>) -> vector<5xi2> {
+func.func @vector_maskedload_i2_constant_mask_unaligned(%passthru: vector<5xi2>) -> vector<5xi2> {
%0 = memref.alloc() : memref<3x5xi2>
%mask = arith.constant dense<[false, true, true, true, false]> : vector<5xi1>
%c0 = arith.constant 0 : index
@@ -311,24 +352,23 @@ func.func @vector_maskedload_i4_constant_mask_unaligned(%passthru: vector<5xi2>)
return %1 : vector<5xi2>
}
-// CHECK: func @vector_maskedload_i4_constant_mask_unaligned(
+// CHECK: func @vector_maskedload_i2_constant_mask_unaligned(
// CHECK-SAME: %[[PTH:.+]]: vector<5xi2>) -> vector<5xi2>
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<4xi8>
// CHECK: %[[MASK:.+]] = arith.constant dense<[false, true, true, true, false]> : vector<5xi1>
+// Emulated masked load from alloc:
// CHECK: %[[COMPRESSED_MASK:.+]] = arith.constant dense<true> : vector<2xi1>
// CHECK: %[[EMPTY:.+]] = arith.constant dense<0> : vector<8xi2>
// CHECK: %[[PTH_PADDED:.+]] = vector.insert_strided_slice %[[PTH]], %[[EMPTY]]
// CHECK-SAME: {offsets = [1], strides = [1]} : vector<5xi2> into vector<8xi2>
-
-// Emulated masked load from alloc:
// CHECK: %[[PTH_PADDED_UPCAST:.+]] = vector.bitcast %[[PTH_PADDED]] : vector<8xi2> to vector<2xi8>
// CHECK: %[[C1:.+]] = arith.constant 1 : index
// CHECK: %[[MASKLOAD:.+]] = vector.maskedload %[[ALLOC]][%[[C1]]], %[[COMPRESSED_MASK]], %[[PTH_PADDED_UPCAST]]
// CHECK: %[[MASKLOAD_DOWNCAST:.+]] = vector.bitcast %[[MASKLOAD]] : vector<2xi8> to vector<8xi2>
// Select from emulated loaded vector and passthru vector:
-// TODO: fold this part if possible.
+// TODO: fold insert_strided_slice into source if possible.
// CHECK: %[[EMPTY_MASK:.+]] = arith.constant dense<false> : vector<8xi1>
// CHECK: %[[MASK_PADDED:.+]] = vector.insert_strided_slice %[[MASK]], %[[EMPTY_MASK]]
// CHECK-SAME: {offsets = [1], strides = [1]} : vector<5xi1> into vector<8xi1>
More information about the Mlir-commits
mailing list