[Mlir-commits] [mlir] [WIP] using splat shifts (PR #87121)

Finn Plummer llvmlistbot at llvm.org
Fri Apr 12 11:36:59 PDT 2024


https://github.com/inbelic updated https://github.com/llvm/llvm-project/pull/87121

>From e29a6dafa80938e78a5e7d37bfa33d50a829a266 Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Thu, 28 Mar 2024 09:11:04 -0700
Subject: [PATCH 1/2] [WIP] using splat shifts

---
 .../Vector/TransformOps/VectorTransformOps.td |   2 +
 .../Vector/Transforms/VectorRewritePatterns.h |   3 +-
 .../TransformOps/VectorTransformOps.cpp       |   3 +-
 .../Transforms/VectorEmulateNarrowType.cpp    | 184 +++++++++++++++++-
 .../Vector/vector-rewrite-narrow-types.mlir   |  76 +++++++-
 .../Vector/CPU/test-rewrite-narrow-types.mlir |  70 ++++++-
 6 files changed, 325 insertions(+), 13 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index f6371f39c39444..81d66e7b6ab179 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -318,6 +318,8 @@ def ApplyRewriteNarrowTypePatternsOp : Op<Transform_Dialect,
     Warning: these patterns currently only work for little endian targets.
   }];
 
+  let arguments = (ins DefaultValuedAttr<I64Attr, "0">:$max_cycle_len);
+
   let assemblyFormat = "attr-dict";
 }
 
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 453fa73429dd1a..9cd3d80c441d95 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -376,7 +376,8 @@ FailureOr<Value> rewriteExtOfBitCast(RewriterBase &rewriter, Operation *extOp,
 /// ops over wider types.
 /// Warning: these patterns currently only work for little endian targets.
 void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
-                                             PatternBenefit benefit = 1);
+                                             PatternBenefit benefit = 1,
+                                             unsigned shiftDepth = 0);
 
 /// Appends patterns for emulating a sub-byte vector transpose.
 void populateVectorTransposeNarrowTypeRewritePatterns(
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 885644864c0f71..e8c9033da72686 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -166,7 +166,8 @@ void transform::ApplyLowerInterleavePatternsOp::populatePatterns(
 
 void transform::ApplyRewriteNarrowTypePatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
-  populateVectorNarrowTypeRewritePatterns(patterns);
+  populateVectorNarrowTypeRewritePatterns(patterns, /*default=*/1,
+                                          getMaxCycleLen());
   populateVectorTransposeNarrowTypeRewritePatterns(patterns);
 }
 
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index dc6f126aae4c87..ecd3879e5f1f9e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -546,6 +546,14 @@ struct SourceElementRangeList : public SmallVector<SourceElementRange> {
 /// and `vector.bitcast ... : vector<2xi15> to vector<3xi10>` is decomposed as:
 ///   [0] = {0, [0, 10)}, {1, [0, 5)}
 ///   [1] = {1, [5, 10)}, {2, [0, 10)}
+/// and `vector.bitcast ... : vector<4xi4> to vector<2xi8>` is decomposed as:
+///   [0] = {0, [0, 4)}, {1, [0, 4)}
+///   [1] = {2, [0, 4)}, {3, [0, 4)}
+/// and `vector.bitcast ... : vector<2xi8> to vector<4xi4>` is decomposed as:
+///   [0] = {0, [0, 4)}
+///   [1] = {0, [4, 8)}
+///   [2] = {1, [0, 4)}
+///   [3] = {1, [4, 8)}
 struct BitCastBitsEnumerator {
   BitCastBitsEnumerator(VectorType sourceVectorType,
                         VectorType targetVectorType);
@@ -633,6 +641,35 @@ struct BitCastBitsEnumerator {
 /// `(shuffle -> and -> shiftright -> shiftleft -> or)` to iteratively update
 /// the result vector (i.e. the `shiftright -> shiftleft -> or` part) with the
 /// bits extracted from the source vector (i.e. the `shuffle -> and` part).
+///
+///
+/// When we consider the above algorithm to rewrite our vector.bitcast, we rely
+/// on using dynamic shift amounts for the left and right shifts. This can be
+/// inefficient on certain targets (RDNA GPUs) in contrast to a splat constant
+/// value. So when possible we can rewrite this as a combination of shifts with
+/// a constant splat value and then regroup the selected terms.
+///
+/// Eg. Instead of:
+///    res = arith.shrui x [0, 4, 8, 0, 4, 8]
+/// use:
+///    y = arith.shrui x [0, 0, 0, 0, 0, 0]  (can be folded away)
+///    y1 = arith.shrui x [4, 4, 4, 4, 4, 4]
+///    y2 = arith.shrui x [8, 8, 8, 8, 8, 8]
+///    y3 = vector.shuffle y y1 [0, 7, 3, 10]
+///    res = vector.shuffle y3 y2 [0, 1, 7, 2, 3, 10]
+///
+/// This is possible when the precomputed shift amounts following a cyclic
+/// pattern of [x, y, z, ..., x, y, z, ...] such that the cycle length,
+/// cycleLen, satisifies 1 < cycleLen < size(shiftAmounts). And the shuffles are
+/// of the form [0, 0, 0, ..., 1, 1, 1, ...]. A common pattern in
+/// (de)quantization, i24 -> 3xi8 or 3xi8 -> i24. The modified algorithm follows
+/// the same 2 steps as above, then it proceeds as follows:
+///
+///   2. for each element in the cycle, x, of the rightShiftAmounts create a
+///   shrui with a splat constant of x.
+///   3. repeat 2. with the respective leftShiftAmounts
+///   4. construct a chain of vector.shuffles that will reconstruct the result
+///   from the chained shifts
 struct BitCastRewriter {
   /// Helper metadata struct to hold the static quantities for the rewrite.
   struct Metadata {
@@ -656,10 +693,25 @@ struct BitCastRewriter {
                            Value initialValue, Value runningResult,
                            const BitCastRewriter::Metadata &metadata);
 
+  /// Rewrite one step of the sequence when able to use a splat constant for the
+  /// shiftright and shiftleft.
+  Value splatRewriteStep(PatternRewriter &rewriter, Location loc,
+                         Value initialValue, Value runningResult,
+                         const BitCastRewriter::Metadata &metadata);
+
+  bool useSplatStep(unsigned maxCycleLen) {
+    return 1 < cycleLen && cycleLen <= maxCycleLen;
+  }
+
 private:
   /// Underlying enumerator that encodes the provenance of the bits in the each
   /// element of the result vector.
   BitCastBitsEnumerator enumerator;
+
+  // Underlying cycleLen computed during precomputeMetadata. A cycleLen > 1
+  // denotes that there is a cycle in the precomputed shift amounts and we are
+  // able to use the splatRewriteStep.
+  int64_t cycleLen = 0;
 };
 
 } // namespace
@@ -775,8 +827,40 @@ static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
   return success();
 }
 
+// Check if the vector is a cycle of the first cycleLen elements.
+template <class T>
+static bool isCyclic(SmallVector<T> xs, int64_t cycleLen) {
+  for (int64_t idx = cycleLen, n = xs.size(); idx < n; idx++) {
+    if (xs[idx] != xs[idx % cycleLen])
+      return false;
+  }
+  return true;
+}
+
+static SmallVector<int64_t> constructShuffles(int64_t inputSize,
+                                              int64_t numCycles,
+                                              int64_t cycleLen, int64_t idx) {
+  // If idx == 1, then the first operand of the shuffle will be the mask which
+  // will have the original size. So we need to step through the mask with a
+  // stride of cycleSize.
+  // When idx > 1, then the first operand will be the size of (idx * cycleSize)
+  // and so we take the first idx elements of the input and then append the
+  // strided mask value.
+  int64_t inputStride = idx == 1 ? cycleLen : idx;
+
+  SmallVector<int64_t> shuffles;
+  for (int64_t cycle = 0; cycle < numCycles; cycle++) {
+    for (int64_t inputIdx = 0; inputIdx < idx; inputIdx++) {
+      shuffles.push_back(cycle * inputStride + inputIdx);
+    }
+    shuffles.push_back(inputSize + cycle * cycleLen + idx);
+  }
+  return shuffles;
+}
+
 SmallVector<BitCastRewriter::Metadata>
 BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
+  bool cyclicShifts = true;
   SmallVector<BitCastRewriter::Metadata> result;
   for (int64_t shuffleIdx = 0, e = enumerator.getMaxNumberOfEntries();
        shuffleIdx < e; ++shuffleIdx) {
@@ -811,8 +895,71 @@ BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
           IntegerAttr::get(shuffledElementType, shiftLeft));
     }
 
+    // Compute a potential cycle size by detecting the number of sourceElements
+    // at the start of shuffle that are the same
+    cycleLen = 1;
+    for (int64_t n = shuffles.size(); cycleLen < n; cycleLen++)
+      if (shuffles[cycleLen] != shuffles[0])
+        break;
+
+    cyclicShifts = cyclicShifts && (cycleLen < (int64_t)shuffles.size()) &&
+                   isCyclic(shiftRightAmounts, cycleLen) &&
+                   isCyclic(shiftLeftAmounts, cycleLen);
+
     result.push_back({shuffles, masks, shiftRightAmounts, shiftLeftAmounts});
   }
+
+  cycleLen = cyclicShifts ? cycleLen : 0;
+  return result;
+}
+
+Value BitCastRewriter::splatRewriteStep(
+    PatternRewriter &rewriter, Location loc, Value initialValue,
+    Value runningResult, const BitCastRewriter::Metadata &metadata) {
+
+  // Initial result will be the Shifted Mask which will have the shuffles size.
+  int64_t inputSize = metadata.shuffles.size();
+  int64_t numCycles = inputSize / cycleLen;
+
+  auto shuffleOp = rewriter.create<vector::ShuffleOp>(
+      loc, initialValue, initialValue, metadata.shuffles);
+
+  // Intersect with the mask.
+  VectorType shuffledVectorType = shuffleOp.getResultVectorType();
+  auto constOp = rewriter.create<arith::ConstantOp>(
+      loc, DenseElementsAttr::get(shuffledVectorType, metadata.masks));
+  Value andValue = rewriter.create<arith::AndIOp>(loc, shuffleOp, constOp);
+
+  Value result;
+  for (int64_t idx = 0; idx < cycleLen; idx++) {
+    auto shiftRightConstantOp = rewriter.create<arith::ConstantOp>(
+        loc, SplatElementsAttr::get(shuffledVectorType,
+                                    metadata.shiftRightAmounts[idx]));
+    Value shiftedRight =
+        rewriter.create<arith::ShRUIOp>(loc, andValue, shiftRightConstantOp);
+
+    auto shiftLeftConstantOp = rewriter.create<arith::ConstantOp>(
+        loc, SplatElementsAttr::get(shuffledVectorType,
+                                    metadata.shiftLeftAmounts[idx]));
+    Value shiftedLeft =
+        rewriter.create<arith::ShLIOp>(loc, shiftedRight, shiftLeftConstantOp);
+
+    if (result) {
+      SmallVector<int64_t> shuffles =
+          constructShuffles(inputSize, numCycles, cycleLen, idx);
+      result = rewriter.create<vector::ShuffleOp>(loc, result, shiftedLeft,
+                                                  shuffles);
+
+      // After the first shuffle in the chain, the size of the input result will
+      // grow as we append more shuffles together to reconstruct the
+      // shuffledVectorType size. Each iteration they will retain numCycles more
+      // elements.
+      inputSize = numCycles * (idx + 1);
+    } else {
+      result = shiftedLeft;
+    }
+  }
+
   return result;
 }
 
@@ -939,6 +1086,11 @@ namespace {
 struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
   using OpRewritePattern::OpRewritePattern;
 
+  RewriteBitCastOfTruncI(MLIRContext *context, PatternBenefit benefit,
+                         unsigned maxCycleLen)
+      : OpRewritePattern<vector::BitCastOp>(context, benefit),
+        maxCycleLen{maxCycleLen} {}
+
   LogicalResult matchAndRewrite(vector::BitCastOp bitCastOp,
                                 PatternRewriter &rewriter) const override {
     // The source must be a trunc op.
@@ -961,8 +1113,12 @@ struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
     Value runningResult;
     for (const BitCastRewriter ::Metadata &metadata :
          bcr.precomputeMetadata(shuffledElementType)) {
-      runningResult = bcr.genericRewriteStep(
-          rewriter, bitCastOp->getLoc(), truncValue, runningResult, metadata);
+      runningResult =
+          bcr.useSplatStep(maxCycleLen)
+              ? bcr.splatRewriteStep(rewriter, bitCastOp->getLoc(), truncValue,
+                                     runningResult, metadata)
+              : bcr.genericRewriteStep(rewriter, bitCastOp->getLoc(),
+                                       truncValue, runningResult, metadata);
     }
 
     // Finalize the rewrite.
@@ -986,6 +1142,9 @@ struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
 
     return success();
   }
+
+private:
+  unsigned maxCycleLen;
 };
 } // namespace
 
@@ -1001,8 +1160,10 @@ template <typename ExtOpType>
 struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
   using OpRewritePattern<ExtOpType>::OpRewritePattern;
 
-  RewriteExtOfBitCast(MLIRContext *context, PatternBenefit benefit)
-      : OpRewritePattern<ExtOpType>(context, benefit) {}
+  RewriteExtOfBitCast(MLIRContext *context, PatternBenefit benefit,
+                      unsigned maxCycleLen)
+      : OpRewritePattern<ExtOpType>(context, benefit),
+        maxCycleLen{maxCycleLen} {}
 
   LogicalResult matchAndRewrite(ExtOpType extOp,
                                 PatternRewriter &rewriter) const override {
@@ -1026,8 +1187,12 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
         cast<IntegerType>(getElementTypeOrSelf(sourceValue.getType()));
     for (const BitCastRewriter::Metadata &metadata :
          bcr.precomputeMetadata(shuffledElementType)) {
-      runningResult = bcr.genericRewriteStep(
-          rewriter, bitCastOp->getLoc(), sourceValue, runningResult, metadata);
+      runningResult =
+          bcr.useSplatStep(maxCycleLen)
+              ? bcr.splatRewriteStep(rewriter, bitCastOp->getLoc(), sourceValue,
+                                     runningResult, metadata)
+              : bcr.genericRewriteStep(rewriter, bitCastOp->getLoc(),
+                                       sourceValue, runningResult, metadata);
     }
 
     // Finalize the rewrite.
@@ -1044,6 +1209,9 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
 
     return success();
   }
+
+private:
+  unsigned maxCycleLen;
 };
 
 /// Rewrite the i4 -> i8 part of any conversion into a sequence of shuffles and
@@ -1222,10 +1390,10 @@ void vector::populateVectorNarrowTypeEmulationPatterns(
 }
 
 void vector::populateVectorNarrowTypeRewritePatterns(
-    RewritePatternSet &patterns, PatternBenefit benefit) {
+    RewritePatternSet &patterns, PatternBenefit benefit, unsigned maxCycleLen) {
   patterns.add<RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>,
                RewriteExtOfBitCast<arith::ExtSIOp>>(patterns.getContext(),
-                                                    benefit);
+                                                    benefit, maxCycleLen);
 
   // Patterns for aligned cases. We set higher priority as they are expected to
   // generate better performance for aligned cases.
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
index 8f0148119806c9..8dd361b8a3573e 100644
--- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -146,6 +146,42 @@ func.func @f4(%a: vector<16xi16>) -> vector<8xi6> {
   return %1 : vector<8xi6>
 }
 
+// CHECK-LABEL: func.func @ftrunc_splat1(
+//  CHECK-SAME: %[[A:[0-9a-z]*]]: vector<2xi16>) -> vector<1xi8> {
+func.func @ftrunc_splat1(%a: vector<2xi16>) -> vector<1xi8> {
+  // CHECK-DAG: %[[MASK:.*]] = arith.constant dense<15> : vector<1xi16>
+  // CHECK-DAG: %[[SHL_CST:.*]] = arith.constant dense<4> : vector<1xi16>
+  // CHECK: %[[V0:.*]] = vector.shuffle %[[A]], %[[A]] [0] : vector<2xi16>, vector<2xi16>
+  // CHECK: %[[A0:.*]] = arith.andi %[[V0]], %[[MASK]] : vector<1xi16>
+  // CHECK: %[[V1:.*]] = vector.shuffle %[[A]], %[[A]] [1] : vector<2xi16>, vector<2xi16>
+  // CHECK: %[[A1:.*]] = arith.andi %[[V1]], %[[MASK]] : vector<1xi16>
+  // CHECK: %[[SHL0:.*]] = arith.shli %[[A1]], %[[SHL_CST]] : vector<1xi16>
+  // CHECK: %[[O1:.*]] = arith.ori %[[A0]], %[[SHL0]] : vector<1xi16>
+  // CHECK: %[[RES:.*]] = arith.trunci %[[O1]] : vector<1xi16> to vector<1xi8>
+  // return %[[RES]] : vector<1xi8>
+  %0 = arith.trunci %a : vector<2xi16> to vector<2xi4>
+  %1 = vector.bitcast %0 : vector<2xi4> to vector<1xi8>
+  return %1 : vector<1xi8>
+}
+
+// CHECK-LABEL: func.func @ftrunc_splat2(
+//  CHECK-SAME: %[[A:[0-9a-z]*]]: vector<4xi16>) -> vector<2xi8> {
+func.func @ftrunc_splat2(%a: vector<4xi16>) -> vector<2xi8> {
+  // CHECK-DAG: %[[MASK:.*]] = arith.constant dense<15> : vector<2xi16>
+  // CHECK-DAG: %[[SHL_CST:.*]] = arith.constant dense<4> : vector<2xi16>
+  // CHECK: %[[V0:.*]] = vector.shuffle %[[A]], %[[A]] [0, 2] : vector<4xi16>, vector<4xi16>
+  // CHECK: %[[A0:.*]] = arith.andi %[[V0]], %[[MASK]] : vector<2xi16>
+  // CHECK: %[[V1:.*]] = vector.shuffle %[[A]], %[[A]] [1, 3] : vector<4xi16>, vector<4xi16>
+  // CHECK: %[[A1:.*]] = arith.andi %[[V1]], %[[MASK]] : vector<2xi16>
+  // CHECK: %[[SHL0:.*]] = arith.shli %[[A1]], %[[SHL_CST]] : vector<2xi16>
+  // CHECK: %[[O1:.*]] = arith.ori %[[A0]], %[[SHL0]] : vector<2xi16>
+  // CHECK: %[[RES:.*]] = arith.trunci %[[O1]] : vector<2xi16> to vector<2xi8>
+  // return %[[RES]] : vector<2xi8>
+  %0 = arith.trunci %a : vector<4xi16> to vector<4xi4>
+  %1 = vector.bitcast %0 : vector<4xi4> to vector<2xi8>
+  return %1 : vector<2xi8>
+}
+
 // CHECK-LABEL: func.func @f1ext(
 //  CHECK-SAME: %[[A:[0-9a-z]*]]: vector<5xi8>) -> vector<8xi16> {
 func.func @f1ext(%a: vector<5xi8>) -> vector<8xi16> {
@@ -193,6 +229,44 @@ func.func @f3ext(%a: vector<5xi8>) -> vector<8xi17> {
   return %1 : vector<8xi17>
 }
 
+// CHECK-LABEL: func.func @fext_splat1(
+//  CHECK-SAME: %[[A:[0-9a-z]*]]: vector<2xi8>) -> vector<4xi16> {
+func.func @fext_splat1(%a: vector<2xi8>) -> vector<4xi16> {
+  // CHECK-DAG: %[[MASK:.*]] = arith.constant dense<[15, -16, 15, -16]> : vector<4xi8>
+  // CHECK-DAG: %[[SHR_CST:.*]] = arith.constant dense<4> : vector<4xi8>
+  // CHECK: %[[V0:.*]] = vector.shuffle %[[A]], %[[A]] [0, 0, 1, 1] : vector<2xi8>, vector<2xi8>
+  // CHECK: %[[A0:.*]] = arith.andi %[[V0]], %[[MASK]] : vector<4xi8>
+  // CHECK: %[[SHR0:.*]] = arith.shrui %[[A0]], %[[SHR_CST]] : vector<4xi8>
+  // CHECK: %[[V1:.*]] = vector.shuffle %[[A0]], %[[SHR0]] [0, 5, 2, 7] : vector<4xi8>, vector<4xi8>
+  // CHECK: %[[RES:.*]] = arith.extui %[[V1]] : vector<4xi8> to vector<4xi16>
+  // return %[[RES]] : vector<4xi16>
+  %0 = vector.bitcast %a : vector<2xi8> to vector<4xi4>
+  %1 = arith.extui %0 : vector<4xi4> to vector<4xi16>
+  return %1 : vector<4xi16>
+}
+
+// CHECK-LABEL: func.func @fext_splat2(
+//  CHECK-SAME: %[[A:[0-9a-z]*]]: vector<3xi16>) -> vector<12xi32> {
+func.func @fext_splat2(%a: vector<3xi16>) -> vector<12xi32> {
+  // CHECK-DAG: %[[MASK:.*]] = arith.constant dense<[15, 240, 3840, -4096, 15, 240, 3840, -4096, 15, 240, 3840, -4096]> : vector<12xi16>
+  // CHECK-DAG: %[[SHR_CST0:.*]] = arith.constant dense<4> : vector<12xi16>
+  // CHECK-DAG: %[[SHR_CST1:.*]] = arith.constant dense<8> : vector<12xi16>
+  // CHECK-DAG: %[[SHR_CST2:.*]] = arith.constant dense<12> : vector<12xi16>
+  // CHECK: %[[V0:.*]] = vector.shuffle %[[A]], %[[A]] [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2] : vector<3xi16>, vector<3xi16>
+  // CHECK: %[[A0:.*]] = arith.andi %[[V0]], %[[MASK]] : vector<12xi16>
+  // CHECK: %[[SHR0:.*]] = arith.shrui %[[A0]], %[[SHR_CST0]] : vector<12xi16>
+  // CHECK: %[[V1:.*]] = vector.shuffle %[[A0]], %[[SHR0]] [0, 13, 4, 17, 8, 21] : vector<12xi16>, vector<12xi16>
+  // CHECK: %[[SHR1:.*]] = arith.shrui %[[A0]], %[[SHR_CST1]] : vector<12xi16>
+  // CHECK: %[[V2:.*]] = vector.shuffle %[[V1]], %[[SHR1]] [0, 1, 8, 2, 3, 12, 4, 5, 16] : vector<6xi16>, vector<12xi16>
+  // CHECK: %[[SHR2:.*]] = arith.shrui %[[A0]], %[[SHR_CST2]] : vector<12xi16>
+  // CHECK: %[[V3:.*]] = vector.shuffle %[[V2]], %[[SHR2]] [0, 1, 2, 12, 3, 4, 5, 16, 6, 7, 8, 20] : vector<9xi16>, vector<12xi16>
+  // CHECK: %[[RES:.*]] = arith.extui %[[V3]] : vector<12xi16> to vector<12xi32>
+  // CHEKC: return %[[RES]] : vector<12xi32>
+  %0 = vector.bitcast %a : vector<3xi16> to vector<12xi4>
+  %1 = arith.extui %0 : vector<12xi4> to vector<12xi32>
+  return %1 : vector<12xi32>
+}
+
 // CHECK-LABEL: func.func @aligned_extsi(
 func.func @aligned_extsi(%a: vector<8xi4>) -> vector<8xi32> {
 // CHECK-SAME:    %[[IN:.*]]: vector<8xi4>) -> vector<8xi32> {
@@ -330,7 +404,7 @@ module attributes {transform.with_named_sequence} {
         : (!transform.any_op) -> !transform.any_op
 
     transform.apply_patterns to %f {
-      transform.apply_patterns.vector.rewrite_narrow_types
+      transform.apply_patterns.vector.rewrite_narrow_types { max_cycle_len = 4 }
     } : !transform.any_op
     transform.yield
   }
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-rewrite-narrow-types.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-rewrite-narrow-types.mlir
index a0b39a2b68f438..a7e13ea1a79c4f 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-rewrite-narrow-types.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-rewrite-narrow-types.mlir
@@ -124,6 +124,36 @@ func.func @f3(%v: vector<2xi48>) {
   return
 }
 
+func.func @print_as_i1_2xi8(%v : vector<2xi8>) {
+  %bitsi16 = vector.bitcast %v : vector<2xi8> to vector<16xi1>
+  vector.print %bitsi16 : vector<16xi1>
+  return
+}
+
+func.func @print_as_i1_4xi4(%v : vector<4xi4>) {
+  %bitsi16 = vector.bitcast %v : vector<4xi4> to vector<16xi1>
+  vector.print %bitsi16 : vector<16xi1>
+  return
+}
+
+func.func @ftrunc_splat(%v: vector<2xi24>) {
+  %trunc = arith.trunci %v : vector<2xi24> to vector<2xi8>
+  func.call @print_as_i1_2xi8(%trunc) : (vector<2xi8>) -> ()
+  //      CHECK: (
+  // CHECK-SAME: 0, 1, 1, 1, 1, 1, 1, 1,
+  // CHECK-SAME: 1, 1, 0, 0, 0, 0, 1, 1 )
+
+  %bitcast = vector.bitcast %trunc : vector<2xi8> to vector<4xi4>
+  func.call @print_as_i1_4xi4(%bitcast) : (vector<4xi4>) -> ()
+  //      CHECK: (
+  // CHECK-SAME: 0, 1, 1, 1,
+  // CHECK-SAME: 1, 1, 1, 1,
+  // CHECK-SAME: 1, 1, 0, 0,
+  // CHECK-SAME: 0, 0, 1, 1 )
+
+  return
+}
+
 func.func @print_as_i1_8xi5(%v : vector<8xi5>) {
   %bitsi40 = vector.bitcast %v : vector<8xi5> to vector<40xi1>
   vector.print %bitsi40 : vector<40xi1>
@@ -164,6 +194,32 @@ func.func @fext(%a: vector<5xi8>) {
   return
 }
 
+func.func @print_as_i1_4xi8(%v : vector<4xi8>) {
+  %bitsi32 = vector.bitcast %v : vector<4xi8> to vector<32xi1>
+  vector.print %bitsi32 : vector<32xi1>
+  return
+}
+
+func.func @fext_splat(%a: vector<2xi8>) {
+  %0 = vector.bitcast %a : vector<2xi8> to vector<4xi4>
+  func.call @print_as_i1_4xi4(%0) : (vector<4xi4>) -> ()
+  //      CHECK: (
+  // CHECK-SAME: 0, 1, 1, 1,
+  // CHECK-SAME: 1, 1, 1, 1,
+  // CHECK-SAME: 1, 1, 0, 0,
+  // CHECK-SAME: 0, 0, 1, 1 )
+
+  %1 = arith.extui %0 : vector<4xi4> to vector<4xi8>
+  func.call @print_as_i1_4xi8(%1) : (vector<4xi8>) -> ()
+  //      CHECK: (
+  // CHECK-SAME: 0, 1, 1, 1, 0, 0, 0, 0,
+  // CHECK-SAME: 1, 1, 1, 1, 0, 0, 0, 0,
+  // CHECK-SAME: 1, 1, 0, 0, 0, 0, 0, 0,
+  // CHECK-SAME: 0, 0, 1, 1, 0, 0, 0, 0 )
+
+  return
+}
+
 func.func @fcst_maskedload(%A: memref<?xi4>, %passthru: vector<6xi4>) -> vector<6xi4> {
   %c0 = arith.constant 0: index
   %mask = vector.constant_mask [3] : vector<6xi1>
@@ -190,9 +246,19 @@ func.func @entry() {
   func.call @f3(%v3) : (vector<2xi48>) -> ()
 
   %v4 = arith.constant dense<[
+    0xafe, 0xbc3
+  ]> : vector<2xi24>
+  func.call @ftrunc_splat(%v4) : (vector<2xi24>) -> ()
+
+  %v5 = arith.constant dense<[
     0xef, 0xee, 0xed, 0xec, 0xeb
   ]> : vector<5xi8>
-  func.call @fext(%v4) : (vector<5xi8>) -> ()
+  func.call @fext(%v5) : (vector<5xi8>) -> ()
+
+  %v6 = arith.constant dense<[
+    0xfe, 0xc3
+  ]> : vector<2xi8>
+  func.call @fext_splat(%v6) : (vector<2xi8>) -> ()
 
   // Set up memory.
   %c0 = arith.constant 0: index
@@ -218,7 +284,7 @@ module attributes {transform.with_named_sequence} {
         : (!transform.any_op) -> !transform.any_op
 
     transform.apply_patterns to %f {
-      transform.apply_patterns.vector.rewrite_narrow_types
+      transform.apply_patterns.vector.rewrite_narrow_types { max_cycle_len = 4 }
     } : !transform.any_op
     transform.yield
   }

>From 461b96911b198c05afe675441d7e7a94fd39b876 Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Fri, 12 Apr 2024 11:19:38 -0700
Subject: [PATCH 2/2] Split up the and operation over the cycles

  - allows for the shifts to only compute the shift on the required
    cycle and not the entire input
  - removes the initial vector.shuffle as we can operate directly on the
    input and merge them after
---
 .../Transforms/VectorEmulateNarrowType.cpp    | 61 ++++++-------------
 .../Vector/vector-rewrite-narrow-types.mlir   | 46 ++++++++------
 2 files changed, 44 insertions(+), 63 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index ecd3879e5f1f9e..425431b7fa3439 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -837,23 +837,14 @@ static bool isCyclic(SmallVector<T> xs, int64_t cycleLen) {
   return true;
 }
 
-static SmallVector<int64_t> constructShuffles(int64_t inputSize,
-                                              int64_t numCycles,
+static SmallVector<int64_t> constructShuffles(int64_t numCycles,
                                               int64_t cycleLen, int64_t idx) {
-  // If idx == 1, then the first operand of the shuffle will be the mask which
-  // will have the original size. So we need to step through the mask with a
-  // stride of cycleSize.
-  // When idx > 1, then the first operand will be the size of (idx * cycleSize)
-  // and so we take the first idx elements of the input and then append the
-  // strided mask value.
-  int64_t inputStride = idx == 1 ? cycleLen : idx;
-
   SmallVector<int64_t> shuffles;
   for (int64_t cycle = 0; cycle < numCycles; cycle++) {
     for (int64_t inputIdx = 0; inputIdx < idx; inputIdx++) {
-      shuffles.push_back(cycle * inputStride + inputIdx);
+      shuffles.push_back(cycle * idx + inputIdx);
     }
-    shuffles.push_back(inputSize + cycle * cycleLen + idx);
+    shuffles.push_back(numCycles * idx + cycle);
   }
   return shuffles;
 }
@@ -917,47 +908,31 @@ Value BitCastRewriter::splatRewriteStep(
     PatternRewriter &rewriter, Location loc, Value initialValue,
     Value runningResult, const BitCastRewriter::Metadata &metadata) {
 
-  // Initial result will be the Shifted Mask which will have the shuffles size.
-  int64_t inputSize = metadata.shuffles.size();
-  int64_t numCycles = inputSize / cycleLen;
-
-  auto shuffleOp = rewriter.create<vector::ShuffleOp>(
-      loc, initialValue, initialValue, metadata.shuffles);
-
-  // Intersect with the mask.
-  VectorType shuffledVectorType = shuffleOp.getResultVectorType();
-  auto constOp = rewriter.create<arith::ConstantOp>(
-      loc, DenseElementsAttr::get(shuffledVectorType, metadata.masks));
-  Value andValue = rewriter.create<arith::AndIOp>(loc, shuffleOp, constOp);
-
+  int64_t numCycles = metadata.shuffles.size() / cycleLen;
+  ShapedType vectorType = dyn_cast<ShapedType>(initialValue.getType());
   Value result;
   for (int64_t idx = 0; idx < cycleLen; idx++) {
+    // Intersect with the mask.
+    auto constOp = rewriter.create<arith::ConstantOp>(
+        loc, DenseElementsAttr::get(vectorType, metadata.masks[idx]));
+    Value andValue = rewriter.create<arith::AndIOp>(loc, initialValue, constOp);
+
     auto shiftRightConstantOp = rewriter.create<arith::ConstantOp>(
-        loc, SplatElementsAttr::get(shuffledVectorType,
-                                    metadata.shiftRightAmounts[idx]));
+        loc,
+        SplatElementsAttr::get(vectorType, metadata.shiftRightAmounts[idx]));
     Value shiftedRight =
         rewriter.create<arith::ShRUIOp>(loc, andValue, shiftRightConstantOp);
 
     auto shiftLeftConstantOp = rewriter.create<arith::ConstantOp>(
-        loc, SplatElementsAttr::get(shuffledVectorType,
-                                    metadata.shiftLeftAmounts[idx]));
+        loc,
+        SplatElementsAttr::get(vectorType, metadata.shiftLeftAmounts[idx]));
     Value shiftedLeft =
         rewriter.create<arith::ShLIOp>(loc, shiftedRight, shiftLeftConstantOp);
 
-    if (result) {
-      SmallVector<int64_t> shuffles =
-          constructShuffles(inputSize, numCycles, cycleLen, idx);
-      result = rewriter.create<vector::ShuffleOp>(loc, result, shiftedLeft,
-                                                  shuffles);
-
-      // After the first shuffle in the chain, the size of the input result will
-      // grow as we append more shuffles together to reconstruct the
-      // shuffledVectorType size. Each iteration they will retain numCycles more
-      // elements.
-      inputSize = numCycles * (idx + 1);
-    } else {
-      result = shiftedLeft;
-    }
+    SmallVector<int64_t> shuffles = constructShuffles(numCycles, cycleLen, idx);
+    result = result ? rewriter.create<vector::ShuffleOp>(loc, result,
+                                                         shiftedLeft, shuffles)
+                    : shiftedLeft;
   }
 
   return result;
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
index 8dd361b8a3573e..396a9e9ee2cb5b 100644
--- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -230,14 +230,15 @@ func.func @f3ext(%a: vector<5xi8>) -> vector<8xi17> {
 }
 
 // CHECK-LABEL: func.func @fext_splat1(
-//  CHECK-SAME: %[[A:[0-9a-z]*]]: vector<2xi8>) -> vector<4xi16> {
+//  CHECK-SAME: %[[ARG:[0-9a-z]*]]: vector<2xi8>) -> vector<4xi16> {
 func.func @fext_splat1(%a: vector<2xi8>) -> vector<4xi16> {
-  // CHECK-DAG: %[[MASK:.*]] = arith.constant dense<[15, -16, 15, -16]> : vector<4xi8>
-  // CHECK-DAG: %[[SHR_CST:.*]] = arith.constant dense<4> : vector<4xi8>
-  // CHECK: %[[V0:.*]] = vector.shuffle %[[A]], %[[A]] [0, 0, 1, 1] : vector<2xi8>, vector<2xi8>
-  // CHECK: %[[A0:.*]] = arith.andi %[[V0]], %[[MASK]] : vector<4xi8>
-  // CHECK: %[[SHR0:.*]] = arith.shrui %[[A0]], %[[SHR_CST]] : vector<4xi8>
-  // CHECK: %[[V1:.*]] = vector.shuffle %[[A0]], %[[SHR0]] [0, 5, 2, 7] : vector<4xi8>, vector<4xi8>
+  // CHECK-DAG: %[[MASK0:.*]] = arith.constant dense<15> : vector<2xi8>
+  // CHECK-DAG: %[[MASK1:.*]] = arith.constant dense<-16> : vector<2xi8>
+  // CHECK-DAG: %[[SHR_CST:.*]] = arith.constant dense<4> : vector<2xi8>
+  // CHECK-DAG: %[[A0:.*]] = arith.andi %[[ARG]], %[[MASK0]] : vector<2xi8>
+  // CHECK-DAG: %[[A1:.*]] = arith.andi %[[ARG]], %[[MASK1]] : vector<2xi8>
+  // CHECK: %[[SHR0:.*]] = arith.shrui %[[A1]], %[[SHR_CST]] : vector<2xi8>
+  // CHECK: %[[V1:.*]] = vector.shuffle %[[A0]], %[[SHR0]] [0, 2, 1, 3] : vector<2xi8>, vector<2xi8>
   // CHECK: %[[RES:.*]] = arith.extui %[[V1]] : vector<4xi8> to vector<4xi16>
   // return %[[RES]] : vector<4xi16>
   %0 = vector.bitcast %a : vector<2xi8> to vector<4xi4>
@@ -246,20 +247,25 @@ func.func @fext_splat1(%a: vector<2xi8>) -> vector<4xi16> {
 }
 
 // CHECK-LABEL: func.func @fext_splat2(
-//  CHECK-SAME: %[[A:[0-9a-z]*]]: vector<3xi16>) -> vector<12xi32> {
+//  CHECK-SAME: %[[ARG:[0-9a-z]*]]: vector<3xi16>) -> vector<12xi32> {
 func.func @fext_splat2(%a: vector<3xi16>) -> vector<12xi32> {
-  // CHECK-DAG: %[[MASK:.*]] = arith.constant dense<[15, 240, 3840, -4096, 15, 240, 3840, -4096, 15, 240, 3840, -4096]> : vector<12xi16>
-  // CHECK-DAG: %[[SHR_CST0:.*]] = arith.constant dense<4> : vector<12xi16>
-  // CHECK-DAG: %[[SHR_CST1:.*]] = arith.constant dense<8> : vector<12xi16>
-  // CHECK-DAG: %[[SHR_CST2:.*]] = arith.constant dense<12> : vector<12xi16>
-  // CHECK: %[[V0:.*]] = vector.shuffle %[[A]], %[[A]] [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2] : vector<3xi16>, vector<3xi16>
-  // CHECK: %[[A0:.*]] = arith.andi %[[V0]], %[[MASK]] : vector<12xi16>
-  // CHECK: %[[SHR0:.*]] = arith.shrui %[[A0]], %[[SHR_CST0]] : vector<12xi16>
-  // CHECK: %[[V1:.*]] = vector.shuffle %[[A0]], %[[SHR0]] [0, 13, 4, 17, 8, 21] : vector<12xi16>, vector<12xi16>
-  // CHECK: %[[SHR1:.*]] = arith.shrui %[[A0]], %[[SHR_CST1]] : vector<12xi16>
-  // CHECK: %[[V2:.*]] = vector.shuffle %[[V1]], %[[SHR1]] [0, 1, 8, 2, 3, 12, 4, 5, 16] : vector<6xi16>, vector<12xi16>
-  // CHECK: %[[SHR2:.*]] = arith.shrui %[[A0]], %[[SHR_CST2]] : vector<12xi16>
-  // CHECK: %[[V3:.*]] = vector.shuffle %[[V2]], %[[SHR2]] [0, 1, 2, 12, 3, 4, 5, 16, 6, 7, 8, 20] : vector<9xi16>, vector<12xi16>
+  // CHECK-DAG: %[[MASK0:.*]] = arith.constant dense<15> : vector<3xi16>
+  // CHECK-DAG: %[[MASK1:.*]] = arith.constant dense<240> : vector<3xi16>
+  // CHECK-DAG: %[[MASK2:.*]] = arith.constant dense<3840> : vector<3xi16>
+  // CHECK-DAG: %[[MASK3:.*]] = arith.constant dense<-4096> : vector<3xi16>
+  // CHECK-DAG: %[[SHR_CST0:.*]] = arith.constant dense<4> : vector<3xi16>
+  // CHECK-DAG: %[[SHR_CST1:.*]] = arith.constant dense<8> : vector<3xi16>
+  // CHECK-DAG: %[[SHR_CST2:.*]] = arith.constant dense<12> : vector<3xi16>
+  // CHECK: %[[A0:.*]] = arith.andi %[[ARG]], %[[MASK0]] : vector<3xi16>
+  // CHECK: %[[A1:.*]] = arith.andi %[[ARG]], %[[MASK1]] : vector<3xi16>
+  // CHECK: %[[SHR0:.*]] = arith.shrui %[[A1]], %[[SHR_CST0]] : vector<3xi16>
+  // CHECK: %[[V1:.*]] = vector.shuffle %[[A0]], %[[SHR0]] [0, 3, 1, 4, 2, 5] : vector<3xi16>, vector<3xi16>
+  // CHECK: %[[A2:.*]] = arith.andi %[[ARG]], %[[MASK2]] : vector<3xi16>
+  // CHECK: %[[SHR1:.*]] = arith.shrui %[[A2]], %[[SHR_CST1]] : vector<3xi16>
+  // CHECK: %[[V2:.*]] = vector.shuffle %[[V1]], %[[SHR1]] [0, 1, 6, 2, 3, 7, 4, 5, 8] : vector<6xi16>, vector<3xi16>
+  // CHECK: %[[A3:.*]] = arith.andi %[[ARG]], %[[MASK3]] : vector<3xi16>
+  // CHECK: %[[SHR2:.*]] = arith.shrui %[[A3]], %[[SHR_CST2]] : vector<3xi16>
+  // CHECK: %[[V3:.*]] = vector.shuffle %[[V2]], %[[SHR2]] [0, 1, 2, 9, 3, 4, 5, 10, 6, 7, 8, 11] : vector<9xi16>, vector<3xi16>
   // CHECK: %[[RES:.*]] = arith.extui %[[V3]] : vector<12xi16> to vector<12xi32>
   // CHEKC: return %[[RES]] : vector<12xi32>
   %0 = vector.bitcast %a : vector<3xi16> to vector<12xi4>



More information about the Mlir-commits mailing list