[Mlir-commits] [mlir] [mlir][Vector] Replace vector.shuffle with vector.interleave in vector narrow type emulation (PR #82565)

Diego Caballero llvmlistbot at llvm.org
Wed Feb 21 17:14:42 PST 2024


https://github.com/dcaballe created https://github.com/llvm/llvm-project/pull/82565

WIP

>From de8a47ae91437a53b483ba72b04c9660df06fe37 Mon Sep 17 00:00:00 2001
From: Diego Caballero <diegocaballero at google.com>
Date: Wed, 21 Feb 2024 23:03:21 +0000
Subject: [PATCH 1/4] [mlir][Vector] Replace `vector.shuffle` with
 `vector.interleave` in vector narrow type emulation

This PR replaces the generation of `vector.shuffle` with
`vector.interleave` in the i4 conversions in vector narrow type
emulation. The multi dimensional semantics of `vector.interleave` allow
us to enable these conversion emulations also for multi dimensional
vectors.
---
 .../Transforms/VectorEmulateNarrowType.cpp    | 17 ++--
 .../Vector/vector-rewrite-narrow-types.mlir   | 82 +++++++++++++------
 2 files changed, 67 insertions(+), 32 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 36fb66708407b4..9ebe36cd3861e0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -724,9 +724,8 @@ BitCastRewriter::BitCastRewriter(VectorType sourceVectorType,
 static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter,
                                                   VectorType preconditionType,
                                                   Operation *op) {
-  if (!preconditionType || preconditionType.getRank() != 1 ||
-      preconditionType.isScalable())
-    return rewriter.notifyMatchFailure(op, "scalable or >1-D vector");
+  if (!preconditionType || preconditionType.isScalable())
+    return rewriter.notifyMatchFailure(op, "scalable vector");
 
   // TODO: consider relaxing this restriction in the future if we find ways
   // to really work with subbyte elements across the MLIR/LLVM boundary.
@@ -743,6 +742,9 @@ LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter,
   if (!enumerator.sourceVectorType || !enumerator.targetVectorType)
     return rewriter.notifyMatchFailure(op, "types are not vector");
 
+  if (!preconditionType || preconditionType.getRank() != 1)
+    return rewriter.notifyMatchFailure(op, "unsupported >1-D vector");
+
   return commonConversionPrecondition(rewriter, preconditionType, op);
 }
 
@@ -879,8 +881,7 @@ static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
     interleaveMaskValues.push_back(i + (vecDimSize / 2));
   }
 
-  return rewriter.create<vector::ShuffleOp>(
-      loc, low, high, rewriter.getI64ArrayAttr(interleaveMaskValues));
+  return rewriter.create<vector::InterleaveOp>(loc, low, high);
 }
 
 namespace {
@@ -1008,8 +1009,7 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
 ///        %1 = arith.shli %0, 4 : vector<4xi8>
 ///        %2 = arith.shrsi %1, 4 : vector<4xi8>
 ///        %3 = arith.shrsi %0, 4 : vector<4xi8>
-///        %4 = vector.shuffle %2, %3 [0, 4, 1, 5, 2, 6, 3, 7]
-///             : vector<4xi8>, vector<4xi8>
+///        %4 = vector.interleave %2, %3 : vector<4xi8>
 ///        %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32>
 ///
 ///    arith.sitofp %in : vector<8xi4> to vector<8xf32>
@@ -1018,8 +1018,7 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
 ///        %1 = arith.shli %0, 4 : vector<4xi8>
 ///        %2 = arith.shrsi %1, 4 : vector<4xi8>
 ///        %3 = arith.shrsi %0, 4 : vector<4xi8>
-///        %4 = vector.shuffle %2, %3 [0, 4, 1, 5, 2, 6, 3, 7]
-///             : vector<4xi8>, vector<4xi8>
+///        %4 = vector.interleave %2, %3 : vector<4xi8>
 ///        %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
 ///
 template <typename ConversionOpType>
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
index 02063a81664b81..94e78ce40a3c19 100644
--- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -195,53 +195,89 @@ func.func @f3ext(%a: vector<5xi8>) -> vector<8xi17> {
 
 // CHECK-LABEL: func.func @aligned_extsi(
 func.func @aligned_extsi(%a: vector<8xi4>) -> vector<8xi32> {
-  // CHECK: arith.shli
-  // CHECK: arith.shrsi
-  // CHECK: arith.shrsi
-  // CHECK: vector.shuffle
-  // CHECK: arith.extsi %{{.*}} : vector<8xi8> to vector<8xi32>
+// CHECK-SAME:    %[[IN:.*]]: vector<8xi4>) -> vector<8xi32> {
+// CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
+// CHECK:           %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK:           %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<4xi8>
+// CHECK:           %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
+// CHECK:           %[[I32:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32>
   %0 = arith.extsi %a : vector<8xi4> to vector<8xi32>
   return %0 : vector<8xi32>
 }
 
+// CHECK-LABEL: func.func @aligned_extsi_2d(
+func.func @aligned_extsi_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
+// CHECK-SAME:    %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xi32> {
+// CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8>
+// CHECK:           %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK:           %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK:           %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8>
+// CHECK:           %[[I32:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32>
+  %0 = arith.extsi %a : vector<8x32xi4> to vector<8x32xi32>
+  return %0 : vector<8x32xi32>
+}
+
 // CHECK-LABEL: func.func @aligned_extsi_base_case(
 func.func @aligned_extsi_base_case(%a: vector<8xi4>) -> vector<8xi8> {
-  // CHECK: arith.shli
-  // CHECK: arith.shrsi
-  // CHECK: arith.shrsi
-  // CHECK: vector.shuffle
-  // CHECK-NOT: arith.extsi
+// CHECK-SAME:    %[[IN:.*]]: vector<8xi4>) -> vector<8xi8> {
+// CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
+// CHECK:           %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK:           %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<4xi8>
+// CHECK:           %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
   %0 = arith.extsi %a : vector<8xi4> to vector<8xi8>
   return %0 : vector<8xi8>
 }
 
 // CHECK-LABEL: func.func @aligned_sitofp(
 func.func @aligned_sitofp(%a: vector<8xi4>) -> vector<8xf32> {
-  // CHECK: arith.shli
-  // CHECK: arith.shrsi
-  // CHECK: arith.shrsi
-  // CHECK: shuffle
-  // CHECK: arith.sitofp %{{.*}} : vector<8xi8> to vector<8xf32>
+// CHECK-SAME:    %[[IN:.*]]: vector<8xi4>) -> vector<8xf32> {
+// CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
+// CHECK:           %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK:           %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<4xi8>
+// CHECK:           %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
+// CHECK:           %[[F32:.*]] = arith.sitofp %[[INTERLEAVE]] : vector<8xi8> to vector<8xf32>
   %0 = arith.sitofp %a : vector<8xi4> to vector<8xf32>
   return %0 : vector<8xf32>
 }
 
+// CHECK-LABEL: func.func @aligned_sitofp_2d(
+func.func @aligned_sitofp_2d(%a: vector<8x32xi4>) -> vector<8x32xf32> {
+// CHECK-SAME:    %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xf32> {
+// CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8>
+// CHECK:           %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK:           %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK:           %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8>
+// CHECK:           %[[F32:.*]] = arith.sitofp %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xf32>
+  %0 = arith.sitofp %a : vector<8x32xi4> to vector<8x32xf32>
+  return %0 : vector<8x32xf32>
+}
+
 // CHECK-LABEL: func.func @i4_transpose(
-//  CHECK-SAME: %[[A:[0-9a-z]*]]
 func.func @i4_transpose(%a: vector<8x16xi4>) -> vector<16x8xi4> {
-  // CHECK: %[[EXT:.*]] = arith.extsi %[[A]] : vector<8x16xi4> to vector<8x16xi8>
-  // CHECK: %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8>
-  // CHECK: %[[TRUNC:.*]] = arith.trunci %[[TRANS]] : vector<16x8xi8> to vector<16x8xi4>
+// CHECK-SAME:    %[[IN:.*]]: vector<8x16xi4>) -> vector<16x8xi4> {
+// CHECK:           %[[EXT:.*]] = vector.interleave
+// CHECK:           %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8>
+// CHECK:           %[[TRUNC:.*]] = arith.trunci %[[TRANS]] : vector<16x8xi8> to vector<16x8xi4>
   %0 = vector.transpose %a, [1, 0] : vector<8x16xi4> to vector<16x8xi4>
   return %0 : vector<16x8xi4>
 }
 
 // CHECK-LABEL: func.func @i7_transpose(
-//  CHECK-SAME: %[[A:[0-9a-z]*]]
 func.func @i7_transpose(%a: vector<8x16xi7>) -> vector<16x8xi7> {
-  // CHECK: %[[EXT:.*]] = arith.extsi %[[A]] : vector<8x16xi7> to vector<8x16xi8>
-  // CHECK: %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8>
-  // CHECK: %[[TRUNC:.*]] = arith.trunci %[[TRANS]] : vector<16x8xi8> to vector<16x8xi7>
+// CHECK-SAME:    %[[IN:.*]]: vector<8x16xi7>) -> vector<16x8xi7> {
+// CHECK:           %[[EXT:.*]] = arith.extsi %[[IN]] : vector<8x16xi7> to vector<8x16xi8>
+// CHECK:           %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8>
+// CHECK:           %[[TRUNC:.*]] = arith.trunci %[[TRANS]] : vector<16x8xi8> to vector<16x8xi7>
   %0 = vector.transpose %a, [1, 0] : vector<8x16xi7> to vector<16x8xi7>
   return %0 : vector<16x8xi7>
 }

>From 87f772ff7352a47018849c39ac05175309a190e5 Mon Sep 17 00:00:00 2001
From: Diego Caballero <diegocaballero at google.com>
Date: Thu, 22 Feb 2024 00:10:27 +0000
Subject: [PATCH 2/4] Feedback

---
 .../Vector/Transforms/VectorEmulateNarrowType.cpp        | 9 +--------
 1 file changed, 1 insertion(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 9ebe36cd3861e0..41a778b8496ef0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -873,14 +873,7 @@ static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
   Value low = rewriter.create<arith::ShRSIOp>(loc, shl, shiftValues);
   Value high = rewriter.create<arith::ShRSIOp>(loc, i8Vector, shiftValues);
 
-  // 3. Interleave low and high i8 elements using a shuffle.
-  SmallVector<int64_t> interleaveMaskValues;
-  interleaveMaskValues.reserve(vecDimSize);
-  for (int i = 0, end = vecDimSize / 2; i < end; ++i) {
-    interleaveMaskValues.push_back(i);
-    interleaveMaskValues.push_back(i + (vecDimSize / 2));
-  }
-
+  // 3. Interleave low and high i8 elements.
   return rewriter.create<vector::InterleaveOp>(loc, low, high);
 }
 

>From 4a2fa749f3cf746baeffa81b3a15554b6bd88448 Mon Sep 17 00:00:00 2001
From: Diego Caballero <diegocaballero at google.com>
Date: Thu, 22 Feb 2024 00:55:54 +0000
Subject: [PATCH 3/4] Remove unused var

---
 mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp | 1 -
 1 file changed, 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 41a778b8496ef0..fc11ae63e718a5 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -857,7 +857,6 @@ static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
          "Expected i4 type");
 
   // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
-  int64_t vecDimSize = srcVecType.getShape().back();
   SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
   constexpr int64_t i4Toi8BitwidthFactor = 2;
   i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor;

>From f2021d70a7f5da5d29f5513c5c6b6df6090261f8 Mon Sep 17 00:00:00 2001
From: Diego Caballero <diegocaballero at google.com>
Date: Thu, 22 Feb 2024 01:13:06 +0000
Subject: [PATCH 4/4] [mlir][Vector] Add support for trunci to narrow type
 emulation

WIP
---
 .../Transforms/VectorEmulateNarrowType.cpp    | 125 +++++++++++++++++-
 .../Vector/vector-rewrite-narrow-types.mlir   |  15 +++
 2 files changed, 134 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index fc11ae63e718a5..394041bd2b2b20 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -729,8 +729,8 @@ static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter,
 
   // TODO: consider relaxing this restriction in the future if we find ways
   // to really work with subbyte elements across the MLIR/LLVM boundary.
-  unsigned resultBitwidth = preconditionType.getElementTypeBitWidth();
-  if (resultBitwidth % 8 != 0)
+  unsigned bitwidth = preconditionType.getElementTypeBitWidth();
+  if (bitwidth % 8 != 0)
     return rewriter.notifyMatchFailure(op, "bitwidth is not k * 8");
 
   return success();
@@ -876,6 +876,54 @@ static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
   return rewriter.create<vector::InterleaveOp>(loc, low, high);
 }
 
+/// Rewrite the i8 -> i4 signed extension into a sequence of shuffles and
+/// bitwise ops that take advantage of high-level information to avoid leaving
+/// LLVM to scramble with peephole optimizations.
+static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc,
+                                Value srcValue) {
+  VectorType srcVecType = cast<VectorType>(srcValue.getType());
+  assert(srcVecType.getElementType().isSignlessInteger(8) &&
+         "Expected i8 type");
+
+  // 1. Zero out the upper side of each i8 element.
+  constexpr int8_t i8BitMask = 0x0F;
+  Value zeroOutMask = rewriter.create<arith::ConstantOp>(
+      loc, DenseElementsAttr::get(srcVecType, i8BitMask));
+  Value zeroOutSrc = rewriter.create<arith::AndIOp>(loc, srcValue, zeroOutMask);
+
+  // 2. De-interleave low and high i8 elements.
+  int64_t vecDimSize = srcVecType.getShape().back();
+  SmallVector<int64_t> deinterleaveLowMaskValues;
+  SmallVector<int64_t> deinterleaveHighMaskValues;
+  deinterleaveLowMaskValues.reserve(vecDimSize / 2);
+  deinterleaveHighMaskValues.reserve(vecDimSize / 2);
+  for (int i = 0, end = vecDimSize; i < end; i += 2) {
+    deinterleaveLowMaskValues.push_back(i);
+    deinterleaveHighMaskValues.push_back(i + 1);
+  }
+
+  auto lowShuffleOp = rewriter.create<vector::ShuffleOp>(
+      loc, zeroOutSrc, zeroOutSrc,
+      rewriter.getI64ArrayAttr(deinterleaveLowMaskValues));
+  auto highShuffleOp = rewriter.create<vector::ShuffleOp>(
+      loc, zeroOutSrc, zeroOutSrc,
+      rewriter.getI64ArrayAttr(deinterleaveHighMaskValues));
+
+  // 3. Move high i4 values to upper side of the byte.
+  constexpr int8_t bitsToShift = 4;
+  VectorType deinterI8VecType = highShuffleOp.getResultVectorType();
+  auto shiftValues = rewriter.create<arith::ConstantOp>(
+      loc, DenseElementsAttr::get(deinterI8VecType, bitsToShift));
+  auto shlHighOp = rewriter.create<arith::ShLIOp>(loc, highShuffleOp, shiftValues);
+
+  // 4. Merge high and low i4 values.
+  auto mergedHiLowOp = rewriter.create<arith::OrIOp>(loc, shlHighOp, lowShuffleOp);
+
+  // 5. Generate a bitcast vector<Xxi8> -> vector<2Xxi4>.
+  auto i4VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI4Type());
+  return rewriter.create<vector::BitCastOp>(loc, i4VecType, mergedHiLowOp);
+}
+
 namespace {
 /// Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take
 /// advantage of high-level information to avoid leaving LLVM to scramble with
@@ -1019,7 +1067,7 @@ struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
 
   LogicalResult matchAndRewrite(ConversionOpType conversionOp,
                                 PatternRewriter &rewriter) const override {
-    // Set up the BitCastRewriter and verify the preconditions.
+    // Verify the preconditions.
     Value srcValue = conversionOp.getIn();
     auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
     auto dstVecType = dyn_cast<VectorType>(conversionOp.getType());
@@ -1027,8 +1075,9 @@ struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
             commonConversionPrecondition(rewriter, dstVecType, conversionOp)))
       return failure();
 
-    // Check general alignment preconditions.
-    if (failed(alignedConversionPrecondition(rewriter, srcVecType, dstVecType,
+    // Check general alignment preconditions. We invert the src/dst type order
+    // to resue the extension preconditions.
+    if (failed(alignedConversionPrecondition(rewriter, dstVecType, srcVecType,
                                              conversionOp)))
       return failure();
 
@@ -1043,6 +1092,69 @@ struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
   }
 };
 
+/// Rewrite the i8 -> i4 part of any truncation into a sequence of shuffles and
+/// bitwise ops that take advantage of high-level information to avoid leaving
+/// LLVM to scramble with peephole optimizations.
+///
+/// For example:
+///    arith.extsi %in : vector<8xi4> to vector<8xi32>
+///      is rewriten as
+///        %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
+///        %1 = arith.shli %0, 4 : vector<4xi8>
+///        %2 = arith.shrsi %1, 4 : vector<4xi8>
+///        %3 = arith.shrsi %0, 4 : vector<4xi8>
+///        %4 = vector.interleave %2, %3 : vector<4xi8>
+///        %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32>
+///
+///    arith.sitofp %in : vector<8xi4> to vector<8xf32>
+///      is rewriten as
+///        %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
+///        %1 = arith.shli %0, 4 : vector<4xi8>
+///        %2 = arith.shrsi %1, 4 : vector<4xi8>
+///        %3 = arith.shrsi %0, 4 : vector<4xi8>
+///        %4 = vector.interleave %2, %3 : vector<4xi8>
+///        %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
+///
+struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
+  using OpRewritePattern<arith::TruncIOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(arith::TruncIOp truncOp,
+                                PatternRewriter &rewriter) const override {
+    // Verify the preconditions.
+    Value srcValue = truncOp.getIn();
+    auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
+    auto dstVecType = dyn_cast<VectorType>(truncOp.getType());
+
+    // Only single dim vectors are supported until we have
+    // `vector.deinterleave`.
+    if (srcVecType.getRank() != 1)
+      return failure();
+
+    if (failed(
+            commonConversionPrecondition(rewriter, srcVecType, truncOp)))
+      return failure();
+
+    // Check general alignment preconditions. We invert the src/dst type order
+    // to reuse the existing precondition logic.
+    if (failed(alignedConversionPrecondition(rewriter, dstVecType, srcVecType,
+                                             truncOp)))
+      return failure();
+
+    // Create a new iX -> i8 truncation op.
+    Location loc = truncOp.getLoc();
+    auto i8VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI8Type());
+    Value i8TruncVal = rewriter.create<arith::TruncIOp>(loc, i8VecType, srcValue);
+
+    // Rewrite the i8 -> i4 truncation part.
+    Value subByteTrunc = rewriteI8ToI4Trunc(rewriter, loc, i8TruncVal);
+
+    // Finalize the rewrite.
+    rewriter.replaceOp(truncOp, subByteTrunc);
+    return success();
+  }
+};
+
+
 /// Rewrite a sub-byte vector transpose into a sequence of instructions that
 /// perform the transpose on wider (byte) element types.
 /// For example:
@@ -1115,7 +1227,8 @@ void vector::populateVectorNarrowTypeRewritePatterns(
   // Patterns for aligned cases. We set higher priority as they are expected to
   // generate better performance for aligned cases.
   patterns.add<RewriteAlignedSubByteIntSignedExt<arith::ExtSIOp>,
-               RewriteAlignedSubByteIntSignedExt<arith::SIToFPOp>>(
+               RewriteAlignedSubByteIntSignedExt<arith::SIToFPOp>,
+               RewriteAlignedSubByteIntTrunc>(
       patterns.getContext(), benefit.getBenefit() + 1);
 }
 
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
index 94e78ce40a3c19..497ca7e876d208 100644
--- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -262,6 +262,21 @@ func.func @aligned_sitofp_2d(%a: vector<8x32xi4>) -> vector<8x32xf32> {
   return %0 : vector<8x32xf32>
 }
 
+// CHECK-LABEL: func.func @aligned_trunci_base_case(
+func.func @aligned_trunci_base_case(%a: vector<8xi8>) -> vector<8xi4> {
+// CHECK-SAME:    %[[VAL_0:.*]]: vector<8xi8>) -> vector<8xi4> {
+// CHECK:           %[[VAL_1:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK:           %[[VAL_2:.*]] = arith.constant dense<15> : vector<8xi8>
+// CHECK:           %[[VAL_3:.*]] = arith.andi %[[VAL_0]], %[[VAL_2]] : vector<8xi8>
+// CHECK:           %[[VAL_4:.*]] = vector.shuffle %[[VAL_3]], %[[VAL_3]] [0, 2, 4, 6] : vector<8xi8>, vector<8xi8>
+// CHECK:           %[[VAL_5:.*]] = vector.shuffle %[[VAL_3]], %[[VAL_3]] [1, 3, 5, 7] : vector<8xi8>, vector<8xi8>
+// CHECK:           %[[VAL_6:.*]] = arith.shli %[[VAL_5]], %[[VAL_1]] : vector<4xi8>
+// CHECK:           %[[VAL_7:.*]] = arith.ori %[[VAL_6]], %[[VAL_4]] : vector<4xi8>
+// CHECK:           %[[VAL_8:.*]] = vector.bitcast %[[VAL_7]] : vector<4xi8> to vector<8xi4>
+  %0 = arith.trunci %a : vector<8xi8> to vector<8xi4>
+  return %0 : vector<8xi4>
+}
+
 // CHECK-LABEL: func.func @i4_transpose(
 func.func @i4_transpose(%a: vector<8x16xi4>) -> vector<16x8xi4> {
 // CHECK-SAME:    %[[IN:.*]]: vector<8x16xi4>) -> vector<16x8xi4> {



More information about the Mlir-commits mailing list