[Mlir-commits] [mlir] [mlir][Vector] Add a rewrite pattern for better low-precision ext(bit… (PR #65774)

Nicolas Vasilache llvmlistbot at llvm.org
Tue Sep 12 08:27:15 PDT 2023


https://github.com/nicolasvasilache updated https://github.com/llvm/llvm-project/pull/65774:

>From 11caf7cacb71874a78474bef25ecc27d68e642e2 Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nicolasvasilache at users.noreply.github.com>
Date: Fri, 8 Sep 2023 11:28:55 +0200
Subject: [PATCH] [mlir][Vector] Add a rewrite pattern for better low-precision
 ext(bitcast) expansion

This revision adds a rewrite for sequences of vector `ext(maybe_broadcast(bitcast))`
to use a more efficient sequence of vector operations comprising `shuffle`, `shift` and
`bitwise` ops.

The rewrite uses an intermediate bitwidth equal to the licm of
the element types of the source and result types of `bitCastOp`. This
intermediate type may be small or greater than the desired elemental type of
the `ext`, in which case appropriate `ext` or `trunc` operations are inserted.

The rewrite fails if the intermediate type is greater than `64` and if the
involved vector types fail to meet basic divisilibity requirements. In other
words, this rewrite does not handle partial vector boundaries and leaves
this part of the heavy-lifting to LLVM.

In the future, it may be relevant to give control on the size of the intermediate type.
For now, it is empirically determined that taking `64` result in much better assembly
being produced when piping through `llvm-mca`.
---
 .../Vector/TransformOps/VectorTransformOps.td |  13 +
 .../Vector/Transforms/VectorRewritePatterns.h |  22 +-
 mlir/include/mlir/IR/BuiltinTypes.h           |  10 +
 .../TransformOps/VectorTransformOps.cpp       |   5 +
 .../Transforms/VectorEmulateNarrowType.cpp    | 280 +++++++++++++++++-
 mlir/test/Dialect/LLVM/transform-e2e.mlir     |  21 --
 .../Vector/vector-rewrite-narrow-types.mlir   | 205 +++++++++++++
 7 files changed, 530 insertions(+), 26 deletions(-)
 create mode 100644 mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir

diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 2b8c95a94257e6c..d1cef91f8e27525 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -281,6 +281,19 @@ def ApplyLowerTransposePatternsOp : Op<Transform_Dialect,
   }];
 }
 
+def ApplyRewriteNarrowTypePatternsOp : Op<Transform_Dialect,
+    "apply_patterns.vector.rewrite_narrow_types",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Indicates that vector narrow rewrite operations should be applied.
+
+    This is usually a late step that is run after bufferization as part of the
+    process of lowering to e.g. LLVM or NVVM.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
 def ApplySplitTransferFullPartialPatternsOp : Op<Transform_Dialect,
     "apply_patterns.vector.split_transfer_full_partial",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index c644090d8c78cd0..20c33921f9de24e 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -23,6 +23,7 @@ namespace mlir {
 class RewritePatternSet;
 
 namespace arith {
+class AndIOp;
 class NarrowTypeEmulationConverter;
 } // namespace arith
 
@@ -143,7 +144,7 @@ void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
 
 /// Patterns that remove redundant vector broadcasts.
 void populateSinkVectorBroadcastPatterns(RewritePatternSet &patterns,
-                                          PatternBenefit benefit = 1);
+                                         PatternBenefit benefit = 1);
 
 /// Populate `patterns` with the following patterns.
 ///
@@ -301,6 +302,25 @@ void populateVectorNarrowTypeEmulationPatterns(
     arith::NarrowTypeEmulationConverter &typeConverter,
     RewritePatternSet &patterns);
 
+/// Rewrite vector ext(maybe_broadcast(bitcast)) to use a more efficient
+/// sequence of vector operations comprising shuffles, shifts and bitwise
+/// logical ops. The rewrite uses an intermediate bitwidth equal to the licm of
+/// the element types of the source and result types of `bitCastOp`. This
+/// intermediate type may be small or greater than the desired elemental type of
+/// the extOp, in which case appropriate ext or trunc operations are inserted.
+/// The rewrite fails if the intermediate type is greater than 64 and if the
+/// involved vector types fail to meet basic divisilibity requirements. In other
+/// words, this rewrite does not handle partial vector boundaries and leaves
+/// this part of the heavy-lifting to LLVM.
+FailureOr<Value> rewriteExtOfBitCast(RewriterBase &rewriter, Operation *extOp,
+                                     vector::BitCastOp bitCastOp,
+                                     vector::BroadcastOp maybeBroadcastOp);
+
+/// Appends patterns for rewriting vector operations over narrow types with
+/// ops over wider types.
+void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
+                                             PatternBenefit benefit = 1);
+
 } // namespace vector
 } // namespace mlir
 
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index f031eb0a5c30ce9..9df5548cd5d939c 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -357,6 +357,16 @@ class VectorType::Builder {
     return *this;
   }
 
+  /// Set a dim in shape @pos to val.
+  Builder &setDim(unsigned pos, int64_t val) {
+    if (storage.empty())
+      storage.append(shape.begin(), shape.end());
+    assert(pos < storage.size() && "overflow");
+    storage[pos] = val;
+    shape = {storage.data(), storage.size()};
+    return *this;
+  }
+
   operator VectorType() {
     return VectorType::get(shape, elementType, scalableDims);
   }
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 94f19e59669eafd..0fdeded436a9773 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -154,6 +154,11 @@ void transform::ApplyLowerTransposePatternsOp::populatePatterns(
   }
 }
 
+void transform::ApplyRewriteNarrowTypePatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  populateVectorNarrowTypeRewritePatterns(patterns);
+}
+
 void transform::ApplySplitTransferFullPartialPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
   vector::VectorTransformsOptions vectorTransformOptions;
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index b2b7bfc5e4437c1..d10994f3709e390 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -7,7 +7,6 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
@@ -15,13 +14,29 @@
 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Support/MathExtras.h"
 #include "mlir/Transforms/DialectConversion.h"
-#include "llvm/Support/FormatVariadic.h"
-#include "llvm/Support/MathExtras.h"
-#include <cassert>
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/raw_ostream.h"
+#include <cstdint>
+#include <numeric>
+#include <type_traits>
 
 using namespace mlir;
 
+#define DEBUG_TYPE "vector-narrow-type-emulation"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+#define DBGSNL() (llvm::dbgs() << "\n")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
+
 namespace {
 
 //===----------------------------------------------------------------------===//
@@ -155,6 +170,256 @@ struct ConvertVectorTransferRead final
 };
 } // end anonymous namespace
 
+//===----------------------------------------------------------------------===//
+// RewriteExtOfBitCast
+//===----------------------------------------------------------------------===//
+
+/// Create a vector of bit masks: `idx .. idx + step - 1` and broadcast it
+/// `numOccurrences` times.
+/// `step` must divide `bitwidth` evenly.
+/// Example: (4, 2, 3) -> [0xc, 0x3, 0xc, 0x3, 0xc, 0x3].
+static SmallVector<Attribute> computeExtOfBitCastMasks(MLIRContext *ctx,
+                                                       int64_t bitwidth,
+                                                       int64_t step,
+                                                       int64_t numOccurrences) {
+  assert(bitwidth % step == 0 && "step must divide bitwidth evenly");
+  IntegerType interimIntType = IntegerType::get(ctx, bitwidth);
+  SmallVector<Attribute> tmpMasks;
+  tmpMasks.reserve(bitwidth / step);
+  // Create a vector of bit masks: `idx .. idx + step - 1`.
+  for (int64_t idx = 0; idx < bitwidth; idx += step) {
+    LDBG("Mask bits " << idx << " .. " << idx + step - 1 << " out of "
+                      << bitwidth);
+    IntegerAttr mask = IntegerAttr::get(
+        interimIntType, llvm::APInt::getBitsSet(bitwidth, idx, idx + step));
+    tmpMasks.push_back(mask);
+  }
+  // Replicate the vector of bit masks to the desired size.
+  SmallVector<Attribute> masks;
+  masks.reserve(numOccurrences * tmpMasks.size());
+  for (int64_t idx = 0; idx < numOccurrences; ++idx)
+    llvm::append_range(masks, tmpMasks);
+  return masks;
+}
+
+/// Create a vector of bit shifts by `k * idx` and broadcast it `numOccurrences`
+/// times.
+/// `step` must divide `bitwidth` evenly.
+/// Example: (4, 2, 3) -> [0x0, 0x2, 0x0, 0x2, 0x0, 0x2].
+static SmallVector<Attribute>
+computeExtOfBitCastShifts(MLIRContext *ctx, int64_t bitwidth, int64_t step,
+                          int64_t numOccurrences) {
+  assert(bitwidth % step == 0 && "step must divide bitwidth evenly");
+  IntegerType interimIntType = IntegerType::get(ctx, bitwidth);
+  SmallVector<Attribute> tmpShifts;
+  for (int64_t idx = 0; idx < bitwidth; idx += step) {
+    IntegerAttr shift = IntegerAttr::get(interimIntType, idx);
+    tmpShifts.push_back(shift);
+  }
+  SmallVector<Attribute> shifts;
+  for (int64_t idx = 0; idx < numOccurrences; ++idx)
+    llvm::append_range(shifts, tmpShifts);
+  return shifts;
+}
+
+/// Create a vector of bit shuffles: `numOccurrences * idx` and broadcast it
+/// `bitwidth/step` times.
+/// `step` must divide `bitwidth` evenly.
+/// Example: (4, 2, 3) -> [0x0, 0x1, 0x0, 0x1, 0x0, 0x1].
+static SmallVector<int64_t>
+computeExtOfBitCastShuffles(MLIRContext *ctx, int64_t bitwidth, int64_t step,
+                            int64_t numOccurrences) {
+  assert(bitwidth % step == 0 && "step must divide bitwidth evenly");
+  SmallVector<int64_t> shuffles;
+  int64_t n = floorDiv(bitwidth, step);
+  for (int64_t idx = 0; idx < numOccurrences; ++idx)
+    llvm::append_range(shuffles, SmallVector<int64_t>(n, idx));
+  return shuffles;
+}
+
+/// Compute the intermediate vector type, its elemental type must be an integer
+/// with bitwidth that:
+///   1. is smaller than 64 (TODO: in the future we may want target-specific
+///   control).
+///   2. divides sourceBitWidth * mostMinorSourceDim
+static int64_t computeExtOfBitCastBitWidth(int64_t sourceBitWidth,
+                                           int64_t mostMinorSourceDim,
+                                           int64_t targetBitWidth) {
+  for (int64_t mult : {32, 16, 8, 4, 2, 1}) {
+    int64_t interimBitWidth =
+        std::lcm(mult, std::lcm(sourceBitWidth, targetBitWidth));
+    if (interimBitWidth > 64)
+      continue;
+    if ((sourceBitWidth * mostMinorSourceDim) % interimBitWidth != 0)
+      continue;
+    return interimBitWidth;
+  }
+  return 0;
+}
+
+FailureOr<Value>
+mlir::vector::rewriteExtOfBitCast(RewriterBase &rewriter, Operation *extOp,
+                                  vector::BitCastOp bitCastOp,
+                                  vector::BroadcastOp maybeBroadcastOp) {
+  assert(
+      (llvm::isa<arith::ExtSIOp>(extOp) || llvm::isa<arith::ExtUIOp>(extOp)) &&
+      "unsupported op");
+
+  // The bitcast op is the load-bearing part, capture the source and bitCast
+  // types as well as bitwidth and most minor dimension.
+  VectorType sourceVectorType = bitCastOp.getSourceVectorType();
+  int64_t sourceBitWidth = sourceVectorType.getElementTypeBitWidth();
+  int64_t mostMinorSourceDim = sourceVectorType.getShape().back();
+  LDBG("sourceVectorType: " << sourceVectorType);
+
+  VectorType bitCastVectorType = bitCastOp.getResultVectorType();
+  int64_t targetBitWidth = bitCastVectorType.getElementTypeBitWidth();
+  LDBG("bitCastVectorType: " << bitCastVectorType);
+
+  int64_t interimBitWidth = computeExtOfBitCastBitWidth(
+      sourceBitWidth, mostMinorSourceDim, targetBitWidth);
+  LDBG("interimBitWidth: " << interimBitWidth);
+  if (!interimBitWidth) {
+    return rewriter.notifyMatchFailure(
+        extOp, "heuristic could not find a reasonable interim bitwidth");
+  }
+  if (sourceBitWidth == interimBitWidth || targetBitWidth == interimBitWidth) {
+    return rewriter.notifyMatchFailure(
+        extOp, "interim bitwidth is equal to source or target, nothing to do");
+  }
+
+  int64_t interimMostMinorDim =
+      sourceBitWidth * mostMinorSourceDim / interimBitWidth;
+  LDBG("interimMostMinorDim: " << interimMostMinorDim);
+
+  Location loc = extOp->getLoc();
+  MLIRContext *ctx = extOp->getContext();
+
+  VectorType interimVectorType =
+      VectorType::Builder(sourceVectorType)
+          .setDim(sourceVectorType.getRank() - 1, interimMostMinorDim)
+          .setElementType(IntegerType::get(ctx, interimBitWidth));
+  LDBG("interimVectorType: " << interimVectorType);
+
+  IntegerType interimIntType = IntegerType::get(ctx, interimBitWidth);
+  VectorType vt =
+      VectorType::Builder(bitCastVectorType).setElementType(interimIntType);
+
+  // Rewrite the original bitcast to the interim vector type and shuffle to
+  // broadcast to the desired size.
+  auto newBitCastOp = rewriter.create<vector::BitCastOp>(loc, interimVectorType,
+                                                         bitCastOp.getSource());
+  SmallVector<int64_t> shuffles = computeExtOfBitCastShuffles(
+      ctx, interimBitWidth, targetBitWidth, interimMostMinorDim);
+  auto shuffleOp = rewriter.create<vector::ShuffleOp>(loc, newBitCastOp,
+                                                      newBitCastOp, shuffles);
+  LDBG("shuffle: " << shuffleOp);
+
+  // Compute the constants for masking.
+  SmallVector<Attribute> masks = computeExtOfBitCastMasks(
+      ctx, interimBitWidth, targetBitWidth, interimMostMinorDim);
+  auto maskConstantOp = rewriter.create<arith::ConstantOp>(
+      loc, DenseElementsAttr::get(vt, masks));
+  LDBG("maskConstant: " << maskConstantOp);
+  auto andOp = rewriter.create<arith::AndIOp>(loc, shuffleOp, maskConstantOp);
+  LDBG("andOp: " << andOp);
+
+  // Preserve the intermediate type: this may have serious consequences on the
+  // backend's ability to generate efficient vector operations.
+  // For instance on x86, converting to f16 without going through i32 has severe
+  // performance implications.
+  // As a consequence, this pattern must preserve the original behavior.
+  VectorType resultType = cast<VectorType>(extOp->getResultTypes().front());
+  Type resultElementType = getElementTypeOrSelf(resultType);
+  SmallVector<Attribute> shifts = computeExtOfBitCastShifts(
+      ctx, interimBitWidth, targetBitWidth, interimMostMinorDim);
+  auto shiftConstantOp = rewriter.create<arith::ConstantOp>(
+      loc, DenseElementsAttr::get(vt, shifts));
+  LDBG("shiftConstant: " << shiftConstantOp);
+  Value newResult =
+      TypeSwitch<Operation *, Value>(extOp)
+          .template Case<arith::ExtSIOp>([&](arith::ExtSIOp op) {
+            Value shifted =
+                rewriter.create<arith::ShRSIOp>(loc, andOp, shiftConstantOp);
+            auto vt = shifted.getType().cast<VectorType>();
+            VectorType extVt =
+                VectorType::Builder(vt).setElementType(resultElementType);
+            Operation *res =
+                (resultElementType.getIntOrFloatBitWidth() > interimBitWidth)
+                    ? rewriter.create<arith::ExtSIOp>(loc, extVt, shifted)
+                    : rewriter.create<arith::TruncIOp>(loc, extVt, shifted);
+            return res->getResult(0);
+          })
+          .template Case<arith::ExtUIOp>([&](arith::ExtUIOp op) {
+            Value shifted =
+                rewriter.create<arith::ShRUIOp>(loc, andOp, shiftConstantOp);
+            auto vt = shifted.getType().cast<VectorType>();
+            VectorType extVt =
+                VectorType::Builder(vt).setElementType(resultElementType);
+            Operation *res =
+                (resultElementType.getIntOrFloatBitWidth() > interimBitWidth)
+                    ? rewriter.create<arith::ExtUIOp>(loc, extVt, shifted)
+                    : rewriter.create<arith::TruncIOp>(loc, extVt, shifted);
+            return res->getResult(0);
+          })
+          .Default([&](Operation *op) {
+            llvm_unreachable("unexpected op type");
+            return nullptr;
+          });
+
+  if (maybeBroadcastOp) {
+    newResult =
+        rewriter.create<vector::BroadcastOp>(loc, resultType, newResult);
+  }
+
+  return newResult;
+}
+
+namespace {
+template <typename ExtOpType>
+struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
+  using OpRewritePattern<ExtOpType>::OpRewritePattern;
+
+  RewriteExtOfBitCast(MLIRContext *context, PatternBenefit benefit)
+      : OpRewritePattern<ExtOpType>(context, benefit) {}
+
+  LogicalResult matchAndRewrite(ExtOpType extOp,
+                                PatternRewriter &rewriter) const override {
+    VectorType resultTy = dyn_cast<VectorType>(extOp.getType());
+    if (!resultTy)
+      return rewriter.notifyMatchFailure(extOp, "not a vector type");
+
+    int64_t elementalBitWidth = resultTy.getElementTypeBitWidth();
+    if (elementalBitWidth & (elementalBitWidth - 1)) {
+      return rewriter.notifyMatchFailure(
+          extOp, "result bitwidth must be a power of 2");
+    }
+
+    // Provision for a potential broadcast op that will be rewritten late.
+    auto maybeBroadcastOp =
+        extOp.getIn().template getDefiningOp<vector::BroadcastOp>();
+
+    // The source must be a bitcast op.
+    auto bitCastOp =
+        maybeBroadcastOp
+            ? maybeBroadcastOp.getSource()
+                  .template getDefiningOp<vector::BitCastOp>()
+            : extOp.getIn().template getDefiningOp<vector::BitCastOp>();
+    if (!bitCastOp)
+      return rewriter.notifyMatchFailure(extOp, "not a bitcast source");
+
+    // Try to rewrite.
+    FailureOr<Value> result =
+        rewriteExtOfBitCast(rewriter, extOp, bitCastOp, maybeBroadcastOp);
+    if (failed(result))
+      return failure();
+
+    rewriter.replaceOp(extOp, *result);
+    return success();
+  }
+};
+} // namespace
+
 //===----------------------------------------------------------------------===//
 // Public Interface Definition
 //===----------------------------------------------------------------------===//
@@ -167,3 +432,10 @@ void vector::populateVectorNarrowTypeEmulationPatterns(
   patterns.add<ConvertVectorLoad, ConvertVectorTransferRead>(
       typeConverter, patterns.getContext());
 }
+
+void vector::populateVectorNarrowTypeRewritePatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<RewriteExtOfBitCast<arith::ExtSIOp>,
+               RewriteExtOfBitCast<arith::ExtUIOp>>(patterns.getContext(),
+                                                    benefit);
+}
diff --git a/mlir/test/Dialect/LLVM/transform-e2e.mlir b/mlir/test/Dialect/LLVM/transform-e2e.mlir
index 777de75b1a47acc..2cb753a3d7fb8f3 100644
--- a/mlir/test/Dialect/LLVM/transform-e2e.mlir
+++ b/mlir/test/Dialect/LLVM/transform-e2e.mlir
@@ -29,33 +29,12 @@ transform.sequence failures(propagate) {
   // lowering TD macros.
   transform.apply_patterns to %f {
     transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
-  } : !transform.any_op
-
-  transform.apply_patterns to %f {
     transform.apply_patterns.vector.transfer_permutation_patterns
-  } : !transform.any_op
-
-  transform.apply_patterns to %f {
     transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerparallel"
-  } : !transform.any_op
-
-  transform.apply_patterns to %f {
     transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "linalg-copy"
-  } : !transform.any_op
-
-  transform.apply_patterns to %f {
     transform.apply_patterns.vector.transfer_to_scf max_transfer_rank = 1 full_unroll = true
-  } : !transform.any_op
-
-  transform.apply_patterns to %f {
     transform.apply_patterns.vector.lower_transfer max_transfer_rank = 1
-  } : !transform.any_op
-
-  transform.apply_patterns to %f {
     transform.apply_patterns.vector.lower_shape_cast
-  } : !transform.any_op
-
-  transform.apply_patterns to %f {
     transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_1d"
   } : !transform.any_op
 }
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
new file mode 100644
index 000000000000000..0550381bf2e2e16
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -0,0 +1,205 @@
+// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s
+
+/// Inspect generated assembly and llvm-mca stats
+/// =============================================
+/// mlir-opt --test-transform-dialect-interpreter mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir -test-transform-dialect-erase-schedule -test-lower-to-llvm | mlir-translate -mlir-to-llvmir | llc -o - -mcpu=skylake-avx512 | llvm-mca -mcpu=skylake-avx512 | less
+
+/// Rewriting the patterns is about 10x on skylake-ax512 faster according to llvm-mca.
+
+!source_element_t = i8
+!quantized_element_t = i3
+!expanded_element_t = i16
+!final_element_t = f32
+!mst = memref<1234x!source_element_t>
+!vst = vector<40x!source_element_t>
+!vst_i1 = vector<320xi1>
+!vqt_i1 = vector<288xi1>
+!vqt = vector<96x!quantized_element_t>
+!bvqt = vector<2x3x96x!quantized_element_t>
+!bvtt = vector<2x3x96x!expanded_element_t>
+!vct = vector<96x!final_element_t>
+!bvct = vector<2x3x96x!final_element_t>
+!mtt = memref<1234x!final_element_t>
+
+// CHECK-LABEL: @f1(
+//  CHECK-SAME: %[[A:[a-z0-9]+]]: memref<1234xi8>
+//  CHECK-SAME: %[[IDX:[a-z0-9]+]]: index
+//  CHECK-SAME: %[[B:[a-z0-9]+]]: memref<1234xf32>
+func.func @f1(%m: !mst, %idx : index, %mf: !mtt) {
+
+//              CHECK: %[[MASK:.*]] = arith.constant dense<[
+// CHECK-SAME-COUNT-6: 7, 56, 448, 3584, 28672, 229376, 1835008, 14680064, 117440512, 939524096, 7516192768, 60129542144, 481036337152, 3848290697216, 30786325577728, -35184372088832
+//         CHECK-SAME: ]> : vector<96xi48>                             
+//              CHECK: %[[SHIFT:.*]] = arith.constant dense<[
+// CHECK-SAME-COUNT-6: 0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45
+//         CHECK-SAME: ]> : vector<96xi48>
+//              CHECK: %[[LOADED:.*]] = vector.load %[[A]][%[[IDX]]] : memref<1234xi8>, vector<40xi8>                                                                    
+//              CHECK: %[[CAST:.*]] = vector.bitcast %[[LOADED]] : vector<40xi8> to vector<320xi1>
+//              CHECK: %[[EXTRACTED:.*]] = vector.extract_strided_slice %[[CAST]] {offsets = [0], sizes = [288], strides = [1]} : vector<320xi1> to vector<288xi1>
+//              CHECK: %[[EXTRACTED_COMMON:.*]] = vector.bitcast %[[EXTRACTED]] : vector<288xi1> to vector<6xi48>
+//              CHECK: %[[SHUFFLED:.*]] = vector.shuffle %[[EXTRACTED_COMMON]], %[[EXTRACTED_COMMON]] [
+// CHECK-SAME-COUNT-16: 0,
+// CHECK-SAME-COUNT-16: 1,
+// CHECK-SAME-COUNT-16: 2,
+// CHECK-SAME-COUNT-16: 3,
+// CHECK-SAME-COUNT-16: 4,
+// CHECK-SAME-COUNT-16: 5,
+//          CHECK-SAME: : vector<6xi48>, vector<6xi48>
+//              CHECK: %[[MASKED:.*]] = arith.andi %[[SHUFFLED]], %[[MASK]] : vector<96xi48>
+//              CHECK: %[[SHIFTED:.*]] = arith.shrui %[[MASKED]], %[[SHIFT]] : vector<96xi48>
+//              CHECK: %[[TRUNCATED:.*]] = arith.trunci %[[SHIFTED]] : vector<96xi48> to vector<96xi16>
+//              CHECK: %[[BCAST:.*]] = vector.broadcast %[[TRUNCATED]] : vector<96xi16> to vector<2x3x96xi16>
+//              CHECK: %[[CONVERTED:.*]] = arith.uitofp %[[BCAST]] : vector<2x3x96xi16> to vector<2x3x96xf32>
+
+  %v = vector.load %m[%idx] : !mst, !vst
+  %bi1 = vector.bitcast %v : !vst to !vst_i1
+  %ei1 = vector.extract_strided_slice %bi1 {offsets = [0], sizes = [288], strides = [1]} : !vst_i1 to !vqt_i1
+  %b = vector.bitcast %ei1 : !vqt_i1 to !vqt
+  %bb = vector.broadcast %b : !vqt to !bvqt
+  %be = arith.extui %bb : !bvqt to !bvtt
+  %bf = arith.uitofp %be : !bvtt to !bvct
+  
+// CHECK: vector.extract %{{.*}}[0, 0] : vector<2x3x96xf32>
+// CHECK: vector.load %{{.*}} : memref<1234xf32>, vector<96xf32>
+// CHECK: arith.addf %{{.*}} : vector<96xf32>
+// CHECK: vector.store %{{.*}} : memref<1234xf32>, vector<96xf32>
+// CHECK: return
+  %f = vector.extract %bf[0, 0] : !bvct
+  %vf = vector.load %mf[%idx] : !mtt, !vct
+  %res = arith.addf %vf, %f : !vct
+  vector.store %res, %mf[%idx] : !mtt, !vct
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !transform.any_op):
+  %f = transform.structured.match ops{["func.func"]} in %module_op 
+    : (!transform.any_op) -> !transform.any_op
+
+  transform.apply_patterns to %f {
+    transform.apply_patterns.vector.rewrite_narrow_types
+  } : !transform.any_op
+}
+
+// -----
+
+
+/// Rewriting the patterns is about 8x faster on skylake-ax512 according to llvm-mca.
+
+!source_element_t = i8
+!quantized_element_t = i4
+!expanded_element_t = i32
+!final_element_t = f16
+!mst = memref<1234x!source_element_t>
+!vst = vector<16x!source_element_t>
+!vqt = vector<32x!quantized_element_t>
+!vct = vector<32x!final_element_t>
+!vtt = vector<32x!expanded_element_t>
+!mtt = memref<1234x!final_element_t>
+
+// CHECK-LABEL: @f2(
+//  CHECK-SAME: %[[A:[a-z0-9]+]]: memref<1234xi8>
+//  CHECK-SAME: %[[IDX:[a-z0-9]+]]: index
+//  CHECK-SAME: %[[B:[a-z0-9]+]]: memref<1234xf16>
+func.func @f2(%m: !mst, %idx : index, %mf: !mtt) {
+
+//              CHECK: %[[MASK:.*]] = arith.constant dense<[
+// CHECK-SAME-COUNT-4: 5, 240, 3840, 61440, 983040, 15728640, 251658240, -268435456
+//         CHECK-SAME: : vector<32xi32>  
+//              CHECK: %[[SHIFT:.*]] = arith.constant dense<[
+// CHECK-SAME-COUNT-4: 0, 4, 8, 12, 16, 20, 24, 28
+//         CHECK-SAME: : vector<32xi32>
+//              CHECK: %[[LOADED:.*]] = vector.load %[[A]][%[[IDX]]] : memref<1234xi8>, vector<16xi8>
+//              CHECK: %[[CAST:.*]] = vector.bitcast %[[LOADED]] : vector<16xi8> to vector<4xi32>
+//              CHECK: %[[SHUFFLED:.*]] = vector.shuffle %[[CAST]], %[[CAST]] [
+// CHECK-SAME-COUNT-8: 0
+// CHECK-SAME-COUNT-8: 1
+// CHECK-SAME-COUNT-8: 2
+// CHECK-SAME-COUNT-8: 3
+//         CHECK-SAME: : vector<4xi32>, vector<4xi32>
+//              CHECK: %[[MASKED:.*]] = arith.andi %[[SHUFFLED]], %[[MASK]] : vector<32xi32>
+//              CHECK: %[[SHIFTED:.*]] = arith.shrui %[[MASKED]], %[[SHIFT]] : vector<32xi32>
+//              CHECK: %[[TRUNC:.*]] = arith.trunci %[[SHIFTED]] : vector<32xi32> to vector<32xi16>
+//              CHECK: %[[CONVERTED:.*]] = arith.uitofp %[[TRUNC]] : vector<32xi16> to vector<32xf16>
+  %v = vector.load %m[%idx] : !mst, !vst
+  %b = vector.bitcast %v : !vst to !vqt
+  %be = arith.extui %b : !vqt to !vtt
+  %bf = arith.uitofp %be : !vtt to !vct
+  %vf = vector.load %mf[%idx] : !mtt, !vct
+  %res = arith.addf %vf, %bf : !vct
+  vector.store %res, %mf[%idx] : !mtt, !vct
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !transform.any_op):
+  %f = transform.structured.match ops{["func.func"]} in %module_op 
+    : (!transform.any_op) -> !transform.any_op
+
+  transform.apply_patterns to %f {
+    transform.apply_patterns.vector.rewrite_narrow_types
+  } : !transform.any_op
+}
+
+
+// -----
+
+/// Rewriting the patterns is about 6x faster on skylake-ax512 according to llvm-mca.
+
+!source_element_t = i8
+!quantized_element_t = i5
+!expanded_element_t = i64
+!final_element_t = f32
+!mst = memref<1234x!source_element_t>
+!vst = vector<10x!source_element_t>
+!vqt = vector<16x!quantized_element_t>
+!bvqt = vector<2x3x16x!quantized_element_t>
+!bvtt = vector<2x3x16x!expanded_element_t>
+!vct = vector<16x!final_element_t>
+!bvct = vector<2x3x16x!final_element_t>
+!mtt = memref<1234x!final_element_t>
+
+// CHECK-LABEL: @f3(
+//  CHECK-SAME: %[[A:[a-z0-9]+]]: memref<1234xi8>
+//  CHECK-SAME: %[[IDX:[a-z0-9]+]]: index
+//  CHECK-SAME: %[[B:[a-z0-9]+]]: memref<1234xf32>
+func.func @f3(%m: !mst, %idx : index, %mf: !mtt) {
+
+//              CHECK: %[[MASK:.*]] = arith.constant dense<[
+// CHECK-SAME-COUNT-2: 31, 992, 31744, 1015808, 32505856, 1040187392, 33285996544, -34359738368
+//         CHECK-SAME: : vector<16xi40>  
+//              CHECK: %[[SHIFT:.*]] = arith.constant dense<[
+// CHECK-SAME-COUNT-2: 0, 5, 10, 15, 20, 25, 30, 35
+//         CHECK-SAME: : vector<16xi40>
+//              CHECK: %[[LOADED:.*]] = vector.load %[[A]][%[[IDX]]] : memref<1234xi8>, vector<10xi8>
+//              CHECK: %[[CAST:.*]] = vector.bitcast %[[LOADED]] : vector<10xi8> to vector<2xi40>
+//              CHECK: %[[SHUFFLED:.*]] = vector.shuffle %[[CAST]], %[[CAST]] [
+// CHECK-SAME-COUNT-8: 0
+// CHECK-SAME-COUNT-8: 1
+//         CHECK-SAME: : vector<2xi40>, vector<2xi40>
+//              CHECK: %[[MASKED:.*]] = arith.andi %[[SHUFFLED]], %[[MASK]] : vector<16xi40>
+//              CHECK: %[[SHIFTED:.*]] = arith.shrui %[[MASKED]], %[[SHIFT]] : vector<16xi40>
+//              CHECK: %[[EXT:.*]] = arith.extui %[[SHIFTED]] : vector<16xi40> to vector<16xi64>
+//              CHECK: %[[BCAST:.*]] = vector.broadcast %[[EXT]] : vector<16xi64> to vector<2x3x16xi64>
+//              CHECK: %[[CONVERTED:.*]] = arith.uitofp %[[BCAST]] : vector<2x3x16xi64> to vector<2x3x16xf32>
+  %v = vector.load %m[%idx] : !mst, !vst
+  %b = vector.bitcast %v : !vst to !vqt
+  %bb = vector.broadcast %b : !vqt to !bvqt
+  %be = arith.extui %bb : !bvqt to !bvtt
+  %bf = arith.uitofp %be : !bvtt to !bvct
+  %f = vector.extract %bf[0, 0] : !bvct
+  %vf = vector.load %mf[%idx] : !mtt, !vct
+  %res = arith.addf %vf, %f : !vct
+  vector.store %res, %mf[%idx] : !mtt, !vct
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !transform.any_op):
+  %f = transform.structured.match ops{["func.func"]} in %module_op 
+    : (!transform.any_op) -> !transform.any_op
+
+  transform.apply_patterns to %f {
+    transform.apply_patterns.vector.rewrite_narrow_types
+  } : !transform.any_op
+}



More information about the Mlir-commits mailing list