[Mlir-commits] [mlir] [mlir][Vector] Add a rewrite pattern for better low-precision ext(bit… (PR #65774)
Nicolas Vasilache
llvmlistbot at llvm.org
Tue Sep 12 08:27:15 PDT 2023
https://github.com/nicolasvasilache updated https://github.com/llvm/llvm-project/pull/65774:
>From 11caf7cacb71874a78474bef25ecc27d68e642e2 Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nicolasvasilache at users.noreply.github.com>
Date: Fri, 8 Sep 2023 11:28:55 +0200
Subject: [PATCH] [mlir][Vector] Add a rewrite pattern for better low-precision
ext(bitcast) expansion
This revision adds a rewrite for sequences of vector `ext(maybe_broadcast(bitcast))`
to use a more efficient sequence of vector operations comprising `shuffle`, `shift` and
`bitwise` ops.
The rewrite uses an intermediate bitwidth equal to the licm of
the element types of the source and result types of `bitCastOp`. This
intermediate type may be small or greater than the desired elemental type of
the `ext`, in which case appropriate `ext` or `trunc` operations are inserted.
The rewrite fails if the intermediate type is greater than `64` and if the
involved vector types fail to meet basic divisilibity requirements. In other
words, this rewrite does not handle partial vector boundaries and leaves
this part of the heavy-lifting to LLVM.
In the future, it may be relevant to give control on the size of the intermediate type.
For now, it is empirically determined that taking `64` result in much better assembly
being produced when piping through `llvm-mca`.
---
.../Vector/TransformOps/VectorTransformOps.td | 13 +
.../Vector/Transforms/VectorRewritePatterns.h | 22 +-
mlir/include/mlir/IR/BuiltinTypes.h | 10 +
.../TransformOps/VectorTransformOps.cpp | 5 +
.../Transforms/VectorEmulateNarrowType.cpp | 280 +++++++++++++++++-
mlir/test/Dialect/LLVM/transform-e2e.mlir | 21 --
.../Vector/vector-rewrite-narrow-types.mlir | 205 +++++++++++++
7 files changed, 530 insertions(+), 26 deletions(-)
create mode 100644 mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 2b8c95a94257e6c..d1cef91f8e27525 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -281,6 +281,19 @@ def ApplyLowerTransposePatternsOp : Op<Transform_Dialect,
}];
}
+def ApplyRewriteNarrowTypePatternsOp : Op<Transform_Dialect,
+ "apply_patterns.vector.rewrite_narrow_types",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Indicates that vector narrow rewrite operations should be applied.
+
+ This is usually a late step that is run after bufferization as part of the
+ process of lowering to e.g. LLVM or NVVM.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
def ApplySplitTransferFullPartialPatternsOp : Op<Transform_Dialect,
"apply_patterns.vector.split_transfer_full_partial",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index c644090d8c78cd0..20c33921f9de24e 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -23,6 +23,7 @@ namespace mlir {
class RewritePatternSet;
namespace arith {
+class AndIOp;
class NarrowTypeEmulationConverter;
} // namespace arith
@@ -143,7 +144,7 @@ void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
/// Patterns that remove redundant vector broadcasts.
void populateSinkVectorBroadcastPatterns(RewritePatternSet &patterns,
- PatternBenefit benefit = 1);
+ PatternBenefit benefit = 1);
/// Populate `patterns` with the following patterns.
///
@@ -301,6 +302,25 @@ void populateVectorNarrowTypeEmulationPatterns(
arith::NarrowTypeEmulationConverter &typeConverter,
RewritePatternSet &patterns);
+/// Rewrite vector ext(maybe_broadcast(bitcast)) to use a more efficient
+/// sequence of vector operations comprising shuffles, shifts and bitwise
+/// logical ops. The rewrite uses an intermediate bitwidth equal to the licm of
+/// the element types of the source and result types of `bitCastOp`. This
+/// intermediate type may be small or greater than the desired elemental type of
+/// the extOp, in which case appropriate ext or trunc operations are inserted.
+/// The rewrite fails if the intermediate type is greater than 64 and if the
+/// involved vector types fail to meet basic divisilibity requirements. In other
+/// words, this rewrite does not handle partial vector boundaries and leaves
+/// this part of the heavy-lifting to LLVM.
+FailureOr<Value> rewriteExtOfBitCast(RewriterBase &rewriter, Operation *extOp,
+ vector::BitCastOp bitCastOp,
+ vector::BroadcastOp maybeBroadcastOp);
+
+/// Appends patterns for rewriting vector operations over narrow types with
+/// ops over wider types.
+void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
+
} // namespace vector
} // namespace mlir
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index f031eb0a5c30ce9..9df5548cd5d939c 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -357,6 +357,16 @@ class VectorType::Builder {
return *this;
}
+ /// Set a dim in shape @pos to val.
+ Builder &setDim(unsigned pos, int64_t val) {
+ if (storage.empty())
+ storage.append(shape.begin(), shape.end());
+ assert(pos < storage.size() && "overflow");
+ storage[pos] = val;
+ shape = {storage.data(), storage.size()};
+ return *this;
+ }
+
operator VectorType() {
return VectorType::get(shape, elementType, scalableDims);
}
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 94f19e59669eafd..0fdeded436a9773 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -154,6 +154,11 @@ void transform::ApplyLowerTransposePatternsOp::populatePatterns(
}
}
+void transform::ApplyRewriteNarrowTypePatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ populateVectorNarrowTypeRewritePatterns(patterns);
+}
+
void transform::ApplySplitTransferFullPartialPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
vector::VectorTransformsOptions vectorTransformOptions;
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index b2b7bfc5e4437c1..d10994f3709e390 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -7,7 +7,6 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
@@ -15,13 +14,29 @@
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Support/MathExtras.h"
#include "mlir/Transforms/DialectConversion.h"
-#include "llvm/Support/FormatVariadic.h"
-#include "llvm/Support/MathExtras.h"
-#include <cassert>
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/raw_ostream.h"
+#include <cstdint>
+#include <numeric>
+#include <type_traits>
using namespace mlir;
+#define DEBUG_TYPE "vector-narrow-type-emulation"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+#define DBGSNL() (llvm::dbgs() << "\n")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
+
namespace {
//===----------------------------------------------------------------------===//
@@ -155,6 +170,256 @@ struct ConvertVectorTransferRead final
};
} // end anonymous namespace
+//===----------------------------------------------------------------------===//
+// RewriteExtOfBitCast
+//===----------------------------------------------------------------------===//
+
+/// Create a vector of bit masks: `idx .. idx + step - 1` and broadcast it
+/// `numOccurrences` times.
+/// `step` must divide `bitwidth` evenly.
+/// Example: (4, 2, 3) -> [0xc, 0x3, 0xc, 0x3, 0xc, 0x3].
+static SmallVector<Attribute> computeExtOfBitCastMasks(MLIRContext *ctx,
+ int64_t bitwidth,
+ int64_t step,
+ int64_t numOccurrences) {
+ assert(bitwidth % step == 0 && "step must divide bitwidth evenly");
+ IntegerType interimIntType = IntegerType::get(ctx, bitwidth);
+ SmallVector<Attribute> tmpMasks;
+ tmpMasks.reserve(bitwidth / step);
+ // Create a vector of bit masks: `idx .. idx + step - 1`.
+ for (int64_t idx = 0; idx < bitwidth; idx += step) {
+ LDBG("Mask bits " << idx << " .. " << idx + step - 1 << " out of "
+ << bitwidth);
+ IntegerAttr mask = IntegerAttr::get(
+ interimIntType, llvm::APInt::getBitsSet(bitwidth, idx, idx + step));
+ tmpMasks.push_back(mask);
+ }
+ // Replicate the vector of bit masks to the desired size.
+ SmallVector<Attribute> masks;
+ masks.reserve(numOccurrences * tmpMasks.size());
+ for (int64_t idx = 0; idx < numOccurrences; ++idx)
+ llvm::append_range(masks, tmpMasks);
+ return masks;
+}
+
+/// Create a vector of bit shifts by `k * idx` and broadcast it `numOccurrences`
+/// times.
+/// `step` must divide `bitwidth` evenly.
+/// Example: (4, 2, 3) -> [0x0, 0x2, 0x0, 0x2, 0x0, 0x2].
+static SmallVector<Attribute>
+computeExtOfBitCastShifts(MLIRContext *ctx, int64_t bitwidth, int64_t step,
+ int64_t numOccurrences) {
+ assert(bitwidth % step == 0 && "step must divide bitwidth evenly");
+ IntegerType interimIntType = IntegerType::get(ctx, bitwidth);
+ SmallVector<Attribute> tmpShifts;
+ for (int64_t idx = 0; idx < bitwidth; idx += step) {
+ IntegerAttr shift = IntegerAttr::get(interimIntType, idx);
+ tmpShifts.push_back(shift);
+ }
+ SmallVector<Attribute> shifts;
+ for (int64_t idx = 0; idx < numOccurrences; ++idx)
+ llvm::append_range(shifts, tmpShifts);
+ return shifts;
+}
+
+/// Create a vector of bit shuffles: `numOccurrences * idx` and broadcast it
+/// `bitwidth/step` times.
+/// `step` must divide `bitwidth` evenly.
+/// Example: (4, 2, 3) -> [0x0, 0x1, 0x0, 0x1, 0x0, 0x1].
+static SmallVector<int64_t>
+computeExtOfBitCastShuffles(MLIRContext *ctx, int64_t bitwidth, int64_t step,
+ int64_t numOccurrences) {
+ assert(bitwidth % step == 0 && "step must divide bitwidth evenly");
+ SmallVector<int64_t> shuffles;
+ int64_t n = floorDiv(bitwidth, step);
+ for (int64_t idx = 0; idx < numOccurrences; ++idx)
+ llvm::append_range(shuffles, SmallVector<int64_t>(n, idx));
+ return shuffles;
+}
+
+/// Compute the intermediate vector type, its elemental type must be an integer
+/// with bitwidth that:
+/// 1. is smaller than 64 (TODO: in the future we may want target-specific
+/// control).
+/// 2. divides sourceBitWidth * mostMinorSourceDim
+static int64_t computeExtOfBitCastBitWidth(int64_t sourceBitWidth,
+ int64_t mostMinorSourceDim,
+ int64_t targetBitWidth) {
+ for (int64_t mult : {32, 16, 8, 4, 2, 1}) {
+ int64_t interimBitWidth =
+ std::lcm(mult, std::lcm(sourceBitWidth, targetBitWidth));
+ if (interimBitWidth > 64)
+ continue;
+ if ((sourceBitWidth * mostMinorSourceDim) % interimBitWidth != 0)
+ continue;
+ return interimBitWidth;
+ }
+ return 0;
+}
+
+FailureOr<Value>
+mlir::vector::rewriteExtOfBitCast(RewriterBase &rewriter, Operation *extOp,
+ vector::BitCastOp bitCastOp,
+ vector::BroadcastOp maybeBroadcastOp) {
+ assert(
+ (llvm::isa<arith::ExtSIOp>(extOp) || llvm::isa<arith::ExtUIOp>(extOp)) &&
+ "unsupported op");
+
+ // The bitcast op is the load-bearing part, capture the source and bitCast
+ // types as well as bitwidth and most minor dimension.
+ VectorType sourceVectorType = bitCastOp.getSourceVectorType();
+ int64_t sourceBitWidth = sourceVectorType.getElementTypeBitWidth();
+ int64_t mostMinorSourceDim = sourceVectorType.getShape().back();
+ LDBG("sourceVectorType: " << sourceVectorType);
+
+ VectorType bitCastVectorType = bitCastOp.getResultVectorType();
+ int64_t targetBitWidth = bitCastVectorType.getElementTypeBitWidth();
+ LDBG("bitCastVectorType: " << bitCastVectorType);
+
+ int64_t interimBitWidth = computeExtOfBitCastBitWidth(
+ sourceBitWidth, mostMinorSourceDim, targetBitWidth);
+ LDBG("interimBitWidth: " << interimBitWidth);
+ if (!interimBitWidth) {
+ return rewriter.notifyMatchFailure(
+ extOp, "heuristic could not find a reasonable interim bitwidth");
+ }
+ if (sourceBitWidth == interimBitWidth || targetBitWidth == interimBitWidth) {
+ return rewriter.notifyMatchFailure(
+ extOp, "interim bitwidth is equal to source or target, nothing to do");
+ }
+
+ int64_t interimMostMinorDim =
+ sourceBitWidth * mostMinorSourceDim / interimBitWidth;
+ LDBG("interimMostMinorDim: " << interimMostMinorDim);
+
+ Location loc = extOp->getLoc();
+ MLIRContext *ctx = extOp->getContext();
+
+ VectorType interimVectorType =
+ VectorType::Builder(sourceVectorType)
+ .setDim(sourceVectorType.getRank() - 1, interimMostMinorDim)
+ .setElementType(IntegerType::get(ctx, interimBitWidth));
+ LDBG("interimVectorType: " << interimVectorType);
+
+ IntegerType interimIntType = IntegerType::get(ctx, interimBitWidth);
+ VectorType vt =
+ VectorType::Builder(bitCastVectorType).setElementType(interimIntType);
+
+ // Rewrite the original bitcast to the interim vector type and shuffle to
+ // broadcast to the desired size.
+ auto newBitCastOp = rewriter.create<vector::BitCastOp>(loc, interimVectorType,
+ bitCastOp.getSource());
+ SmallVector<int64_t> shuffles = computeExtOfBitCastShuffles(
+ ctx, interimBitWidth, targetBitWidth, interimMostMinorDim);
+ auto shuffleOp = rewriter.create<vector::ShuffleOp>(loc, newBitCastOp,
+ newBitCastOp, shuffles);
+ LDBG("shuffle: " << shuffleOp);
+
+ // Compute the constants for masking.
+ SmallVector<Attribute> masks = computeExtOfBitCastMasks(
+ ctx, interimBitWidth, targetBitWidth, interimMostMinorDim);
+ auto maskConstantOp = rewriter.create<arith::ConstantOp>(
+ loc, DenseElementsAttr::get(vt, masks));
+ LDBG("maskConstant: " << maskConstantOp);
+ auto andOp = rewriter.create<arith::AndIOp>(loc, shuffleOp, maskConstantOp);
+ LDBG("andOp: " << andOp);
+
+ // Preserve the intermediate type: this may have serious consequences on the
+ // backend's ability to generate efficient vector operations.
+ // For instance on x86, converting to f16 without going through i32 has severe
+ // performance implications.
+ // As a consequence, this pattern must preserve the original behavior.
+ VectorType resultType = cast<VectorType>(extOp->getResultTypes().front());
+ Type resultElementType = getElementTypeOrSelf(resultType);
+ SmallVector<Attribute> shifts = computeExtOfBitCastShifts(
+ ctx, interimBitWidth, targetBitWidth, interimMostMinorDim);
+ auto shiftConstantOp = rewriter.create<arith::ConstantOp>(
+ loc, DenseElementsAttr::get(vt, shifts));
+ LDBG("shiftConstant: " << shiftConstantOp);
+ Value newResult =
+ TypeSwitch<Operation *, Value>(extOp)
+ .template Case<arith::ExtSIOp>([&](arith::ExtSIOp op) {
+ Value shifted =
+ rewriter.create<arith::ShRSIOp>(loc, andOp, shiftConstantOp);
+ auto vt = shifted.getType().cast<VectorType>();
+ VectorType extVt =
+ VectorType::Builder(vt).setElementType(resultElementType);
+ Operation *res =
+ (resultElementType.getIntOrFloatBitWidth() > interimBitWidth)
+ ? rewriter.create<arith::ExtSIOp>(loc, extVt, shifted)
+ : rewriter.create<arith::TruncIOp>(loc, extVt, shifted);
+ return res->getResult(0);
+ })
+ .template Case<arith::ExtUIOp>([&](arith::ExtUIOp op) {
+ Value shifted =
+ rewriter.create<arith::ShRUIOp>(loc, andOp, shiftConstantOp);
+ auto vt = shifted.getType().cast<VectorType>();
+ VectorType extVt =
+ VectorType::Builder(vt).setElementType(resultElementType);
+ Operation *res =
+ (resultElementType.getIntOrFloatBitWidth() > interimBitWidth)
+ ? rewriter.create<arith::ExtUIOp>(loc, extVt, shifted)
+ : rewriter.create<arith::TruncIOp>(loc, extVt, shifted);
+ return res->getResult(0);
+ })
+ .Default([&](Operation *op) {
+ llvm_unreachable("unexpected op type");
+ return nullptr;
+ });
+
+ if (maybeBroadcastOp) {
+ newResult =
+ rewriter.create<vector::BroadcastOp>(loc, resultType, newResult);
+ }
+
+ return newResult;
+}
+
+namespace {
+template <typename ExtOpType>
+struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
+ using OpRewritePattern<ExtOpType>::OpRewritePattern;
+
+ RewriteExtOfBitCast(MLIRContext *context, PatternBenefit benefit)
+ : OpRewritePattern<ExtOpType>(context, benefit) {}
+
+ LogicalResult matchAndRewrite(ExtOpType extOp,
+ PatternRewriter &rewriter) const override {
+ VectorType resultTy = dyn_cast<VectorType>(extOp.getType());
+ if (!resultTy)
+ return rewriter.notifyMatchFailure(extOp, "not a vector type");
+
+ int64_t elementalBitWidth = resultTy.getElementTypeBitWidth();
+ if (elementalBitWidth & (elementalBitWidth - 1)) {
+ return rewriter.notifyMatchFailure(
+ extOp, "result bitwidth must be a power of 2");
+ }
+
+ // Provision for a potential broadcast op that will be rewritten late.
+ auto maybeBroadcastOp =
+ extOp.getIn().template getDefiningOp<vector::BroadcastOp>();
+
+ // The source must be a bitcast op.
+ auto bitCastOp =
+ maybeBroadcastOp
+ ? maybeBroadcastOp.getSource()
+ .template getDefiningOp<vector::BitCastOp>()
+ : extOp.getIn().template getDefiningOp<vector::BitCastOp>();
+ if (!bitCastOp)
+ return rewriter.notifyMatchFailure(extOp, "not a bitcast source");
+
+ // Try to rewrite.
+ FailureOr<Value> result =
+ rewriteExtOfBitCast(rewriter, extOp, bitCastOp, maybeBroadcastOp);
+ if (failed(result))
+ return failure();
+
+ rewriter.replaceOp(extOp, *result);
+ return success();
+ }
+};
+} // namespace
+
//===----------------------------------------------------------------------===//
// Public Interface Definition
//===----------------------------------------------------------------------===//
@@ -167,3 +432,10 @@ void vector::populateVectorNarrowTypeEmulationPatterns(
patterns.add<ConvertVectorLoad, ConvertVectorTransferRead>(
typeConverter, patterns.getContext());
}
+
+void vector::populateVectorNarrowTypeRewritePatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns.add<RewriteExtOfBitCast<arith::ExtSIOp>,
+ RewriteExtOfBitCast<arith::ExtUIOp>>(patterns.getContext(),
+ benefit);
+}
diff --git a/mlir/test/Dialect/LLVM/transform-e2e.mlir b/mlir/test/Dialect/LLVM/transform-e2e.mlir
index 777de75b1a47acc..2cb753a3d7fb8f3 100644
--- a/mlir/test/Dialect/LLVM/transform-e2e.mlir
+++ b/mlir/test/Dialect/LLVM/transform-e2e.mlir
@@ -29,33 +29,12 @@ transform.sequence failures(propagate) {
// lowering TD macros.
transform.apply_patterns to %f {
transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
- } : !transform.any_op
-
- transform.apply_patterns to %f {
transform.apply_patterns.vector.transfer_permutation_patterns
- } : !transform.any_op
-
- transform.apply_patterns to %f {
transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerparallel"
- } : !transform.any_op
-
- transform.apply_patterns to %f {
transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "linalg-copy"
- } : !transform.any_op
-
- transform.apply_patterns to %f {
transform.apply_patterns.vector.transfer_to_scf max_transfer_rank = 1 full_unroll = true
- } : !transform.any_op
-
- transform.apply_patterns to %f {
transform.apply_patterns.vector.lower_transfer max_transfer_rank = 1
- } : !transform.any_op
-
- transform.apply_patterns to %f {
transform.apply_patterns.vector.lower_shape_cast
- } : !transform.any_op
-
- transform.apply_patterns to %f {
transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_1d"
} : !transform.any_op
}
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
new file mode 100644
index 000000000000000..0550381bf2e2e16
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -0,0 +1,205 @@
+// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s
+
+/// Inspect generated assembly and llvm-mca stats
+/// =============================================
+/// mlir-opt --test-transform-dialect-interpreter mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir -test-transform-dialect-erase-schedule -test-lower-to-llvm | mlir-translate -mlir-to-llvmir | llc -o - -mcpu=skylake-avx512 | llvm-mca -mcpu=skylake-avx512 | less
+
+/// Rewriting the patterns is about 10x on skylake-ax512 faster according to llvm-mca.
+
+!source_element_t = i8
+!quantized_element_t = i3
+!expanded_element_t = i16
+!final_element_t = f32
+!mst = memref<1234x!source_element_t>
+!vst = vector<40x!source_element_t>
+!vst_i1 = vector<320xi1>
+!vqt_i1 = vector<288xi1>
+!vqt = vector<96x!quantized_element_t>
+!bvqt = vector<2x3x96x!quantized_element_t>
+!bvtt = vector<2x3x96x!expanded_element_t>
+!vct = vector<96x!final_element_t>
+!bvct = vector<2x3x96x!final_element_t>
+!mtt = memref<1234x!final_element_t>
+
+// CHECK-LABEL: @f1(
+// CHECK-SAME: %[[A:[a-z0-9]+]]: memref<1234xi8>
+// CHECK-SAME: %[[IDX:[a-z0-9]+]]: index
+// CHECK-SAME: %[[B:[a-z0-9]+]]: memref<1234xf32>
+func.func @f1(%m: !mst, %idx : index, %mf: !mtt) {
+
+// CHECK: %[[MASK:.*]] = arith.constant dense<[
+// CHECK-SAME-COUNT-6: 7, 56, 448, 3584, 28672, 229376, 1835008, 14680064, 117440512, 939524096, 7516192768, 60129542144, 481036337152, 3848290697216, 30786325577728, -35184372088832
+// CHECK-SAME: ]> : vector<96xi48>
+// CHECK: %[[SHIFT:.*]] = arith.constant dense<[
+// CHECK-SAME-COUNT-6: 0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45
+// CHECK-SAME: ]> : vector<96xi48>
+// CHECK: %[[LOADED:.*]] = vector.load %[[A]][%[[IDX]]] : memref<1234xi8>, vector<40xi8>
+// CHECK: %[[CAST:.*]] = vector.bitcast %[[LOADED]] : vector<40xi8> to vector<320xi1>
+// CHECK: %[[EXTRACTED:.*]] = vector.extract_strided_slice %[[CAST]] {offsets = [0], sizes = [288], strides = [1]} : vector<320xi1> to vector<288xi1>
+// CHECK: %[[EXTRACTED_COMMON:.*]] = vector.bitcast %[[EXTRACTED]] : vector<288xi1> to vector<6xi48>
+// CHECK: %[[SHUFFLED:.*]] = vector.shuffle %[[EXTRACTED_COMMON]], %[[EXTRACTED_COMMON]] [
+// CHECK-SAME-COUNT-16: 0,
+// CHECK-SAME-COUNT-16: 1,
+// CHECK-SAME-COUNT-16: 2,
+// CHECK-SAME-COUNT-16: 3,
+// CHECK-SAME-COUNT-16: 4,
+// CHECK-SAME-COUNT-16: 5,
+// CHECK-SAME: : vector<6xi48>, vector<6xi48>
+// CHECK: %[[MASKED:.*]] = arith.andi %[[SHUFFLED]], %[[MASK]] : vector<96xi48>
+// CHECK: %[[SHIFTED:.*]] = arith.shrui %[[MASKED]], %[[SHIFT]] : vector<96xi48>
+// CHECK: %[[TRUNCATED:.*]] = arith.trunci %[[SHIFTED]] : vector<96xi48> to vector<96xi16>
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[TRUNCATED]] : vector<96xi16> to vector<2x3x96xi16>
+// CHECK: %[[CONVERTED:.*]] = arith.uitofp %[[BCAST]] : vector<2x3x96xi16> to vector<2x3x96xf32>
+
+ %v = vector.load %m[%idx] : !mst, !vst
+ %bi1 = vector.bitcast %v : !vst to !vst_i1
+ %ei1 = vector.extract_strided_slice %bi1 {offsets = [0], sizes = [288], strides = [1]} : !vst_i1 to !vqt_i1
+ %b = vector.bitcast %ei1 : !vqt_i1 to !vqt
+ %bb = vector.broadcast %b : !vqt to !bvqt
+ %be = arith.extui %bb : !bvqt to !bvtt
+ %bf = arith.uitofp %be : !bvtt to !bvct
+
+// CHECK: vector.extract %{{.*}}[0, 0] : vector<2x3x96xf32>
+// CHECK: vector.load %{{.*}} : memref<1234xf32>, vector<96xf32>
+// CHECK: arith.addf %{{.*}} : vector<96xf32>
+// CHECK: vector.store %{{.*}} : memref<1234xf32>, vector<96xf32>
+// CHECK: return
+ %f = vector.extract %bf[0, 0] : !bvct
+ %vf = vector.load %mf[%idx] : !mtt, !vct
+ %res = arith.addf %vf, %f : !vct
+ vector.store %res, %mf[%idx] : !mtt, !vct
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !transform.any_op):
+ %f = transform.structured.match ops{["func.func"]} in %module_op
+ : (!transform.any_op) -> !transform.any_op
+
+ transform.apply_patterns to %f {
+ transform.apply_patterns.vector.rewrite_narrow_types
+ } : !transform.any_op
+}
+
+// -----
+
+
+/// Rewriting the patterns is about 8x faster on skylake-ax512 according to llvm-mca.
+
+!source_element_t = i8
+!quantized_element_t = i4
+!expanded_element_t = i32
+!final_element_t = f16
+!mst = memref<1234x!source_element_t>
+!vst = vector<16x!source_element_t>
+!vqt = vector<32x!quantized_element_t>
+!vct = vector<32x!final_element_t>
+!vtt = vector<32x!expanded_element_t>
+!mtt = memref<1234x!final_element_t>
+
+// CHECK-LABEL: @f2(
+// CHECK-SAME: %[[A:[a-z0-9]+]]: memref<1234xi8>
+// CHECK-SAME: %[[IDX:[a-z0-9]+]]: index
+// CHECK-SAME: %[[B:[a-z0-9]+]]: memref<1234xf16>
+func.func @f2(%m: !mst, %idx : index, %mf: !mtt) {
+
+// CHECK: %[[MASK:.*]] = arith.constant dense<[
+// CHECK-SAME-COUNT-4: 5, 240, 3840, 61440, 983040, 15728640, 251658240, -268435456
+// CHECK-SAME: : vector<32xi32>
+// CHECK: %[[SHIFT:.*]] = arith.constant dense<[
+// CHECK-SAME-COUNT-4: 0, 4, 8, 12, 16, 20, 24, 28
+// CHECK-SAME: : vector<32xi32>
+// CHECK: %[[LOADED:.*]] = vector.load %[[A]][%[[IDX]]] : memref<1234xi8>, vector<16xi8>
+// CHECK: %[[CAST:.*]] = vector.bitcast %[[LOADED]] : vector<16xi8> to vector<4xi32>
+// CHECK: %[[SHUFFLED:.*]] = vector.shuffle %[[CAST]], %[[CAST]] [
+// CHECK-SAME-COUNT-8: 0
+// CHECK-SAME-COUNT-8: 1
+// CHECK-SAME-COUNT-8: 2
+// CHECK-SAME-COUNT-8: 3
+// CHECK-SAME: : vector<4xi32>, vector<4xi32>
+// CHECK: %[[MASKED:.*]] = arith.andi %[[SHUFFLED]], %[[MASK]] : vector<32xi32>
+// CHECK: %[[SHIFTED:.*]] = arith.shrui %[[MASKED]], %[[SHIFT]] : vector<32xi32>
+// CHECK: %[[TRUNC:.*]] = arith.trunci %[[SHIFTED]] : vector<32xi32> to vector<32xi16>
+// CHECK: %[[CONVERTED:.*]] = arith.uitofp %[[TRUNC]] : vector<32xi16> to vector<32xf16>
+ %v = vector.load %m[%idx] : !mst, !vst
+ %b = vector.bitcast %v : !vst to !vqt
+ %be = arith.extui %b : !vqt to !vtt
+ %bf = arith.uitofp %be : !vtt to !vct
+ %vf = vector.load %mf[%idx] : !mtt, !vct
+ %res = arith.addf %vf, %bf : !vct
+ vector.store %res, %mf[%idx] : !mtt, !vct
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !transform.any_op):
+ %f = transform.structured.match ops{["func.func"]} in %module_op
+ : (!transform.any_op) -> !transform.any_op
+
+ transform.apply_patterns to %f {
+ transform.apply_patterns.vector.rewrite_narrow_types
+ } : !transform.any_op
+}
+
+
+// -----
+
+/// Rewriting the patterns is about 6x faster on skylake-ax512 according to llvm-mca.
+
+!source_element_t = i8
+!quantized_element_t = i5
+!expanded_element_t = i64
+!final_element_t = f32
+!mst = memref<1234x!source_element_t>
+!vst = vector<10x!source_element_t>
+!vqt = vector<16x!quantized_element_t>
+!bvqt = vector<2x3x16x!quantized_element_t>
+!bvtt = vector<2x3x16x!expanded_element_t>
+!vct = vector<16x!final_element_t>
+!bvct = vector<2x3x16x!final_element_t>
+!mtt = memref<1234x!final_element_t>
+
+// CHECK-LABEL: @f3(
+// CHECK-SAME: %[[A:[a-z0-9]+]]: memref<1234xi8>
+// CHECK-SAME: %[[IDX:[a-z0-9]+]]: index
+// CHECK-SAME: %[[B:[a-z0-9]+]]: memref<1234xf32>
+func.func @f3(%m: !mst, %idx : index, %mf: !mtt) {
+
+// CHECK: %[[MASK:.*]] = arith.constant dense<[
+// CHECK-SAME-COUNT-2: 31, 992, 31744, 1015808, 32505856, 1040187392, 33285996544, -34359738368
+// CHECK-SAME: : vector<16xi40>
+// CHECK: %[[SHIFT:.*]] = arith.constant dense<[
+// CHECK-SAME-COUNT-2: 0, 5, 10, 15, 20, 25, 30, 35
+// CHECK-SAME: : vector<16xi40>
+// CHECK: %[[LOADED:.*]] = vector.load %[[A]][%[[IDX]]] : memref<1234xi8>, vector<10xi8>
+// CHECK: %[[CAST:.*]] = vector.bitcast %[[LOADED]] : vector<10xi8> to vector<2xi40>
+// CHECK: %[[SHUFFLED:.*]] = vector.shuffle %[[CAST]], %[[CAST]] [
+// CHECK-SAME-COUNT-8: 0
+// CHECK-SAME-COUNT-8: 1
+// CHECK-SAME: : vector<2xi40>, vector<2xi40>
+// CHECK: %[[MASKED:.*]] = arith.andi %[[SHUFFLED]], %[[MASK]] : vector<16xi40>
+// CHECK: %[[SHIFTED:.*]] = arith.shrui %[[MASKED]], %[[SHIFT]] : vector<16xi40>
+// CHECK: %[[EXT:.*]] = arith.extui %[[SHIFTED]] : vector<16xi40> to vector<16xi64>
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[EXT]] : vector<16xi64> to vector<2x3x16xi64>
+// CHECK: %[[CONVERTED:.*]] = arith.uitofp %[[BCAST]] : vector<2x3x16xi64> to vector<2x3x16xf32>
+ %v = vector.load %m[%idx] : !mst, !vst
+ %b = vector.bitcast %v : !vst to !vqt
+ %bb = vector.broadcast %b : !vqt to !bvqt
+ %be = arith.extui %bb : !bvqt to !bvtt
+ %bf = arith.uitofp %be : !bvtt to !bvct
+ %f = vector.extract %bf[0, 0] : !bvct
+ %vf = vector.load %mf[%idx] : !mtt, !vct
+ %res = arith.addf %vf, %f : !vct
+ vector.store %res, %mf[%idx] : !mtt, !vct
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !transform.any_op):
+ %f = transform.structured.match ops{["func.func"]} in %module_op
+ : (!transform.any_op) -> !transform.any_op
+
+ transform.apply_patterns to %f {
+ transform.apply_patterns.vector.rewrite_narrow_types
+ } : !transform.any_op
+}
More information about the Mlir-commits
mailing list