[flang-commits] [llvm] [openmp] [flang] [mlir] [lldb] [mlir][Vector] Add patterns for efficient i4 -> i8 conversion emulation (PR #79494)

Diego Caballero via flang-commits flang-commits at lists.llvm.org
Thu Jan 25 16:57:38 PST 2024


https://github.com/dcaballe updated https://github.com/llvm/llvm-project/pull/79494

>From b8fb65dd1e65c36cfb2104e5f35179faa6011552 Mon Sep 17 00:00:00 2001
From: Diego Caballero <diegocaballero at google.com>
Date: Thu, 25 Jan 2024 02:39:14 +0000
Subject: [PATCH] [mlir][Vector] Add patterns for efficient i4 -> i8 conversion
 emulation

This PR adds new patterns to improve the generated vector code for the
emulation of any conversion that have to go through an i4 -> i8 type
extension (only signed extensions are supported for now). This will
impact any i4 -> i8/i16/i32/i64 signed extensions as well as sitofp i4
-> f8/f16/f32/f64.

The asm code generated for the supported cases is significantly better
after this PR for both x86 and aarch64.
---
 .../Transforms/VectorEmulateNarrowType.cpp    | 176 ++++++++++++++++--
 .../Vector/vector-rewrite-narrow-types.mlir   |  33 ++++
 2 files changed, 189 insertions(+), 20 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index a4a72754ccc250..8abd34fd246224 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -642,9 +642,9 @@ struct BitCastRewriter {
 
   BitCastRewriter(VectorType sourceVectorType, VectorType targetVectorType);
 
-  /// Verify that the preconditions for the rewrite are met.
-  LogicalResult precondition(PatternRewriter &rewriter,
-                             VectorType preconditionVectorType, Operation *op);
+  /// Verify that general preconditions for the rewrite are met.
+  LogicalResult commonPrecondition(PatternRewriter &rewriter,
+                                   VectorType preconditionType, Operation *op);
 
   /// Precompute the metadata for the rewrite.
   SmallVector<BitCastRewriter::Metadata>
@@ -652,9 +652,9 @@ struct BitCastRewriter {
 
   /// Rewrite one step of the sequence:
   ///   `(shuffle -> and -> shiftright -> shiftleft -> or)`.
-  Value rewriteStep(PatternRewriter &rewriter, Location loc, Value initialValue,
-                    Value runningResult,
-                    const BitCastRewriter::Metadata &metadata);
+  Value genericRewriteStep(PatternRewriter &rewriter, Location loc,
+                           Value initialValue, Value runningResult,
+                           const BitCastRewriter::Metadata &metadata);
 
 private:
   /// Underlying enumerator that encodes the provenance of the bits in the each
@@ -719,21 +719,57 @@ BitCastRewriter::BitCastRewriter(VectorType sourceVectorType,
   LDBG("\n" << enumerator.sourceElementRanges);
 }
 
-LogicalResult BitCastRewriter::precondition(PatternRewriter &rewriter,
-                                            VectorType precondition,
-                                            Operation *op) {
-  if (precondition.getRank() != 1 || precondition.isScalable())
+/// Verify that the precondition type meets the common preconditions for any
+/// conversion.
+static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter,
+                                                  VectorType preconditionType,
+                                                  Operation *op) {
+  if (!preconditionType || preconditionType.getRank() != 1 ||
+      preconditionType.isScalable())
     return rewriter.notifyMatchFailure(op, "scalable or >1-D vector");
 
   // TODO: consider relaxing this restriction in the future if we find ways
   // to really work with subbyte elements across the MLIR/LLVM boundary.
-  int64_t resultBitwidth = precondition.getElementTypeBitWidth();
+  unsigned resultBitwidth = preconditionType.getElementTypeBitWidth();
   if (resultBitwidth % 8 != 0)
     return rewriter.notifyMatchFailure(op, "bitwidth is not k * 8");
 
   return success();
 }
 
+LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter,
+                                                  VectorType preconditionType,
+                                                  Operation *op) {
+  if (!enumerator.sourceVectorType || !enumerator.targetVectorType)
+    return rewriter.notifyMatchFailure(op, "types are not vector");
+
+  return commonConversionPrecondition(rewriter, preconditionType, op);
+}
+
+/// Verify that source and destination element types meet the precondition for
+/// the supported aligned conversion cases. Alignment means that the either the
+/// source element type is multiple of the destination element type or the other
+/// way around.
+///
+/// NOTE: This method assumes that common conversion preconditions are met.
+static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
+                                                   VectorType srcType,
+                                                   VectorType dstType,
+                                                   Operation *op) {
+  if (!srcType || !dstType)
+    return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
+  unsigned srcElemBitwidth = srcType.getElementTypeBitWidth();
+  unsigned dstElemBitwidth = dstType.getElementTypeBitWidth();
+  unsigned byteBitwidth = 8;
+
+  // Only {s}i4 -> (size_of({{s}i/f}) >= 8) are supported for now.
+  if (srcElemBitwidth != 4 || dstElemBitwidth < 8 ||
+      (dstElemBitwidth % srcElemBitwidth) != 0)
+    return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
+
+  return success();
+}
+
 SmallVector<BitCastRewriter::Metadata>
 BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
   SmallVector<BitCastRewriter::Metadata> result;
@@ -775,9 +811,9 @@ BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
   return result;
 }
 
-Value BitCastRewriter::rewriteStep(PatternRewriter &rewriter, Location loc,
-                                   Value initialValue, Value runningResult,
-                                   const BitCastRewriter::Metadata &metadata) {
+Value BitCastRewriter::genericRewriteStep(
+    PatternRewriter &rewriter, Location loc, Value initialValue,
+    Value runningResult, const BitCastRewriter::Metadata &metadata) {
   // Create vector.shuffle from the metadata.
   auto shuffleOp = rewriter.create<vector::ShuffleOp>(
       loc, initialValue, initialValue, metadata.shuffles);
@@ -810,6 +846,44 @@ Value BitCastRewriter::rewriteStep(PatternRewriter &rewriter, Location loc,
   return runningResult;
 }
 
+/// Rewrite the i4 -> i8 signed extension into a sequence of shuffles and
+/// bitwise ops that take advantage of high-level information to avoid leaving
+/// LLVM to scramble with peephole optimizations.
+static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
+                                    Value srcValue) {
+  VectorType srcVecType = cast<VectorType>(srcValue.getType());
+  assert(srcVecType.getElementType().isSignlessInteger(4) &&
+         "Expected i4 type");
+
+  // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
+  int64_t vecDimSize = srcVecType.getShape().back();
+  SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
+  constexpr int64_t i4Toi8BitwidthFactor = 2;
+  i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor;
+  auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
+  Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
+
+  // 2. Extend i4 elements to i8 elements using shifts. Low i4 elemens of each
+  // byte are place in one vector and the high i4 elements in another vector.
+  constexpr int8_t bitsToShift = 4;
+  auto shiftValues = rewriter.create<arith::ConstantOp>(
+      loc, DenseElementsAttr::get(i8VecType, bitsToShift));
+  Value shl = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues);
+  Value low = rewriter.create<arith::ShRSIOp>(loc, shl, shiftValues);
+  Value high = rewriter.create<arith::ShRSIOp>(loc, i8Vector, shiftValues);
+
+  // 3. Interleave low and high i8 elements using a shuffle.
+  SmallVector<int64_t> interleaveMaskValues;
+  interleaveMaskValues.reserve(vecDimSize);
+  for (int i = 0, end = vecDimSize / 2; i < end; ++i) {
+    interleaveMaskValues.push_back(i);
+    interleaveMaskValues.push_back(i + (vecDimSize / 2));
+  }
+
+  return rewriter.create<vector::ShuffleOp>(
+      loc, low, high, rewriter.getI64ArrayAttr(interleaveMaskValues));
+}
+
 namespace {
 /// Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take
 /// advantage of high-level information to avoid leaving LLVM to scramble with
@@ -829,7 +903,7 @@ struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
     VectorType sourceVectorType = bitCastOp.getSourceVectorType();
     VectorType targetVectorType = bitCastOp.getResultVectorType();
     BitCastRewriter bcr(sourceVectorType, targetVectorType);
-    if (failed(bcr.precondition(rewriter, targetVectorType, bitCastOp)))
+    if (failed(bcr.commonPrecondition(rewriter, targetVectorType, bitCastOp)))
       return failure();
 
     // Perform the rewrite.
@@ -839,8 +913,8 @@ struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
     Value runningResult;
     for (const BitCastRewriter ::Metadata &metadata :
          bcr.precomputeMetadata(shuffledElementType)) {
-      runningResult = bcr.rewriteStep(rewriter, bitCastOp->getLoc(), truncValue,
-                                      runningResult, metadata);
+      runningResult = bcr.genericRewriteStep(
+          rewriter, bitCastOp->getLoc(), truncValue, runningResult, metadata);
     }
 
     // Finalize the rewrite.
@@ -893,7 +967,7 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
     VectorType sourceVectorType = bitCastOp.getSourceVectorType();
     VectorType targetVectorType = bitCastOp.getResultVectorType();
     BitCastRewriter bcr(sourceVectorType, targetVectorType);
-    if (failed(bcr.precondition(
+    if (failed(bcr.commonPrecondition(
             rewriter, cast<VectorType>(extOp.getOut().getType()), bitCastOp)))
       return failure();
 
@@ -904,8 +978,8 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
         cast<IntegerType>(getElementTypeOrSelf(sourceValue.getType()));
     for (const BitCastRewriter::Metadata &metadata :
          bcr.precomputeMetadata(shuffledElementType)) {
-      runningResult = bcr.rewriteStep(rewriter, bitCastOp->getLoc(),
-                                      sourceValue, runningResult, metadata);
+      runningResult = bcr.genericRewriteStep(
+          rewriter, bitCastOp->getLoc(), sourceValue, runningResult, metadata);
     }
 
     // Finalize the rewrite.
@@ -923,6 +997,62 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
     return success();
   }
 };
+
+/// Rewrite the i4 -> i8 part of any conversion into a sequence of shuffles and
+/// bitwise ops that take advantage of high-level information to avoid leaving
+/// LLVM to scramble with peephole optimizations.
+///
+/// For example:
+///    arith.extsi %in : vector<8xi4> to vector<8xi32>
+///      is rewriten as
+///        %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
+///        %1 = arith.shli %0, 4 : vector<4xi8>
+///        %2 = arith.shrsi %1, 4 : vector<4xi8>
+///        %3 = arith.shrsi %0, 4 : vector<4xi8>
+///        %4 = vector.shuffle %2, %3 [0, 4, 1, 5, 2, 6, 3, 7]
+///             : vector<4xi8>, vector<4xi8>
+///        %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32>
+///
+///    arith.sitofp %in : vector<8xi4> to vector<8xf32>
+///      is rewriten as
+///        %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
+///        %1 = arith.shli %0, 4 : vector<4xi8>
+///        %2 = arith.shrsi %1, 4 : vector<4xi8>
+///        %3 = arith.shrsi %0, 4 : vector<4xi8>
+///        %4 = vector.shuffle %2, %3 [0, 4, 1, 5, 2, 6, 3, 7]
+///             : vector<4xi8>, vector<4xi8>
+///        %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
+///
+template <typename ConversionOpType>
+struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
+  using OpRewritePattern<ConversionOpType>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ConversionOpType conversionOp,
+                                PatternRewriter &rewriter) const override {
+    // Set up the BitCastRewriter and verify the preconditions.
+    Value srcValue = conversionOp.getIn();
+    auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
+    auto dstVecType = dyn_cast<VectorType>(conversionOp.getType());
+    if (failed(
+            commonConversionPrecondition(rewriter, dstVecType, conversionOp)))
+      return failure();
+
+    // Check general alignment preconditions.
+    if (failed(alignedConversionPrecondition(rewriter, srcVecType, dstVecType,
+                                             conversionOp)))
+      return failure();
+
+    // Perform the rewrite.
+    Value subByteExt =
+        rewriteI4ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue);
+
+    // Finalize the rewrite.
+    rewriter.replaceOpWithNewOp<ConversionOpType>(
+        conversionOp, conversionOp.getType(), subByteExt);
+    return success();
+  }
+};
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -944,4 +1074,10 @@ void vector::populateVectorNarrowTypeRewritePatterns(
   patterns.add<RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>,
                RewriteExtOfBitCast<arith::ExtSIOp>>(patterns.getContext(),
                                                     benefit);
+
+  // Patterns for aligned cases. We set higher priority as they are expected to
+  // generate better performance for aligned cases.
+  patterns.add<RewriteAlignedSubByteIntSignedExt<arith::ExtSIOp>,
+               RewriteAlignedSubByteIntSignedExt<arith::SIToFPOp>>(
+      patterns.getContext(), benefit.getBenefit() + 1);
 }
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
index a600fa955b1700..c4fbb4c219b917 100644
--- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -193,6 +193,39 @@ func.func @f3ext(%a: vector<5xi8>) -> vector<8xi17> {
   return %1 : vector<8xi17>
 }
 
+// CHECK-LABEL: func.func @aligned_extsi(
+func.func @aligned_extsi(%a: vector<8xi4>) -> vector<8xi32> {
+  // CHECK: arith.shli
+  // CHECK: arith.shrsi
+  // CHECK: arith.shrsi
+  // CHECK: vector.shuffle
+  // CHECK: arith.extsi %{{.*}} : vector<8xi8> to vector<8xi32>
+  %0 = arith.extsi %a : vector<8xi4> to vector<8xi32>
+  return %0 : vector<8xi32>
+}
+
+// CHECK-LABEL: func.func @aligned_extsi_base_case(
+func.func @aligned_extsi_base_case(%a: vector<8xi4>) -> vector<8xi8> {
+  // CHECK: arith.shli
+  // CHECK: arith.shrsi
+  // CHECK: arith.shrsi
+  // CHECK: vector.shuffle
+  // CHECK-NOT: arith.extsi
+  %0 = arith.extsi %a : vector<8xi4> to vector<8xi8>
+  return %0 : vector<8xi8>
+}
+
+// CHECK-LABEL: func.func @aligned_sitofp(
+func.func @aligned_sitofp(%a: vector<8xi4>) -> vector<8xf32> {
+  // CHECK: arith.shli
+  // CHECK: arith.shrsi
+  // CHECK: arith.shrsi
+  // CHECK: shuffle
+  // CHECK: arith.sitofp %{{.*}} : vector<8xi8> to vector<8xf32>
+  %0 = arith.sitofp %a : vector<8xi4> to vector<8xf32>
+  return %0 : vector<8xf32>
+}
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
     %f = transform.structured.match ops{["func.func"]} in %module_op



More information about the flang-commits mailing list