[Mlir-commits] [mlir] [mlir][vector][nfc] Add tests + update docs for narrow-type emulation (PR #115460)

Andrzej WarzyƄski llvmlistbot at llvm.org
Fri Nov 8 05:47:06 PST 2024


https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/115460

>From 8a9abf6714427af3d82b13b28961eb6f65c18b4a Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Fri, 8 Nov 2024 10:44:02 +0000
Subject: [PATCH 1/2] [mlir][vector][nfc] Add tests + update docs for
 narrow-type emulation

The documentation for narrow-type emulation is a bit inaccurate. In
particular, we don't really support/generate masks like this:

  %mask = [0, 1, 1, 1, 1, 1, 0, 0]

I updated the comment for `ConvertVectorMaskedStore` accordingly. I also
added a few clarification (e.g. that the comment is discussing i4 -> i8
emulation).

Separately, I've noticed inconsistency in testing for
narrow-type-emulation. In particular, there's a few cases that are
tested for "loading" and which are missing for "storing". I've added
  * comments in the test file so that it's easy to see what's tested,
  * missing tests for `vector.maskedstor`.

Finally, I've added a top level comment in VectorEmulateNarrowType.cpp
so that the overall intent and design are clearer.
---
 .../Transforms/VectorEmulateNarrowType.cpp    |  41 ++++--
 .../Vector/vector-emulate-narrow-type.mlir    | 126 ++++++++++++++++++
 2 files changed, 154 insertions(+), 13 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 58841f29698e0d..0f88ff21e847e4 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1,11 +1,19 @@
-//===- VectorEmulateNarrowType.cpp - Narrow type emulation ----*- C++
-//-*-===//
+//===- VectorEmulateNarrowType.cpp - Narrow type emulation ----------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 //
 //===----------------------------------------------------------------------===//
+//
+// This file implements target-independent rewrites and utilities to emulate
+// narrow types that are not supported by the target hardware, e.g. i4, using
+// wider types, e.g. i8.
+//
+/// Currently, only power-of-two integer types are supported. These are
+/// converted to wider integers that are either 8 bits wide or wider.
+//
+//===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
@@ -315,21 +323,28 @@ struct ConvertVectorMaskedStore final
         getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndicesOfr);
 
     // Load the whole data and use arith.select to handle the corner cases.
-    // E.g., given these input values:
+    // E.g., given these input i4 values:
+    //
+    //   %res = vector.maskedload %0[%c0, %c0], %mask, %val_to_store :
+    //
+    //   %mask = [1, 1, 1, 1, 1, 1, 1, 0]                     (8 * i1)
+    //   %0[%c0, %c0] =
+    //      [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8]          (8 * i4)
+    //   %val_to_store =
+    //      [0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x0]          (8 * i4)
     //
-    //   %mask = [0, 1, 1, 1, 1, 1, 0, 0]
-    //   %0[%c0, %c0] contains [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8]
-    //   %value_to_store = [0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x0]
+    // we'll have the following i4 output:
     //
-    // we'll have
+    //    expected output: [0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x8]
     //
-    //    expected output: [0x1, 0xA, 0xB, 0xC, 0xD, 0xE, 0x7, 0x8]
+    // Emulating the above using i8 will give:
     //
-    //    %new_mask = [1, 1, 1, 0]
-    //    %maskedload = [0x12, 0x34, 0x56, 0x00]
-    //    %bitcast = [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x0, 0x0]
-    //    %select_using_shifted_mask = [0x1, 0xA, 0xB, 0xC, 0xD, 0xE, 0x0, 0x0]
-    //    %packed_data = [0x1A, 0xBC, 0xDE, 0x00]
+    //    %compressed_mask = [1, 1, 1, 1]                     (4 * i1)
+    //    %maskedload = [0x12, 0x34, 0x56, 0x78]              (4 * i8)
+    //    %bitcast = [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8] (8 * i4)
+    //    %select_using_shifted_mask =
+    //      [0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x8]          (8 * i4)
+    //    %packed_data = [0x9A, 0xBC, 0xDE, 0xF8]             (4 * i8)
     //
     // Using the new mask to store %packed_data results in expected output.
     FailureOr<Operation *> newMask =
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
index cba299b2a1d956..c98b4dd50a7028 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
@@ -1,6 +1,10 @@
 // RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=8" --cse --split-input-file %s | FileCheck %s
 // RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=32" --cse --split-input-file %s | FileCheck %s --check-prefix=CHECK32
 
+///----------------------------------------------------------------------------------------
+/// vector.load
+///----------------------------------------------------------------------------------------
+
 func.func @vector_load_i8(%arg1: index, %arg2: index) -> vector<4xi8> {
     %0 = memref.alloc() : memref<3x4xi8>
     %1 = vector.load %0[%arg1, %arg2] : memref<3x4xi8>, vector<4xi8>
@@ -82,6 +86,10 @@ func.func @vector_load_i4_dynamic(%arg0 : index, %arg1 : index, %arg2 : index, %
 
 // -----
 
+///----------------------------------------------------------------------------------------
+/// vector.transfer_read
+///----------------------------------------------------------------------------------------
+
 func.func @vector_transfer_read_i4(%arg1: index, %arg2: index) -> vector<8xi4> {
     %c0 = arith.constant 0 : i4
     %0 = memref.alloc() : memref<3x8xi4>
@@ -111,6 +119,10 @@ func.func @vector_transfer_read_i4(%arg1: index, %arg2: index) -> vector<8xi4> {
 
 // -----
 
+///----------------------------------------------------------------------------------------
+/// vector.maskedload
+///----------------------------------------------------------------------------------------
+
 func.func @vector_maskedload_i8(%arg1: index, %arg2: index, %arg3: index, %passthru: vector<4xi8>) -> vector<4xi8> {
     %0 = memref.alloc() : memref<3x4xi8>
     %mask = vector.create_mask %arg3 : vector<4xi1>
@@ -263,6 +275,10 @@ func.func @vector_cst_maskedload_i4(%arg1: index, %arg2: index, %passthru: vecto
 
 // -----
 
+///----------------------------------------------------------------------------------------
+/// vector.extract -> vector.masked_load
+///----------------------------------------------------------------------------------------
+
 func.func @vector_extract_maskedload_i4(%arg1: index) -> vector<8x8x16xi4> {
     %0 = memref.alloc() : memref<8x8x16xi4>
     %c0 = arith.constant 0 : index
@@ -353,6 +369,10 @@ func.func @vector_extract_cst_maskedload_i4() -> vector<8x8x16xi4> {
 
 // -----
 
+///----------------------------------------------------------------------------------------
+/// vector.store
+///----------------------------------------------------------------------------------------
+
 func.func @vector_store_i8(%arg0: vector<8xi8>, %arg1: index, %arg2: index) {
     %0 = memref.alloc() : memref<4x8xi8>
     vector.store %arg0, %0[%arg1, %arg2] :memref<4x8xi8>, vector<8xi8>
@@ -431,6 +451,10 @@ func.func @vector_store_i4_dynamic(%arg0: vector<8xi4>, %arg1: index, %arg2: ind
 
 // -----
 
+///----------------------------------------------------------------------------------------
+/// vector.maskedstore
+///----------------------------------------------------------------------------------------
+
 func.func @vector_maskedstore_i8(%arg0: index, %arg1: index, %arg2: index, %value: vector<8xi8>) {
   %0 = memref.alloc() : memref<3x8xi8>
   %mask = vector.create_mask %arg2 : vector<8xi1>
@@ -469,6 +493,61 @@ func.func @vector_maskedstore_i8(%arg0: index, %arg1: index, %arg2: index, %valu
 
 // -----
 
+func.func @vector_maskedstore_i4(
+  %idx1: index,
+  %idx2: index,
+  %num_elements_to_store: index,
+  %value: vector<8xi4>) {
+
+    %0 = memref.alloc() : memref<3x8xi4>
+    %cst = arith.constant dense<0> : vector<3x8xi4>
+    %mask = vector.create_mask %num_elements_to_store : vector<8xi1>
+    vector.maskedstore %0[%idx1, %idx2], %mask, %value :
+      memref<3x8xi4>, vector<8xi1>, vector<8xi4>
+    return
+}
+// CHECK: #[[$ATTR_10:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
+// CHECK: #[[$ATTR_11:.+]] = affine_map<()[s0] -> ((s0 + 1) floordiv 2)>
+
+// CHECK-LABEL:   func.func @vector_maskedstore_i4(
+// CHECK-SAME:      %[[IDX_1:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[IDX_2:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[NUM_EL_TO_STORE:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) {
+// CHECK:           %[[ALLOC:.+]] = memref.alloc() : memref<12xi8>
+// CHECK:           %[[ORIG_MASK:.+]] = vector.create_mask %[[NUM_EL_TO_STORE]] : vector<8xi1>
+// CHECK:           %[[LIDX:.+]] = affine.apply #[[$ATTR_10]](){{\[}}%[[IDX_1]], %[[IDX_2]]]
+// CHECK:           %[[MASK_IDX:.+]] = affine.apply #[[$ATTR_11]](){{\[}}%[[NUM_EL_TO_STORE]]]
+// CHECK:           %[[NEW_MASK:.+]] = vector.create_mask %[[MASK_IDX]] : vector<4xi1>
+// CHECK:           %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<4xi8>
+// CHECK:           %[[LOAD:.+]] = vector.maskedload %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<12xi8>, vector<4xi1>, vector<4xi8> into vector<4xi8>
+// CHECK:           %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<4xi8> to vector<8xi4>
+// CHECK:           %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[BITCAST]] : vector<8xi1>, vector<8xi4>
+// CHECK:           %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<4xi8>
+// CHECK:           vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]] : memref<12xi8>, vector<4xi1>, vector<4xi8>
+
+// CHECK32: #[[$ATTR_17:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
+// CHECK32: #[[$ATTR_18:.+]] = affine_map<()[s0] -> ((s0 + 7) floordiv 8)>
+
+// CHECK32-LABEL:   func.func @vector_maskedstore_i4(
+// CHECK32-SAME:      %[[IDX_1:[a-zA-Z0-9]+]]: index,
+// CHECK32-SAME:      %[[IDX_2:[a-zA-Z0-9]+]]: index,
+// CHECK32-SAME:      %[[NUM_EL_TO_STORE:[a-zA-Z0-9]+]]: index,
+// CHECK32-SAME:      %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) {
+// CHECK32:           %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
+// CHECK32:           %[[ORIG_MASK:.+]] = vector.create_mask %[[NUM_EL_TO_STORE]] : vector<8xi1>
+// CHECK32:           %[[LIDX:.+]] = affine.apply #[[$ATTR_17]](){{\[}}%[[IDX_1]], %[[IDX_2]]]
+// CHECK32:           %[[MASK_IDX:.+]] = affine.apply #[[$ATTR_18]](){{\[}}%[[NUM_EL_TO_STORE]]]
+// CHECK32:           %[[NEW_MASK:.+]] = vector.create_mask %[[MASK_IDX]] : vector<1xi1>
+// CHECK32:           %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<1xi32>
+// CHECK32:           %[[LOAD:.+]] = vector.maskedload %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<3xi32>, vector<1xi1>, vector<1xi32> into vector<1xi32>
+// CHECK32:           %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi32> to vector<8xi4>
+// CHECK32:           %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[BITCAST]] : vector<8xi1>, vector<8xi4>
+// CHECK32:           %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<1xi32>
+// CHECK32:           vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]] : memref<3xi32>, vector<1xi1>, vector<1xi32>
+
+// -----
+
 func.func @vector_cst_maskedstore_i8(%arg0: index, %arg1: index, %value: vector<8xi8>) {
   %0 = memref.alloc() : memref<3x8xi8>
   %mask = vector.constant_mask [4] : vector<8xi1>
@@ -500,3 +579,50 @@ func.func @vector_cst_maskedstore_i8(%arg0: index, %arg1: index, %value: vector<
 // CHECK32:        %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL]], %[[BITCAST]] : vector<8xi1>, vector<8xi8>
 // CHECK32:        %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi8> to vector<2xi32>
 // CHECK32:        vector.maskedstore %[[ALLOC]][%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]]
+
+// -----
+
+func.func @vector_cst_maskedstore_i4(
+  %idx_1: index,
+  %idx_2: index,
+  %val_to_store: vector<8xi4>) {
+
+    %0 = memref.alloc() : memref<3x8xi4>
+    %cst = arith.constant dense<0> : vector<3x8xi4>
+    %mask = vector.constant_mask [4] : vector<8xi1>
+    vector.maskedstore %0[%idx_1, %idx_2], %mask, %val_to_store :
+      memref<3x8xi4>, vector<8xi1>, vector<8xi4>
+    return
+}
+
+// CHECK: #[[$ATTR_12:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
+// CHECK-LABEL:   func.func @vector_cst_maskedstore_i4(
+// CHECK-SAME:      %[[IDX_1:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[IDX_2:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) {
+// CHECK:           %[[ALLOC:.+]] = memref.alloc() : memref<12xi8>
+// CHECK:           %[[ORIG_MASK:.+]] = vector.constant_mask [4] : vector<8xi1>
+// CHECK:           %[[LIDX:.+]] = affine.apply #[[$ATTR_12]](){{\[}}%[[IDX_1]], %[[IDX_2]]]
+// CHECK:           %[[NEW_MASK:.+]] = vector.constant_mask [2] : vector<4xi1>
+// CHECK:           %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<4xi8>
+// CHECK:           %[[LOAD:.+]] = vector.maskedload %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<12xi8>, vector<4xi1>, vector<4xi8> into vector<4xi8>
+// CHECK:           %[[VAL_9:.+]] = vector.bitcast %[[LOAD]] : vector<4xi8> to vector<8xi4>
+// CHECK:           %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[VAL_9]] : vector<8xi1>, vector<8xi4>
+// CHECK:           %[[BITCAST:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<4xi8>
+// CHECK:           vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[BITCAST]] : memref<12xi8>, vector<4xi1>, vector<4xi8>
+
+// CHECK32: #[[$ATTR_20:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
+// CHECK32-LABEL:   func.func @vector_cst_maskedstore_i4(
+// CHECK32-SAME:      %[[IDX_1:[a-zA-Z0-9]+]]: index,
+// CHECK32-SAME:      %[[IDX_2:[a-zA-Z0-9]+]]: index,
+// CHECK32-SAME:      %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) {
+// CHECK32:           %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
+// CHECK32:           %[[ORIG_MASK:.+]] = vector.constant_mask [4] : vector<8xi1>
+// CHECK32:           %[[LIDX:.+]] = affine.apply #[[$ATTR_20]](){{\[}}%[[IDX_1]], %[[IDX_2]]]
+// CHECK32:           %[[NEW_MASK:.+]] = vector.constant_mask [1] : vector<1xi1>
+// CHECK32:           %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<1xi32>
+// CHECK32:           %[[LOAD:.+]] = vector.maskedload %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<3xi32>, vector<1xi1>, vector<1xi32> into vector<1xi32>
+// CHECK32:           %[[VAL_9:.+]] = vector.bitcast %[[LOAD]] : vector<1xi32> to vector<8xi4>
+// CHECK32:           %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_2]], %[[VAL_9]] : vector<8xi1>, vector<8xi4>
+// CHECK32:           %[[BITCAST:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<1xi32>
+// CHECK32:           vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[BITCAST]] : memref<3xi32>, vector<1xi1>, vector<1xi32>

>From eef52e2784ac8bf7d7b011d8db84f284f8f14576 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Fri, 8 Nov 2024 13:45:58 +0000
Subject: [PATCH 2/2] fixup! [mlir][vector][nfc] Add tests + update docs for
 narrow-type emulation

* Fix failing test
* Tweak/fix the comment
* Rename: @vector_cst_maskedload_i8 -> @vector_cst_maskedload_i8_constant_mask (same for other similar tests)
---
 .../Transforms/VectorEmulateNarrowType.cpp    | 24 ++++++-----
 .../Vector/vector-emulate-narrow-type.mlir    | 42 +++++++++----------
 2 files changed, 34 insertions(+), 32 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 0f88ff21e847e4..b9f5c71fa4805f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -323,11 +323,14 @@ struct ConvertVectorMaskedStore final
         getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndicesOfr);
 
     // Load the whole data and use arith.select to handle the corner cases.
-    // E.g., given these input i4 values:
     //
-    //   %res = vector.maskedload %0[%c0, %c0], %mask, %val_to_store :
+    // As an example, for this masked store:
     //
-    //   %mask = [1, 1, 1, 1, 1, 1, 1, 0]                     (8 * i1)
+    //   vector.maskedstore %0[%c0, %c0], %mask, %val_to_store
+    //
+    // and given these input i4 values:
+    //
+    //   %mask = [1, 1, 1, 1, 1, 0, 0, 0]                     (8 * i1)
     //   %0[%c0, %c0] =
     //      [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8]          (8 * i4)
     //   %val_to_store =
@@ -335,18 +338,19 @@ struct ConvertVectorMaskedStore final
     //
     // we'll have the following i4 output:
     //
-    //    expected output: [0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x8]
+    //    expected output: [0x9, 0xA, 0xB, 0xC, 0xD, 0x6, 0x7, 0x8]
     //
     // Emulating the above using i8 will give:
     //
-    //    %compressed_mask = [1, 1, 1, 1]                     (4 * i1)
-    //    %maskedload = [0x12, 0x34, 0x56, 0x78]              (4 * i8)
-    //    %bitcast = [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8] (8 * i4)
+    //    %compressed_mask = [1, 1, 1, 0]                     (4 * i1)
+    //    %maskedload = [0x12, 0x34, 0x56, 0x00]              (4 * i8)
+    //    %bitcast = [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x0, 0x0] (8 * i4)
     //    %select_using_shifted_mask =
-    //      [0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x8]          (8 * i4)
-    //    %packed_data = [0x9A, 0xBC, 0xDE, 0xF8]             (4 * i8)
+    //      [0x9, 0xA, 0xB, 0xC, 0xD, 0x6, 0x0, 0x0]          (8 * i4)
+    //    %packed_data = [0x9A, 0xBC, 0xD6, 0x00]             (4 * i8)
     //
-    // Using the new mask to store %packed_data results in expected output.
+    // Using the compressed mask to store %packed_data results in expected
+    // output.
     FailureOr<Operation *> newMask =
         getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale);
     if (failed(newMask))
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
index c98b4dd50a7028..5e139b04d7ee6f 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
@@ -202,7 +202,7 @@ func.func @vector_maskedload_i4(%arg1: index, %arg2: index, %arg3: index, %passt
 
 // -----
 
-func.func @vector_cst_maskedload_i8(%arg1: index, %arg2: index, %passthru: vector<4xi8>) -> vector<4xi8> {
+func.func @vector_maskedload_i8_constant_mask(%arg1: index, %arg2: index, %passthru: vector<4xi8>) -> vector<4xi8> {
     %0 = memref.alloc() : memref<3x4xi8>
     %mask = vector.constant_mask [2] : vector<4xi1>
     %1 = vector.maskedload %0[%arg1, %arg2], %mask, %passthru :
@@ -210,7 +210,7 @@ func.func @vector_cst_maskedload_i8(%arg1: index, %arg2: index, %passthru: vecto
     return %1 : vector<4xi8>
 }
 // Expect no conversions, i8 is supported.
-//      CHECK: func @vector_cst_maskedload_i8(
+//      CHECK: func @vector_maskedload_i8_constant_mask(
 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index,
 // CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: vector<4xi8>)
 // CHECK-NEXT:   %[[ALLOC:.+]] = memref.alloc() : memref<3x4xi8>
@@ -220,7 +220,7 @@ func.func @vector_cst_maskedload_i8(%arg1: index, %arg2: index, %passthru: vecto
 // CHECK-NEXT:   return
 
 //  CHECK32-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 4)>
-//      CHECK32: func @vector_cst_maskedload_i8(
+//      CHECK32: func @vector_maskedload_i8_constant_mask(
 // CHECK32-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index,
 // CHECK32-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: vector<4xi8>)
 //      CHECK32:   %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
@@ -236,7 +236,7 @@ func.func @vector_cst_maskedload_i8(%arg1: index, %arg2: index, %passthru: vecto
 
 // -----
 
-func.func @vector_cst_maskedload_i4(%arg1: index, %arg2: index, %passthru: vector<8xi4>) -> vector<3x8xi4> {
+func.func @vector_maskedload_i4_constant_mask(%arg1: index, %arg2: index, %passthru: vector<8xi4>) -> vector<3x8xi4> {
     %0 = memref.alloc() : memref<3x8xi4>
     %cst = arith.constant dense<0> : vector<3x8xi4>
     %mask = vector.constant_mask [4] : vector<8xi1>
@@ -246,7 +246,7 @@ func.func @vector_cst_maskedload_i4(%arg1: index, %arg2: index, %passthru: vecto
     return %2 : vector<3x8xi4>
 }
 //  CHECK-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
-//      CHECK: func @vector_cst_maskedload_i4(
+//      CHECK: func @vector_maskedload_i4_constant_mask(
 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index,
 // CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: vector<8xi4>)
 //      CHECK:   %[[ALLOC:.+]] = memref.alloc() : memref<12xi8>
@@ -260,7 +260,7 @@ func.func @vector_cst_maskedload_i4(%arg1: index, %arg2: index, %passthru: vecto
 //      CHECK:   %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[BITCAST]], %[[ARG2]] : vector<8xi1>, vector<8xi4>
 
 //  CHECK32-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
-//      CHECK32: func @vector_cst_maskedload_i4(
+//      CHECK32: func @vector_maskedload_i4_constant_mask(
 // CHECK32-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index,
 // CHECK32-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: vector<8xi4>)
 //      CHECK32:   %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
@@ -500,7 +500,6 @@ func.func @vector_maskedstore_i4(
   %value: vector<8xi4>) {
 
     %0 = memref.alloc() : memref<3x8xi4>
-    %cst = arith.constant dense<0> : vector<3x8xi4>
     %mask = vector.create_mask %num_elements_to_store : vector<8xi1>
     vector.maskedstore %0[%idx1, %idx2], %mask, %value :
       memref<3x8xi4>, vector<8xi1>, vector<8xi4>
@@ -548,14 +547,14 @@ func.func @vector_maskedstore_i4(
 
 // -----
 
-func.func @vector_cst_maskedstore_i8(%arg0: index, %arg1: index, %value: vector<8xi8>) {
+func.func @vector_maskedstore_i8_constant_mask(%arg0: index, %arg1: index, %value: vector<8xi8>) {
   %0 = memref.alloc() : memref<3x8xi8>
   %mask = vector.constant_mask [4] : vector<8xi1>
   vector.maskedstore %0[%arg0, %arg1], %mask, %value : memref<3x8xi8>, vector<8xi1>, vector<8xi8>
   return
 }
 // Expect no conversions, i8 is supported.
-//      CHECK: func @vector_cst_maskedstore_i8(
+//      CHECK: func @vector_maskedstore_i8_constant_mask(
 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
 // CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
 // CHECK-SAME:     %[[VAL:[a-zA-Z0-9]+]]
@@ -565,7 +564,7 @@ func.func @vector_cst_maskedstore_i8(%arg0: index, %arg1: index, %value: vector<
 // CHECK-NEXT:   return
 
 // CHECK32-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 2 + s1 floordiv 4)>
-// CHECK32:     func @vector_cst_maskedstore_i8(
+// CHECK32:     func @vector_maskedstore_i8_constant_mask(
 // CHECK32-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
 // CHECK32-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
 // CHECK32-SAME:     %[[VAL:[a-zA-Z0-9]+]]
@@ -582,13 +581,12 @@ func.func @vector_cst_maskedstore_i8(%arg0: index, %arg1: index, %value: vector<
 
 // -----
 
-func.func @vector_cst_maskedstore_i4(
+func.func @vector_maskedstore_i4_constant_mask(
   %idx_1: index,
   %idx_2: index,
   %val_to_store: vector<8xi4>) {
 
     %0 = memref.alloc() : memref<3x8xi4>
-    %cst = arith.constant dense<0> : vector<3x8xi4>
     %mask = vector.constant_mask [4] : vector<8xi1>
     vector.maskedstore %0[%idx_1, %idx_2], %mask, %val_to_store :
       memref<3x8xi4>, vector<8xi1>, vector<8xi4>
@@ -596,7 +594,7 @@ func.func @vector_cst_maskedstore_i4(
 }
 
 // CHECK: #[[$ATTR_12:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
-// CHECK-LABEL:   func.func @vector_cst_maskedstore_i4(
+// CHECK-LABEL:   func.func @vector_maskedstore_i4_constant_mask(
 // CHECK-SAME:      %[[IDX_1:[a-zA-Z0-9]+]]: index,
 // CHECK-SAME:      %[[IDX_2:[a-zA-Z0-9]+]]: index,
 // CHECK-SAME:      %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) {
@@ -606,13 +604,13 @@ func.func @vector_cst_maskedstore_i4(
 // CHECK:           %[[NEW_MASK:.+]] = vector.constant_mask [2] : vector<4xi1>
 // CHECK:           %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<4xi8>
 // CHECK:           %[[LOAD:.+]] = vector.maskedload %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<12xi8>, vector<4xi1>, vector<4xi8> into vector<4xi8>
-// CHECK:           %[[VAL_9:.+]] = vector.bitcast %[[LOAD]] : vector<4xi8> to vector<8xi4>
-// CHECK:           %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[VAL_9]] : vector<8xi1>, vector<8xi4>
-// CHECK:           %[[BITCAST:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<4xi8>
-// CHECK:           vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[BITCAST]] : memref<12xi8>, vector<4xi1>, vector<4xi8>
+// CHECK:           %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<4xi8> to vector<8xi4>
+// CHECK:           %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[BITCAST]] : vector<8xi1>, vector<8xi4>
+// CHECK:           %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<4xi8>
+// CHECK:           vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]] : memref<12xi8>, vector<4xi1>, vector<4xi8>
 
 // CHECK32: #[[$ATTR_20:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
-// CHECK32-LABEL:   func.func @vector_cst_maskedstore_i4(
+// CHECK32-LABEL:   func.func @vector_maskedstore_i4_constant_mask(
 // CHECK32-SAME:      %[[IDX_1:[a-zA-Z0-9]+]]: index,
 // CHECK32-SAME:      %[[IDX_2:[a-zA-Z0-9]+]]: index,
 // CHECK32-SAME:      %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) {
@@ -622,7 +620,7 @@ func.func @vector_cst_maskedstore_i4(
 // CHECK32:           %[[NEW_MASK:.+]] = vector.constant_mask [1] : vector<1xi1>
 // CHECK32:           %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<1xi32>
 // CHECK32:           %[[LOAD:.+]] = vector.maskedload %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<3xi32>, vector<1xi1>, vector<1xi32> into vector<1xi32>
-// CHECK32:           %[[VAL_9:.+]] = vector.bitcast %[[LOAD]] : vector<1xi32> to vector<8xi4>
-// CHECK32:           %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_2]], %[[VAL_9]] : vector<8xi1>, vector<8xi4>
-// CHECK32:           %[[BITCAST:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<1xi32>
-// CHECK32:           vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[BITCAST]] : memref<3xi32>, vector<1xi1>, vector<1xi32>
+// CHECK32:           %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi32> to vector<8xi4>
+// CHECK32:           %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[BITCAST]] : vector<8xi1>, vector<8xi4>
+// CHECK32:           %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<1xi32>
+// CHECK32:           vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]] : memref<3xi32>, vector<1xi1>, vector<1xi32>



More information about the Mlir-commits mailing list