[Mlir-commits] [mlir] [mlir][Vector] Add a rewrite pattern for better low-precision ext(bit… (PR #65774)

Nicolas Vasilache llvmlistbot at llvm.org
Fri Sep 8 09:09:45 PDT 2023


https://github.com/nicolasvasilache created https://github.com/llvm/llvm-project/pull/65774:

…cast) expansion

This revision adds a rewrite for sequences of vector ext(maybe_broadcast(bitcast)) to use a more efficient sequence of vector operations comprising shuffles, shifts and bitwise logical ops. The rewrite uses an intermediate bitwidth equal to the licm of the element types of the source and result types of `bitCastOp`. This intermediate type may be small or greater than the desired elemental type of the extOp, in which case appropriate ext or trunc operations are inserted. The rewrite fails if the intermediate type is greater than 64 and if the involved vector types fail to meet basic divisilibity requirements. In other words, this rewrite does not handle partial vector boundaries and leaves this part of the heavy-lifting to LLVM.

>From fde457483108f47dacfbe001b665bdd277b3e7e8 Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nicolasvasilache at users.noreply.github.com>
Date: Fri, 8 Sep 2023 11:28:55 +0200
Subject: [PATCH] [mlir][Vector] Add a rewrite pattern for better low-precision
 ext(bitcast) expansion

This revision adds a rewrite for sequences of vector ext(maybe_broadcast(bitcast)) to use a more efficient
sequence of vector operations comprising shuffles, shifts and bitwise
logical ops. The rewrite uses an intermediate bitwidth equal to the licm of
the element types of the source and result types of `bitCastOp`. This
intermediate type may be small or greater than the desired elemental type of
the extOp, in which case appropriate ext or trunc operations are inserted.
The rewrite fails if the intermediate type is greater than 64 and if the
involved vector types fail to meet basic divisilibity requirements. In other
words, this rewrite does not handle partial vector boundaries and leaves
this part of the heavy-lifting to LLVM.
---
 .../Vector/TransformOps/VectorTransformOps.td |  13 +
 .../Vector/Transforms/VectorRewritePatterns.h |  22 +-
 mlir/include/mlir/IR/BuiltinTypes.h           |  10 +
 .../TransformOps/VectorTransformOps.cpp       |   5 +
 .../Transforms/VectorEmulateNarrowType.cpp    | 279 +++++++++++++++++-
 mlir/test/Dialect/LLVM/transform-e2e.mlir     |  21 --
 .../Vector/vector-rewrite-narrow-types.mlir   | 119 ++++++++
 7 files changed, 443 insertions(+), 26 deletions(-)
 create mode 100644 mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir

diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 2b8c95a94257e6c..d1cef91f8e27525 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -281,6 +281,19 @@ def ApplyLowerTransposePatternsOp : Op<Transform_Dialect,
   }];
 }
 
+def ApplyRewriteNarrowTypePatternsOp : Op<Transform_Dialect,
+    "apply_patterns.vector.rewrite_narrow_types",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Indicates that vector narrow rewrite operations should be applied.
+
+    This is usually a late step that is run after bufferization as part of the
+    process of lowering to e.g. LLVM or NVVM.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
 def ApplySplitTransferFullPartialPatternsOp : Op<Transform_Dialect,
     "apply_patterns.vector.split_transfer_full_partial",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index c644090d8c78cd0..20c33921f9de24e 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -23,6 +23,7 @@ namespace mlir {
 class RewritePatternSet;
 
 namespace arith {
+class AndIOp;
 class NarrowTypeEmulationConverter;
 } // namespace arith
 
@@ -143,7 +144,7 @@ void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
 
 /// Patterns that remove redundant vector broadcasts.
 void populateSinkVectorBroadcastPatterns(RewritePatternSet &patterns,
-                                          PatternBenefit benefit = 1);
+                                         PatternBenefit benefit = 1);
 
 /// Populate `patterns` with the following patterns.
 ///
@@ -301,6 +302,25 @@ void populateVectorNarrowTypeEmulationPatterns(
     arith::NarrowTypeEmulationConverter &typeConverter,
     RewritePatternSet &patterns);
 
+/// Rewrite vector ext(maybe_broadcast(bitcast)) to use a more efficient
+/// sequence of vector operations comprising shuffles, shifts and bitwise
+/// logical ops. The rewrite uses an intermediate bitwidth equal to the licm of
+/// the element types of the source and result types of `bitCastOp`. This
+/// intermediate type may be small or greater than the desired elemental type of
+/// the extOp, in which case appropriate ext or trunc operations are inserted.
+/// The rewrite fails if the intermediate type is greater than 64 and if the
+/// involved vector types fail to meet basic divisilibity requirements. In other
+/// words, this rewrite does not handle partial vector boundaries and leaves
+/// this part of the heavy-lifting to LLVM.
+FailureOr<Value> rewriteExtOfBitCast(RewriterBase &rewriter, Operation *extOp,
+                                     vector::BitCastOp bitCastOp,
+                                     vector::BroadcastOp maybeBroadcastOp);
+
+/// Appends patterns for rewriting vector operations over narrow types with
+/// ops over wider types.
+void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
+                                             PatternBenefit benefit = 1);
+
 } // namespace vector
 } // namespace mlir
 
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index f031eb0a5c30ce9..9df5548cd5d939c 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -357,6 +357,16 @@ class VectorType::Builder {
     return *this;
   }
 
+  /// Set a dim in shape @pos to val.
+  Builder &setDim(unsigned pos, int64_t val) {
+    if (storage.empty())
+      storage.append(shape.begin(), shape.end());
+    assert(pos < storage.size() && "overflow");
+    storage[pos] = val;
+    shape = {storage.data(), storage.size()};
+    return *this;
+  }
+
   operator VectorType() {
     return VectorType::get(shape, elementType, scalableDims);
   }
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 94f19e59669eafd..0fdeded436a9773 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -154,6 +154,11 @@ void transform::ApplyLowerTransposePatternsOp::populatePatterns(
   }
 }
 
+void transform::ApplyRewriteNarrowTypePatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  populateVectorNarrowTypeRewritePatterns(patterns);
+}
+
 void transform::ApplySplitTransferFullPartialPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
   vector::VectorTransformsOptions vectorTransformOptions;
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index b2b7bfc5e4437c1..27a595e7434b9b6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -7,7 +7,6 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
@@ -15,13 +14,29 @@
 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Support/MathExtras.h"
 #include "mlir/Transforms/DialectConversion.h"
-#include "llvm/Support/FormatVariadic.h"
-#include "llvm/Support/MathExtras.h"
-#include <cassert>
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/raw_ostream.h"
+#include <cstdint>
+#include <numeric>
+#include <type_traits>
 
 using namespace mlir;
 
+#define DEBUG_TYPE "vector-narrow-type-emulation"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+#define DBGSNL() (llvm::dbgs() << "\n")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
+
 namespace {
 
 //===----------------------------------------------------------------------===//
@@ -155,6 +170,255 @@ struct ConvertVectorTransferRead final
 };
 } // end anonymous namespace
 
+//===----------------------------------------------------------------------===//
+// RewriteExtOfBitCast
+//===----------------------------------------------------------------------===//
+
+/// Create a vector of bit masks: `idx .. idx + step - 1` and broadcast it
+/// `numOccurrences` times.
+/// `step` must divide `bitwidth` evenly.
+/// Example: (4, 2, 3) -> [0xc, 0x3, 0xc, 0x3, 0xc, 0x3].
+static SmallVector<Attribute> computeExtOfBitCastMasks(MLIRContext *ctx,
+                                                       int64_t bitwidth,
+                                                       int64_t step,
+                                                       int64_t numOccurrences) {
+  assert(bitwidth % step == 0 && "step must divide bitwidth evenly");
+  IntegerType interimIntType = IntegerType::get(ctx, bitwidth);
+  SmallVector<Attribute> tmpMasks;
+  tmpMasks.reserve(bitwidth / step);
+  // Create a vector of bit masks: `idx .. idx + step - 1`.
+  for (int64_t idx = 0; idx < bitwidth; idx += step) {
+    LDBG("Mask bits " << idx << " .. " << idx + step - 1 << " out of "
+                      << bitwidth);
+    IntegerAttr mask = IntegerAttr::get(
+        interimIntType, llvm::APInt::getBitsSet(bitwidth, idx, idx + step));
+    tmpMasks.push_back(mask);
+  }
+  // Replicate the vector of bit masks to the desired size.
+  SmallVector<Attribute> masks;
+  masks.reserve(numOccurrences * tmpMasks.size());
+  for (int64_t idx = 0; idx < numOccurrences; ++idx)
+    llvm::append_range(masks, tmpMasks);
+  return masks;
+}
+
+/// Create a vector of bit shifts by `k * idx` and broadcast it `numOccurrences`
+/// times.
+/// `step` must divide `bitwidth` evenly.
+/// Example: (4, 2, 3) -> [0x0, 0x2, 0x0, 0x2, 0x0, 0x2].
+static SmallVector<Attribute>
+computeExtOfBitCastShifts(MLIRContext *ctx, int64_t bitwidth, int64_t step,
+                          int64_t numOccurrences) {
+  assert(bitwidth % step == 0 && "step must divide bitwidth evenly");
+  IntegerType interimIntType = IntegerType::get(ctx, bitwidth);
+  SmallVector<Attribute> tmpShifts;
+  for (int64_t idx = 0; idx < bitwidth; idx += step) {
+    IntegerAttr shift = IntegerAttr::get(interimIntType, idx);
+    tmpShifts.push_back(shift);
+  }
+  SmallVector<Attribute> shifts;
+  for (int64_t idx = 0; idx < numOccurrences; ++idx)
+    llvm::append_range(shifts, tmpShifts);
+  return shifts;
+}
+
+/// Create a vector of bit shuffles: `numOccurrences * idx` and broadcast it
+/// `bitwidth/step` times.
+/// `step` must divide `bitwidth` evenly.
+/// Example: (4, 2, 3) -> [0x0, 0x1, 0x0, 0x1, 0x0, 0x1].
+static SmallVector<int64_t>
+computeExtOfBitCastShuffles(MLIRContext *ctx, int64_t bitwidth, int64_t step,
+                            int64_t numOccurrences) {
+  assert(bitwidth % step == 0 && "step must divide bitwidth evenly");
+  SmallVector<int64_t> shuffles;
+  int64_t n = floorDiv(bitwidth, step);
+  for (int64_t idx = 0; idx < numOccurrences; ++idx)
+    llvm::append_range(shuffles, SmallVector<int64_t>(n, idx));
+  return shuffles;
+}
+
+/// Compute the intermediate vector type, its elemental type must be an integer
+/// with bitwidth that:
+///   1. divides sourceBitWidth * mostMinorSourceDim
+///   2. smaller than 64.
+static int64_t computeExtOfBitCastBitWidth(int64_t sourceBitWidth,
+                                           int64_t mostMinorSourceDim,
+                                           int64_t targetBitWidth) {
+  for (int64_t mult : {32, 16, 8, 4, 2, 1}) {
+    int64_t interimBitWidth =
+        std::lcm(mult, std::lcm(sourceBitWidth, targetBitWidth));
+    if (interimBitWidth > 64)
+      continue;
+    if ((sourceBitWidth * mostMinorSourceDim) % interimBitWidth != 0)
+      continue;
+    return interimBitWidth;
+  }
+  return 0;
+}
+
+FailureOr<Value>
+mlir::vector::rewriteExtOfBitCast(RewriterBase &rewriter, Operation *extOp,
+                                  vector::BitCastOp bitCastOp,
+                                  vector::BroadcastOp maybeBroadcastOp) {
+  assert(
+      (llvm::isa<arith::ExtSIOp>(extOp) || llvm::isa<arith::ExtUIOp>(extOp)) &&
+      "unsupported op");
+
+  // The bitcast op is the load-bearing part, capture the source and bitCast
+  // types as well as bitwidth and most minor dimension.
+  VectorType sourceVectorType = bitCastOp.getSourceVectorType();
+  int64_t sourceBitWidth = sourceVectorType.getElementTypeBitWidth();
+  int64_t mostMinorSourceDim = sourceVectorType.getShape().back();
+  LDBG("sourceVectorType: " << sourceVectorType);
+
+  VectorType bitCastVectorType = bitCastOp.getResultVectorType();
+  int64_t targetBitWidth = bitCastVectorType.getElementTypeBitWidth();
+  LDBG("bitCastVectorType: " << bitCastVectorType);
+
+  int64_t interimBitWidth = computeExtOfBitCastBitWidth(
+      sourceBitWidth, mostMinorSourceDim, targetBitWidth);
+  LDBG("interimBitWidth: " << interimBitWidth);
+  if (!interimBitWidth) {
+    return rewriter.notifyMatchFailure(
+        extOp, "heuristic could not find a reasonable interim bitwidth");
+  }
+  if (sourceBitWidth == interimBitWidth || targetBitWidth == interimBitWidth) {
+    return rewriter.notifyMatchFailure(
+        extOp, "interim bitwidth is equal to source or target, nothing to do");
+  }
+
+  int64_t interimMostMinorDim =
+      sourceBitWidth * mostMinorSourceDim / interimBitWidth;
+  LDBG("interimMostMinorDim: " << interimMostMinorDim);
+
+  Location loc = extOp->getLoc();
+  MLIRContext *ctx = extOp->getContext();
+
+  VectorType interimVectorType =
+      VectorType::Builder(sourceVectorType)
+          .setDim(sourceVectorType.getRank() - 1, interimMostMinorDim)
+          .setElementType(IntegerType::get(ctx, interimBitWidth));
+  LDBG("interimVectorType: " << interimVectorType);
+
+  IntegerType interimIntType = IntegerType::get(ctx, interimBitWidth);
+  VectorType vt =
+      VectorType::Builder(bitCastVectorType).setElementType(interimIntType);
+
+  // Rewrite the original bitcast to the interim vector type and shuffle to
+  // broadcast to the desired size.
+  auto newBitCastOp = rewriter.create<vector::BitCastOp>(loc, interimVectorType,
+                                                         bitCastOp.getSource());
+  SmallVector<int64_t> shuffles = computeExtOfBitCastShuffles(
+      ctx, interimBitWidth, targetBitWidth, interimMostMinorDim);
+  auto shuffleOp = rewriter.create<vector::ShuffleOp>(loc, newBitCastOp,
+                                                      newBitCastOp, shuffles);
+  LDBG("shuffle: " << shuffleOp);
+
+  // Compute the constants for masking.
+  SmallVector<Attribute> masks = computeExtOfBitCastMasks(
+      ctx, interimBitWidth, targetBitWidth, interimMostMinorDim);
+  auto maskConstantOp = rewriter.create<arith::ConstantOp>(
+      loc, DenseElementsAttr::get(vt, masks));
+  LDBG("maskConstant: " << maskConstantOp);
+  auto andOp = rewriter.create<arith::AndIOp>(loc, shuffleOp, maskConstantOp);
+  LDBG("andOp: " << andOp);
+
+  // Preserve the intermediate type: this may have serious consequences on the
+  // backend's ability to generate efficient vector operations.
+  // For instance on x86, converting to f16 without going through i32 has severe
+  // performance implications.
+  // As a consequence, this pattern must preserve the original behavior.
+  VectorType resultType = cast<VectorType>(extOp->getResultTypes().front());
+  Type resultElementType = getElementTypeOrSelf(resultType);
+  SmallVector<Attribute> shifts = computeExtOfBitCastShifts(
+      ctx, interimBitWidth, targetBitWidth, interimMostMinorDim);
+  auto shiftConstantOp = rewriter.create<arith::ConstantOp>(
+      loc, DenseElementsAttr::get(vt, shifts));
+  LDBG("shiftConstant: " << shiftConstantOp);
+  Value newResult =
+      TypeSwitch<Operation *, Value>(extOp)
+          .template Case<arith::ExtSIOp>([&](arith::ExtSIOp op) {
+            Value shifted =
+                rewriter.create<arith::ShRSIOp>(loc, andOp, shiftConstantOp);
+            auto vt = shifted.getType().cast<VectorType>();
+            VectorType extVt =
+                VectorType::Builder(vt).setElementType(resultElementType);
+            Operation *res =
+                (resultElementType.getIntOrFloatBitWidth() > interimBitWidth)
+                    ? rewriter.create<arith::ExtSIOp>(loc, extVt, shifted)
+                    : rewriter.create<arith::TruncIOp>(loc, extVt, shifted);
+            return res->getResult(0);
+          })
+          .template Case<arith::ExtUIOp>([&](arith::ExtUIOp op) {
+            Value shifted =
+                rewriter.create<arith::ShRUIOp>(loc, andOp, shiftConstantOp);
+            auto vt = shifted.getType().cast<VectorType>();
+            VectorType extVt =
+                VectorType::Builder(vt).setElementType(resultElementType);
+            Operation *res =
+                (resultElementType.getIntOrFloatBitWidth() > interimBitWidth)
+                    ? rewriter.create<arith::ExtUIOp>(loc, extVt, shifted)
+                    : rewriter.create<arith::TruncIOp>(loc, extVt, shifted);
+            return res->getResult(0);
+          })
+          .Default([&](Operation *op) {
+            llvm_unreachable("unexpected op type");
+            return nullptr;
+          });
+
+  if (maybeBroadcastOp) {
+    newResult =
+        rewriter.create<vector::BroadcastOp>(loc, resultType, newResult);
+  }
+
+  return newResult;
+}
+
+namespace {
+template <typename ExtOpType>
+struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
+  using OpRewritePattern<ExtOpType>::OpRewritePattern;
+
+  RewriteExtOfBitCast(MLIRContext *context, PatternBenefit benefit)
+      : OpRewritePattern<ExtOpType>(context, benefit) {}
+
+  LogicalResult matchAndRewrite(ExtOpType extOp,
+                                PatternRewriter &rewriter) const override {
+    VectorType resultTy = dyn_cast<VectorType>(extOp.getType());
+    if (!resultTy)
+      return rewriter.notifyMatchFailure(extOp, "not a vector type");
+
+    int64_t elementalBitWidth = resultTy.getElementTypeBitWidth();
+    if (elementalBitWidth & (elementalBitWidth - 1)) {
+      return rewriter.notifyMatchFailure(
+          extOp, "result bitwidth must be a power of 2");
+    }
+
+    // Provision for a potential broadcast op that will be rewritten late.
+    auto maybeBroadcastOp =
+        extOp.getIn().template getDefiningOp<vector::BroadcastOp>();
+
+    // The source must be a bitcast op.
+    auto bitCastOp =
+        maybeBroadcastOp
+            ? maybeBroadcastOp.getSource()
+                  .template getDefiningOp<vector::BitCastOp>()
+            : extOp.getIn().template getDefiningOp<vector::BitCastOp>();
+    if (!bitCastOp)
+      return rewriter.notifyMatchFailure(extOp, "not a bitcast source");
+
+    // Try to rewrite.
+    FailureOr<Value> result =
+        rewriteExtOfBitCast(rewriter, extOp, bitCastOp, maybeBroadcastOp);
+    if (failed(result))
+      return failure();
+
+    rewriter.replaceOp(extOp, *result);
+    return success();
+  }
+};
+} // namespace
+
 //===----------------------------------------------------------------------===//
 // Public Interface Definition
 //===----------------------------------------------------------------------===//
@@ -167,3 +431,10 @@ void vector::populateVectorNarrowTypeEmulationPatterns(
   patterns.add<ConvertVectorLoad, ConvertVectorTransferRead>(
       typeConverter, patterns.getContext());
 }
+
+void vector::populateVectorNarrowTypeRewritePatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<RewriteExtOfBitCast<arith::ExtSIOp>,
+               RewriteExtOfBitCast<arith::ExtUIOp>>(patterns.getContext(),
+                                                    benefit);
+}
diff --git a/mlir/test/Dialect/LLVM/transform-e2e.mlir b/mlir/test/Dialect/LLVM/transform-e2e.mlir
index 777de75b1a47acc..2cb753a3d7fb8f3 100644
--- a/mlir/test/Dialect/LLVM/transform-e2e.mlir
+++ b/mlir/test/Dialect/LLVM/transform-e2e.mlir
@@ -29,33 +29,12 @@ transform.sequence failures(propagate) {
   // lowering TD macros.
   transform.apply_patterns to %f {
     transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
-  } : !transform.any_op
-
-  transform.apply_patterns to %f {
     transform.apply_patterns.vector.transfer_permutation_patterns
-  } : !transform.any_op
-
-  transform.apply_patterns to %f {
     transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerparallel"
-  } : !transform.any_op
-
-  transform.apply_patterns to %f {
     transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "linalg-copy"
-  } : !transform.any_op
-
-  transform.apply_patterns to %f {
     transform.apply_patterns.vector.transfer_to_scf max_transfer_rank = 1 full_unroll = true
-  } : !transform.any_op
-
-  transform.apply_patterns to %f {
     transform.apply_patterns.vector.lower_transfer max_transfer_rank = 1
-  } : !transform.any_op
-
-  transform.apply_patterns to %f {
     transform.apply_patterns.vector.lower_shape_cast
-  } : !transform.any_op
-
-  transform.apply_patterns to %f {
     transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_1d"
   } : !transform.any_op
 }
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
new file mode 100644
index 000000000000000..aed729b0e5d0d9f
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -0,0 +1,119 @@
+// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s
+
+// Inspect generated assembly and llvm-mca stats with:
+// mlir-opt --test-transform-dialect-interpreter mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir -test-transform-dialect-erase-schedule -test-lower-to-llvm | mlir-translate -mlir-to-llvmir | llc -o - -mcpu=skylake-avx512 | llvm-mca -mcpu=skylake-avx512 | less
+// mlir-opt --test-transform-dialect-interpreter mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir -test-transform-dialect-erase-schedule -test-lower-to-llvm | mlir-translate -mlir-to-llvmir | llc -o - -mtriple=aarch64-apple-ios --function-sections | llvm-mca -march=arm64 -mcpu=cyclone | less
+
+!source_element_t = i8
+!quantized_element_t = i3
+!expanded_element_t = i32
+!final_element_t = f16
+!mst = memref<1234x!source_element_t>
+!vst = vector<40x!source_element_t>
+!vst_i1 = vector<320xi1>
+!vqt_i1 = vector<288xi1>
+!vqt = vector<96x!quantized_element_t>
+!bvqt = vector<2x3x96x!quantized_element_t>
+!bvtt = vector<2x3x96x!expanded_element_t>
+!vct = vector<96x!final_element_t>
+!bvct = vector<2x3x96x!final_element_t>
+!mtt = memref<1234x!final_element_t>
+
+
+func.func @func(%m: !mst, %idx : index, %mf: !mtt) {
+//     %cst = arith.constant dense<[7, 56, 448, 3584, 28672, 229376, 1835008, 14680064, 117440512, 939524096, 7516192768, 60129542144, 481036337152, 3848290697216, 30786325577728, -35184372088832, 7, 56, 448, 3584, 28672, 229376, 1835008, 14680064, 117440512, 939524096, 7516192768, 60129542144, 481036337152, 3848290697216, 30786325577728, -35184372088832, 7, 56, 448, 3584, 28672, 229376, 1835008, 14680064, 117440
+// 512, 939524096, 7516192768, 60129542144, 481036337152, 3848290697216, 30786325577728, -35184372088832, 7, 56, 448, 3584, 28672, 229376, 1835008, 14680064, 117440512, 939524096, 7516192768, 60129542144, 481036337152, 3848290697216, 30786325577728, -35184372088832, 7, 56, 448, 3584, 28672, 229376, 1835008, 14680064, 117440512, 939524096, 7516192768, 60129542144, 481036337152, 3848290697216, 30786325577728, -3518
+// 4372088832, 7, 56, 448, 3584, 28672, 229376, 1835008, 14680064, 117440512, 939524096, 7516192768, 60129542144, 481036337152, 3848290697216, 30786325577728, -35184372088832]> : vector<96xi48>                             
+//     %cst_0 = arith.constant dense<[0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45, 0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45, 0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45, 0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45, 0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45, 0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45]> : vector<96xi48> 
+//     %0 = vector.load %arg0[%arg1] : memref<1234xi8>, vector<40xi8>                                                                                                                                                         
+//     %1 = vector.bitcast %0 : vector<40xi8> to vector<320xi1>                                                                                                                                                               
+//     %2 = vector.extract_strided_slice %1 {offsets = [0], sizes = [288], strides = [1]} : vector<320xi1> to vector<288xi1>                                                                                                  
+//     %3 = vector.bitcast %2 : vector<288xi1> to vector<6xi48>                                                                                                                                                               
+//     %4 = vector.shuffle %3, %3 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5] : vector<6xi48>, vector<6xi48>                                                               
+//     %5 = arith.andi %4, %cst : vector<96xi48>                                                                                                                                                                              
+//     %6 = arith.shrui %5, %cst_0 : vector<96xi48>                                                                                                                                                                           
+//     %7 = arith.trunci %6 : vector<96xi48> to vector<96xi32>                                                                                                                                                                
+//     %8 = vector.broadcast %7 : vector<96xi32> to vector<2x3x96xi32>                                                                                                                                                        
+//     %9 = arith.uitofp %8 : vector<2x3x96xi32> to vector<2x3x96xf16>                                                                                                                                                        
+//     %10 = vector.extract %9[0, 0] : vector<2x3x96xf16>                                                                                                                                                                     
+//     %11 = vector.load %arg2[%arg1] : memref<1234xf16>, vector<96xf16>                                                                                                                                               
+//     %12 = arith.addf %11, %10 : vector<96xf16>
+//     vector.store %12, %arg2[%arg1] : memref<1234xf16>, vector<96xf16>
+//     return
+
+  %v = vector.load %m[%idx] : !mst, !vst
+  %bi1 = vector.bitcast %v : !vst to !vst_i1
+  %ei1 = vector.extract_strided_slice %bi1 {offsets = [0], sizes = [288], strides = [1]} : !vst_i1 to !vqt_i1
+  %b = vector.bitcast %ei1 : !vqt_i1 to !vqt
+  %bb = vector.broadcast %b : !vqt to !bvqt
+  %be = arith.extui %bb : !bvqt to !bvtt
+  %bf = arith.uitofp %be : !bvtt to !bvct
+  %f = vector.extract %bf[0, 0] : !bvct
+  %vf = vector.load %mf[%idx] : !mtt, !vct
+  %res = arith.addf %vf, %f : !vct
+  vector.store %res, %mf[%idx] : !mtt, !vct
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !transform.any_op):
+  %f = transform.structured.match ops{["func.func"]} in %module_op 
+    : (!transform.any_op) -> !transform.any_op
+
+  transform.apply_patterns to %f {
+    transform.apply_patterns.vector.rewrite_narrow_types
+  } : !transform.any_op
+}
+
+// -----
+
+!source_element_t = i8
+!quantized_element_t = i5
+!expanded_element_t = i8
+!final_element_t = f32
+!mst = memref<1234x!source_element_t>
+!vst = vector<40x!source_element_t>
+!vqt = vector<64x!quantized_element_t>
+!bvqt = vector<2x3x64x!quantized_element_t>
+!bvtt = vector<2x3x64x!expanded_element_t>
+!vct = vector<64x!final_element_t>
+!bvct = vector<2x3x64x!final_element_t>
+!mtt = memref<1234x!final_element_t>
+
+func.func @func(%m: !mst, %idx : index, %mf: !mtt) {
+  //   %cst = arith.constant dense<[31, 992, 31744, 1015808, 32505856, 1040187392, 33285996544, -34359738368, 31, 992, 31744, 1015808, 32505856, 1040187392, 33285996544, -34359738368, 31, 992, 31744, 1015808, 32505856, 1040187392, 33285996544, -34359738368, 31, 992, 31744, 1015808, 32505856, 1040187392, 33285996544, -34359738368, 31, 992, 31744, 1015808, 32505856, 1040187392, 33285996544, -34359738368, 31, 992, 31744, 1015808, 32505856, 1040187392, 33285996544, -34359738368, 31, 992, 31744, 1015808, 32505856, 1040187392, 33285996544, -34359738368, 31, 992, 31744, 1015808, 32505856, 1040187392, 33285996544, -34359738368]> : vector<64xi40>
+  // %cst_0 = arith.constant dense<[0, 5, 10, 15, 20, 25, 30, 35, 0, 5, 10, 15, 20, 25, 30, 35, 0, 5, 10, 15, 20, 25, 30, 35, 0, 5, 10, 15, 20, 25, 30, 35, 0, 5, 10, 15, 20, 25, 30, 35, 0, 5, 10, 15, 20, 25, 30, 35, 0, 5, 10, 15, 20, 25, 30, 35, 0, 5, 10, 15, 20, 25, 30, 35]> : vector<64xi40>
+  // %0 = vector.load %arg0[%arg1] : memref<1234xi8>, vector<40xi8>
+  // %1 = vector.bitcast %0 : vector<40xi8> to vector<8xi40>
+  // %2 = vector.shuffle %1, %1 [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7] : vector<8xi40>, vector<8xi40>
+  // %3 = arith.andi %2, %cst : vector<64xi40>
+  // %4 = arith.shrui %3, %cst_0 : vector<64xi40>
+  // %5 = arith.trunci %4 : vector<64xi40> to vector<64xi8>
+  // %6 = vector.broadcast %5 : vector<64xi8> to vector<2x3x64xi8>
+  // %7 = arith.uitofp %6 : vector<2x3x64xi8> to vector<2x3x64xf32>
+  // %8 = vector.extract %7[0, 0] : vector<2x3x64xf32>
+  // %9 = vector.load %arg2[%arg1] : memref<1234xf32>, vector<64xf32>
+  // %10 = arith.addf %9, %8 : vector<64xf32>
+  // vector.store %10, %arg2[%arg1] : memref<1234xf32>, vector<64xf32>
+
+  %v = vector.load %m[%idx] : !mst, !vst
+  %b = vector.bitcast %v : !vst to !vqt
+  %bb = vector.broadcast %b : !vqt to !bvqt
+  %be = arith.extui %bb : !bvqt to !bvtt
+  %bf = arith.uitofp %be : !bvtt to !bvct
+  %f = vector.extract %bf[0, 0] : !bvct
+  %vf = vector.load %mf[%idx] : !mtt, !vct
+  %res = arith.addf %vf, %f : !vct
+  vector.store %res, %mf[%idx] : !mtt, !vct
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !transform.any_op):
+  %f = transform.structured.match ops{["func.func"]} in %module_op 
+    : (!transform.any_op) -> !transform.any_op
+
+  transform.apply_patterns to %f {
+    transform.apply_patterns.vector.rewrite_narrow_types
+  } : !transform.any_op
+}



More information about the Mlir-commits mailing list