[Mlir-commits] [mlir] [mlir][Vector] Add a rewrite pattern for better low-precision ext(bit… (PR #66648)

Nicolas Vasilache llvmlistbot at llvm.org
Mon Sep 18 06:49:14 PDT 2023


https://github.com/nicolasvasilache updated https://github.com/llvm/llvm-project/pull/66648

>From 8bbacf92659ce4552bbf487a09c5d3229364545f Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nicolasvasilache at users.noreply.github.com>
Date: Fri, 8 Sep 2023 11:28:55 +0200
Subject: [PATCH] [mlir][Vector] Add a rewrite pattern for better low-precision
 ext(bitcast) expansion

This revision adds a rewrite for sequences of vector `ext(bitcast)` to use a more efficient sequence of
vector operations comprising `shuffle` and `bitwise` ops.

Such patterns appear naturally when writing quantization / dequantization functionality with the vector dialect.

The rewrite performs a simple enumeration of each of the bits in the result vector and determines its provenance
in the source vector. The enumeration is used to generate the proper sequence of `shuffle`, `andi`, `ori`
with shifts`.

The rewrite currently only applies to 1-D non-scalable vectors and bails out if the final vector element type is
not a multiple of 8. This is a failsafe heuristic determined empirically: if the resulting type
is not an even number of bytes, further complexities arise that are not improved by this pattern:
the heavy lifting still needs to be done by LLVM.
---
 .../Vector/Transforms/VectorRewritePatterns.h |   7 +
 .../Transforms/VectorEmulateNarrowType.cpp    | 428 ++++++++++++------
 .../Vector/vector-rewrite-narrow-types.mlir   | 295 ++++++------
 .../Vector/CPU/test-rewrite-narrow-types.mlir |  46 ++
 4 files changed, 491 insertions(+), 285 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 8652fc7f5e5c640..eb561ba3b23557a 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -23,6 +23,7 @@ namespace mlir {
 class RewritePatternSet;
 
 namespace arith {
+class AndIOp;
 class NarrowTypeEmulationConverter;
 class TruncIOp;
 } // namespace arith
@@ -309,6 +310,12 @@ FailureOr<Value> rewriteBitCastOfTruncI(RewriterBase &rewriter,
                                         arith::TruncIOp truncOp,
                                         vector::BroadcastOp maybeBroadcastOp);
 
+/// Rewrite a vector `ext(bitcast)` to use a more efficient sequence of
+/// vector operations comprising `shuffle` and `bitwise` ops.
+FailureOr<Value> rewriteExtOfBitCast(RewriterBase &rewriter, Operation *extOp,
+                                     vector::BitCastOp bitCastOp,
+                                     vector::BroadcastOp maybeBroadcastOp);
+
 /// Appends patterns for rewriting vector operations over narrow types with
 /// ops over wider types.
 void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 9d659bf694a2445..9b85236b7d09b16 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -18,11 +18,15 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/IR/Value.h"
+#include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/APInt.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
 #include <cstdint>
+#include <numeric>
 
 using namespace mlir;
 
@@ -224,6 +228,106 @@ struct BitCastBitsEnumerator {
   SmallVector<SourceElementRangeList> sourceElementRanges;
 };
 
+/// Rewrite vector.bitcast to a sequence of shuffles and bitwise ops that take
+/// advantage of high-level information to avoid leaving LLVM to scramble with
+/// peephole optimizations.
+/// BitCastBitsEnumerator encodes for each element of the target vector the
+/// provenance of the bits in the source vector. We can "transpose" this
+/// information to build a sequence of shuffles and bitwise ops that will
+/// produce the desired result.
+//
+/// Consider the following motivating example:
+/// ```
+///   %1 = vector.bitcast %0 : vector<32xi5> to vector<20xi8>
+/// ```
+//
+/// BitCastBitsEnumerator contains the following information:
+/// ```
+///   { 0: b@[0..5) lshl: 0}{ 1: b@[0..3) lshl: 5}
+///   { 1: b@[3..5) lshl: 0}{ 2: b@[0..5) lshl: 2}{ 3: b@[0..1) lshl: 7}
+///   { 3: b@[1..5) lshl: 0}{ 4: b@[0..4) lshl: 4}
+///   { 4: b@[4..5) lshl: 0}{ 5: b@[0..5) lshl: 1}{ 6: b@[0..2) lshl: 6}
+///   { 6: b@[2..5) lshl: 0}{ 7: b@[0..5) lshl: 3}
+///   { 8: b@[0..5) lshl: 0}{ 9: b@[0..3) lshl: 5}
+///   { 9: b@[3..5) lshl: 0}{10: b@[0..5) lshl: 2}{11: b@[0..1) lshl: 7}
+///   {11: b@[1..5) lshl: 0}{12: b@[0..4) lshl: 4}
+///   {12: b@[4..5) lshl: 0}{13: b@[0..5) lshl: 1}{14: b@[0..2) lshl: 6}
+///   {14: b@[2..5) lshl: 0}{15: b@[0..5) lshl: 3}
+///   {16: b@[0..5) lshl: 0}{17: b@[0..3) lshl: 5}
+///   {17: b@[3..5) lshl: 0}{18: b@[0..5) lshl: 2}{19: b@[0..1) lshl: 7}
+///   {19: b@[1..5) lshl: 0}{20: b@[0..4) lshl: 4}
+///   {20: b@[4..5) lshl: 0}{21: b@[0..5) lshl: 1}{22: b@[0..2) lshl: 6}
+///   {22: b@[2..5) lshl: 0}{23: b@[0..5) lshl: 3}
+///   {24: b@[0..5) lshl: 0}{25: b@[0..3) lshl: 5}
+///   {25: b@[3..5) lshl: 0}{26: b@[0..5) lshl: 2}{27: b@[0..1) lshl: 7}
+///   {27: b@[1..5) lshl: 0}{28: b@[0..4) lshl: 4}
+///   {28: b@[4..5) lshl: 0}{29: b@[0..5) lshl: 1}{30: b@[0..2) lshl: 6}
+///   {30: b@[2..5) lshl: 0}{31: b@[0..5) lshl: 3}
+/// ```
+///
+/// In the above, each row represents one target vector element and each
+/// column represents one bit contribution from a source vector element.
+/// The algorithm creates vector.shuffle operations (in this case there are 3
+/// shuffles (i.e. the max number of columns in BitCastBitsEnumerator). The
+/// algorithm populates the bits as follows:
+/// ```
+///     src bits 0 ...
+/// 1st shuffle |xxxxx   |xx      |...
+/// 2nd shuffle |     xxx|  xxxxx |...
+/// 3rd shuffle |        |       x|...
+/// ```
+//
+/// The algorithm proceeds as follows:
+///   1. for each vector.shuffle, collect the source vectors that participate in
+///     this shuffle. One source vector per target element of the resulting
+///     vector.shuffle. If there is no source element contributing bits for the
+///     current vector.shuffle, take 0 (i.e. row 0 in the above example has only
+///     2 columns).
+///   2. represent the bitrange in the source vector as a mask. If there is no
+///     source element contributing bits for the current vector.shuffle, take 0.
+///   3. shift right by the proper amount to align the source bitrange at
+///     position 0. This is exactly the low end of the bitrange. For instance,
+///     the first element of row 2 is `{ 1: b@[3..5) lshl: 0}` and one needs to
+///     shift right by 3 to get the bits contributed by the source element #1
+///     into position 0.
+///   4. shift left by the proper amount to to align to the desired position in
+///     the result element vector.  For instance, the contribution of the second
+///     source element for the first row needs to be shifted by `5` to form the
+///     first i8 result element.
+///
+/// Eventually, we end up building  the sequence
+/// `(shuffle -> and -> shiftright -> shiftleft -> or)` to iteratively update
+/// the result vector (i.e. the `shiftright -> shiftleft -> or` part) with the
+/// bits extracted from the source vector (i.e. the `shuffle -> and` part).
+struct BitCastRewriter {
+  /// Helper metadata struct to hold the static quantities for the rewrite.
+  struct Metadata {
+    SmallVector<int64_t> shuffles;
+    SmallVector<Attribute> masks, shiftRightAmounts, shiftLeftAmounts;
+  };
+
+  BitCastRewriter(VectorType sourceVectorType, VectorType targetVectorType);
+
+  /// Verify that the preconditions for the rewrite are met.
+  LogicalResult precondition(PatternRewriter &rewriter,
+                             VectorType targetVectorType, Operation *op);
+
+  /// Precompute the metadata for the rewrite.
+  SmallVector<BitCastRewriter::Metadata>
+  precomputeMetadata(IntegerType shuffledElementType);
+
+  /// Rewrite one step of the sequence:
+  ///   `(shuffle -> and -> shiftright -> shiftleft -> or)`.
+  Value rewriteStep(PatternRewriter &rewriter, Location loc, Value initialValue,
+                    Value runningResult,
+                    const BitCastRewriter::Metadata &metadata);
+
+private:
+  /// Underlying enumerator that encodes the provenance of the bits in the each
+  /// element of the result vector.
+  BitCastBitsEnumerator enumerator;
+};
+
 } // namespace
 
 static raw_ostream &operator<<(raw_ostream &os,
@@ -256,7 +360,7 @@ BitCastBitsEnumerator::BitCastBitsEnumerator(VectorType sourceVectorType,
   LDBG("targetVectorType: " << targetVectorType);
 
   int64_t bitwidth = targetBitWidth * mostMinorTargetDim;
-  (void) mostMinorSourceDim;
+  (void)mostMinorSourceDim;
   assert(bitwidth == sourceBitWidth * mostMinorSourceDim &&
          "source and target bitwidths must match");
 
@@ -275,79 +379,104 @@ BitCastBitsEnumerator::BitCastBitsEnumerator(VectorType sourceVectorType,
   }
 }
 
+BitCastRewriter::BitCastRewriter(VectorType sourceVectorType,
+                                 VectorType targetVectorType)
+    : enumerator(BitCastBitsEnumerator(sourceVectorType, targetVectorType)) {
+  LDBG("\n" << enumerator.sourceElementRanges);
+}
+
+LogicalResult BitCastRewriter::precondition(PatternRewriter &rewriter,
+                                            VectorType targetVectorType,
+                                            Operation *op) {
+  if (targetVectorType.getRank() != 1 || targetVectorType.isScalable())
+    return rewriter.notifyMatchFailure(op, "scalable or >1-D vector");
+
+  // TODO: consider relaxing this restriction in the future if we find ways
+  // to really work with subbyte elements across the MLIR/LLVM boundary.
+  int64_t resultBitwidth = targetVectorType.getElementTypeBitWidth();
+  if (resultBitwidth % 8 != 0)
+    return rewriter.notifyMatchFailure(op, "bitwidth is not k * 8");
+
+  return success();
+}
+
+SmallVector<BitCastRewriter::Metadata>
+BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
+  SmallVector<BitCastRewriter::Metadata> result;
+  for (int64_t shuffleIdx = 0, e = enumerator.getMaxNumberOfEntries();
+       shuffleIdx < e; ++shuffleIdx) {
+    SmallVector<int64_t> shuffles;
+    SmallVector<Attribute> masks, shiftRightAmounts, shiftLeftAmounts;
+
+    // Create the attribute quantities for the shuffle / mask / shift ops.
+    for (auto &l : enumerator.sourceElementRanges) {
+      int64_t sourceElement =
+          (shuffleIdx < (int64_t)l.size()) ? l[shuffleIdx].sourceElementIdx : 0;
+      shuffles.push_back(sourceElement);
+
+      int64_t bitLo =
+          (shuffleIdx < (int64_t)l.size()) ? l[shuffleIdx].sourceBitBegin : 0;
+      int64_t bitHi =
+          (shuffleIdx < (int64_t)l.size()) ? l[shuffleIdx].sourceBitEnd : 0;
+      IntegerAttr mask = IntegerAttr::get(
+          shuffledElementType,
+          llvm::APInt::getBitsSet(shuffledElementType.getIntOrFloatBitWidth(),
+                                  bitLo, bitHi));
+      masks.push_back(mask);
+
+      int64_t shiftRight = bitLo;
+      shiftRightAmounts.push_back(
+          IntegerAttr::get(shuffledElementType, shiftRight));
+
+      int64_t shiftLeft = l.computeLeftShiftAmount(shuffleIdx);
+      shiftLeftAmounts.push_back(
+          IntegerAttr::get(shuffledElementType, shiftLeft));
+    }
+
+    result.push_back({shuffles, masks, shiftRightAmounts, shiftLeftAmounts});
+  }
+  return result;
+}
+
+Value BitCastRewriter::rewriteStep(PatternRewriter &rewriter, Location loc,
+                                   Value initialValue, Value runningResult,
+                                   const BitCastRewriter::Metadata &metadata) {
+  // Create vector.shuffle from the metadata.
+  auto shuffleOp = rewriter.create<vector::ShuffleOp>(
+      loc, initialValue, initialValue, metadata.shuffles);
+
+  // Intersect with the mask.
+  VectorType shuffledVectorType = shuffleOp.getResultVectorType();
+  auto constOp = rewriter.create<arith::ConstantOp>(
+      loc, DenseElementsAttr::get(shuffledVectorType, metadata.masks));
+  Value andValue = rewriter.create<arith::AndIOp>(loc, shuffleOp, constOp);
+
+  // Align right on 0.
+  auto shiftRightConstantOp = rewriter.create<arith::ConstantOp>(
+      loc,
+      DenseElementsAttr::get(shuffledVectorType, metadata.shiftRightAmounts));
+  Value shiftedRight =
+      rewriter.create<arith::ShRUIOp>(loc, andValue, shiftRightConstantOp);
+
+  // Shift bits left into their final position.
+  auto shiftLeftConstantOp = rewriter.create<arith::ConstantOp>(
+      loc,
+      DenseElementsAttr::get(shuffledVectorType, metadata.shiftLeftAmounts));
+  Value shiftedLeft =
+      rewriter.create<arith::ShLIOp>(loc, shiftedRight, shiftLeftConstantOp);
+
+  runningResult =
+      runningResult
+          ? rewriter.create<arith::OrIOp>(loc, runningResult, shiftedLeft)
+          : shiftedLeft;
+
+  return runningResult;
+}
+
 namespace {
 /// Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take
 /// advantage of high-level information to avoid leaving LLVM to scramble with
 /// peephole optimizations.
-
-// BitCastBitsEnumerator encodes for each element of the target vector the
-// provenance of the bits in the source vector. We can "transpose" this
-// information to build a sequence of shuffles and bitwise ops that will
-// produce the desired result.
-//
-// Let's take the following motivating example to explain the algorithm:
-// ```
-//   %0 = arith.trunci %a : vector<32xi64> to vector<32xi5>
-//   %1 = vector.bitcast %0 : vector<32xi5> to vector<20xi8>
-// ```
-//
-// BitCastBitsEnumerator contains the following information:
-// ```
-//   { 0: b@[0..5) lshl: 0}{1: b@[0..3) lshl: 5 }
-//   { 1: b@[3..5) lshl: 0}{2: b@[0..5) lshl: 2}{3: b@[0..1) lshl: 7 }
-//   { 3: b@[1..5) lshl: 0}{4: b@[0..4) lshl: 4 }
-//   { 4: b@[4..5) lshl: 0}{5: b@[0..5) lshl: 1}{6: b@[0..2) lshl: 6 }
-//   { 6: b@[2..5) lshl: 0}{7: b@[0..5) lshl: 3 }
-//   { 8: b@[0..5) lshl: 0}{9: b@[0..3) lshl: 5 }
-//   { 9: b@[3..5) lshl: 0}{10: b@[0..5) lshl: 2}{11: b@[0..1) lshl: 7 }
-//   { 11: b@[1..5) lshl: 0}{12: b@[0..4) lshl: 4 }
-//   { 12: b@[4..5) lshl: 0}{13: b@[0..5) lshl: 1}{14: b@[0..2) lshl: 6 }
-//   { 14: b@[2..5) lshl: 0}{15: b@[0..5) lshl: 3}
-//   { 16: b@[0..5) lshl: 0}{17: b@[0..3) lshl: 5}
-//   { 17: b@[3..5) lshl: 0}{18: b@[0..5) lshl: 2}{19: b@[0..1) lshl: 7}
-//   { 19: b@[1..5) lshl: 0}{20: b@[0..4) lshl: 4}
-//   { 20: b@[4..5) lshl: 0}{21: b@[0..5) lshl: 1 }{22: b@[0..2) lshl: 6}
-//   { 22: b@[2..5) lshl: 0}{23: b@[0..5) lshl: 3 }
-//   { 24: b@[0..5) lshl: 0}{25: b@[0..3) lshl: 5 }
-//   { 25: b@[3..5) lshl: 0}{26: b@[0..5) lshl: 2}{27: b@[0..1) lshl: 7 }
-//   { 27: b@[1..5) lshl: 0}{28: b@[0..4) lshl: 4}
-//   { 28: b@[4..5) lshl: 0}{29: b@[0..5) lshl: 1}{30: b@[0..2) lshl: 6}
-//   { 30: b@[2..5) lshl: 0}{31: b@[0..5) lshl: 3 }
-// ```
-//
-// In the above, each row represents one target vector element and each
-// column represents one bit contribution from a source vector element.
-// The algorithm creates vector.shuffle operations (in this case there are 3
-// shuffles (i.e. the max number of columns in BitCastBitsEnumerator). The
-// algorithm populates the bits as follows:
-// ```
-//     src bits 0 ...
-// 1st shuffle |xxxxx   |xx      |...
-// 2nd shuffle |     xxx|  xxxxx |...
-// 3rd shuffle |        |       x|...
-// ```
-//
-// The algorithm proceeds as follows:
-//   1. for each vector.shuffle, collect the source vectors that participate in
-//     this shuffle. One source vector per target element of the resulting
-//     vector.shuffle. If there is no source element contributing bits for the
-//     current vector.shuffle, take 0 (i.e. row 0 in the above example has only
-//     2 columns).
-//   2. represent the bitrange in the source vector as a mask. If there is no
-//     source element contributing bits for the current vector.shuffle, take 0.
-//   3. shift right by the proper amount to align the source bitrange at
-//     position 0. This is exactly the low end of the bitrange. For instance,
-//     the first element of row 2 is `{ 1: b@[3..5) lshl: 0}` and one needs to
-//     shift right by 3 to get the bits contributed by the source element #1
-//     into position 0.
-//   4. shift left by the proper amount to to align to the desired position in
-//     the result element vector.  For instance, the contribution of the second
-//     source element for the first row needs to be shifted by `5` to form the
-//     first i8 result element.
-// Eventually, we end up building  the sequence
-// `(shuffle -> and -> shiftright -> shiftleft -> or)` to iteratively update the
-// result vector (i.e. the `shiftright -> shiftleft -> or` part) with the bits
-// extracted from the source vector (i.e. the `shuffle -> and` part).
 struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
   using OpRewritePattern::OpRewritePattern;
 
@@ -359,93 +488,92 @@ struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
     if (!truncOp)
       return rewriter.notifyMatchFailure(bitCastOp, "not a trunci source");
 
+    // Set up the BitCastRewriter and verify the precondition.
+    VectorType sourceVectorType = bitCastOp.getSourceVectorType();
     VectorType targetVectorType = bitCastOp.getResultVectorType();
-    if (targetVectorType.getRank() != 1 || targetVectorType.isScalable())
-      return rewriter.notifyMatchFailure(bitCastOp, "scalable or >1-D vector");
-    // TODO: consider relaxing this restriction in the future if we find ways
-    // to really work with subbyte elements across the MLIR/LLVM boundary.
-    int64_t resultBitwidth = targetVectorType.getElementTypeBitWidth();
-    if (resultBitwidth % 8 != 0)
-      return rewriter.notifyMatchFailure(bitCastOp, "bitwidth is not k * 8");
+    BitCastRewriter bcr(sourceVectorType, targetVectorType);
+    if (failed(bcr.precondition(rewriter, targetVectorType, bitCastOp)))
+      return failure();
 
-    VectorType sourceVectorType = bitCastOp.getSourceVectorType();
-    BitCastBitsEnumerator be(sourceVectorType, targetVectorType);
-    LDBG("\n" << be.sourceElementRanges);
-
-    Value initialValue = truncOp.getIn();
-    auto initalVectorType = initialValue.getType().cast<VectorType>();
-    auto initalElementType = initalVectorType.getElementType();
-    auto initalElementBitWidth = initalElementType.getIntOrFloatBitWidth();
-
-    Value res;
-    for (int64_t shuffleIdx = 0, e = be.getMaxNumberOfEntries(); shuffleIdx < e;
-         ++shuffleIdx) {
-      SmallVector<int64_t> shuffles;
-      SmallVector<Attribute> masks, shiftRightAmounts, shiftLeftAmounts;
-
-      // Create the attribute quantities for the shuffle / mask / shift ops.
-      for (auto &srcEltRangeList : be.sourceElementRanges) {
-        bool idxContributesBits =
-            (shuffleIdx < (int64_t)srcEltRangeList.size());
-        int64_t sourceElementIdx =
-            idxContributesBits ? srcEltRangeList[shuffleIdx].sourceElementIdx
-                               : 0;
-        shuffles.push_back(sourceElementIdx);
-
-        int64_t bitLo = (shuffleIdx < (int64_t)srcEltRangeList.size())
-                            ? srcEltRangeList[shuffleIdx].sourceBitBegin
-                            : 0;
-        int64_t bitHi = (shuffleIdx < (int64_t)srcEltRangeList.size())
-                            ? srcEltRangeList[shuffleIdx].sourceBitEnd
-                            : 0;
-        IntegerAttr mask = IntegerAttr::get(
-            rewriter.getIntegerType(initalElementBitWidth),
-            llvm::APInt::getBitsSet(initalElementBitWidth, bitLo, bitHi));
-        masks.push_back(mask);
-
-        int64_t shiftRight = bitLo;
-        shiftRightAmounts.push_back(IntegerAttr::get(
-            rewriter.getIntegerType(initalElementBitWidth), shiftRight));
-
-        int64_t shiftLeft = srcEltRangeList.computeLeftShiftAmount(shuffleIdx);
-        shiftLeftAmounts.push_back(IntegerAttr::get(
-            rewriter.getIntegerType(initalElementBitWidth), shiftLeft));
-      }
-
-      // Create vector.shuffle #shuffleIdx.
-      auto shuffleOp = rewriter.create<vector::ShuffleOp>(
-          bitCastOp.getLoc(), initialValue, initialValue, shuffles);
-      // And with the mask.
-      VectorType vt = VectorType::Builder(initalVectorType)
-                          .setDim(initalVectorType.getRank() - 1, masks.size());
-      auto constOp = rewriter.create<arith::ConstantOp>(
-          bitCastOp.getLoc(), DenseElementsAttr::get(vt, masks));
-      Value andValue = rewriter.create<arith::AndIOp>(bitCastOp.getLoc(),
-                                                      shuffleOp, constOp);
-      // Align right on 0.
-      auto shiftRightConstantOp = rewriter.create<arith::ConstantOp>(
-          bitCastOp.getLoc(), DenseElementsAttr::get(vt, shiftRightAmounts));
-      Value shiftedRight = rewriter.create<arith::ShRUIOp>(
-          bitCastOp.getLoc(), andValue, shiftRightConstantOp);
-
-      auto shiftLeftConstantOp = rewriter.create<arith::ConstantOp>(
-          bitCastOp.getLoc(), DenseElementsAttr::get(vt, shiftLeftAmounts));
-      Value shiftedLeft = rewriter.create<arith::ShLIOp>(
-          bitCastOp.getLoc(), shiftedRight, shiftLeftConstantOp);
-
-      res = res ? rewriter.create<arith::OrIOp>(bitCastOp.getLoc(), res,
-                                                shiftedLeft)
-                : shiftedLeft;
+    // Perform the rewrite.
+    Value truncValue = truncOp.getIn();
+    auto shuffledElementType =
+        cast<IntegerType>(getElementTypeOrSelf(truncValue.getType()));
+    Value runningResult;
+    for (const BitCastRewriter ::Metadata &metadata :
+         bcr.precomputeMetadata(shuffledElementType)) {
+      runningResult = bcr.rewriteStep(rewriter, bitCastOp->getLoc(), truncValue,
+                                      runningResult, metadata);
     }
 
-    bool narrowing = resultBitwidth <= initalElementBitWidth;
+    // Finalize the rewrite.
+    bool narrowing = targetVectorType.getElementTypeBitWidth() <=
+                     shuffledElementType.getIntOrFloatBitWidth();
     if (narrowing) {
       rewriter.replaceOpWithNewOp<arith::TruncIOp>(
-          bitCastOp, bitCastOp.getResultVectorType(), res);
+          bitCastOp, bitCastOp.getResultVectorType(), runningResult);
     } else {
       rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
-          bitCastOp, bitCastOp.getResultVectorType(), res);
+          bitCastOp, bitCastOp.getResultVectorType(), runningResult);
     }
+
+    return success();
+  }
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// RewriteExtOfBitCast
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take
+/// advantage of high-level information to avoid leaving LLVM to scramble with
+/// peephole optimizations.
+template <typename ExtOpType>
+struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
+  using OpRewritePattern<ExtOpType>::OpRewritePattern;
+
+  RewriteExtOfBitCast(MLIRContext *context, PatternBenefit benefit)
+      : OpRewritePattern<ExtOpType>(context, benefit) {}
+
+  LogicalResult matchAndRewrite(ExtOpType extOp,
+                                PatternRewriter &rewriter) const override {
+    // The source must be a bitcast op.
+    auto bitCastOp = extOp.getIn().template getDefiningOp<vector::BitCastOp>();
+    if (!bitCastOp)
+      return rewriter.notifyMatchFailure(extOp, "not a bitcast source");
+
+    // Set up the BitCastRewriter and verify the precondition.
+    VectorType sourceVectorType = bitCastOp.getSourceVectorType();
+    VectorType targetVectorType = bitCastOp.getResultVectorType();
+    BitCastRewriter bcr(sourceVectorType, targetVectorType);
+    if (failed(bcr.precondition(rewriter, targetVectorType, bitCastOp)))
+      return failure();
+
+    // Perform the rewrite.
+    Value runningResult;
+    Value sourceValue = bitCastOp.getSource();
+    auto shuffledElementType =
+        cast<IntegerType>(getElementTypeOrSelf(sourceValue.getType()));
+    for (const BitCastRewriter::Metadata &metadata :
+         bcr.precomputeMetadata(shuffledElementType)) {
+      runningResult = bcr.rewriteStep(rewriter, bitCastOp->getLoc(),
+                                      sourceValue, runningResult, metadata);
+    }
+
+    // Finalize the rewrite.
+    bool narrowing =
+        cast<VectorType>(extOp.getOut().getType()).getElementTypeBitWidth() <=
+        shuffledElementType.getIntOrFloatBitWidth();
+    if (narrowing) {
+      rewriter.replaceOpWithNewOp<arith::TruncIOp>(
+          extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
+    } else {
+      rewriter.replaceOpWithNewOp<ExtOpType>(
+          extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
+    }
+
     return success();
   }
 };
@@ -466,5 +594,7 @@ void vector::populateVectorNarrowTypeEmulationPatterns(
 
 void vector::populateVectorNarrowTypeRewritePatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
-  patterns.add<RewriteBitCastOfTruncI>(patterns.getContext(), benefit);
+  patterns.add<RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>,
+               RewriteExtOfBitCast<arith::ExtSIOp>>(patterns.getContext(),
+                                                    benefit);
 }
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
index ba6efde40f36c2b..e5e1cc6d37b041b 100644
--- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -4,146 +4,169 @@
 /// ====================================================
 /// mlir-opt --test-transform-dialect-interpreter mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir -test-transform-dialect-erase-schedule -test-lower-to-llvm | mlir-translate -mlir-to-llvmir | llc -o - -mcpu=skylake-avx512 --function-sections -filetype=obj > /tmp/a.out; objdump -d --disassemble=f1 --no-addresses --no-show-raw-insn -M att /tmp/a.out | ./build/bin/llvm-mca -mcpu=skylake-avx512
 
-// CHECK-LABEL: func.func @f1(
-//  CHECK-SAME: %[[A:[0-9a-z]*]]: vector<32xi64>) -> vector<20xi8>
-func.func @f1(%a: vector<32xi64>) -> vector<20xi8> {
-  /// Rewriting this standalone pattern is about 2x faster on skylake-ax512 according to llvm-mca.
-  /// Benefit further increases when mixed with other compute ops.
-  ///
-  /// The provenance of the 20x8 bits of the result are the following bits in the
-  /// source vector:
-  // { 0: b@[0..5) lshl: 0 } { 1: b@[0..3) lshl: 5 }
-  // { 1: b@[3..5) lshl: 0 } { 2: b@[0..5) lshl: 2 } { 3: b@[0..1) lshl: 7 }
-  // { 3: b@[1..5) lshl: 0 } { 4: b@[0..4) lshl: 4 }
-  // { 4: b@[4..5) lshl: 0 } { 5: b@[0..5) lshl: 1 } { 6: b@[0..2) lshl: 6 }
-  // { 6: b@[2..5) lshl: 0 } { 7: b@[0..5) lshl: 3 }
-  // { 8: b@[0..5) lshl: 0 } { 9: b@[0..3) lshl: 5 }
-  // { 9: b@[3..5) lshl: 0 } { 10: b@[0..5) lshl: 2 } { 11: b@[0..1) lshl: 7 }
-  // { 11: b@[1..5) lshl: 0 } { 12: b@[0..4) lshl: 4 }                      
-  // { 12: b@[4..5) lshl: 0 } { 13: b@[0..5) lshl: 1 } { 14: b@[0..2) lshl: 6 }
-  // { 14: b@[2..5) lshl: 0 } { 15: b@[0..5) lshl: 3 }                      
-  // { 16: b@[0..5) lshl: 0 } { 17: b@[0..3) lshl: 5 }                      
-  // { 17: b@[3..5) lshl: 0 } { 18: b@[0..5) lshl: 2 } { 19: b@[0..1) lshl: 7 }
-  // { 19: b@[1..5) lshl: 0 } { 20: b@[0..4) lshl: 4 }                      
-  // { 20: b@[4..5) lshl: 0 } { 21: b@[0..5) lshl: 1 } { 22: b@[0..2) lshl: 6 }
-  // { 22: b@[2..5) lshl: 0 } { 23: b@[0..5) lshl: 3 }                      
-  // { 24: b@[0..5) lshl: 0 } { 25: b@[0..3) lshl: 5 }                      
-  // { 25: b@[3..5) lshl: 0 } { 26: b@[0..5) lshl: 2 } { 27: b@[0..1) lshl: 7 }
-  // { 27: b@[1..5) lshl: 0 } { 28: b@[0..4) lshl: 4 }                      
-  // { 28: b@[4..5) lshl: 0 } { 29: b@[0..5) lshl: 1 } { 30: b@[0..2) lshl: 6 }
-  // { 30: b@[2..5) lshl: 0 } { 31: b@[0..5) lshl: 3 }  
-  /// This results in 3 shuffles + 1 shr + 2 shl + 3 and + 2 or.
-  /// The third vector is empty for positions 0, 2, 4, 5, 7, 9, 10, 12, 14, 15,
-  /// 17 and 19 (i.e. there are only 2 entries in that row).
-  /// 
-  ///                             0: b@[0..5), 1: b@[3..5), etc
-  // CHECK-DAG: %[[MASK0:.*]] = arith.constant dense<[31, 24, 30, 16, 28, 31, 24, 30, 16, 28, 31, 24, 30, 16, 28, 31, 24, 30, 16, 28]> : vector<20xi64>
-  ///                             1: b@[0..3), 2: b@[0..5), etc
-  // CHECK-DAG: %[[MASK1:.*]] = arith.constant dense<[7, 31, 15, 31, 31, 7, 31, 15, 31, 31, 7, 31, 15, 31, 31, 7, 31, 15, 31, 31]> :  vector<20xi64>
-  ///                             empty, 3: b@[0..1), empty etc
-  // CHECK-DAG: %[[MASK2:.*]] = arith.constant dense<[0, 1, 0, 3, 0, 0, 1, 0, 3, 0, 0, 1, 0, 3, 0, 0, 1, 0, 3, 0]> : vector<20xi64>
-  // CHECK-DAG: %[[SHR0_CST:.*]] = arith.constant dense<[0, 3, 1, 4, 2, 0, 3, 1, 4, 2, 0, 3, 1, 4, 2, 0, 3, 1, 4, 2]> : vector<20xi64>
-  // CHECK-DAG: %[[SHL1_CST:.*]] = arith.constant dense<[5, 2, 4, 1, 3, 5, 2, 4, 1, 3, 5, 2, 4, 1, 3, 5, 2, 4, 1, 3]> : vector<20xi64>
-  // CHECK-DAG: %[[SHL2_CST:.*]] = arith.constant dense<[8, 7, 8, 6, 8, 8, 7, 8, 6, 8, 8, 7, 8, 6, 8, 8, 7, 8, 6, 8]> : vector<20xi64>
-  //
-  // CHECK: %[[V0:.*]] = vector.shuffle %[[A]], %[[A]] [0, 1, 3, 4, 6, 8, 9, 11, 12, 14, 16, 17, 19, 20, 22, 24, 25, 27, 28, 30] : vector<32xi64>, vector<32xi64>
-  // CHECK: %[[A0:.*]] = arith.andi %[[V0]], %[[MASK0]] : vector<20xi64>
-  // CHECK: %[[SHR0:.*]] = arith.shrui %[[A0]], %[[SHR0_CST]] : vector<20xi64>
-  // CHECK: %[[V1:.*]] = vector.shuffle %[[A]], %[[A]] [1, 2, 4, 5, 7, 9, 10, 12, 13, 15, 17, 18, 20, 21, 23, 25, 26, 28, 29, 31] : vector<32xi64>, vector<32xi64>
-  // CHECK: %[[A1:.*]] = arith.andi %[[V1]], %[[MASK1]] : vector<20xi64>
-  // CHECK: %[[SHL1:.*]] = arith.shli %[[A1]], %[[SHL1_CST]] : vector<20xi64>
-  // CHECK: %[[O1:.*]] = arith.ori %[[SHR0]], %[[SHL1]] : vector<20xi64>
-  // CHECK: %[[V2:.*]] = vector.shuffle %[[A]], %[[A]] [0, 3, 0, 6, 0, 0, 11, 0, 14, 0, 0, 19, 0, 22, 0, 0, 27, 0, 30, 0] : vector<32xi64>, vector<32xi64>
-  // CHECK: %[[A2:.*]] = arith.andi %[[V2]], %[[MASK2]] : vector<20xi64>
-  // CHECK: %[[SHL2:.*]] = arith.shli %[[A2]], %[[SHL2_CST]] : vector<20xi64>
-  // CHECK: %[[O2:.*]] = arith.ori %[[O1]], %[[SHL2]] : vector<20xi64>
-  // CHECK: %[[TR:.*]] = arith.trunci %[[O2]] : vector<20xi64> to vector<20xi8>
-  // CHECK-NOT: bitcast
-  %0 = arith.trunci %a : vector<32xi64> to vector<32xi5>
-  %1 = vector.bitcast %0 : vector<32xi5> to vector<20xi8>
-  return %1 : vector<20xi8>
-}
+// // CHECK-LABEL: func.func @f1(
+// //  CHECK-SAME: %[[A:[0-9a-z]*]]: vector<32xi64>) -> vector<20xi8>
+// func.func @f1(%a: vector<32xi64>) -> vector<20xi8> {
+//   /// Rewriting this standalone pattern is about 2x faster on skylake-ax512 according to llvm-mca.
+//   /// Benefit further increases when mixed with other compute ops.
+//   ///
+//   /// The provenance of the 20x8 bits of the result are the following bits in the
+//   /// source vector:
+//   // { 0: b@[0..5) lshl: 0 } { 1: b@[0..3) lshl: 5 }
+//   // { 1: b@[3..5) lshl: 0 } { 2: b@[0..5) lshl: 2 } { 3: b@[0..1) lshl: 7 }
+//   // { 3: b@[1..5) lshl: 0 } { 4: b@[0..4) lshl: 4 }
+//   // { 4: b@[4..5) lshl: 0 } { 5: b@[0..5) lshl: 1 } { 6: b@[0..2) lshl: 6 }
+//   // { 6: b@[2..5) lshl: 0 } { 7: b@[0..5) lshl: 3 }
+//   // { 8: b@[0..5) lshl: 0 } { 9: b@[0..3) lshl: 5 }
+//   // { 9: b@[3..5) lshl: 0 } { 10: b@[0..5) lshl: 2 } { 11: b@[0..1) lshl: 7 }
+//   // { 11: b@[1..5) lshl: 0 } { 12: b@[0..4) lshl: 4 }                      
+//   // { 12: b@[4..5) lshl: 0 } { 13: b@[0..5) lshl: 1 } { 14: b@[0..2) lshl: 6 }
+//   // { 14: b@[2..5) lshl: 0 } { 15: b@[0..5) lshl: 3 }                      
+//   // { 16: b@[0..5) lshl: 0 } { 17: b@[0..3) lshl: 5 }                      
+//   // { 17: b@[3..5) lshl: 0 } { 18: b@[0..5) lshl: 2 } { 19: b@[0..1) lshl: 7 }
+//   // { 19: b@[1..5) lshl: 0 } { 20: b@[0..4) lshl: 4 }                      
+//   // { 20: b@[4..5) lshl: 0 } { 21: b@[0..5) lshl: 1 } { 22: b@[0..2) lshl: 6 }
+//   // { 22: b@[2..5) lshl: 0 } { 23: b@[0..5) lshl: 3 }                      
+//   // { 24: b@[0..5) lshl: 0 } { 25: b@[0..3) lshl: 5 }                      
+//   // { 25: b@[3..5) lshl: 0 } { 26: b@[0..5) lshl: 2 } { 27: b@[0..1) lshl: 7 }
+//   // { 27: b@[1..5) lshl: 0 } { 28: b@[0..4) lshl: 4 }                      
+//   // { 28: b@[4..5) lshl: 0 } { 29: b@[0..5) lshl: 1 } { 30: b@[0..2) lshl: 6 }
+//   // { 30: b@[2..5) lshl: 0 } { 31: b@[0..5) lshl: 3 }  
+//   /// This results in 3 shuffles + 1 shr + 2 shl + 3 and + 2 or.
+//   /// The third vector is empty for positions 0, 2, 4, 5, 7, 9, 10, 12, 14, 15,
+//   /// 17 and 19 (i.e. there are only 2 entries in that row).
+//   /// 
+//   ///                             0: b@[0..5), 1: b@[3..5), etc
+//   // CHECK-DAG: %[[MASK0:.*]] = arith.constant dense<[31, 24, 30, 16, 28, 31, 24, 30, 16, 28, 31, 24, 30, 16, 28, 31, 24, 30, 16, 28]> : vector<20xi64>
+//   ///                             1: b@[0..3), 2: b@[0..5), etc
+//   // CHECK-DAG: %[[MASK1:.*]] = arith.constant dense<[7, 31, 15, 31, 31, 7, 31, 15, 31, 31, 7, 31, 15, 31, 31, 7, 31, 15, 31, 31]> :  vector<20xi64>
+//   ///                             empty, 3: b@[0..1), empty etc
+//   // CHECK-DAG: %[[MASK2:.*]] = arith.constant dense<[0, 1, 0, 3, 0, 0, 1, 0, 3, 0, 0, 1, 0, 3, 0, 0, 1, 0, 3, 0]> : vector<20xi64>
+//   // CHECK-DAG: %[[SHR0_CST:.*]] = arith.constant dense<[0, 3, 1, 4, 2, 0, 3, 1, 4, 2, 0, 3, 1, 4, 2, 0, 3, 1, 4, 2]> : vector<20xi64>
+//   // CHECK-DAG: %[[SHL1_CST:.*]] = arith.constant dense<[5, 2, 4, 1, 3, 5, 2, 4, 1, 3, 5, 2, 4, 1, 3, 5, 2, 4, 1, 3]> : vector<20xi64>
+//   // CHECK-DAG: %[[SHL2_CST:.*]] = arith.constant dense<[8, 7, 8, 6, 8, 8, 7, 8, 6, 8, 8, 7, 8, 6, 8, 8, 7, 8, 6, 8]> : vector<20xi64>
+//   //
+//   // CHECK: %[[V0:.*]] = vector.shuffle %[[A]], %[[A]] [0, 1, 3, 4, 6, 8, 9, 11, 12, 14, 16, 17, 19, 20, 22, 24, 25, 27, 28, 30] : vector<32xi64>, vector<32xi64>
+//   // CHECK: %[[A0:.*]] = arith.andi %[[V0]], %[[MASK0]] : vector<20xi64>
+//   // CHECK: %[[SHR0:.*]] = arith.shrui %[[A0]], %[[SHR0_CST]] : vector<20xi64>
+//   // CHECK: %[[V1:.*]] = vector.shuffle %[[A]], %[[A]] [1, 2, 4, 5, 7, 9, 10, 12, 13, 15, 17, 18, 20, 21, 23, 25, 26, 28, 29, 31] : vector<32xi64>, vector<32xi64>
+//   // CHECK: %[[A1:.*]] = arith.andi %[[V1]], %[[MASK1]] : vector<20xi64>
+//   // CHECK: %[[SHL1:.*]] = arith.shli %[[A1]], %[[SHL1_CST]] : vector<20xi64>
+//   // CHECK: %[[O1:.*]] = arith.ori %[[SHR0]], %[[SHL1]] : vector<20xi64>
+//   // CHECK: %[[V2:.*]] = vector.shuffle %[[A]], %[[A]] [0, 3, 0, 6, 0, 0, 11, 0, 14, 0, 0, 19, 0, 22, 0, 0, 27, 0, 30, 0] : vector<32xi64>, vector<32xi64>
+//   // CHECK: %[[A2:.*]] = arith.andi %[[V2]], %[[MASK2]] : vector<20xi64>
+//   // CHECK: %[[SHL2:.*]] = arith.shli %[[A2]], %[[SHL2_CST]] : vector<20xi64>
+//   // CHECK: %[[O2:.*]] = arith.ori %[[O1]], %[[SHL2]] : vector<20xi64>
+//   // CHECK: %[[TR:.*]] = arith.trunci %[[O2]] : vector<20xi64> to vector<20xi8>
+//   // CHECK-NOT: bitcast
+//   %0 = arith.trunci %a : vector<32xi64> to vector<32xi5>
+//   %1 = vector.bitcast %0 : vector<32xi5> to vector<20xi8>
+//   return %1 : vector<20xi8>
+// }
 
-// CHECK-LABEL: func.func @f2(
-//  CHECK-SAME:   %[[A:[0-9a-z]*]]: vector<16xi16>) -> vector<3xi16>
-func.func @f2(%a: vector<16xi16>) -> vector<3xi16> {
-  /// Rewriting this standalone pattern is about 1.8x faster on skylake-ax512 according to llvm-mca.
-  /// Benefit further increases when mixed with other compute ops.
-  ///
-  // { 0: b@[0..3) lshl: 0 } { 1: b@[0..3) lshl: 3 } { 2: b@[0..3) lshl: 6 } { 3: b@[0..3) lshl: 9 } { 4: b@[0..3) lshl: 12 } { 5: b@[0..1) lshl: 15 } 
-  // { 5: b@[1..3) lshl: 0 } { 6: b@[0..3) lshl: 2 } { 7: b@[0..3) lshl: 5 } { 8: b@[0..3) lshl: 8 } { 9: b@[0..3) lshl: 11 } { 10: b@[0..2) lshl: 14 } 
-  // { 10: b@[2..3) lshl: 0 } { 11: b@[0..3) lshl: 1 } { 12: b@[0..3) lshl: 4 } { 13: b@[0..3) lshl: 7 } { 14: b@[0..3) lshl: 10 } { 15: b@[0..3) lshl: 13 }
-  ///                                             0: b@[0..3), 5: b@[1..3), 10: b@[2..3)
-  // CHECK-DAG: %[[MASK0:.*]] = arith.constant dense<[7, 6, 4]> : vector<3xi16>
-  ///                                             1: b@[0..3), 6: b@[0..3), 11: b@[0..3)
-  ///                                             ...
-  // CHECK-DAG: %[[MASK1:.*]] = arith.constant dense<7> : vector<3xi16>
-  ///                                             5: b@[0..1), 10: b@[0..2), 15: b@[0..3)
-  // CHECK-DAG: %[[MASK2:.*]] = arith.constant dense<[1, 3, 7]> : vector<3xi16>
-  // CHECK-DAG: %[[SHR0_CST:.*]] = arith.constant dense<[0, 1, 2]> : vector<3xi16>
-  // CHECK-DAG: %[[SHL1_CST:.*]] = arith.constant dense<[3, 2, 1]> : vector<3xi16>
-  // CHECK-DAG: %[[SHL2_CST:.*]] = arith.constant dense<[6, 5, 4]> : vector<3xi16>
-  // CHECK-DAG: %[[SHL3_CST:.*]] = arith.constant dense<[9, 8, 7]> : vector<3xi16>
-  // CHECK-DAG: %[[SHL4_CST:.*]] = arith.constant dense<[12, 11, 10]> : vector<3xi16>
-  // CHECK-DAG: %[[SHL5_CST:.*]] = arith.constant dense<[15, 14, 13]> : vector<3xi16>
+// // CHECK-LABEL: func.func @f2(
+// //  CHECK-SAME:   %[[A:[0-9a-z]*]]: vector<16xi16>) -> vector<3xi16>
+// func.func @f2(%a: vector<16xi16>) -> vector<3xi16> {
+//   /// Rewriting this standalone pattern is about 1.8x faster on skylake-ax512 according to llvm-mca.
+//   /// Benefit further increases when mixed with other compute ops.
+//   ///
+//   // { 0: b@[0..3) lshl: 0 } { 1: b@[0..3) lshl: 3 } { 2: b@[0..3) lshl: 6 } { 3: b@[0..3) lshl: 9 } { 4: b@[0..3) lshl: 12 } { 5: b@[0..1) lshl: 15 } 
+//   // { 5: b@[1..3) lshl: 0 } { 6: b@[0..3) lshl: 2 } { 7: b@[0..3) lshl: 5 } { 8: b@[0..3) lshl: 8 } { 9: b@[0..3) lshl: 11 } { 10: b@[0..2) lshl: 14 } 
+//   // { 10: b@[2..3) lshl: 0 } { 11: b@[0..3) lshl: 1 } { 12: b@[0..3) lshl: 4 } { 13: b@[0..3) lshl: 7 } { 14: b@[0..3) lshl: 10 } { 15: b@[0..3) lshl: 13 }
+//   ///                                             0: b@[0..3), 5: b@[1..3), 10: b@[2..3)
+//   // CHECK-DAG: %[[MASK0:.*]] = arith.constant dense<[7, 6, 4]> : vector<3xi16>
+//   ///                                             1: b@[0..3), 6: b@[0..3), 11: b@[0..3)
+//   ///                                             ...
+//   // CHECK-DAG: %[[MASK1:.*]] = arith.constant dense<7> : vector<3xi16>
+//   ///                                             5: b@[0..1), 10: b@[0..2), 15: b@[0..3)
+//   // CHECK-DAG: %[[MASK2:.*]] = arith.constant dense<[1, 3, 7]> : vector<3xi16>
+//   // CHECK-DAG: %[[SHR0_CST:.*]] = arith.constant dense<[0, 1, 2]> : vector<3xi16>
+//   // CHECK-DAG: %[[SHL1_CST:.*]] = arith.constant dense<[3, 2, 1]> : vector<3xi16>
+//   // CHECK-DAG: %[[SHL2_CST:.*]] = arith.constant dense<[6, 5, 4]> : vector<3xi16>
+//   // CHECK-DAG: %[[SHL3_CST:.*]] = arith.constant dense<[9, 8, 7]> : vector<3xi16>
+//   // CHECK-DAG: %[[SHL4_CST:.*]] = arith.constant dense<[12, 11, 10]> : vector<3xi16>
+//   // CHECK-DAG: %[[SHL5_CST:.*]] = arith.constant dense<[15, 14, 13]> : vector<3xi16>
 
-  //
-  // CHECK: %[[V0:.*]] = vector.shuffle %[[A]], %[[A]] [0, 5, 10] : vector<16xi16>, vector<16xi16>
-  // CHECK: %[[A0:.*]] = arith.andi %[[V0]], %[[MASK0]] : vector<3xi16>
-  // CHECK: %[[SHR0:.*]] = arith.shrui %[[A0]], %[[SHR0_CST]] : vector<3xi16>
-  // CHECK: %[[V1:.*]] = vector.shuffle %[[A]], %[[A]] [1, 6, 11] : vector<16xi16>, vector<16xi16>
-  // CHECK: %[[A1:.*]] = arith.andi %[[V1]], %[[MASK1]] : vector<3xi16>
-  // CHECK: %[[SHL1:.*]] = arith.shli %[[A1]], %[[SHL1_CST]] : vector<3xi16>
-  // CHECK: %[[O1:.*]] = arith.ori %[[SHR0]], %[[SHL1]] : vector<3xi16>
-  // CHECK: %[[V2:.*]] = vector.shuffle %[[A]], %[[A]] [2, 7, 12] : vector<16xi16>, vector<16xi16>
-  // CHECK: %[[A2:.*]] = arith.andi %[[V2]], %[[MASK1]] : vector<3xi16>
-  // CHECK: %[[SHL2:.*]] = arith.shli %[[A2]], %[[SHL2_CST]] : vector<3xi16>
-  // CHECK: %[[O2:.*]] = arith.ori %[[O1]], %[[SHL2]] : vector<3xi16>
-  // CHECK: %[[V3:.*]] = vector.shuffle %[[A]], %[[A]] [3, 8, 13] : vector<16xi16>, vector<16xi16>
-  // CHECK: %[[A3:.*]] = arith.andi %[[V3]], %[[MASK1]] : vector<3xi16>
-  // CHECK: %[[SHL3:.*]] = arith.shli %[[A3]], %[[SHL3_CST]] : vector<3xi16>
-  // CHECK: %[[O3:.*]] = arith.ori %[[O2]], %[[SHL3]]  : vector<3xi16>
-  // CHECK: %[[V4:.*]] = vector.shuffle %[[A]], %[[A]] [4, 9, 14] : vector<16xi16>, vector<16xi16>
-  // CHECK: %[[A4:.*]] = arith.andi %[[V4]], %[[MASK1]] : vector<3xi16>
-  // CHECK: %[[SHL4:.*]] = arith.shli %[[A4]], %[[SHL4_CST]] : vector<3xi16>
-  // CHECK: %[[O4:.*]] = arith.ori %[[O3]], %[[SHL4]]  : vector<3xi16>
-  // CHECK: %[[V5:.*]] = vector.shuffle %[[A]], %[[A]] [5, 10, 15] : vector<16xi16>, vector<16xi16>
-  // CHECK: %[[A5:.*]] = arith.andi %[[V5]], %[[MASK2]] : vector<3xi16>
-  // CHECK: %[[SHL5:.*]] = arith.shli %[[A5]], %[[SHL5_CST]] : vector<3xi16>
-  // CHECK: %[[O5:.*]] = arith.ori %[[O4]], %[[SHL5]]  : vector<3xi16>
-  /// No trunci needed as the result is already in i16.
-  // CHECK-NOT: arith.trunci
-  // CHECK-NOT: bitcast
-  %0 = arith.trunci %a : vector<16xi16> to vector<16xi3>
-  %1 = vector.bitcast %0 : vector<16xi3> to vector<3xi16>
-  return %1 : vector<3xi16>
-}
+//   //
+//   // CHECK: %[[V0:.*]] = vector.shuffle %[[A]], %[[A]] [0, 5, 10] : vector<16xi16>, vector<16xi16>
+//   // CHECK: %[[A0:.*]] = arith.andi %[[V0]], %[[MASK0]] : vector<3xi16>
+//   // CHECK: %[[SHR0:.*]] = arith.shrui %[[A0]], %[[SHR0_CST]] : vector<3xi16>
+//   // CHECK: %[[V1:.*]] = vector.shuffle %[[A]], %[[A]] [1, 6, 11] : vector<16xi16>, vector<16xi16>
+//   // CHECK: %[[A1:.*]] = arith.andi %[[V1]], %[[MASK1]] : vector<3xi16>
+//   // CHECK: %[[SHL1:.*]] = arith.shli %[[A1]], %[[SHL1_CST]] : vector<3xi16>
+//   // CHECK: %[[O1:.*]] = arith.ori %[[SHR0]], %[[SHL1]] : vector<3xi16>
+//   // CHECK: %[[V2:.*]] = vector.shuffle %[[A]], %[[A]] [2, 7, 12] : vector<16xi16>, vector<16xi16>
+//   // CHECK: %[[A2:.*]] = arith.andi %[[V2]], %[[MASK1]] : vector<3xi16>
+//   // CHECK: %[[SHL2:.*]] = arith.shli %[[A2]], %[[SHL2_CST]] : vector<3xi16>
+//   // CHECK: %[[O2:.*]] = arith.ori %[[O1]], %[[SHL2]] : vector<3xi16>
+//   // CHECK: %[[V3:.*]] = vector.shuffle %[[A]], %[[A]] [3, 8, 13] : vector<16xi16>, vector<16xi16>
+//   // CHECK: %[[A3:.*]] = arith.andi %[[V3]], %[[MASK1]] : vector<3xi16>
+//   // CHECK: %[[SHL3:.*]] = arith.shli %[[A3]], %[[SHL3_CST]] : vector<3xi16>
+//   // CHECK: %[[O3:.*]] = arith.ori %[[O2]], %[[SHL3]]  : vector<3xi16>
+//   // CHECK: %[[V4:.*]] = vector.shuffle %[[A]], %[[A]] [4, 9, 14] : vector<16xi16>, vector<16xi16>
+//   // CHECK: %[[A4:.*]] = arith.andi %[[V4]], %[[MASK1]] : vector<3xi16>
+//   // CHECK: %[[SHL4:.*]] = arith.shli %[[A4]], %[[SHL4_CST]] : vector<3xi16>
+//   // CHECK: %[[O4:.*]] = arith.ori %[[O3]], %[[SHL4]]  : vector<3xi16>
+//   // CHECK: %[[V5:.*]] = vector.shuffle %[[A]], %[[A]] [5, 10, 15] : vector<16xi16>, vector<16xi16>
+//   // CHECK: %[[A5:.*]] = arith.andi %[[V5]], %[[MASK2]] : vector<3xi16>
+//   // CHECK: %[[SHL5:.*]] = arith.shli %[[A5]], %[[SHL5_CST]] : vector<3xi16>
+//   // CHECK: %[[O5:.*]] = arith.ori %[[O4]], %[[SHL5]]  : vector<3xi16>
+//   /// No trunci needed as the result is already in i16.
+//   // CHECK-NOT: arith.trunci
+//   // CHECK-NOT: bitcast
+//   %0 = arith.trunci %a : vector<16xi16> to vector<16xi3>
+//   %1 = vector.bitcast %0 : vector<16xi3> to vector<3xi16>
+//   return %1 : vector<3xi16>
+// }
 
-/// This pattern requires an extui 16 -> 32 and not a trunci.
-// CHECK-LABEL: func.func @f3(
-func.func @f3(%a: vector<16xi16>) -> vector<2xi32> {
-  /// Rewriting this standalone pattern is about 25x faster on skylake-ax512 according to llvm-mca.
-  /// Benefit further increases when mixed with other compute ops.
-  ///
-  // CHECK-NOT: arith.trunci
-  // CHECK-NOT: bitcast
-  //     CHECK: arith.extui
-  %0 = arith.trunci %a : vector<16xi16> to vector<16xi4>
-  %1 = vector.bitcast %0 : vector<16xi4> to vector<2xi32>
-  return %1 : vector<2xi32>
-}
+// /// This pattern requires an extui 16 -> 32 and not a trunci.
+// // CHECK-LABEL: func.func @f3(
+// func.func @f3(%a: vector<16xi16>) -> vector<2xi32> {
+//   /// Rewriting this standalone pattern is about 25x faster on skylake-ax512 according to llvm-mca.
+//   /// Benefit further increases when mixed with other compute ops.
+//   ///
+//   // CHECK-NOT: arith.trunci
+//   // CHECK-NOT: bitcast
+//   //     CHECK: arith.extui
+//   %0 = arith.trunci %a : vector<16xi16> to vector<16xi4>
+//   %1 = vector.bitcast %0 : vector<16xi4> to vector<2xi32>
+//   return %1 : vector<2xi32>
+// }
+
+// /// This pattern is not rewritten as the result i6 is not a multiple of i8.
+// // CHECK-LABEL: func.func @f4(
+// func.func @f4(%a: vector<16xi16>) -> vector<8xi6> {
+//   // CHECK: trunci
+//   // CHECK: bitcast
+//   // CHECK-NOT: shuffle
+//   // CHECK-NOT: andi
+//   // CHECK-NOT: ori
+//   %0 = arith.trunci %a : vector<16xi16> to vector<16xi3>
+//   %1 = vector.bitcast %0 : vector<16xi3> to vector<8xi6>
+//   return %1 : vector<8xi6>
+// }
+
+
+// // CHECK-LABEL: func.func @f1ext(
+// //  CHECK-SAME: %[[A:[0-9a-z]*]]: vector<32xi64>) -> vector<32xi64> {
+// func.func @f1ext(%a: vector<20xi8>) -> vector<32xi64> {
+//   %0 = vector.bitcast %a : vector<20xi8> to vector<32xi5>
+//   %1 = arith.extui %0 : vector<32xi5> to vector<32xi64>
+//   return %1 : vector<32xi64>
+// }
+
+// // CHECK-LABEL: func.func @f2ext(
+// //  CHECK-SAME:   %[[A:[0-9a-z]*]]: vector<3xi16>) -> vector<16xi16>
+// func.func @f2ext(%a: vector<3xi16>) -> vector<16xi16> {
+//   %0 = vector.bitcast %a : vector<3xi16> x vector<16xi3> 
+//   %1 = arith.extui %0 : vector<16xi3> to vector<16xi16>
+//   return %1 : vector<16xi16>
+// }
 
-/// This pattern is not rewritten as the result i6 is not a multiple of i8.
-// CHECK-LABEL: func.func @f4(
-func.func @f4(%a: vector<16xi16>) -> vector<8xi6> {
-  // CHECK: trunci
-  // CHECK: bitcast
-  // CHECK-NOT: shuffle
-  // CHECK-NOT: andi
-  // CHECK-NOT: ori
-  %0 = arith.trunci %a : vector<16xi16> to vector<16xi3>
-  %1 = vector.bitcast %0 : vector<16xi3> to vector<8xi6>
-  return %1 : vector<8xi6>
+func.func @fext(%a: vector<5xi8>) -> vector<8xi16> {
+  %0 = vector.bitcast %a : vector<5xi8> to vector<8xi5>
+  %1 = arith.extui %0 : vector<8xi5> to vector<8xi16>
+  return %1 : vector<8xi16>
 }
 
 transform.sequence failures(propagate) {
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-rewrite-narrow-types.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-rewrite-narrow-types.mlir
index 44c608726f13530..7d15e2e2e3ef5e4 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-rewrite-narrow-types.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-rewrite-narrow-types.mlir
@@ -124,6 +124,47 @@ func.func @f3(%v: vector<2xi48>) {
   return
 }
 
+func.func @print_as_i1_8xi5(%v : vector<8xi5>) {
+  %bitsi40 = vector.bitcast %v : vector<8xi5> to vector<40xi1>
+  vector.print %bitsi40 : vector<40xi1>
+  return
+}
+
+func.func @print_as_i1_8xi16(%v : vector<8xi16>) {
+  %bitsi128 = vector.bitcast %v : vector<8xi16> to vector<128xi1>
+  vector.print %bitsi128 : vector<128xi1>
+  return
+}
+
+func.func @fext(%a: vector<5xi8>) {
+  %0 = vector.bitcast %a : vector<5xi8> to vector<8xi5>
+  func.call @print_as_i1_8xi5(%0) : (vector<8xi5>) -> ()
+  //      CHECK: (
+  // CHECK-SAME: 1, 1, 1, 1, 0,
+  // CHECK-SAME: 1, 1, 1, 0, 1,
+  // CHECK-SAME: 1, 1, 0, 1, 1,
+  // CHECK-SAME: 1, 1, 0, 1, 1,
+  // CHECK-SAME: 0, 1, 1, 1, 0,
+  // CHECK-SAME: 0, 1, 1, 0, 1,
+  // CHECK-SAME: 1, 1, 1, 1, 0,
+  // CHECK-SAME: 1, 0, 1, 1, 1 )
+
+  %1 = arith.extui %0 : vector<8xi5> to vector<8xi16>
+  func.call @print_as_i1_8xi16(%1) : (vector<8xi16>) -> ()
+  //      CHECK: (
+  // CHECK-SAME: 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+  // CHECK-SAME: 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+  // CHECK-SAME: 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+  // CHECK-SAME: 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+  // CHECK-SAME: 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+  // CHECK-SAME: 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+  // CHECK-SAME: 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+  // CHECK-SAME: 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
+
+  return
+}
+
+
 func.func @entry() {
   %v = arith.constant dense<[
     0xffff, 0xfffe, 0xfffd, 0xfffc, 0xfffb, 0xfffa, 0xfff9, 0xfff8,
@@ -141,6 +182,11 @@ func.func @entry() {
   ]> : vector<2xi48>
   func.call @f3(%v3) : (vector<2xi48>) -> ()
 
+  %v4 = arith.constant dense<[
+    0xef, 0xee, 0xed, 0xec, 0xeb
+  ]> : vector<5xi8>
+  func.call @fext(%v4) : (vector<5xi8>) -> ()
+
   return
 }
 



More information about the Mlir-commits mailing list