[Mlir-commits] [mlir] [mlir][Vector] Update VectorEmulateNarrowType.cpp (3/N) (PR #123529)

Andrzej Warzyński llvmlistbot at llvm.org
Sat Feb 15 12:38:51 PST 2025


https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/123529

>From 180145d00efb143b52cc771379056ae14f504b16 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Sun, 9 Feb 2025 13:45:44 +0000
Subject: [PATCH 1/2] [mlir][vector] Move tests for
 `rewriteAlignedSubByteInt{Ext|Trunc}` (nfc)

Moves tests for `rewriteAlignedSubByteIntExt` and
`rewriteAlignedSubByteIntTrunc` into a dedicated files. Also adds +
fixes some comments.

This is merely for better organisation and so that it's easier to
identify the patterns and edge cases being tested.
---
 .../Transforms/VectorEmulateNarrowType.cpp    |  10 +-
 .../Vector/vector-rewrite-narrow-types.mlir   | 377 ----------------
 ...vector-rewrite-subbyte-ext-and-trunci.mlir | 415 ++++++++++++++++++
 3 files changed, 420 insertions(+), 382 deletions(-)
 create mode 100644 mlir/test/Dialect/Vector/vector-rewrite-subbyte-ext-and-trunci.mlir

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index bf1ecd7d4559c..057788d85a0f7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -7,8 +7,8 @@
 //===----------------------------------------------------------------------===//
 //
 // This file implements target-independent rewrites and utilities to emulate
-// narrow types that are not supported by the target hardware, e.g. i4, using
-// wider types, e.g. i8.
+// narrow types that are not supported by the target hardware, e.g. i4
+// ("emulated type"), using wider types, e.g. i8 ("container type").
 //
 /// Currently, only power-of-two integer types are supported. These are
 /// converted to wider integers that are either 8 bits wide or wider.
@@ -2022,19 +2022,19 @@ void vector::populateVectorNarrowTypeRewritePatterns(
 
   // Patterns for aligned cases. We set higher priority as they are expected to
   // generate better performance for aligned cases.
-  // The emulated type is always i8.
+  // The container type is always i8.
   patterns.add<RewriteAlignedSubByteIntExt<arith::ExtSIOp, /*isSigned=*/true>,
                RewriteAlignedSubByteIntExt<arith::SIToFPOp, /*isSigned=*/true>,
                RewriteAlignedSubByteIntTrunc>(patterns.getContext(),
                                               benefit.getBenefit() + 1);
-  // The emulated type is always i8.
+  // The container type is always i8.
   patterns
       .add<RewriteAlignedSubByteIntExt<arith::ExtUIOp, /*isSigned=*/false>,
            RewriteAlignedSubByteIntExt<arith::UIToFPOp, /*isSigned=*/false>>(
           patterns.getContext(), benefit.getBenefit() + 1);
 }
 
-// The emulated type is always i8.
+// The container type is always i8.
 void vector::populateVectorTransposeNarrowTypeRewritePatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
   patterns.add<RewriteVectorTranspose>(patterns.getContext(), benefit);
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
index 8d28f248e392d..a4af307b15da4 100644
--- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -193,382 +193,6 @@ func.func @f3ext(%a: vector<5xi8>) -> vector<8xi17> {
   return %1 : vector<8xi17>
 }
 
-
-// Negative test - the trailing dim 1 is not a multiple of 2 (i.e. 8 / 4).
-// CHECK-LABEL: func.func @unaligned_extsi_i4_to_i8(
-func.func @unaligned_extsi_i4_to_i8(%a: vector<1xi4>) -> vector<1xi8> {
-  // CHECK-NOT: arith.bitcast
-  // CHECK: arith.extsi %[[IN:.*]] : vector<1xi4> to vector<1xi8>
-  %0 = arith.extsi %a : vector<1xi4> to vector<1xi8>
-  return %0 : vector<1xi8>
-}
-
-// Negative test - the trailing dim 2 is not a multiple of 4 (i.e. 8 / 2).
-// CHECK-LABEL: func.func @unaligned_extsi_i2_to_i8(
-func.func @unaligned_extsi_i2_to_i8(%a: vector<2xi2>) -> vector<2xi8> {
-  // CHECK-NOT: arith.bitcast
-  // CHECK: arith.extsi %[[IN:.*]] : vector<2xi2> to vector<2xi8>
-  %0 = arith.extsi %a : vector<2xi2> to vector<2xi8>
-  return %0 : vector<2xi8>
-}
-
-// CHECK-LABEL: func.func @aligned_extsi_i4_to_i8(
-func.func @aligned_extsi_i4_to_i8(%a: vector<8xi4>) -> vector<8xi8> {
-// CHECK-SAME:    %[[IN:.*]]: vector<8xi4>) -> vector<8xi8> {
-// CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
-// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
-// CHECK:           %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
-// CHECK:           %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<4xi8>
-// CHECK:           %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
-// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
-  %0 = arith.extsi %a : vector<8xi4> to vector<8xi8>
-  return %0 : vector<8xi8>
-}
-
-// CHECK-LABEL: func.func @aligned_extsi_i2_to_i8(
-func.func @aligned_extsi_i2_to_i8(%a: vector<8xi2>) -> vector<8xi8> {
-// CHECK-SAME:      %[[IN:.*]]: vector<8xi2>) -> vector<8xi8> {
-// CHECK:           %[[CST_2:.*]] = arith.constant dense<2> : vector<2xi8>
-// CHECK:           %[[CST_4:.*]] = arith.constant dense<4> : vector<2xi8>
-// CHECK:           %[[CST_6:.*]] = arith.constant dense<6> : vector<2xi8>
-// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi2> to vector<2xi8>
-// Extract bits 0-1
-// CHECK:           %[[SHL_6:.*]] = arith.shli %[[BITCAST]], %[[CST_6]] : vector<2xi8>
-// CHECK:           %[[ELEM0:.*]] = arith.shrsi %[[SHL_6]], %[[CST_6]] : vector<2xi8>
-// Extract bits 2-3
-// CHECK:           %[[SHL_4:.*]] = arith.shli %[[BITCAST]], %[[CST_4]] : vector<2xi8>
-// CHECK:           %[[ELEM1:.*]] = arith.shrsi %[[SHL_4]], %[[CST_6]] : vector<2xi8>
-// Extract bits 4-5
-// CHECK:           %[[SHL_2:.*]] = arith.shli %[[BITCAST]], %[[CST_2]] : vector<2xi8>
-// CHECK:           %[[ELEM2:.*]] = arith.shrsi %[[SHL_2]], %[[CST_6]] : vector<2xi8>
-// Extract bits 6-7
-// CHECK:           %[[ELEM3:.*]] = arith.shrsi %[[BITCAST]], %[[CST_6]] : vector<2xi8>
-// CHECK:           %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<2xi8>
-// CHECK:           %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<2xi8>
-// CHECK:           %[[RESULT:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<4xi8>
-  %0 = arith.extsi %a : vector<8xi2> to vector<8xi8>
-  return %0 : vector<8xi8>
-}
-
-// CHECK-LABEL: func.func @aligned_extsi_i4_to_i32(
-func.func @aligned_extsi_i4_to_i32(%a: vector<8xi4>) -> vector<8xi32> {
-// CHECK-SAME:    %[[IN:.*]]: vector<8xi4>) -> vector<8xi32> {
-// CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
-// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
-// CHECK:           %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
-// CHECK:           %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<4xi8>
-// CHECK:           %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
-// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
-// CHECK:           %[[I32:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32>
-  %0 = arith.extsi %a : vector<8xi4> to vector<8xi32>
-  return %0 : vector<8xi32>
-}
-
-// CHECK-LABEL: func.func @aligned_extsi_i2_to_i32(
-func.func @aligned_extsi_i2_to_i32(%a: vector<8xi2>) -> vector<8xi32> {
-// CHECK-SAME:      %[[IN:.*]]: vector<8xi2>) -> vector<8xi32> {
-// CHECK:           %[[CST_2:.*]] = arith.constant dense<2> : vector<2xi8>
-// CHECK:           %[[CST_4:.*]] = arith.constant dense<4> : vector<2xi8>
-// CHECK:           %[[CST_6:.*]] = arith.constant dense<6> : vector<2xi8>
-// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi2> to vector<2xi8>
-// Extract bits 0-1
-// CHECK:           %[[SHL_6:.*]] = arith.shli %[[BITCAST]], %[[CST_6]] : vector<2xi8>
-// CHECK:           %[[ELEM0:.*]] = arith.shrsi %[[SHL_6]], %[[CST_6]] : vector<2xi8>
-// Extract bits 2-3
-// CHECK:           %[[SHL_4:.*]] = arith.shli %[[BITCAST]], %[[CST_4]] : vector<2xi8>
-// CHECK:           %[[ELEM1:.*]] = arith.shrsi %[[SHL_4]], %[[CST_6]] : vector<2xi8>
-// Extract bits 4-5
-// CHECK:           %[[SHL_2:.*]] = arith.shli %[[BITCAST]], %[[CST_2]] : vector<2xi8>
-// CHECK:           %[[ELEM2:.*]] = arith.shrsi %[[SHL_2]], %[[CST_6]] : vector<2xi8>
-// Extract bits 6-7
-// CHECK:           %[[ELEM3:.*]] = arith.shrsi %[[BITCAST]], %[[CST_6]] : vector<2xi8>
-// CHECK:           %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<2xi8>
-// CHECK:           %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<2xi8>
-// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<4xi8>
-// CHECK:           %[[RESULT:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32>
-  %0 = arith.extsi %a : vector<8xi2> to vector<8xi32>
-  return %0 : vector<8xi32>
-}
-
-// CHECK-LABEL: func.func @aligned_extsi_i4_to_i32_2d(
-func.func @aligned_extsi_i4_to_i32_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
-// CHECK-SAME:    %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xi32> {
-// CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
-// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8>
-// CHECK:           %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
-// CHECK:           %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<8x16xi8>
-// CHECK:           %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
-// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8>
-// CHECK:           %[[I32:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32>
-  %0 = arith.extsi %a : vector<8x32xi4> to vector<8x32xi32>
-  return %0 : vector<8x32xi32>
-}
-
-// CHECK-LABEL: func.func @aligned_extsi_i2_to_i32_2d(
-func.func @aligned_extsi_i2_to_i32_2d(%a: vector<8x32xi2>) -> vector<8x32xi32> {
-// CHECK-SAME:      %[[IN:.*]]: vector<8x32xi2>) -> vector<8x32xi32> {
-// CHECK:           %[[CST_2:.*]] = arith.constant dense<2> : vector<8x8xi8>
-// CHECK:           %[[CST_4:.*]] = arith.constant dense<4> : vector<8x8xi8>
-// CHECK:           %[[CST_6:.*]] = arith.constant dense<6> : vector<8x8xi8>
-// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi2> to vector<8x8xi8>
-// Extract bits 0-1
-// CHECK:           %[[SHL_6:.*]] = arith.shli %[[BITCAST]], %[[CST_6]] : vector<8x8xi8>
-// CHECK:           %[[ELEM0:.*]] = arith.shrsi %[[SHL_6]], %[[CST_6]] : vector<8x8xi8>
-// Extract bits 2-3
-// CHECK:           %[[SHL_4:.*]] = arith.shli %[[BITCAST]], %[[CST_4]] : vector<8x8xi8>
-// CHECK:           %[[ELEM1:.*]] = arith.shrsi %[[SHL_4]], %[[CST_6]] : vector<8x8xi8>
-// Extract bits 4-5
-// CHECK:           %[[SHL_2:.*]] = arith.shli %[[BITCAST]], %[[CST_2]] : vector<8x8xi8>
-// CHECK:           %[[ELEM2:.*]] = arith.shrsi %[[SHL_2]], %[[CST_6]] : vector<8x8xi8>
-// Extract bits 6-7
-// CHECK:           %[[ELEM3:.*]] = arith.shrsi %[[BITCAST]], %[[CST_6]] : vector<8x8xi8>
-// CHECK:           %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<8x8xi8>
-// CHECK:           %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<8x8xi8>
-// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<8x16xi8>
-// CHECK:           %[[RESULT:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32>
-  %0 = arith.extsi %a : vector<8x32xi2> to vector<8x32xi32>
-  return %0 : vector<8x32xi32>
-}
-
-
-// CHECK-LABEL: func.func @aligned_trunci_i8_to_i4(
-func.func @aligned_trunci_i8_to_i4(%a: vector<8xi8>) -> vector<8xi4> {
-// CHECK-SAME:    %[[IN:.*]]: vector<8xi8>) -> vector<8xi4> {
-// CHECK-DAG:       %[[LOW_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
-// CHECK-DAG:       %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
-// CHECK:           %[[LOW:.*]], %[[HIGH:.*]] = vector.deinterleave %[[IN]] : vector<8xi8> -> vector<4xi8>
-// CHECK:           %[[ZEROED_LOW:.*]] = arith.andi %[[LOW]], %[[LOW_MASK]] : vector<4xi8>
-// CHECK:           %[[SHL_HIGH:.*]] = arith.shli %[[HIGH]], %[[I4_BITS]] : vector<4xi8>
-// CHECK:           %[[MERGED:.*]] = arith.ori %[[ZEROED_LOW]], %[[SHL_HIGH]] : vector<4xi8>
-// CHECK:           %[[I4:.*]] = vector.bitcast %[[MERGED]] : vector<4xi8> to vector<8xi4>
-  %0 = arith.trunci %a : vector<8xi8> to vector<8xi4>
-  return %0 : vector<8xi4>
-}
-
-// CHECK-LABEL: func.func @aligned_trunci_i32_to_i4(
-func.func @aligned_trunci_i32_to_i4(%a: vector<8xi32>) -> vector<8xi4> {
-// CHECK-SAME:    %[[IN:.*]]: vector<8xi32>) -> vector<8xi4> {
-// CHECK-DAG:       %[[LOW_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
-// CHECK-DAG:       %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
-// CHECK:           %[[I8:.*]] = arith.trunci %[[IN]] : vector<8xi32> to vector<8xi8>
-// CHECK:           %[[LOW:.*]], %[[HIGH:.*]] = vector.deinterleave %[[I8]] : vector<8xi8> -> vector<4xi8>
-// CHECK:           %[[ZEROED_LOW:.*]] = arith.andi %[[LOW]], %[[LOW_MASK]] : vector<4xi8>
-// CHECK:           %[[SHL_HIGH:.*]] = arith.shli %[[HIGH]], %[[I4_BITS]] : vector<4xi8>
-// CHECK:           %[[MERGED:.*]] = arith.ori %[[ZEROED_LOW]], %[[SHL_HIGH]] : vector<4xi8>
-// CHECK:           %[[I4:.*]] = vector.bitcast %[[MERGED]] : vector<4xi8> to vector<8xi4>
-  %0 = arith.trunci %a : vector<8xi32> to vector<8xi4>
-  return %0 : vector<8xi4>
-}
-
-// CHECK-LABEL: func.func @aligned_trunci_2d(
-func.func @aligned_trunci_2d(%a: vector<8x32xi32>) -> vector<8x32xi4> {
-// CHECK-NOT:       vector.shuffle
-// CHECK-NOT:       vector.andi
-// CHECK-NOT:       vector.shli
-// CHECK-NOT:       vector.ori
-// CHECK:           arith.trunci {{.*}} : vector<8x32xi32> to vector<8x32xi8>
-// CHECK-NOT:       arith.trunci {{.*}} : vector<8x32xi8> to vector<8x32xi4>
-// CHECK:           vector.deinterleave
-  %0 = arith.trunci %a : vector<8x32xi32> to vector<8x32xi4>
-  return %0 : vector<8x32xi4>
-}
-
-// CHECK-LABEL: func.func @aligned_trunci_nd(
-// CHECK-SAME: %[[IN:.*]]: vector<3x8x32xi32>) -> vector<3x8x32xi4> {
-func.func @aligned_trunci_nd(%a: vector<3x8x32xi32>) -> vector<3x8x32xi4> {
-  // CHECK: %[[LEFT_SHIFT_BITS:.*]] = arith.constant dense<4> : vector<3x8x16xi8>
-  // CHECK: %[[I4_MASK:.*]] = arith.constant dense<15> : vector<3x8x16xi8>
-  // CHECK: %[[I8:.*]] = arith.trunci %[[IN]] : vector<3x8x32xi32> to vector<3x8x32xi8>
-  // CHECK: %[[LOW:.*]], %[[HIGH:.*]] = vector.deinterleave %[[I8]] : vector<3x8x32xi8> -> vector<3x8x16xi8>
-  // CHECK: %[[ZEROED_LOW:.*]] = arith.andi %[[LOW]], %[[I4_MASK]] : vector<3x8x16xi8>
-  // CHECK: %[[SHL_HIGH:.*]] = arith.shli %[[HIGH]], %[[LEFT_SHIFT_BITS]] : vector<3x8x16xi8>
-  // CHECK: %[[MERGED:.*]] = arith.ori %[[ZEROED_LOW]], %[[SHL_HIGH]] : vector<3x8x16xi8>
-  // CHECK: %[[I4:.*]] = vector.bitcast %[[MERGED]] : vector<3x8x16xi8> to vector<3x8x32xi4>
-  %0 = arith.trunci %a : vector<3x8x32xi32> to vector<3x8x32xi4>
-  return %0 : vector<3x8x32xi4>
-}
-
-func.func @aligned_trunci_i8_to_i2_no_match(%a: vector<8xi8>) -> vector<8xi2> {
-  // CHECK-NOT: arith.bitcast
-  // CHECK: arith.trunci %[[IN:.*]] : vector<8xi8> to vector<8xi2>
-  %0 = arith.trunci %a : vector<8xi8> to vector<8xi2>
-  return %0 : vector<8xi2>
-}
-
-// CHECK-LABEL: func.func @aligned_extui_i4_to_i8(
-func.func @aligned_extui_i4_to_i8(%a: vector<8xi4>) -> vector<8xi8> {
-// CHECK-SAME:                             %[[IN:.*]]: vector<8xi4>) -> vector<8xi8> {
-// CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
-// CHECK:           %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
-// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
-// CHECK:           %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<4xi8>
-// CHECK:           %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
-// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
-  %0 = arith.extui %a : vector<8xi4> to vector<8xi8>
-  return %0 : vector<8xi8>
-}
-
-// CHECK-LABEL: func.func @aligned_extui_i2_to_i8(
-func.func @aligned_extui_i2_to_i8(%a: vector<8xi2>) -> vector<8xi8> {
-// CHECK-SAME:      %[[IN:.*]]: vector<8xi2>) -> vector<8xi8> {
-// CHECK:           %[[CST_6:.*]] = arith.constant dense<6> : vector<2xi8>
-// CHECK:           %[[CST_4:.*]] = arith.constant dense<4> : vector<2xi8>
-// CHECK:           %[[CST_2:.*]] = arith.constant dense<2> : vector<2xi8>
-// CHECK:           %[[LOWBITS_MASK:.*]] = arith.constant dense<3> : vector<2xi8>
-// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi2> to vector<2xi8>
-// Extract bits 0-1
-// CHECK:           %[[ELEM0:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<2xi8>
-// Extract bits 2-3
-// CHECK:           %[[SHR_2:.*]] = arith.shrui %[[BITCAST]], %[[CST_2]] : vector<2xi8>
-// CHECK:           %[[ELEM1:.*]] = arith.andi %[[SHR_2]], %[[LOWBITS_MASK]] : vector<2xi8>
-// Extract bits 4-5
-// CHECK:           %[[SHR_4:.*]] = arith.shrui %[[BITCAST]], %[[CST_4]] : vector<2xi8>
-// CHECK:           %[[ELEM2:.*]] = arith.andi %[[SHR_4]], %[[LOWBITS_MASK]] : vector<2xi8>
-// Extract bits 6-7
-// CHECK:           %[[ELEM3:.*]] = arith.shrui %[[BITCAST]], %[[CST_6]] : vector<2xi8>
-// CHECK:           %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<2xi8>
-// CHECK:           %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<2xi8>
-// CHECK:           %[[RESULT:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<4xi8>
-  %0 = arith.extui %a : vector<8xi2> to vector<8xi8>
-  return %0 : vector<8xi8>
-}
-
-// CHECK-LABEL: func.func @aligned_extui_i4_to_i32(
-func.func @aligned_extui_i4_to_i32(%a: vector<8xi4>) -> vector<8xi32> {
-// CHECK-SAME:                             %[[IN:.*]]: vector<8xi4>) -> vector<8xi32> {
-// CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
-// CHECK:           %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
-// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
-// CHECK:           %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<4xi8>
-// CHECK:           %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
-// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
-// CHECK:           %[[I32:.*]] = arith.extui %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32>
-  %0 = arith.extui %a : vector<8xi4> to vector<8xi32>
-  return %0 : vector<8xi32>
-}
-
-// CHECK-LABEL: func.func @aligned_extui_i2_to_i32(
-func.func @aligned_extui_i2_to_i32(%a: vector<8xi2>) -> vector<8xi32> {
-// CHECK-SAME:      %[[IN:.*]]: vector<8xi2>) -> vector<8xi32> {
-// CHECK:           %[[CST_6:.*]] = arith.constant dense<6> : vector<2xi8>
-// CHECK:           %[[CST_4:.*]] = arith.constant dense<4> : vector<2xi8>
-// CHECK:           %[[CST_2:.*]] = arith.constant dense<2> : vector<2xi8>
-// CHECK:           %[[LOWBITS_MASK:.*]] = arith.constant dense<3> : vector<2xi8>
-// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi2> to vector<2xi8>
-// Extract bits 0-1
-// CHECK:           %[[ELEM0:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<2xi8>
-// Extract bits 2-3
-// CHECK:           %[[SHR_2:.*]] = arith.shrui %[[BITCAST]], %[[CST_2]] : vector<2xi8>
-// CHECK:           %[[ELEM1:.*]] = arith.andi %[[SHR_2]], %[[LOWBITS_MASK]] : vector<2xi8>
-// Extract bits 4-5
-// CHECK:           %[[SHR_4:.*]] = arith.shrui %[[BITCAST]], %[[CST_4]] : vector<2xi8>
-// CHECK:           %[[ELEM2:.*]] = arith.andi %[[SHR_4]], %[[LOWBITS_MASK]] : vector<2xi8>
-// Extract bits 6-7
-// CHECK:           %[[ELEM3:.*]] = arith.shrui %[[BITCAST]], %[[CST_6]] : vector<2xi8>
-// CHECK:           %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<2xi8>
-// CHECK:           %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<2xi8>
-// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<4xi8>
-// CHECK:           %[[RESULT:.*]] = arith.extui %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32>
-  %0 = arith.extui %a : vector<8xi2> to vector<8xi32>
-  return %0 : vector<8xi32>
-}
-
-// CHECK-LABEL: func.func @aligned_extui_i4_to_i32_2d(
-func.func @aligned_extui_i4_to_i32_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
-// CHECK-SAME:                                %[[VAL_0:.*]]: vector<8x32xi4>) -> vector<8x32xi32> {
-// CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
-// CHECK:           %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<8x16xi8>
-// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[VAL_0]] : vector<8x32xi4> to vector<8x16xi8>
-// CHECK:           %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<8x16xi8>
-// CHECK:           %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
-// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8>
-// CHECK:           %[[I32:.*]] = arith.extui %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32>
-  %0 = arith.extui %a : vector<8x32xi4> to vector<8x32xi32>
-  return %0 : vector<8x32xi32>
-}
-
-// CHECK-LABEL: func.func @aligned_extui_i2_to_i32_2d(
-func.func @aligned_extui_i2_to_i32_2d(%a: vector<8x32xi2>) -> vector<8x32xi32> {
-// CHECK-SAME:      %[[IN:.*]]: vector<8x32xi2>) -> vector<8x32xi32> {
-// CHECK:           %[[CST_6:.*]] = arith.constant dense<6> : vector<8x8xi8>
-// CHECK:           %[[CST_4:.*]] = arith.constant dense<4> : vector<8x8xi8>
-// CHECK:           %[[CST_2:.*]] = arith.constant dense<2> : vector<8x8xi8>
-// CHECK:           %[[LOWBITS_MASK:.*]] = arith.constant dense<3> : vector<8x8xi8>
-// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi2> to vector<8x8xi8>
-// Extract bits 0-1
-// CHECK:           %[[ELEM0:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<8x8xi8>
-// Extract bits 2-3
-// CHECK:           %[[SHR_2:.*]] = arith.shrui %[[BITCAST]], %[[CST_2]] : vector<8x8xi8>
-// CHECK:           %[[ELEM1:.*]] = arith.andi %[[SHR_2]], %[[LOWBITS_MASK]] : vector<8x8xi8>
-// Extract bits 4-5
-// CHECK:           %[[SHR_4:.*]] = arith.shrui %[[BITCAST]], %[[CST_4]] : vector<8x8xi8>
-// CHECK:           %[[ELEM2:.*]] = arith.andi %[[SHR_4]], %[[LOWBITS_MASK]] : vector<8x8xi8>
-// Extract bits 6-7
-// CHECK:           %[[ELEM3:.*]] = arith.shrui %[[BITCAST]], %[[CST_6]] : vector<8x8xi8>
-// CHECK:           %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<8x8xi8>
-// CHECK:           %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<8x8xi8>
-// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<8x16xi8>
-// CHECK:           %[[RESULT:.*]] = arith.extui %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32>
-  %0 = arith.extui %a : vector<8x32xi2> to vector<8x32xi32>
-  return %0 : vector<8x32xi32>
-}
-
-// CHECK-LABEL: func.func @aligned_sitofp(
-func.func @aligned_sitofp(%a: vector<8xi4>) -> vector<8xf32> {
-// CHECK-SAME:    %[[IN:.*]]: vector<8xi4>) -> vector<8xf32> {
-// CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
-// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
-// CHECK:           %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
-// CHECK:           %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<4xi8>
-// CHECK:           %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
-// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
-// CHECK:           %[[F32:.*]] = arith.sitofp %[[INTERLEAVE]] : vector<8xi8> to vector<8xf32>
-  %0 = arith.sitofp %a : vector<8xi4> to vector<8xf32>
-  return %0 : vector<8xf32>
-}
-
-// CHECK-LABEL: func.func @aligned_sitofp_2d(
-func.func @aligned_sitofp_2d(%a: vector<8x32xi4>) -> vector<8x32xf32> {
-// CHECK-SAME:    %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xf32> {
-// CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
-// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8>
-// CHECK:           %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
-// CHECK:           %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<8x16xi8>
-// CHECK:           %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
-// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8>
-// CHECK:           %[[F32:.*]] = arith.sitofp %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xf32>
-  %0 = arith.sitofp %a : vector<8x32xi4> to vector<8x32xf32>
-  return %0 : vector<8x32xf32>
-}
-
-// CHECK-LABEL: func.func @aligned_uitofp(
-func.func @aligned_uitofp(%a: vector<8xi4>) -> vector<8xf32> {
-// CHECK-SAME:    %[[IN:.*]]: vector<8xi4>) -> vector<8xf32> {
-// CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
-// CHECK:           %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
-// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
-// CHECK:           %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<4xi8>
-// CHECK:           %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
-// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
-// CHECK:           %[[F32:.*]] = arith.uitofp %[[INTERLEAVE]] : vector<8xi8> to vector<8xf32>
-  %0 = arith.uitofp %a : vector<8xi4> to vector<8xf32>
-  return %0 : vector<8xf32>
-}
-
-// CHECK-LABEL: func.func @aligned_uitofp_2d(
-func.func @aligned_uitofp_2d(%a: vector<8x32xi4>) -> vector<8x32xf32> {
-// CHECK-SAME:    %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xf32> {
-// CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
-// CHECK:           %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<8x16xi8>
-// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8>
-// CHECK:           %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<8x16xi8>
-// CHECK:           %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
-// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8>
-// CHECK:           %[[F32:.*]] = arith.uitofp %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xf32>
-  %0 = arith.uitofp %a : vector<8x32xi4> to vector<8x32xf32>
-  return %0 : vector<8x32xf32>
-}
-
 // CHECK-LABEL: func.func @i4_transpose(
 func.func @i4_transpose(%a: vector<8x16xi4>) -> vector<16x8xi4> {
 // CHECK-SAME:    %[[IN:.*]]: vector<8x16xi4>) -> vector<16x8xi4> {
@@ -589,7 +213,6 @@ func.func @i7_transpose(%a: vector<8x16xi7>) -> vector<16x8xi7> {
   return %0 : vector<16x8xi7>
 }
 
-
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
     %f = transform.structured.match ops{["func.func"]} in %module_op
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-subbyte-ext-and-trunci.mlir b/mlir/test/Dialect/Vector/vector-rewrite-subbyte-ext-and-trunci.mlir
new file mode 100644
index 0000000000000..aa75e02005256
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-rewrite-subbyte-ext-and-trunci.mlir
@@ -0,0 +1,415 @@
+// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
+
+///----------------------------------------------------------------------------------------
+/// arith.extsi
+///
+/// [Pattern: RewriteAlignedSubByteIntExt]
+///----------------------------------------------------------------------------------------
+// Negative test - the trailing dim 1 is not a multiple of 2 (i.e. 8 / 4).
+// CHECK-LABEL: func.func @unaligned_extsi_i4_to_i8(
+func.func @unaligned_extsi_i4_to_i8(%a: vector<1xi4>) -> vector<1xi8> {
+  // CHECK-NOT: arith.bitcast
+  // CHECK: arith.extsi %[[IN:.*]] : vector<1xi4> to vector<1xi8>
+  %0 = arith.extsi %a : vector<1xi4> to vector<1xi8>
+  return %0 : vector<1xi8>
+}
+
+// Negative test - the trailing dim 2 is not a multiple of 4 (i.e. 8 / 2).
+// CHECK-LABEL: func.func @unaligned_extsi_i2_to_i8(
+func.func @unaligned_extsi_i2_to_i8(%a: vector<2xi2>) -> vector<2xi8> {
+  // CHECK-NOT: arith.bitcast
+  // CHECK: arith.extsi %[[IN:.*]] : vector<2xi2> to vector<2xi8>
+  %0 = arith.extsi %a : vector<2xi2> to vector<2xi8>
+  return %0 : vector<2xi8>
+}
+
+// CHECK-LABEL: func.func @aligned_extsi_i4_to_i8(
+func.func @aligned_extsi_i4_to_i8(%a: vector<8xi4>) -> vector<8xi8> {
+// CHECK-SAME:    %[[IN:.*]]: vector<8xi4>) -> vector<8xi8> {
+// CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
+// CHECK:           %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK:           %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<4xi8>
+// CHECK:           %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
+  %0 = arith.extsi %a : vector<8xi4> to vector<8xi8>
+  return %0 : vector<8xi8>
+}
+
+// CHECK-LABEL: func.func @aligned_extsi_i2_to_i8(
+func.func @aligned_extsi_i2_to_i8(%a: vector<8xi2>) -> vector<8xi8> {
+// CHECK-SAME:      %[[IN:.*]]: vector<8xi2>) -> vector<8xi8> {
+// CHECK:           %[[CST_2:.*]] = arith.constant dense<2> : vector<2xi8>
+// CHECK:           %[[CST_4:.*]] = arith.constant dense<4> : vector<2xi8>
+// CHECK:           %[[CST_6:.*]] = arith.constant dense<6> : vector<2xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi2> to vector<2xi8>
+// Extract bits 0-1
+// CHECK:           %[[SHL_6:.*]] = arith.shli %[[BITCAST]], %[[CST_6]] : vector<2xi8>
+// CHECK:           %[[ELEM0:.*]] = arith.shrsi %[[SHL_6]], %[[CST_6]] : vector<2xi8>
+// Extract bits 2-3
+// CHECK:           %[[SHL_4:.*]] = arith.shli %[[BITCAST]], %[[CST_4]] : vector<2xi8>
+// CHECK:           %[[ELEM1:.*]] = arith.shrsi %[[SHL_4]], %[[CST_6]] : vector<2xi8>
+// Extract bits 4-5
+// CHECK:           %[[SHL_2:.*]] = arith.shli %[[BITCAST]], %[[CST_2]] : vector<2xi8>
+// CHECK:           %[[ELEM2:.*]] = arith.shrsi %[[SHL_2]], %[[CST_6]] : vector<2xi8>
+// Extract bits 6-7
+// CHECK:           %[[ELEM3:.*]] = arith.shrsi %[[BITCAST]], %[[CST_6]] : vector<2xi8>
+// CHECK:           %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<2xi8>
+// CHECK:           %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<2xi8>
+// CHECK:           %[[RESULT:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<4xi8>
+  %0 = arith.extsi %a : vector<8xi2> to vector<8xi8>
+  return %0 : vector<8xi8>
+}
+
+// CHECK-LABEL: func.func @aligned_extsi_i4_to_i32(
+func.func @aligned_extsi_i4_to_i32(%a: vector<8xi4>) -> vector<8xi32> {
+// CHECK-SAME:    %[[IN:.*]]: vector<8xi4>) -> vector<8xi32> {
+// CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
+// CHECK:           %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK:           %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<4xi8>
+// CHECK:           %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
+// CHECK:           %[[I32:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32>
+  %0 = arith.extsi %a : vector<8xi4> to vector<8xi32>
+  return %0 : vector<8xi32>
+}
+
+// CHECK-LABEL: func.func @aligned_extsi_i2_to_i32(
+func.func @aligned_extsi_i2_to_i32(%a: vector<8xi2>) -> vector<8xi32> {
+// CHECK-SAME:      %[[IN:.*]]: vector<8xi2>) -> vector<8xi32> {
+// CHECK:           %[[CST_2:.*]] = arith.constant dense<2> : vector<2xi8>
+// CHECK:           %[[CST_4:.*]] = arith.constant dense<4> : vector<2xi8>
+// CHECK:           %[[CST_6:.*]] = arith.constant dense<6> : vector<2xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi2> to vector<2xi8>
+// Extract bits 0-1
+// CHECK:           %[[SHL_6:.*]] = arith.shli %[[BITCAST]], %[[CST_6]] : vector<2xi8>
+// CHECK:           %[[ELEM0:.*]] = arith.shrsi %[[SHL_6]], %[[CST_6]] : vector<2xi8>
+// Extract bits 2-3
+// CHECK:           %[[SHL_4:.*]] = arith.shli %[[BITCAST]], %[[CST_4]] : vector<2xi8>
+// CHECK:           %[[ELEM1:.*]] = arith.shrsi %[[SHL_4]], %[[CST_6]] : vector<2xi8>
+// Extract bits 4-5
+// CHECK:           %[[SHL_2:.*]] = arith.shli %[[BITCAST]], %[[CST_2]] : vector<2xi8>
+// CHECK:           %[[ELEM2:.*]] = arith.shrsi %[[SHL_2]], %[[CST_6]] : vector<2xi8>
+// Extract bits 6-7
+// CHECK:           %[[ELEM3:.*]] = arith.shrsi %[[BITCAST]], %[[CST_6]] : vector<2xi8>
+// CHECK:           %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<2xi8>
+// CHECK:           %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<2xi8>
+// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<4xi8>
+// CHECK:           %[[RESULT:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32>
+  %0 = arith.extsi %a : vector<8xi2> to vector<8xi32>
+  return %0 : vector<8xi32>
+}
+
+// CHECK-LABEL: func.func @aligned_extsi_i4_to_i32_2d(
+func.func @aligned_extsi_i4_to_i32_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
+// CHECK-SAME:    %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xi32> {
+// CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8>
+// CHECK:           %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK:           %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK:           %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8>
+// CHECK:           %[[I32:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32>
+  %0 = arith.extsi %a : vector<8x32xi4> to vector<8x32xi32>
+  return %0 : vector<8x32xi32>
+}
+
+// CHECK-LABEL: func.func @aligned_extsi_i2_to_i32_2d(
+func.func @aligned_extsi_i2_to_i32_2d(%a: vector<8x32xi2>) -> vector<8x32xi32> {
+// CHECK-SAME:      %[[IN:.*]]: vector<8x32xi2>) -> vector<8x32xi32> {
+// CHECK:           %[[CST_2:.*]] = arith.constant dense<2> : vector<8x8xi8>
+// CHECK:           %[[CST_4:.*]] = arith.constant dense<4> : vector<8x8xi8>
+// CHECK:           %[[CST_6:.*]] = arith.constant dense<6> : vector<8x8xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi2> to vector<8x8xi8>
+// Extract bits 0-1
+// CHECK:           %[[SHL_6:.*]] = arith.shli %[[BITCAST]], %[[CST_6]] : vector<8x8xi8>
+// CHECK:           %[[ELEM0:.*]] = arith.shrsi %[[SHL_6]], %[[CST_6]] : vector<8x8xi8>
+// Extract bits 2-3
+// CHECK:           %[[SHL_4:.*]] = arith.shli %[[BITCAST]], %[[CST_4]] : vector<8x8xi8>
+// CHECK:           %[[ELEM1:.*]] = arith.shrsi %[[SHL_4]], %[[CST_6]] : vector<8x8xi8>
+// Extract bits 4-5
+// CHECK:           %[[SHL_2:.*]] = arith.shli %[[BITCAST]], %[[CST_2]] : vector<8x8xi8>
+// CHECK:           %[[ELEM2:.*]] = arith.shrsi %[[SHL_2]], %[[CST_6]] : vector<8x8xi8>
+// Extract bits 6-7
+// CHECK:           %[[ELEM3:.*]] = arith.shrsi %[[BITCAST]], %[[CST_6]] : vector<8x8xi8>
+// CHECK:           %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<8x8xi8>
+// CHECK:           %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<8x8xi8>
+// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<8x16xi8>
+// CHECK:           %[[RESULT:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32>
+  %0 = arith.extsi %a : vector<8x32xi2> to vector<8x32xi32>
+  return %0 : vector<8x32xi32>
+}
+
+///----------------------------------------------------------------------------------------
+/// arith.trunci
+///
+/// [Pattern: RewriteAlignedSubByteIntTrunc]
+///----------------------------------------------------------------------------------------
+// CHECK-LABEL: func.func @aligned_trunci_i8_to_i4(
+func.func @aligned_trunci_i8_to_i4(%a: vector<8xi8>) -> vector<8xi4> {
+// CHECK-SAME:    %[[IN:.*]]: vector<8xi8>) -> vector<8xi4> {
+// CHECK-DAG:       %[[LOW_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
+// CHECK-DAG:       %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK:           %[[LOW:.*]], %[[HIGH:.*]] = vector.deinterleave %[[IN]] : vector<8xi8> -> vector<4xi8>
+// CHECK:           %[[ZEROED_LOW:.*]] = arith.andi %[[LOW]], %[[LOW_MASK]] : vector<4xi8>
+// CHECK:           %[[SHL_HIGH:.*]] = arith.shli %[[HIGH]], %[[I4_BITS]] : vector<4xi8>
+// CHECK:           %[[MERGED:.*]] = arith.ori %[[ZEROED_LOW]], %[[SHL_HIGH]] : vector<4xi8>
+// CHECK:           %[[I4:.*]] = vector.bitcast %[[MERGED]] : vector<4xi8> to vector<8xi4>
+  %0 = arith.trunci %a : vector<8xi8> to vector<8xi4>
+  return %0 : vector<8xi4>
+}
+
+// CHECK-LABEL: func.func @aligned_trunci_i32_to_i4(
+func.func @aligned_trunci_i32_to_i4(%a: vector<8xi32>) -> vector<8xi4> {
+// CHECK-SAME:    %[[IN:.*]]: vector<8xi32>) -> vector<8xi4> {
+// CHECK-DAG:       %[[LOW_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
+// CHECK-DAG:       %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK:           %[[I8:.*]] = arith.trunci %[[IN]] : vector<8xi32> to vector<8xi8>
+// CHECK:           %[[LOW:.*]], %[[HIGH:.*]] = vector.deinterleave %[[I8]] : vector<8xi8> -> vector<4xi8>
+// CHECK:           %[[ZEROED_LOW:.*]] = arith.andi %[[LOW]], %[[LOW_MASK]] : vector<4xi8>
+// CHECK:           %[[SHL_HIGH:.*]] = arith.shli %[[HIGH]], %[[I4_BITS]] : vector<4xi8>
+// CHECK:           %[[MERGED:.*]] = arith.ori %[[ZEROED_LOW]], %[[SHL_HIGH]] : vector<4xi8>
+// CHECK:           %[[I4:.*]] = vector.bitcast %[[MERGED]] : vector<4xi8> to vector<8xi4>
+  %0 = arith.trunci %a : vector<8xi32> to vector<8xi4>
+  return %0 : vector<8xi4>
+}
+
+// CHECK-LABEL: func.func @aligned_trunci_2d(
+func.func @aligned_trunci_2d(%a: vector<8x32xi32>) -> vector<8x32xi4> {
+// CHECK-NOT:       vector.shuffle
+// CHECK-NOT:       vector.andi
+// CHECK-NOT:       vector.shli
+// CHECK-NOT:       vector.ori
+// CHECK:           arith.trunci {{.*}} : vector<8x32xi32> to vector<8x32xi8>
+// CHECK-NOT:       arith.trunci {{.*}} : vector<8x32xi8> to vector<8x32xi4>
+// CHECK:           vector.deinterleave
+  %0 = arith.trunci %a : vector<8x32xi32> to vector<8x32xi4>
+  return %0 : vector<8x32xi4>
+}
+
+// CHECK-LABEL: func.func @aligned_trunci_nd(
+// CHECK-SAME: %[[IN:.*]]: vector<3x8x32xi32>) -> vector<3x8x32xi4> {
+func.func @aligned_trunci_nd(%a: vector<3x8x32xi32>) -> vector<3x8x32xi4> {
+  // CHECK: %[[LEFT_SHIFT_BITS:.*]] = arith.constant dense<4> : vector<3x8x16xi8>
+  // CHECK: %[[I4_MASK:.*]] = arith.constant dense<15> : vector<3x8x16xi8>
+  // CHECK: %[[I8:.*]] = arith.trunci %[[IN]] : vector<3x8x32xi32> to vector<3x8x32xi8>
+  // CHECK: %[[LOW:.*]], %[[HIGH:.*]] = vector.deinterleave %[[I8]] : vector<3x8x32xi8> -> vector<3x8x16xi8>
+  // CHECK: %[[ZEROED_LOW:.*]] = arith.andi %[[LOW]], %[[I4_MASK]] : vector<3x8x16xi8>
+  // CHECK: %[[SHL_HIGH:.*]] = arith.shli %[[HIGH]], %[[LEFT_SHIFT_BITS]] : vector<3x8x16xi8>
+  // CHECK: %[[MERGED:.*]] = arith.ori %[[ZEROED_LOW]], %[[SHL_HIGH]] : vector<3x8x16xi8>
+  // CHECK: %[[I4:.*]] = vector.bitcast %[[MERGED]] : vector<3x8x16xi8> to vector<3x8x32xi4>
+  %0 = arith.trunci %a : vector<3x8x32xi32> to vector<3x8x32xi4>
+  return %0 : vector<3x8x32xi4>
+}
+
+func.func @aligned_trunci_i8_to_i2_no_match(%a: vector<8xi8>) -> vector<8xi2> {
+  // CHECK-NOT: arith.bitcast
+  // CHECK: arith.trunci %[[IN:.*]] : vector<8xi8> to vector<8xi2>
+  %0 = arith.trunci %a : vector<8xi8> to vector<8xi2>
+  return %0 : vector<8xi2>
+}
+
+///----------------------------------------------------------------------------------------
+/// arith.extui
+///
+/// [Pattern: RewriteAlignedSubByteIntExt]
+///----------------------------------------------------------------------------------------
+
+// CHECK-LABEL: func.func @aligned_extui_i4_to_i8(
+func.func @aligned_extui_i4_to_i8(%a: vector<8xi4>) -> vector<8xi8> {
+// CHECK-SAME:                             %[[IN:.*]]: vector<8xi4>) -> vector<8xi8> {
+// CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK:           %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
+// CHECK:           %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<4xi8>
+// CHECK:           %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
+  %0 = arith.extui %a : vector<8xi4> to vector<8xi8>
+  return %0 : vector<8xi8>
+}
+
+// CHECK-LABEL: func.func @aligned_extui_i2_to_i8(
+func.func @aligned_extui_i2_to_i8(%a: vector<8xi2>) -> vector<8xi8> {
+// CHECK-SAME:      %[[IN:.*]]: vector<8xi2>) -> vector<8xi8> {
+// CHECK:           %[[CST_6:.*]] = arith.constant dense<6> : vector<2xi8>
+// CHECK:           %[[CST_4:.*]] = arith.constant dense<4> : vector<2xi8>
+// CHECK:           %[[CST_2:.*]] = arith.constant dense<2> : vector<2xi8>
+// CHECK:           %[[LOWBITS_MASK:.*]] = arith.constant dense<3> : vector<2xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi2> to vector<2xi8>
+// Extract bits 0-1
+// CHECK:           %[[ELEM0:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<2xi8>
+// Extract bits 2-3
+// CHECK:           %[[SHR_2:.*]] = arith.shrui %[[BITCAST]], %[[CST_2]] : vector<2xi8>
+// CHECK:           %[[ELEM1:.*]] = arith.andi %[[SHR_2]], %[[LOWBITS_MASK]] : vector<2xi8>
+// Extract bits 4-5
+// CHECK:           %[[SHR_4:.*]] = arith.shrui %[[BITCAST]], %[[CST_4]] : vector<2xi8>
+// CHECK:           %[[ELEM2:.*]] = arith.andi %[[SHR_4]], %[[LOWBITS_MASK]] : vector<2xi8>
+// Extract bits 6-7
+// CHECK:           %[[ELEM3:.*]] = arith.shrui %[[BITCAST]], %[[CST_6]] : vector<2xi8>
+// CHECK:           %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<2xi8>
+// CHECK:           %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<2xi8>
+// CHECK:           %[[RESULT:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<4xi8>
+  %0 = arith.extui %a : vector<8xi2> to vector<8xi8>
+  return %0 : vector<8xi8>
+}
+
+// CHECK-LABEL: func.func @aligned_extui_i4_to_i32(
+func.func @aligned_extui_i4_to_i32(%a: vector<8xi4>) -> vector<8xi32> {
+// CHECK-SAME:                             %[[IN:.*]]: vector<8xi4>) -> vector<8xi32> {
+// CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK:           %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
+// CHECK:           %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<4xi8>
+// CHECK:           %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
+// CHECK:           %[[I32:.*]] = arith.extui %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32>
+  %0 = arith.extui %a : vector<8xi4> to vector<8xi32>
+  return %0 : vector<8xi32>
+}
+
+// CHECK-LABEL: func.func @aligned_extui_i2_to_i32(
+func.func @aligned_extui_i2_to_i32(%a: vector<8xi2>) -> vector<8xi32> {
+// CHECK-SAME:      %[[IN:.*]]: vector<8xi2>) -> vector<8xi32> {
+// CHECK:           %[[CST_6:.*]] = arith.constant dense<6> : vector<2xi8>
+// CHECK:           %[[CST_4:.*]] = arith.constant dense<4> : vector<2xi8>
+// CHECK:           %[[CST_2:.*]] = arith.constant dense<2> : vector<2xi8>
+// CHECK:           %[[LOWBITS_MASK:.*]] = arith.constant dense<3> : vector<2xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi2> to vector<2xi8>
+// Extract bits 0-1
+// CHECK:           %[[ELEM0:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<2xi8>
+// Extract bits 2-3
+// CHECK:           %[[SHR_2:.*]] = arith.shrui %[[BITCAST]], %[[CST_2]] : vector<2xi8>
+// CHECK:           %[[ELEM1:.*]] = arith.andi %[[SHR_2]], %[[LOWBITS_MASK]] : vector<2xi8>
+// Extract bits 4-5
+// CHECK:           %[[SHR_4:.*]] = arith.shrui %[[BITCAST]], %[[CST_4]] : vector<2xi8>
+// CHECK:           %[[ELEM2:.*]] = arith.andi %[[SHR_4]], %[[LOWBITS_MASK]] : vector<2xi8>
+// Extract bits 6-7
+// CHECK:           %[[ELEM3:.*]] = arith.shrui %[[BITCAST]], %[[CST_6]] : vector<2xi8>
+// CHECK:           %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<2xi8>
+// CHECK:           %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<2xi8>
+// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<4xi8>
+// CHECK:           %[[RESULT:.*]] = arith.extui %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32>
+  %0 = arith.extui %a : vector<8xi2> to vector<8xi32>
+  return %0 : vector<8xi32>
+}
+
+// CHECK-LABEL: func.func @aligned_extui_i4_to_i32_2d(
+func.func @aligned_extui_i4_to_i32_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
+// CHECK-SAME:                                %[[VAL_0:.*]]: vector<8x32xi4>) -> vector<8x32xi32> {
+// CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
+// CHECK:           %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<8x16xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[VAL_0]] : vector<8x32xi4> to vector<8x16xi8>
+// CHECK:           %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<8x16xi8>
+// CHECK:           %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8>
+// CHECK:           %[[I32:.*]] = arith.extui %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32>
+  %0 = arith.extui %a : vector<8x32xi4> to vector<8x32xi32>
+  return %0 : vector<8x32xi32>
+}
+
+// CHECK-LABEL: func.func @aligned_extui_i2_to_i32_2d(
+func.func @aligned_extui_i2_to_i32_2d(%a: vector<8x32xi2>) -> vector<8x32xi32> {
+// CHECK-SAME:      %[[IN:.*]]: vector<8x32xi2>) -> vector<8x32xi32> {
+// CHECK:           %[[CST_6:.*]] = arith.constant dense<6> : vector<8x8xi8>
+// CHECK:           %[[CST_4:.*]] = arith.constant dense<4> : vector<8x8xi8>
+// CHECK:           %[[CST_2:.*]] = arith.constant dense<2> : vector<8x8xi8>
+// CHECK:           %[[LOWBITS_MASK:.*]] = arith.constant dense<3> : vector<8x8xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi2> to vector<8x8xi8>
+// Extract bits 0-1
+// CHECK:           %[[ELEM0:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<8x8xi8>
+// Extract bits 2-3
+// CHECK:           %[[SHR_2:.*]] = arith.shrui %[[BITCAST]], %[[CST_2]] : vector<8x8xi8>
+// CHECK:           %[[ELEM1:.*]] = arith.andi %[[SHR_2]], %[[LOWBITS_MASK]] : vector<8x8xi8>
+// Extract bits 4-5
+// CHECK:           %[[SHR_4:.*]] = arith.shrui %[[BITCAST]], %[[CST_4]] : vector<8x8xi8>
+// CHECK:           %[[ELEM2:.*]] = arith.andi %[[SHR_4]], %[[LOWBITS_MASK]] : vector<8x8xi8>
+// Extract bits 6-7
+// CHECK:           %[[ELEM3:.*]] = arith.shrui %[[BITCAST]], %[[CST_6]] : vector<8x8xi8>
+// CHECK:           %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<8x8xi8>
+// CHECK:           %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<8x8xi8>
+// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<8x16xi8>
+// CHECK:           %[[RESULT:.*]] = arith.extui %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32>
+  %0 = arith.extui %a : vector<8x32xi2> to vector<8x32xi32>
+  return %0 : vector<8x32xi32>
+}
+
+///----------------------------------------------------------------------------------------
+/// arith.sitofp
+///
+/// [Pattern: RewriteAlignedSubByteIntExt]
+///----------------------------------------------------------------------------------------
+
+// CHECK-LABEL: func.func @aligned_sitofp(
+func.func @aligned_sitofp(%a: vector<8xi4>) -> vector<8xf32> {
+// CHECK-SAME:    %[[IN:.*]]: vector<8xi4>) -> vector<8xf32> {
+// CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
+// CHECK:           %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK:           %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<4xi8>
+// CHECK:           %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
+// CHECK:           %[[F32:.*]] = arith.sitofp %[[INTERLEAVE]] : vector<8xi8> to vector<8xf32>
+  %0 = arith.sitofp %a : vector<8xi4> to vector<8xf32>
+  return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: func.func @aligned_sitofp_2d(
+func.func @aligned_sitofp_2d(%a: vector<8x32xi4>) -> vector<8x32xf32> {
+// CHECK-SAME:    %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xf32> {
+// CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8>
+// CHECK:           %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK:           %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK:           %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8>
+// CHECK:           %[[F32:.*]] = arith.sitofp %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xf32>
+  %0 = arith.sitofp %a : vector<8x32xi4> to vector<8x32xf32>
+  return %0 : vector<8x32xf32>
+}
+
+///----------------------------------------------------------------------------------------
+/// arith.uitofp
+///
+/// [Pattern: RewriteAlignedSubByteIntExt]
+///----------------------------------------------------------------------------------------
+
+// CHECK-LABEL: func.func @aligned_uitofp(
+func.func @aligned_uitofp(%a: vector<8xi4>) -> vector<8xf32> {
+// CHECK-SAME:    %[[IN:.*]]: vector<8xi4>) -> vector<8xf32> {
+// CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK:           %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
+// CHECK:           %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<4xi8>
+// CHECK:           %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
+// CHECK:           %[[F32:.*]] = arith.uitofp %[[INTERLEAVE]] : vector<8xi8> to vector<8xf32>
+  %0 = arith.uitofp %a : vector<8xi4> to vector<8xf32>
+  return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: func.func @aligned_uitofp_2d(
+func.func @aligned_uitofp_2d(%a: vector<8x32xi4>) -> vector<8x32xf32> {
+// CHECK-SAME:    %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xf32> {
+// CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
+// CHECK:           %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<8x16xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8>
+// CHECK:           %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<8x16xi8>
+// CHECK:           %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8>
+// CHECK:           %[[F32:.*]] = arith.uitofp %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xf32>
+  %0 = arith.uitofp %a : vector<8x32xi4> to vector<8x32xf32>
+  return %0 : vector<8x32xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+    %f = transform.structured.match ops{["func.func"]} in %module_op
+        : (!transform.any_op) -> !transform.any_op
+
+    transform.apply_patterns to %f {
+      transform.apply_patterns.vector.rewrite_narrow_types
+    } : !transform.any_op
+    transform.yield
+  }
+}

>From 411ee3ca91c1c8192521d618ffd1afecfd89ed0d Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Fri, 17 Jan 2025 13:54:34 +0000
Subject: [PATCH 2/2] [mlir][Vector] Update VectorEmulateNarrowType.cpp (4/N)

This is PR 4 in a series of N patches aimed at improving
"VectorEmulateNarrowType.cpp". This is mainly minor refactoring, no
major functional changes are made/added.

1. Update `alignedConversionPrecondition` (1):

This method didn't require the vector type for the "destination"
argument. The underlying element type is sufficient. The corresponding
argument has been renamed as `multiByteScalarTy` - this is meant as the
multi-byte emulated type (`i8`, `i16`, `i32`, etc).

2. Update `alignedConversionPrecondition` (2):

In #121298, we replaced `dstElemBitwidt` in this calculation:

```cpp
  const int numSrcElemsPerDestElem = dstElemBitwidth / srcElemBitwidth;
```

with the hard-coded value of 8:
```cpp
  const int numSrcElemsPerDestElem = 8 / srcElemBitwidth;
```

That was correct as for the patterns for which this hook was/is used:

  * `RewriteAlignedSubByteIntExt`,
  * `RewriteAlignedSubByteIntTrunc`.

The destination type (or, more precisely, the emulated type) was always
`i8`.

In this PR, I am switching back to a more generic approach - the
calculation should take into account the bit-width of the emulated type.

Note that at the call sites I am passing `i8` as the emulated type, so the
end-result is effectively identical. However, the intent is clearer, i.e.,
the underlying value is 8 because the emulated type happens to be `i8`
(as opposed using a magic number).

3. Update alignedConversionPrecondition (3):

The final check has been replaced with a new helper method,
`isSubByteVecFittable`. This new method is also re-used within the code
and hopefully will allow us more code re-use moving forward (to avoid
re-implementing the same condition).

4. Update alignedConversionPrecondition (4):

NEXT STEPS:

We need to clarify the meaning of "source" and "destination" types.
Currently the usage is ambiguous.

For example, for this `arith.extsi` Op, `vector<8xi2>` and
`vector<8xi32>` are the "source" and "destination" types, respectively:

```mlir
  %0 = arith.extsi %arg0 : vector<8xi2> to vector<8xi32>
}
```

However, patterns like `RewriteAlignedSubByteIntExt` introduce
`vector.bitcast` Ops like this:

```mlir
  %bitcast = vector.bitcast %arg0 : vector<8xi2> to vector<2xi8>
```

I've noticed that we tend to mix `vector<2xi8>` and `vector<8xi32>` as
the destination types and that should be clarified.
---
 .../Transforms/VectorEmulateNarrowType.cpp    | 134 +++++++++++++-----
 1 file changed, 102 insertions(+), 32 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 057788d85a0f7..ae9a06ea3f169 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1050,6 +1050,38 @@ struct ConvertVectorMaskedLoad final
   }
 };
 
+/// Check whether `subByteVecTy` fits wthin a vector of `multiByteScalarTy`
+///
+/// "Fitting" means that `subByteVecTy` (a vector of sub-byte elements, e.g.
+/// vector<4xi4>), can fit within N scalar elements of type `multiByteScalarTy`
+/// (a multi-byte scalar, e.g. i16), where N is some integer.
+///
+/// Put differently, this method checks whether this would be valid:
+///
+///   vector.bitcast subByteVecTy into vector<N x multiByteScalarTy>
+///
+/// EXAMPLES:
+///   * vector<4xi4> -> i16 - yes (N = 1)
+///   * vector<4xi4> -> i8 - yes (N = 2)
+///   * vector<3xi4> -> i8 - no (N would have to be 1.5)
+///   * vector<3xi2> -> i16 - no (N would have to be 0.5)
+static bool isSubByteVecFittable(VectorType subByteVecTy,
+                                 Type multiByteScalarTy) {
+  assert((isa<IntegerType, FloatType>(multiByteScalarTy)) && "Not scalar!");
+
+  int subByteBits = subByteVecTy.getElementType().getIntOrFloatBitWidth();
+  int multiByteBits = multiByteScalarTy.getIntOrFloatBitWidth();
+
+  assert(subByteBits < 8 && "Not a sub-byte scalar type!");
+  assert(multiByteBits % 8 == 0 && "Not a multi-byte scalar type!");
+  assert(multiByteBits % subByteBits == 0 && "Unalagined element types!");
+
+  int elemsPerMultiByte = multiByteBits / subByteBits;
+
+  // TODO: This is a bit too restrictive for vectors rank > 1.
+  return subByteVecTy.getShape().back() % elemsPerMultiByte == 0;
+}
+
 //===----------------------------------------------------------------------===//
 // ConvertVectorTransferRead
 //===----------------------------------------------------------------------===//
@@ -1086,7 +1118,8 @@ struct ConvertVectorTransferRead final
     auto origElements = op.getVectorType().getNumElements();
 
     // Note, per-element-alignment was already verified above.
-    bool isFullyAligned = origElements % emulatedPerContainerElem == 0;
+    bool isFullyAligned =
+        isSubByteVecFittable(op.getVectorType(), containerElemTy);
 
     auto newPadding = rewriter.create<arith::ExtUIOp>(loc, containerElemTy,
                                                       adaptor.getPadding());
@@ -1387,41 +1420,76 @@ LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter,
   return commonConversionPrecondition(rewriter, preconditionType, op);
 }
 
-/// Verify that `subByteVecType` and `dstType` are aligned. Alignment
-/// means that:
-///   1. The `dstType` element type is a multiple of the
-///   `srcVectorOfSubByteType` element type (e.g. i4 vs i8 is OK, but i3 vs i8
-///   is not supported). Let this multiple be `N`.
-///   2. The number of the (trailing) elements in `srcVectorOfSubByteType` is a
-///   multiple of `N` from 1. (e.g., when targetting i8, 2xi4 is OK, but 3xi4 is
-///   not supported).
+/// Verify that `subByteVecTy` (vector) and `containerTy` (scalar) are aligned.
+///
+/// Alignment means that `subByteVecTy` can be packed into a vector of
+/// `containerTy` elements. More specifically:
+///   1. The bit-width of `containerTy` is a multiple of the
+///      bit-width of `subByteVecTy` elements. For example, for `i4` and `i16`
+///      this multiple is 4.
+///   2. The multiple from 1. above divides evenly the number of the (trailing)
+///      elements in `subByteVecTy`.
+///
+/// EXAMPLE 1:
+///   `subByteVecTy = vector<2xi4>`, and
+///   `containerTy = i16`
+///
+/// 2 divides evenly 4 ( = 16 / 4), hence both conditions are _met_.
+///
+/// EXAMPLE 2:
+///   `subByteVecTy = vector<3xi4>`, and
+///   `containerTy = i16`
+///
+/// 3 _does not_ divide evenly 4 (= 16/4), hence the conditions are _not met_.
+///
+/// EXAMPLE 3:
+///   `subByteVecTy = vector<3xi3>`, and
+///   `containerTy = i16`
+///
+/// 16 _is not_ a multiple of 3, hence the conditions are _not met_.
 ///
 /// NOTE: This method assumes that common conversion preconditions are met. In
-/// particular, the element type of `dstType` is assumed to be a multi-byte
-/// type (e.g. i8, i16, i32).
+/// particular, `containerTy` is assumed to be a
+/// multi-byte scalar type (e.g., i8, i16, i32).
 static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
-                                                   VectorType subByteVecType,
-                                                   VectorType dstType,
+                                                   VectorType subByteVecTy,
+                                                   Type containerTy,
                                                    Operation *op) {
-  if (!subByteVecType || !dstType)
-    return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
-  unsigned srcElemBitwidth = subByteVecType.getElementTypeBitWidth();
-  unsigned dstElemBitwidth = dstType.getElementTypeBitWidth();
+  // TODO: This is validating the inputs rather than checking the conditions
+  // documented above. Replace with an assert.
+  if (!subByteVecTy)
+    return rewriter.notifyMatchFailure(op, "not a vector!");
 
-  if (dstElemBitwidth < 8)
-    return rewriter.notifyMatchFailure(
-        op, "the bitwidth of dstType must be greater than or equal to 8");
-  if (dstElemBitwidth % srcElemBitwidth != 0)
-    return rewriter.notifyMatchFailure(op, "unaligned cases are not supported");
-  if (srcElemBitwidth != 2 && srcElemBitwidth != 4)
+  // TODO: This is validating the inputs rather than checking the conditions
+  // documented above. Replace with an assert.
+  if (!containerTy.isIntOrFloat())
+    return rewriter.notifyMatchFailure(op, "not a scalar!");
+
+  unsigned subByteBits = subByteVecTy.getElementTypeBitWidth();
+  unsigned multiByteBits = containerTy.getIntOrFloatBitWidth();
+
+  // Enforced by the common pre-conditions.
+  assert(multiByteBits % 8 == 0 && "Not a multi-byte scalar type!");
+
+  // TODO: Remove this condition - the assert above (and
+  // commonConversionPrecondtion) takes care of that.
+  if (multiByteBits < 8)
+    return rewriter.notifyMatchFailure(op, "not a multi-byte scalar type!");
+
+  // TODO: Add support other widths (when/if needed)
+  if (subByteBits != 2 && subByteBits != 4)
     return rewriter.notifyMatchFailure(
-        op, "only src bitwidth of 2 or 4 is supported at this moment");
+        op, "only 2-bit and 4-bit sub-byte type is supported at this moment");
+
+  // Condition 1.
+  if (multiByteBits % subByteBits != 0)
+    return rewriter.notifyMatchFailure(op, "unalagined element types");
 
-  const int numSrcElemsPerByte = 8 / srcElemBitwidth;
-  if ((subByteVecType.getShape().back() % numSrcElemsPerByte) != 0)
+  // Condition 2.
+  if (!isSubByteVecFittable(subByteVecTy, containerTy))
     return rewriter.notifyMatchFailure(
-        op, "the trailing dimension of the input vector of sub-bytes must be a "
-            "multiple of 8 / <sub-byte-width>");
+        op, "not possible to fit this sub-byte vector type into a vector of "
+            "the given multi-byte type");
 
   return success();
 }
@@ -1858,8 +1926,9 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
       return failure();
 
     // Check general alignment preconditions.
-    if (failed(alignedConversionPrecondition(rewriter, srcVecType, dstVecType,
-                                             conversionOp)))
+    Type containerType = rewriter.getI8Type();
+    if (failed(alignedConversionPrecondition(rewriter, srcVecType,
+                                             containerType, conversionOp)))
       return failure();
 
     // Perform the rewrite.
@@ -1923,8 +1992,9 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
 
     // Check general alignment preconditions. We invert the src/dst type order
     // to reuse the existing precondition logic.
-    if (failed(alignedConversionPrecondition(rewriter, dstVecType, srcVecType,
-                                             truncOp)))
+    Type containerType = rewriter.getI8Type();
+    if (failed(alignedConversionPrecondition(rewriter, dstVecType,
+                                             containerType, truncOp)))
       return failure();
 
     // Create a new iX -> i8 truncation op.



More information about the Mlir-commits mailing list