[Mlir-commits] [mlir] [MLIR] extend `getCompressedMaskOp` support in `VectorEmulateNarrowType` (PR #116122)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Nov 13 15:47:08 PST 2024


https://github.com/lialan created https://github.com/llvm/llvm-project/pull/116122

Previously when `numFrontPadElems` is not zero, `getCompressedMaskOp` produces wrong result if the mask generator op is a `vector.create_mask`. 

This patch resolves the issue by including `numFrontPadElems` into the mask generation.

>From b129fc06a029acc79ce1c968e2660b23b013ffe2 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Wed, 13 Nov 2024 18:16:13 -0500
Subject: [PATCH] [MLIR] extend `getCompressedMaskOp` support in
 `VectorEmulateNarrowType`

Previously when `numFrontPadElems`is not zero, `getCompressedMaskOp` produces
wrong result if the mask generator op is `vector.create_mask`. This patch resolves
such issue when `numFrontPadElems` is not zero.

Signed-off-by: Alan Li <me at alanli.org>
---
 .../Transforms/VectorEmulateNarrowType.cpp    |  8 ++-
 .../vector-emulate-narrow-type-unaligned.mlir | 49 +++++++++++++++++++
 .../Vector/vector-emulate-narrow-type.mlir    | 12 ++---
 3 files changed, 61 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index e5f2a847994aee..0b5b8e0559cd2b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -104,10 +104,14 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
   if (createMaskOp) {
     OperandRange maskOperands = createMaskOp.getOperands();
     size_t numMaskOperands = maskOperands.size();
+    // The `vector.create_mask` op creates a mask arrangement without any zeros
+    // at the front. Also, because `numFrontPadElems` is strictly smaller than
+    // `numSrcElemsPerDest`, the compressed mask generated by shifting the
+    // original mask by `numFrontPadElems` will not have any zeros at the front
+    // as well.
     AffineExpr s0;
     bindSymbols(rewriter.getContext(), s0);
-    s0 = s0 + numSrcElemsPerDest - 1;
-    s0 = s0.floorDiv(numSrcElemsPerDest);
+    s0 = (s0 + numFrontPadElems).ceilDiv(numSrcElemsPerDest);
     OpFoldResult origIndex =
         getAsOpFoldResult(maskOperands[numMaskOperands - 1]);
     OpFoldResult maskIndex =
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
index 7ed75ff7f1579c..327364ce820da7 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
@@ -74,6 +74,55 @@ func.func @vector_cst_maskedload_i2(%passthru: vector<5xi2>) -> vector<3x5xi2> {
 
 // -----
 
+// This tests the correctness of generating compressed mask with `vector.create_mask` and a dynamic input.
+// Specifically, the program masked loads a vector<5xi2> from `vector<3x5xi2>[1, 0]`, with an unknown mask generator `m`.
+// After emulation transformation, it masked loads 2 bytes from linearized index `vector<4xi8>[1]`, with a new compressed mask
+// given by `ceildiv(m + 1, 4)`.
+func.func @check_unaligned_create_mask_dynamic_i2(%m : index, %passthru: vector<5xi2>) -> vector<5xi2> {
+    %0 = memref.alloc() : memref<3x5xi2>
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %mask = vector.create_mask %m : vector<5xi1>
+    %1 = vector.maskedload %0[%c1, %c0], %mask, %passthru :
+      memref<3x5xi2>, vector<5xi1>, vector<5xi2> into vector<5xi2>
+    return %1 : vector<5xi2>
+}
+
+// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> ((s0 + 1) ceildiv 4)>
+// CHECK: func @check_unaligned_create_mask_dynamic_i2(
+// CHECK-SAME:     %[[MASK:.+]]: index, %[[PASSTHRU:.+]]: vector<5xi2>)
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<4xi8>
+// CHECK: %[[COMP_MASK:.+]] = affine.apply #map()[%[[MASK]]]
+// CHECK: vector.create_mask %[[COMP_MASK]] : vector<2xi1>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: vector.maskedload %[[ALLOC]][%[[C1]]]
+
+// -----
+
+// This tests the correctness of generated compressed mask with `vector.create_mask`, and a static input.
+// Quite the same as the previous test, but the mask generator is a static value.
+// In this case, the desired slice `vector<7xi2>` spans over 3 bytes.
+func.func @check_unaligned_create_mask_static_i2(%passthru: vector<7xi2>) -> vector<7xi2> {
+    %0 = memref.alloc() : memref<3x7xi2>
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %c3 = arith.constant 3 : index
+    %mask = vector.create_mask %c3 : vector<7xi1>
+    %1 = vector.maskedload %0[%c1, %c0], %mask, %passthru :
+      memref<3x7xi2>, vector<7xi1>, vector<7xi2> into vector<7xi2>
+    return %1 : vector<7xi2>
+}
+
+// CHECK: func @check_unaligned_create_mask_static_i2(
+// CHECK-SAME:     %[[PASSTHRU:[a-zA-Z0-9]+]]: vector<7xi2>)
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<6xi8>
+// CHECK: %[[C2:.+]] = arith.constant 2 : index
+// CHECK: %[[COMP_MASK:.+]] = vector.create_mask %[[C2]] : vector<3xi1>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %4 = vector.maskedload %[[ALLOC]][%[[C1]]], %[[COMP_MASK]]
+
+// -----
+
 func.func @vector_load_i2_dynamic_indexing(%idx1: index, %idx2: index) -> vector<3xi2> {
   %0 = memref.alloc() : memref<3x3xi2>
   %cst = arith.constant dense<0> : vector<3x3xi2>
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
index 034bd47f6163e6..c68909061d8f3c 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
@@ -141,7 +141,7 @@ func.func @vector_maskedload_i8(%arg1: index, %arg2: index, %arg3: index, %passt
 // CHECK-NEXT:   return
 
 //  CHECK32-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 4)>
-//  CHECK32-DAG: #[[MASK_IDX_MAP:.+]] = affine_map<()[s0] -> ((s0 + 3) floordiv 4)>
+//  CHECK32-DAG: #[[MASK_IDX_MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)>
 //      CHECK32: func @vector_maskedload_i8(
 // CHECK32-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index,
 // CHECK32-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index, %[[ARG3:[a-zA-Z0-9]+]]: vector<4xi8>)
@@ -169,7 +169,7 @@ func.func @vector_maskedload_i4(%arg1: index, %arg2: index, %arg3: index, %passt
     return %2 : vector<3x8xi4>
 }
 //  CHECK-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
-//  CHECK-DAG: #[[MASK_IDX_MAP:.+]] = affine_map<()[s0] -> ((s0 + 1) floordiv 2)>
+//  CHECK-DAG: #[[MASK_IDX_MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 2)>
 //      CHECK: func @vector_maskedload_i4(
 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index,
 // CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index, %[[ARG3:[a-zA-Z0-9]+]]: vector<8xi4>)
@@ -185,7 +185,7 @@ func.func @vector_maskedload_i4(%arg1: index, %arg2: index, %arg3: index, %passt
 //      CHECK:   %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[BITCAST]], %[[ARG3]] : vector<8xi1>, vector<8xi4>
 
 //  CHECK32-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
-//  CHECK32-DAG: #[[MASK_IDX_MAP:.+]] = affine_map<()[s0] -> ((s0 + 7) floordiv 8)>
+//  CHECK32-DAG: #[[MASK_IDX_MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
 //      CHECK32: func @vector_maskedload_i4(
 // CHECK32-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index,
 // CHECK32-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index, %[[ARG3:[a-zA-Z0-9]+]]: vector<8xi4>)
@@ -473,7 +473,7 @@ func.func @vector_maskedstore_i8(%arg0: index, %arg1: index, %arg2: index, %valu
 // CHECK-NEXT:   return
 
 // CHECK32-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 2 + s1 floordiv 4)>
-// CHECK32-DAG: #[[MASK_IDX_MAP:.+]] = affine_map<()[s0] -> ((s0 + 3) floordiv 4)>
+// CHECK32-DAG: #[[MASK_IDX_MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)>
 // CHECK32:     func @vector_maskedstore_i8(
 // CHECK32-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
 // CHECK32-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
@@ -506,7 +506,7 @@ func.func @vector_maskedstore_i4(
     return
 }
 // CHECK: #[[$ATTR_10:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
-// CHECK: #[[$ATTR_11:.+]] = affine_map<()[s0] -> ((s0 + 1) floordiv 2)>
+// CHECK: #[[$ATTR_11:.+]] = affine_map<()[s0] -> (s0 ceildiv 2)>
 
 // CHECK-LABEL:   func.func @vector_maskedstore_i4(
 // CHECK-SAME:      %[[IDX_1:[a-zA-Z0-9]+]]: index,
@@ -526,7 +526,7 @@ func.func @vector_maskedstore_i4(
 // CHECK:           vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]] : memref<12xi8>, vector<4xi1>, vector<4xi8>
 
 // CHECK32: #[[$ATTR_17:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
-// CHECK32: #[[$ATTR_18:.+]] = affine_map<()[s0] -> ((s0 + 7) floordiv 8)>
+// CHECK32: #[[$ATTR_18:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
 
 // CHECK32-LABEL:   func.func @vector_maskedstore_i4(
 // CHECK32-SAME:      %[[IDX_1:[a-zA-Z0-9]+]]: index,



More information about the Mlir-commits mailing list