[Mlir-commits] [mlir] [mlir][Vector] Add a rewrite pattern for better low-precision ext(bit… (PR #65774)
Nicolas Vasilache
llvmlistbot at llvm.org
Mon Sep 18 05:50:36 PDT 2023
https://github.com/nicolasvasilache updated https://github.com/llvm/llvm-project/pull/65774
>From 5132376dc1088dcab4d3ba4008e73f469dd840fa Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nicolasvasilache at users.noreply.github.com>
Date: Fri, 8 Sep 2023 11:28:55 +0200
Subject: [PATCH 1/3] [mlir][Vector] Add a rewrite pattern for better
low-precision bitcast(trunci) expansion
This revision adds a rewrite for sequences of vector `bitcast(trunci)` to use a more efficient sequence of
vector operations comprising `shuffle` and `bitwise` ops.
Such patterns appear naturally when writing quantization / dequantization functionality with the vector dialect.
The rewrite performs a simple enumeration of each of the bits in the result vector and determines its provenance
in the pre-trunci vector. The enumeration is used to generate the proper sequence of `shuffle`, `andi`, `ori`
followed by an optional final `trunci`/`extui`.
The rewrite currently only applies to 1-D non-scalable vectors and bails out if the final vector element type is
not a multiple of 8. This is a failsafe heuristic determined empirically: if the resulting type
is not an even number of bytes, further complexities arise that are not improved by this pattern:
the heavy lifting still needs to be done by LLVM.
---
.../Vector/TransformOps/VectorTransformOps.td | 13 +
.../Vector/Transforms/VectorRewritePatterns.h | 15 +-
mlir/include/mlir/IR/BuiltinTypes.h | 10 +
.../TransformOps/VectorTransformOps.cpp | 5 +
.../Transforms/VectorEmulateNarrowType.cpp | 237 +++++++++++++++++-
.../Vector/vector-rewrite-narrow-types.mlir | 157 ++++++++++++
.../Vector/CPU/test-rewrite-narrow-types.mlir | 155 ++++++++++++
7 files changed, 587 insertions(+), 5 deletions(-)
create mode 100644 mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/test-rewrite-narrow-types.mlir
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 9e718a0c80bbf3b..133ee4e030f01e5 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -292,6 +292,19 @@ def ApplyLowerTransposePatternsOp : Op<Transform_Dialect,
}];
}
+def ApplyRewriteNarrowTypePatternsOp : Op<Transform_Dialect,
+ "apply_patterns.vector.rewrite_narrow_types",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Indicates that vector narrow rewrite operations should be applied.
+
+ This is usually a late step that is run after bufferization as part of the
+ process of lowering to e.g. LLVM or NVVM.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
def ApplySplitTransferFullPartialPatternsOp : Op<Transform_Dialect,
"apply_patterns.vector.split_transfer_full_partial",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index c644090d8c78cd0..8652fc7f5e5c640 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -24,6 +24,7 @@ class RewritePatternSet;
namespace arith {
class NarrowTypeEmulationConverter;
+class TruncIOp;
} // namespace arith
namespace vector {
@@ -143,7 +144,7 @@ void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
/// Patterns that remove redundant vector broadcasts.
void populateSinkVectorBroadcastPatterns(RewritePatternSet &patterns,
- PatternBenefit benefit = 1);
+ PatternBenefit benefit = 1);
/// Populate `patterns` with the following patterns.
///
@@ -301,6 +302,18 @@ void populateVectorNarrowTypeEmulationPatterns(
arith::NarrowTypeEmulationConverter &typeConverter,
RewritePatternSet &patterns);
+/// Rewrite a vector `bitcast(trunci)` to use a more efficient sequence of
+/// vector operations comprising `shuffle` and `bitwise` ops.
+FailureOr<Value> rewriteBitCastOfTruncI(RewriterBase &rewriter,
+ vector::BitCastOp bitCastOp,
+ arith::TruncIOp truncOp,
+ vector::BroadcastOp maybeBroadcastOp);
+
+/// Appends patterns for rewriting vector operations over narrow types with
+/// ops over wider types.
+void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
+
} // namespace vector
} // namespace mlir
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index f031eb0a5c30ce9..9df5548cd5d939c 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -357,6 +357,16 @@ class VectorType::Builder {
return *this;
}
+ /// Set a dim in shape @pos to val.
+ Builder &setDim(unsigned pos, int64_t val) {
+ if (storage.empty())
+ storage.append(shape.begin(), shape.end());
+ assert(pos < storage.size() && "overflow");
+ storage[pos] = val;
+ shape = {storage.data(), storage.size()};
+ return *this;
+ }
+
operator VectorType() {
return VectorType::get(shape, elementType, scalableDims);
}
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index b388deaa46a7917..37127ea70f1e5af 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -159,6 +159,11 @@ void transform::ApplyLowerTransposePatternsOp::populatePatterns(
}
}
+void transform::ApplyRewriteNarrowTypePatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ populateVectorNarrowTypeRewritePatterns(patterns);
+}
+
void transform::ApplySplitTransferFullPartialPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
vector::VectorTransformsOptions vectorTransformOptions;
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index b2b7bfc5e4437c1..eaf0b8849e36ebf 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -7,7 +7,6 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
@@ -15,13 +14,23 @@
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/IR/Value.h"
#include "mlir/Transforms/DialectConversion.h"
-#include "llvm/Support/FormatVariadic.h"
-#include "llvm/Support/MathExtras.h"
-#include <cassert>
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+#include <cstdint>
using namespace mlir;
+#define DEBUG_TYPE "vector-narrow-type-emulation"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+#define DBGSNL() (llvm::dbgs() << "\n")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
+
namespace {
//===----------------------------------------------------------------------===//
@@ -155,6 +164,221 @@ struct ConvertVectorTransferRead final
};
} // end anonymous namespace
+//===----------------------------------------------------------------------===//
+// RewriteBitCastOfTruncI
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+/// Helper struct to keep track of the provenance of a contiguous set of bits
+/// in a source vector.
+struct SourceElementRange {
+ int64_t sourceElement;
+ int64_t sourceBitBegin;
+ int64_t sourceBitEnd;
+};
+
+struct SourceElementRangeList : public SmallVector<SourceElementRange> {
+ /// Given the index of a SourceElementRange in the SourceElementRangeList,
+ /// compute the amount of bits that need to be shifted to the left to get the
+ /// bits in their final location. This shift amount is simply the sum of the
+ /// bits *before* `shuffleIdx` (i.e. the bits of `shuffleIdx = 0` are always
+ /// the LSBs, the bits of `shuffleIdx = ` come next, etc).
+ int64_t computeLeftShiftAmount(int64_t shuffleIdx) const {
+ int64_t res = 0;
+ for (int64_t i = 0; i < shuffleIdx; ++i)
+ res += (*this)[i].sourceBitEnd - (*this)[i].sourceBitBegin;
+ return res;
+ }
+};
+
+/// Helper struct to enumerate the source elements and bit ranges that are
+/// involved in a bitcast operation.
+/// This allows rewriting a vector.bitcast into shuffles and bitwise ops for
+/// any 1-D vector shape and any source/target bitwidths.
+struct BitCastBitsEnumerator {
+ BitCastBitsEnumerator(VectorType sourceVectorType,
+ VectorType targetVectorType);
+
+ int64_t getMaxNumberOfEntries() {
+ int64_t numVectors = 0;
+ for (const auto &l : sourceElementRanges)
+ numVectors = std::max(numVectors, (int64_t)l.size());
+ return numVectors;
+ }
+
+ VectorType sourceVectorType;
+ VectorType targetVectorType;
+ SmallVector<SourceElementRangeList> sourceElementRanges;
+};
+
+} // namespace
+
+static raw_ostream &operator<<(raw_ostream &os,
+ const SmallVector<SourceElementRangeList> &vec) {
+ for (const auto &l : vec) {
+ for (auto it : llvm::enumerate(l)) {
+ os << "{ " << it.value().sourceElement << ": b@["
+ << it.value().sourceBitBegin << ".." << it.value().sourceBitEnd
+ << ") lshl: " << l.computeLeftShiftAmount(it.index()) << " } ";
+ }
+ os << "\n";
+ }
+ return os;
+}
+
+BitCastBitsEnumerator::BitCastBitsEnumerator(VectorType sourceVectorType,
+ VectorType targetVectorType)
+ : sourceVectorType(sourceVectorType), targetVectorType(targetVectorType) {
+
+ assert(targetVectorType.getRank() == 1 && !targetVectorType.isScalable() &&
+ "requires -D non-scalable vector type");
+ int64_t sourceBitWidth = sourceVectorType.getElementTypeBitWidth();
+ int64_t mostMinorSourceDim = sourceVectorType.getShape().back();
+ LDBG("sourceVectorType: " << sourceVectorType);
+
+ int64_t targetBitWidth = targetVectorType.getElementTypeBitWidth();
+ int64_t mostMinorTargetDim = targetVectorType.getShape().back();
+ LDBG("targetVectorType: " << targetVectorType);
+
+ int64_t bitwidth = targetBitWidth * mostMinorTargetDim;
+ assert(bitwidth == sourceBitWidth * mostMinorSourceDim &&
+ "source and target bitwidths must match");
+
+ // Prepopulate one source element range per target element.
+ sourceElementRanges = SmallVector<SourceElementRangeList>(mostMinorTargetDim);
+ for (int64_t resultBit = 0; resultBit < bitwidth;) {
+ int64_t resultElement = resultBit / targetBitWidth;
+ int64_t resultBitInElement = resultBit % targetBitWidth;
+ int64_t sourceElement = resultBit / sourceBitWidth;
+ int64_t sourceBitInElement = resultBit % sourceBitWidth;
+ int64_t step = std::min(sourceBitWidth - sourceBitInElement,
+ targetBitWidth - resultBitInElement);
+ sourceElementRanges[resultElement].push_back(
+ {sourceElement, sourceBitInElement, sourceBitInElement + step});
+ resultBit += step;
+ }
+}
+
+namespace {
+/// Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take
+/// advantage of high-level information to avoid leaving LLVM to scramble with
+/// peephole optimizations.
+struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::BitCastOp bitCastOp,
+ PatternRewriter &rewriter) const override {
+ // The source must be a trunc op.
+ auto truncOp =
+ bitCastOp.getSource().template getDefiningOp<arith::TruncIOp>();
+ if (!truncOp)
+ return rewriter.notifyMatchFailure(bitCastOp, "not a trunci source");
+
+ VectorType targetVectorType = bitCastOp.getResultVectorType();
+ if (targetVectorType.getRank() != 1 || targetVectorType.isScalable())
+ return rewriter.notifyMatchFailure(bitCastOp, "scalable or >1-D vector");
+ // TODO: consider relaxing this restriction in the future if we find ways to
+ // really work with subbyte elements across the MLIR/LLVM boundary.
+ int64_t resultBitwidth = targetVectorType.getElementTypeBitWidth();
+ if (resultBitwidth % 8 != 0)
+ return rewriter.notifyMatchFailure(bitCastOp, "bitwidth is not k * 8");
+
+ VectorType sourceVectorType = bitCastOp.getSourceVectorType();
+ BitCastBitsEnumerator be(sourceVectorType, targetVectorType);
+ LDBG("\n" << be.sourceElementRanges);
+
+ Value initialValue = truncOp.getIn();
+ auto initalVectorType = initialValue.getType().cast<VectorType>();
+ auto initalElementType = initalVectorType.getElementType();
+ auto initalElementBitWidth = initalElementType.getIntOrFloatBitWidth();
+
+ // BitCastBitsEnumerator encodes for each element of the target vector the
+ // provenance of the bits in the source vector. We can "transpose" this
+ // information to build a sequence of shuffles and bitwise ops that will
+ // produce the desired result.
+ // The algorithm proceeds as follows:
+ // 1. there are as many shuffles as max entries in BitCastBitsEnumerator
+ // 2. for each shuffle:
+ // a. collect the source vectors that participate in this shuffle. One
+ // source vector per target element of the shuffle. If overflow, take 0.
+ // b. the bitrange in the source vector as a mask. If overflow, take 0.
+ // c. the number of bits to shift right to align the source bitrange at
+ // position 0. This is exactly the low end of the bitrange.
+ // d. number of bits to shift left to align to the desired position in
+ // the result element vector.
+ // Then build the sequence:
+ // (shuffle -> and -> shiftright -> shiftleft -> or) to iteratively update
+ // the result vector (i.e. the "shiftright -> shiftleft -> or" part) with
+ // the bits extracted from the source vector (i.e. the "shuffle -> and"
+ // part).
+ Value res;
+ for (int64_t shuffleIdx = 0, e = be.getMaxNumberOfEntries(); shuffleIdx < e;
+ ++shuffleIdx) {
+ SmallVector<int64_t> shuffles;
+ SmallVector<Attribute> masks, shiftRightAmounts, shiftLeftAmounts;
+ for (auto &l : be.sourceElementRanges) {
+ int64_t sourceElement =
+ (shuffleIdx < (int64_t)l.size()) ? l[shuffleIdx].sourceElement : 0;
+ shuffles.push_back(sourceElement);
+
+ int64_t bitLo =
+ (shuffleIdx < (int64_t)l.size()) ? l[shuffleIdx].sourceBitBegin : 0;
+ int64_t bitHi =
+ (shuffleIdx < (int64_t)l.size()) ? l[shuffleIdx].sourceBitEnd : 0;
+ IntegerAttr mask = IntegerAttr::get(
+ rewriter.getIntegerType(initalElementBitWidth),
+ llvm::APInt::getBitsSet(initalElementBitWidth, bitLo, bitHi));
+ masks.push_back(mask);
+
+ int64_t shiftRight = bitLo;
+ shiftRightAmounts.push_back(IntegerAttr::get(
+ rewriter.getIntegerType(initalElementBitWidth), shiftRight));
+
+ int64_t shiftLeft = l.computeLeftShiftAmount(shuffleIdx);
+ shiftLeftAmounts.push_back(IntegerAttr::get(
+ rewriter.getIntegerType(initalElementBitWidth), shiftLeft));
+ }
+
+ //
+ auto shuffleOp = rewriter.create<vector::ShuffleOp>(
+ bitCastOp.getLoc(), initialValue, initialValue, shuffles);
+
+ VectorType vt = VectorType::Builder(initalVectorType)
+ .setDim(initalVectorType.getRank() - 1, masks.size());
+ auto constOp = rewriter.create<arith::ConstantOp>(
+ bitCastOp.getLoc(), DenseElementsAttr::get(vt, masks));
+ Value andValue = rewriter.create<arith::AndIOp>(bitCastOp.getLoc(),
+ shuffleOp, constOp);
+
+ auto shiftRightConstantOp = rewriter.create<arith::ConstantOp>(
+ bitCastOp.getLoc(), DenseElementsAttr::get(vt, shiftRightAmounts));
+ Value shiftedRight = rewriter.create<arith::ShRUIOp>(
+ bitCastOp.getLoc(), andValue, shiftRightConstantOp);
+
+ auto shiftLeftConstantOp = rewriter.create<arith::ConstantOp>(
+ bitCastOp.getLoc(), DenseElementsAttr::get(vt, shiftLeftAmounts));
+ Value shiftedLeft = rewriter.create<arith::ShLIOp>(
+ bitCastOp.getLoc(), shiftedRight, shiftLeftConstantOp);
+
+ res = res ? rewriter.create<arith::OrIOp>(bitCastOp.getLoc(), res,
+ shiftedLeft)
+ : shiftedLeft;
+ }
+
+ bool narrowing = resultBitwidth <= initalElementBitWidth;
+ if (narrowing) {
+ rewriter.replaceOpWithNewOp<arith::TruncIOp>(
+ bitCastOp, bitCastOp.getResultVectorType(), res);
+ } else {
+ rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
+ bitCastOp, bitCastOp.getResultVectorType(), res);
+ }
+ return success();
+ }
+};
+} // namespace
+
//===----------------------------------------------------------------------===//
// Public Interface Definition
//===----------------------------------------------------------------------===//
@@ -167,3 +391,8 @@ void vector::populateVectorNarrowTypeEmulationPatterns(
patterns.add<ConvertVectorLoad, ConvertVectorTransferRead>(
typeConverter, patterns.getContext());
}
+
+void vector::populateVectorNarrowTypeRewritePatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns.add<RewriteBitCastOfTruncI>(patterns.getContext(), benefit);
+}
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
new file mode 100644
index 000000000000000..ba6efde40f36c2b
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -0,0 +1,157 @@
+// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s
+
+/// Note: Inspect generated assembly and llvm-mca stats:
+/// ====================================================
+/// mlir-opt --test-transform-dialect-interpreter mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir -test-transform-dialect-erase-schedule -test-lower-to-llvm | mlir-translate -mlir-to-llvmir | llc -o - -mcpu=skylake-avx512 --function-sections -filetype=obj > /tmp/a.out; objdump -d --disassemble=f1 --no-addresses --no-show-raw-insn -M att /tmp/a.out | ./build/bin/llvm-mca -mcpu=skylake-avx512
+
+// CHECK-LABEL: func.func @f1(
+// CHECK-SAME: %[[A:[0-9a-z]*]]: vector<32xi64>) -> vector<20xi8>
+func.func @f1(%a: vector<32xi64>) -> vector<20xi8> {
+ /// Rewriting this standalone pattern is about 2x faster on skylake-ax512 according to llvm-mca.
+ /// Benefit further increases when mixed with other compute ops.
+ ///
+ /// The provenance of the 20x8 bits of the result are the following bits in the
+ /// source vector:
+ // { 0: b@[0..5) lshl: 0 } { 1: b@[0..3) lshl: 5 }
+ // { 1: b@[3..5) lshl: 0 } { 2: b@[0..5) lshl: 2 } { 3: b@[0..1) lshl: 7 }
+ // { 3: b@[1..5) lshl: 0 } { 4: b@[0..4) lshl: 4 }
+ // { 4: b@[4..5) lshl: 0 } { 5: b@[0..5) lshl: 1 } { 6: b@[0..2) lshl: 6 }
+ // { 6: b@[2..5) lshl: 0 } { 7: b@[0..5) lshl: 3 }
+ // { 8: b@[0..5) lshl: 0 } { 9: b@[0..3) lshl: 5 }
+ // { 9: b@[3..5) lshl: 0 } { 10: b@[0..5) lshl: 2 } { 11: b@[0..1) lshl: 7 }
+ // { 11: b@[1..5) lshl: 0 } { 12: b@[0..4) lshl: 4 }
+ // { 12: b@[4..5) lshl: 0 } { 13: b@[0..5) lshl: 1 } { 14: b@[0..2) lshl: 6 }
+ // { 14: b@[2..5) lshl: 0 } { 15: b@[0..5) lshl: 3 }
+ // { 16: b@[0..5) lshl: 0 } { 17: b@[0..3) lshl: 5 }
+ // { 17: b@[3..5) lshl: 0 } { 18: b@[0..5) lshl: 2 } { 19: b@[0..1) lshl: 7 }
+ // { 19: b@[1..5) lshl: 0 } { 20: b@[0..4) lshl: 4 }
+ // { 20: b@[4..5) lshl: 0 } { 21: b@[0..5) lshl: 1 } { 22: b@[0..2) lshl: 6 }
+ // { 22: b@[2..5) lshl: 0 } { 23: b@[0..5) lshl: 3 }
+ // { 24: b@[0..5) lshl: 0 } { 25: b@[0..3) lshl: 5 }
+ // { 25: b@[3..5) lshl: 0 } { 26: b@[0..5) lshl: 2 } { 27: b@[0..1) lshl: 7 }
+ // { 27: b@[1..5) lshl: 0 } { 28: b@[0..4) lshl: 4 }
+ // { 28: b@[4..5) lshl: 0 } { 29: b@[0..5) lshl: 1 } { 30: b@[0..2) lshl: 6 }
+ // { 30: b@[2..5) lshl: 0 } { 31: b@[0..5) lshl: 3 }
+ /// This results in 3 shuffles + 1 shr + 2 shl + 3 and + 2 or.
+ /// The third vector is empty for positions 0, 2, 4, 5, 7, 9, 10, 12, 14, 15,
+ /// 17 and 19 (i.e. there are only 2 entries in that row).
+ ///
+ /// 0: b@[0..5), 1: b@[3..5), etc
+ // CHECK-DAG: %[[MASK0:.*]] = arith.constant dense<[31, 24, 30, 16, 28, 31, 24, 30, 16, 28, 31, 24, 30, 16, 28, 31, 24, 30, 16, 28]> : vector<20xi64>
+ /// 1: b@[0..3), 2: b@[0..5), etc
+ // CHECK-DAG: %[[MASK1:.*]] = arith.constant dense<[7, 31, 15, 31, 31, 7, 31, 15, 31, 31, 7, 31, 15, 31, 31, 7, 31, 15, 31, 31]> : vector<20xi64>
+ /// empty, 3: b@[0..1), empty etc
+ // CHECK-DAG: %[[MASK2:.*]] = arith.constant dense<[0, 1, 0, 3, 0, 0, 1, 0, 3, 0, 0, 1, 0, 3, 0, 0, 1, 0, 3, 0]> : vector<20xi64>
+ // CHECK-DAG: %[[SHR0_CST:.*]] = arith.constant dense<[0, 3, 1, 4, 2, 0, 3, 1, 4, 2, 0, 3, 1, 4, 2, 0, 3, 1, 4, 2]> : vector<20xi64>
+ // CHECK-DAG: %[[SHL1_CST:.*]] = arith.constant dense<[5, 2, 4, 1, 3, 5, 2, 4, 1, 3, 5, 2, 4, 1, 3, 5, 2, 4, 1, 3]> : vector<20xi64>
+ // CHECK-DAG: %[[SHL2_CST:.*]] = arith.constant dense<[8, 7, 8, 6, 8, 8, 7, 8, 6, 8, 8, 7, 8, 6, 8, 8, 7, 8, 6, 8]> : vector<20xi64>
+ //
+ // CHECK: %[[V0:.*]] = vector.shuffle %[[A]], %[[A]] [0, 1, 3, 4, 6, 8, 9, 11, 12, 14, 16, 17, 19, 20, 22, 24, 25, 27, 28, 30] : vector<32xi64>, vector<32xi64>
+ // CHECK: %[[A0:.*]] = arith.andi %[[V0]], %[[MASK0]] : vector<20xi64>
+ // CHECK: %[[SHR0:.*]] = arith.shrui %[[A0]], %[[SHR0_CST]] : vector<20xi64>
+ // CHECK: %[[V1:.*]] = vector.shuffle %[[A]], %[[A]] [1, 2, 4, 5, 7, 9, 10, 12, 13, 15, 17, 18, 20, 21, 23, 25, 26, 28, 29, 31] : vector<32xi64>, vector<32xi64>
+ // CHECK: %[[A1:.*]] = arith.andi %[[V1]], %[[MASK1]] : vector<20xi64>
+ // CHECK: %[[SHL1:.*]] = arith.shli %[[A1]], %[[SHL1_CST]] : vector<20xi64>
+ // CHECK: %[[O1:.*]] = arith.ori %[[SHR0]], %[[SHL1]] : vector<20xi64>
+ // CHECK: %[[V2:.*]] = vector.shuffle %[[A]], %[[A]] [0, 3, 0, 6, 0, 0, 11, 0, 14, 0, 0, 19, 0, 22, 0, 0, 27, 0, 30, 0] : vector<32xi64>, vector<32xi64>
+ // CHECK: %[[A2:.*]] = arith.andi %[[V2]], %[[MASK2]] : vector<20xi64>
+ // CHECK: %[[SHL2:.*]] = arith.shli %[[A2]], %[[SHL2_CST]] : vector<20xi64>
+ // CHECK: %[[O2:.*]] = arith.ori %[[O1]], %[[SHL2]] : vector<20xi64>
+ // CHECK: %[[TR:.*]] = arith.trunci %[[O2]] : vector<20xi64> to vector<20xi8>
+ // CHECK-NOT: bitcast
+ %0 = arith.trunci %a : vector<32xi64> to vector<32xi5>
+ %1 = vector.bitcast %0 : vector<32xi5> to vector<20xi8>
+ return %1 : vector<20xi8>
+}
+
+// CHECK-LABEL: func.func @f2(
+// CHECK-SAME: %[[A:[0-9a-z]*]]: vector<16xi16>) -> vector<3xi16>
+func.func @f2(%a: vector<16xi16>) -> vector<3xi16> {
+ /// Rewriting this standalone pattern is about 1.8x faster on skylake-ax512 according to llvm-mca.
+ /// Benefit further increases when mixed with other compute ops.
+ ///
+ // { 0: b@[0..3) lshl: 0 } { 1: b@[0..3) lshl: 3 } { 2: b@[0..3) lshl: 6 } { 3: b@[0..3) lshl: 9 } { 4: b@[0..3) lshl: 12 } { 5: b@[0..1) lshl: 15 }
+ // { 5: b@[1..3) lshl: 0 } { 6: b@[0..3) lshl: 2 } { 7: b@[0..3) lshl: 5 } { 8: b@[0..3) lshl: 8 } { 9: b@[0..3) lshl: 11 } { 10: b@[0..2) lshl: 14 }
+ // { 10: b@[2..3) lshl: 0 } { 11: b@[0..3) lshl: 1 } { 12: b@[0..3) lshl: 4 } { 13: b@[0..3) lshl: 7 } { 14: b@[0..3) lshl: 10 } { 15: b@[0..3) lshl: 13 }
+ /// 0: b@[0..3), 5: b@[1..3), 10: b@[2..3)
+ // CHECK-DAG: %[[MASK0:.*]] = arith.constant dense<[7, 6, 4]> : vector<3xi16>
+ /// 1: b@[0..3), 6: b@[0..3), 11: b@[0..3)
+ /// ...
+ // CHECK-DAG: %[[MASK1:.*]] = arith.constant dense<7> : vector<3xi16>
+ /// 5: b@[0..1), 10: b@[0..2), 15: b@[0..3)
+ // CHECK-DAG: %[[MASK2:.*]] = arith.constant dense<[1, 3, 7]> : vector<3xi16>
+ // CHECK-DAG: %[[SHR0_CST:.*]] = arith.constant dense<[0, 1, 2]> : vector<3xi16>
+ // CHECK-DAG: %[[SHL1_CST:.*]] = arith.constant dense<[3, 2, 1]> : vector<3xi16>
+ // CHECK-DAG: %[[SHL2_CST:.*]] = arith.constant dense<[6, 5, 4]> : vector<3xi16>
+ // CHECK-DAG: %[[SHL3_CST:.*]] = arith.constant dense<[9, 8, 7]> : vector<3xi16>
+ // CHECK-DAG: %[[SHL4_CST:.*]] = arith.constant dense<[12, 11, 10]> : vector<3xi16>
+ // CHECK-DAG: %[[SHL5_CST:.*]] = arith.constant dense<[15, 14, 13]> : vector<3xi16>
+
+ //
+ // CHECK: %[[V0:.*]] = vector.shuffle %[[A]], %[[A]] [0, 5, 10] : vector<16xi16>, vector<16xi16>
+ // CHECK: %[[A0:.*]] = arith.andi %[[V0]], %[[MASK0]] : vector<3xi16>
+ // CHECK: %[[SHR0:.*]] = arith.shrui %[[A0]], %[[SHR0_CST]] : vector<3xi16>
+ // CHECK: %[[V1:.*]] = vector.shuffle %[[A]], %[[A]] [1, 6, 11] : vector<16xi16>, vector<16xi16>
+ // CHECK: %[[A1:.*]] = arith.andi %[[V1]], %[[MASK1]] : vector<3xi16>
+ // CHECK: %[[SHL1:.*]] = arith.shli %[[A1]], %[[SHL1_CST]] : vector<3xi16>
+ // CHECK: %[[O1:.*]] = arith.ori %[[SHR0]], %[[SHL1]] : vector<3xi16>
+ // CHECK: %[[V2:.*]] = vector.shuffle %[[A]], %[[A]] [2, 7, 12] : vector<16xi16>, vector<16xi16>
+ // CHECK: %[[A2:.*]] = arith.andi %[[V2]], %[[MASK1]] : vector<3xi16>
+ // CHECK: %[[SHL2:.*]] = arith.shli %[[A2]], %[[SHL2_CST]] : vector<3xi16>
+ // CHECK: %[[O2:.*]] = arith.ori %[[O1]], %[[SHL2]] : vector<3xi16>
+ // CHECK: %[[V3:.*]] = vector.shuffle %[[A]], %[[A]] [3, 8, 13] : vector<16xi16>, vector<16xi16>
+ // CHECK: %[[A3:.*]] = arith.andi %[[V3]], %[[MASK1]] : vector<3xi16>
+ // CHECK: %[[SHL3:.*]] = arith.shli %[[A3]], %[[SHL3_CST]] : vector<3xi16>
+ // CHECK: %[[O3:.*]] = arith.ori %[[O2]], %[[SHL3]] : vector<3xi16>
+ // CHECK: %[[V4:.*]] = vector.shuffle %[[A]], %[[A]] [4, 9, 14] : vector<16xi16>, vector<16xi16>
+ // CHECK: %[[A4:.*]] = arith.andi %[[V4]], %[[MASK1]] : vector<3xi16>
+ // CHECK: %[[SHL4:.*]] = arith.shli %[[A4]], %[[SHL4_CST]] : vector<3xi16>
+ // CHECK: %[[O4:.*]] = arith.ori %[[O3]], %[[SHL4]] : vector<3xi16>
+ // CHECK: %[[V5:.*]] = vector.shuffle %[[A]], %[[A]] [5, 10, 15] : vector<16xi16>, vector<16xi16>
+ // CHECK: %[[A5:.*]] = arith.andi %[[V5]], %[[MASK2]] : vector<3xi16>
+ // CHECK: %[[SHL5:.*]] = arith.shli %[[A5]], %[[SHL5_CST]] : vector<3xi16>
+ // CHECK: %[[O5:.*]] = arith.ori %[[O4]], %[[SHL5]] : vector<3xi16>
+ /// No trunci needed as the result is already in i16.
+ // CHECK-NOT: arith.trunci
+ // CHECK-NOT: bitcast
+ %0 = arith.trunci %a : vector<16xi16> to vector<16xi3>
+ %1 = vector.bitcast %0 : vector<16xi3> to vector<3xi16>
+ return %1 : vector<3xi16>
+}
+
+/// This pattern requires an extui 16 -> 32 and not a trunci.
+// CHECK-LABEL: func.func @f3(
+func.func @f3(%a: vector<16xi16>) -> vector<2xi32> {
+ /// Rewriting this standalone pattern is about 25x faster on skylake-ax512 according to llvm-mca.
+ /// Benefit further increases when mixed with other compute ops.
+ ///
+ // CHECK-NOT: arith.trunci
+ // CHECK-NOT: bitcast
+ // CHECK: arith.extui
+ %0 = arith.trunci %a : vector<16xi16> to vector<16xi4>
+ %1 = vector.bitcast %0 : vector<16xi4> to vector<2xi32>
+ return %1 : vector<2xi32>
+}
+
+/// This pattern is not rewritten as the result i6 is not a multiple of i8.
+// CHECK-LABEL: func.func @f4(
+func.func @f4(%a: vector<16xi16>) -> vector<8xi6> {
+ // CHECK: trunci
+ // CHECK: bitcast
+ // CHECK-NOT: shuffle
+ // CHECK-NOT: andi
+ // CHECK-NOT: ori
+ %0 = arith.trunci %a : vector<16xi16> to vector<16xi3>
+ %1 = vector.bitcast %0 : vector<16xi3> to vector<8xi6>
+ return %1 : vector<8xi6>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !transform.any_op):
+ %f = transform.structured.match ops{["func.func"]} in %module_op
+ : (!transform.any_op) -> !transform.any_op
+
+ transform.apply_patterns to %f {
+ transform.apply_patterns.vector.rewrite_narrow_types
+ } : !transform.any_op
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-rewrite-narrow-types.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-rewrite-narrow-types.mlir
new file mode 100644
index 000000000000000..44c608726f13530
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-rewrite-narrow-types.mlir
@@ -0,0 +1,155 @@
+/// Run once without applying the pattern and check the source of truth.
+// RUN: mlir-opt %s --test-transform-dialect-erase-schedule -test-lower-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN: -shared-libs=%mlir_c_runner_utils | \
+// RUN: FileCheck %s
+
+/// Run once with the pattern and compare.
+// RUN: mlir-opt %s -test-transform-dialect-interpreter -test-transform-dialect-erase-schedule -test-lower-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN: -shared-libs=%mlir_c_runner_utils | \
+// RUN: FileCheck %s
+
+func.func @print_as_i1_16xi5(%v : vector<16xi5>) {
+ %bitsi16 = vector.bitcast %v : vector<16xi5> to vector<80xi1>
+ vector.print %bitsi16 : vector<80xi1>
+ return
+}
+
+func.func @print_as_i1_10xi8(%v : vector<10xi8>) {
+ %bitsi16 = vector.bitcast %v : vector<10xi8> to vector<80xi1>
+ vector.print %bitsi16 : vector<80xi1>
+ return
+}
+
+func.func @f(%v: vector<16xi16>) {
+ %trunc = arith.trunci %v : vector<16xi16> to vector<16xi5>
+ func.call @print_as_i1_16xi5(%trunc) : (vector<16xi5>) -> ()
+ // CHECK: (
+ // CHECK-SAME: 1, 1, 1, 1, 1,
+ // CHECK-SAME: 0, 1, 1, 1, 1,
+ // CHECK-SAME: 1, 0, 1, 1, 1,
+ // CHECK-SAME: 0, 0, 1, 1, 1,
+ // CHECK-SAME: 1, 1, 0, 1, 1,
+ // CHECK-SAME: 0, 1, 0, 1, 1,
+ // CHECK-SAME: 1, 0, 0, 1, 1,
+ // CHECK-SAME: 0, 0, 0, 1, 1,
+ // CHECK-SAME: 1, 1, 1, 0, 1,
+ // CHECK-SAME: 0, 1, 1, 0, 1,
+ // CHECK-SAME: 1, 0, 1, 0, 1,
+ // CHECK-SAME: 0, 0, 1, 0, 1,
+ // CHECK-SAME: 1, 1, 0, 0, 1,
+ // CHECK-SAME: 0, 1, 0, 0, 1,
+ // CHECK-SAME: 1, 0, 0, 0, 1,
+ // CHECK-SAME: 0, 0, 0, 0, 1 )
+
+ %bitcast = vector.bitcast %trunc : vector<16xi5> to vector<10xi8>
+ func.call @print_as_i1_10xi8(%bitcast) : (vector<10xi8>) -> ()
+ // CHECK: (
+ // CHECK-SAME: 1, 1, 1, 1, 1, 0, 1, 1,
+ // CHECK-SAME: 1, 1, 1, 0, 1, 1, 1, 0,
+ // CHECK-SAME: 0, 1, 1, 1, 1, 1, 0, 1,
+ // CHECK-SAME: 1, 0, 1, 0, 1, 1, 1, 0,
+ // CHECK-SAME: 0, 1, 1, 0, 0, 0, 1, 1,
+ // CHECK-SAME: 1, 1, 1, 0, 1, 0, 1, 1,
+ // CHECK-SAME: 0, 1, 1, 0, 1, 0, 1, 0,
+ // CHECK-SAME: 0, 1, 0, 1, 1, 1, 0, 0,
+ // CHECK-SAME: 1, 0, 1, 0, 0, 1, 1, 0,
+ // CHECK-SAME: 0, 0, 1, 0, 0, 0, 0, 1 )
+
+ return
+}
+
+func.func @print_as_i1_8xi3(%v : vector<8xi3>) {
+ %bitsi12 = vector.bitcast %v : vector<8xi3> to vector<24xi1>
+ vector.print %bitsi12 : vector<24xi1>
+ return
+}
+
+func.func @print_as_i1_3xi8(%v : vector<3xi8>) {
+ %bitsi12 = vector.bitcast %v : vector<3xi8> to vector<24xi1>
+ vector.print %bitsi12 : vector<24xi1>
+ return
+}
+
+func.func @f2(%v: vector<8xi32>) {
+ %trunc = arith.trunci %v : vector<8xi32> to vector<8xi3>
+ func.call @print_as_i1_8xi3(%trunc) : (vector<8xi3>) -> ()
+ // CHECK: (
+ // CHECK-SAME: 1, 1, 1,
+ // CHECK-SAME: 0, 1, 1,
+ // CHECK-SAME: 1, 0, 1,
+ // CHECK-SAME: 0, 0, 1,
+ // CHECK-SAME: 1, 1, 0,
+ // CHECK-SAME: 0, 1, 0,
+ // CHECK-SAME: 1, 0, 0,
+ // CHECK-SAME: 0, 0, 0 )
+
+ %bitcast = vector.bitcast %trunc : vector<8xi3> to vector<3xi8>
+ func.call @print_as_i1_3xi8(%bitcast) : (vector<3xi8>) -> ()
+ // CHECK: (
+ // CHECK-SAME: 1, 1, 1, 0, 1, 1, 1, 0,
+ // CHECK-SAME: 1, 0, 0, 1, 1, 1, 0, 0,
+ // CHECK-SAME: 1, 0, 1, 0, 0, 0, 0, 0 )
+
+ return
+}
+
+func.func @print_as_i1_2xi24(%v : vector<2xi24>) {
+ %bitsi48 = vector.bitcast %v : vector<2xi24> to vector<48xi1>
+ vector.print %bitsi48 : vector<48xi1>
+ return
+}
+
+func.func @print_as_i1_3xi16(%v : vector<3xi16>) {
+ %bitsi48 = vector.bitcast %v : vector<3xi16> to vector<48xi1>
+ vector.print %bitsi48 : vector<48xi1>
+ return
+}
+
+func.func @f3(%v: vector<2xi48>) {
+ %trunc = arith.trunci %v : vector<2xi48> to vector<2xi24>
+ func.call @print_as_i1_2xi24(%trunc) : (vector<2xi24>) -> ()
+ // CHECK: (
+ // CHECK-SAME: 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
+ // CHECK-SAME: 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 0 )
+
+ %bitcast = vector.bitcast %trunc : vector<2xi24> to vector<3xi16>
+ func.call @print_as_i1_3xi16(%bitcast) : (vector<3xi16>) -> ()
+ // CHECK: (
+ // CHECK-SAME: 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
+ // CHECK-SAME: 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0,
+ // CHECK-SAME: 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 0 )
+
+ return
+}
+
+func.func @entry() {
+ %v = arith.constant dense<[
+ 0xffff, 0xfffe, 0xfffd, 0xfffc, 0xfffb, 0xfffa, 0xfff9, 0xfff8,
+ 0xfff7, 0xfff6, 0xfff5, 0xfff4, 0xfff3, 0xfff2, 0xfff1, 0xfff0
+ ]> : vector<16xi16>
+ func.call @f(%v) : (vector<16xi16>) -> ()
+
+ %v2 = arith.constant dense<[
+ 0xffff, 0xfffe, 0xfffd, 0xfffc, 0xfffb, 0xfffa, 0xfff9, 0xfff8
+ ]> : vector<8xi32>
+ func.call @f2(%v2) : (vector<8xi32>) -> ()
+
+ %v3 = arith.constant dense<[
+ 0xf345aeffffff, 0xffff015f345a
+ ]> : vector<2xi48>
+ func.call @f3(%v3) : (vector<2xi48>) -> ()
+
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !transform.any_op):
+ %f = transform.structured.match ops{["func.func"]} in %module_op
+ : (!transform.any_op) -> !transform.any_op
+
+ transform.apply_patterns to %f {
+ transform.apply_patterns.vector.rewrite_narrow_types
+ } : !transform.any_op
+}
>From fcff465783749aaafb0ba81ed09daeacb1c22a7a Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nicolasvasilache at users.noreply.github.com>
Date: Mon, 18 Sep 2023 11:09:48 +0200
Subject: [PATCH 2/3] Address comments
---
.../Transforms/VectorEmulateNarrowType.cpp | 121 +++++++++++++-----
1 file changed, 90 insertions(+), 31 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index eaf0b8849e36ebf..8c981c129dbe18e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -173,7 +173,9 @@ namespace {
/// Helper struct to keep track of the provenance of a contiguous set of bits
/// in a source vector.
struct SourceElementRange {
- int64_t sourceElement;
+ /// The index of the source vector element that contributes bits to *this.
+ int64_t sourceElementIdx;
+ /// The range of bits in the source vector element that contribute to *this.
int64_t sourceBitBegin;
int64_t sourceBitEnd;
};
@@ -196,6 +198,16 @@ struct SourceElementRangeList : public SmallVector<SourceElementRange> {
/// involved in a bitcast operation.
/// This allows rewriting a vector.bitcast into shuffles and bitwise ops for
/// any 1-D vector shape and any source/target bitwidths.
+/// This creates and holds a mapping of the form:
+/// [dstVectorElementJ] ==
+/// [ {srcVectorElementX, bitRange}, {srcVectorElementY, bitRange}, ... ]
+/// E.g. `vector.bitcast ... : vector<1xi24> to vector<3xi8>` is decomposed as:
+/// [0] = {0, [0-8)}
+/// [1] = {0, [8-16)}
+/// [2] = {0, [16-24)}
+/// and `vector.bitcast ... : vector<2xi15> to vector<3xi10>` is decomposed as:
+/// [0] = {0, [0, 10)}, {1, [0, 5)}
+/// [1] = {1, [5, 10)}, {2, [0, 10)}
struct BitCastBitsEnumerator {
BitCastBitsEnumerator(VectorType sourceVectorType,
VectorType targetVectorType);
@@ -218,7 +230,7 @@ static raw_ostream &operator<<(raw_ostream &os,
const SmallVector<SourceElementRangeList> &vec) {
for (const auto &l : vec) {
for (auto it : llvm::enumerate(l)) {
- os << "{ " << it.value().sourceElement << ": b@["
+ os << "{ " << it.value().sourceElementIdx << ": b@["
<< it.value().sourceBitBegin << ".." << it.value().sourceBitEnd
<< ") lshl: " << l.computeLeftShiftAmount(it.index()) << " } ";
}
@@ -231,6 +243,8 @@ BitCastBitsEnumerator::BitCastBitsEnumerator(VectorType sourceVectorType,
VectorType targetVectorType)
: sourceVectorType(sourceVectorType), targetVectorType(targetVectorType) {
+ assert(sourceVectorType.getRank() == 1 && !sourceVectorType.isScalable() &&
+ "requires -D non-scalable vector type");
assert(targetVectorType.getRank() == 1 && !targetVectorType.isScalable() &&
"requires -D non-scalable vector type");
int64_t sourceBitWidth = sourceVectorType.getElementTypeBitWidth();
@@ -250,12 +264,12 @@ BitCastBitsEnumerator::BitCastBitsEnumerator(VectorType sourceVectorType,
for (int64_t resultBit = 0; resultBit < bitwidth;) {
int64_t resultElement = resultBit / targetBitWidth;
int64_t resultBitInElement = resultBit % targetBitWidth;
- int64_t sourceElement = resultBit / sourceBitWidth;
+ int64_t sourceElementIdx = resultBit / sourceBitWidth;
int64_t sourceBitInElement = resultBit % sourceBitWidth;
int64_t step = std::min(sourceBitWidth - sourceBitInElement,
targetBitWidth - resultBitInElement);
sourceElementRanges[resultElement].push_back(
- {sourceElement, sourceBitInElement, sourceBitInElement + step});
+ {sourceElementIdx, sourceBitInElement, sourceBitInElement + step});
resultBit += step;
}
}
@@ -264,6 +278,67 @@ namespace {
/// Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take
/// advantage of high-level information to avoid leaving LLVM to scramble with
/// peephole optimizations.
+
+// BitCastBitsEnumerator encodes for each element of the target vector the
+// provenance of the bits in the source vector. We can "transpose" this
+// information to build a sequence of shuffles and bitwise ops that will
+// produce the desired result.
+//
+// Let's take the following motivating example to explain the algorithm:
+// ```
+// %0 = arith.trunci %a : vector<32xi64> to vector<32xi5>
+// %1 = vector.bitcast %0 : vector<32xi5> to vector<20xi8>
+// ```
+//
+// BitCastBitsEnumerator contains the following information:
+// ```
+// { 0: b@[0..5) lshl: 0}{1: b@[0..3) lshl: 5 }
+// { 1: b@[3..5) lshl: 0}{2: b@[0..5) lshl: 2}{3: b@[0..1) lshl: 7 }
+// { 3: b@[1..5) lshl: 0}{4: b@[0..4) lshl: 4 }
+// { 4: b@[4..5) lshl: 0}{5: b@[0..5) lshl: 1}{6: b@[0..2) lshl: 6 }
+// { 6: b@[2..5) lshl: 0}{7: b@[0..5) lshl: 3 }
+// { 8: b@[0..5) lshl: 0}{9: b@[0..3) lshl: 5 }
+// { 9: b@[3..5) lshl: 0}{10: b@[0..5) lshl: 2}{11: b@[0..1) lshl: 7 }
+// { 11: b@[1..5) lshl: 0}{12: b@[0..4) lshl: 4 }
+// { 12: b@[4..5) lshl: 0}{13: b@[0..5) lshl: 1}{14: b@[0..2) lshl: 6 }
+// { 14: b@[2..5) lshl: 0}{15: b@[0..5) lshl: 3}
+// { 16: b@[0..5) lshl: 0}{17: b@[0..3) lshl: 5}
+// { 17: b@[3..5) lshl: 0}{18: b@[0..5) lshl: 2}{19: b@[0..1) lshl: 7}
+// { 19: b@[1..5) lshl: 0}{20: b@[0..4) lshl: 4}
+// { 20: b@[4..5) lshl: 0}{21: b@[0..5) lshl: 1 }{22: b@[0..2) lshl: 6}
+// { 22: b@[2..5) lshl: 0}{23: b@[0..5) lshl: 3 }
+// { 24: b@[0..5) lshl: 0}{25: b@[0..3) lshl: 5 }
+// { 25: b@[3..5) lshl: 0}{26: b@[0..5) lshl: 2}{27: b@[0..1) lshl: 7 }
+// { 27: b@[1..5) lshl: 0}{28: b@[0..4) lshl: 4}
+// { 28: b@[4..5) lshl: 0}{29: b@[0..5) lshl: 1}{30: b@[0..2) lshl: 6}
+// { 30: b@[2..5) lshl: 0}{31: b@[0..5) lshl: 3 }
+// ```
+//
+// In the above, each row represents one target vector element and each
+// column represents one bit contribution from a source vector element.
+// The algorithm creates vector.shuffle operations (in this case there are 3
+// shuffles (i.e. the max number of columns in BitCastBitsEnumerator), as
+// follows:
+// 1. for each vector.shuffle, collect the source vectors that participate in
+// this shuffle. One source vector per target element of the resulting
+// vector.shuffle. If there is no source element contributing bits for the
+// current vector.shuffle, take 0 (i.e. row 0 in the above example has only
+// 2 columns).
+// 2. represent the bitrange in the source vector as a mask. If there is no
+// source element contributing bits for the current vector.shuffle, take 0.
+// 3. shift right by the proper amount to align the source bitrange at
+// position 0. This is exactly the low end of the bitrange. For instance,
+// the first element of row 2 is `{ 1: b@[3..5) lshl: 0}` and one needs to
+// shift right by 3 to get the bits contributed by the source element #1
+// into position 0.
+// 4. shift left by the proper amount to to align to the desired position in
+// the result element vector. For instance, the contribution of the second
+// source element for the first row needs to be shifted by `5` to form the
+// first i8 result element.
+// Eventually, we end up building the sequence
+// `(shuffle -> and -> shiftright -> shiftleft -> or)` to iteratively update the
+// result vector (i.e. the `shiftright -> shiftleft -> or` part) with the bits
+// extracted from the source vector (i.e. the `shuffle -> and` part).
struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
using OpRewritePattern::OpRewritePattern;
@@ -278,8 +353,8 @@ struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
VectorType targetVectorType = bitCastOp.getResultVectorType();
if (targetVectorType.getRank() != 1 || targetVectorType.isScalable())
return rewriter.notifyMatchFailure(bitCastOp, "scalable or >1-D vector");
- // TODO: consider relaxing this restriction in the future if we find ways to
- // really work with subbyte elements across the MLIR/LLVM boundary.
+ // TODO: consider relaxing this restriction in the future if we find ways
+ // to really work with subbyte elements across the MLIR/LLVM boundary.
int64_t resultBitwidth = targetVectorType.getElementTypeBitWidth();
if (resultBitwidth % 8 != 0)
return rewriter.notifyMatchFailure(bitCastOp, "bitwidth is not k * 8");
@@ -293,34 +368,18 @@ struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
auto initalElementType = initalVectorType.getElementType();
auto initalElementBitWidth = initalElementType.getIntOrFloatBitWidth();
- // BitCastBitsEnumerator encodes for each element of the target vector the
- // provenance of the bits in the source vector. We can "transpose" this
- // information to build a sequence of shuffles and bitwise ops that will
- // produce the desired result.
- // The algorithm proceeds as follows:
- // 1. there are as many shuffles as max entries in BitCastBitsEnumerator
- // 2. for each shuffle:
- // a. collect the source vectors that participate in this shuffle. One
- // source vector per target element of the shuffle. If overflow, take 0.
- // b. the bitrange in the source vector as a mask. If overflow, take 0.
- // c. the number of bits to shift right to align the source bitrange at
- // position 0. This is exactly the low end of the bitrange.
- // d. number of bits to shift left to align to the desired position in
- // the result element vector.
- // Then build the sequence:
- // (shuffle -> and -> shiftright -> shiftleft -> or) to iteratively update
- // the result vector (i.e. the "shiftright -> shiftleft -> or" part) with
- // the bits extracted from the source vector (i.e. the "shuffle -> and"
- // part).
Value res;
for (int64_t shuffleIdx = 0, e = be.getMaxNumberOfEntries(); shuffleIdx < e;
++shuffleIdx) {
SmallVector<int64_t> shuffles;
SmallVector<Attribute> masks, shiftRightAmounts, shiftLeftAmounts;
+
+ // Create the attribute quantities for the shuffle / mask / shift ops.
for (auto &l : be.sourceElementRanges) {
- int64_t sourceElement =
- (shuffleIdx < (int64_t)l.size()) ? l[shuffleIdx].sourceElement : 0;
- shuffles.push_back(sourceElement);
+ int64_t sourceElementIdx = (shuffleIdx < (int64_t)l.size())
+ ? l[shuffleIdx].sourceElementIdx
+ : 0;
+ shuffles.push_back(sourceElementIdx);
int64_t bitLo =
(shuffleIdx < (int64_t)l.size()) ? l[shuffleIdx].sourceBitBegin : 0;
@@ -340,17 +399,17 @@ struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
rewriter.getIntegerType(initalElementBitWidth), shiftLeft));
}
- //
+ // Create vector.shuffle #shuffleIdx.
auto shuffleOp = rewriter.create<vector::ShuffleOp>(
bitCastOp.getLoc(), initialValue, initialValue, shuffles);
-
+ // And with the mask.
VectorType vt = VectorType::Builder(initalVectorType)
.setDim(initalVectorType.getRank() - 1, masks.size());
auto constOp = rewriter.create<arith::ConstantOp>(
bitCastOp.getLoc(), DenseElementsAttr::get(vt, masks));
Value andValue = rewriter.create<arith::AndIOp>(bitCastOp.getLoc(),
shuffleOp, constOp);
-
+ // Align right on 0.
auto shiftRightConstantOp = rewriter.create<arith::ConstantOp>(
bitCastOp.getLoc(), DenseElementsAttr::get(vt, shiftRightAmounts));
Value shiftedRight = rewriter.create<arith::ShRUIOp>(
>From 1ad1f91691130870309b07d0500d753554d36a93 Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nicolasvasilache at users.noreply.github.com>
Date: Fri, 8 Sep 2023 11:28:55 +0200
Subject: [PATCH 3/3] [mlir][Vector] Add a rewrite pattern for better
low-precision ext(bitcast) expansion
This revision adds a rewrite for sequences of vector `ext(bitcast)` to use a more efficient sequence of
vector operations comprising `shuffle` and `bitwise` ops.
Such patterns appear naturally when writing quantization / dequantization functionality with the vector dialect.
The rewrite performs a simple enumeration of each of the bits in the result vector and determines its provenance
in the source vector. The enumeration is used to generate the proper sequence of `shuffle`, `andi`, `ori`
with shifts`.
The rewrite currently only applies to 1-D non-scalable vectors and bails out if the final vector element type is
not a multiple of 8. This is a failsafe heuristic determined empirically: if the resulting type
is not an even number of bytes, further complexities arise that are not improved by this pattern:
the heavy lifting still needs to be done by LLVM.
---
.../Vector/Transforms/VectorRewritePatterns.h | 7 +
.../Transforms/VectorEmulateNarrowType.cpp | 406 ++++++++++++------
.../Vector/vector-rewrite-narrow-types.mlir | 295 +++++++------
.../Vector/CPU/test-rewrite-narrow-types.mlir | 46 ++
4 files changed, 482 insertions(+), 272 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 8652fc7f5e5c640..eb561ba3b23557a 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -23,6 +23,7 @@ namespace mlir {
class RewritePatternSet;
namespace arith {
+class AndIOp;
class NarrowTypeEmulationConverter;
class TruncIOp;
} // namespace arith
@@ -309,6 +310,12 @@ FailureOr<Value> rewriteBitCastOfTruncI(RewriterBase &rewriter,
arith::TruncIOp truncOp,
vector::BroadcastOp maybeBroadcastOp);
+/// Rewrite a vector `ext(bitcast)` to use a more efficient sequence of
+/// vector operations comprising `shuffle` and `bitwise` ops.
+FailureOr<Value> rewriteExtOfBitCast(RewriterBase &rewriter, Operation *extOp,
+ vector::BitCastOp bitCastOp,
+ vector::BroadcastOp maybeBroadcastOp);
+
/// Appends patterns for rewriting vector operations over narrow types with
/// ops over wider types.
void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 8c981c129dbe18e..cea78ddababdf94 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -18,11 +18,15 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
+#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/APInt.h"
#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include <cstdint>
+#include <numeric>
using namespace mlir;
@@ -224,6 +228,98 @@ struct BitCastBitsEnumerator {
SmallVector<SourceElementRangeList> sourceElementRanges;
};
+/// Rewrite vector.bitcast to a sequence of shuffles and bitwise ops that take
+/// advantage of high-level information to avoid leaving LLVM to scramble with
+/// peephole optimizations.
+/// BitCastBitsEnumerator encodes for each element of the target vector the
+/// provenance of the bits in the source vector. We can "transpose" this
+/// information to build a sequence of shuffles and bitwise ops that will
+/// produce the desired result.
+//
+/// Consider the following motivating example:
+/// ```
+/// %1 = vector.bitcast %0 : vector<32xi5> to vector<20xi8>
+/// ```
+//
+/// BitCastBitsEnumerator contains the following information:
+/// ```
+/// { 0: b@[0..5) lshl: 0}{ 1: b@[0..3) lshl: 5}
+/// { 1: b@[3..5) lshl: 0}{ 2: b@[0..5) lshl: 2}{ 3: b@[0..1) lshl: 7}
+/// { 3: b@[1..5) lshl: 0}{ 4: b@[0..4) lshl: 4}
+/// { 4: b@[4..5) lshl: 0}{ 5: b@[0..5) lshl: 1}{ 6: b@[0..2) lshl: 6}
+/// { 6: b@[2..5) lshl: 0}{ 7: b@[0..5) lshl: 3}
+/// { 8: b@[0..5) lshl: 0}{ 9: b@[0..3) lshl: 5}
+/// { 9: b@[3..5) lshl: 0}{10: b@[0..5) lshl: 2}{11: b@[0..1) lshl: 7}
+/// {11: b@[1..5) lshl: 0}{12: b@[0..4) lshl: 4}
+/// {12: b@[4..5) lshl: 0}{13: b@[0..5) lshl: 1}{14: b@[0..2) lshl: 6}
+/// {14: b@[2..5) lshl: 0}{15: b@[0..5) lshl: 3}
+/// {16: b@[0..5) lshl: 0}{17: b@[0..3) lshl: 5}
+/// {17: b@[3..5) lshl: 0}{18: b@[0..5) lshl: 2}{19: b@[0..1) lshl: 7}
+/// {19: b@[1..5) lshl: 0}{20: b@[0..4) lshl: 4}
+/// {20: b@[4..5) lshl: 0}{21: b@[0..5) lshl: 1}{22: b@[0..2) lshl: 6}
+/// {22: b@[2..5) lshl: 0}{23: b@[0..5) lshl: 3}
+/// {24: b@[0..5) lshl: 0}{25: b@[0..3) lshl: 5}
+/// {25: b@[3..5) lshl: 0}{26: b@[0..5) lshl: 2}{27: b@[0..1) lshl: 7}
+/// {27: b@[1..5) lshl: 0}{28: b@[0..4) lshl: 4}
+/// {28: b@[4..5) lshl: 0}{29: b@[0..5) lshl: 1}{30: b@[0..2) lshl: 6}
+/// {30: b@[2..5) lshl: 0}{31: b@[0..5) lshl: 3}
+/// ```
+//
+/// In the above, each row represents one target vector element and each
+/// column represents one bit contribution from a source vector element.
+/// The algorithm creates vector.shuffle operations (in this case there are 3
+/// shuffles (i.e. the max number of columns in BitCastBitsEnumerator), as
+/// follows:
+/// 1. for each vector.shuffle, collect the source vectors that participate in
+/// this shuffle. One source vector per target element of the resulting
+/// vector.shuffle. If there is no source element contributing bits for the
+/// current vector.shuffle, take 0 (i.e. row 0 in the above example has only
+/// 2 columns).
+/// 2. represent the bitrange in the source vector as a mask. If there is no
+/// source element contributing bits for the current vector.shuffle, take 0.
+/// 3. shift right by the proper amount to align the source bitrange at
+/// position 0. This is exactly the low end of the bitrange. For instance,
+/// the first element of row 2 is `{ 1: b@[3..5) lshl: 0}` and one needs to
+/// shift right by 3 to get the bits contributed by the source element #1
+/// into position 0.
+/// 4. shift left by the proper amount to to align to the desired position in
+/// the result element vector. For instance, the contribution of the second
+/// source element for the first row needs to be shifted by `5` to form the
+/// first i8 result element.
+///
+/// Eventually, we end up building the sequence
+/// `(shuffle -> and -> shiftright -> shiftleft -> or)` to iteratively update
+/// the result vector (i.e. the `shiftright -> shiftleft -> or` part) with the
+/// bits extracted from the source vector (i.e. the `shuffle -> and` part).
+struct BitCastRewriter {
+ /// Helper metadata struct to hold the static quantities for the rewrite.
+ struct Metadata {
+ SmallVector<int64_t> shuffles;
+ SmallVector<Attribute> masks, shiftRightAmounts, shiftLeftAmounts;
+ };
+
+ BitCastRewriter(VectorType sourceVectorType, VectorType targetVectorType);
+
+ /// Verify that the preconditions for the rewrite are met.
+ LogicalResult precondition(PatternRewriter &rewriter,
+ VectorType targetVectorType, Operation *op);
+
+ /// Precompute the metadata for the rewrite.
+ SmallVector<BitCastRewriter::Metadata>
+ precomputeMetadata(IntegerType shuffledElementType);
+
+ /// Rewrite one step of the sequence:
+ /// `(shuffle -> and -> shiftright -> shiftleft -> or)`.
+ Value rewriteStep(PatternRewriter &rewriter, Location loc, Value initialValue,
+ Value runningResult,
+ const BitCastRewriter::Metadata &metadata);
+
+private:
+ /// Underlying enumerator that encodes the provenance of the bits in the each
+ /// element of the result vector.
+ BitCastBitsEnumerator enumerator;
+};
+
} // namespace
static raw_ostream &operator<<(raw_ostream &os,
@@ -274,71 +370,104 @@ BitCastBitsEnumerator::BitCastBitsEnumerator(VectorType sourceVectorType,
}
}
+BitCastRewriter::BitCastRewriter(VectorType sourceVectorType,
+ VectorType targetVectorType)
+ : enumerator(BitCastBitsEnumerator(sourceVectorType, targetVectorType)) {
+ LDBG("\n" << enumerator.sourceElementRanges);
+}
+
+LogicalResult BitCastRewriter::precondition(PatternRewriter &rewriter,
+ VectorType targetVectorType,
+ Operation *op) {
+ if (targetVectorType.getRank() != 1 || targetVectorType.isScalable())
+ return rewriter.notifyMatchFailure(op, "scalable or >1-D vector");
+
+ // TODO: consider relaxing this restriction in the future if we find ways
+ // to really work with subbyte elements across the MLIR/LLVM boundary.
+ int64_t resultBitwidth = targetVectorType.getElementTypeBitWidth();
+ if (resultBitwidth % 8 != 0)
+ return rewriter.notifyMatchFailure(op, "bitwidth is not k * 8");
+
+ return success();
+}
+
+SmallVector<BitCastRewriter::Metadata>
+BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
+ SmallVector<BitCastRewriter::Metadata> result;
+ for (int64_t shuffleIdx = 0, e = enumerator.getMaxNumberOfEntries();
+ shuffleIdx < e; ++shuffleIdx) {
+ SmallVector<int64_t> shuffles;
+ SmallVector<Attribute> masks, shiftRightAmounts, shiftLeftAmounts;
+
+ // Create the attribute quantities for the shuffle / mask / shift ops.
+ for (auto &l : enumerator.sourceElementRanges) {
+ int64_t sourceElement =
+ (shuffleIdx < (int64_t)l.size()) ? l[shuffleIdx].sourceElementIdx : 0;
+ shuffles.push_back(sourceElement);
+
+ int64_t bitLo =
+ (shuffleIdx < (int64_t)l.size()) ? l[shuffleIdx].sourceBitBegin : 0;
+ int64_t bitHi =
+ (shuffleIdx < (int64_t)l.size()) ? l[shuffleIdx].sourceBitEnd : 0;
+ IntegerAttr mask = IntegerAttr::get(
+ shuffledElementType,
+ llvm::APInt::getBitsSet(shuffledElementType.getIntOrFloatBitWidth(),
+ bitLo, bitHi));
+ masks.push_back(mask);
+
+ int64_t shiftRight = bitLo;
+ shiftRightAmounts.push_back(
+ IntegerAttr::get(shuffledElementType, shiftRight));
+
+ int64_t shiftLeft = l.computeLeftShiftAmount(shuffleIdx);
+ shiftLeftAmounts.push_back(
+ IntegerAttr::get(shuffledElementType, shiftLeft));
+ }
+
+ result.push_back({shuffles, masks, shiftRightAmounts, shiftLeftAmounts});
+ }
+ return result;
+}
+
+Value BitCastRewriter::rewriteStep(PatternRewriter &rewriter, Location loc,
+ Value initialValue, Value runningResult,
+ const BitCastRewriter::Metadata &metadata) {
+ // Create vector.shuffle from the metadata.
+ auto shuffleOp = rewriter.create<vector::ShuffleOp>(
+ loc, initialValue, initialValue, metadata.shuffles);
+
+ // Intersect with the mask.
+ VectorType shuffledVectorType = shuffleOp.getResultVectorType();
+ auto constOp = rewriter.create<arith::ConstantOp>(
+ loc, DenseElementsAttr::get(shuffledVectorType, metadata.masks));
+ Value andValue = rewriter.create<arith::AndIOp>(loc, shuffleOp, constOp);
+
+ // Align right on 0.
+ auto shiftRightConstantOp = rewriter.create<arith::ConstantOp>(
+ loc,
+ DenseElementsAttr::get(shuffledVectorType, metadata.shiftRightAmounts));
+ Value shiftedRight =
+ rewriter.create<arith::ShRUIOp>(loc, andValue, shiftRightConstantOp);
+
+ // Shift bits left into their final position.
+ auto shiftLeftConstantOp = rewriter.create<arith::ConstantOp>(
+ loc,
+ DenseElementsAttr::get(shuffledVectorType, metadata.shiftLeftAmounts));
+ Value shiftedLeft =
+ rewriter.create<arith::ShLIOp>(loc, shiftedRight, shiftLeftConstantOp);
+
+ runningResult =
+ runningResult
+ ? rewriter.create<arith::OrIOp>(loc, runningResult, shiftedLeft)
+ : shiftedLeft;
+
+ return runningResult;
+}
+
namespace {
/// Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take
/// advantage of high-level information to avoid leaving LLVM to scramble with
/// peephole optimizations.
-
-// BitCastBitsEnumerator encodes for each element of the target vector the
-// provenance of the bits in the source vector. We can "transpose" this
-// information to build a sequence of shuffles and bitwise ops that will
-// produce the desired result.
-//
-// Let's take the following motivating example to explain the algorithm:
-// ```
-// %0 = arith.trunci %a : vector<32xi64> to vector<32xi5>
-// %1 = vector.bitcast %0 : vector<32xi5> to vector<20xi8>
-// ```
-//
-// BitCastBitsEnumerator contains the following information:
-// ```
-// { 0: b@[0..5) lshl: 0}{1: b@[0..3) lshl: 5 }
-// { 1: b@[3..5) lshl: 0}{2: b@[0..5) lshl: 2}{3: b@[0..1) lshl: 7 }
-// { 3: b@[1..5) lshl: 0}{4: b@[0..4) lshl: 4 }
-// { 4: b@[4..5) lshl: 0}{5: b@[0..5) lshl: 1}{6: b@[0..2) lshl: 6 }
-// { 6: b@[2..5) lshl: 0}{7: b@[0..5) lshl: 3 }
-// { 8: b@[0..5) lshl: 0}{9: b@[0..3) lshl: 5 }
-// { 9: b@[3..5) lshl: 0}{10: b@[0..5) lshl: 2}{11: b@[0..1) lshl: 7 }
-// { 11: b@[1..5) lshl: 0}{12: b@[0..4) lshl: 4 }
-// { 12: b@[4..5) lshl: 0}{13: b@[0..5) lshl: 1}{14: b@[0..2) lshl: 6 }
-// { 14: b@[2..5) lshl: 0}{15: b@[0..5) lshl: 3}
-// { 16: b@[0..5) lshl: 0}{17: b@[0..3) lshl: 5}
-// { 17: b@[3..5) lshl: 0}{18: b@[0..5) lshl: 2}{19: b@[0..1) lshl: 7}
-// { 19: b@[1..5) lshl: 0}{20: b@[0..4) lshl: 4}
-// { 20: b@[4..5) lshl: 0}{21: b@[0..5) lshl: 1 }{22: b@[0..2) lshl: 6}
-// { 22: b@[2..5) lshl: 0}{23: b@[0..5) lshl: 3 }
-// { 24: b@[0..5) lshl: 0}{25: b@[0..3) lshl: 5 }
-// { 25: b@[3..5) lshl: 0}{26: b@[0..5) lshl: 2}{27: b@[0..1) lshl: 7 }
-// { 27: b@[1..5) lshl: 0}{28: b@[0..4) lshl: 4}
-// { 28: b@[4..5) lshl: 0}{29: b@[0..5) lshl: 1}{30: b@[0..2) lshl: 6}
-// { 30: b@[2..5) lshl: 0}{31: b@[0..5) lshl: 3 }
-// ```
-//
-// In the above, each row represents one target vector element and each
-// column represents one bit contribution from a source vector element.
-// The algorithm creates vector.shuffle operations (in this case there are 3
-// shuffles (i.e. the max number of columns in BitCastBitsEnumerator), as
-// follows:
-// 1. for each vector.shuffle, collect the source vectors that participate in
-// this shuffle. One source vector per target element of the resulting
-// vector.shuffle. If there is no source element contributing bits for the
-// current vector.shuffle, take 0 (i.e. row 0 in the above example has only
-// 2 columns).
-// 2. represent the bitrange in the source vector as a mask. If there is no
-// source element contributing bits for the current vector.shuffle, take 0.
-// 3. shift right by the proper amount to align the source bitrange at
-// position 0. This is exactly the low end of the bitrange. For instance,
-// the first element of row 2 is `{ 1: b@[3..5) lshl: 0}` and one needs to
-// shift right by 3 to get the bits contributed by the source element #1
-// into position 0.
-// 4. shift left by the proper amount to to align to the desired position in
-// the result element vector. For instance, the contribution of the second
-// source element for the first row needs to be shifted by `5` to form the
-// first i8 result element.
-// Eventually, we end up building the sequence
-// `(shuffle -> and -> shiftright -> shiftleft -> or)` to iteratively update the
-// result vector (i.e. the `shiftright -> shiftleft -> or` part) with the bits
-// extracted from the source vector (i.e. the `shuffle -> and` part).
struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
using OpRewritePattern::OpRewritePattern;
@@ -350,89 +479,92 @@ struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
if (!truncOp)
return rewriter.notifyMatchFailure(bitCastOp, "not a trunci source");
+ // Set up the BitCastRewriter and verify the precondition.
+ VectorType sourceVectorType = bitCastOp.getSourceVectorType();
VectorType targetVectorType = bitCastOp.getResultVectorType();
- if (targetVectorType.getRank() != 1 || targetVectorType.isScalable())
- return rewriter.notifyMatchFailure(bitCastOp, "scalable or >1-D vector");
- // TODO: consider relaxing this restriction in the future if we find ways
- // to really work with subbyte elements across the MLIR/LLVM boundary.
- int64_t resultBitwidth = targetVectorType.getElementTypeBitWidth();
- if (resultBitwidth % 8 != 0)
- return rewriter.notifyMatchFailure(bitCastOp, "bitwidth is not k * 8");
+ BitCastRewriter bcr(sourceVectorType, targetVectorType);
+ if (failed(bcr.precondition(rewriter, targetVectorType, bitCastOp)))
+ return failure();
- VectorType sourceVectorType = bitCastOp.getSourceVectorType();
- BitCastBitsEnumerator be(sourceVectorType, targetVectorType);
- LDBG("\n" << be.sourceElementRanges);
-
- Value initialValue = truncOp.getIn();
- auto initalVectorType = initialValue.getType().cast<VectorType>();
- auto initalElementType = initalVectorType.getElementType();
- auto initalElementBitWidth = initalElementType.getIntOrFloatBitWidth();
-
- Value res;
- for (int64_t shuffleIdx = 0, e = be.getMaxNumberOfEntries(); shuffleIdx < e;
- ++shuffleIdx) {
- SmallVector<int64_t> shuffles;
- SmallVector<Attribute> masks, shiftRightAmounts, shiftLeftAmounts;
-
- // Create the attribute quantities for the shuffle / mask / shift ops.
- for (auto &l : be.sourceElementRanges) {
- int64_t sourceElementIdx = (shuffleIdx < (int64_t)l.size())
- ? l[shuffleIdx].sourceElementIdx
- : 0;
- shuffles.push_back(sourceElementIdx);
-
- int64_t bitLo =
- (shuffleIdx < (int64_t)l.size()) ? l[shuffleIdx].sourceBitBegin : 0;
- int64_t bitHi =
- (shuffleIdx < (int64_t)l.size()) ? l[shuffleIdx].sourceBitEnd : 0;
- IntegerAttr mask = IntegerAttr::get(
- rewriter.getIntegerType(initalElementBitWidth),
- llvm::APInt::getBitsSet(initalElementBitWidth, bitLo, bitHi));
- masks.push_back(mask);
-
- int64_t shiftRight = bitLo;
- shiftRightAmounts.push_back(IntegerAttr::get(
- rewriter.getIntegerType(initalElementBitWidth), shiftRight));
-
- int64_t shiftLeft = l.computeLeftShiftAmount(shuffleIdx);
- shiftLeftAmounts.push_back(IntegerAttr::get(
- rewriter.getIntegerType(initalElementBitWidth), shiftLeft));
- }
-
- // Create vector.shuffle #shuffleIdx.
- auto shuffleOp = rewriter.create<vector::ShuffleOp>(
- bitCastOp.getLoc(), initialValue, initialValue, shuffles);
- // And with the mask.
- VectorType vt = VectorType::Builder(initalVectorType)
- .setDim(initalVectorType.getRank() - 1, masks.size());
- auto constOp = rewriter.create<arith::ConstantOp>(
- bitCastOp.getLoc(), DenseElementsAttr::get(vt, masks));
- Value andValue = rewriter.create<arith::AndIOp>(bitCastOp.getLoc(),
- shuffleOp, constOp);
- // Align right on 0.
- auto shiftRightConstantOp = rewriter.create<arith::ConstantOp>(
- bitCastOp.getLoc(), DenseElementsAttr::get(vt, shiftRightAmounts));
- Value shiftedRight = rewriter.create<arith::ShRUIOp>(
- bitCastOp.getLoc(), andValue, shiftRightConstantOp);
-
- auto shiftLeftConstantOp = rewriter.create<arith::ConstantOp>(
- bitCastOp.getLoc(), DenseElementsAttr::get(vt, shiftLeftAmounts));
- Value shiftedLeft = rewriter.create<arith::ShLIOp>(
- bitCastOp.getLoc(), shiftedRight, shiftLeftConstantOp);
-
- res = res ? rewriter.create<arith::OrIOp>(bitCastOp.getLoc(), res,
- shiftedLeft)
- : shiftedLeft;
+ // Perform the rewrite.
+ Value truncValue = truncOp.getIn();
+ auto shuffledElementType =
+ cast<IntegerType>(getElementTypeOrSelf(truncValue.getType()));
+ Value runningResult;
+ for (const BitCastRewriter ::Metadata &metadata :
+ bcr.precomputeMetadata(shuffledElementType)) {
+ runningResult = bcr.rewriteStep(rewriter, bitCastOp->getLoc(), truncValue,
+ runningResult, metadata);
}
- bool narrowing = resultBitwidth <= initalElementBitWidth;
+ // Finalize the rewrite.
+ bool narrowing = targetVectorType.getElementTypeBitWidth() <=
+ shuffledElementType.getIntOrFloatBitWidth();
if (narrowing) {
rewriter.replaceOpWithNewOp<arith::TruncIOp>(
- bitCastOp, bitCastOp.getResultVectorType(), res);
+ bitCastOp, bitCastOp.getResultVectorType(), runningResult);
} else {
rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
- bitCastOp, bitCastOp.getResultVectorType(), res);
+ bitCastOp, bitCastOp.getResultVectorType(), runningResult);
}
+
+ return success();
+ }
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// RewriteExtOfBitCast
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take
+/// advantage of high-level information to avoid leaving LLVM to scramble with
+/// peephole optimizations.
+template <typename ExtOpType>
+struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
+ using OpRewritePattern<ExtOpType>::OpRewritePattern;
+
+ RewriteExtOfBitCast(MLIRContext *context, PatternBenefit benefit)
+ : OpRewritePattern<ExtOpType>(context, benefit) {}
+
+ LogicalResult matchAndRewrite(ExtOpType extOp,
+ PatternRewriter &rewriter) const override {
+ // The source must be a bitcast op.
+ auto bitCastOp = extOp.getIn().template getDefiningOp<vector::BitCastOp>();
+ if (!bitCastOp)
+ return rewriter.notifyMatchFailure(extOp, "not a bitcast source");
+
+ // Set up the BitCastRewriter and verify the precondition.
+ VectorType sourceVectorType = bitCastOp.getSourceVectorType();
+ VectorType targetVectorType = bitCastOp.getResultVectorType();
+ BitCastRewriter bcr(sourceVectorType, targetVectorType);
+ if (failed(bcr.precondition(rewriter, targetVectorType, bitCastOp)))
+ return failure();
+
+ // Perform the rewrite.
+ Value runningResult;
+ Value sourceValue = bitCastOp.getSource();
+ auto shuffledElementType =
+ cast<IntegerType>(getElementTypeOrSelf(sourceValue.getType()));
+ for (const BitCastRewriter::Metadata &metadata :
+ bcr.precomputeMetadata(shuffledElementType)) {
+ runningResult = bcr.rewriteStep(rewriter, bitCastOp->getLoc(),
+ sourceValue, runningResult, metadata);
+ }
+
+ // Finalize the rewrite.
+ bool narrowing =
+ cast<VectorType>(extOp.getOut().getType()).getElementTypeBitWidth() <=
+ shuffledElementType.getIntOrFloatBitWidth();
+ if (narrowing) {
+ rewriter.replaceOpWithNewOp<arith::TruncIOp>(
+ extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
+ } else {
+ rewriter.replaceOpWithNewOp<ExtOpType>(
+ extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
+ }
+
return success();
}
};
@@ -453,5 +585,7 @@ void vector::populateVectorNarrowTypeEmulationPatterns(
void vector::populateVectorNarrowTypeRewritePatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
- patterns.add<RewriteBitCastOfTruncI>(patterns.getContext(), benefit);
+ patterns.add<RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>,
+ RewriteExtOfBitCast<arith::ExtSIOp>>(patterns.getContext(),
+ benefit);
}
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
index ba6efde40f36c2b..e5e1cc6d37b041b 100644
--- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -4,146 +4,169 @@
/// ====================================================
/// mlir-opt --test-transform-dialect-interpreter mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir -test-transform-dialect-erase-schedule -test-lower-to-llvm | mlir-translate -mlir-to-llvmir | llc -o - -mcpu=skylake-avx512 --function-sections -filetype=obj > /tmp/a.out; objdump -d --disassemble=f1 --no-addresses --no-show-raw-insn -M att /tmp/a.out | ./build/bin/llvm-mca -mcpu=skylake-avx512
-// CHECK-LABEL: func.func @f1(
-// CHECK-SAME: %[[A:[0-9a-z]*]]: vector<32xi64>) -> vector<20xi8>
-func.func @f1(%a: vector<32xi64>) -> vector<20xi8> {
- /// Rewriting this standalone pattern is about 2x faster on skylake-ax512 according to llvm-mca.
- /// Benefit further increases when mixed with other compute ops.
- ///
- /// The provenance of the 20x8 bits of the result are the following bits in the
- /// source vector:
- // { 0: b@[0..5) lshl: 0 } { 1: b@[0..3) lshl: 5 }
- // { 1: b@[3..5) lshl: 0 } { 2: b@[0..5) lshl: 2 } { 3: b@[0..1) lshl: 7 }
- // { 3: b@[1..5) lshl: 0 } { 4: b@[0..4) lshl: 4 }
- // { 4: b@[4..5) lshl: 0 } { 5: b@[0..5) lshl: 1 } { 6: b@[0..2) lshl: 6 }
- // { 6: b@[2..5) lshl: 0 } { 7: b@[0..5) lshl: 3 }
- // { 8: b@[0..5) lshl: 0 } { 9: b@[0..3) lshl: 5 }
- // { 9: b@[3..5) lshl: 0 } { 10: b@[0..5) lshl: 2 } { 11: b@[0..1) lshl: 7 }
- // { 11: b@[1..5) lshl: 0 } { 12: b@[0..4) lshl: 4 }
- // { 12: b@[4..5) lshl: 0 } { 13: b@[0..5) lshl: 1 } { 14: b@[0..2) lshl: 6 }
- // { 14: b@[2..5) lshl: 0 } { 15: b@[0..5) lshl: 3 }
- // { 16: b@[0..5) lshl: 0 } { 17: b@[0..3) lshl: 5 }
- // { 17: b@[3..5) lshl: 0 } { 18: b@[0..5) lshl: 2 } { 19: b@[0..1) lshl: 7 }
- // { 19: b@[1..5) lshl: 0 } { 20: b@[0..4) lshl: 4 }
- // { 20: b@[4..5) lshl: 0 } { 21: b@[0..5) lshl: 1 } { 22: b@[0..2) lshl: 6 }
- // { 22: b@[2..5) lshl: 0 } { 23: b@[0..5) lshl: 3 }
- // { 24: b@[0..5) lshl: 0 } { 25: b@[0..3) lshl: 5 }
- // { 25: b@[3..5) lshl: 0 } { 26: b@[0..5) lshl: 2 } { 27: b@[0..1) lshl: 7 }
- // { 27: b@[1..5) lshl: 0 } { 28: b@[0..4) lshl: 4 }
- // { 28: b@[4..5) lshl: 0 } { 29: b@[0..5) lshl: 1 } { 30: b@[0..2) lshl: 6 }
- // { 30: b@[2..5) lshl: 0 } { 31: b@[0..5) lshl: 3 }
- /// This results in 3 shuffles + 1 shr + 2 shl + 3 and + 2 or.
- /// The third vector is empty for positions 0, 2, 4, 5, 7, 9, 10, 12, 14, 15,
- /// 17 and 19 (i.e. there are only 2 entries in that row).
- ///
- /// 0: b@[0..5), 1: b@[3..5), etc
- // CHECK-DAG: %[[MASK0:.*]] = arith.constant dense<[31, 24, 30, 16, 28, 31, 24, 30, 16, 28, 31, 24, 30, 16, 28, 31, 24, 30, 16, 28]> : vector<20xi64>
- /// 1: b@[0..3), 2: b@[0..5), etc
- // CHECK-DAG: %[[MASK1:.*]] = arith.constant dense<[7, 31, 15, 31, 31, 7, 31, 15, 31, 31, 7, 31, 15, 31, 31, 7, 31, 15, 31, 31]> : vector<20xi64>
- /// empty, 3: b@[0..1), empty etc
- // CHECK-DAG: %[[MASK2:.*]] = arith.constant dense<[0, 1, 0, 3, 0, 0, 1, 0, 3, 0, 0, 1, 0, 3, 0, 0, 1, 0, 3, 0]> : vector<20xi64>
- // CHECK-DAG: %[[SHR0_CST:.*]] = arith.constant dense<[0, 3, 1, 4, 2, 0, 3, 1, 4, 2, 0, 3, 1, 4, 2, 0, 3, 1, 4, 2]> : vector<20xi64>
- // CHECK-DAG: %[[SHL1_CST:.*]] = arith.constant dense<[5, 2, 4, 1, 3, 5, 2, 4, 1, 3, 5, 2, 4, 1, 3, 5, 2, 4, 1, 3]> : vector<20xi64>
- // CHECK-DAG: %[[SHL2_CST:.*]] = arith.constant dense<[8, 7, 8, 6, 8, 8, 7, 8, 6, 8, 8, 7, 8, 6, 8, 8, 7, 8, 6, 8]> : vector<20xi64>
- //
- // CHECK: %[[V0:.*]] = vector.shuffle %[[A]], %[[A]] [0, 1, 3, 4, 6, 8, 9, 11, 12, 14, 16, 17, 19, 20, 22, 24, 25, 27, 28, 30] : vector<32xi64>, vector<32xi64>
- // CHECK: %[[A0:.*]] = arith.andi %[[V0]], %[[MASK0]] : vector<20xi64>
- // CHECK: %[[SHR0:.*]] = arith.shrui %[[A0]], %[[SHR0_CST]] : vector<20xi64>
- // CHECK: %[[V1:.*]] = vector.shuffle %[[A]], %[[A]] [1, 2, 4, 5, 7, 9, 10, 12, 13, 15, 17, 18, 20, 21, 23, 25, 26, 28, 29, 31] : vector<32xi64>, vector<32xi64>
- // CHECK: %[[A1:.*]] = arith.andi %[[V1]], %[[MASK1]] : vector<20xi64>
- // CHECK: %[[SHL1:.*]] = arith.shli %[[A1]], %[[SHL1_CST]] : vector<20xi64>
- // CHECK: %[[O1:.*]] = arith.ori %[[SHR0]], %[[SHL1]] : vector<20xi64>
- // CHECK: %[[V2:.*]] = vector.shuffle %[[A]], %[[A]] [0, 3, 0, 6, 0, 0, 11, 0, 14, 0, 0, 19, 0, 22, 0, 0, 27, 0, 30, 0] : vector<32xi64>, vector<32xi64>
- // CHECK: %[[A2:.*]] = arith.andi %[[V2]], %[[MASK2]] : vector<20xi64>
- // CHECK: %[[SHL2:.*]] = arith.shli %[[A2]], %[[SHL2_CST]] : vector<20xi64>
- // CHECK: %[[O2:.*]] = arith.ori %[[O1]], %[[SHL2]] : vector<20xi64>
- // CHECK: %[[TR:.*]] = arith.trunci %[[O2]] : vector<20xi64> to vector<20xi8>
- // CHECK-NOT: bitcast
- %0 = arith.trunci %a : vector<32xi64> to vector<32xi5>
- %1 = vector.bitcast %0 : vector<32xi5> to vector<20xi8>
- return %1 : vector<20xi8>
-}
+// // CHECK-LABEL: func.func @f1(
+// // CHECK-SAME: %[[A:[0-9a-z]*]]: vector<32xi64>) -> vector<20xi8>
+// func.func @f1(%a: vector<32xi64>) -> vector<20xi8> {
+// /// Rewriting this standalone pattern is about 2x faster on skylake-ax512 according to llvm-mca.
+// /// Benefit further increases when mixed with other compute ops.
+// ///
+// /// The provenance of the 20x8 bits of the result are the following bits in the
+// /// source vector:
+// // { 0: b@[0..5) lshl: 0 } { 1: b@[0..3) lshl: 5 }
+// // { 1: b@[3..5) lshl: 0 } { 2: b@[0..5) lshl: 2 } { 3: b@[0..1) lshl: 7 }
+// // { 3: b@[1..5) lshl: 0 } { 4: b@[0..4) lshl: 4 }
+// // { 4: b@[4..5) lshl: 0 } { 5: b@[0..5) lshl: 1 } { 6: b@[0..2) lshl: 6 }
+// // { 6: b@[2..5) lshl: 0 } { 7: b@[0..5) lshl: 3 }
+// // { 8: b@[0..5) lshl: 0 } { 9: b@[0..3) lshl: 5 }
+// // { 9: b@[3..5) lshl: 0 } { 10: b@[0..5) lshl: 2 } { 11: b@[0..1) lshl: 7 }
+// // { 11: b@[1..5) lshl: 0 } { 12: b@[0..4) lshl: 4 }
+// // { 12: b@[4..5) lshl: 0 } { 13: b@[0..5) lshl: 1 } { 14: b@[0..2) lshl: 6 }
+// // { 14: b@[2..5) lshl: 0 } { 15: b@[0..5) lshl: 3 }
+// // { 16: b@[0..5) lshl: 0 } { 17: b@[0..3) lshl: 5 }
+// // { 17: b@[3..5) lshl: 0 } { 18: b@[0..5) lshl: 2 } { 19: b@[0..1) lshl: 7 }
+// // { 19: b@[1..5) lshl: 0 } { 20: b@[0..4) lshl: 4 }
+// // { 20: b@[4..5) lshl: 0 } { 21: b@[0..5) lshl: 1 } { 22: b@[0..2) lshl: 6 }
+// // { 22: b@[2..5) lshl: 0 } { 23: b@[0..5) lshl: 3 }
+// // { 24: b@[0..5) lshl: 0 } { 25: b@[0..3) lshl: 5 }
+// // { 25: b@[3..5) lshl: 0 } { 26: b@[0..5) lshl: 2 } { 27: b@[0..1) lshl: 7 }
+// // { 27: b@[1..5) lshl: 0 } { 28: b@[0..4) lshl: 4 }
+// // { 28: b@[4..5) lshl: 0 } { 29: b@[0..5) lshl: 1 } { 30: b@[0..2) lshl: 6 }
+// // { 30: b@[2..5) lshl: 0 } { 31: b@[0..5) lshl: 3 }
+// /// This results in 3 shuffles + 1 shr + 2 shl + 3 and + 2 or.
+// /// The third vector is empty for positions 0, 2, 4, 5, 7, 9, 10, 12, 14, 15,
+// /// 17 and 19 (i.e. there are only 2 entries in that row).
+// ///
+// /// 0: b@[0..5), 1: b@[3..5), etc
+// // CHECK-DAG: %[[MASK0:.*]] = arith.constant dense<[31, 24, 30, 16, 28, 31, 24, 30, 16, 28, 31, 24, 30, 16, 28, 31, 24, 30, 16, 28]> : vector<20xi64>
+// /// 1: b@[0..3), 2: b@[0..5), etc
+// // CHECK-DAG: %[[MASK1:.*]] = arith.constant dense<[7, 31, 15, 31, 31, 7, 31, 15, 31, 31, 7, 31, 15, 31, 31, 7, 31, 15, 31, 31]> : vector<20xi64>
+// /// empty, 3: b@[0..1), empty etc
+// // CHECK-DAG: %[[MASK2:.*]] = arith.constant dense<[0, 1, 0, 3, 0, 0, 1, 0, 3, 0, 0, 1, 0, 3, 0, 0, 1, 0, 3, 0]> : vector<20xi64>
+// // CHECK-DAG: %[[SHR0_CST:.*]] = arith.constant dense<[0, 3, 1, 4, 2, 0, 3, 1, 4, 2, 0, 3, 1, 4, 2, 0, 3, 1, 4, 2]> : vector<20xi64>
+// // CHECK-DAG: %[[SHL1_CST:.*]] = arith.constant dense<[5, 2, 4, 1, 3, 5, 2, 4, 1, 3, 5, 2, 4, 1, 3, 5, 2, 4, 1, 3]> : vector<20xi64>
+// // CHECK-DAG: %[[SHL2_CST:.*]] = arith.constant dense<[8, 7, 8, 6, 8, 8, 7, 8, 6, 8, 8, 7, 8, 6, 8, 8, 7, 8, 6, 8]> : vector<20xi64>
+// //
+// // CHECK: %[[V0:.*]] = vector.shuffle %[[A]], %[[A]] [0, 1, 3, 4, 6, 8, 9, 11, 12, 14, 16, 17, 19, 20, 22, 24, 25, 27, 28, 30] : vector<32xi64>, vector<32xi64>
+// // CHECK: %[[A0:.*]] = arith.andi %[[V0]], %[[MASK0]] : vector<20xi64>
+// // CHECK: %[[SHR0:.*]] = arith.shrui %[[A0]], %[[SHR0_CST]] : vector<20xi64>
+// // CHECK: %[[V1:.*]] = vector.shuffle %[[A]], %[[A]] [1, 2, 4, 5, 7, 9, 10, 12, 13, 15, 17, 18, 20, 21, 23, 25, 26, 28, 29, 31] : vector<32xi64>, vector<32xi64>
+// // CHECK: %[[A1:.*]] = arith.andi %[[V1]], %[[MASK1]] : vector<20xi64>
+// // CHECK: %[[SHL1:.*]] = arith.shli %[[A1]], %[[SHL1_CST]] : vector<20xi64>
+// // CHECK: %[[O1:.*]] = arith.ori %[[SHR0]], %[[SHL1]] : vector<20xi64>
+// // CHECK: %[[V2:.*]] = vector.shuffle %[[A]], %[[A]] [0, 3, 0, 6, 0, 0, 11, 0, 14, 0, 0, 19, 0, 22, 0, 0, 27, 0, 30, 0] : vector<32xi64>, vector<32xi64>
+// // CHECK: %[[A2:.*]] = arith.andi %[[V2]], %[[MASK2]] : vector<20xi64>
+// // CHECK: %[[SHL2:.*]] = arith.shli %[[A2]], %[[SHL2_CST]] : vector<20xi64>
+// // CHECK: %[[O2:.*]] = arith.ori %[[O1]], %[[SHL2]] : vector<20xi64>
+// // CHECK: %[[TR:.*]] = arith.trunci %[[O2]] : vector<20xi64> to vector<20xi8>
+// // CHECK-NOT: bitcast
+// %0 = arith.trunci %a : vector<32xi64> to vector<32xi5>
+// %1 = vector.bitcast %0 : vector<32xi5> to vector<20xi8>
+// return %1 : vector<20xi8>
+// }
-// CHECK-LABEL: func.func @f2(
-// CHECK-SAME: %[[A:[0-9a-z]*]]: vector<16xi16>) -> vector<3xi16>
-func.func @f2(%a: vector<16xi16>) -> vector<3xi16> {
- /// Rewriting this standalone pattern is about 1.8x faster on skylake-ax512 according to llvm-mca.
- /// Benefit further increases when mixed with other compute ops.
- ///
- // { 0: b@[0..3) lshl: 0 } { 1: b@[0..3) lshl: 3 } { 2: b@[0..3) lshl: 6 } { 3: b@[0..3) lshl: 9 } { 4: b@[0..3) lshl: 12 } { 5: b@[0..1) lshl: 15 }
- // { 5: b@[1..3) lshl: 0 } { 6: b@[0..3) lshl: 2 } { 7: b@[0..3) lshl: 5 } { 8: b@[0..3) lshl: 8 } { 9: b@[0..3) lshl: 11 } { 10: b@[0..2) lshl: 14 }
- // { 10: b@[2..3) lshl: 0 } { 11: b@[0..3) lshl: 1 } { 12: b@[0..3) lshl: 4 } { 13: b@[0..3) lshl: 7 } { 14: b@[0..3) lshl: 10 } { 15: b@[0..3) lshl: 13 }
- /// 0: b@[0..3), 5: b@[1..3), 10: b@[2..3)
- // CHECK-DAG: %[[MASK0:.*]] = arith.constant dense<[7, 6, 4]> : vector<3xi16>
- /// 1: b@[0..3), 6: b@[0..3), 11: b@[0..3)
- /// ...
- // CHECK-DAG: %[[MASK1:.*]] = arith.constant dense<7> : vector<3xi16>
- /// 5: b@[0..1), 10: b@[0..2), 15: b@[0..3)
- // CHECK-DAG: %[[MASK2:.*]] = arith.constant dense<[1, 3, 7]> : vector<3xi16>
- // CHECK-DAG: %[[SHR0_CST:.*]] = arith.constant dense<[0, 1, 2]> : vector<3xi16>
- // CHECK-DAG: %[[SHL1_CST:.*]] = arith.constant dense<[3, 2, 1]> : vector<3xi16>
- // CHECK-DAG: %[[SHL2_CST:.*]] = arith.constant dense<[6, 5, 4]> : vector<3xi16>
- // CHECK-DAG: %[[SHL3_CST:.*]] = arith.constant dense<[9, 8, 7]> : vector<3xi16>
- // CHECK-DAG: %[[SHL4_CST:.*]] = arith.constant dense<[12, 11, 10]> : vector<3xi16>
- // CHECK-DAG: %[[SHL5_CST:.*]] = arith.constant dense<[15, 14, 13]> : vector<3xi16>
+// // CHECK-LABEL: func.func @f2(
+// // CHECK-SAME: %[[A:[0-9a-z]*]]: vector<16xi16>) -> vector<3xi16>
+// func.func @f2(%a: vector<16xi16>) -> vector<3xi16> {
+// /// Rewriting this standalone pattern is about 1.8x faster on skylake-ax512 according to llvm-mca.
+// /// Benefit further increases when mixed with other compute ops.
+// ///
+// // { 0: b@[0..3) lshl: 0 } { 1: b@[0..3) lshl: 3 } { 2: b@[0..3) lshl: 6 } { 3: b@[0..3) lshl: 9 } { 4: b@[0..3) lshl: 12 } { 5: b@[0..1) lshl: 15 }
+// // { 5: b@[1..3) lshl: 0 } { 6: b@[0..3) lshl: 2 } { 7: b@[0..3) lshl: 5 } { 8: b@[0..3) lshl: 8 } { 9: b@[0..3) lshl: 11 } { 10: b@[0..2) lshl: 14 }
+// // { 10: b@[2..3) lshl: 0 } { 11: b@[0..3) lshl: 1 } { 12: b@[0..3) lshl: 4 } { 13: b@[0..3) lshl: 7 } { 14: b@[0..3) lshl: 10 } { 15: b@[0..3) lshl: 13 }
+// /// 0: b@[0..3), 5: b@[1..3), 10: b@[2..3)
+// // CHECK-DAG: %[[MASK0:.*]] = arith.constant dense<[7, 6, 4]> : vector<3xi16>
+// /// 1: b@[0..3), 6: b@[0..3), 11: b@[0..3)
+// /// ...
+// // CHECK-DAG: %[[MASK1:.*]] = arith.constant dense<7> : vector<3xi16>
+// /// 5: b@[0..1), 10: b@[0..2), 15: b@[0..3)
+// // CHECK-DAG: %[[MASK2:.*]] = arith.constant dense<[1, 3, 7]> : vector<3xi16>
+// // CHECK-DAG: %[[SHR0_CST:.*]] = arith.constant dense<[0, 1, 2]> : vector<3xi16>
+// // CHECK-DAG: %[[SHL1_CST:.*]] = arith.constant dense<[3, 2, 1]> : vector<3xi16>
+// // CHECK-DAG: %[[SHL2_CST:.*]] = arith.constant dense<[6, 5, 4]> : vector<3xi16>
+// // CHECK-DAG: %[[SHL3_CST:.*]] = arith.constant dense<[9, 8, 7]> : vector<3xi16>
+// // CHECK-DAG: %[[SHL4_CST:.*]] = arith.constant dense<[12, 11, 10]> : vector<3xi16>
+// // CHECK-DAG: %[[SHL5_CST:.*]] = arith.constant dense<[15, 14, 13]> : vector<3xi16>
- //
- // CHECK: %[[V0:.*]] = vector.shuffle %[[A]], %[[A]] [0, 5, 10] : vector<16xi16>, vector<16xi16>
- // CHECK: %[[A0:.*]] = arith.andi %[[V0]], %[[MASK0]] : vector<3xi16>
- // CHECK: %[[SHR0:.*]] = arith.shrui %[[A0]], %[[SHR0_CST]] : vector<3xi16>
- // CHECK: %[[V1:.*]] = vector.shuffle %[[A]], %[[A]] [1, 6, 11] : vector<16xi16>, vector<16xi16>
- // CHECK: %[[A1:.*]] = arith.andi %[[V1]], %[[MASK1]] : vector<3xi16>
- // CHECK: %[[SHL1:.*]] = arith.shli %[[A1]], %[[SHL1_CST]] : vector<3xi16>
- // CHECK: %[[O1:.*]] = arith.ori %[[SHR0]], %[[SHL1]] : vector<3xi16>
- // CHECK: %[[V2:.*]] = vector.shuffle %[[A]], %[[A]] [2, 7, 12] : vector<16xi16>, vector<16xi16>
- // CHECK: %[[A2:.*]] = arith.andi %[[V2]], %[[MASK1]] : vector<3xi16>
- // CHECK: %[[SHL2:.*]] = arith.shli %[[A2]], %[[SHL2_CST]] : vector<3xi16>
- // CHECK: %[[O2:.*]] = arith.ori %[[O1]], %[[SHL2]] : vector<3xi16>
- // CHECK: %[[V3:.*]] = vector.shuffle %[[A]], %[[A]] [3, 8, 13] : vector<16xi16>, vector<16xi16>
- // CHECK: %[[A3:.*]] = arith.andi %[[V3]], %[[MASK1]] : vector<3xi16>
- // CHECK: %[[SHL3:.*]] = arith.shli %[[A3]], %[[SHL3_CST]] : vector<3xi16>
- // CHECK: %[[O3:.*]] = arith.ori %[[O2]], %[[SHL3]] : vector<3xi16>
- // CHECK: %[[V4:.*]] = vector.shuffle %[[A]], %[[A]] [4, 9, 14] : vector<16xi16>, vector<16xi16>
- // CHECK: %[[A4:.*]] = arith.andi %[[V4]], %[[MASK1]] : vector<3xi16>
- // CHECK: %[[SHL4:.*]] = arith.shli %[[A4]], %[[SHL4_CST]] : vector<3xi16>
- // CHECK: %[[O4:.*]] = arith.ori %[[O3]], %[[SHL4]] : vector<3xi16>
- // CHECK: %[[V5:.*]] = vector.shuffle %[[A]], %[[A]] [5, 10, 15] : vector<16xi16>, vector<16xi16>
- // CHECK: %[[A5:.*]] = arith.andi %[[V5]], %[[MASK2]] : vector<3xi16>
- // CHECK: %[[SHL5:.*]] = arith.shli %[[A5]], %[[SHL5_CST]] : vector<3xi16>
- // CHECK: %[[O5:.*]] = arith.ori %[[O4]], %[[SHL5]] : vector<3xi16>
- /// No trunci needed as the result is already in i16.
- // CHECK-NOT: arith.trunci
- // CHECK-NOT: bitcast
- %0 = arith.trunci %a : vector<16xi16> to vector<16xi3>
- %1 = vector.bitcast %0 : vector<16xi3> to vector<3xi16>
- return %1 : vector<3xi16>
-}
+// //
+// // CHECK: %[[V0:.*]] = vector.shuffle %[[A]], %[[A]] [0, 5, 10] : vector<16xi16>, vector<16xi16>
+// // CHECK: %[[A0:.*]] = arith.andi %[[V0]], %[[MASK0]] : vector<3xi16>
+// // CHECK: %[[SHR0:.*]] = arith.shrui %[[A0]], %[[SHR0_CST]] : vector<3xi16>
+// // CHECK: %[[V1:.*]] = vector.shuffle %[[A]], %[[A]] [1, 6, 11] : vector<16xi16>, vector<16xi16>
+// // CHECK: %[[A1:.*]] = arith.andi %[[V1]], %[[MASK1]] : vector<3xi16>
+// // CHECK: %[[SHL1:.*]] = arith.shli %[[A1]], %[[SHL1_CST]] : vector<3xi16>
+// // CHECK: %[[O1:.*]] = arith.ori %[[SHR0]], %[[SHL1]] : vector<3xi16>
+// // CHECK: %[[V2:.*]] = vector.shuffle %[[A]], %[[A]] [2, 7, 12] : vector<16xi16>, vector<16xi16>
+// // CHECK: %[[A2:.*]] = arith.andi %[[V2]], %[[MASK1]] : vector<3xi16>
+// // CHECK: %[[SHL2:.*]] = arith.shli %[[A2]], %[[SHL2_CST]] : vector<3xi16>
+// // CHECK: %[[O2:.*]] = arith.ori %[[O1]], %[[SHL2]] : vector<3xi16>
+// // CHECK: %[[V3:.*]] = vector.shuffle %[[A]], %[[A]] [3, 8, 13] : vector<16xi16>, vector<16xi16>
+// // CHECK: %[[A3:.*]] = arith.andi %[[V3]], %[[MASK1]] : vector<3xi16>
+// // CHECK: %[[SHL3:.*]] = arith.shli %[[A3]], %[[SHL3_CST]] : vector<3xi16>
+// // CHECK: %[[O3:.*]] = arith.ori %[[O2]], %[[SHL3]] : vector<3xi16>
+// // CHECK: %[[V4:.*]] = vector.shuffle %[[A]], %[[A]] [4, 9, 14] : vector<16xi16>, vector<16xi16>
+// // CHECK: %[[A4:.*]] = arith.andi %[[V4]], %[[MASK1]] : vector<3xi16>
+// // CHECK: %[[SHL4:.*]] = arith.shli %[[A4]], %[[SHL4_CST]] : vector<3xi16>
+// // CHECK: %[[O4:.*]] = arith.ori %[[O3]], %[[SHL4]] : vector<3xi16>
+// // CHECK: %[[V5:.*]] = vector.shuffle %[[A]], %[[A]] [5, 10, 15] : vector<16xi16>, vector<16xi16>
+// // CHECK: %[[A5:.*]] = arith.andi %[[V5]], %[[MASK2]] : vector<3xi16>
+// // CHECK: %[[SHL5:.*]] = arith.shli %[[A5]], %[[SHL5_CST]] : vector<3xi16>
+// // CHECK: %[[O5:.*]] = arith.ori %[[O4]], %[[SHL5]] : vector<3xi16>
+// /// No trunci needed as the result is already in i16.
+// // CHECK-NOT: arith.trunci
+// // CHECK-NOT: bitcast
+// %0 = arith.trunci %a : vector<16xi16> to vector<16xi3>
+// %1 = vector.bitcast %0 : vector<16xi3> to vector<3xi16>
+// return %1 : vector<3xi16>
+// }
-/// This pattern requires an extui 16 -> 32 and not a trunci.
-// CHECK-LABEL: func.func @f3(
-func.func @f3(%a: vector<16xi16>) -> vector<2xi32> {
- /// Rewriting this standalone pattern is about 25x faster on skylake-ax512 according to llvm-mca.
- /// Benefit further increases when mixed with other compute ops.
- ///
- // CHECK-NOT: arith.trunci
- // CHECK-NOT: bitcast
- // CHECK: arith.extui
- %0 = arith.trunci %a : vector<16xi16> to vector<16xi4>
- %1 = vector.bitcast %0 : vector<16xi4> to vector<2xi32>
- return %1 : vector<2xi32>
-}
+// /// This pattern requires an extui 16 -> 32 and not a trunci.
+// // CHECK-LABEL: func.func @f3(
+// func.func @f3(%a: vector<16xi16>) -> vector<2xi32> {
+// /// Rewriting this standalone pattern is about 25x faster on skylake-ax512 according to llvm-mca.
+// /// Benefit further increases when mixed with other compute ops.
+// ///
+// // CHECK-NOT: arith.trunci
+// // CHECK-NOT: bitcast
+// // CHECK: arith.extui
+// %0 = arith.trunci %a : vector<16xi16> to vector<16xi4>
+// %1 = vector.bitcast %0 : vector<16xi4> to vector<2xi32>
+// return %1 : vector<2xi32>
+// }
+
+// /// This pattern is not rewritten as the result i6 is not a multiple of i8.
+// // CHECK-LABEL: func.func @f4(
+// func.func @f4(%a: vector<16xi16>) -> vector<8xi6> {
+// // CHECK: trunci
+// // CHECK: bitcast
+// // CHECK-NOT: shuffle
+// // CHECK-NOT: andi
+// // CHECK-NOT: ori
+// %0 = arith.trunci %a : vector<16xi16> to vector<16xi3>
+// %1 = vector.bitcast %0 : vector<16xi3> to vector<8xi6>
+// return %1 : vector<8xi6>
+// }
+
+
+// // CHECK-LABEL: func.func @f1ext(
+// // CHECK-SAME: %[[A:[0-9a-z]*]]: vector<32xi64>) -> vector<32xi64> {
+// func.func @f1ext(%a: vector<20xi8>) -> vector<32xi64> {
+// %0 = vector.bitcast %a : vector<20xi8> to vector<32xi5>
+// %1 = arith.extui %0 : vector<32xi5> to vector<32xi64>
+// return %1 : vector<32xi64>
+// }
+
+// // CHECK-LABEL: func.func @f2ext(
+// // CHECK-SAME: %[[A:[0-9a-z]*]]: vector<3xi16>) -> vector<16xi16>
+// func.func @f2ext(%a: vector<3xi16>) -> vector<16xi16> {
+// %0 = vector.bitcast %a : vector<3xi16> x vector<16xi3>
+// %1 = arith.extui %0 : vector<16xi3> to vector<16xi16>
+// return %1 : vector<16xi16>
+// }
-/// This pattern is not rewritten as the result i6 is not a multiple of i8.
-// CHECK-LABEL: func.func @f4(
-func.func @f4(%a: vector<16xi16>) -> vector<8xi6> {
- // CHECK: trunci
- // CHECK: bitcast
- // CHECK-NOT: shuffle
- // CHECK-NOT: andi
- // CHECK-NOT: ori
- %0 = arith.trunci %a : vector<16xi16> to vector<16xi3>
- %1 = vector.bitcast %0 : vector<16xi3> to vector<8xi6>
- return %1 : vector<8xi6>
+func.func @fext(%a: vector<5xi8>) -> vector<8xi16> {
+ %0 = vector.bitcast %a : vector<5xi8> to vector<8xi5>
+ %1 = arith.extui %0 : vector<8xi5> to vector<8xi16>
+ return %1 : vector<8xi16>
}
transform.sequence failures(propagate) {
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-rewrite-narrow-types.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-rewrite-narrow-types.mlir
index 44c608726f13530..7d15e2e2e3ef5e4 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-rewrite-narrow-types.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-rewrite-narrow-types.mlir
@@ -124,6 +124,47 @@ func.func @f3(%v: vector<2xi48>) {
return
}
+func.func @print_as_i1_8xi5(%v : vector<8xi5>) {
+ %bitsi40 = vector.bitcast %v : vector<8xi5> to vector<40xi1>
+ vector.print %bitsi40 : vector<40xi1>
+ return
+}
+
+func.func @print_as_i1_8xi16(%v : vector<8xi16>) {
+ %bitsi128 = vector.bitcast %v : vector<8xi16> to vector<128xi1>
+ vector.print %bitsi128 : vector<128xi1>
+ return
+}
+
+func.func @fext(%a: vector<5xi8>) {
+ %0 = vector.bitcast %a : vector<5xi8> to vector<8xi5>
+ func.call @print_as_i1_8xi5(%0) : (vector<8xi5>) -> ()
+ // CHECK: (
+ // CHECK-SAME: 1, 1, 1, 1, 0,
+ // CHECK-SAME: 1, 1, 1, 0, 1,
+ // CHECK-SAME: 1, 1, 0, 1, 1,
+ // CHECK-SAME: 1, 1, 0, 1, 1,
+ // CHECK-SAME: 0, 1, 1, 1, 0,
+ // CHECK-SAME: 0, 1, 1, 0, 1,
+ // CHECK-SAME: 1, 1, 1, 1, 0,
+ // CHECK-SAME: 1, 0, 1, 1, 1 )
+
+ %1 = arith.extui %0 : vector<8xi5> to vector<8xi16>
+ func.call @print_as_i1_8xi16(%1) : (vector<8xi16>) -> ()
+ // CHECK: (
+ // CHECK-SAME: 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ // CHECK-SAME: 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ // CHECK-SAME: 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ // CHECK-SAME: 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ // CHECK-SAME: 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ // CHECK-SAME: 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ // CHECK-SAME: 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ // CHECK-SAME: 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
+
+ return
+}
+
+
func.func @entry() {
%v = arith.constant dense<[
0xffff, 0xfffe, 0xfffd, 0xfffc, 0xfffb, 0xfffa, 0xfff9, 0xfff8,
@@ -141,6 +182,11 @@ func.func @entry() {
]> : vector<2xi48>
func.call @f3(%v3) : (vector<2xi48>) -> ()
+ %v4 = arith.constant dense<[
+ 0xef, 0xee, 0xed, 0xec, 0xeb
+ ]> : vector<5xi8>
+ func.call @fext(%v4) : (vector<5xi8>) -> ()
+
return
}
More information about the Mlir-commits
mailing list