[Mlir-commits] [mlir] andrzej/emulate narrow type update 2 (PR #115612)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Nov 9 08:46:20 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Andrzej WarzyĆski (banach-space)
<details>
<summary>Changes</summary>
- **[mlir][vector][nfc] Add tests + update docs for narrow-type emulation**
- **fixup! [mlir][vector][nfc] Add tests + update docs for narrow-type emulation**
- **[mlir][vector] Restrict narrow-type-emulation patterns**
---
Patch is 27.28 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/115612.diff
4 Files Affected:
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp (+53-14)
- (added) mlir/test/Dialect/Vector/emulate-narrow-type-unsupported.mlir (+112)
- (modified) mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir (+133-9)
- (modified) mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp (+10-1)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 58841f29698e0d..91da9bc9c7f8a0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1,11 +1,19 @@
-//===- VectorEmulateNarrowType.cpp - Narrow type emulation ----*- C++
-//-*-===//
+//===- VectorEmulateNarrowType.cpp - Narrow type emulation ----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
+//
+// This file implements target-independent rewrites and utilities to emulate
+// narrow types that are not supported by the target hardware, e.g. i4, using
+// wider types, e.g. i8.
+//
+/// Currently, only power-of-two integer types are supported. These are
+/// converted to wider integers that are either 8 bits wide or wider.
+//
+//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -217,6 +225,10 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ if (op.getValueToStore().getType().getRank() != 1)
+ return rewriter.notifyMatchFailure(op,
+ "only 1-D vectors are supported ATM");
+
auto loc = op.getLoc();
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
Type oldElementType = op.getValueToStore().getType().getElementType();
@@ -283,6 +295,10 @@ struct ConvertVectorMaskedStore final
matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ if (op.getValueToStore().getType().getRank() != 1)
+ return rewriter.notifyMatchFailure(op,
+ "only 1-D vectors are supported ATM");
+
auto loc = op.getLoc();
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
Type oldElementType = op.getValueToStore().getType().getElementType();
@@ -315,23 +331,34 @@ struct ConvertVectorMaskedStore final
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndicesOfr);
// Load the whole data and use arith.select to handle the corner cases.
- // E.g., given these input values:
//
- // %mask = [0, 1, 1, 1, 1, 1, 0, 0]
- // %0[%c0, %c0] contains [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8]
- // %value_to_store = [0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x0]
+ // As an example, for this masked store:
+ //
+ // vector.maskedstore %0[%c0, %c0], %mask, %val_to_store
//
- // we'll have
+ // and given these input i4 values:
//
- // expected output: [0x1, 0xA, 0xB, 0xC, 0xD, 0xE, 0x7, 0x8]
+ // %mask = [1, 1, 1, 1, 1, 0, 0, 0] (8 * i1)
+ // %0[%c0, %c0] =
+ // [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8] (8 * i4)
+ // %val_to_store =
+ // [0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x0] (8 * i4)
//
- // %new_mask = [1, 1, 1, 0]
- // %maskedload = [0x12, 0x34, 0x56, 0x00]
- // %bitcast = [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x0, 0x0]
- // %select_using_shifted_mask = [0x1, 0xA, 0xB, 0xC, 0xD, 0xE, 0x0, 0x0]
- // %packed_data = [0x1A, 0xBC, 0xDE, 0x00]
+ // we'll have the following i4 output:
//
- // Using the new mask to store %packed_data results in expected output.
+ // expected output: [0x9, 0xA, 0xB, 0xC, 0xD, 0x6, 0x7, 0x8]
+ //
+ // Emulating the above using i8 will give:
+ //
+ // %compressed_mask = [1, 1, 1, 0] (4 * i1)
+ // %maskedload = [0x12, 0x34, 0x56, 0x00] (4 * i8)
+ // %bitcast = [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x0, 0x0] (8 * i4)
+ // %select_using_shifted_mask =
+ // [0x9, 0xA, 0xB, 0xC, 0xD, 0x6, 0x0, 0x0] (8 * i4)
+ // %packed_data = [0x9A, 0xBC, 0xD6, 0x00] (4 * i8)
+ //
+ // Using the compressed mask to store %packed_data results in expected
+ // output.
FailureOr<Operation *> newMask =
getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale);
if (failed(newMask))
@@ -372,6 +399,10 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ if (op.getVectorType().getRank() != 1)
+ return rewriter.notifyMatchFailure(op,
+ "only 1-D vectors are supported ATM");
+
auto loc = op.getLoc();
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
Type oldElementType = op.getType().getElementType();
@@ -473,6 +504,10 @@ struct ConvertVectorMaskedLoad final
matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ if (op.getVectorType().getRank() != 1)
+ return rewriter.notifyMatchFailure(op,
+ "only 1-D vectors are supported ATM");
+
auto loc = op.getLoc();
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
Type oldElementType = op.getType().getElementType();
@@ -624,6 +659,10 @@ struct ConvertVectorTransferRead final
matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ if (op.getVectorType().getRank() != 1)
+ return rewriter.notifyMatchFailure(op,
+ "only 1-D vectors are supported ATM");
+
auto loc = op.getLoc();
auto convertedType = cast<MemRefType>(adaptor.getSource().getType());
Type oldElementType = op.getType().getElementType();
diff --git a/mlir/test/Dialect/Vector/emulate-narrow-type-unsupported.mlir b/mlir/test/Dialect/Vector/emulate-narrow-type-unsupported.mlir
new file mode 100644
index 00000000000000..30ce13e8169c47
--- /dev/null
+++ b/mlir/test/Dialect/Vector/emulate-narrow-type-unsupported.mlir
@@ -0,0 +1,112 @@
+// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=32 skip-memref-type-conversion" --split-input-file %s | FileCheck %s
+
+// These tests mimic tests from vector-narrow-type.mlir, but load/store 2-D
+// insted of 1-D vectors. That's currently not supported.
+
+///----------------------------------------------------------------------------------------
+/// vector.load
+///----------------------------------------------------------------------------------------
+
+func.func @vector_load_2d_i8_negative(%arg1: index, %arg2: index) -> vector<2x4xi8> {
+ %0 = memref.alloc() : memref<3x4xi8>
+ %1 = vector.load %0[%arg1, %arg2] : memref<3x4xi8>, vector<2x4xi8>
+ return %1 : vector<2x4xi8>
+}
+
+// No support for loading 2D vectors - expect no conversions
+// CHECK-LABEL: func @vector_load_2d_i8_negative
+// CHECK: memref.alloc() : memref<3x4xi8>
+// CHECK-NOT: i32
+
+// -----
+
+///----------------------------------------------------------------------------------------
+/// vector.transfer_read
+///----------------------------------------------------------------------------------------
+
+func.func @vector_transfer_read_2d_i4_negative(%arg1: index, %arg2: index) -> vector<2x8xi4> {
+ %c0 = arith.constant 0 : i4
+ %0 = memref.alloc() : memref<3x8xi4>
+ %1 = vector.transfer_read %0[%arg1, %arg2], %c0 {in_bounds = [true, true]} :
+ memref<3x8xi4>, vector<2x8xi4>
+ return %1 : vector<2x8xi4>
+}
+// CHECK-LABEL: func @vector_transfer_read_2d_i4_negative
+// CHECK: memref.alloc() : memref<3x8xi4>
+// CHECK-NOT: i32
+
+// -----
+
+///----------------------------------------------------------------------------------------
+/// vector.maskedload
+///----------------------------------------------------------------------------------------
+
+func.func @vector_maskedload_2d_i8_negative(%arg1: index, %arg2: index, %arg3: index, %passthru: vector<2x4xi8>) -> vector<2x4xi8> {
+ %0 = memref.alloc() : memref<3x4xi8>
+ %mask = vector.create_mask %arg3, %arg3 : vector<2x4xi1>
+ %1 = vector.maskedload %0[%arg1, %arg2], %mask, %passthru :
+ memref<3x4xi8>, vector<2x4xi1>, vector<2x4xi8> into vector<2x4xi8>
+ return %1 : vector<2x4xi8>
+}
+
+// CHECK-LABEL: func @vector_maskedload_2d_i8_negative
+// CHECK: memref.alloc() : memref<3x4xi8>
+// CHECK-NOT: i32
+
+// -----
+
+///----------------------------------------------------------------------------------------
+/// vector.extract -> vector.masked_load
+///----------------------------------------------------------------------------------------
+
+func.func @vector_extract_maskedload_2d_i4_negative(%arg1: index) -> vector<8x8x16xi4> {
+ %0 = memref.alloc() : memref<8x8x16xi4>
+ %c0 = arith.constant 0 : index
+ %c16 = arith.constant 16 : index
+ %c8 = arith.constant 8 : index
+ %cst_1 = arith.constant dense<0> : vector<8x8x16xi4>
+ %cst_2 = arith.constant dense<0> : vector<16xi4>
+ %27 = vector.create_mask %c8, %arg1, %c16 : vector<8x8x16xi1>
+ %48 = vector.extract %27[0] : vector<8x16xi1> from vector<8x8x16xi1>
+ %49 = vector.extract %48[0] : vector<16xi1> from vector<8x16xi1>
+ %50 = vector.maskedload %0[%c0, %c0, %c0], %49, %cst_2 : memref<8x8x16xi4>, vector<16xi1>, vector<16xi4> into vector<16xi4>
+ %63 = vector.insert %50, %cst_1 [0, 0] : vector<16xi4> into vector<8x8x16xi4>
+ return %63 : vector<8x8x16xi4>
+}
+
+// CHECK-LABEL: func @vector_extract_maskedload_2d_i4_negative
+// CHECK: memref.alloc() : memref<8x8x16xi4>
+// CHECK-NOT: i32
+
+// -----
+
+///----------------------------------------------------------------------------------------
+/// vector.store
+///----------------------------------------------------------------------------------------
+
+func.func @vector_store_2d_i8_negative(%arg0: vector<8xi8>, %arg1: index, %arg2: index) {
+ %0 = memref.alloc() : memref<4x8xi8>
+ vector.store %arg0, %0[%arg1, %arg2] :memref<4x8xi8>, vector<8xi8>
+ return
+}
+
+// CHECK-LABEL: func @vector_store_2d_i8_negative
+// CHECK: memref.alloc() : memref<4x8xi8>
+// CHECK-NOT: i32
+
+// -----
+
+///----------------------------------------------------------------------------------------
+/// vector.maskedstore
+///----------------------------------------------------------------------------------------
+
+func.func @vector_maskedstore_2d_i8_negative(%arg0: index, %arg1: index, %arg2: index, %value: vector<8xi8>) {
+ %0 = memref.alloc() : memref<3x8xi8>
+ %mask = vector.create_mask %arg2 : vector<8xi1>
+ vector.maskedstore %0[%arg0, %arg1], %mask, %value : memref<3x8xi8>, vector<8xi1>, vector<8xi8>
+ return
+}
+
+// CHECK-LABEL: func @vector_maskedstore_2d_i8_negative
+// CHECK: memref.alloc() : memref<3x8xi8>
+// CHECK-NOT: i32
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
index cba299b2a1d956..5e139b04d7ee6f 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
@@ -1,6 +1,10 @@
// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=8" --cse --split-input-file %s | FileCheck %s
// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=32" --cse --split-input-file %s | FileCheck %s --check-prefix=CHECK32
+///----------------------------------------------------------------------------------------
+/// vector.load
+///----------------------------------------------------------------------------------------
+
func.func @vector_load_i8(%arg1: index, %arg2: index) -> vector<4xi8> {
%0 = memref.alloc() : memref<3x4xi8>
%1 = vector.load %0[%arg1, %arg2] : memref<3x4xi8>, vector<4xi8>
@@ -82,6 +86,10 @@ func.func @vector_load_i4_dynamic(%arg0 : index, %arg1 : index, %arg2 : index, %
// -----
+///----------------------------------------------------------------------------------------
+/// vector.transfer_read
+///----------------------------------------------------------------------------------------
+
func.func @vector_transfer_read_i4(%arg1: index, %arg2: index) -> vector<8xi4> {
%c0 = arith.constant 0 : i4
%0 = memref.alloc() : memref<3x8xi4>
@@ -111,6 +119,10 @@ func.func @vector_transfer_read_i4(%arg1: index, %arg2: index) -> vector<8xi4> {
// -----
+///----------------------------------------------------------------------------------------
+/// vector.maskedload
+///----------------------------------------------------------------------------------------
+
func.func @vector_maskedload_i8(%arg1: index, %arg2: index, %arg3: index, %passthru: vector<4xi8>) -> vector<4xi8> {
%0 = memref.alloc() : memref<3x4xi8>
%mask = vector.create_mask %arg3 : vector<4xi1>
@@ -190,7 +202,7 @@ func.func @vector_maskedload_i4(%arg1: index, %arg2: index, %arg3: index, %passt
// -----
-func.func @vector_cst_maskedload_i8(%arg1: index, %arg2: index, %passthru: vector<4xi8>) -> vector<4xi8> {
+func.func @vector_maskedload_i8_constant_mask(%arg1: index, %arg2: index, %passthru: vector<4xi8>) -> vector<4xi8> {
%0 = memref.alloc() : memref<3x4xi8>
%mask = vector.constant_mask [2] : vector<4xi1>
%1 = vector.maskedload %0[%arg1, %arg2], %mask, %passthru :
@@ -198,7 +210,7 @@ func.func @vector_cst_maskedload_i8(%arg1: index, %arg2: index, %passthru: vecto
return %1 : vector<4xi8>
}
// Expect no conversions, i8 is supported.
-// CHECK: func @vector_cst_maskedload_i8(
+// CHECK: func @vector_maskedload_i8_constant_mask(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: vector<4xi8>)
// CHECK-NEXT: %[[ALLOC:.+]] = memref.alloc() : memref<3x4xi8>
@@ -208,7 +220,7 @@ func.func @vector_cst_maskedload_i8(%arg1: index, %arg2: index, %passthru: vecto
// CHECK-NEXT: return
// CHECK32-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 4)>
-// CHECK32: func @vector_cst_maskedload_i8(
+// CHECK32: func @vector_maskedload_i8_constant_mask(
// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index,
// CHECK32-SAME: %[[ARG3:[a-zA-Z0-9]+]]: vector<4xi8>)
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
@@ -224,7 +236,7 @@ func.func @vector_cst_maskedload_i8(%arg1: index, %arg2: index, %passthru: vecto
// -----
-func.func @vector_cst_maskedload_i4(%arg1: index, %arg2: index, %passthru: vector<8xi4>) -> vector<3x8xi4> {
+func.func @vector_maskedload_i4_constant_mask(%arg1: index, %arg2: index, %passthru: vector<8xi4>) -> vector<3x8xi4> {
%0 = memref.alloc() : memref<3x8xi4>
%cst = arith.constant dense<0> : vector<3x8xi4>
%mask = vector.constant_mask [4] : vector<8xi1>
@@ -234,7 +246,7 @@ func.func @vector_cst_maskedload_i4(%arg1: index, %arg2: index, %passthru: vecto
return %2 : vector<3x8xi4>
}
// CHECK-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
-// CHECK: func @vector_cst_maskedload_i4(
+// CHECK: func @vector_maskedload_i4_constant_mask(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: vector<8xi4>)
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<12xi8>
@@ -248,7 +260,7 @@ func.func @vector_cst_maskedload_i4(%arg1: index, %arg2: index, %passthru: vecto
// CHECK: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[BITCAST]], %[[ARG2]] : vector<8xi1>, vector<8xi4>
// CHECK32-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
-// CHECK32: func @vector_cst_maskedload_i4(
+// CHECK32: func @vector_maskedload_i4_constant_mask(
// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index,
// CHECK32-SAME: %[[ARG2:[a-zA-Z0-9]+]]: vector<8xi4>)
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
@@ -263,6 +275,10 @@ func.func @vector_cst_maskedload_i4(%arg1: index, %arg2: index, %passthru: vecto
// -----
+///----------------------------------------------------------------------------------------
+/// vector.extract -> vector.masked_load
+///----------------------------------------------------------------------------------------
+
func.func @vector_extract_maskedload_i4(%arg1: index) -> vector<8x8x16xi4> {
%0 = memref.alloc() : memref<8x8x16xi4>
%c0 = arith.constant 0 : index
@@ -353,6 +369,10 @@ func.func @vector_extract_cst_maskedload_i4() -> vector<8x8x16xi4> {
// -----
+///----------------------------------------------------------------------------------------
+/// vector.store
+///----------------------------------------------------------------------------------------
+
func.func @vector_store_i8(%arg0: vector<8xi8>, %arg1: index, %arg2: index) {
%0 = memref.alloc() : memref<4x8xi8>
vector.store %arg0, %0[%arg1, %arg2] :memref<4x8xi8>, vector<8xi8>
@@ -431,6 +451,10 @@ func.func @vector_store_i4_dynamic(%arg0: vector<8xi4>, %arg1: index, %arg2: ind
// -----
+///----------------------------------------------------------------------------------------
+/// vector.maskedstore
+///----------------------------------------------------------------------------------------
+
func.func @vector_maskedstore_i8(%arg0: index, %arg1: index, %arg2: index, %value: vector<8xi8>) {
%0 = memref.alloc() : memref<3x8xi8>
%mask = vector.create_mask %arg2 : vector<8xi1>
@@ -469,14 +493,68 @@ func.func @vector_maskedstore_i8(%arg0: index, %arg1: index, %arg2: index, %valu
// -----
-func.func @vector_cst_maskedstore_i8(%arg0: index, %arg1: index, %value: vector<8xi8>) {
+func.func @vector_maskedstore_i4(
+ %idx1: index,
+ %idx2: index,
+ %num_elements_to_store: index,
+ %value: vector<8xi4>) {
+
+ %0 = memref.alloc() : memref<3x8xi4>
+ %mask = vector.create_mask %num_elements_to_store : vector<8xi1>
+ vector.maskedstore %0[%idx1, %idx2], %mask, %value :
+ memref<3x8xi4>, vector<8xi1>, vector<8xi4>
+ return
+}
+// CHECK: #[[$ATTR_10:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
+// CHECK: #[[$ATTR_11:.+]] = affine_map<()[s0] -> ((s0 + 1) floordiv 2)>
+
+// CHECK-LABEL: func.func @vector_maskedstore_i4(
+// CHECK-SAME: %[[IDX_1:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[IDX_2:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[NUM_EL_TO_STORE:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) {
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<12xi8>
+// CHECK: %[[ORIG_MASK:.+]] = vector.create_mask %[[NUM_EL_TO_STORE]] : vector<8xi1>
+// CHECK: %[[LIDX:.+]] = affine.apply #[[$ATTR_10]](){{\[}}%[[IDX_1]], %[[IDX_2]]]
+// CHECK: %[[MASK_IDX:.+]] = affine.apply #[[$ATTR_11]](){{\[}}%[[NUM_EL_TO_STORE]]]
+// CHECK: %[[NEW_MASK:.+]] = vector.create_mask %[[MASK_IDX]] : vector<4xi1>
+// CHECK: %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<4xi8>
+// CHECK: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<12xi8>, vector<4xi1>, vector<4xi8> into vector<4xi8>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<4xi8> to vector<8xi4>
+// CHECK: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[BITCAST]] : vector<8xi1>, vector<8xi4>
+// CHECK: %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<4xi8>
+// CHECK: vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]] : memref<12xi8>, vector<4xi1>, vector<4xi8>
+
+// CHECK32: #[[$ATTR_17:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
+// CHECK32: #[[$ATTR_18:.+]] = affine_map<()[s0] -> ((s0 + 7) floordiv 8)>
+
+// CHECK32...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/115612
More information about the Mlir-commits
mailing list