[Mlir-commits] [mlir] [WIP] using splat shifts (PR #87121)

Finn Plummer llvmlistbot at llvm.org
Fri Mar 29 15:50:19 PDT 2024


https://github.com/inbelic created https://github.com/llvm/llvm-project/pull/87121

None

>From f1a21fee566097588e9d137b932eff97fb4263db Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Thu, 28 Mar 2024 09:11:04 -0700
Subject: [PATCH] [WIP] using splat shifts

---
 .../Transforms/VectorEmulateNarrowType.cpp    | 161 +++++++++++++++++-
 .../Vector/vector-rewrite-narrow-types.mlir   |  74 ++++++++
 .../Vector/CPU/test-rewrite-narrow-types.mlir |  68 +++++++-
 3 files changed, 298 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index dc6f126aae4c87..a8b6d4eb2d21c6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -546,6 +546,14 @@ struct SourceElementRangeList : public SmallVector<SourceElementRange> {
 /// and `vector.bitcast ... : vector<2xi15> to vector<3xi10>` is decomposed as:
 ///   [0] = {0, [0, 10)}, {1, [0, 5)}
 ///   [1] = {1, [5, 10)}, {2, [0, 10)}
+/// and `vector.bitcast ... : vector<4xi4> to vector<2xi8>` is decomposed as:
+///   [0] = {0, [0, 4)}, {1, [0, 4)}
+///   [1] = {2, [0, 4)}, {3, [0, 4)}
+/// and `vector.bitcast ... : vector<2xi8> to vector<4xi4>` is decomposed as:
+///   [0] = {0, [0, 4)}
+///   [1] = {0, [4, 8)}
+///   [2] = {1, [0, 4)}
+///   [3] = {1, [4, 8)}
 struct BitCastBitsEnumerator {
   BitCastBitsEnumerator(VectorType sourceVectorType,
                         VectorType targetVectorType);
@@ -633,6 +641,35 @@ struct BitCastBitsEnumerator {
 /// `(shuffle -> and -> shiftright -> shiftleft -> or)` to iteratively update
 /// the result vector (i.e. the `shiftright -> shiftleft -> or` part) with the
 /// bits extracted from the source vector (i.e. the `shuffle -> and` part).
+///
+///
+/// When we consider the above algorithm to rewrite our vector.bitcast, we rely
+/// on using dynamic shift amounts for the left and right shifts. This can be
+/// inefficient on certain targets (RDNA GPUs) in contrast to a splat constant
+/// value. So when possible we can rewrite this as a combination of shifts with
+/// a constant splat value and then regroup the selected terms.
+///
+/// Eg. Instead of:
+///    res = arith.shrui x [0, 4, 8, 0, 4, 8]
+/// use:
+///    y = arith.shrui x [0, 0, 0, 0, 0, 0]  (can be folded away)
+///    y1 = arith.shrui x [4, 4, 4, 4, 4, 4]
+///    y2 = arith.shrui x [8, 8, 8, 8, 8, 8]
+///    y3 = vector.shuffle y y1 [0, 7, 3, 10]
+///    res = vector.shuffle y3 y2 [0, 1, 7, 2, 3, 10]
+///
+/// This is possible when the precomputed shift amounts following a cyclic
+/// pattern of [x, y, z, ..., x, y, z, ...] such that the cycle length,
+/// cycleLen, satisifies 1 < cycleLen < size(shiftAmounts). And the shuffles are
+/// of the form [0, 0, 0, ..., 1, 1, 1, ...]. A common pattern in
+/// (de)quantization, i24 -> 3xi8 or 3xi8 -> i24. The modified algorithm follows
+/// the same 2 steps as above, then it proceeds as follows:
+///
+///   2. for each element in the cycle, x, of the rightShiftAmounts create a
+///   shrui with a splat constant of x.
+///   3. repeat 2. with the respective leftShiftAmounts
+///   4. construct a chain of vector.shuffles that will reconstruct the result
+///   from the chained shifts
 struct BitCastRewriter {
   /// Helper metadata struct to hold the static quantities for the rewrite.
   struct Metadata {
@@ -656,10 +693,23 @@ struct BitCastRewriter {
                            Value initialValue, Value runningResult,
                            const BitCastRewriter::Metadata &metadata);
 
+  /// Rewrite one step of the sequence when able to use a splat constant for the
+  /// shiftright and shiftleft.
+  Value splatRewriteStep(PatternRewriter &rewriter, Location loc,
+                         Value initialValue, Value runningResult,
+                         const BitCastRewriter::Metadata &metadata);
+
+  bool useSplatStep() { return cycleLen > 1; }
+
 private:
   /// Underlying enumerator that encodes the provenance of the bits in the each
   /// element of the result vector.
   BitCastBitsEnumerator enumerator;
+
+  // Underlying cycleLen computed during precomputeMetadata. A cycleLen > 1
+  // denotes that there is a cycle in the precomputed shift amounts and we are
+  // able to use the splatRewriteStep.
+  int64_t cycleLen = 0;
 };
 
 } // namespace
@@ -775,8 +825,40 @@ static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
   return success();
 }
 
+// Check if the vector is a cycle of the first cycleLen elements.
+template <class T>
+static bool isCyclic(SmallVector<T> xs, int64_t cycleLen) {
+  for (int64_t idx = cycleLen, n = xs.size(); idx < n; idx++) {
+    if (xs[idx] != xs[idx % cycleLen])
+      return false;
+  }
+  return true;
+}
+
+static SmallVector<int64_t> constructShuffles(int64_t inputSize,
+                                              int64_t numCycles,
+                                              int64_t cycleLen, int64_t idx) {
+  // If idx == 1, then the first operand of the shuffle will be the mask which
+  // will have the original size. So we need to step through the mask with a
+  // stride of cycleSize.
+  // When idx > 1, then the first operand will be the size of (idx * cycleSize)
+  // and so we take the first idx elements of the input and then append the
+  // strided mask value.
+  int64_t inputStride = idx == 1 ? cycleLen : idx;
+
+  SmallVector<int64_t> shuffles;
+  for (int64_t cycle = 0; cycle < numCycles; cycle++) {
+    for (int64_t inputIdx = 0; inputIdx < idx; inputIdx++) {
+      shuffles.push_back(cycle * inputStride + inputIdx);
+    }
+    shuffles.push_back(inputSize + cycle * cycleLen + idx);
+  }
+  return shuffles;
+}
+
 SmallVector<BitCastRewriter::Metadata>
 BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
+  bool cyclicShifts = true;
   SmallVector<BitCastRewriter::Metadata> result;
   for (int64_t shuffleIdx = 0, e = enumerator.getMaxNumberOfEntries();
        shuffleIdx < e; ++shuffleIdx) {
@@ -811,8 +893,71 @@ BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
           IntegerAttr::get(shuffledElementType, shiftLeft));
     }
 
+    // Compute a potential cycle size by detecting the number of sourceElements
+    // at the start of shuffle that are the same
+    cycleLen = 1;
+    for (int64_t n = shuffles.size(); cycleLen < n; cycleLen++)
+      if (shuffles[cycleLen] != shuffles[0])
+        break;
+
+    cyclicShifts = cyclicShifts && (cycleLen < (int64_t)shuffles.size()) &&
+                   isCyclic(shiftRightAmounts, cycleLen) &&
+                   isCyclic(shiftLeftAmounts, cycleLen);
+
     result.push_back({shuffles, masks, shiftRightAmounts, shiftLeftAmounts});
   }
+
+  cycleLen = cyclicShifts ? cycleLen : 0;
+  return result;
+}
+
+Value BitCastRewriter::splatRewriteStep(
+    PatternRewriter &rewriter, Location loc, Value initialValue,
+    Value runningResult, const BitCastRewriter::Metadata &metadata) {
+
+  // Initial result will be the Shifted Mask which will have the shuffles size.
+  int64_t inputSize = metadata.shuffles.size();
+  int64_t numCycles = inputSize / cycleLen;
+
+  auto shuffleOp = rewriter.create<vector::ShuffleOp>(
+      loc, initialValue, initialValue, metadata.shuffles);
+
+  // Intersect with the mask.
+  VectorType shuffledVectorType = shuffleOp.getResultVectorType();
+  auto constOp = rewriter.create<arith::ConstantOp>(
+      loc, DenseElementsAttr::get(shuffledVectorType, metadata.masks));
+  Value andValue = rewriter.create<arith::AndIOp>(loc, shuffleOp, constOp);
+
+  Value result;
+  for (int64_t idx = 0; idx < cycleLen; idx++) {
+    auto shiftRightConstantOp = rewriter.create<arith::ConstantOp>(
+        loc, SplatElementsAttr::get(shuffledVectorType,
+                                    metadata.shiftRightAmounts[idx]));
+    Value shiftedRight =
+        rewriter.create<arith::ShRUIOp>(loc, andValue, shiftRightConstantOp);
+
+    auto shiftLeftConstantOp = rewriter.create<arith::ConstantOp>(
+        loc, SplatElementsAttr::get(shuffledVectorType,
+                                    metadata.shiftLeftAmounts[idx]));
+    Value shiftedLeft =
+        rewriter.create<arith::ShLIOp>(loc, shiftedRight, shiftLeftConstantOp);
+
+    if (result) {
+      SmallVector<int64_t> shuffles =
+          constructShuffles(inputSize, numCycles, cycleLen, idx);
+      result = rewriter.create<vector::ShuffleOp>(loc, result, shiftedLeft,
+                                                  shuffles);
+
+      // After the first shuffle in the chain, the size of the input result will
+      // grow as we append more shuffles together to reconstruct the
+      // shuffledVectorType size. Each iteration they will retain numCycles more
+      // elements.
+      inputSize = numCycles * (idx + 1);
+    } else {
+      result = shiftedLeft;
+    }
+  }
+
   return result;
 }
 
@@ -961,8 +1106,12 @@ struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
     Value runningResult;
     for (const BitCastRewriter ::Metadata &metadata :
          bcr.precomputeMetadata(shuffledElementType)) {
-      runningResult = bcr.genericRewriteStep(
-          rewriter, bitCastOp->getLoc(), truncValue, runningResult, metadata);
+      runningResult =
+          bcr.useSplatStep()
+              ? bcr.splatRewriteStep(rewriter, bitCastOp->getLoc(), truncValue,
+                                     runningResult, metadata)
+              : bcr.genericRewriteStep(rewriter, bitCastOp->getLoc(),
+                                       truncValue, runningResult, metadata);
     }
 
     // Finalize the rewrite.
@@ -1026,8 +1175,12 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
         cast<IntegerType>(getElementTypeOrSelf(sourceValue.getType()));
     for (const BitCastRewriter::Metadata &metadata :
          bcr.precomputeMetadata(shuffledElementType)) {
-      runningResult = bcr.genericRewriteStep(
-          rewriter, bitCastOp->getLoc(), sourceValue, runningResult, metadata);
+      runningResult =
+          bcr.useSplatStep()
+              ? bcr.splatRewriteStep(rewriter, bitCastOp->getLoc(), sourceValue,
+                                     runningResult, metadata)
+              : bcr.genericRewriteStep(rewriter, bitCastOp->getLoc(),
+                                       sourceValue, runningResult, metadata);
     }
 
     // Finalize the rewrite.
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
index 8f0148119806c9..897b1c703f22f3 100644
--- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -146,6 +146,42 @@ func.func @f4(%a: vector<16xi16>) -> vector<8xi6> {
   return %1 : vector<8xi6>
 }
 
+// CHECK-LABEL: func.func @ftrunc_splat1(
+//  CHECK-SAME: %[[A:[0-9a-z]*]]: vector<2xi16>) -> vector<1xi8> {
+func.func @ftrunc_splat1(%a: vector<2xi16>) -> vector<1xi8> {
+  // CHECK-DAG: %[[MASK:.*]] = arith.constant dense<15> : vector<1xi16>
+  // CHECK-DAG: %[[SHL_CST:.*]] = arith.constant dense<4> : vector<1xi16>
+  // CHECK: %[[V0:.*]] = vector.shuffle %[[A]], %[[A]] [0] : vector<2xi16>, vector<2xi16>
+  // CHECK: %[[A0:.*]] = arith.andi %[[V0]], %[[MASK]] : vector<1xi16>
+  // CHECK: %[[V1:.*]] = vector.shuffle %[[A]], %[[A]] [1] : vector<2xi16>, vector<2xi16>
+  // CHECK: %[[A1:.*]] = arith.andi %[[V1]], %[[MASK]] : vector<1xi16>
+  // CHECK: %[[SHL0:.*]] = arith.shli %[[A1]], %[[SHL_CST]] : vector<1xi16>
+  // CHECK: %[[O1:.*]] = arith.ori %[[A0]], %[[SHL0]] : vector<1xi16>
+  // CHECK: %[[RES:.*]] = arith.trunci %[[O1]] : vector<1xi16> to vector<1xi8>
+  // return %[[RES]] : vector<1xi8>
+  %0 = arith.trunci %a : vector<2xi16> to vector<2xi4>
+  %1 = vector.bitcast %0 : vector<2xi4> to vector<1xi8>
+  return %1 : vector<1xi8>
+}
+
+// CHECK-LABEL: func.func @ftrunc_splat2(
+//  CHECK-SAME: %[[A:[0-9a-z]*]]: vector<4xi16>) -> vector<2xi8> {
+func.func @ftrunc_splat2(%a: vector<4xi16>) -> vector<2xi8> {
+  // CHECK-DAG: %[[MASK:.*]] = arith.constant dense<15> : vector<2xi16>
+  // CHECK-DAG: %[[SHL_CST:.*]] = arith.constant dense<4> : vector<2xi16>
+  // CHECK: %[[V0:.*]] = vector.shuffle %[[A]], %[[A]] [0, 2] : vector<4xi16>, vector<4xi16>
+  // CHECK: %[[A0:.*]] = arith.andi %[[V0]], %[[MASK]] : vector<2xi16>
+  // CHECK: %[[V1:.*]] = vector.shuffle %[[A]], %[[A]] [1, 3] : vector<4xi16>, vector<4xi16>
+  // CHECK: %[[A1:.*]] = arith.andi %[[V1]], %[[MASK]] : vector<2xi16>
+  // CHECK: %[[SHL0:.*]] = arith.shli %[[A1]], %[[SHL_CST]] : vector<2xi16>
+  // CHECK: %[[O1:.*]] = arith.ori %[[A0]], %[[SHL0]] : vector<2xi16>
+  // CHECK: %[[RES:.*]] = arith.trunci %[[O1]] : vector<2xi16> to vector<2xi8>
+  // return %[[RES]] : vector<2xi8>
+  %0 = arith.trunci %a : vector<4xi16> to vector<4xi4>
+  %1 = vector.bitcast %0 : vector<4xi4> to vector<2xi8>
+  return %1 : vector<2xi8>
+}
+
 // CHECK-LABEL: func.func @f1ext(
 //  CHECK-SAME: %[[A:[0-9a-z]*]]: vector<5xi8>) -> vector<8xi16> {
 func.func @f1ext(%a: vector<5xi8>) -> vector<8xi16> {
@@ -193,6 +229,44 @@ func.func @f3ext(%a: vector<5xi8>) -> vector<8xi17> {
   return %1 : vector<8xi17>
 }
 
+// CHECK-LABEL: func.func @fext_splat1(
+//  CHECK-SAME: %[[A:[0-9a-z]*]]: vector<2xi8>) -> vector<4xi16> {
+func.func @fext_splat1(%a: vector<2xi8>) -> vector<4xi16> {
+  // CHECK-DAG: %[[MASK:.*]] = arith.constant dense<[15, -16, 15, -16]> : vector<4xi8>
+  // CHECK-DAG: %[[SHR_CST:.*]] = arith.constant dense<4> : vector<4xi8>
+  // CHECK: %[[V0:.*]] = vector.shuffle %[[A]], %[[A]] [0, 0, 1, 1] : vector<2xi8>, vector<2xi8>
+  // CHECK: %[[A0:.*]] = arith.andi %[[V0]], %[[MASK]] : vector<4xi8>
+  // CHECK: %[[SHR0:.*]] = arith.shrui %[[A0]], %[[SHR_CST]] : vector<4xi8>
+  // CHECK: %[[V1:.*]] = vector.shuffle %[[A0]], %[[SHR0]] [0, 5, 2, 7] : vector<4xi8>, vector<4xi8>
+  // CHECK: %[[RES:.*]] = arith.extui %[[V1]] : vector<4xi8> to vector<4xi16>
+  // return %[[RES]] : vector<4xi16>
+  %0 = vector.bitcast %a : vector<2xi8> to vector<4xi4>
+  %1 = arith.extui %0 : vector<4xi4> to vector<4xi16>
+  return %1 : vector<4xi16>
+}
+
+// CHECK-LABEL: func.func @fext_splat2(
+//  CHECK-SAME: %[[A:[0-9a-z]*]]: vector<3xi16>) -> vector<12xi32> {
+func.func @fext_splat2(%a: vector<3xi16>) -> vector<12xi32> {
+  // CHECK-DAG: %[[MASK:.*]] = arith.constant dense<[15, 240, 3840, -4096, 15, 240, 3840, -4096, 15, 240, 3840, -4096]> : vector<12xi16>
+  // CHECK-DAG: %[[SHR_CST0:.*]] = arith.constant dense<4> : vector<12xi16>
+  // CHECK-DAG: %[[SHR_CST1:.*]] = arith.constant dense<8> : vector<12xi16>
+  // CHECK-DAG: %[[SHR_CST2:.*]] = arith.constant dense<12> : vector<12xi16>
+  // CHECK: %[[V0:.*]] = vector.shuffle %[[A]], %[[A]] [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2] : vector<3xi16>, vector<3xi16>
+  // CHECK: %[[A0:.*]] = arith.andi %[[V0]], %[[MASK]] : vector<12xi16>
+  // CHECK: %[[SHR0:.*]] = arith.shrui %[[A0]], %[[SHR_CST0]] : vector<12xi16>
+  // CHECK: %[[V1:.*]] = vector.shuffle %[[A0]], %[[SHR0]] [0, 13, 4, 17, 8, 21] : vector<12xi16>, vector<12xi16>
+  // CHECK: %[[SHR1:.*]] = arith.shrui %[[A0]], %[[SHR_CST1]] : vector<12xi16>
+  // CHECK: %[[V2:.*]] = vector.shuffle %[[V1]], %[[SHR1]] [0, 1, 8, 2, 3, 12, 4, 5, 16] : vector<6xi16>, vector<12xi16>
+  // CHECK: %[[SHR2:.*]] = arith.shrui %[[A0]], %[[SHR_CST2]] : vector<12xi16>
+  // CHECK: %[[V3:.*]] = vector.shuffle %[[V2]], %[[SHR2]] [0, 1, 2, 12, 3, 4, 5, 16, 6, 7, 8, 20] : vector<9xi16>, vector<12xi16>
+  // CHECK: %[[RES:.*]] = arith.extui %[[V3]] : vector<12xi16> to vector<12xi32>
+  // CHEKC: return %[[RES]] : vector<12xi32>
+  %0 = vector.bitcast %a : vector<3xi16> to vector<12xi4>
+  %1 = arith.extui %0 : vector<12xi4> to vector<12xi32>
+  return %1 : vector<12xi32>
+}
+
 // CHECK-LABEL: func.func @aligned_extsi(
 func.func @aligned_extsi(%a: vector<8xi4>) -> vector<8xi32> {
 // CHECK-SAME:    %[[IN:.*]]: vector<8xi4>) -> vector<8xi32> {
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-rewrite-narrow-types.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-rewrite-narrow-types.mlir
index a0b39a2b68f438..dd1d6887ada274 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-rewrite-narrow-types.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-rewrite-narrow-types.mlir
@@ -124,6 +124,36 @@ func.func @f3(%v: vector<2xi48>) {
   return
 }
 
+func.func @print_as_i1_2xi8(%v : vector<2xi8>) {
+  %bitsi16 = vector.bitcast %v : vector<2xi8> to vector<16xi1>
+  vector.print %bitsi16 : vector<16xi1>
+  return
+}
+
+func.func @print_as_i1_4xi4(%v : vector<4xi4>) {
+  %bitsi16 = vector.bitcast %v : vector<4xi4> to vector<16xi1>
+  vector.print %bitsi16 : vector<16xi1>
+  return
+}
+
+func.func @ftrunc_splat(%v: vector<2xi24>) {
+  %trunc = arith.trunci %v : vector<2xi24> to vector<2xi8>
+  func.call @print_as_i1_2xi8(%trunc) : (vector<2xi8>) -> ()
+  //      CHECK: (
+  // CHECK-SAME: 0, 1, 1, 1, 1, 1, 1, 1,
+  // CHECK-SAME: 1, 1, 0, 0, 0, 0, 1, 1 )
+
+  %bitcast = vector.bitcast %trunc : vector<2xi8> to vector<4xi4>
+  func.call @print_as_i1_4xi4(%bitcast) : (vector<4xi4>) -> ()
+  //      CHECK: (
+  // CHECK-SAME: 0, 1, 1, 1,
+  // CHECK-SAME: 1, 1, 1, 1,
+  // CHECK-SAME: 1, 1, 0, 0,
+  // CHECK-SAME: 0, 0, 1, 1 )
+
+  return
+}
+
 func.func @print_as_i1_8xi5(%v : vector<8xi5>) {
   %bitsi40 = vector.bitcast %v : vector<8xi5> to vector<40xi1>
   vector.print %bitsi40 : vector<40xi1>
@@ -164,6 +194,32 @@ func.func @fext(%a: vector<5xi8>) {
   return
 }
 
+func.func @print_as_i1_4xi8(%v : vector<4xi8>) {
+  %bitsi32 = vector.bitcast %v : vector<4xi8> to vector<32xi1>
+  vector.print %bitsi32 : vector<32xi1>
+  return
+}
+
+func.func @fext_splat(%a: vector<2xi8>) {
+  %0 = vector.bitcast %a : vector<2xi8> to vector<4xi4>
+  func.call @print_as_i1_4xi4(%0) : (vector<4xi4>) -> ()
+  //      CHECK: (
+  // CHECK-SAME: 0, 1, 1, 1,
+  // CHECK-SAME: 1, 1, 1, 1,
+  // CHECK-SAME: 1, 1, 0, 0,
+  // CHECK-SAME: 0, 0, 1, 1 )
+
+  %1 = arith.extui %0 : vector<4xi4> to vector<4xi8>
+  func.call @print_as_i1_4xi8(%1) : (vector<4xi8>) -> ()
+  //      CHECK: (
+  // CHECK-SAME: 0, 1, 1, 1, 0, 0, 0, 0,
+  // CHECK-SAME: 1, 1, 1, 1, 0, 0, 0, 0,
+  // CHECK-SAME: 1, 1, 0, 0, 0, 0, 0, 0,
+  // CHECK-SAME: 0, 0, 1, 1, 0, 0, 0, 0 )
+
+  return
+}
+
 func.func @fcst_maskedload(%A: memref<?xi4>, %passthru: vector<6xi4>) -> vector<6xi4> {
   %c0 = arith.constant 0: index
   %mask = vector.constant_mask [3] : vector<6xi1>
@@ -190,9 +246,19 @@ func.func @entry() {
   func.call @f3(%v3) : (vector<2xi48>) -> ()
 
   %v4 = arith.constant dense<[
+    0xafe, 0xbc3
+  ]> : vector<2xi24>
+  func.call @ftrunc_splat(%v4) : (vector<2xi24>) -> ()
+
+  %v5 = arith.constant dense<[
     0xef, 0xee, 0xed, 0xec, 0xeb
   ]> : vector<5xi8>
-  func.call @fext(%v4) : (vector<5xi8>) -> ()
+  func.call @fext(%v5) : (vector<5xi8>) -> ()
+
+  %v6 = arith.constant dense<[
+    0xfe, 0xc3
+  ]> : vector<2xi8>
+  func.call @fext_splat(%v6) : (vector<2xi8>) -> ()
 
   // Set up memory.
   %c0 = arith.constant 0: index



More information about the Mlir-commits mailing list