[Mlir-commits] [mlir] Rewrites for I2 to I8 signed and unsigned extension (PR #121298)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Dec 29 07:50:32 PST 2024


https://github.com/ziereis created https://github.com/llvm/llvm-project/pull/121298

Adds rewrites for i2 to i8 signed and unsigned extension, similar to the ones that already exist for i4 to i8 conversion.

I use this for i6 quantized models, and this gives me roughly a 2x speedup for an i6 4096x4096 dequantization-matmul on an AMD 5950x.

I didn't add the rewrite for i8 to i2 truncation because I currently don't use it, but if this is needed, I can add it as well.


>From 25333f82f3fd4517f34aea65e57121610c8c6634 Mon Sep 17 00:00:00 2001
From: Thomas Ziereis <ziereis at roofline.ai>
Date: Sun, 29 Dec 2024 16:35:54 +0100
Subject: [PATCH 1/3] init

---
 .../Transforms/VectorEmulateNarrowType.cpp    | 137 +++++++++++++++-
 .../Vector/vector-rewrite-narrow-types.mlir   | 149 +++++++++++++++++-
 2 files changed, 276 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 181c394edc1d20..519e6e68bc1b0a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1084,8 +1084,8 @@ static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
   unsigned srcElemBitwidth = srcType.getElementTypeBitWidth();
   unsigned dstElemBitwidth = dstType.getElementTypeBitWidth();
 
-  // Only {s}i4 -> (size_of({{s}i/f}) >= 8) are supported for now.
-  if (srcElemBitwidth != 4 || dstElemBitwidth < 8 ||
+  // Only {s}i4/i2 -> (size_of({{s}i/f}) >= 8) are supported for now.
+  if ((srcElemBitwidth != 4 && srcElemBitwidth != 2) || dstElemBitwidth < 8 ||
       (dstElemBitwidth % srcElemBitwidth) != 0)
     return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
 
@@ -1233,6 +1233,117 @@ static Value rewriteI4ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
   return rewriter.create<vector::InterleaveOp>(loc, low, high);
 }
 
+/// Rewrite the i2 -> i8 signed extension into a sequence of shuffles and
+/// bitwise ops that take advantage of high-level information to avoid leaving
+/// LLVM to scramble with peephole optimizations.
+static Value rewriteI2ToI8SignedExt(PatternRewriter &rewriter, Location loc,
+                                    Value srcValue) {
+  VectorType srcVecType = cast<VectorType>(srcValue.getType());
+  assert(srcVecType.getElementType().isSignlessInteger(2) &&
+         "Expected i2 type");
+
+  // 1. Generate a bitcast vector<Xxi2> -> vector<X/4xi8>.
+  SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
+  constexpr int64_t i2Toi8BitwidthFactor = 4;
+  i8VecShape.back() = i8VecShape.back() / i2Toi8BitwidthFactor;
+  auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
+  Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
+
+  // Position 0 (bits 0-1)
+  constexpr int8_t shiftConst6 = 6;
+  auto shiftAttr6 = DenseElementsAttr::get(i8VecType, shiftConst6);
+  auto shiftValues6 = rewriter.create<arith::ConstantOp>(loc, shiftAttr6);
+  Value shl0 = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues6);
+  Value elem0 = rewriter.create<arith::ShRSIOp>(loc, shl0, shiftValues6);
+
+  // Position 1 (bits 2-3)
+  constexpr int8_t shiftConst4 = 4;
+  auto shiftAttr4 = DenseElementsAttr::get(i8VecType, shiftConst4);
+  auto shiftValues4 = rewriter.create<arith::ConstantOp>(loc, shiftAttr4);
+  Value shl1 = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues4);
+  Value elem1 = rewriter.create<arith::ShRSIOp>(loc, shl1, shiftValues6);
+
+  // Position 1 (bits 4-5)
+  constexpr int8_t shiftConst2 = 2;
+  auto shiftAttr2 = DenseElementsAttr::get(i8VecType, shiftConst2);
+  auto shiftValues2 = rewriter.create<arith::ConstantOp>(loc, shiftAttr2);
+  Value shl2 = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues2);
+  Value elem2 = rewriter.create<arith::ShRSIOp>(loc, shl2, shiftValues6);
+
+  // Position 3 (bits 6-7)
+  Value elem3 = rewriter.create<arith::ShRSIOp>(loc, i8Vector, shiftValues6);
+
+  // interleave all 4 elements by first interleaving even elements and then odd
+  // elem0 = [0,0,0,0]
+  // elem1 = [1,1,1,1]
+  // elem2 = [2,2,2,2]
+  // elem3 = [3,3,3,3]
+  // 02    = [0,2,0,2]
+  // 13    = [1,3,1,3]
+  // 0213  = [0,1,2,3]
+  Value interleave02 = rewriter.create<vector::InterleaveOp>(loc, elem0, elem2);
+  Value interleave13 = rewriter.create<vector::InterleaveOp>(loc, elem1, elem3);
+  return rewriter.create<vector::InterleaveOp>(loc, interleave02, interleave13);
+}
+
+/// Rewrite the i2 -> i8 unsigned extension into a sequence of shuffles and
+/// bitwise ops that take advantage of high-level information to avoid leaving
+/// LLVM to scramble with peephole optimizations.
+static Value rewriteI2ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
+                                      Value srcValue) {
+  VectorType srcVecType = cast<VectorType>(srcValue.getType());
+  assert(srcVecType.getElementType().isSignlessInteger(2) &&
+         "Expected i2 type");
+
+  // 1. Generate a bitcast vector<Xxi2> -> vector<X/4xi8>.
+  SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
+  constexpr int64_t i2Toi8BitwidthFactor = 4;
+  i8VecShape.back() = i8VecShape.back() / i2Toi8BitwidthFactor;
+  auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
+  Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
+
+  // 2. Extract each i2 element using shifts and masks
+  constexpr uint8_t mask = 3; // Mask for 2 bits: [0000 0011]
+  auto maskAttr = DenseElementsAttr::get(i8VecType, mask);
+  auto maskValues = rewriter.create<arith::ConstantOp>(loc, maskAttr);
+
+  // Element 0 (bits 0-1)
+  Value elem0 = rewriter.create<arith::AndIOp>(loc, i8Vector, maskValues);
+
+  // Element 1 (bits 2-3)
+  constexpr int8_t shift1 = 2;
+  auto shiftAttr1 = DenseElementsAttr::get(i8VecType, shift1);
+  auto shiftValues1 = rewriter.create<arith::ConstantOp>(loc, shiftAttr1);
+  Value shifted1 = rewriter.create<arith::ShRUIOp>(loc, i8Vector, shiftValues1);
+  Value elem1 = rewriter.create<arith::AndIOp>(loc, shifted1, maskValues);
+
+  // Element 2 (bits 4-5)
+  constexpr int8_t shift2 = 4;
+  auto shiftAttr2 = DenseElementsAttr::get(i8VecType, shift2);
+  auto shiftValues2 = rewriter.create<arith::ConstantOp>(loc, shiftAttr2);
+  Value shifted2 = rewriter.create<arith::ShRUIOp>(loc, i8Vector, shiftValues2);
+  Value elem2 = rewriter.create<arith::AndIOp>(loc, shifted2, maskValues);
+
+  // Element 3 (bits 6-7)
+  constexpr int8_t shift3 = 6;
+  auto shiftAttr3 = DenseElementsAttr::get(i8VecType, shift3);
+  auto shiftValues3 = rewriter.create<arith::ConstantOp>(loc, shiftAttr3);
+  Value shifted3 = rewriter.create<arith::ShRUIOp>(loc, i8Vector, shiftValues3);
+  Value elem3 = rewriter.create<arith::AndIOp>(loc, shifted3, maskValues);
+
+  // interleave all 4 elements by first interleaving even elements and then odd
+  // elem0 = [0,0,0,0]
+  // elem1 = [1,1,1,1]
+  // elem2 = [2,2,2,2]
+  // elem3 = [3,3,3,3]
+  // 02    = [0,2,0,2]
+  // 13    = [1,3,1,3]
+  // 0213  = [0,1,2,3]
+  Value interleave02 = rewriter.create<vector::InterleaveOp>(loc, elem0, elem2);
+  Value interleave13 = rewriter.create<vector::InterleaveOp>(loc, elem1, elem3);
+  return rewriter.create<vector::InterleaveOp>(loc, interleave02, interleave13);
+}
+
 /// Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise
 /// ops that take advantage of high-level information to avoid leaving LLVM to
 /// scramble with peephole optimizations.
@@ -1438,11 +1549,21 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
     // Perform the rewrite.
     Value subByteExt;
     if (isSigned) {
-      subByteExt =
-          rewriteI4ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue);
+      if (srcVecType.getElementType().getIntOrFloatBitWidth() == 2)
+        subByteExt =
+            rewriteI2ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue);
+      else {
+        subByteExt =
+            rewriteI4ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue);
+      }
     } else {
-      subByteExt =
-          rewriteI4ToI8UnsignedExt(rewriter, conversionOp.getLoc(), srcValue);
+      if (srcVecType.getElementType().getIntOrFloatBitWidth() == 2) {
+        subByteExt =
+            rewriteI2ToI8UnsignedExt(rewriter, conversionOp.getLoc(), srcValue);
+      } else {
+        subByteExt =
+            rewriteI4ToI8UnsignedExt(rewriter, conversionOp.getLoc(), srcValue);
+      }
     }
 
     // Finalize the rewrite.
@@ -1489,6 +1610,10 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
                                              truncOp)))
       return failure();
 
+    // not supported currently.
+    if (dstVecType.getElementType().getIntOrFloatBitWidth() == 2)
+      return failure();
+
     // Create a new iX -> i8 truncation op.
     Location loc = truncOp.getLoc();
     auto i8VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI8Type());
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
index 210025e30d7db5..0b469066f290c2 100644
--- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -220,8 +220,8 @@ func.func @aligned_extsi_i4_to_i32(%a: vector<8xi4>) -> vector<8xi32> {
   return %0 : vector<8xi32>
 }
 
-// CHECK-LABEL: func.func @aligned_extsi_2d(
-func.func @aligned_extsi_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
+// CHECK-LABEL: func.func @aligned_extsi_i4_2d(
+func.func @aligned_extsi_i4_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
 // CHECK-SAME:    %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xi32> {
 // CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
 // CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8>
@@ -234,6 +234,72 @@ func.func @aligned_extsi_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
   return %0 : vector<8x32xi32>
 }
 
+// CHECK-LABEL: func.func @aligned_extsi_i2_to_i8(
+func.func @aligned_extsi_i2_to_i8(%a: vector<8xi2>) -> vector<8xi8> {
+// CHECK-SAME:      %[[IN:.*]]: vector<8xi2>) -> vector<8xi8> {
+// CHECK:           %[[CST_2:.*]] = arith.constant dense<2> : vector<2xi8>
+// CHECK:           %[[CST_4:.*]] = arith.constant dense<4> : vector<2xi8>
+// CHECK:           %[[CST_6:.*]] = arith.constant dense<6> : vector<2xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi2> to vector<2xi8>
+// CHECK:           %[[SHL_6:.*]] = arith.shli %[[BITCAST]], %[[CST_6]] : vector<2xi8>
+// CHECK:           %[[ELEM0:.*]] = arith.shrsi %[[SHL_6]], %[[CST_6]] : vector<2xi8>
+// CHECK:           %[[SHL_4:.*]] = arith.shli %[[BITCAST]], %[[CST_4]] : vector<2xi8>
+// CHECK:           %[[ELEM1:.*]] = arith.shrsi %[[SHL_4]], %[[CST_6]] : vector<2xi8>
+// CHECK:           %[[SHL_2:.*]] = arith.shli %[[BITCAST]], %[[CST_2]] : vector<2xi8>
+// CHECK:           %[[ELEM2:.*]] = arith.shrsi %[[SHL_2]], %[[CST_6]] : vector<2xi8>
+// CHECK:           %[[ELEM3:.*]] = arith.shrsi %[[BITCAST]], %[[CST_6]] : vector<2xi8>
+// CHECK:           %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<2xi8>
+// CHECK:           %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<2xi8>
+// CHECK:           %[[RESULT:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<4xi8>
+  %0 = arith.extsi %a : vector<8xi2> to vector<8xi8>
+  return %0 : vector<8xi8>
+}
+
+
+// CHECK-LABEL: func.func @aligned_extsi_i2_to_i32(
+func.func @aligned_extsi_i2_to_i32(%a: vector<8xi2>) -> vector<8xi32> {
+// CHECK-SAME:      %[[IN:.*]]: vector<8xi2>) -> vector<8xi32> {
+// CHECK:           %[[CST_2:.*]] = arith.constant dense<2> : vector<2xi8>
+// CHECK:           %[[CST_4:.*]] = arith.constant dense<4> : vector<2xi8>
+// CHECK:           %[[CST_6:.*]] = arith.constant dense<6> : vector<2xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi2> to vector<2xi8>
+// CHECK:           %[[SHL_6:.*]] = arith.shli %[[BITCAST]], %[[CST_6]] : vector<2xi8>
+// CHECK:           %[[ELEM0:.*]] = arith.shrsi %[[SHL_6]], %[[CST_6]] : vector<2xi8>
+// CHECK:           %[[SHL_4:.*]] = arith.shli %[[BITCAST]], %[[CST_4]] : vector<2xi8>
+// CHECK:           %[[ELEM1:.*]] = arith.shrsi %[[SHL_4]], %[[CST_6]] : vector<2xi8>
+// CHECK:           %[[SHL_2:.*]] = arith.shli %[[BITCAST]], %[[CST_2]] : vector<2xi8>
+// CHECK:           %[[ELEM2:.*]] = arith.shrsi %[[SHL_2]], %[[CST_6]] : vector<2xi8>
+// CHECK:           %[[ELEM3:.*]] = arith.shrsi %[[BITCAST]], %[[CST_6]] : vector<2xi8>
+// CHECK:           %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<2xi8>
+// CHECK:           %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<2xi8>
+// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<4xi8>
+// CHECK:           %[[RESULT:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32>
+  %0 = arith.extsi %a : vector<8xi2> to vector<8xi32>
+  return %0 : vector<8xi32>
+}
+
+// CHECK-LABEL: func.func @aligned_extsi_i2_2d(
+func.func @aligned_extsi_i2_2d(%a: vector<8x32xi2>) -> vector<8x32xi32> {
+// CHECK-SAME:      %[[IN:.*]]: vector<8x32xi2>) -> vector<8x32xi32> {
+// CHECK:           %[[CST_2:.*]] = arith.constant dense<2> : vector<8x8xi8>
+// CHECK:           %[[CST_4:.*]] = arith.constant dense<4> : vector<8x8xi8>
+// CHECK:           %[[CST_6:.*]] = arith.constant dense<6> : vector<8x8xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi2> to vector<8x8xi8>
+// CHECK:           %[[SHL_6:.*]] = arith.shli %[[BITCAST]], %[[CST_6]] : vector<8x8xi8>
+// CHECK:           %[[ELEM0:.*]] = arith.shrsi %[[SHL_6]], %[[CST_6]] : vector<8x8xi8>
+// CHECK:           %[[SHL_4:.*]] = arith.shli %[[BITCAST]], %[[CST_4]] : vector<8x8xi8>
+// CHECK:           %[[ELEM1:.*]] = arith.shrsi %[[SHL_4]], %[[CST_6]] : vector<8x8xi8>
+// CHECK:           %[[SHL_2:.*]] = arith.shli %[[BITCAST]], %[[CST_2]] : vector<8x8xi8>
+// CHECK:           %[[ELEM2:.*]] = arith.shrsi %[[SHL_2]], %[[CST_6]] : vector<8x8xi8>
+// CHECK:           %[[ELEM3:.*]] = arith.shrsi %[[BITCAST]], %[[CST_6]] : vector<8x8xi8>
+// CHECK:           %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<8x8xi8>
+// CHECK:           %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<8x8xi8>
+// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<8x16xi8>
+// CHECK:           %[[RESULT:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32>
+  %0 = arith.extsi %a : vector<8x32xi2> to vector<8x32xi32>
+  return %0 : vector<8x32xi32>
+}
+
 
 // CHECK-LABEL: func.func @aligned_trunci_i8_to_i4(
 func.func @aligned_trunci_i8_to_i4(%a: vector<8xi8>) -> vector<8xi4> {
@@ -292,6 +358,13 @@ func.func @aligned_trunci_nd(%a: vector<3x8x32xi32>) -> vector<3x8x32xi4> {
   return %0 : vector<3x8x32xi4>
 }
 
+func.func @aligned_trunci_i8_to_i2_no_match(%a: vector<8xi8>) -> vector<8xi2> {
+  // CHECK-NOT: arith.bitcast
+  // CHECK: arith.trunci %[[IN:.*]] : vector<8xi8> to vector<8xi2>
+  %0 = arith.trunci %a : vector<8xi8> to vector<8xi2>
+  return %0 : vector<8xi2>
+}
+
 // CHECK-LABEL: func.func @aligned_extui_i4_to_i8(
 func.func @aligned_extui_i4_to_i8(%a: vector<8xi4>) -> vector<8xi8> {
 // CHECK-SAME:                             %[[IN:.*]]: vector<8xi4>) -> vector<8xi8> {
@@ -319,8 +392,8 @@ func.func @aligned_extui_i4_to_i32(%a: vector<8xi4>) -> vector<8xi32> {
   return %0 : vector<8xi32>
 }
 
-// CHECK-LABEL: func.func @aligned_extui_2d(
-func.func @aligned_extui_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
+// CHECK-LABEL: func.func @aligned_extui_i4_2d(
+func.func @aligned_extui_i4_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
 // CHECK-SAME:                                %[[VAL_0:.*]]: vector<8x32xi4>) -> vector<8x32xi32> {
 // CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
 // CHECK:           %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<8x16xi8>
@@ -333,6 +406,74 @@ func.func @aligned_extui_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
   return %0 : vector<8x32xi32>
 }
 
+// CHECK-LABEL: func.func @aligned_extui_i2_to_i8(
+func.func @aligned_extui_i2_to_i8(%a: vector<8xi2>) -> vector<8xi8> {
+// CHECK-SAME:      %[[IN:.*]]: vector<8xi2>) -> vector<8xi8> {
+// CHECK:           %[[CST_6:.*]] = arith.constant dense<6> : vector<2xi8>
+// CHECK:           %[[CST_4:.*]] = arith.constant dense<4> : vector<2xi8>
+// CHECK:           %[[CST_2:.*]] = arith.constant dense<2> : vector<2xi8>
+// CHECK:           %[[LOWBITS_MASK:.*]] = arith.constant dense<3> : vector<2xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi2> to vector<2xi8>
+// CHECK:           %[[ELEM0:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<2xi8>
+// CHECK:           %[[SHR_2:.*]] = arith.shrui %[[BITCAST]], %[[CST_2]] : vector<2xi8>
+// CHECK:           %[[ELEM1:.*]] = arith.andi %[[SHR_2]], %[[LOWBITS_MASK]] : vector<2xi8>
+// CHECK:           %[[SHR_4:.*]] = arith.shrui %[[BITCAST]], %[[CST_4]] : vector<2xi8>
+// CHECK:           %[[ELEM2:.*]] = arith.andi %[[SHR_4]], %[[LOWBITS_MASK]] : vector<2xi8>
+// CHECK:           %[[SHR_6:.*]] = arith.shrui %[[BITCAST]], %[[CST_6]] : vector<2xi8>
+// CHECK:           %[[ELEM3:.*]] = arith.andi %[[SHR_6]], %[[LOWBITS_MASK]] : vector<2xi8>
+// CHECK:           %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<2xi8>
+// CHECK:           %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<2xi8>
+// CHECK:           %[[RESULT:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<4xi8>
+  %0 = arith.extui %a : vector<8xi2> to vector<8xi8>
+  return %0 : vector<8xi8>
+}
+
+// CHECK-LABEL: func.func @aligned_extui_i2_to_i32(
+func.func @aligned_extui_i2_to_i32(%a: vector<8xi2>) -> vector<8xi32> {
+// CHECK-SAME:      %[[IN:.*]]: vector<8xi2>) -> vector<8xi32> {
+// CHECK:           %[[CST_6:.*]] = arith.constant dense<6> : vector<2xi8>
+// CHECK:           %[[CST_4:.*]] = arith.constant dense<4> : vector<2xi8>
+// CHECK:           %[[CST_2:.*]] = arith.constant dense<2> : vector<2xi8>
+// CHECK:           %[[LOWBITS_MASK:.*]] = arith.constant dense<3> : vector<2xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi2> to vector<2xi8>
+// CHECK:           %[[ELEM0:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<2xi8>
+// CHECK:           %[[SHR_2:.*]] = arith.shrui %[[BITCAST]], %[[CST_2]] : vector<2xi8>
+// CHECK:           %[[ELEM1:.*]] = arith.andi %[[SHR_2]], %[[LOWBITS_MASK]] : vector<2xi8>
+// CHECK:           %[[SHR_4:.*]] = arith.shrui %[[BITCAST]], %[[CST_4]] : vector<2xi8>
+// CHECK:           %[[ELEM2:.*]] = arith.andi %[[SHR_4]], %[[LOWBITS_MASK]] : vector<2xi8>
+// CHECK:           %[[SHR_6:.*]] = arith.shrui %[[BITCAST]], %[[CST_6]] : vector<2xi8>
+// CHECK:           %[[ELEM3:.*]] = arith.andi %[[SHR_6]], %[[LOWBITS_MASK]] : vector<2xi8>
+// CHECK:           %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<2xi8>
+// CHECK:           %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<2xi8>
+// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<4xi8>
+// CHECK:           %[[RESULT:.*]] = arith.extui %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32>
+  %0 = arith.extui %a : vector<8xi2> to vector<8xi32>
+  return %0 : vector<8xi32>
+}
+
+// CHECK-LABEL: func.func @aligned_extui_i2_2d(
+func.func @aligned_extui_i2_2d(%a: vector<8x32xi2>) -> vector<8x32xi32> {
+// CHECK-SAME:      %[[IN:.*]]: vector<8x32xi2>) -> vector<8x32xi32> {
+// CHECK:           %[[CST_6:.*]] = arith.constant dense<6> : vector<8x8xi8>
+// CHECK:           %[[CST_4:.*]] = arith.constant dense<4> : vector<8x8xi8>
+// CHECK:           %[[CST_2:.*]] = arith.constant dense<2> : vector<8x8xi8>
+// CHECK:           %[[LOWBITS_MASK:.*]] = arith.constant dense<3> : vector<8x8xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi2> to vector<8x8xi8>
+// CHECK:           %[[ELEM0:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<8x8xi8>
+// CHECK:           %[[SHR_2:.*]] = arith.shrui %[[BITCAST]], %[[CST_2]] : vector<8x8xi8>
+// CHECK:           %[[ELEM1:.*]] = arith.andi %[[SHR_2]], %[[LOWBITS_MASK]] : vector<8x8xi8>
+// CHECK:           %[[SHR_4:.*]] = arith.shrui %[[BITCAST]], %[[CST_4]] : vector<8x8xi8>
+// CHECK:           %[[ELEM2:.*]] = arith.andi %[[SHR_4]], %[[LOWBITS_MASK]] : vector<8x8xi8>
+// CHECK:           %[[SHR_6:.*]] = arith.shrui %[[BITCAST]], %[[CST_6]] : vector<8x8xi8>
+// CHECK:           %[[ELEM3:.*]] = arith.andi %[[SHR_6]], %[[LOWBITS_MASK]] : vector<8x8xi8>
+// CHECK:           %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<8x8xi8>
+// CHECK:           %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<8x8xi8>
+// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<8x16xi8>
+// CHECK:           %[[RESULT:.*]] = arith.extui %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32>
+  %0 = arith.extui %a : vector<8x32xi2> to vector<8x32xi32>
+  return %0 : vector<8x32xi32>
+}
+
 // CHECK-LABEL: func.func @aligned_sitofp(
 func.func @aligned_sitofp(%a: vector<8xi4>) -> vector<8xf32> {
 // CHECK-SAME:    %[[IN:.*]]: vector<8xi4>) -> vector<8xf32> {

>From 29bad6e225de28a4d050c243abfce0f8de5a715d Mon Sep 17 00:00:00 2001
From: Thomas Ziereis <ziereis at roofline.ai>
Date: Sun, 29 Dec 2024 16:47:18 +0100
Subject: [PATCH 2/3] fix typos in comments

---
 .../lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 519e6e68bc1b0a..11cce3540ac436 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1242,7 +1242,7 @@ static Value rewriteI2ToI8SignedExt(PatternRewriter &rewriter, Location loc,
   assert(srcVecType.getElementType().isSignlessInteger(2) &&
          "Expected i2 type");
 
-  // 1. Generate a bitcast vector<Xxi2> -> vector<X/4xi8>.
+  // 1. Generate a bitcast vector<Xxi2> -> vector<X/2xi8>.
   SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
   constexpr int64_t i2Toi8BitwidthFactor = 4;
   i8VecShape.back() = i8VecShape.back() / i2Toi8BitwidthFactor;
@@ -1295,7 +1295,7 @@ static Value rewriteI2ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
   assert(srcVecType.getElementType().isSignlessInteger(2) &&
          "Expected i2 type");
 
-  // 1. Generate a bitcast vector<Xxi2> -> vector<X/4xi8>.
+  // 1. Generate a bitcast vector<Xxi2> -> vector<X/2xi8>.
   SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
   constexpr int64_t i2Toi8BitwidthFactor = 4;
   i8VecShape.back() = i8VecShape.back() / i2Toi8BitwidthFactor;

>From ec04f6028f735fb339e741701218d147018931ac Mon Sep 17 00:00:00 2001
From: Thomas Ziereis <ziereis at roofline.ai>
Date: Sun, 29 Dec 2024 16:48:25 +0100
Subject: [PATCH 3/3] fix typos in comments

---
 .../Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 11cce3540ac436..323da627de7bc2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1249,28 +1249,28 @@ static Value rewriteI2ToI8SignedExt(PatternRewriter &rewriter, Location loc,
   auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
   Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
 
-  // Position 0 (bits 0-1)
+  // Element 0 (bits 0-1)
   constexpr int8_t shiftConst6 = 6;
   auto shiftAttr6 = DenseElementsAttr::get(i8VecType, shiftConst6);
   auto shiftValues6 = rewriter.create<arith::ConstantOp>(loc, shiftAttr6);
   Value shl0 = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues6);
   Value elem0 = rewriter.create<arith::ShRSIOp>(loc, shl0, shiftValues6);
 
-  // Position 1 (bits 2-3)
+  // Element 1 (bits 2-3)
   constexpr int8_t shiftConst4 = 4;
   auto shiftAttr4 = DenseElementsAttr::get(i8VecType, shiftConst4);
   auto shiftValues4 = rewriter.create<arith::ConstantOp>(loc, shiftAttr4);
   Value shl1 = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues4);
   Value elem1 = rewriter.create<arith::ShRSIOp>(loc, shl1, shiftValues6);
 
-  // Position 1 (bits 4-5)
+  // Element 2 (bits 4-5)
   constexpr int8_t shiftConst2 = 2;
   auto shiftAttr2 = DenseElementsAttr::get(i8VecType, shiftConst2);
   auto shiftValues2 = rewriter.create<arith::ConstantOp>(loc, shiftAttr2);
   Value shl2 = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues2);
   Value elem2 = rewriter.create<arith::ShRSIOp>(loc, shl2, shiftValues6);
 
-  // Position 3 (bits 6-7)
+  // Element 3 (bits 6-7)
   Value elem3 = rewriter.create<arith::ShRSIOp>(loc, i8Vector, shiftValues6);
 
   // interleave all 4 elements by first interleaving even elements and then odd



More information about the Mlir-commits mailing list