[Mlir-commits] [mlir] [mlir][vector] Move tests for `rewriteAlignedSubByteInt{Ext|Trunc}` (nfc) (PR #126416)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Sun Feb 9 05:51:50 PST 2025
https://github.com/banach-space created https://github.com/llvm/llvm-project/pull/126416
Moves tests for `rewriteAlignedSubByteIntExt` and
`rewriteAlignedSubByteIntTrunc` into a dedicated files. Also adds +
fixes some comments.
This is merely for better organisation and so that it's easier to
identify the patterns and edge cases being tested.
>From 9e95f12a19fa2ca185cb5a266b92eaf03d07caa4 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Sun, 9 Feb 2025 13:45:44 +0000
Subject: [PATCH] [mlir][vector] Move tests for
`rewriteAlignedSubByteInt{Ext|Trunc}` (nfc)
Moves tests for `rewriteAlignedSubByteIntExt` and
`rewriteAlignedSubByteIntTrunc` into a dedicated files. Also adds +
fixes some comments.
This is merely for better organisation and so that it's easier to
identify the patterns and edge cases being tested.
---
.../Transforms/VectorEmulateNarrowType.cpp | 10 +-
.../Vector/vector-rewrite-narrow-types.mlir | 377 ----------------
...vector-rewrite-subbyte-ext-and-trunci.mlir | 415 ++++++++++++++++++
3 files changed, 420 insertions(+), 382 deletions(-)
create mode 100644 mlir/test/Dialect/Vector/vector-rewrite-subbyte-ext-and-trunci.mlir
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index bf1ecd7d4559caf..057788d85a0f74a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -7,8 +7,8 @@
//===----------------------------------------------------------------------===//
//
// This file implements target-independent rewrites and utilities to emulate
-// narrow types that are not supported by the target hardware, e.g. i4, using
-// wider types, e.g. i8.
+// narrow types that are not supported by the target hardware, e.g. i4
+// ("emulated type"), using wider types, e.g. i8 ("container type").
//
/// Currently, only power-of-two integer types are supported. These are
/// converted to wider integers that are either 8 bits wide or wider.
@@ -2022,19 +2022,19 @@ void vector::populateVectorNarrowTypeRewritePatterns(
// Patterns for aligned cases. We set higher priority as they are expected to
// generate better performance for aligned cases.
- // The emulated type is always i8.
+ // The container type is always i8.
patterns.add<RewriteAlignedSubByteIntExt<arith::ExtSIOp, /*isSigned=*/true>,
RewriteAlignedSubByteIntExt<arith::SIToFPOp, /*isSigned=*/true>,
RewriteAlignedSubByteIntTrunc>(patterns.getContext(),
benefit.getBenefit() + 1);
- // The emulated type is always i8.
+ // The container type is always i8.
patterns
.add<RewriteAlignedSubByteIntExt<arith::ExtUIOp, /*isSigned=*/false>,
RewriteAlignedSubByteIntExt<arith::UIToFPOp, /*isSigned=*/false>>(
patterns.getContext(), benefit.getBenefit() + 1);
}
-// The emulated type is always i8.
+// The container type is always i8.
void vector::populateVectorTransposeNarrowTypeRewritePatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<RewriteVectorTranspose>(patterns.getContext(), benefit);
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
index 8d28f248e392d20..a4af307b15da4aa 100644
--- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -193,382 +193,6 @@ func.func @f3ext(%a: vector<5xi8>) -> vector<8xi17> {
return %1 : vector<8xi17>
}
-
-// Negative test - the trailing dim 1 is not a multiple of 2 (i.e. 8 / 4).
-// CHECK-LABEL: func.func @unaligned_extsi_i4_to_i8(
-func.func @unaligned_extsi_i4_to_i8(%a: vector<1xi4>) -> vector<1xi8> {
- // CHECK-NOT: arith.bitcast
- // CHECK: arith.extsi %[[IN:.*]] : vector<1xi4> to vector<1xi8>
- %0 = arith.extsi %a : vector<1xi4> to vector<1xi8>
- return %0 : vector<1xi8>
-}
-
-// Negative test - the trailing dim 2 is not a multiple of 4 (i.e. 8 / 2).
-// CHECK-LABEL: func.func @unaligned_extsi_i2_to_i8(
-func.func @unaligned_extsi_i2_to_i8(%a: vector<2xi2>) -> vector<2xi8> {
- // CHECK-NOT: arith.bitcast
- // CHECK: arith.extsi %[[IN:.*]] : vector<2xi2> to vector<2xi8>
- %0 = arith.extsi %a : vector<2xi2> to vector<2xi8>
- return %0 : vector<2xi8>
-}
-
-// CHECK-LABEL: func.func @aligned_extsi_i4_to_i8(
-func.func @aligned_extsi_i4_to_i8(%a: vector<8xi4>) -> vector<8xi8> {
-// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi8> {
-// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
-// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
-// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
-// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<4xi8>
-// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
-// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
- %0 = arith.extsi %a : vector<8xi4> to vector<8xi8>
- return %0 : vector<8xi8>
-}
-
-// CHECK-LABEL: func.func @aligned_extsi_i2_to_i8(
-func.func @aligned_extsi_i2_to_i8(%a: vector<8xi2>) -> vector<8xi8> {
-// CHECK-SAME: %[[IN:.*]]: vector<8xi2>) -> vector<8xi8> {
-// CHECK: %[[CST_2:.*]] = arith.constant dense<2> : vector<2xi8>
-// CHECK: %[[CST_4:.*]] = arith.constant dense<4> : vector<2xi8>
-// CHECK: %[[CST_6:.*]] = arith.constant dense<6> : vector<2xi8>
-// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi2> to vector<2xi8>
-// Extract bits 0-1
-// CHECK: %[[SHL_6:.*]] = arith.shli %[[BITCAST]], %[[CST_6]] : vector<2xi8>
-// CHECK: %[[ELEM0:.*]] = arith.shrsi %[[SHL_6]], %[[CST_6]] : vector<2xi8>
-// Extract bits 2-3
-// CHECK: %[[SHL_4:.*]] = arith.shli %[[BITCAST]], %[[CST_4]] : vector<2xi8>
-// CHECK: %[[ELEM1:.*]] = arith.shrsi %[[SHL_4]], %[[CST_6]] : vector<2xi8>
-// Extract bits 4-5
-// CHECK: %[[SHL_2:.*]] = arith.shli %[[BITCAST]], %[[CST_2]] : vector<2xi8>
-// CHECK: %[[ELEM2:.*]] = arith.shrsi %[[SHL_2]], %[[CST_6]] : vector<2xi8>
-// Extract bits 6-7
-// CHECK: %[[ELEM3:.*]] = arith.shrsi %[[BITCAST]], %[[CST_6]] : vector<2xi8>
-// CHECK: %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<2xi8>
-// CHECK: %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<2xi8>
-// CHECK: %[[RESULT:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<4xi8>
- %0 = arith.extsi %a : vector<8xi2> to vector<8xi8>
- return %0 : vector<8xi8>
-}
-
-// CHECK-LABEL: func.func @aligned_extsi_i4_to_i32(
-func.func @aligned_extsi_i4_to_i32(%a: vector<8xi4>) -> vector<8xi32> {
-// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi32> {
-// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
-// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
-// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
-// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<4xi8>
-// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
-// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
-// CHECK: %[[I32:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32>
- %0 = arith.extsi %a : vector<8xi4> to vector<8xi32>
- return %0 : vector<8xi32>
-}
-
-// CHECK-LABEL: func.func @aligned_extsi_i2_to_i32(
-func.func @aligned_extsi_i2_to_i32(%a: vector<8xi2>) -> vector<8xi32> {
-// CHECK-SAME: %[[IN:.*]]: vector<8xi2>) -> vector<8xi32> {
-// CHECK: %[[CST_2:.*]] = arith.constant dense<2> : vector<2xi8>
-// CHECK: %[[CST_4:.*]] = arith.constant dense<4> : vector<2xi8>
-// CHECK: %[[CST_6:.*]] = arith.constant dense<6> : vector<2xi8>
-// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi2> to vector<2xi8>
-// Extract bits 0-1
-// CHECK: %[[SHL_6:.*]] = arith.shli %[[BITCAST]], %[[CST_6]] : vector<2xi8>
-// CHECK: %[[ELEM0:.*]] = arith.shrsi %[[SHL_6]], %[[CST_6]] : vector<2xi8>
-// Extract bits 2-3
-// CHECK: %[[SHL_4:.*]] = arith.shli %[[BITCAST]], %[[CST_4]] : vector<2xi8>
-// CHECK: %[[ELEM1:.*]] = arith.shrsi %[[SHL_4]], %[[CST_6]] : vector<2xi8>
-// Extract bits 4-5
-// CHECK: %[[SHL_2:.*]] = arith.shli %[[BITCAST]], %[[CST_2]] : vector<2xi8>
-// CHECK: %[[ELEM2:.*]] = arith.shrsi %[[SHL_2]], %[[CST_6]] : vector<2xi8>
-// Extract bits 6-7
-// CHECK: %[[ELEM3:.*]] = arith.shrsi %[[BITCAST]], %[[CST_6]] : vector<2xi8>
-// CHECK: %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<2xi8>
-// CHECK: %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<2xi8>
-// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<4xi8>
-// CHECK: %[[RESULT:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32>
- %0 = arith.extsi %a : vector<8xi2> to vector<8xi32>
- return %0 : vector<8xi32>
-}
-
-// CHECK-LABEL: func.func @aligned_extsi_i4_to_i32_2d(
-func.func @aligned_extsi_i4_to_i32_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
-// CHECK-SAME: %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xi32> {
-// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
-// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8>
-// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
-// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<8x16xi8>
-// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
-// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8>
-// CHECK: %[[I32:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32>
- %0 = arith.extsi %a : vector<8x32xi4> to vector<8x32xi32>
- return %0 : vector<8x32xi32>
-}
-
-// CHECK-LABEL: func.func @aligned_extsi_i2_to_i32_2d(
-func.func @aligned_extsi_i2_to_i32_2d(%a: vector<8x32xi2>) -> vector<8x32xi32> {
-// CHECK-SAME: %[[IN:.*]]: vector<8x32xi2>) -> vector<8x32xi32> {
-// CHECK: %[[CST_2:.*]] = arith.constant dense<2> : vector<8x8xi8>
-// CHECK: %[[CST_4:.*]] = arith.constant dense<4> : vector<8x8xi8>
-// CHECK: %[[CST_6:.*]] = arith.constant dense<6> : vector<8x8xi8>
-// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi2> to vector<8x8xi8>
-// Extract bits 0-1
-// CHECK: %[[SHL_6:.*]] = arith.shli %[[BITCAST]], %[[CST_6]] : vector<8x8xi8>
-// CHECK: %[[ELEM0:.*]] = arith.shrsi %[[SHL_6]], %[[CST_6]] : vector<8x8xi8>
-// Extract bits 2-3
-// CHECK: %[[SHL_4:.*]] = arith.shli %[[BITCAST]], %[[CST_4]] : vector<8x8xi8>
-// CHECK: %[[ELEM1:.*]] = arith.shrsi %[[SHL_4]], %[[CST_6]] : vector<8x8xi8>
-// Extract bits 4-5
-// CHECK: %[[SHL_2:.*]] = arith.shli %[[BITCAST]], %[[CST_2]] : vector<8x8xi8>
-// CHECK: %[[ELEM2:.*]] = arith.shrsi %[[SHL_2]], %[[CST_6]] : vector<8x8xi8>
-// Extract bits 6-7
-// CHECK: %[[ELEM3:.*]] = arith.shrsi %[[BITCAST]], %[[CST_6]] : vector<8x8xi8>
-// CHECK: %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<8x8xi8>
-// CHECK: %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<8x8xi8>
-// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<8x16xi8>
-// CHECK: %[[RESULT:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32>
- %0 = arith.extsi %a : vector<8x32xi2> to vector<8x32xi32>
- return %0 : vector<8x32xi32>
-}
-
-
-// CHECK-LABEL: func.func @aligned_trunci_i8_to_i4(
-func.func @aligned_trunci_i8_to_i4(%a: vector<8xi8>) -> vector<8xi4> {
-// CHECK-SAME: %[[IN:.*]]: vector<8xi8>) -> vector<8xi4> {
-// CHECK-DAG: %[[LOW_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
-// CHECK-DAG: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
-// CHECK: %[[LOW:.*]], %[[HIGH:.*]] = vector.deinterleave %[[IN]] : vector<8xi8> -> vector<4xi8>
-// CHECK: %[[ZEROED_LOW:.*]] = arith.andi %[[LOW]], %[[LOW_MASK]] : vector<4xi8>
-// CHECK: %[[SHL_HIGH:.*]] = arith.shli %[[HIGH]], %[[I4_BITS]] : vector<4xi8>
-// CHECK: %[[MERGED:.*]] = arith.ori %[[ZEROED_LOW]], %[[SHL_HIGH]] : vector<4xi8>
-// CHECK: %[[I4:.*]] = vector.bitcast %[[MERGED]] : vector<4xi8> to vector<8xi4>
- %0 = arith.trunci %a : vector<8xi8> to vector<8xi4>
- return %0 : vector<8xi4>
-}
-
-// CHECK-LABEL: func.func @aligned_trunci_i32_to_i4(
-func.func @aligned_trunci_i32_to_i4(%a: vector<8xi32>) -> vector<8xi4> {
-// CHECK-SAME: %[[IN:.*]]: vector<8xi32>) -> vector<8xi4> {
-// CHECK-DAG: %[[LOW_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
-// CHECK-DAG: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
-// CHECK: %[[I8:.*]] = arith.trunci %[[IN]] : vector<8xi32> to vector<8xi8>
-// CHECK: %[[LOW:.*]], %[[HIGH:.*]] = vector.deinterleave %[[I8]] : vector<8xi8> -> vector<4xi8>
-// CHECK: %[[ZEROED_LOW:.*]] = arith.andi %[[LOW]], %[[LOW_MASK]] : vector<4xi8>
-// CHECK: %[[SHL_HIGH:.*]] = arith.shli %[[HIGH]], %[[I4_BITS]] : vector<4xi8>
-// CHECK: %[[MERGED:.*]] = arith.ori %[[ZEROED_LOW]], %[[SHL_HIGH]] : vector<4xi8>
-// CHECK: %[[I4:.*]] = vector.bitcast %[[MERGED]] : vector<4xi8> to vector<8xi4>
- %0 = arith.trunci %a : vector<8xi32> to vector<8xi4>
- return %0 : vector<8xi4>
-}
-
-// CHECK-LABEL: func.func @aligned_trunci_2d(
-func.func @aligned_trunci_2d(%a: vector<8x32xi32>) -> vector<8x32xi4> {
-// CHECK-NOT: vector.shuffle
-// CHECK-NOT: vector.andi
-// CHECK-NOT: vector.shli
-// CHECK-NOT: vector.ori
-// CHECK: arith.trunci {{.*}} : vector<8x32xi32> to vector<8x32xi8>
-// CHECK-NOT: arith.trunci {{.*}} : vector<8x32xi8> to vector<8x32xi4>
-// CHECK: vector.deinterleave
- %0 = arith.trunci %a : vector<8x32xi32> to vector<8x32xi4>
- return %0 : vector<8x32xi4>
-}
-
-// CHECK-LABEL: func.func @aligned_trunci_nd(
-// CHECK-SAME: %[[IN:.*]]: vector<3x8x32xi32>) -> vector<3x8x32xi4> {
-func.func @aligned_trunci_nd(%a: vector<3x8x32xi32>) -> vector<3x8x32xi4> {
- // CHECK: %[[LEFT_SHIFT_BITS:.*]] = arith.constant dense<4> : vector<3x8x16xi8>
- // CHECK: %[[I4_MASK:.*]] = arith.constant dense<15> : vector<3x8x16xi8>
- // CHECK: %[[I8:.*]] = arith.trunci %[[IN]] : vector<3x8x32xi32> to vector<3x8x32xi8>
- // CHECK: %[[LOW:.*]], %[[HIGH:.*]] = vector.deinterleave %[[I8]] : vector<3x8x32xi8> -> vector<3x8x16xi8>
- // CHECK: %[[ZEROED_LOW:.*]] = arith.andi %[[LOW]], %[[I4_MASK]] : vector<3x8x16xi8>
- // CHECK: %[[SHL_HIGH:.*]] = arith.shli %[[HIGH]], %[[LEFT_SHIFT_BITS]] : vector<3x8x16xi8>
- // CHECK: %[[MERGED:.*]] = arith.ori %[[ZEROED_LOW]], %[[SHL_HIGH]] : vector<3x8x16xi8>
- // CHECK: %[[I4:.*]] = vector.bitcast %[[MERGED]] : vector<3x8x16xi8> to vector<3x8x32xi4>
- %0 = arith.trunci %a : vector<3x8x32xi32> to vector<3x8x32xi4>
- return %0 : vector<3x8x32xi4>
-}
-
-func.func @aligned_trunci_i8_to_i2_no_match(%a: vector<8xi8>) -> vector<8xi2> {
- // CHECK-NOT: arith.bitcast
- // CHECK: arith.trunci %[[IN:.*]] : vector<8xi8> to vector<8xi2>
- %0 = arith.trunci %a : vector<8xi8> to vector<8xi2>
- return %0 : vector<8xi2>
-}
-
-// CHECK-LABEL: func.func @aligned_extui_i4_to_i8(
-func.func @aligned_extui_i4_to_i8(%a: vector<8xi4>) -> vector<8xi8> {
-// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi8> {
-// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
-// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
-// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
-// CHECK: %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<4xi8>
-// CHECK: %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
-// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
- %0 = arith.extui %a : vector<8xi4> to vector<8xi8>
- return %0 : vector<8xi8>
-}
-
-// CHECK-LABEL: func.func @aligned_extui_i2_to_i8(
-func.func @aligned_extui_i2_to_i8(%a: vector<8xi2>) -> vector<8xi8> {
-// CHECK-SAME: %[[IN:.*]]: vector<8xi2>) -> vector<8xi8> {
-// CHECK: %[[CST_6:.*]] = arith.constant dense<6> : vector<2xi8>
-// CHECK: %[[CST_4:.*]] = arith.constant dense<4> : vector<2xi8>
-// CHECK: %[[CST_2:.*]] = arith.constant dense<2> : vector<2xi8>
-// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<3> : vector<2xi8>
-// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi2> to vector<2xi8>
-// Extract bits 0-1
-// CHECK: %[[ELEM0:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<2xi8>
-// Extract bits 2-3
-// CHECK: %[[SHR_2:.*]] = arith.shrui %[[BITCAST]], %[[CST_2]] : vector<2xi8>
-// CHECK: %[[ELEM1:.*]] = arith.andi %[[SHR_2]], %[[LOWBITS_MASK]] : vector<2xi8>
-// Extract bits 4-5
-// CHECK: %[[SHR_4:.*]] = arith.shrui %[[BITCAST]], %[[CST_4]] : vector<2xi8>
-// CHECK: %[[ELEM2:.*]] = arith.andi %[[SHR_4]], %[[LOWBITS_MASK]] : vector<2xi8>
-// Extract bits 6-7
-// CHECK: %[[ELEM3:.*]] = arith.shrui %[[BITCAST]], %[[CST_6]] : vector<2xi8>
-// CHECK: %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<2xi8>
-// CHECK: %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<2xi8>
-// CHECK: %[[RESULT:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<4xi8>
- %0 = arith.extui %a : vector<8xi2> to vector<8xi8>
- return %0 : vector<8xi8>
-}
-
-// CHECK-LABEL: func.func @aligned_extui_i4_to_i32(
-func.func @aligned_extui_i4_to_i32(%a: vector<8xi4>) -> vector<8xi32> {
-// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi32> {
-// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
-// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
-// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
-// CHECK: %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<4xi8>
-// CHECK: %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
-// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
-// CHECK: %[[I32:.*]] = arith.extui %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32>
- %0 = arith.extui %a : vector<8xi4> to vector<8xi32>
- return %0 : vector<8xi32>
-}
-
-// CHECK-LABEL: func.func @aligned_extui_i2_to_i32(
-func.func @aligned_extui_i2_to_i32(%a: vector<8xi2>) -> vector<8xi32> {
-// CHECK-SAME: %[[IN:.*]]: vector<8xi2>) -> vector<8xi32> {
-// CHECK: %[[CST_6:.*]] = arith.constant dense<6> : vector<2xi8>
-// CHECK: %[[CST_4:.*]] = arith.constant dense<4> : vector<2xi8>
-// CHECK: %[[CST_2:.*]] = arith.constant dense<2> : vector<2xi8>
-// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<3> : vector<2xi8>
-// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi2> to vector<2xi8>
-// Extract bits 0-1
-// CHECK: %[[ELEM0:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<2xi8>
-// Extract bits 2-3
-// CHECK: %[[SHR_2:.*]] = arith.shrui %[[BITCAST]], %[[CST_2]] : vector<2xi8>
-// CHECK: %[[ELEM1:.*]] = arith.andi %[[SHR_2]], %[[LOWBITS_MASK]] : vector<2xi8>
-// Extract bits 4-5
-// CHECK: %[[SHR_4:.*]] = arith.shrui %[[BITCAST]], %[[CST_4]] : vector<2xi8>
-// CHECK: %[[ELEM2:.*]] = arith.andi %[[SHR_4]], %[[LOWBITS_MASK]] : vector<2xi8>
-// Extract bits 6-7
-// CHECK: %[[ELEM3:.*]] = arith.shrui %[[BITCAST]], %[[CST_6]] : vector<2xi8>
-// CHECK: %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<2xi8>
-// CHECK: %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<2xi8>
-// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<4xi8>
-// CHECK: %[[RESULT:.*]] = arith.extui %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32>
- %0 = arith.extui %a : vector<8xi2> to vector<8xi32>
- return %0 : vector<8xi32>
-}
-
-// CHECK-LABEL: func.func @aligned_extui_i4_to_i32_2d(
-func.func @aligned_extui_i4_to_i32_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
-// CHECK-SAME: %[[VAL_0:.*]]: vector<8x32xi4>) -> vector<8x32xi32> {
-// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
-// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<8x16xi8>
-// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[VAL_0]] : vector<8x32xi4> to vector<8x16xi8>
-// CHECK: %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<8x16xi8>
-// CHECK: %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
-// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8>
-// CHECK: %[[I32:.*]] = arith.extui %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32>
- %0 = arith.extui %a : vector<8x32xi4> to vector<8x32xi32>
- return %0 : vector<8x32xi32>
-}
-
-// CHECK-LABEL: func.func @aligned_extui_i2_to_i32_2d(
-func.func @aligned_extui_i2_to_i32_2d(%a: vector<8x32xi2>) -> vector<8x32xi32> {
-// CHECK-SAME: %[[IN:.*]]: vector<8x32xi2>) -> vector<8x32xi32> {
-// CHECK: %[[CST_6:.*]] = arith.constant dense<6> : vector<8x8xi8>
-// CHECK: %[[CST_4:.*]] = arith.constant dense<4> : vector<8x8xi8>
-// CHECK: %[[CST_2:.*]] = arith.constant dense<2> : vector<8x8xi8>
-// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<3> : vector<8x8xi8>
-// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi2> to vector<8x8xi8>
-// Extract bits 0-1
-// CHECK: %[[ELEM0:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<8x8xi8>
-// Extract bits 2-3
-// CHECK: %[[SHR_2:.*]] = arith.shrui %[[BITCAST]], %[[CST_2]] : vector<8x8xi8>
-// CHECK: %[[ELEM1:.*]] = arith.andi %[[SHR_2]], %[[LOWBITS_MASK]] : vector<8x8xi8>
-// Extract bits 4-5
-// CHECK: %[[SHR_4:.*]] = arith.shrui %[[BITCAST]], %[[CST_4]] : vector<8x8xi8>
-// CHECK: %[[ELEM2:.*]] = arith.andi %[[SHR_4]], %[[LOWBITS_MASK]] : vector<8x8xi8>
-// Extract bits 6-7
-// CHECK: %[[ELEM3:.*]] = arith.shrui %[[BITCAST]], %[[CST_6]] : vector<8x8xi8>
-// CHECK: %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<8x8xi8>
-// CHECK: %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<8x8xi8>
-// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<8x16xi8>
-// CHECK: %[[RESULT:.*]] = arith.extui %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32>
- %0 = arith.extui %a : vector<8x32xi2> to vector<8x32xi32>
- return %0 : vector<8x32xi32>
-}
-
-// CHECK-LABEL: func.func @aligned_sitofp(
-func.func @aligned_sitofp(%a: vector<8xi4>) -> vector<8xf32> {
-// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xf32> {
-// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
-// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
-// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
-// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<4xi8>
-// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
-// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
-// CHECK: %[[F32:.*]] = arith.sitofp %[[INTERLEAVE]] : vector<8xi8> to vector<8xf32>
- %0 = arith.sitofp %a : vector<8xi4> to vector<8xf32>
- return %0 : vector<8xf32>
-}
-
-// CHECK-LABEL: func.func @aligned_sitofp_2d(
-func.func @aligned_sitofp_2d(%a: vector<8x32xi4>) -> vector<8x32xf32> {
-// CHECK-SAME: %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xf32> {
-// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
-// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8>
-// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
-// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<8x16xi8>
-// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
-// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8>
-// CHECK: %[[F32:.*]] = arith.sitofp %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xf32>
- %0 = arith.sitofp %a : vector<8x32xi4> to vector<8x32xf32>
- return %0 : vector<8x32xf32>
-}
-
-// CHECK-LABEL: func.func @aligned_uitofp(
-func.func @aligned_uitofp(%a: vector<8xi4>) -> vector<8xf32> {
-// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xf32> {
-// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
-// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
-// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
-// CHECK: %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<4xi8>
-// CHECK: %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
-// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
-// CHECK: %[[F32:.*]] = arith.uitofp %[[INTERLEAVE]] : vector<8xi8> to vector<8xf32>
- %0 = arith.uitofp %a : vector<8xi4> to vector<8xf32>
- return %0 : vector<8xf32>
-}
-
-// CHECK-LABEL: func.func @aligned_uitofp_2d(
-func.func @aligned_uitofp_2d(%a: vector<8x32xi4>) -> vector<8x32xf32> {
-// CHECK-SAME: %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xf32> {
-// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
-// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<8x16xi8>
-// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8>
-// CHECK: %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<8x16xi8>
-// CHECK: %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
-// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8>
-// CHECK: %[[F32:.*]] = arith.uitofp %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xf32>
- %0 = arith.uitofp %a : vector<8x32xi4> to vector<8x32xf32>
- return %0 : vector<8x32xf32>
-}
-
// CHECK-LABEL: func.func @i4_transpose(
func.func @i4_transpose(%a: vector<8x16xi4>) -> vector<16x8xi4> {
// CHECK-SAME: %[[IN:.*]]: vector<8x16xi4>) -> vector<16x8xi4> {
@@ -589,7 +213,6 @@ func.func @i7_transpose(%a: vector<8x16xi7>) -> vector<16x8xi7> {
return %0 : vector<16x8xi7>
}
-
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
%f = transform.structured.match ops{["func.func"]} in %module_op
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-subbyte-ext-and-trunci.mlir b/mlir/test/Dialect/Vector/vector-rewrite-subbyte-ext-and-trunci.mlir
new file mode 100644
index 000000000000000..aa75e020052566c
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-rewrite-subbyte-ext-and-trunci.mlir
@@ -0,0 +1,415 @@
+// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
+
+///----------------------------------------------------------------------------------------
+/// arith.extsi
+///
+/// [Pattern: RewriteAlignedSubByteIntExt]
+///----------------------------------------------------------------------------------------
+// Negative test - the trailing dim 1 is not a multiple of 2 (i.e. 8 / 4).
+// CHECK-LABEL: func.func @unaligned_extsi_i4_to_i8(
+func.func @unaligned_extsi_i4_to_i8(%a: vector<1xi4>) -> vector<1xi8> {
+ // CHECK-NOT: arith.bitcast
+ // CHECK: arith.extsi %[[IN:.*]] : vector<1xi4> to vector<1xi8>
+ %0 = arith.extsi %a : vector<1xi4> to vector<1xi8>
+ return %0 : vector<1xi8>
+}
+
+// Negative test - the trailing dim 2 is not a multiple of 4 (i.e. 8 / 2).
+// CHECK-LABEL: func.func @unaligned_extsi_i2_to_i8(
+func.func @unaligned_extsi_i2_to_i8(%a: vector<2xi2>) -> vector<2xi8> {
+ // CHECK-NOT: arith.bitcast
+ // CHECK: arith.extsi %[[IN:.*]] : vector<2xi2> to vector<2xi8>
+ %0 = arith.extsi %a : vector<2xi2> to vector<2xi8>
+ return %0 : vector<2xi8>
+}
+
+// CHECK-LABEL: func.func @aligned_extsi_i4_to_i8(
+func.func @aligned_extsi_i4_to_i8(%a: vector<8xi4>) -> vector<8xi8> {
+// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi8> {
+// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
+// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
+ %0 = arith.extsi %a : vector<8xi4> to vector<8xi8>
+ return %0 : vector<8xi8>
+}
+
+// CHECK-LABEL: func.func @aligned_extsi_i2_to_i8(
+func.func @aligned_extsi_i2_to_i8(%a: vector<8xi2>) -> vector<8xi8> {
+// CHECK-SAME: %[[IN:.*]]: vector<8xi2>) -> vector<8xi8> {
+// CHECK: %[[CST_2:.*]] = arith.constant dense<2> : vector<2xi8>
+// CHECK: %[[CST_4:.*]] = arith.constant dense<4> : vector<2xi8>
+// CHECK: %[[CST_6:.*]] = arith.constant dense<6> : vector<2xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi2> to vector<2xi8>
+// Extract bits 0-1
+// CHECK: %[[SHL_6:.*]] = arith.shli %[[BITCAST]], %[[CST_6]] : vector<2xi8>
+// CHECK: %[[ELEM0:.*]] = arith.shrsi %[[SHL_6]], %[[CST_6]] : vector<2xi8>
+// Extract bits 2-3
+// CHECK: %[[SHL_4:.*]] = arith.shli %[[BITCAST]], %[[CST_4]] : vector<2xi8>
+// CHECK: %[[ELEM1:.*]] = arith.shrsi %[[SHL_4]], %[[CST_6]] : vector<2xi8>
+// Extract bits 4-5
+// CHECK: %[[SHL_2:.*]] = arith.shli %[[BITCAST]], %[[CST_2]] : vector<2xi8>
+// CHECK: %[[ELEM2:.*]] = arith.shrsi %[[SHL_2]], %[[CST_6]] : vector<2xi8>
+// Extract bits 6-7
+// CHECK: %[[ELEM3:.*]] = arith.shrsi %[[BITCAST]], %[[CST_6]] : vector<2xi8>
+// CHECK: %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<2xi8>
+// CHECK: %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<2xi8>
+// CHECK: %[[RESULT:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<4xi8>
+ %0 = arith.extsi %a : vector<8xi2> to vector<8xi8>
+ return %0 : vector<8xi8>
+}
+
+// CHECK-LABEL: func.func @aligned_extsi_i4_to_i32(
+func.func @aligned_extsi_i4_to_i32(%a: vector<8xi4>) -> vector<8xi32> {
+// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi32> {
+// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
+// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
+// CHECK: %[[I32:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32>
+ %0 = arith.extsi %a : vector<8xi4> to vector<8xi32>
+ return %0 : vector<8xi32>
+}
+
+// CHECK-LABEL: func.func @aligned_extsi_i2_to_i32(
+func.func @aligned_extsi_i2_to_i32(%a: vector<8xi2>) -> vector<8xi32> {
+// CHECK-SAME: %[[IN:.*]]: vector<8xi2>) -> vector<8xi32> {
+// CHECK: %[[CST_2:.*]] = arith.constant dense<2> : vector<2xi8>
+// CHECK: %[[CST_4:.*]] = arith.constant dense<4> : vector<2xi8>
+// CHECK: %[[CST_6:.*]] = arith.constant dense<6> : vector<2xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi2> to vector<2xi8>
+// Extract bits 0-1
+// CHECK: %[[SHL_6:.*]] = arith.shli %[[BITCAST]], %[[CST_6]] : vector<2xi8>
+// CHECK: %[[ELEM0:.*]] = arith.shrsi %[[SHL_6]], %[[CST_6]] : vector<2xi8>
+// Extract bits 2-3
+// CHECK: %[[SHL_4:.*]] = arith.shli %[[BITCAST]], %[[CST_4]] : vector<2xi8>
+// CHECK: %[[ELEM1:.*]] = arith.shrsi %[[SHL_4]], %[[CST_6]] : vector<2xi8>
+// Extract bits 4-5
+// CHECK: %[[SHL_2:.*]] = arith.shli %[[BITCAST]], %[[CST_2]] : vector<2xi8>
+// CHECK: %[[ELEM2:.*]] = arith.shrsi %[[SHL_2]], %[[CST_6]] : vector<2xi8>
+// Extract bits 6-7
+// CHECK: %[[ELEM3:.*]] = arith.shrsi %[[BITCAST]], %[[CST_6]] : vector<2xi8>
+// CHECK: %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<2xi8>
+// CHECK: %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<2xi8>
+// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<4xi8>
+// CHECK: %[[RESULT:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32>
+ %0 = arith.extsi %a : vector<8xi2> to vector<8xi32>
+ return %0 : vector<8xi32>
+}
+
+// CHECK-LABEL: func.func @aligned_extsi_i4_to_i32_2d(
+func.func @aligned_extsi_i4_to_i32_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
+// CHECK-SAME: %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xi32> {
+// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8>
+// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8>
+// CHECK: %[[I32:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32>
+ %0 = arith.extsi %a : vector<8x32xi4> to vector<8x32xi32>
+ return %0 : vector<8x32xi32>
+}
+
+// CHECK-LABEL: func.func @aligned_extsi_i2_to_i32_2d(
+func.func @aligned_extsi_i2_to_i32_2d(%a: vector<8x32xi2>) -> vector<8x32xi32> {
+// CHECK-SAME: %[[IN:.*]]: vector<8x32xi2>) -> vector<8x32xi32> {
+// CHECK: %[[CST_2:.*]] = arith.constant dense<2> : vector<8x8xi8>
+// CHECK: %[[CST_4:.*]] = arith.constant dense<4> : vector<8x8xi8>
+// CHECK: %[[CST_6:.*]] = arith.constant dense<6> : vector<8x8xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi2> to vector<8x8xi8>
+// Extract bits 0-1
+// CHECK: %[[SHL_6:.*]] = arith.shli %[[BITCAST]], %[[CST_6]] : vector<8x8xi8>
+// CHECK: %[[ELEM0:.*]] = arith.shrsi %[[SHL_6]], %[[CST_6]] : vector<8x8xi8>
+// Extract bits 2-3
+// CHECK: %[[SHL_4:.*]] = arith.shli %[[BITCAST]], %[[CST_4]] : vector<8x8xi8>
+// CHECK: %[[ELEM1:.*]] = arith.shrsi %[[SHL_4]], %[[CST_6]] : vector<8x8xi8>
+// Extract bits 4-5
+// CHECK: %[[SHL_2:.*]] = arith.shli %[[BITCAST]], %[[CST_2]] : vector<8x8xi8>
+// CHECK: %[[ELEM2:.*]] = arith.shrsi %[[SHL_2]], %[[CST_6]] : vector<8x8xi8>
+// Extract bits 6-7
+// CHECK: %[[ELEM3:.*]] = arith.shrsi %[[BITCAST]], %[[CST_6]] : vector<8x8xi8>
+// CHECK: %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<8x8xi8>
+// CHECK: %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<8x8xi8>
+// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<8x16xi8>
+// CHECK: %[[RESULT:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32>
+ %0 = arith.extsi %a : vector<8x32xi2> to vector<8x32xi32>
+ return %0 : vector<8x32xi32>
+}
+
+///----------------------------------------------------------------------------------------
+/// arith.trunci
+///
+/// [Pattern: RewriteAlignedSubByteIntTrunc]
+///----------------------------------------------------------------------------------------
+// CHECK-LABEL: func.func @aligned_trunci_i8_to_i4(
+func.func @aligned_trunci_i8_to_i4(%a: vector<8xi8>) -> vector<8xi4> {
+// CHECK-SAME: %[[IN:.*]]: vector<8xi8>) -> vector<8xi4> {
+// CHECK-DAG: %[[LOW_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
+// CHECK-DAG: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK: %[[LOW:.*]], %[[HIGH:.*]] = vector.deinterleave %[[IN]] : vector<8xi8> -> vector<4xi8>
+// CHECK: %[[ZEROED_LOW:.*]] = arith.andi %[[LOW]], %[[LOW_MASK]] : vector<4xi8>
+// CHECK: %[[SHL_HIGH:.*]] = arith.shli %[[HIGH]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[MERGED:.*]] = arith.ori %[[ZEROED_LOW]], %[[SHL_HIGH]] : vector<4xi8>
+// CHECK: %[[I4:.*]] = vector.bitcast %[[MERGED]] : vector<4xi8> to vector<8xi4>
+ %0 = arith.trunci %a : vector<8xi8> to vector<8xi4>
+ return %0 : vector<8xi4>
+}
+
+// CHECK-LABEL: func.func @aligned_trunci_i32_to_i4(
+func.func @aligned_trunci_i32_to_i4(%a: vector<8xi32>) -> vector<8xi4> {
+// CHECK-SAME: %[[IN:.*]]: vector<8xi32>) -> vector<8xi4> {
+// CHECK-DAG: %[[LOW_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
+// CHECK-DAG: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK: %[[I8:.*]] = arith.trunci %[[IN]] : vector<8xi32> to vector<8xi8>
+// CHECK: %[[LOW:.*]], %[[HIGH:.*]] = vector.deinterleave %[[I8]] : vector<8xi8> -> vector<4xi8>
+// CHECK: %[[ZEROED_LOW:.*]] = arith.andi %[[LOW]], %[[LOW_MASK]] : vector<4xi8>
+// CHECK: %[[SHL_HIGH:.*]] = arith.shli %[[HIGH]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[MERGED:.*]] = arith.ori %[[ZEROED_LOW]], %[[SHL_HIGH]] : vector<4xi8>
+// CHECK: %[[I4:.*]] = vector.bitcast %[[MERGED]] : vector<4xi8> to vector<8xi4>
+ %0 = arith.trunci %a : vector<8xi32> to vector<8xi4>
+ return %0 : vector<8xi4>
+}
+
+// CHECK-LABEL: func.func @aligned_trunci_2d(
+func.func @aligned_trunci_2d(%a: vector<8x32xi32>) -> vector<8x32xi4> {
+// CHECK-NOT: vector.shuffle
+// CHECK-NOT: vector.andi
+// CHECK-NOT: vector.shli
+// CHECK-NOT: vector.ori
+// CHECK: arith.trunci {{.*}} : vector<8x32xi32> to vector<8x32xi8>
+// CHECK-NOT: arith.trunci {{.*}} : vector<8x32xi8> to vector<8x32xi4>
+// CHECK: vector.deinterleave
+ %0 = arith.trunci %a : vector<8x32xi32> to vector<8x32xi4>
+ return %0 : vector<8x32xi4>
+}
+
+// CHECK-LABEL: func.func @aligned_trunci_nd(
+// CHECK-SAME: %[[IN:.*]]: vector<3x8x32xi32>) -> vector<3x8x32xi4> {
+func.func @aligned_trunci_nd(%a: vector<3x8x32xi32>) -> vector<3x8x32xi4> {
+ // CHECK: %[[LEFT_SHIFT_BITS:.*]] = arith.constant dense<4> : vector<3x8x16xi8>
+ // CHECK: %[[I4_MASK:.*]] = arith.constant dense<15> : vector<3x8x16xi8>
+ // CHECK: %[[I8:.*]] = arith.trunci %[[IN]] : vector<3x8x32xi32> to vector<3x8x32xi8>
+ // CHECK: %[[LOW:.*]], %[[HIGH:.*]] = vector.deinterleave %[[I8]] : vector<3x8x32xi8> -> vector<3x8x16xi8>
+ // CHECK: %[[ZEROED_LOW:.*]] = arith.andi %[[LOW]], %[[I4_MASK]] : vector<3x8x16xi8>
+ // CHECK: %[[SHL_HIGH:.*]] = arith.shli %[[HIGH]], %[[LEFT_SHIFT_BITS]] : vector<3x8x16xi8>
+ // CHECK: %[[MERGED:.*]] = arith.ori %[[ZEROED_LOW]], %[[SHL_HIGH]] : vector<3x8x16xi8>
+ // CHECK: %[[I4:.*]] = vector.bitcast %[[MERGED]] : vector<3x8x16xi8> to vector<3x8x32xi4>
+ %0 = arith.trunci %a : vector<3x8x32xi32> to vector<3x8x32xi4>
+ return %0 : vector<3x8x32xi4>
+}
+
+func.func @aligned_trunci_i8_to_i2_no_match(%a: vector<8xi8>) -> vector<8xi2> {
+ // CHECK-NOT: arith.bitcast
+ // CHECK: arith.trunci %[[IN:.*]] : vector<8xi8> to vector<8xi2>
+ %0 = arith.trunci %a : vector<8xi8> to vector<8xi2>
+ return %0 : vector<8xi2>
+}
+
+///----------------------------------------------------------------------------------------
+/// arith.extui
+///
+/// [Pattern: RewriteAlignedSubByteIntExt]
+///----------------------------------------------------------------------------------------
+
+// CHECK-LABEL: func.func @aligned_extui_i4_to_i8(
+func.func @aligned_extui_i4_to_i8(%a: vector<8xi4>) -> vector<8xi8> {
+// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi8> {
+// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
+// CHECK: %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<4xi8>
+// CHECK: %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
+ %0 = arith.extui %a : vector<8xi4> to vector<8xi8>
+ return %0 : vector<8xi8>
+}
+
+// CHECK-LABEL: func.func @aligned_extui_i2_to_i8(
+func.func @aligned_extui_i2_to_i8(%a: vector<8xi2>) -> vector<8xi8> {
+// CHECK-SAME: %[[IN:.*]]: vector<8xi2>) -> vector<8xi8> {
+// CHECK: %[[CST_6:.*]] = arith.constant dense<6> : vector<2xi8>
+// CHECK: %[[CST_4:.*]] = arith.constant dense<4> : vector<2xi8>
+// CHECK: %[[CST_2:.*]] = arith.constant dense<2> : vector<2xi8>
+// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<3> : vector<2xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi2> to vector<2xi8>
+// Extract bits 0-1
+// CHECK: %[[ELEM0:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<2xi8>
+// Extract bits 2-3
+// CHECK: %[[SHR_2:.*]] = arith.shrui %[[BITCAST]], %[[CST_2]] : vector<2xi8>
+// CHECK: %[[ELEM1:.*]] = arith.andi %[[SHR_2]], %[[LOWBITS_MASK]] : vector<2xi8>
+// Extract bits 4-5
+// CHECK: %[[SHR_4:.*]] = arith.shrui %[[BITCAST]], %[[CST_4]] : vector<2xi8>
+// CHECK: %[[ELEM2:.*]] = arith.andi %[[SHR_4]], %[[LOWBITS_MASK]] : vector<2xi8>
+// Extract bits 6-7
+// CHECK: %[[ELEM3:.*]] = arith.shrui %[[BITCAST]], %[[CST_6]] : vector<2xi8>
+// CHECK: %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<2xi8>
+// CHECK: %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<2xi8>
+// CHECK: %[[RESULT:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<4xi8>
+ %0 = arith.extui %a : vector<8xi2> to vector<8xi8>
+ return %0 : vector<8xi8>
+}
+
+// CHECK-LABEL: func.func @aligned_extui_i4_to_i32(
+func.func @aligned_extui_i4_to_i32(%a: vector<8xi4>) -> vector<8xi32> {
+// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi32> {
+// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
+// CHECK: %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<4xi8>
+// CHECK: %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
+// CHECK: %[[I32:.*]] = arith.extui %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32>
+ %0 = arith.extui %a : vector<8xi4> to vector<8xi32>
+ return %0 : vector<8xi32>
+}
+
+// CHECK-LABEL: func.func @aligned_extui_i2_to_i32(
+func.func @aligned_extui_i2_to_i32(%a: vector<8xi2>) -> vector<8xi32> {
+// CHECK-SAME: %[[IN:.*]]: vector<8xi2>) -> vector<8xi32> {
+// CHECK: %[[CST_6:.*]] = arith.constant dense<6> : vector<2xi8>
+// CHECK: %[[CST_4:.*]] = arith.constant dense<4> : vector<2xi8>
+// CHECK: %[[CST_2:.*]] = arith.constant dense<2> : vector<2xi8>
+// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<3> : vector<2xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi2> to vector<2xi8>
+// Extract bits 0-1
+// CHECK: %[[ELEM0:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<2xi8>
+// Extract bits 2-3
+// CHECK: %[[SHR_2:.*]] = arith.shrui %[[BITCAST]], %[[CST_2]] : vector<2xi8>
+// CHECK: %[[ELEM1:.*]] = arith.andi %[[SHR_2]], %[[LOWBITS_MASK]] : vector<2xi8>
+// Extract bits 4-5
+// CHECK: %[[SHR_4:.*]] = arith.shrui %[[BITCAST]], %[[CST_4]] : vector<2xi8>
+// CHECK: %[[ELEM2:.*]] = arith.andi %[[SHR_4]], %[[LOWBITS_MASK]] : vector<2xi8>
+// Extract bits 6-7
+// CHECK: %[[ELEM3:.*]] = arith.shrui %[[BITCAST]], %[[CST_6]] : vector<2xi8>
+// CHECK: %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<2xi8>
+// CHECK: %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<2xi8>
+// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<4xi8>
+// CHECK: %[[RESULT:.*]] = arith.extui %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32>
+ %0 = arith.extui %a : vector<8xi2> to vector<8xi32>
+ return %0 : vector<8xi32>
+}
+
+// CHECK-LABEL: func.func @aligned_extui_i4_to_i32_2d(
+func.func @aligned_extui_i4_to_i32_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
+// CHECK-SAME: %[[VAL_0:.*]]: vector<8x32xi4>) -> vector<8x32xi32> {
+// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
+// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<8x16xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[VAL_0]] : vector<8x32xi4> to vector<8x16xi8>
+// CHECK: %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<8x16xi8>
+// CHECK: %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8>
+// CHECK: %[[I32:.*]] = arith.extui %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32>
+ %0 = arith.extui %a : vector<8x32xi4> to vector<8x32xi32>
+ return %0 : vector<8x32xi32>
+}
+
+// CHECK-LABEL: func.func @aligned_extui_i2_to_i32_2d(
+func.func @aligned_extui_i2_to_i32_2d(%a: vector<8x32xi2>) -> vector<8x32xi32> {
+// CHECK-SAME: %[[IN:.*]]: vector<8x32xi2>) -> vector<8x32xi32> {
+// CHECK: %[[CST_6:.*]] = arith.constant dense<6> : vector<8x8xi8>
+// CHECK: %[[CST_4:.*]] = arith.constant dense<4> : vector<8x8xi8>
+// CHECK: %[[CST_2:.*]] = arith.constant dense<2> : vector<8x8xi8>
+// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<3> : vector<8x8xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi2> to vector<8x8xi8>
+// Extract bits 0-1
+// CHECK: %[[ELEM0:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<8x8xi8>
+// Extract bits 2-3
+// CHECK: %[[SHR_2:.*]] = arith.shrui %[[BITCAST]], %[[CST_2]] : vector<8x8xi8>
+// CHECK: %[[ELEM1:.*]] = arith.andi %[[SHR_2]], %[[LOWBITS_MASK]] : vector<8x8xi8>
+// Extract bits 4-5
+// CHECK: %[[SHR_4:.*]] = arith.shrui %[[BITCAST]], %[[CST_4]] : vector<8x8xi8>
+// CHECK: %[[ELEM2:.*]] = arith.andi %[[SHR_4]], %[[LOWBITS_MASK]] : vector<8x8xi8>
+// Extract bits 6-7
+// CHECK: %[[ELEM3:.*]] = arith.shrui %[[BITCAST]], %[[CST_6]] : vector<8x8xi8>
+// CHECK: %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<8x8xi8>
+// CHECK: %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<8x8xi8>
+// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<8x16xi8>
+// CHECK: %[[RESULT:.*]] = arith.extui %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32>
+ %0 = arith.extui %a : vector<8x32xi2> to vector<8x32xi32>
+ return %0 : vector<8x32xi32>
+}
+
+///----------------------------------------------------------------------------------------
+/// arith.sitofp
+///
+/// [Pattern: RewriteAlignedSubByteIntExt]
+///----------------------------------------------------------------------------------------
+
+// CHECK-LABEL: func.func @aligned_sitofp(
+func.func @aligned_sitofp(%a: vector<8xi4>) -> vector<8xf32> {
+// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xf32> {
+// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
+// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
+// CHECK: %[[F32:.*]] = arith.sitofp %[[INTERLEAVE]] : vector<8xi8> to vector<8xf32>
+ %0 = arith.sitofp %a : vector<8xi4> to vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: func.func @aligned_sitofp_2d(
+func.func @aligned_sitofp_2d(%a: vector<8x32xi4>) -> vector<8x32xf32> {
+// CHECK-SAME: %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xf32> {
+// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8>
+// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8>
+// CHECK: %[[F32:.*]] = arith.sitofp %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xf32>
+ %0 = arith.sitofp %a : vector<8x32xi4> to vector<8x32xf32>
+ return %0 : vector<8x32xf32>
+}
+
+///----------------------------------------------------------------------------------------
+/// arith.uitofp
+///
+/// [Pattern: RewriteAlignedSubByteIntExt]
+///----------------------------------------------------------------------------------------
+
+// CHECK-LABEL: func.func @aligned_uitofp(
+func.func @aligned_uitofp(%a: vector<8xi4>) -> vector<8xf32> {
+// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xf32> {
+// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
+// CHECK: %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<4xi8>
+// CHECK: %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
+// CHECK: %[[F32:.*]] = arith.uitofp %[[INTERLEAVE]] : vector<8xi8> to vector<8xf32>
+ %0 = arith.uitofp %a : vector<8xi4> to vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: func.func @aligned_uitofp_2d(
+func.func @aligned_uitofp_2d(%a: vector<8x32xi4>) -> vector<8x32xf32> {
+// CHECK-SAME: %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xf32> {
+// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
+// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<8x16xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8>
+// CHECK: %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<8x16xi8>
+// CHECK: %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8>
+// CHECK: %[[F32:.*]] = arith.uitofp %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xf32>
+ %0 = arith.uitofp %a : vector<8x32xi4> to vector<8x32xf32>
+ return %0 : vector<8x32xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+ %f = transform.structured.match ops{["func.func"]} in %module_op
+ : (!transform.any_op) -> !transform.any_op
+
+ transform.apply_patterns to %f {
+ transform.apply_patterns.vector.rewrite_narrow_types
+ } : !transform.any_op
+ transform.yield
+ }
+}
More information about the Mlir-commits
mailing list