[Mlir-commits] [mlir] [mlir] Rewrites for I2 to I8 signed and unsigned extension (PR #121298)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 10 11:01:17 PST 2025


https://github.com/ziereis updated https://github.com/llvm/llvm-project/pull/121298

>From 908a6eba949f179cdf6412de97814bb46476cc72 Mon Sep 17 00:00:00 2001
From: Thomas Ziereis <ziereis at roofline.ai>
Date: Sun, 29 Dec 2024 16:35:54 +0100
Subject: [PATCH 01/13] init

---
 .../Transforms/VectorEmulateNarrowType.cpp    | 137 +++++++++++++++-
 .../Vector/vector-rewrite-narrow-types.mlir   | 149 +++++++++++++++++-
 2 files changed, 276 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index d04f302200519e..c44744105ba2f3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1090,8 +1090,8 @@ static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
   unsigned srcElemBitwidth = subByteVecType.getElementTypeBitWidth();
   unsigned dstElemBitwidth = dstType.getElementTypeBitWidth();
 
-  // Only {s}i4 -> (size_of({{s}i/f}) >= 8) are supported for now.
-  if (srcElemBitwidth != 4 || dstElemBitwidth < 8 ||
+  // Only {s}i4/i2 -> (size_of({{s}i/f}) >= 8) are supported for now.
+  if ((srcElemBitwidth != 4 && srcElemBitwidth != 2) || dstElemBitwidth < 8 ||
       (dstElemBitwidth % srcElemBitwidth) != 0)
     return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
 
@@ -1240,6 +1240,117 @@ static Value rewriteI4ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
   return rewriter.create<vector::InterleaveOp>(loc, low, high);
 }
 
+/// Rewrite the i2 -> i8 signed extension into a sequence of shuffles and
+/// bitwise ops that take advantage of high-level information to avoid leaving
+/// LLVM to scramble with peephole optimizations.
+static Value rewriteI2ToI8SignedExt(PatternRewriter &rewriter, Location loc,
+                                    Value srcValue) {
+  VectorType srcVecType = cast<VectorType>(srcValue.getType());
+  assert(srcVecType.getElementType().isSignlessInteger(2) &&
+         "Expected i2 type");
+
+  // 1. Generate a bitcast vector<Xxi2> -> vector<X/4xi8>.
+  SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
+  constexpr int64_t i2Toi8BitwidthFactor = 4;
+  i8VecShape.back() = i8VecShape.back() / i2Toi8BitwidthFactor;
+  auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
+  Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
+
+  // Position 0 (bits 0-1)
+  constexpr int8_t shiftConst6 = 6;
+  auto shiftAttr6 = DenseElementsAttr::get(i8VecType, shiftConst6);
+  auto shiftValues6 = rewriter.create<arith::ConstantOp>(loc, shiftAttr6);
+  Value shl0 = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues6);
+  Value elem0 = rewriter.create<arith::ShRSIOp>(loc, shl0, shiftValues6);
+
+  // Position 1 (bits 2-3)
+  constexpr int8_t shiftConst4 = 4;
+  auto shiftAttr4 = DenseElementsAttr::get(i8VecType, shiftConst4);
+  auto shiftValues4 = rewriter.create<arith::ConstantOp>(loc, shiftAttr4);
+  Value shl1 = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues4);
+  Value elem1 = rewriter.create<arith::ShRSIOp>(loc, shl1, shiftValues6);
+
+  // Position 1 (bits 4-5)
+  constexpr int8_t shiftConst2 = 2;
+  auto shiftAttr2 = DenseElementsAttr::get(i8VecType, shiftConst2);
+  auto shiftValues2 = rewriter.create<arith::ConstantOp>(loc, shiftAttr2);
+  Value shl2 = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues2);
+  Value elem2 = rewriter.create<arith::ShRSIOp>(loc, shl2, shiftValues6);
+
+  // Position 3 (bits 6-7)
+  Value elem3 = rewriter.create<arith::ShRSIOp>(loc, i8Vector, shiftValues6);
+
+  // interleave all 4 elements by first interleaving even elements and then odd
+  // elem0 = [0,0,0,0]
+  // elem1 = [1,1,1,1]
+  // elem2 = [2,2,2,2]
+  // elem3 = [3,3,3,3]
+  // 02    = [0,2,0,2]
+  // 13    = [1,3,1,3]
+  // 0213  = [0,1,2,3]
+  Value interleave02 = rewriter.create<vector::InterleaveOp>(loc, elem0, elem2);
+  Value interleave13 = rewriter.create<vector::InterleaveOp>(loc, elem1, elem3);
+  return rewriter.create<vector::InterleaveOp>(loc, interleave02, interleave13);
+}
+
+/// Rewrite the i2 -> i8 unsigned extension into a sequence of shuffles and
+/// bitwise ops that take advantage of high-level information to avoid leaving
+/// LLVM to scramble with peephole optimizations.
+static Value rewriteI2ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
+                                      Value srcValue) {
+  VectorType srcVecType = cast<VectorType>(srcValue.getType());
+  assert(srcVecType.getElementType().isSignlessInteger(2) &&
+         "Expected i2 type");
+
+  // 1. Generate a bitcast vector<Xxi2> -> vector<X/4xi8>.
+  SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
+  constexpr int64_t i2Toi8BitwidthFactor = 4;
+  i8VecShape.back() = i8VecShape.back() / i2Toi8BitwidthFactor;
+  auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
+  Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
+
+  // 2. Extract each i2 element using shifts and masks
+  constexpr uint8_t mask = 3; // Mask for 2 bits: [0000 0011]
+  auto maskAttr = DenseElementsAttr::get(i8VecType, mask);
+  auto maskValues = rewriter.create<arith::ConstantOp>(loc, maskAttr);
+
+  // Element 0 (bits 0-1)
+  Value elem0 = rewriter.create<arith::AndIOp>(loc, i8Vector, maskValues);
+
+  // Element 1 (bits 2-3)
+  constexpr int8_t shift1 = 2;
+  auto shiftAttr1 = DenseElementsAttr::get(i8VecType, shift1);
+  auto shiftValues1 = rewriter.create<arith::ConstantOp>(loc, shiftAttr1);
+  Value shifted1 = rewriter.create<arith::ShRUIOp>(loc, i8Vector, shiftValues1);
+  Value elem1 = rewriter.create<arith::AndIOp>(loc, shifted1, maskValues);
+
+  // Element 2 (bits 4-5)
+  constexpr int8_t shift2 = 4;
+  auto shiftAttr2 = DenseElementsAttr::get(i8VecType, shift2);
+  auto shiftValues2 = rewriter.create<arith::ConstantOp>(loc, shiftAttr2);
+  Value shifted2 = rewriter.create<arith::ShRUIOp>(loc, i8Vector, shiftValues2);
+  Value elem2 = rewriter.create<arith::AndIOp>(loc, shifted2, maskValues);
+
+  // Element 3 (bits 6-7)
+  constexpr int8_t shift3 = 6;
+  auto shiftAttr3 = DenseElementsAttr::get(i8VecType, shift3);
+  auto shiftValues3 = rewriter.create<arith::ConstantOp>(loc, shiftAttr3);
+  Value shifted3 = rewriter.create<arith::ShRUIOp>(loc, i8Vector, shiftValues3);
+  Value elem3 = rewriter.create<arith::AndIOp>(loc, shifted3, maskValues);
+
+  // interleave all 4 elements by first interleaving even elements and then odd
+  // elem0 = [0,0,0,0]
+  // elem1 = [1,1,1,1]
+  // elem2 = [2,2,2,2]
+  // elem3 = [3,3,3,3]
+  // 02    = [0,2,0,2]
+  // 13    = [1,3,1,3]
+  // 0213  = [0,1,2,3]
+  Value interleave02 = rewriter.create<vector::InterleaveOp>(loc, elem0, elem2);
+  Value interleave13 = rewriter.create<vector::InterleaveOp>(loc, elem1, elem3);
+  return rewriter.create<vector::InterleaveOp>(loc, interleave02, interleave13);
+}
+
 /// Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise
 /// ops that take advantage of high-level information to avoid leaving LLVM to
 /// scramble with peephole optimizations.
@@ -1445,11 +1556,21 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
     // Perform the rewrite.
     Value subByteExt;
     if (isSigned) {
-      subByteExt =
-          rewriteI4ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue);
+      if (srcVecType.getElementType().getIntOrFloatBitWidth() == 2)
+        subByteExt =
+            rewriteI2ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue);
+      else {
+        subByteExt =
+            rewriteI4ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue);
+      }
     } else {
-      subByteExt =
-          rewriteI4ToI8UnsignedExt(rewriter, conversionOp.getLoc(), srcValue);
+      if (srcVecType.getElementType().getIntOrFloatBitWidth() == 2) {
+        subByteExt =
+            rewriteI2ToI8UnsignedExt(rewriter, conversionOp.getLoc(), srcValue);
+      } else {
+        subByteExt =
+            rewriteI4ToI8UnsignedExt(rewriter, conversionOp.getLoc(), srcValue);
+      }
     }
 
     // Finalize the rewrite.
@@ -1496,6 +1617,10 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
                                              truncOp)))
       return failure();
 
+    // not supported currently.
+    if (dstVecType.getElementType().getIntOrFloatBitWidth() == 2)
+      return failure();
+
     // Create a new iX -> i8 truncation op.
     Location loc = truncOp.getLoc();
     auto i8VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI8Type());
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
index 210025e30d7db5..0b469066f290c2 100644
--- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -220,8 +220,8 @@ func.func @aligned_extsi_i4_to_i32(%a: vector<8xi4>) -> vector<8xi32> {
   return %0 : vector<8xi32>
 }
 
-// CHECK-LABEL: func.func @aligned_extsi_2d(
-func.func @aligned_extsi_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
+// CHECK-LABEL: func.func @aligned_extsi_i4_2d(
+func.func @aligned_extsi_i4_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
 // CHECK-SAME:    %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xi32> {
 // CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
 // CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8>
@@ -234,6 +234,72 @@ func.func @aligned_extsi_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
   return %0 : vector<8x32xi32>
 }
 
+// CHECK-LABEL: func.func @aligned_extsi_i2_to_i8(
+func.func @aligned_extsi_i2_to_i8(%a: vector<8xi2>) -> vector<8xi8> {
+// CHECK-SAME:      %[[IN:.*]]: vector<8xi2>) -> vector<8xi8> {
+// CHECK:           %[[CST_2:.*]] = arith.constant dense<2> : vector<2xi8>
+// CHECK:           %[[CST_4:.*]] = arith.constant dense<4> : vector<2xi8>
+// CHECK:           %[[CST_6:.*]] = arith.constant dense<6> : vector<2xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi2> to vector<2xi8>
+// CHECK:           %[[SHL_6:.*]] = arith.shli %[[BITCAST]], %[[CST_6]] : vector<2xi8>
+// CHECK:           %[[ELEM0:.*]] = arith.shrsi %[[SHL_6]], %[[CST_6]] : vector<2xi8>
+// CHECK:           %[[SHL_4:.*]] = arith.shli %[[BITCAST]], %[[CST_4]] : vector<2xi8>
+// CHECK:           %[[ELEM1:.*]] = arith.shrsi %[[SHL_4]], %[[CST_6]] : vector<2xi8>
+// CHECK:           %[[SHL_2:.*]] = arith.shli %[[BITCAST]], %[[CST_2]] : vector<2xi8>
+// CHECK:           %[[ELEM2:.*]] = arith.shrsi %[[SHL_2]], %[[CST_6]] : vector<2xi8>
+// CHECK:           %[[ELEM3:.*]] = arith.shrsi %[[BITCAST]], %[[CST_6]] : vector<2xi8>
+// CHECK:           %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<2xi8>
+// CHECK:           %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<2xi8>
+// CHECK:           %[[RESULT:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<4xi8>
+  %0 = arith.extsi %a : vector<8xi2> to vector<8xi8>
+  return %0 : vector<8xi8>
+}
+
+
+// CHECK-LABEL: func.func @aligned_extsi_i2_to_i32(
+func.func @aligned_extsi_i2_to_i32(%a: vector<8xi2>) -> vector<8xi32> {
+// CHECK-SAME:      %[[IN:.*]]: vector<8xi2>) -> vector<8xi32> {
+// CHECK:           %[[CST_2:.*]] = arith.constant dense<2> : vector<2xi8>
+// CHECK:           %[[CST_4:.*]] = arith.constant dense<4> : vector<2xi8>
+// CHECK:           %[[CST_6:.*]] = arith.constant dense<6> : vector<2xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi2> to vector<2xi8>
+// CHECK:           %[[SHL_6:.*]] = arith.shli %[[BITCAST]], %[[CST_6]] : vector<2xi8>
+// CHECK:           %[[ELEM0:.*]] = arith.shrsi %[[SHL_6]], %[[CST_6]] : vector<2xi8>
+// CHECK:           %[[SHL_4:.*]] = arith.shli %[[BITCAST]], %[[CST_4]] : vector<2xi8>
+// CHECK:           %[[ELEM1:.*]] = arith.shrsi %[[SHL_4]], %[[CST_6]] : vector<2xi8>
+// CHECK:           %[[SHL_2:.*]] = arith.shli %[[BITCAST]], %[[CST_2]] : vector<2xi8>
+// CHECK:           %[[ELEM2:.*]] = arith.shrsi %[[SHL_2]], %[[CST_6]] : vector<2xi8>
+// CHECK:           %[[ELEM3:.*]] = arith.shrsi %[[BITCAST]], %[[CST_6]] : vector<2xi8>
+// CHECK:           %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<2xi8>
+// CHECK:           %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<2xi8>
+// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<4xi8>
+// CHECK:           %[[RESULT:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32>
+  %0 = arith.extsi %a : vector<8xi2> to vector<8xi32>
+  return %0 : vector<8xi32>
+}
+
+// CHECK-LABEL: func.func @aligned_extsi_i2_2d(
+func.func @aligned_extsi_i2_2d(%a: vector<8x32xi2>) -> vector<8x32xi32> {
+// CHECK-SAME:      %[[IN:.*]]: vector<8x32xi2>) -> vector<8x32xi32> {
+// CHECK:           %[[CST_2:.*]] = arith.constant dense<2> : vector<8x8xi8>
+// CHECK:           %[[CST_4:.*]] = arith.constant dense<4> : vector<8x8xi8>
+// CHECK:           %[[CST_6:.*]] = arith.constant dense<6> : vector<8x8xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi2> to vector<8x8xi8>
+// CHECK:           %[[SHL_6:.*]] = arith.shli %[[BITCAST]], %[[CST_6]] : vector<8x8xi8>
+// CHECK:           %[[ELEM0:.*]] = arith.shrsi %[[SHL_6]], %[[CST_6]] : vector<8x8xi8>
+// CHECK:           %[[SHL_4:.*]] = arith.shli %[[BITCAST]], %[[CST_4]] : vector<8x8xi8>
+// CHECK:           %[[ELEM1:.*]] = arith.shrsi %[[SHL_4]], %[[CST_6]] : vector<8x8xi8>
+// CHECK:           %[[SHL_2:.*]] = arith.shli %[[BITCAST]], %[[CST_2]] : vector<8x8xi8>
+// CHECK:           %[[ELEM2:.*]] = arith.shrsi %[[SHL_2]], %[[CST_6]] : vector<8x8xi8>
+// CHECK:           %[[ELEM3:.*]] = arith.shrsi %[[BITCAST]], %[[CST_6]] : vector<8x8xi8>
+// CHECK:           %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<8x8xi8>
+// CHECK:           %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<8x8xi8>
+// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<8x16xi8>
+// CHECK:           %[[RESULT:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32>
+  %0 = arith.extsi %a : vector<8x32xi2> to vector<8x32xi32>
+  return %0 : vector<8x32xi32>
+}
+
 
 // CHECK-LABEL: func.func @aligned_trunci_i8_to_i4(
 func.func @aligned_trunci_i8_to_i4(%a: vector<8xi8>) -> vector<8xi4> {
@@ -292,6 +358,13 @@ func.func @aligned_trunci_nd(%a: vector<3x8x32xi32>) -> vector<3x8x32xi4> {
   return %0 : vector<3x8x32xi4>
 }
 
+func.func @aligned_trunci_i8_to_i2_no_match(%a: vector<8xi8>) -> vector<8xi2> {
+  // CHECK-NOT: arith.bitcast
+  // CHECK: arith.trunci %[[IN:.*]] : vector<8xi8> to vector<8xi2>
+  %0 = arith.trunci %a : vector<8xi8> to vector<8xi2>
+  return %0 : vector<8xi2>
+}
+
 // CHECK-LABEL: func.func @aligned_extui_i4_to_i8(
 func.func @aligned_extui_i4_to_i8(%a: vector<8xi4>) -> vector<8xi8> {
 // CHECK-SAME:                             %[[IN:.*]]: vector<8xi4>) -> vector<8xi8> {
@@ -319,8 +392,8 @@ func.func @aligned_extui_i4_to_i32(%a: vector<8xi4>) -> vector<8xi32> {
   return %0 : vector<8xi32>
 }
 
-// CHECK-LABEL: func.func @aligned_extui_2d(
-func.func @aligned_extui_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
+// CHECK-LABEL: func.func @aligned_extui_i4_2d(
+func.func @aligned_extui_i4_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
 // CHECK-SAME:                                %[[VAL_0:.*]]: vector<8x32xi4>) -> vector<8x32xi32> {
 // CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
 // CHECK:           %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<8x16xi8>
@@ -333,6 +406,74 @@ func.func @aligned_extui_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
   return %0 : vector<8x32xi32>
 }
 
+// CHECK-LABEL: func.func @aligned_extui_i2_to_i8(
+func.func @aligned_extui_i2_to_i8(%a: vector<8xi2>) -> vector<8xi8> {
+// CHECK-SAME:      %[[IN:.*]]: vector<8xi2>) -> vector<8xi8> {
+// CHECK:           %[[CST_6:.*]] = arith.constant dense<6> : vector<2xi8>
+// CHECK:           %[[CST_4:.*]] = arith.constant dense<4> : vector<2xi8>
+// CHECK:           %[[CST_2:.*]] = arith.constant dense<2> : vector<2xi8>
+// CHECK:           %[[LOWBITS_MASK:.*]] = arith.constant dense<3> : vector<2xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi2> to vector<2xi8>
+// CHECK:           %[[ELEM0:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<2xi8>
+// CHECK:           %[[SHR_2:.*]] = arith.shrui %[[BITCAST]], %[[CST_2]] : vector<2xi8>
+// CHECK:           %[[ELEM1:.*]] = arith.andi %[[SHR_2]], %[[LOWBITS_MASK]] : vector<2xi8>
+// CHECK:           %[[SHR_4:.*]] = arith.shrui %[[BITCAST]], %[[CST_4]] : vector<2xi8>
+// CHECK:           %[[ELEM2:.*]] = arith.andi %[[SHR_4]], %[[LOWBITS_MASK]] : vector<2xi8>
+// CHECK:           %[[SHR_6:.*]] = arith.shrui %[[BITCAST]], %[[CST_6]] : vector<2xi8>
+// CHECK:           %[[ELEM3:.*]] = arith.andi %[[SHR_6]], %[[LOWBITS_MASK]] : vector<2xi8>
+// CHECK:           %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<2xi8>
+// CHECK:           %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<2xi8>
+// CHECK:           %[[RESULT:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<4xi8>
+  %0 = arith.extui %a : vector<8xi2> to vector<8xi8>
+  return %0 : vector<8xi8>
+}
+
+// CHECK-LABEL: func.func @aligned_extui_i2_to_i32(
+func.func @aligned_extui_i2_to_i32(%a: vector<8xi2>) -> vector<8xi32> {
+// CHECK-SAME:      %[[IN:.*]]: vector<8xi2>) -> vector<8xi32> {
+// CHECK:           %[[CST_6:.*]] = arith.constant dense<6> : vector<2xi8>
+// CHECK:           %[[CST_4:.*]] = arith.constant dense<4> : vector<2xi8>
+// CHECK:           %[[CST_2:.*]] = arith.constant dense<2> : vector<2xi8>
+// CHECK:           %[[LOWBITS_MASK:.*]] = arith.constant dense<3> : vector<2xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi2> to vector<2xi8>
+// CHECK:           %[[ELEM0:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<2xi8>
+// CHECK:           %[[SHR_2:.*]] = arith.shrui %[[BITCAST]], %[[CST_2]] : vector<2xi8>
+// CHECK:           %[[ELEM1:.*]] = arith.andi %[[SHR_2]], %[[LOWBITS_MASK]] : vector<2xi8>
+// CHECK:           %[[SHR_4:.*]] = arith.shrui %[[BITCAST]], %[[CST_4]] : vector<2xi8>
+// CHECK:           %[[ELEM2:.*]] = arith.andi %[[SHR_4]], %[[LOWBITS_MASK]] : vector<2xi8>
+// CHECK:           %[[SHR_6:.*]] = arith.shrui %[[BITCAST]], %[[CST_6]] : vector<2xi8>
+// CHECK:           %[[ELEM3:.*]] = arith.andi %[[SHR_6]], %[[LOWBITS_MASK]] : vector<2xi8>
+// CHECK:           %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<2xi8>
+// CHECK:           %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<2xi8>
+// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<4xi8>
+// CHECK:           %[[RESULT:.*]] = arith.extui %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32>
+  %0 = arith.extui %a : vector<8xi2> to vector<8xi32>
+  return %0 : vector<8xi32>
+}
+
+// CHECK-LABEL: func.func @aligned_extui_i2_2d(
+func.func @aligned_extui_i2_2d(%a: vector<8x32xi2>) -> vector<8x32xi32> {
+// CHECK-SAME:      %[[IN:.*]]: vector<8x32xi2>) -> vector<8x32xi32> {
+// CHECK:           %[[CST_6:.*]] = arith.constant dense<6> : vector<8x8xi8>
+// CHECK:           %[[CST_4:.*]] = arith.constant dense<4> : vector<8x8xi8>
+// CHECK:           %[[CST_2:.*]] = arith.constant dense<2> : vector<8x8xi8>
+// CHECK:           %[[LOWBITS_MASK:.*]] = arith.constant dense<3> : vector<8x8xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi2> to vector<8x8xi8>
+// CHECK:           %[[ELEM0:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<8x8xi8>
+// CHECK:           %[[SHR_2:.*]] = arith.shrui %[[BITCAST]], %[[CST_2]] : vector<8x8xi8>
+// CHECK:           %[[ELEM1:.*]] = arith.andi %[[SHR_2]], %[[LOWBITS_MASK]] : vector<8x8xi8>
+// CHECK:           %[[SHR_4:.*]] = arith.shrui %[[BITCAST]], %[[CST_4]] : vector<8x8xi8>
+// CHECK:           %[[ELEM2:.*]] = arith.andi %[[SHR_4]], %[[LOWBITS_MASK]] : vector<8x8xi8>
+// CHECK:           %[[SHR_6:.*]] = arith.shrui %[[BITCAST]], %[[CST_6]] : vector<8x8xi8>
+// CHECK:           %[[ELEM3:.*]] = arith.andi %[[SHR_6]], %[[LOWBITS_MASK]] : vector<8x8xi8>
+// CHECK:           %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<8x8xi8>
+// CHECK:           %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<8x8xi8>
+// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<8x16xi8>
+// CHECK:           %[[RESULT:.*]] = arith.extui %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32>
+  %0 = arith.extui %a : vector<8x32xi2> to vector<8x32xi32>
+  return %0 : vector<8x32xi32>
+}
+
 // CHECK-LABEL: func.func @aligned_sitofp(
 func.func @aligned_sitofp(%a: vector<8xi4>) -> vector<8xf32> {
 // CHECK-SAME:    %[[IN:.*]]: vector<8xi4>) -> vector<8xf32> {

>From e3c5c56595bee191e2a9561e4d9c1f3d43212b88 Mon Sep 17 00:00:00 2001
From: Thomas Ziereis <ziereis at roofline.ai>
Date: Sun, 29 Dec 2024 16:47:18 +0100
Subject: [PATCH 02/13] fix typos in comments

---
 .../lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index c44744105ba2f3..4db2e70be01c91 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1249,7 +1249,7 @@ static Value rewriteI2ToI8SignedExt(PatternRewriter &rewriter, Location loc,
   assert(srcVecType.getElementType().isSignlessInteger(2) &&
          "Expected i2 type");
 
-  // 1. Generate a bitcast vector<Xxi2> -> vector<X/4xi8>.
+  // 1. Generate a bitcast vector<Xxi2> -> vector<X/2xi8>.
   SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
   constexpr int64_t i2Toi8BitwidthFactor = 4;
   i8VecShape.back() = i8VecShape.back() / i2Toi8BitwidthFactor;
@@ -1302,7 +1302,7 @@ static Value rewriteI2ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
   assert(srcVecType.getElementType().isSignlessInteger(2) &&
          "Expected i2 type");
 
-  // 1. Generate a bitcast vector<Xxi2> -> vector<X/4xi8>.
+  // 1. Generate a bitcast vector<Xxi2> -> vector<X/2xi8>.
   SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
   constexpr int64_t i2Toi8BitwidthFactor = 4;
   i8VecShape.back() = i8VecShape.back() / i2Toi8BitwidthFactor;

>From 99b678550d00eb8569c45c3aead4cc3b95b5f678 Mon Sep 17 00:00:00 2001
From: Thomas Ziereis <ziereis at roofline.ai>
Date: Sun, 29 Dec 2024 16:48:25 +0100
Subject: [PATCH 03/13] fix typos in comments

---
 .../Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 4db2e70be01c91..44093692823a96 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1256,28 +1256,28 @@ static Value rewriteI2ToI8SignedExt(PatternRewriter &rewriter, Location loc,
   auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
   Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
 
-  // Position 0 (bits 0-1)
+  // Element 0 (bits 0-1)
   constexpr int8_t shiftConst6 = 6;
   auto shiftAttr6 = DenseElementsAttr::get(i8VecType, shiftConst6);
   auto shiftValues6 = rewriter.create<arith::ConstantOp>(loc, shiftAttr6);
   Value shl0 = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues6);
   Value elem0 = rewriter.create<arith::ShRSIOp>(loc, shl0, shiftValues6);
 
-  // Position 1 (bits 2-3)
+  // Element 1 (bits 2-3)
   constexpr int8_t shiftConst4 = 4;
   auto shiftAttr4 = DenseElementsAttr::get(i8VecType, shiftConst4);
   auto shiftValues4 = rewriter.create<arith::ConstantOp>(loc, shiftAttr4);
   Value shl1 = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues4);
   Value elem1 = rewriter.create<arith::ShRSIOp>(loc, shl1, shiftValues6);
 
-  // Position 1 (bits 4-5)
+  // Element 2 (bits 4-5)
   constexpr int8_t shiftConst2 = 2;
   auto shiftAttr2 = DenseElementsAttr::get(i8VecType, shiftConst2);
   auto shiftValues2 = rewriter.create<arith::ConstantOp>(loc, shiftAttr2);
   Value shl2 = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues2);
   Value elem2 = rewriter.create<arith::ShRSIOp>(loc, shl2, shiftValues6);
 
-  // Position 3 (bits 6-7)
+  // Element 3 (bits 6-7)
   Value elem3 = rewriter.create<arith::ShRSIOp>(loc, i8Vector, shiftValues6);
 
   // interleave all 4 elements by first interleaving even elements and then odd

>From 303416a09d7224e6be9ffb7c22345b5698453627 Mon Sep 17 00:00:00 2001
From: Thomas Ziereis <ziereis at roofline.ai>
Date: Mon, 6 Jan 2025 12:44:34 +0100
Subject: [PATCH 04/13] refactoring: - extracts repeated code into functions -
 reorder tests - improve naming

---
 .../Transforms/VectorEmulateNarrowType.cpp    | 236 +++++++++---------
 .../Vector/vector-rewrite-narrow-types.mlir   | 128 +++++-----
 2 files changed, 178 insertions(+), 186 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 44093692823a96..5a8413982c3f60 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1179,70 +1179,62 @@ Value BitCastRewriter::genericRewriteStep(
   return runningResult;
 }
 
-/// Rewrite the i4 -> i8 signed extension into a sequence of shuffles and
-/// bitwise ops that take advantage of high-level information to avoid leaving
-/// LLVM to scramble with peephole optimizations.
-static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
-                                    Value srcValue) {
+Value bitcastSubByteVectorToI8(PatternRewriter &rewriter, Location loc,
+                               Value srcValue) {
   VectorType srcVecType = cast<VectorType>(srcValue.getType());
-  assert(srcVecType.getElementType().isSignlessInteger(4) &&
-         "Expected i4 type");
-
-  // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
+  int64_t srcBitwidth = srcVecType.getElementType().getIntOrFloatBitWidth();
+  assert(srcBitwidth % 8 != 0 && "Invalid source bitwidth");
+  int64_t bitwidthFactor = 8 / srcBitwidth;
   SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
-  constexpr int64_t i4Toi8BitwidthFactor = 2;
-  i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor;
+  i8VecShape.back() = i8VecShape.back() / bitwidthFactor;
   auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
-  Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
+  return rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
+}
 
-  // 2. Extend i4 elements to i8 elements using shifts. Low i4 elemens of each
-  // byte are place in one vector and the high i4 elements in another vector.
-  constexpr int8_t bitsToShift = 4;
-  auto shiftValues = rewriter.create<arith::ConstantOp>(
-      loc, DenseElementsAttr::get(i8VecType, bitsToShift));
-  Value shl = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues);
-  Value low = rewriter.create<arith::ShRSIOp>(loc, shl, shiftValues);
-  Value high = rewriter.create<arith::ShRSIOp>(loc, i8Vector, shiftValues);
+/// Extracts a signed N-bit sequence from each element of an 8-bit vector,
+/// starting at the specified bit index.
+Value extractNBitsFromVectorSigned(PatternRewriter &rewriter, Location loc,
+                                   Value src, int bitIdx, int numBits) {
+  assert(bitIdx >= 0 && bitIdx <= 8 - numBits && numBits > 0 && numBits <= 8 &&
+         "Invalid bitIdx range");
+  auto srcType = cast<VectorType>(src.getType());
+  Value shl = src;
+  int8_t bitsToShiftLeft = 8 - numBits - bitIdx;
+  if (bitsToShiftLeft != 0) {
+    Value shiftLeftValues = rewriter.create<arith::ConstantOp>(
+        loc, DenseElementsAttr::get(srcType, bitsToShiftLeft));
+    shl = rewriter.create<arith::ShLIOp>(loc, src, shiftLeftValues);
+  }
 
-  // 3. Interleave low and high i8 elements.
-  return rewriter.create<vector::InterleaveOp>(loc, low, high);
+  int8_t bitsToShiftRight = 8 - numBits;
+  Value shiftRightValues = rewriter.create<arith::ConstantOp>(
+      loc, DenseElementsAttr::get(srcType, bitsToShiftRight));
+  Value shr = rewriter.create<arith::ShRSIOp>(loc, shl, shiftRightValues);
+  return shr;
 }
 
-/// Rewrite the i4 -> i8 unsigned extension into a sequence of shuffles and
-/// bitwise ops that take advantage of high-level information to avoid leaving
-/// LLVM to scramble with peephole optimizations.
-static Value rewriteI4ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
-                                      Value srcValue) {
+/// Rewrite the i4 -> i8 signed extension into a sequence of shuffles and
+/// bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
+static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
+                                    Value srcValue) {
   VectorType srcVecType = cast<VectorType>(srcValue.getType());
   assert(srcVecType.getElementType().isSignlessInteger(4) &&
          "Expected i4 type");
 
   // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
-  SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
-  constexpr int64_t i4Toi8BitwidthFactor = 2;
-  i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor;
-  auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
-  Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
+  Value i8Vector = bitcastSubByteVectorToI8(rewriter, loc, srcValue);
 
-  // 2 Extend the i4 elements using shifts & masking. Low i4 elements of each
-  //  byte are placed in one vector and the high i4 elements in another vector.
-  constexpr uint8_t lowBitsMask = 15; // Equivalent to [00001111] bit mask
-  auto lowBitsMaskValues = rewriter.create<arith::ConstantOp>(
-      loc, DenseElementsAttr::get(i8VecType, lowBitsMask));
-  Value low = rewriter.create<arith::AndIOp>(loc, i8VecType, i8Vector,
-                                             lowBitsMaskValues);
-  constexpr int8_t highBitsToShift = 4;
-  auto highShiftValues = rewriter.create<arith::ConstantOp>(
-      loc, DenseElementsAttr::get(i8VecType, highBitsToShift));
-  Value high = rewriter.create<arith::ShRUIOp>(loc, i8Vector, highShiftValues);
+  // 2. Extend i4 elements to i8 elements using shifts. Low i4 elemens of each
+  // byte are place in one vector and the high i4 elements in another vector.
+  Value low = extractNBitsFromVectorSigned(rewriter, loc, i8Vector, 0, 4);
+  Value high = extractNBitsFromVectorSigned(rewriter, loc, i8Vector, 4, 4);
 
   // 3. Interleave low and high i8 elements.
   return rewriter.create<vector::InterleaveOp>(loc, low, high);
 }
 
 /// Rewrite the i2 -> i8 signed extension into a sequence of shuffles and
-/// bitwise ops that take advantage of high-level information to avoid leaving
-/// LLVM to scramble with peephole optimizations.
+/// bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
 static Value rewriteI2ToI8SignedExt(PatternRewriter &rewriter, Location loc,
                                     Value srcValue) {
   VectorType srcVecType = cast<VectorType>(srcValue.getType());
@@ -1250,41 +1242,20 @@ static Value rewriteI2ToI8SignedExt(PatternRewriter &rewriter, Location loc,
          "Expected i2 type");
 
   // 1. Generate a bitcast vector<Xxi2> -> vector<X/2xi8>.
-  SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
-  constexpr int64_t i2Toi8BitwidthFactor = 4;
-  i8VecShape.back() = i8VecShape.back() / i2Toi8BitwidthFactor;
-  auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
-  Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
+  Value i8Vector = bitcastSubByteVectorToI8(rewriter, loc, srcValue);
 
+  // 2. Extract each i2 element using shifts
   // Element 0 (bits 0-1)
-  constexpr int8_t shiftConst6 = 6;
-  auto shiftAttr6 = DenseElementsAttr::get(i8VecType, shiftConst6);
-  auto shiftValues6 = rewriter.create<arith::ConstantOp>(loc, shiftAttr6);
-  Value shl0 = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues6);
-  Value elem0 = rewriter.create<arith::ShRSIOp>(loc, shl0, shiftValues6);
-
+  Value elem0 = extractNBitsFromVectorSigned(rewriter, loc, i8Vector, 0, 2);
   // Element 1 (bits 2-3)
-  constexpr int8_t shiftConst4 = 4;
-  auto shiftAttr4 = DenseElementsAttr::get(i8VecType, shiftConst4);
-  auto shiftValues4 = rewriter.create<arith::ConstantOp>(loc, shiftAttr4);
-  Value shl1 = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues4);
-  Value elem1 = rewriter.create<arith::ShRSIOp>(loc, shl1, shiftValues6);
-
+  Value elem1 = extractNBitsFromVectorSigned(rewriter, loc, i8Vector, 2, 2);
   // Element 2 (bits 4-5)
-  constexpr int8_t shiftConst2 = 2;
-  auto shiftAttr2 = DenseElementsAttr::get(i8VecType, shiftConst2);
-  auto shiftValues2 = rewriter.create<arith::ConstantOp>(loc, shiftAttr2);
-  Value shl2 = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues2);
-  Value elem2 = rewriter.create<arith::ShRSIOp>(loc, shl2, shiftValues6);
-
+  Value elem2 = extractNBitsFromVectorSigned(rewriter, loc, i8Vector, 4, 2);
   // Element 3 (bits 6-7)
-  Value elem3 = rewriter.create<arith::ShRSIOp>(loc, i8Vector, shiftValues6);
+  Value elem3 = extractNBitsFromVectorSigned(rewriter, loc, i8Vector, 6, 2);
 
-  // interleave all 4 elements by first interleaving even elements and then odd
-  // elem0 = [0,0,0,0]
-  // elem1 = [1,1,1,1]
-  // elem2 = [2,2,2,2]
-  // elem3 = [3,3,3,3]
+  // 3. Interleave all 4 elements by first interleaving even elements and then
+  // odd elem0 = [0,0,0,0] elem1 = [1,1,1,1] elem2 = [2,2,2,2] elem3 = [3,3,3,3]
   // 02    = [0,2,0,2]
   // 13    = [1,3,1,3]
   // 0213  = [0,1,2,3]
@@ -1293,9 +1264,51 @@ static Value rewriteI2ToI8SignedExt(PatternRewriter &rewriter, Location loc,
   return rewriter.create<vector::InterleaveOp>(loc, interleave02, interleave13);
 }
 
+/// Extracts an unsigned N-bit sequence from each element of an 8-bit vector,
+/// starting at the specified bit index.
+Value extractNBitsFromVectorUnsinged(PatternRewriter &rewriter, Location loc,
+                                     Value src, int bitIdx, int numBits) {
+  assert(bitIdx >= 0 && bitIdx <= 8 - numBits && numBits > 0 && numBits <= 8 &&
+         "Invalid bitIdx range");
+  auto srcType = cast<VectorType>(src.getType());
+  int8_t bitsToShiftRight = bitIdx;
+  Value shr = src;
+  if (bitsToShiftRight != 0) {
+    Value shiftRightValues = rewriter.create<arith::ConstantOp>(
+        loc, DenseElementsAttr::get(srcType, bitsToShiftRight));
+    shr = rewriter.create<arith::ShRUIOp>(loc, src, shiftRightValues);
+  }
+  if (bitIdx + numBits == 8) {
+    return shr;
+  }
+  uint8_t lowBitsMask = (1 << numBits) - 1;
+  Value lowBitsMaskValues = rewriter.create<arith::ConstantOp>(
+      loc, DenseElementsAttr::get(srcType, lowBitsMask));
+  return rewriter.create<arith::AndIOp>(loc, shr, lowBitsMaskValues);
+}
+
+/// Rewrite the i4 -> i8 unsigned extension into a sequence of shuffles and
+/// bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
+static Value rewriteI4ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
+                                      Value srcValue) {
+  VectorType srcVecType = cast<VectorType>(srcValue.getType());
+  assert(srcVecType.getElementType().isSignlessInteger(4) &&
+         "Expected i4 type");
+
+  // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
+  Value i8Vector = bitcastSubByteVectorToI8(rewriter, loc, srcValue);
+
+  // 2 Extend the i4 elements using shifts & masking. Low i4 elements of each
+  //  byte are placed in one vector and the high i4 elements in another vector.
+  Value low = extractNBitsFromVectorUnsinged(rewriter, loc, i8Vector, 0, 4);
+  Value high = extractNBitsFromVectorUnsinged(rewriter, loc, i8Vector, 4, 4);
+
+  // 3. Interleave low and high i8 elements.
+  return rewriter.create<vector::InterleaveOp>(loc, low, high);
+}
+
 /// Rewrite the i2 -> i8 unsigned extension into a sequence of shuffles and
-/// bitwise ops that take advantage of high-level information to avoid leaving
-/// LLVM to scramble with peephole optimizations.
+/// bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
 static Value rewriteI2ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
                                       Value srcValue) {
   VectorType srcVecType = cast<VectorType>(srcValue.getType());
@@ -1303,46 +1316,20 @@ static Value rewriteI2ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
          "Expected i2 type");
 
   // 1. Generate a bitcast vector<Xxi2> -> vector<X/2xi8>.
-  SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
-  constexpr int64_t i2Toi8BitwidthFactor = 4;
-  i8VecShape.back() = i8VecShape.back() / i2Toi8BitwidthFactor;
-  auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
-  Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
+  Value i8Vector = bitcastSubByteVectorToI8(rewriter, loc, srcValue);
 
   // 2. Extract each i2 element using shifts and masks
-  constexpr uint8_t mask = 3; // Mask for 2 bits: [0000 0011]
-  auto maskAttr = DenseElementsAttr::get(i8VecType, mask);
-  auto maskValues = rewriter.create<arith::ConstantOp>(loc, maskAttr);
-
   // Element 0 (bits 0-1)
-  Value elem0 = rewriter.create<arith::AndIOp>(loc, i8Vector, maskValues);
-
+  Value elem0 = extractNBitsFromVectorUnsinged(rewriter, loc, i8Vector, 0, 2);
   // Element 1 (bits 2-3)
-  constexpr int8_t shift1 = 2;
-  auto shiftAttr1 = DenseElementsAttr::get(i8VecType, shift1);
-  auto shiftValues1 = rewriter.create<arith::ConstantOp>(loc, shiftAttr1);
-  Value shifted1 = rewriter.create<arith::ShRUIOp>(loc, i8Vector, shiftValues1);
-  Value elem1 = rewriter.create<arith::AndIOp>(loc, shifted1, maskValues);
-
+  Value elem1 = extractNBitsFromVectorUnsinged(rewriter, loc, i8Vector, 2, 2);
   // Element 2 (bits 4-5)
-  constexpr int8_t shift2 = 4;
-  auto shiftAttr2 = DenseElementsAttr::get(i8VecType, shift2);
-  auto shiftValues2 = rewriter.create<arith::ConstantOp>(loc, shiftAttr2);
-  Value shifted2 = rewriter.create<arith::ShRUIOp>(loc, i8Vector, shiftValues2);
-  Value elem2 = rewriter.create<arith::AndIOp>(loc, shifted2, maskValues);
-
+  Value elem2 = extractNBitsFromVectorUnsinged(rewriter, loc, i8Vector, 4, 2);
   // Element 3 (bits 6-7)
-  constexpr int8_t shift3 = 6;
-  auto shiftAttr3 = DenseElementsAttr::get(i8VecType, shift3);
-  auto shiftValues3 = rewriter.create<arith::ConstantOp>(loc, shiftAttr3);
-  Value shifted3 = rewriter.create<arith::ShRUIOp>(loc, i8Vector, shiftValues3);
-  Value elem3 = rewriter.create<arith::AndIOp>(loc, shifted3, maskValues);
-
-  // interleave all 4 elements by first interleaving even elements and then odd
-  // elem0 = [0,0,0,0]
-  // elem1 = [1,1,1,1]
-  // elem2 = [2,2,2,2]
-  // elem3 = [3,3,3,3]
+  Value elem3 = extractNBitsFromVectorUnsinged(rewriter, loc, i8Vector, 6, 2);
+
+  // 3. Interleave all 4 elements by first interleaving even elements and then
+  // odd elem0 = [0,0,0,0] elem1 = [1,1,1,1] elem2 = [2,2,2,2] elem3 = [3,3,3,3]
   // 02    = [0,2,0,2]
   // 13    = [1,3,1,3]
   // 0213  = [0,1,2,3]
@@ -1352,8 +1339,7 @@ static Value rewriteI2ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
 }
 
 /// Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise
-/// ops that take advantage of high-level information to avoid leaving LLVM to
-/// scramble with peephole optimizations.
+/// ops to avoid leaving LLVM to scramble with peephole optimizations.
 static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc,
                                 Value srcValue) {
   VectorType srcVecType = cast<VectorType>(srcValue.getType());
@@ -1556,20 +1542,30 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
     // Perform the rewrite.
     Value subByteExt;
     if (isSigned) {
-      if (srcVecType.getElementType().getIntOrFloatBitWidth() == 2)
+      switch (srcVecType.getElementType().getIntOrFloatBitWidth()) {
+      case 2:
         subByteExt =
             rewriteI2ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue);
-      else {
+        break;
+      case 4:
         subByteExt =
             rewriteI4ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue);
+        break;
+      default:
+        return failure();
       }
     } else {
-      if (srcVecType.getElementType().getIntOrFloatBitWidth() == 2) {
+      switch (srcVecType.getElementType().getIntOrFloatBitWidth()) {
+      case 2:
         subByteExt =
             rewriteI2ToI8UnsignedExt(rewriter, conversionOp.getLoc(), srcValue);
-      } else {
+        break;
+      case 4:
         subByteExt =
             rewriteI4ToI8UnsignedExt(rewriter, conversionOp.getLoc(), srcValue);
+        break;
+      default:
+        return failure();
       }
     }
 
@@ -1611,16 +1607,16 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
     if (failed(commonConversionPrecondition(rewriter, srcVecType, truncOp)))
       return failure();
 
+    // TODO: Add support for truncating to i2.
+    if (dstVecType.getElementType().getIntOrFloatBitWidth() == 2)
+      return failure();
+
     // Check general alignment preconditions. We invert the src/dst type order
     // to reuse the existing precondition logic.
     if (failed(alignedConversionPrecondition(rewriter, dstVecType, srcVecType,
                                              truncOp)))
       return failure();
 
-    // not supported currently.
-    if (dstVecType.getElementType().getIntOrFloatBitWidth() == 2)
-      return failure();
-
     // Create a new iX -> i8 truncation op.
     Location loc = truncOp.getLoc();
     auto i8VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI8Type());
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
index 0b469066f290c2..3d37fe4efa40c4 100644
--- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -206,34 +206,6 @@ func.func @aligned_extsi_i4_to_i8(%a: vector<8xi4>) -> vector<8xi8> {
   return %0 : vector<8xi8>
 }
 
-// CHECK-LABEL: func.func @aligned_extsi_i4_to_i32(
-func.func @aligned_extsi_i4_to_i32(%a: vector<8xi4>) -> vector<8xi32> {
-// CHECK-SAME:    %[[IN:.*]]: vector<8xi4>) -> vector<8xi32> {
-// CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
-// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
-// CHECK:           %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
-// CHECK:           %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<4xi8>
-// CHECK:           %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
-// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
-// CHECK:           %[[I32:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32>
-  %0 = arith.extsi %a : vector<8xi4> to vector<8xi32>
-  return %0 : vector<8xi32>
-}
-
-// CHECK-LABEL: func.func @aligned_extsi_i4_2d(
-func.func @aligned_extsi_i4_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
-// CHECK-SAME:    %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xi32> {
-// CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
-// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8>
-// CHECK:           %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
-// CHECK:           %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<8x16xi8>
-// CHECK:           %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
-// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8>
-// CHECK:           %[[I32:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32>
-  %0 = arith.extsi %a : vector<8x32xi4> to vector<8x32xi32>
-  return %0 : vector<8x32xi32>
-}
-
 // CHECK-LABEL: func.func @aligned_extsi_i2_to_i8(
 func.func @aligned_extsi_i2_to_i8(%a: vector<8xi2>) -> vector<8xi8> {
 // CHECK-SAME:      %[[IN:.*]]: vector<8xi2>) -> vector<8xi8> {
@@ -255,6 +227,19 @@ func.func @aligned_extsi_i2_to_i8(%a: vector<8xi2>) -> vector<8xi8> {
   return %0 : vector<8xi8>
 }
 
+// CHECK-LABEL: func.func @aligned_extsi_i4_to_i32(
+func.func @aligned_extsi_i4_to_i32(%a: vector<8xi4>) -> vector<8xi32> {
+// CHECK-SAME:    %[[IN:.*]]: vector<8xi4>) -> vector<8xi32> {
+// CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
+// CHECK:           %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK:           %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<4xi8>
+// CHECK:           %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
+// CHECK:           %[[I32:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32>
+  %0 = arith.extsi %a : vector<8xi4> to vector<8xi32>
+  return %0 : vector<8xi32>
+}
 
 // CHECK-LABEL: func.func @aligned_extsi_i2_to_i32(
 func.func @aligned_extsi_i2_to_i32(%a: vector<8xi2>) -> vector<8xi32> {
@@ -278,8 +263,22 @@ func.func @aligned_extsi_i2_to_i32(%a: vector<8xi2>) -> vector<8xi32> {
   return %0 : vector<8xi32>
 }
 
-// CHECK-LABEL: func.func @aligned_extsi_i2_2d(
-func.func @aligned_extsi_i2_2d(%a: vector<8x32xi2>) -> vector<8x32xi32> {
+// CHECK-LABEL: func.func @aligned_extsi_i4_to_i32_2d(
+func.func @aligned_extsi_i4_to_i32_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
+// CHECK-SAME:    %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xi32> {
+// CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8>
+// CHECK:           %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK:           %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK:           %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8>
+// CHECK:           %[[I32:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32>
+  %0 = arith.extsi %a : vector<8x32xi4> to vector<8x32xi32>
+  return %0 : vector<8x32xi32>
+}
+
+// CHECK-LABEL: func.func @aligned_extsi_i2_to_i32_2d(
+func.func @aligned_extsi_i2_to_i32_2d(%a: vector<8x32xi2>) -> vector<8x32xi32> {
 // CHECK-SAME:      %[[IN:.*]]: vector<8x32xi2>) -> vector<8x32xi32> {
 // CHECK:           %[[CST_2:.*]] = arith.constant dense<2> : vector<8x8xi8>
 // CHECK:           %[[CST_4:.*]] = arith.constant dense<4> : vector<8x8xi8>
@@ -378,34 +377,6 @@ func.func @aligned_extui_i4_to_i8(%a: vector<8xi4>) -> vector<8xi8> {
   return %0 : vector<8xi8>
 }
 
-// CHECK-LABEL: func.func @aligned_extui_i4_to_i32(
-func.func @aligned_extui_i4_to_i32(%a: vector<8xi4>) -> vector<8xi32> {
-// CHECK-SAME:                             %[[IN:.*]]: vector<8xi4>) -> vector<8xi32> {
-// CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
-// CHECK:           %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
-// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
-// CHECK:           %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<4xi8>
-// CHECK:           %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
-// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
-// CHECK:           %[[I32:.*]] = arith.extui %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32>
-  %0 = arith.extui %a : vector<8xi4> to vector<8xi32>
-  return %0 : vector<8xi32>
-}
-
-// CHECK-LABEL: func.func @aligned_extui_i4_2d(
-func.func @aligned_extui_i4_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
-// CHECK-SAME:                                %[[VAL_0:.*]]: vector<8x32xi4>) -> vector<8x32xi32> {
-// CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
-// CHECK:           %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<8x16xi8>
-// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[VAL_0]] : vector<8x32xi4> to vector<8x16xi8>
-// CHECK:           %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<8x16xi8>
-// CHECK:           %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
-// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8>
-// CHECK:           %[[I32:.*]] = arith.extui %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32>
-  %0 = arith.extui %a : vector<8x32xi4> to vector<8x32xi32>
-  return %0 : vector<8x32xi32>
-}
-
 // CHECK-LABEL: func.func @aligned_extui_i2_to_i8(
 func.func @aligned_extui_i2_to_i8(%a: vector<8xi2>) -> vector<8xi8> {
 // CHECK-SAME:      %[[IN:.*]]: vector<8xi2>) -> vector<8xi8> {
@@ -419,8 +390,7 @@ func.func @aligned_extui_i2_to_i8(%a: vector<8xi2>) -> vector<8xi8> {
 // CHECK:           %[[ELEM1:.*]] = arith.andi %[[SHR_2]], %[[LOWBITS_MASK]] : vector<2xi8>
 // CHECK:           %[[SHR_4:.*]] = arith.shrui %[[BITCAST]], %[[CST_4]] : vector<2xi8>
 // CHECK:           %[[ELEM2:.*]] = arith.andi %[[SHR_4]], %[[LOWBITS_MASK]] : vector<2xi8>
-// CHECK:           %[[SHR_6:.*]] = arith.shrui %[[BITCAST]], %[[CST_6]] : vector<2xi8>
-// CHECK:           %[[ELEM3:.*]] = arith.andi %[[SHR_6]], %[[LOWBITS_MASK]] : vector<2xi8>
+// CHECK:           %[[ELEM3:.*]] = arith.shrui %[[BITCAST]], %[[CST_6]] : vector<2xi8>
 // CHECK:           %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<2xi8>
 // CHECK:           %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<2xi8>
 // CHECK:           %[[RESULT:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<4xi8>
@@ -428,6 +398,20 @@ func.func @aligned_extui_i2_to_i8(%a: vector<8xi2>) -> vector<8xi8> {
   return %0 : vector<8xi8>
 }
 
+// CHECK-LABEL: func.func @aligned_extui_i4_to_i32(
+func.func @aligned_extui_i4_to_i32(%a: vector<8xi4>) -> vector<8xi32> {
+// CHECK-SAME:                             %[[IN:.*]]: vector<8xi4>) -> vector<8xi32> {
+// CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK:           %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
+// CHECK:           %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<4xi8>
+// CHECK:           %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
+// CHECK:           %[[I32:.*]] = arith.extui %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32>
+  %0 = arith.extui %a : vector<8xi4> to vector<8xi32>
+  return %0 : vector<8xi32>
+}
+
 // CHECK-LABEL: func.func @aligned_extui_i2_to_i32(
 func.func @aligned_extui_i2_to_i32(%a: vector<8xi2>) -> vector<8xi32> {
 // CHECK-SAME:      %[[IN:.*]]: vector<8xi2>) -> vector<8xi32> {
@@ -441,8 +425,7 @@ func.func @aligned_extui_i2_to_i32(%a: vector<8xi2>) -> vector<8xi32> {
 // CHECK:           %[[ELEM1:.*]] = arith.andi %[[SHR_2]], %[[LOWBITS_MASK]] : vector<2xi8>
 // CHECK:           %[[SHR_4:.*]] = arith.shrui %[[BITCAST]], %[[CST_4]] : vector<2xi8>
 // CHECK:           %[[ELEM2:.*]] = arith.andi %[[SHR_4]], %[[LOWBITS_MASK]] : vector<2xi8>
-// CHECK:           %[[SHR_6:.*]] = arith.shrui %[[BITCAST]], %[[CST_6]] : vector<2xi8>
-// CHECK:           %[[ELEM3:.*]] = arith.andi %[[SHR_6]], %[[LOWBITS_MASK]] : vector<2xi8>
+// CHECK:           %[[ELEM3:.*]] = arith.shrui %[[BITCAST]], %[[CST_6]] : vector<2xi8>
 // CHECK:           %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<2xi8>
 // CHECK:           %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<2xi8>
 // CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<4xi8>
@@ -451,8 +434,22 @@ func.func @aligned_extui_i2_to_i32(%a: vector<8xi2>) -> vector<8xi32> {
   return %0 : vector<8xi32>
 }
 
-// CHECK-LABEL: func.func @aligned_extui_i2_2d(
-func.func @aligned_extui_i2_2d(%a: vector<8x32xi2>) -> vector<8x32xi32> {
+// CHECK-LABEL: func.func @aligned_extui_i4_to_i32_2d(
+func.func @aligned_extui_i4_to_i32_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
+// CHECK-SAME:                                %[[VAL_0:.*]]: vector<8x32xi4>) -> vector<8x32xi32> {
+// CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
+// CHECK:           %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<8x16xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[VAL_0]] : vector<8x32xi4> to vector<8x16xi8>
+// CHECK:           %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<8x16xi8>
+// CHECK:           %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8>
+// CHECK:           %[[I32:.*]] = arith.extui %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32>
+  %0 = arith.extui %a : vector<8x32xi4> to vector<8x32xi32>
+  return %0 : vector<8x32xi32>
+}
+
+// CHECK-LABEL: func.func @aligned_extui_i2_to_i32_2d(
+func.func @aligned_extui_i2_to_i32_2d(%a: vector<8x32xi2>) -> vector<8x32xi32> {
 // CHECK-SAME:      %[[IN:.*]]: vector<8x32xi2>) -> vector<8x32xi32> {
 // CHECK:           %[[CST_6:.*]] = arith.constant dense<6> : vector<8x8xi8>
 // CHECK:           %[[CST_4:.*]] = arith.constant dense<4> : vector<8x8xi8>
@@ -464,8 +461,7 @@ func.func @aligned_extui_i2_2d(%a: vector<8x32xi2>) -> vector<8x32xi32> {
 // CHECK:           %[[ELEM1:.*]] = arith.andi %[[SHR_2]], %[[LOWBITS_MASK]] : vector<8x8xi8>
 // CHECK:           %[[SHR_4:.*]] = arith.shrui %[[BITCAST]], %[[CST_4]] : vector<8x8xi8>
 // CHECK:           %[[ELEM2:.*]] = arith.andi %[[SHR_4]], %[[LOWBITS_MASK]] : vector<8x8xi8>
-// CHECK:           %[[SHR_6:.*]] = arith.shrui %[[BITCAST]], %[[CST_6]] : vector<8x8xi8>
-// CHECK:           %[[ELEM3:.*]] = arith.andi %[[SHR_6]], %[[LOWBITS_MASK]] : vector<8x8xi8>
+// CHECK:           %[[ELEM3:.*]] = arith.shrui %[[BITCAST]], %[[CST_6]] : vector<8x8xi8>
 // CHECK:           %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<8x8xi8>
 // CHECK:           %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<8x8xi8>
 // CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<8x16xi8>

>From cfe31bb07b747540107b57da0f88df8b8b340903 Mon Sep 17 00:00:00 2001
From: Thomas Ziereis <ziereis at roofline.ai>
Date: Tue, 7 Jan 2025 19:24:25 +0100
Subject: [PATCH 05/13] refactoring

---
 .../Transforms/VectorEmulateNarrowType.cpp    | 189 ++++++++----------
 1 file changed, 78 insertions(+), 111 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 5a8413982c3f60..5d7fe95e395343 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1090,10 +1090,14 @@ static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
   unsigned srcElemBitwidth = subByteVecType.getElementTypeBitWidth();
   unsigned dstElemBitwidth = dstType.getElementTypeBitWidth();
 
-  // Only {s}i4/i2 -> (size_of({{s}i/f}) >= 8) are supported for now.
-  if ((srcElemBitwidth != 4 && srcElemBitwidth != 2) || dstElemBitwidth < 8 ||
-      (dstElemBitwidth % srcElemBitwidth) != 0)
-    return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
+  if (dstElemBitwidth < 8)
+    return rewriter.notifyMatchFailure(
+        op, "the bitwidth of dstType must be greater than or equal to 8");
+  if (dstElemBitwidth % srcElemBitwidth != 0)
+    return rewriter.notifyMatchFailure(op, "unaligned cases are not supported");
+  if (srcElemBitwidth != 2 && srcElemBitwidth != 4)
+    return rewriter.notifyMatchFailure(
+        op, "only src bitwidth of 2 or 4 is supported at this moment");
 
   const int numSrcElemsPerDestElem = dstElemBitwidth / srcElemBitwidth;
   if ((subByteVecType.getShape().back() % numSrcElemsPerDestElem) != 0)
@@ -1179,22 +1183,35 @@ Value BitCastRewriter::genericRewriteStep(
   return runningResult;
 }
 
-Value bitcastSubByteVectorToI8(PatternRewriter &rewriter, Location loc,
-                               Value srcValue) {
-  VectorType srcVecType = cast<VectorType>(srcValue.getType());
+/// takes a aligned subByte vector as Input and bitcasts it to a vector of i8.
+///
+/// Example:
+/// vector<16x16xi2> -> vector<16x2xi8>
+/// vector<16x16xi4> -> vector<16x4xi8>
+static Value bitcastSubByteVectorToI8(PatternRewriter &rewriter, Location loc,
+                                      Value srcValue) {
+  auto srcVecType = cast<VectorType>(srcValue.getType());
   int64_t srcBitwidth = srcVecType.getElementType().getIntOrFloatBitWidth();
-  assert(srcBitwidth % 8 != 0 && "Invalid source bitwidth");
+  assert(8 % srcBitwidth == 0 && "Invalid source bitwidth");
   int64_t bitwidthFactor = 8 / srcBitwidth;
-  SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
-  i8VecShape.back() = i8VecShape.back() / bitwidthFactor;
-  auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
+  SmallVector<int64_t> vecShape(srcVecType.getShape());
+  // adjust last dimension of the vector so the total size remains the same.
+  vecShape.back() = vecShape.back() / bitwidthFactor;
+  auto i8VecType = VectorType::get(vecShape, rewriter.getI8Type());
   return rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
 }
 
 /// Extracts a signed N-bit sequence from each element of an 8-bit vector,
 /// starting at the specified bit index.
-Value extractNBitsFromVectorSigned(PatternRewriter &rewriter, Location loc,
-                                   Value src, int bitIdx, int numBits) {
+///
+/// Example:
+/// extract numBits=2 starting at bitIdx=2
+/// src    =               [0101|11|10]
+/// shl    = src << 4    -> [11100000]
+/// result = shl >> 6    -> [11111111]
+static Value extractNBitsFromVectorSigned(PatternRewriter &rewriter,
+                                          Location loc, Value src, int bitIdx,
+                                          int numBits) {
   assert(bitIdx >= 0 && bitIdx <= 8 - numBits && numBits > 0 && numBits <= 8 &&
          "Invalid bitIdx range");
   auto srcType = cast<VectorType>(src.getType());
@@ -1213,61 +1230,18 @@ Value extractNBitsFromVectorSigned(PatternRewriter &rewriter, Location loc,
   return shr;
 }
 
-/// Rewrite the i4 -> i8 signed extension into a sequence of shuffles and
-/// bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
-static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
-                                    Value srcValue) {
-  VectorType srcVecType = cast<VectorType>(srcValue.getType());
-  assert(srcVecType.getElementType().isSignlessInteger(4) &&
-         "Expected i4 type");
-
-  // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
-  Value i8Vector = bitcastSubByteVectorToI8(rewriter, loc, srcValue);
-
-  // 2. Extend i4 elements to i8 elements using shifts. Low i4 elemens of each
-  // byte are place in one vector and the high i4 elements in another vector.
-  Value low = extractNBitsFromVectorSigned(rewriter, loc, i8Vector, 0, 4);
-  Value high = extractNBitsFromVectorSigned(rewriter, loc, i8Vector, 4, 4);
-
-  // 3. Interleave low and high i8 elements.
-  return rewriter.create<vector::InterleaveOp>(loc, low, high);
-}
-
-/// Rewrite the i2 -> i8 signed extension into a sequence of shuffles and
-/// bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
-static Value rewriteI2ToI8SignedExt(PatternRewriter &rewriter, Location loc,
-                                    Value srcValue) {
-  VectorType srcVecType = cast<VectorType>(srcValue.getType());
-  assert(srcVecType.getElementType().isSignlessInteger(2) &&
-         "Expected i2 type");
-
-  // 1. Generate a bitcast vector<Xxi2> -> vector<X/2xi8>.
-  Value i8Vector = bitcastSubByteVectorToI8(rewriter, loc, srcValue);
-
-  // 2. Extract each i2 element using shifts
-  // Element 0 (bits 0-1)
-  Value elem0 = extractNBitsFromVectorSigned(rewriter, loc, i8Vector, 0, 2);
-  // Element 1 (bits 2-3)
-  Value elem1 = extractNBitsFromVectorSigned(rewriter, loc, i8Vector, 2, 2);
-  // Element 2 (bits 4-5)
-  Value elem2 = extractNBitsFromVectorSigned(rewriter, loc, i8Vector, 4, 2);
-  // Element 3 (bits 6-7)
-  Value elem3 = extractNBitsFromVectorSigned(rewriter, loc, i8Vector, 6, 2);
-
-  // 3. Interleave all 4 elements by first interleaving even elements and then
-  // odd elem0 = [0,0,0,0] elem1 = [1,1,1,1] elem2 = [2,2,2,2] elem3 = [3,3,3,3]
-  // 02    = [0,2,0,2]
-  // 13    = [1,3,1,3]
-  // 0213  = [0,1,2,3]
-  Value interleave02 = rewriter.create<vector::InterleaveOp>(loc, elem0, elem2);
-  Value interleave13 = rewriter.create<vector::InterleaveOp>(loc, elem1, elem3);
-  return rewriter.create<vector::InterleaveOp>(loc, interleave02, interleave13);
-}
-
 /// Extracts an unsigned N-bit sequence from each element of an 8-bit vector,
 /// starting at the specified bit index.
-Value extractNBitsFromVectorUnsinged(PatternRewriter &rewriter, Location loc,
-                                     Value src, int bitIdx, int numBits) {
+///
+/// Example:
+/// extract numBits=2 starting at bitIdx=2
+/// src                 = [0101|10|10]
+/// mask                = [00000011]
+/// shr    = src >> 6   = [00010110]
+/// result = shr & mask = [00000010]
+static Value extractNBitsFromVectorUnsinged(PatternRewriter &rewriter,
+                                            Location loc, Value src, int bitIdx,
+                                            int numBits) {
   assert(bitIdx >= 0 && bitIdx <= 8 - numBits && numBits > 0 && numBits <= 8 &&
          "Invalid bitIdx range");
   auto srcType = cast<VectorType>(src.getType());
@@ -1287,30 +1261,33 @@ Value extractNBitsFromVectorUnsinged(PatternRewriter &rewriter, Location loc,
   return rewriter.create<arith::AndIOp>(loc, shr, lowBitsMaskValues);
 }
 
-/// Rewrite the i4 -> i8 unsigned extension into a sequence of shuffles and
+using ExtractNBitsFn =
+    std::function<Value(PatternRewriter &, Location, Value, int, int)>;
+
+/// Rewrite the i4 -> i8  extension into a sequence of shuffles and
 /// bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
-static Value rewriteI4ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
-                                      Value srcValue) {
-  VectorType srcVecType = cast<VectorType>(srcValue.getType());
+static Value rewriteI4ToI8Ext(PatternRewriter &rewriter, Location loc,
+                              Value srcValue, const ExtractNBitsFn &extFn) {
+  auto srcVecType = cast<VectorType>(srcValue.getType());
   assert(srcVecType.getElementType().isSignlessInteger(4) &&
          "Expected i4 type");
 
   // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
   Value i8Vector = bitcastSubByteVectorToI8(rewriter, loc, srcValue);
 
-  // 2 Extend the i4 elements using shifts & masking. Low i4 elements of each
-  //  byte are placed in one vector and the high i4 elements in another vector.
-  Value low = extractNBitsFromVectorUnsinged(rewriter, loc, i8Vector, 0, 4);
-  Value high = extractNBitsFromVectorUnsinged(rewriter, loc, i8Vector, 4, 4);
+  // 2. Extend i4 elements to i8 elements. Low i4 elemens of each
+  // byte are place in one vector and the high i4 elements in another vector.
+  Value low = extFn(rewriter, loc, i8Vector, 0, 4);
+  Value high = extFn(rewriter, loc, i8Vector, 4, 4);
 
   // 3. Interleave low and high i8 elements.
   return rewriter.create<vector::InterleaveOp>(loc, low, high);
 }
 
-/// Rewrite the i2 -> i8 unsigned extension into a sequence of shuffles and
+/// Rewrite the i2 -> i8  extension into a sequence of shuffles and
 /// bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
-static Value rewriteI2ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
-                                      Value srcValue) {
+static Value rewriteI2ToI8Ext(PatternRewriter &rewriter, Location loc,
+                              Value srcValue, const ExtractNBitsFn &extFn) {
   VectorType srcVecType = cast<VectorType>(srcValue.getType());
   assert(srcVecType.getElementType().isSignlessInteger(2) &&
          "Expected i2 type");
@@ -1318,18 +1295,22 @@ static Value rewriteI2ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
   // 1. Generate a bitcast vector<Xxi2> -> vector<X/2xi8>.
   Value i8Vector = bitcastSubByteVectorToI8(rewriter, loc, srcValue);
 
-  // 2. Extract each i2 element using shifts and masks
+  // 2. Extract each i2 element
   // Element 0 (bits 0-1)
-  Value elem0 = extractNBitsFromVectorUnsinged(rewriter, loc, i8Vector, 0, 2);
+  Value elem0 = extFn(rewriter, loc, i8Vector, 0, 2);
   // Element 1 (bits 2-3)
-  Value elem1 = extractNBitsFromVectorUnsinged(rewriter, loc, i8Vector, 2, 2);
+  Value elem1 = extFn(rewriter, loc, i8Vector, 2, 2);
   // Element 2 (bits 4-5)
-  Value elem2 = extractNBitsFromVectorUnsinged(rewriter, loc, i8Vector, 4, 2);
+  Value elem2 = extFn(rewriter, loc, i8Vector, 4, 2);
   // Element 3 (bits 6-7)
-  Value elem3 = extractNBitsFromVectorUnsinged(rewriter, loc, i8Vector, 6, 2);
-
-  // 3. Interleave all 4 elements by first interleaving even elements and then
-  // odd elem0 = [0,0,0,0] elem1 = [1,1,1,1] elem2 = [2,2,2,2] elem3 = [3,3,3,3]
+  Value elem3 = extFn(rewriter, loc, i8Vector, 6, 2);
+
+  // 3. Interleave all 4 elements by first interleaving
+  // even elements and then odd
+  // elem0 = [0,0,0,0]
+  // elem1 = [1,1,1,1]
+  // elem2 = [2,2,2,2]
+  // elem3 = [3,3,3,3]
   // 02    = [0,2,0,2]
   // 13    = [1,3,1,3]
   // 0213  = [0,1,2,3]
@@ -1540,33 +1521,19 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
       return failure();
 
     // Perform the rewrite.
+    Location loc = conversionOp.getLoc();
+    const auto &extFn = isSigned ? extractNBitsFromVectorSigned
+                                 : extractNBitsFromVectorUnsinged;
     Value subByteExt;
-    if (isSigned) {
-      switch (srcVecType.getElementType().getIntOrFloatBitWidth()) {
-      case 2:
-        subByteExt =
-            rewriteI2ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue);
-        break;
-      case 4:
-        subByteExt =
-            rewriteI4ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue);
-        break;
-      default:
-        return failure();
-      }
-    } else {
-      switch (srcVecType.getElementType().getIntOrFloatBitWidth()) {
-      case 2:
-        subByteExt =
-            rewriteI2ToI8UnsignedExt(rewriter, conversionOp.getLoc(), srcValue);
-        break;
-      case 4:
-        subByteExt =
-            rewriteI4ToI8UnsignedExt(rewriter, conversionOp.getLoc(), srcValue);
-        break;
-      default:
-        return failure();
-      }
+    switch (srcVecType.getElementType().getIntOrFloatBitWidth()) {
+    case 2:
+      subByteExt = rewriteI2ToI8Ext(rewriter, loc, srcValue, extFn);
+      break;
+    case 4:
+      subByteExt = rewriteI4ToI8Ext(rewriter, loc, srcValue, extFn);
+      break;
+    default:
+      return failure();
     }
 
     // Finalize the rewrite.

>From 0313d89168a5102bcb48e5e363800a835d5c1306 Mon Sep 17 00:00:00 2001
From: Thomas Ziereis <ziereis at roofline.ai>
Date: Thu, 9 Jan 2025 15:48:22 +0100
Subject: [PATCH 06/13] update documentation

---
 .../Transforms/VectorEmulateNarrowType.cpp    | 45 +++++++++++++------
 1 file changed, 31 insertions(+), 14 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 5d7fe95e395343..14d7a5ea2faca2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1183,32 +1183,41 @@ Value BitCastRewriter::genericRewriteStep(
   return runningResult;
 }
 
-/// takes a aligned subByte vector as Input and bitcasts it to a vector of i8.
+/// Bitcasts the aligned `subByteVec` vector to a vector of i8.
+/// Where aligned means it satisfies the alignedConversionPreconditions.
 ///
 /// Example:
 /// vector<16x16xi2> -> vector<16x2xi8>
 /// vector<16x16xi4> -> vector<16x4xi8>
 static Value bitcastSubByteVectorToI8(PatternRewriter &rewriter, Location loc,
-                                      Value srcValue) {
-  auto srcVecType = cast<VectorType>(srcValue.getType());
+                                      Value subByteVec) {
+  auto srcVecType = cast<VectorType>(subByteVec.getType());
   int64_t srcBitwidth = srcVecType.getElementType().getIntOrFloatBitWidth();
   assert(8 % srcBitwidth == 0 && "Invalid source bitwidth");
   int64_t bitwidthFactor = 8 / srcBitwidth;
   SmallVector<int64_t> vecShape(srcVecType.getShape());
-  // adjust last dimension of the vector so the total size remains the same.
+  // Adjust last dimension of the vector, so the total size remains the same.
   vecShape.back() = vecShape.back() / bitwidthFactor;
   auto i8VecType = VectorType::get(vecShape, rewriter.getI8Type());
-  return rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
+  return rewriter.create<vector::BitCastOp>(loc, i8VecType, subByteVec);
 }
 
 /// Extracts a signed N-bit sequence from each element of an 8-bit vector,
 /// starting at the specified bit index.
+/// The `bitIdx` starts at 0 from the LSB and moves to the left.
 ///
-/// Example:
+/// Example for a single element:
 /// extract numBits=2 starting at bitIdx=2
-/// src    =               [0101|11|10]
-/// shl    = src << 4    -> [11100000]
-/// result = shl >> 6    -> [11111111]
+/// src     = [0 | 1 | 0 | 1 | 1 | 1 | 1 | 0]
+/// indices = [7 | 6 | 5 | 4 | 3 | 2 | 1 | 0]
+/// target  = [.   .   .   .   ^   ^   .   .]
+///
+/// The target sequence is [11](decimal=-1) as signed 2-bit integer.
+/// So the result should be [11 11 11 11](decimal=-1) as signed 8-bit integer.
+///
+/// src     =                         [01 01 11 10]
+/// shl     = arith.shl(src, 4)    -> [11 10 00 00]
+/// result  = arith.shrsi(shl, 6)  -> [11 11 11 11]
 static Value extractNBitsFromVectorSigned(PatternRewriter &rewriter,
                                           Location loc, Value src, int bitIdx,
                                           int numBits) {
@@ -1232,13 +1241,21 @@ static Value extractNBitsFromVectorSigned(PatternRewriter &rewriter,
 
 /// Extracts an unsigned N-bit sequence from each element of an 8-bit vector,
 /// starting at the specified bit index.
+/// The `bitIdx` starts at 0 from the LSB and moves to the left.
 ///
-/// Example:
+/// Example for a single element:
 /// extract numBits=2 starting at bitIdx=2
-/// src                 = [0101|10|10]
-/// mask                = [00000011]
-/// shr    = src >> 6   = [00010110]
-/// result = shr & mask = [00000010]
+/// src     = [0 | 1 | 0 | 1 | 1 | 0 | 1 | 0]
+/// indices = [7 | 6 | 5 | 4 | 3 | 2 | 1 | 0]
+/// target  = [.   .   .   .   ^   ^   .   .]
+///
+/// The target sequence is [10](decimal=2) as unsigned 2-bit integer.
+/// So the result should be [00 00 00 10](decimal=2) as unsigned 8-bit integer.
+///
+/// src                            = [01 01 10 10]
+/// mask                           = [00 00 00 11]
+/// shr    = arith.shrui(src, 2)   = [00 01 01 10]
+/// result = arith.andi(shr, mask) = [00 00 00 10]
 static Value extractNBitsFromVectorUnsinged(PatternRewriter &rewriter,
                                             Location loc, Value src, int bitIdx,
                                             int numBits) {

>From 3b1005d9d0669792b506444394a1be7dc56acaaa Mon Sep 17 00:00:00 2001
From: Thomas Ziereis <ziereis at roofline.ai>
Date: Thu, 9 Jan 2025 15:54:49 +0100
Subject: [PATCH 07/13] improve naming and comments

---
 .../Transforms/VectorEmulateNarrowType.cpp    | 34 +++++++++----------
 1 file changed, 17 insertions(+), 17 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 14d7a5ea2faca2..e53696beea3a7a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1313,26 +1313,26 @@ static Value rewriteI2ToI8Ext(PatternRewriter &rewriter, Location loc,
   Value i8Vector = bitcastSubByteVectorToI8(rewriter, loc, srcValue);
 
   // 2. Extract each i2 element
-  // Element 0 (bits 0-1)
-  Value elem0 = extFn(rewriter, loc, i8Vector, 0, 2);
-  // Element 1 (bits 2-3)
-  Value elem1 = extFn(rewriter, loc, i8Vector, 2, 2);
-  // Element 2 (bits 4-5)
-  Value elem2 = extFn(rewriter, loc, i8Vector, 4, 2);
-  // Element 3 (bits 6-7)
-  Value elem3 = extFn(rewriter, loc, i8Vector, 6, 2);
+  // Positon 0 (bits 0-1)
+  Value vec0 = extFn(rewriter, loc, i8Vector, 0, 2);
+  // Position 1 (bits 2-3)
+  Value vec1 = extFn(rewriter, loc, i8Vector, 2, 2);
+  // Position 2 (bits 4-5)
+  Value vec2 = extFn(rewriter, loc, i8Vector, 4, 2);
+  // Position 3 (bits 6-7)
+  Value vec3 = extFn(rewriter, loc, i8Vector, 6, 2);
 
   // 3. Interleave all 4 elements by first interleaving
   // even elements and then odd
-  // elem0 = [0,0,0,0]
-  // elem1 = [1,1,1,1]
-  // elem2 = [2,2,2,2]
-  // elem3 = [3,3,3,3]
-  // 02    = [0,2,0,2]
-  // 13    = [1,3,1,3]
-  // 0213  = [0,1,2,3]
-  Value interleave02 = rewriter.create<vector::InterleaveOp>(loc, elem0, elem2);
-  Value interleave13 = rewriter.create<vector::InterleaveOp>(loc, elem1, elem3);
+  // vec0  = [0,0,0,0],...
+  // vec1  = [1,1,1,1],...
+  // vec2  = [2,2,2,2],...
+  // vec3  = [3,3,3,3],...
+  // 02    = [0,2,0,2,...],...
+  // 13    = [1,3,1,3,...],...
+  // 0213  = [0,1,2,3,...],...
+  Value interleave02 = rewriter.create<vector::InterleaveOp>(loc, vec0, vec2);
+  Value interleave13 = rewriter.create<vector::InterleaveOp>(loc, vec1, vec3);
   return rewriter.create<vector::InterleaveOp>(loc, interleave02, interleave13);
 }
 

>From b975051172e44afa331d90a6e0f1aa7ab39fed20 Mon Sep 17 00:00:00 2001
From: Thomas Ziereis <ziereis at roofline.ai>
Date: Thu, 9 Jan 2025 16:14:43 +0100
Subject: [PATCH 08/13] improve naming and fix typo

---
 .../Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp   | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index e53696beea3a7a..62c95eea3e7eee 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1193,7 +1193,7 @@ static Value bitcastSubByteVectorToI8(PatternRewriter &rewriter, Location loc,
                                       Value subByteVec) {
   auto srcVecType = cast<VectorType>(subByteVec.getType());
   int64_t srcBitwidth = srcVecType.getElementType().getIntOrFloatBitWidth();
-  assert(8 % srcBitwidth == 0 && "Invalid source bitwidth");
+  assert(8 % srcBitwidth == 0 && "Unsupported sub-byte type (not a divisor of i8)");
   int64_t bitwidthFactor = 8 / srcBitwidth;
   SmallVector<int64_t> vecShape(srcVecType.getShape());
   // Adjust last dimension of the vector, so the total size remains the same.
@@ -1207,7 +1207,7 @@ static Value bitcastSubByteVectorToI8(PatternRewriter &rewriter, Location loc,
 /// The `bitIdx` starts at 0 from the LSB and moves to the left.
 ///
 /// Example for a single element:
-/// extract numBits=2 starting at bitIdx=2
+/// Extract numBits=2 starting at bitIdx=2
 /// src     = [0 | 1 | 0 | 1 | 1 | 1 | 1 | 0]
 /// indices = [7 | 6 | 5 | 4 | 3 | 2 | 1 | 0]
 /// target  = [.   .   .   .   ^   ^   .   .]
@@ -1244,7 +1244,7 @@ static Value extractNBitsFromVectorSigned(PatternRewriter &rewriter,
 /// The `bitIdx` starts at 0 from the LSB and moves to the left.
 ///
 /// Example for a single element:
-/// extract numBits=2 starting at bitIdx=2
+/// Extract numBits=2 starting at bitIdx=2
 /// src     = [0 | 1 | 0 | 1 | 1 | 0 | 1 | 0]
 /// indices = [7 | 6 | 5 | 4 | 3 | 2 | 1 | 0]
 /// target  = [.   .   .   .   ^   ^   .   .]

>From 1196049fd96c0d66eceb685ed6991cb25840bdba Mon Sep 17 00:00:00 2001
From: Thomas Ziereis <ziereis at roofline.ai>
Date: Thu, 9 Jan 2025 18:11:05 +0100
Subject: [PATCH 09/13] tests for precondition and note for optimization

---
 .../Transforms/VectorEmulateNarrowType.cpp      |  9 +++++++--
 .../Vector/vector-rewrite-narrow-types.mlir     | 17 +++++++++++++++++
 2 files changed, 24 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 62c95eea3e7eee..618374faa6ee37 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1099,7 +1099,7 @@ static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
     return rewriter.notifyMatchFailure(
         op, "only src bitwidth of 2 or 4 is supported at this moment");
 
-  const int numSrcElemsPerDestElem = dstElemBitwidth / srcElemBitwidth;
+  const int numSrcElemsPerDestElem = 8 / srcElemBitwidth;
   if ((subByteVecType.getShape().back() % numSrcElemsPerDestElem) != 0)
     return rewriter.notifyMatchFailure(
         op, "Not an even number of i4 elements in trailing dim");
@@ -1193,7 +1193,8 @@ static Value bitcastSubByteVectorToI8(PatternRewriter &rewriter, Location loc,
                                       Value subByteVec) {
   auto srcVecType = cast<VectorType>(subByteVec.getType());
   int64_t srcBitwidth = srcVecType.getElementType().getIntOrFloatBitWidth();
-  assert(8 % srcBitwidth == 0 && "Unsupported sub-byte type (not a divisor of i8)");
+  assert(8 % srcBitwidth == 0 &&
+         "Unsupported sub-byte type (not a divisor of i8)");
   int64_t bitwidthFactor = 8 / srcBitwidth;
   SmallVector<int64_t> vecShape(srcVecType.getShape());
   // Adjust last dimension of the vector, so the total size remains the same.
@@ -1256,6 +1257,10 @@ static Value extractNBitsFromVectorSigned(PatternRewriter &rewriter,
 /// mask                           = [00 00 00 11]
 /// shr    = arith.shrui(src, 2)   = [00 01 01 10]
 /// result = arith.andi(shr, mask) = [00 00 00 10]
+/// NOTE: Similarly to extractNBitsFromVectorSigned, this could be achieved by
+/// using arith::ShLIOp + arith::ShRUIOp instead of the masking. However, by
+/// using arith::ShRUIOp + arith::AndIOp, we are eliminating shift left when the
+/// index is 0.
 static Value extractNBitsFromVectorUnsinged(PatternRewriter &rewriter,
                                             Location loc, Value src, int bitIdx,
                                             int numBits) {
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
index 3d37fe4efa40c4..e84a6db7052a73 100644
--- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -193,6 +193,23 @@ func.func @f3ext(%a: vector<5xi8>) -> vector<8xi17> {
   return %1 : vector<8xi17>
 }
 
+
+// CHECK-LABEL: func.func @aligned_i4_trailing_dim_not_multiple(
+func.func @aligned_i4_trailing_dim_not_multiple(%a: vector<1xi4>) -> vector<1xi8> {
+  // CHECK-NOT: arith.bitcast
+  // CHECK: arith.extsi %[[IN:.*]] : vector<1xi4> to vector<1xi8>
+  %0 = arith.extsi %a : vector<1xi4> to vector<1xi8>
+  return %0 : vector<1xi8>
+}
+
+// CHECK-LABEL: func.func @aligned_i2_trailing_dim_not_multiple(
+func.func @aligned_i2_trailing_dim_not_multiple(%a: vector<2xi2>) -> vector<2xi8> {
+  // CHECK-NOT: arith.bitcast
+  // CHECK: arith.extsi %[[IN:.*]] : vector<2xi2> to vector<2xi8>
+  %0 = arith.extsi %a : vector<2xi2> to vector<2xi8>
+  return %0 : vector<2xi8>
+}
+
 // CHECK-LABEL: func.func @aligned_extsi_i4_to_i8(
 func.func @aligned_extsi_i4_to_i8(%a: vector<8xi4>) -> vector<8xi8> {
 // CHECK-SAME:    %[[IN:.*]]: vector<8xi4>) -> vector<8xi8> {

>From 106f8b7e6095dcaf512acd8e24598dcc50fa23cf Mon Sep 17 00:00:00 2001
From: Thomas Ziereis <ziereis at roofline.ai>
Date: Thu, 9 Jan 2025 20:23:44 +0100
Subject: [PATCH 10/13] better name for extract function

---
 .../Transforms/VectorEmulateNarrowType.cpp       | 16 ++++++++--------
 1 file changed, 8 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 618374faa6ee37..f9975402e7c68e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1219,9 +1219,9 @@ static Value bitcastSubByteVectorToI8(PatternRewriter &rewriter, Location loc,
 /// src     =                         [01 01 11 10]
 /// shl     = arith.shl(src, 4)    -> [11 10 00 00]
 /// result  = arith.shrsi(shl, 6)  -> [11 11 11 11]
-static Value extractNBitsFromVectorSigned(PatternRewriter &rewriter,
-                                          Location loc, Value src, int bitIdx,
-                                          int numBits) {
+static Value extractNBitsPerByteAndSignExtendToI8(PatternRewriter &rewriter,
+                                                  Location loc, Value src,
+                                                  int bitIdx, int numBits) {
   assert(bitIdx >= 0 && bitIdx <= 8 - numBits && numBits > 0 && numBits <= 8 &&
          "Invalid bitIdx range");
   auto srcType = cast<VectorType>(src.getType());
@@ -1261,9 +1261,9 @@ static Value extractNBitsFromVectorSigned(PatternRewriter &rewriter,
 /// using arith::ShLIOp + arith::ShRUIOp instead of the masking. However, by
 /// using arith::ShRUIOp + arith::AndIOp, we are eliminating shift left when the
 /// index is 0.
-static Value extractNBitsFromVectorUnsinged(PatternRewriter &rewriter,
-                                            Location loc, Value src, int bitIdx,
-                                            int numBits) {
+static Value extractNBitsPerByteAndExtendToI8(PatternRewriter &rewriter,
+                                              Location loc, Value src,
+                                              int bitIdx, int numBits) {
   assert(bitIdx >= 0 && bitIdx <= 8 - numBits && numBits > 0 && numBits <= 8 &&
          "Invalid bitIdx range");
   auto srcType = cast<VectorType>(src.getType());
@@ -1544,8 +1544,8 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
 
     // Perform the rewrite.
     Location loc = conversionOp.getLoc();
-    const auto &extFn = isSigned ? extractNBitsFromVectorSigned
-                                 : extractNBitsFromVectorUnsinged;
+    const auto &extFn = isSigned ? extractNBitsPerByteAndSignExtendToI8
+                                 : extractNBitsPerByteAndExtendToI8;
     Value subByteExt;
     switch (srcVecType.getElementType().getIntOrFloatBitWidth()) {
     case 2:

>From 7e25b9a31679b3dee464bb59c827860cc0fe0479 Mon Sep 17 00:00:00 2001
From: ziereis <44057120+ziereis at users.noreply.github.com>
Date: Fri, 10 Jan 2025 20:00:23 +0100
Subject: [PATCH 11/13] Update
 mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Co-authored-by: Andrzej Warzyński <andrzej.warzynski at gmail.com>
---
 .../lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index f9975402e7c68e..92c4c678e40638 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1187,8 +1187,8 @@ Value BitCastRewriter::genericRewriteStep(
 /// Where aligned means it satisfies the alignedConversionPreconditions.
 ///
 /// Example:
-/// vector<16x16xi2> -> vector<16x2xi8>
-/// vector<16x16xi4> -> vector<16x4xi8>
+/// vector<16x16xi2> -> vector<16x4xi8>
+/// vector<16x16xi4> -> vector<16x8xi8>
 static Value bitcastSubByteVectorToI8(PatternRewriter &rewriter, Location loc,
                                       Value subByteVec) {
   auto srcVecType = cast<VectorType>(subByteVec.getType());

>From 70ae38a02035e78c9e753f594e05e5687b80e86c Mon Sep 17 00:00:00 2001
From: ziereis <44057120+ziereis at users.noreply.github.com>
Date: Fri, 10 Jan 2025 20:00:49 +0100
Subject: [PATCH 12/13] Update
 mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Co-authored-by: Andrzej Warzyński <andrzej.warzynski at gmail.com>
---
 mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 92c4c678e40638..204995655f2241 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1203,7 +1203,7 @@ static Value bitcastSubByteVectorToI8(PatternRewriter &rewriter, Location loc,
   return rewriter.create<vector::BitCastOp>(loc, i8VecType, subByteVec);
 }
 
-/// Extracts a signed N-bit sequence from each element of an 8-bit vector,
+/// Extracts a signed N-bit sequence from each element of a vector of bytes,
 /// starting at the specified bit index.
 /// The `bitIdx` starts at 0 from the LSB and moves to the left.
 ///

>From 86e11c4da3c23a7601ddbaf7a3537e1b50b6ee3a Mon Sep 17 00:00:00 2001
From: ziereis <44057120+ziereis at users.noreply.github.com>
Date: Fri, 10 Jan 2025 20:01:03 +0100
Subject: [PATCH 13/13] Update
 mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Co-authored-by: Andrzej Warzyński <andrzej.warzynski at gmail.com>
---
 .../lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 204995655f2241..c7e2b2ce60dd74 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1333,8 +1333,8 @@ static Value rewriteI2ToI8Ext(PatternRewriter &rewriter, Location loc,
   // vec1  = [1,1,1,1],...
   // vec2  = [2,2,2,2],...
   // vec3  = [3,3,3,3],...
-  // 02    = [0,2,0,2,...],...
-  // 13    = [1,3,1,3,...],...
+  // 02    = [0,2,0,2,0,2,0,2],...
+  // 13    = [1,3,1,3,1,3,1,3],...
   // 0213  = [0,1,2,3,...],...
   Value interleave02 = rewriter.create<vector::InterleaveOp>(loc, vec0, vec2);
   Value interleave13 = rewriter.create<vector::InterleaveOp>(loc, vec1, vec3);



More information about the Mlir-commits mailing list