[Mlir-commits] [mlir] [mlir][Vector] Add a rewrite pattern for better low-precision bitcast… (PR #66387)

Nicolas Vasilache llvmlistbot at llvm.org
Fri Sep 15 06:21:13 PDT 2023


https://github.com/nicolasvasilache updated https://github.com/llvm/llvm-project/pull/66387

>From 5132376dc1088dcab4d3ba4008e73f469dd840fa Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nicolasvasilache at users.noreply.github.com>
Date: Fri, 8 Sep 2023 11:28:55 +0200
Subject: [PATCH] [mlir][Vector] Add a rewrite pattern for better low-precision
 bitcast(trunci) expansion

This revision adds a rewrite for sequences of vector `bitcast(trunci)` to use a more efficient sequence of
vector operations comprising `shuffle` and `bitwise` ops.

Such patterns appear naturally when writing quantization / dequantization functionality with the vector dialect.

The rewrite performs a simple enumeration of each of the bits in the result vector and determines its provenance
in the pre-trunci vector. The enumeration is used to generate the proper sequence of `shuffle`, `andi`, `ori`
followed by an optional final `trunci`/`extui`.

The rewrite currently only applies to 1-D non-scalable vectors and bails out if the final vector element type is
not a multiple of 8. This is a failsafe heuristic determined empirically: if the resulting type
is not an even number of bytes, further complexities arise that are not improved by this pattern:
the heavy lifting still needs to be done by LLVM.
---
 .../Vector/TransformOps/VectorTransformOps.td |  13 +
 .../Vector/Transforms/VectorRewritePatterns.h |  15 +-
 mlir/include/mlir/IR/BuiltinTypes.h           |  10 +
 .../TransformOps/VectorTransformOps.cpp       |   5 +
 .../Transforms/VectorEmulateNarrowType.cpp    | 237 +++++++++++++++++-
 .../Vector/vector-rewrite-narrow-types.mlir   | 157 ++++++++++++
 .../Vector/CPU/test-rewrite-narrow-types.mlir | 155 ++++++++++++
 7 files changed, 587 insertions(+), 5 deletions(-)
 create mode 100644 mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
 create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/test-rewrite-narrow-types.mlir

diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 9e718a0c80bbf3b..133ee4e030f01e5 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -292,6 +292,19 @@ def ApplyLowerTransposePatternsOp : Op<Transform_Dialect,
   }];
 }
 
+def ApplyRewriteNarrowTypePatternsOp : Op<Transform_Dialect,
+    "apply_patterns.vector.rewrite_narrow_types",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Indicates that vector narrow rewrite operations should be applied.
+
+    This is usually a late step that is run after bufferization as part of the
+    process of lowering to e.g. LLVM or NVVM.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
 def ApplySplitTransferFullPartialPatternsOp : Op<Transform_Dialect,
     "apply_patterns.vector.split_transfer_full_partial",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index c644090d8c78cd0..8652fc7f5e5c640 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -24,6 +24,7 @@ class RewritePatternSet;
 
 namespace arith {
 class NarrowTypeEmulationConverter;
+class TruncIOp;
 } // namespace arith
 
 namespace vector {
@@ -143,7 +144,7 @@ void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
 
 /// Patterns that remove redundant vector broadcasts.
 void populateSinkVectorBroadcastPatterns(RewritePatternSet &patterns,
-                                          PatternBenefit benefit = 1);
+                                         PatternBenefit benefit = 1);
 
 /// Populate `patterns` with the following patterns.
 ///
@@ -301,6 +302,18 @@ void populateVectorNarrowTypeEmulationPatterns(
     arith::NarrowTypeEmulationConverter &typeConverter,
     RewritePatternSet &patterns);
 
+/// Rewrite a vector `bitcast(trunci)` to use a more efficient sequence of
+/// vector operations comprising `shuffle` and `bitwise` ops.
+FailureOr<Value> rewriteBitCastOfTruncI(RewriterBase &rewriter,
+                                        vector::BitCastOp bitCastOp,
+                                        arith::TruncIOp truncOp,
+                                        vector::BroadcastOp maybeBroadcastOp);
+
+/// Appends patterns for rewriting vector operations over narrow types with
+/// ops over wider types.
+void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
+                                             PatternBenefit benefit = 1);
+
 } // namespace vector
 } // namespace mlir
 
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index f031eb0a5c30ce9..9df5548cd5d939c 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -357,6 +357,16 @@ class VectorType::Builder {
     return *this;
   }
 
+  /// Set a dim in shape @pos to val.
+  Builder &setDim(unsigned pos, int64_t val) {
+    if (storage.empty())
+      storage.append(shape.begin(), shape.end());
+    assert(pos < storage.size() && "overflow");
+    storage[pos] = val;
+    shape = {storage.data(), storage.size()};
+    return *this;
+  }
+
   operator VectorType() {
     return VectorType::get(shape, elementType, scalableDims);
   }
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index b388deaa46a7917..37127ea70f1e5af 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -159,6 +159,11 @@ void transform::ApplyLowerTransposePatternsOp::populatePatterns(
   }
 }
 
+void transform::ApplyRewriteNarrowTypePatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  populateVectorNarrowTypeRewritePatterns(patterns);
+}
+
 void transform::ApplySplitTransferFullPartialPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
   vector::VectorTransformsOptions vectorTransformOptions;
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index b2b7bfc5e4437c1..eaf0b8849e36ebf 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -7,7 +7,6 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
@@ -15,13 +14,23 @@
 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/IR/Value.h"
 #include "mlir/Transforms/DialectConversion.h"
-#include "llvm/Support/FormatVariadic.h"
-#include "llvm/Support/MathExtras.h"
-#include <cassert>
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+#include <cstdint>
 
 using namespace mlir;
 
+#define DEBUG_TYPE "vector-narrow-type-emulation"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+#define DBGSNL() (llvm::dbgs() << "\n")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
+
 namespace {
 
 //===----------------------------------------------------------------------===//
@@ -155,6 +164,221 @@ struct ConvertVectorTransferRead final
 };
 } // end anonymous namespace
 
+//===----------------------------------------------------------------------===//
+// RewriteBitCastOfTruncI
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+/// Helper struct to keep track of the provenance of a contiguous set of bits
+/// in a source vector.
+struct SourceElementRange {
+  int64_t sourceElement;
+  int64_t sourceBitBegin;
+  int64_t sourceBitEnd;
+};
+
+struct SourceElementRangeList : public SmallVector<SourceElementRange> {
+  /// Given the index of a SourceElementRange in the SourceElementRangeList,
+  /// compute the amount of bits that need to be shifted to the left to get the
+  /// bits in their final location. This shift amount is simply the sum of the
+  /// bits *before* `shuffleIdx` (i.e. the bits of `shuffleIdx = 0` are always
+  /// the LSBs, the bits of `shuffleIdx = ` come next, etc).
+  int64_t computeLeftShiftAmount(int64_t shuffleIdx) const {
+    int64_t res = 0;
+    for (int64_t i = 0; i < shuffleIdx; ++i)
+      res += (*this)[i].sourceBitEnd - (*this)[i].sourceBitBegin;
+    return res;
+  }
+};
+
+/// Helper struct to enumerate the source elements and bit ranges that are
+/// involved in a bitcast operation.
+/// This allows rewriting a vector.bitcast into shuffles and bitwise ops for
+/// any 1-D vector shape and any source/target bitwidths.
+struct BitCastBitsEnumerator {
+  BitCastBitsEnumerator(VectorType sourceVectorType,
+                        VectorType targetVectorType);
+
+  int64_t getMaxNumberOfEntries() {
+    int64_t numVectors = 0;
+    for (const auto &l : sourceElementRanges)
+      numVectors = std::max(numVectors, (int64_t)l.size());
+    return numVectors;
+  }
+
+  VectorType sourceVectorType;
+  VectorType targetVectorType;
+  SmallVector<SourceElementRangeList> sourceElementRanges;
+};
+
+} // namespace
+
+static raw_ostream &operator<<(raw_ostream &os,
+                               const SmallVector<SourceElementRangeList> &vec) {
+  for (const auto &l : vec) {
+    for (auto it : llvm::enumerate(l)) {
+      os << "{ " << it.value().sourceElement << ": b@["
+         << it.value().sourceBitBegin << ".." << it.value().sourceBitEnd
+         << ") lshl: " << l.computeLeftShiftAmount(it.index()) << " } ";
+    }
+    os << "\n";
+  }
+  return os;
+}
+
+BitCastBitsEnumerator::BitCastBitsEnumerator(VectorType sourceVectorType,
+                                             VectorType targetVectorType)
+    : sourceVectorType(sourceVectorType), targetVectorType(targetVectorType) {
+
+  assert(targetVectorType.getRank() == 1 && !targetVectorType.isScalable() &&
+         "requires -D non-scalable vector type");
+  int64_t sourceBitWidth = sourceVectorType.getElementTypeBitWidth();
+  int64_t mostMinorSourceDim = sourceVectorType.getShape().back();
+  LDBG("sourceVectorType: " << sourceVectorType);
+
+  int64_t targetBitWidth = targetVectorType.getElementTypeBitWidth();
+  int64_t mostMinorTargetDim = targetVectorType.getShape().back();
+  LDBG("targetVectorType: " << targetVectorType);
+
+  int64_t bitwidth = targetBitWidth * mostMinorTargetDim;
+  assert(bitwidth == sourceBitWidth * mostMinorSourceDim &&
+         "source and target bitwidths must match");
+
+  // Prepopulate one source element range per target element.
+  sourceElementRanges = SmallVector<SourceElementRangeList>(mostMinorTargetDim);
+  for (int64_t resultBit = 0; resultBit < bitwidth;) {
+    int64_t resultElement = resultBit / targetBitWidth;
+    int64_t resultBitInElement = resultBit % targetBitWidth;
+    int64_t sourceElement = resultBit / sourceBitWidth;
+    int64_t sourceBitInElement = resultBit % sourceBitWidth;
+    int64_t step = std::min(sourceBitWidth - sourceBitInElement,
+                            targetBitWidth - resultBitInElement);
+    sourceElementRanges[resultElement].push_back(
+        {sourceElement, sourceBitInElement, sourceBitInElement + step});
+    resultBit += step;
+  }
+}
+
+namespace {
+/// Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take
+/// advantage of high-level information to avoid leaving LLVM to scramble with
+/// peephole optimizations.
+struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::BitCastOp bitCastOp,
+                                PatternRewriter &rewriter) const override {
+    // The source must be a trunc op.
+    auto truncOp =
+        bitCastOp.getSource().template getDefiningOp<arith::TruncIOp>();
+    if (!truncOp)
+      return rewriter.notifyMatchFailure(bitCastOp, "not a trunci source");
+
+    VectorType targetVectorType = bitCastOp.getResultVectorType();
+    if (targetVectorType.getRank() != 1 || targetVectorType.isScalable())
+      return rewriter.notifyMatchFailure(bitCastOp, "scalable or >1-D vector");
+    // TODO: consider relaxing this restriction in the future if we find ways to
+    // really work with subbyte elements across the MLIR/LLVM boundary.
+    int64_t resultBitwidth = targetVectorType.getElementTypeBitWidth();
+    if (resultBitwidth % 8 != 0)
+      return rewriter.notifyMatchFailure(bitCastOp, "bitwidth is not k * 8");
+
+    VectorType sourceVectorType = bitCastOp.getSourceVectorType();
+    BitCastBitsEnumerator be(sourceVectorType, targetVectorType);
+    LDBG("\n" << be.sourceElementRanges);
+
+    Value initialValue = truncOp.getIn();
+    auto initalVectorType = initialValue.getType().cast<VectorType>();
+    auto initalElementType = initalVectorType.getElementType();
+    auto initalElementBitWidth = initalElementType.getIntOrFloatBitWidth();
+
+    // BitCastBitsEnumerator encodes for each element of the target vector the
+    // provenance of the bits in the source vector. We can "transpose" this
+    // information to build a sequence of shuffles and bitwise ops that will
+    // produce the desired result.
+    // The algorithm proceeds as follows:
+    //   1. there are as many shuffles as max entries in BitCastBitsEnumerator
+    //   2. for each shuffle:
+    //     a. collect the source vectors that participate in this shuffle. One
+    //     source vector per target element of the shuffle. If overflow, take 0.
+    //     b. the bitrange in the source vector as a mask. If overflow, take 0.
+    //     c. the number of bits to shift right to align the source bitrange at
+    //     position 0. This is exactly the low end of the bitrange.
+    //     d. number of bits to shift left to align to the desired position in
+    //     the result element vector.
+    // Then build the sequence:
+    //   (shuffle -> and -> shiftright -> shiftleft -> or) to iteratively update
+    // the result vector (i.e. the "shiftright -> shiftleft -> or" part) with
+    // the bits extracted from the source vector (i.e. the "shuffle -> and"
+    // part).
+    Value res;
+    for (int64_t shuffleIdx = 0, e = be.getMaxNumberOfEntries(); shuffleIdx < e;
+         ++shuffleIdx) {
+      SmallVector<int64_t> shuffles;
+      SmallVector<Attribute> masks, shiftRightAmounts, shiftLeftAmounts;
+      for (auto &l : be.sourceElementRanges) {
+        int64_t sourceElement =
+            (shuffleIdx < (int64_t)l.size()) ? l[shuffleIdx].sourceElement : 0;
+        shuffles.push_back(sourceElement);
+
+        int64_t bitLo =
+            (shuffleIdx < (int64_t)l.size()) ? l[shuffleIdx].sourceBitBegin : 0;
+        int64_t bitHi =
+            (shuffleIdx < (int64_t)l.size()) ? l[shuffleIdx].sourceBitEnd : 0;
+        IntegerAttr mask = IntegerAttr::get(
+            rewriter.getIntegerType(initalElementBitWidth),
+            llvm::APInt::getBitsSet(initalElementBitWidth, bitLo, bitHi));
+        masks.push_back(mask);
+
+        int64_t shiftRight = bitLo;
+        shiftRightAmounts.push_back(IntegerAttr::get(
+            rewriter.getIntegerType(initalElementBitWidth), shiftRight));
+
+        int64_t shiftLeft = l.computeLeftShiftAmount(shuffleIdx);
+        shiftLeftAmounts.push_back(IntegerAttr::get(
+            rewriter.getIntegerType(initalElementBitWidth), shiftLeft));
+      }
+
+      //
+      auto shuffleOp = rewriter.create<vector::ShuffleOp>(
+          bitCastOp.getLoc(), initialValue, initialValue, shuffles);
+
+      VectorType vt = VectorType::Builder(initalVectorType)
+                          .setDim(initalVectorType.getRank() - 1, masks.size());
+      auto constOp = rewriter.create<arith::ConstantOp>(
+          bitCastOp.getLoc(), DenseElementsAttr::get(vt, masks));
+      Value andValue = rewriter.create<arith::AndIOp>(bitCastOp.getLoc(),
+                                                      shuffleOp, constOp);
+
+      auto shiftRightConstantOp = rewriter.create<arith::ConstantOp>(
+          bitCastOp.getLoc(), DenseElementsAttr::get(vt, shiftRightAmounts));
+      Value shiftedRight = rewriter.create<arith::ShRUIOp>(
+          bitCastOp.getLoc(), andValue, shiftRightConstantOp);
+
+      auto shiftLeftConstantOp = rewriter.create<arith::ConstantOp>(
+          bitCastOp.getLoc(), DenseElementsAttr::get(vt, shiftLeftAmounts));
+      Value shiftedLeft = rewriter.create<arith::ShLIOp>(
+          bitCastOp.getLoc(), shiftedRight, shiftLeftConstantOp);
+
+      res = res ? rewriter.create<arith::OrIOp>(bitCastOp.getLoc(), res,
+                                                shiftedLeft)
+                : shiftedLeft;
+    }
+
+    bool narrowing = resultBitwidth <= initalElementBitWidth;
+    if (narrowing) {
+      rewriter.replaceOpWithNewOp<arith::TruncIOp>(
+          bitCastOp, bitCastOp.getResultVectorType(), res);
+    } else {
+      rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
+          bitCastOp, bitCastOp.getResultVectorType(), res);
+    }
+    return success();
+  }
+};
+} // namespace
+
 //===----------------------------------------------------------------------===//
 // Public Interface Definition
 //===----------------------------------------------------------------------===//
@@ -167,3 +391,8 @@ void vector::populateVectorNarrowTypeEmulationPatterns(
   patterns.add<ConvertVectorLoad, ConvertVectorTransferRead>(
       typeConverter, patterns.getContext());
 }
+
+void vector::populateVectorNarrowTypeRewritePatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<RewriteBitCastOfTruncI>(patterns.getContext(), benefit);
+}
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
new file mode 100644
index 000000000000000..ba6efde40f36c2b
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -0,0 +1,157 @@
+// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s
+
+/// Note: Inspect generated assembly and llvm-mca stats:
+/// ====================================================
+/// mlir-opt --test-transform-dialect-interpreter mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir -test-transform-dialect-erase-schedule -test-lower-to-llvm | mlir-translate -mlir-to-llvmir | llc -o - -mcpu=skylake-avx512 --function-sections -filetype=obj > /tmp/a.out; objdump -d --disassemble=f1 --no-addresses --no-show-raw-insn -M att /tmp/a.out | ./build/bin/llvm-mca -mcpu=skylake-avx512
+
+// CHECK-LABEL: func.func @f1(
+//  CHECK-SAME: %[[A:[0-9a-z]*]]: vector<32xi64>) -> vector<20xi8>
+func.func @f1(%a: vector<32xi64>) -> vector<20xi8> {
+  /// Rewriting this standalone pattern is about 2x faster on skylake-ax512 according to llvm-mca.
+  /// Benefit further increases when mixed with other compute ops.
+  ///
+  /// The provenance of the 20x8 bits of the result are the following bits in the
+  /// source vector:
+  // { 0: b@[0..5) lshl: 0 } { 1: b@[0..3) lshl: 5 }
+  // { 1: b@[3..5) lshl: 0 } { 2: b@[0..5) lshl: 2 } { 3: b@[0..1) lshl: 7 }
+  // { 3: b@[1..5) lshl: 0 } { 4: b@[0..4) lshl: 4 }
+  // { 4: b@[4..5) lshl: 0 } { 5: b@[0..5) lshl: 1 } { 6: b@[0..2) lshl: 6 }
+  // { 6: b@[2..5) lshl: 0 } { 7: b@[0..5) lshl: 3 }
+  // { 8: b@[0..5) lshl: 0 } { 9: b@[0..3) lshl: 5 }
+  // { 9: b@[3..5) lshl: 0 } { 10: b@[0..5) lshl: 2 } { 11: b@[0..1) lshl: 7 }
+  // { 11: b@[1..5) lshl: 0 } { 12: b@[0..4) lshl: 4 }                      
+  // { 12: b@[4..5) lshl: 0 } { 13: b@[0..5) lshl: 1 } { 14: b@[0..2) lshl: 6 }
+  // { 14: b@[2..5) lshl: 0 } { 15: b@[0..5) lshl: 3 }                      
+  // { 16: b@[0..5) lshl: 0 } { 17: b@[0..3) lshl: 5 }                      
+  // { 17: b@[3..5) lshl: 0 } { 18: b@[0..5) lshl: 2 } { 19: b@[0..1) lshl: 7 }
+  // { 19: b@[1..5) lshl: 0 } { 20: b@[0..4) lshl: 4 }                      
+  // { 20: b@[4..5) lshl: 0 } { 21: b@[0..5) lshl: 1 } { 22: b@[0..2) lshl: 6 }
+  // { 22: b@[2..5) lshl: 0 } { 23: b@[0..5) lshl: 3 }                      
+  // { 24: b@[0..5) lshl: 0 } { 25: b@[0..3) lshl: 5 }                      
+  // { 25: b@[3..5) lshl: 0 } { 26: b@[0..5) lshl: 2 } { 27: b@[0..1) lshl: 7 }
+  // { 27: b@[1..5) lshl: 0 } { 28: b@[0..4) lshl: 4 }                      
+  // { 28: b@[4..5) lshl: 0 } { 29: b@[0..5) lshl: 1 } { 30: b@[0..2) lshl: 6 }
+  // { 30: b@[2..5) lshl: 0 } { 31: b@[0..5) lshl: 3 }  
+  /// This results in 3 shuffles + 1 shr + 2 shl + 3 and + 2 or.
+  /// The third vector is empty for positions 0, 2, 4, 5, 7, 9, 10, 12, 14, 15,
+  /// 17 and 19 (i.e. there are only 2 entries in that row).
+  /// 
+  ///                             0: b@[0..5), 1: b@[3..5), etc
+  // CHECK-DAG: %[[MASK0:.*]] = arith.constant dense<[31, 24, 30, 16, 28, 31, 24, 30, 16, 28, 31, 24, 30, 16, 28, 31, 24, 30, 16, 28]> : vector<20xi64>
+  ///                             1: b@[0..3), 2: b@[0..5), etc
+  // CHECK-DAG: %[[MASK1:.*]] = arith.constant dense<[7, 31, 15, 31, 31, 7, 31, 15, 31, 31, 7, 31, 15, 31, 31, 7, 31, 15, 31, 31]> :  vector<20xi64>
+  ///                             empty, 3: b@[0..1), empty etc
+  // CHECK-DAG: %[[MASK2:.*]] = arith.constant dense<[0, 1, 0, 3, 0, 0, 1, 0, 3, 0, 0, 1, 0, 3, 0, 0, 1, 0, 3, 0]> : vector<20xi64>
+  // CHECK-DAG: %[[SHR0_CST:.*]] = arith.constant dense<[0, 3, 1, 4, 2, 0, 3, 1, 4, 2, 0, 3, 1, 4, 2, 0, 3, 1, 4, 2]> : vector<20xi64>
+  // CHECK-DAG: %[[SHL1_CST:.*]] = arith.constant dense<[5, 2, 4, 1, 3, 5, 2, 4, 1, 3, 5, 2, 4, 1, 3, 5, 2, 4, 1, 3]> : vector<20xi64>
+  // CHECK-DAG: %[[SHL2_CST:.*]] = arith.constant dense<[8, 7, 8, 6, 8, 8, 7, 8, 6, 8, 8, 7, 8, 6, 8, 8, 7, 8, 6, 8]> : vector<20xi64>
+  //
+  // CHECK: %[[V0:.*]] = vector.shuffle %[[A]], %[[A]] [0, 1, 3, 4, 6, 8, 9, 11, 12, 14, 16, 17, 19, 20, 22, 24, 25, 27, 28, 30] : vector<32xi64>, vector<32xi64>
+  // CHECK: %[[A0:.*]] = arith.andi %[[V0]], %[[MASK0]] : vector<20xi64>
+  // CHECK: %[[SHR0:.*]] = arith.shrui %[[A0]], %[[SHR0_CST]] : vector<20xi64>
+  // CHECK: %[[V1:.*]] = vector.shuffle %[[A]], %[[A]] [1, 2, 4, 5, 7, 9, 10, 12, 13, 15, 17, 18, 20, 21, 23, 25, 26, 28, 29, 31] : vector<32xi64>, vector<32xi64>
+  // CHECK: %[[A1:.*]] = arith.andi %[[V1]], %[[MASK1]] : vector<20xi64>
+  // CHECK: %[[SHL1:.*]] = arith.shli %[[A1]], %[[SHL1_CST]] : vector<20xi64>
+  // CHECK: %[[O1:.*]] = arith.ori %[[SHR0]], %[[SHL1]] : vector<20xi64>
+  // CHECK: %[[V2:.*]] = vector.shuffle %[[A]], %[[A]] [0, 3, 0, 6, 0, 0, 11, 0, 14, 0, 0, 19, 0, 22, 0, 0, 27, 0, 30, 0] : vector<32xi64>, vector<32xi64>
+  // CHECK: %[[A2:.*]] = arith.andi %[[V2]], %[[MASK2]] : vector<20xi64>
+  // CHECK: %[[SHL2:.*]] = arith.shli %[[A2]], %[[SHL2_CST]] : vector<20xi64>
+  // CHECK: %[[O2:.*]] = arith.ori %[[O1]], %[[SHL2]] : vector<20xi64>
+  // CHECK: %[[TR:.*]] = arith.trunci %[[O2]] : vector<20xi64> to vector<20xi8>
+  // CHECK-NOT: bitcast
+  %0 = arith.trunci %a : vector<32xi64> to vector<32xi5>
+  %1 = vector.bitcast %0 : vector<32xi5> to vector<20xi8>
+  return %1 : vector<20xi8>
+}
+
+// CHECK-LABEL: func.func @f2(
+//  CHECK-SAME:   %[[A:[0-9a-z]*]]: vector<16xi16>) -> vector<3xi16>
+func.func @f2(%a: vector<16xi16>) -> vector<3xi16> {
+  /// Rewriting this standalone pattern is about 1.8x faster on skylake-ax512 according to llvm-mca.
+  /// Benefit further increases when mixed with other compute ops.
+  ///
+  // { 0: b@[0..3) lshl: 0 } { 1: b@[0..3) lshl: 3 } { 2: b@[0..3) lshl: 6 } { 3: b@[0..3) lshl: 9 } { 4: b@[0..3) lshl: 12 } { 5: b@[0..1) lshl: 15 } 
+  // { 5: b@[1..3) lshl: 0 } { 6: b@[0..3) lshl: 2 } { 7: b@[0..3) lshl: 5 } { 8: b@[0..3) lshl: 8 } { 9: b@[0..3) lshl: 11 } { 10: b@[0..2) lshl: 14 } 
+  // { 10: b@[2..3) lshl: 0 } { 11: b@[0..3) lshl: 1 } { 12: b@[0..3) lshl: 4 } { 13: b@[0..3) lshl: 7 } { 14: b@[0..3) lshl: 10 } { 15: b@[0..3) lshl: 13 }
+  ///                                             0: b@[0..3), 5: b@[1..3), 10: b@[2..3)
+  // CHECK-DAG: %[[MASK0:.*]] = arith.constant dense<[7, 6, 4]> : vector<3xi16>
+  ///                                             1: b@[0..3), 6: b@[0..3), 11: b@[0..3)
+  ///                                             ...
+  // CHECK-DAG: %[[MASK1:.*]] = arith.constant dense<7> : vector<3xi16>
+  ///                                             5: b@[0..1), 10: b@[0..2), 15: b@[0..3)
+  // CHECK-DAG: %[[MASK2:.*]] = arith.constant dense<[1, 3, 7]> : vector<3xi16>
+  // CHECK-DAG: %[[SHR0_CST:.*]] = arith.constant dense<[0, 1, 2]> : vector<3xi16>
+  // CHECK-DAG: %[[SHL1_CST:.*]] = arith.constant dense<[3, 2, 1]> : vector<3xi16>
+  // CHECK-DAG: %[[SHL2_CST:.*]] = arith.constant dense<[6, 5, 4]> : vector<3xi16>
+  // CHECK-DAG: %[[SHL3_CST:.*]] = arith.constant dense<[9, 8, 7]> : vector<3xi16>
+  // CHECK-DAG: %[[SHL4_CST:.*]] = arith.constant dense<[12, 11, 10]> : vector<3xi16>
+  // CHECK-DAG: %[[SHL5_CST:.*]] = arith.constant dense<[15, 14, 13]> : vector<3xi16>
+
+  //
+  // CHECK: %[[V0:.*]] = vector.shuffle %[[A]], %[[A]] [0, 5, 10] : vector<16xi16>, vector<16xi16>
+  // CHECK: %[[A0:.*]] = arith.andi %[[V0]], %[[MASK0]] : vector<3xi16>
+  // CHECK: %[[SHR0:.*]] = arith.shrui %[[A0]], %[[SHR0_CST]] : vector<3xi16>
+  // CHECK: %[[V1:.*]] = vector.shuffle %[[A]], %[[A]] [1, 6, 11] : vector<16xi16>, vector<16xi16>
+  // CHECK: %[[A1:.*]] = arith.andi %[[V1]], %[[MASK1]] : vector<3xi16>
+  // CHECK: %[[SHL1:.*]] = arith.shli %[[A1]], %[[SHL1_CST]] : vector<3xi16>
+  // CHECK: %[[O1:.*]] = arith.ori %[[SHR0]], %[[SHL1]] : vector<3xi16>
+  // CHECK: %[[V2:.*]] = vector.shuffle %[[A]], %[[A]] [2, 7, 12] : vector<16xi16>, vector<16xi16>
+  // CHECK: %[[A2:.*]] = arith.andi %[[V2]], %[[MASK1]] : vector<3xi16>
+  // CHECK: %[[SHL2:.*]] = arith.shli %[[A2]], %[[SHL2_CST]] : vector<3xi16>
+  // CHECK: %[[O2:.*]] = arith.ori %[[O1]], %[[SHL2]] : vector<3xi16>
+  // CHECK: %[[V3:.*]] = vector.shuffle %[[A]], %[[A]] [3, 8, 13] : vector<16xi16>, vector<16xi16>
+  // CHECK: %[[A3:.*]] = arith.andi %[[V3]], %[[MASK1]] : vector<3xi16>
+  // CHECK: %[[SHL3:.*]] = arith.shli %[[A3]], %[[SHL3_CST]] : vector<3xi16>
+  // CHECK: %[[O3:.*]] = arith.ori %[[O2]], %[[SHL3]]  : vector<3xi16>
+  // CHECK: %[[V4:.*]] = vector.shuffle %[[A]], %[[A]] [4, 9, 14] : vector<16xi16>, vector<16xi16>
+  // CHECK: %[[A4:.*]] = arith.andi %[[V4]], %[[MASK1]] : vector<3xi16>
+  // CHECK: %[[SHL4:.*]] = arith.shli %[[A4]], %[[SHL4_CST]] : vector<3xi16>
+  // CHECK: %[[O4:.*]] = arith.ori %[[O3]], %[[SHL4]]  : vector<3xi16>
+  // CHECK: %[[V5:.*]] = vector.shuffle %[[A]], %[[A]] [5, 10, 15] : vector<16xi16>, vector<16xi16>
+  // CHECK: %[[A5:.*]] = arith.andi %[[V5]], %[[MASK2]] : vector<3xi16>
+  // CHECK: %[[SHL5:.*]] = arith.shli %[[A5]], %[[SHL5_CST]] : vector<3xi16>
+  // CHECK: %[[O5:.*]] = arith.ori %[[O4]], %[[SHL5]]  : vector<3xi16>
+  /// No trunci needed as the result is already in i16.
+  // CHECK-NOT: arith.trunci
+  // CHECK-NOT: bitcast
+  %0 = arith.trunci %a : vector<16xi16> to vector<16xi3>
+  %1 = vector.bitcast %0 : vector<16xi3> to vector<3xi16>
+  return %1 : vector<3xi16>
+}
+
+/// This pattern requires an extui 16 -> 32 and not a trunci.
+// CHECK-LABEL: func.func @f3(
+func.func @f3(%a: vector<16xi16>) -> vector<2xi32> {
+  /// Rewriting this standalone pattern is about 25x faster on skylake-ax512 according to llvm-mca.
+  /// Benefit further increases when mixed with other compute ops.
+  ///
+  // CHECK-NOT: arith.trunci
+  // CHECK-NOT: bitcast
+  //     CHECK: arith.extui
+  %0 = arith.trunci %a : vector<16xi16> to vector<16xi4>
+  %1 = vector.bitcast %0 : vector<16xi4> to vector<2xi32>
+  return %1 : vector<2xi32>
+}
+
+/// This pattern is not rewritten as the result i6 is not a multiple of i8.
+// CHECK-LABEL: func.func @f4(
+func.func @f4(%a: vector<16xi16>) -> vector<8xi6> {
+  // CHECK: trunci
+  // CHECK: bitcast
+  // CHECK-NOT: shuffle
+  // CHECK-NOT: andi
+  // CHECK-NOT: ori
+  %0 = arith.trunci %a : vector<16xi16> to vector<16xi3>
+  %1 = vector.bitcast %0 : vector<16xi3> to vector<8xi6>
+  return %1 : vector<8xi6>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !transform.any_op):
+  %f = transform.structured.match ops{["func.func"]} in %module_op
+      : (!transform.any_op) -> !transform.any_op
+
+  transform.apply_patterns to %f {
+    transform.apply_patterns.vector.rewrite_narrow_types
+  } : !transform.any_op
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-rewrite-narrow-types.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-rewrite-narrow-types.mlir
new file mode 100644
index 000000000000000..44c608726f13530
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-rewrite-narrow-types.mlir
@@ -0,0 +1,155 @@
+/// Run once without applying the pattern and check the source of truth.
+// RUN: mlir-opt %s --test-transform-dialect-erase-schedule -test-lower-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void  \
+// RUN:   -shared-libs=%mlir_c_runner_utils | \
+// RUN: FileCheck %s
+
+/// Run once with the pattern and compare.
+// RUN: mlir-opt %s -test-transform-dialect-interpreter -test-transform-dialect-erase-schedule -test-lower-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void  \
+// RUN:   -shared-libs=%mlir_c_runner_utils | \
+// RUN: FileCheck %s
+
+func.func @print_as_i1_16xi5(%v : vector<16xi5>) {
+  %bitsi16 = vector.bitcast %v : vector<16xi5> to vector<80xi1>
+  vector.print %bitsi16 : vector<80xi1>
+  return
+}
+
+func.func @print_as_i1_10xi8(%v : vector<10xi8>) {
+  %bitsi16 = vector.bitcast %v : vector<10xi8> to vector<80xi1>
+  vector.print %bitsi16 : vector<80xi1>
+  return
+}
+
+func.func @f(%v: vector<16xi16>) {
+  %trunc = arith.trunci %v : vector<16xi16> to vector<16xi5>
+  func.call @print_as_i1_16xi5(%trunc) : (vector<16xi5>) -> ()
+  //      CHECK: ( 
+  // CHECK-SAME: 1, 1, 1, 1, 1,
+  // CHECK-SAME: 0, 1, 1, 1, 1,
+  // CHECK-SAME: 1, 0, 1, 1, 1,
+  // CHECK-SAME: 0, 0, 1, 1, 1,
+  // CHECK-SAME: 1, 1, 0, 1, 1,
+  // CHECK-SAME: 0, 1, 0, 1, 1,
+  // CHECK-SAME: 1, 0, 0, 1, 1,
+  // CHECK-SAME: 0, 0, 0, 1, 1,
+  // CHECK-SAME: 1, 1, 1, 0, 1,
+  // CHECK-SAME: 0, 1, 1, 0, 1,
+  // CHECK-SAME: 1, 0, 1, 0, 1,
+  // CHECK-SAME: 0, 0, 1, 0, 1,
+  // CHECK-SAME: 1, 1, 0, 0, 1,
+  // CHECK-SAME: 0, 1, 0, 0, 1,
+  // CHECK-SAME: 1, 0, 0, 0, 1,
+  // CHECK-SAME: 0, 0, 0, 0, 1 )
+
+  %bitcast = vector.bitcast %trunc : vector<16xi5> to vector<10xi8>
+  func.call @print_as_i1_10xi8(%bitcast) : (vector<10xi8>) -> ()
+  //      CHECK: ( 
+  // CHECK-SAME: 1, 1, 1, 1, 1, 0, 1, 1,
+  // CHECK-SAME: 1, 1, 1, 0, 1, 1, 1, 0,
+  // CHECK-SAME: 0, 1, 1, 1, 1, 1, 0, 1,
+  // CHECK-SAME: 1, 0, 1, 0, 1, 1, 1, 0,
+  // CHECK-SAME: 0, 1, 1, 0, 0, 0, 1, 1,
+  // CHECK-SAME: 1, 1, 1, 0, 1, 0, 1, 1,
+  // CHECK-SAME: 0, 1, 1, 0, 1, 0, 1, 0,
+  // CHECK-SAME: 0, 1, 0, 1, 1, 1, 0, 0,
+  // CHECK-SAME: 1, 0, 1, 0, 0, 1, 1, 0,
+  // CHECK-SAME: 0, 0, 1, 0, 0, 0, 0, 1 )
+
+  return
+}
+
+func.func @print_as_i1_8xi3(%v : vector<8xi3>) {
+  %bitsi12 = vector.bitcast %v : vector<8xi3> to vector<24xi1>
+  vector.print %bitsi12 : vector<24xi1>
+  return
+}
+
+func.func @print_as_i1_3xi8(%v : vector<3xi8>) {
+  %bitsi12 = vector.bitcast %v : vector<3xi8> to vector<24xi1>
+  vector.print %bitsi12 : vector<24xi1>
+  return
+}
+
+func.func @f2(%v: vector<8xi32>) {
+  %trunc = arith.trunci %v : vector<8xi32> to vector<8xi3>
+  func.call @print_as_i1_8xi3(%trunc) : (vector<8xi3>) -> ()
+  //      CHECK: (
+  // CHECK-SAME: 1, 1, 1,
+  // CHECK-SAME: 0, 1, 1,
+  // CHECK-SAME: 1, 0, 1,
+  // CHECK-SAME: 0, 0, 1,
+  // CHECK-SAME: 1, 1, 0,
+  // CHECK-SAME: 0, 1, 0,
+  // CHECK-SAME: 1, 0, 0,
+  // CHECK-SAME: 0, 0, 0 )
+
+  %bitcast = vector.bitcast %trunc : vector<8xi3> to vector<3xi8>
+  func.call @print_as_i1_3xi8(%bitcast) : (vector<3xi8>) -> ()
+  //      CHECK: (
+  // CHECK-SAME: 1, 1, 1, 0, 1, 1, 1, 0,
+  // CHECK-SAME: 1, 0, 0, 1, 1, 1, 0, 0,
+  // CHECK-SAME: 1, 0, 1, 0, 0, 0, 0, 0 )
+
+  return
+}
+
+func.func @print_as_i1_2xi24(%v : vector<2xi24>) {
+  %bitsi48 = vector.bitcast %v : vector<2xi24> to vector<48xi1>
+  vector.print %bitsi48 : vector<48xi1>
+  return
+}
+
+func.func @print_as_i1_3xi16(%v : vector<3xi16>) {
+  %bitsi48 = vector.bitcast %v : vector<3xi16> to vector<48xi1>
+  vector.print %bitsi48 : vector<48xi1>
+  return
+}
+
+func.func @f3(%v: vector<2xi48>) {
+  %trunc = arith.trunci %v : vector<2xi48> to vector<2xi24>
+  func.call @print_as_i1_2xi24(%trunc) : (vector<2xi24>) -> ()
+  //      CHECK: (
+  // CHECK-SAME: 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
+  // CHECK-SAME: 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 0 )
+
+  %bitcast = vector.bitcast %trunc : vector<2xi24> to vector<3xi16>
+  func.call @print_as_i1_3xi16(%bitcast) : (vector<3xi16>) -> ()
+  //      CHECK: (
+  // CHECK-SAME: 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
+  // CHECK-SAME: 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0,
+  // CHECK-SAME: 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 0 )
+
+  return
+}
+
+func.func @entry() {
+  %v = arith.constant dense<[
+    0xffff, 0xfffe, 0xfffd, 0xfffc, 0xfffb, 0xfffa, 0xfff9, 0xfff8,
+    0xfff7, 0xfff6, 0xfff5, 0xfff4, 0xfff3, 0xfff2, 0xfff1, 0xfff0
+  ]> : vector<16xi16>
+  func.call @f(%v) : (vector<16xi16>) -> ()
+
+  %v2 = arith.constant dense<[
+    0xffff, 0xfffe, 0xfffd, 0xfffc, 0xfffb, 0xfffa, 0xfff9, 0xfff8
+  ]> : vector<8xi32>
+  func.call @f2(%v2) : (vector<8xi32>) -> ()
+
+  %v3 = arith.constant dense<[
+    0xf345aeffffff, 0xffff015f345a
+  ]> : vector<2xi48>
+  func.call @f3(%v3) : (vector<2xi48>) -> ()
+
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !transform.any_op):
+  %f = transform.structured.match ops{["func.func"]} in %module_op
+      : (!transform.any_op) -> !transform.any_op
+
+  transform.apply_patterns to %f {
+    transform.apply_patterns.vector.rewrite_narrow_types
+  } : !transform.any_op
+}



More information about the Mlir-commits mailing list