[Mlir-commits] [mlir] [mlir][Vector] Add patterns for efficient unsigned i4 -> i8 conversion emulation (PR #89131)

Kojo Acquah llvmlistbot at llvm.org
Fri Apr 26 11:52:20 PDT 2024


https://github.com/KoolJBlack updated https://github.com/llvm/llvm-project/pull/89131

>From f879baba6aef58d0a22074fa7d97336f3f7d87ef Mon Sep 17 00:00:00 2001
From: Kojo Acquah <kooljblack at google.com>
Date: Wed, 17 Apr 2024 19:51:15 +0000
Subject: [PATCH 1/2] unsigned emulation for i4

---
 .../Transforms/VectorEmulateNarrowType.cpp    | 78 +++++++++++++++++++
 .../Vector/vector-rewrite-narrow-types.mlir   | 45 ++++++++++-
 2 files changed, 122 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index d24721f3defa65..d9ecd01aa35e25 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -880,6 +880,38 @@ static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
   return rewriter.create<vector::InterleaveOp>(loc, low, high);
 }
 
+/// Rewrite the i4 -> i8 unsigned extension into a sequence of shuffles and
+/// bitwise ops that take advantage of high-level information to avoid leaving
+/// LLVM to scramble with peephole optimizations.
+static Value rewriteI4ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
+                                      Value srcValue) {
+  VectorType srcVecType = cast<VectorType>(srcValue.getType());
+  assert(srcVecType.getElementType().isSignlessInteger(4) &&
+         "Expected i4 type");
+
+  // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
+  SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
+  constexpr int64_t i4Toi8BitwidthFactor = 2;
+  i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor;
+  auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
+  Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
+
+  // 2 Extend the i4 elements using shifts & masking. Low i4 elemens of each
+  //  byte are place in one vector and the high i4 elements in another vector.
+  constexpr unsigned char lowBitsMask = 15; // Equivalent to [0000IIII] bit mask
+  auto lowBitsMaskValues = rewriter.create<arith::ConstantOp>(
+      loc, DenseElementsAttr::get(i8VecType, lowBitsMask));
+  Value low = rewriter.create<arith::AndIOp>(loc, i8Vector.getType(), i8Vector,
+                                             lowBitsMaskValues);
+  constexpr int8_t highBitsToShift = 4;
+  auto highShiftValues = rewriter.create<arith::ConstantOp>(
+      loc, DenseElementsAttr::get(i8VecType, highBitsToShift));
+  Value high = rewriter.create<arith::ShRSIOp>(loc, i8Vector, highShiftValues);
+
+  // 3. Interleave low and high i8 elements.
+  return rewriter.create<vector::InterleaveOp>(loc, low, high);
+}
+
 /// Rewrite the i8 -> i4 truncation into a sequence of shuffles and bitwise ops
 /// that take advantage of high-level information to avoid leaving LLVM to
 /// scramble with peephole optimizations.
@@ -1099,6 +1131,50 @@ struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
   }
 };
 
+/// Rewrite the i4 -> i8 part of any unsigned conversion into a sequence of
+/// shuffles and bitwise ops that take advantage of high-level information to
+/// avoid leaving LLVM to scramble with peephole optimizations.
+///
+/// For example:
+///    arith.extui %in : vector<8xi4> to vector<8xi32>
+///      is rewritten as
+///        %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
+///        %1 = arith.andi %0, 15 : vector<4xi8>
+///        %2 = arith.shrsi %0, 4 : vector<4xi8>
+///        %3 = vector.interleave %1, %2 : vector<4xi8>
+///        %4 = arith.extsi %3 : vector<8xi8> to vector<8xi32>
+///
+template <typename ConversionOpType>
+struct RewriteAlignedSubByteIntUnsignedExt
+    : OpRewritePattern<ConversionOpType> {
+  using OpRewritePattern<ConversionOpType>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ConversionOpType conversionOp,
+                                PatternRewriter &rewriter) const override {
+    // Verify the preconditions.
+    Value srcValue = conversionOp.getIn();
+    auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
+    auto dstVecType = dyn_cast<VectorType>(conversionOp.getType());
+    if (failed(
+            commonConversionPrecondition(rewriter, dstVecType, conversionOp)))
+      return failure();
+
+    // Check general alignment preconditions.
+    if (failed(alignedConversionPrecondition(rewriter, srcVecType, dstVecType,
+                                             conversionOp)))
+      return failure();
+
+    // Perform the rewrite.
+    Value subByteExt =
+        rewriteI4ToI8UnsignedExt(rewriter, conversionOp.getLoc(), srcValue);
+
+    // Finalize the rewrite.
+    rewriter.replaceOpWithNewOp<ConversionOpType>(
+        conversionOp, conversionOp.getType(), subByteExt);
+    return success();
+  }
+};
+
 /// Rewrite the i8 -> i4 part of any truncation into a sequence of shuffles and
 /// bitwise ops that take advantage of high-level information to avoid leaving
 /// LLVM to scramble with peephole optimizations.
@@ -1233,6 +1309,8 @@ void vector::populateVectorNarrowTypeRewritePatterns(
                RewriteAlignedSubByteIntSignedExt<arith::SIToFPOp>,
                RewriteAlignedSubByteIntTrunc>(patterns.getContext(),
                                               benefit.getBenefit() + 1);
+  patterns.add<RewriteAlignedSubByteIntUnsignedExt<arith::ExtUIOp>>(
+      patterns.getContext(), benefit.getBenefit() + 1);
 }
 
 void vector::populateVectorTransposeNarrowTypeRewritePatterns(
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
index 8f0148119806c9..6d2b49889a3392 100644
--- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -324,6 +324,50 @@ func.func @i7_transpose(%a: vector<8x16xi7>) -> vector<16x8xi7> {
   return %0 : vector<16x8xi7>
 }
 
+// CHECK-LABEL: func.func @aligned_extui(
+func.func @aligned_extui(%a: vector<8xi4>) -> vector<8xi32> {
+// CHECK-SAME:                             %[[VAL_0:.*]]: vector<8xi4>) -> vector<8xi32> {
+// CHECK:           %[[VAL_1:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK:           %[[VAL_2:.*]] = arith.constant dense<15> : vector<4xi8>
+// CHECK:           %[[VAL_3:.*]] = vector.bitcast %[[VAL_0]] : vector<8xi4> to vector<4xi8>
+// CHECK:           %[[VAL_4:.*]] = arith.andi %[[VAL_3]], %[[VAL_2]] : vector<4xi8>
+// CHECK:           %[[VAL_5:.*]] = arith.shrsi %[[VAL_3]], %[[VAL_1]] : vector<4xi8>
+// CHECK:           %[[VAL_6:.*]] = vector.interleave %[[VAL_4]], %[[VAL_5]] : vector<4xi8>
+// CHECK:           %[[VAL_7:.*]] = arith.extui %[[VAL_6]] : vector<8xi8> to vector<8xi32>
+  %0 = arith.extui %a : vector<8xi4> to vector<8xi32>
+  return %0 : vector<8xi32>
+}
+
+
+// CHECK-LABEL: func.func @aligned_extui_2d(
+func.func @aligned_extui_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
+// CHECK-SAME:                                %[[VAL_0:.*]]: vector<8x32xi4>) -> vector<8x32xi32> {
+// CHECK:           %[[VAL_1:.*]] = arith.constant dense<4> : vector<8x16xi8>
+// CHECK:           %[[VAL_2:.*]] = arith.constant dense<15> : vector<8x16xi8>
+// CHECK:           %[[VAL_3:.*]] = vector.bitcast %[[VAL_0]] : vector<8x32xi4> to vector<8x16xi8>
+// CHECK:           %[[VAL_4:.*]] = arith.andi %[[VAL_3]], %[[VAL_2]] : vector<8x16xi8>
+// CHECK:           %[[VAL_5:.*]] = arith.shrsi %[[VAL_3]], %[[VAL_1]] : vector<8x16xi8>
+// CHECK:           %[[VAL_6:.*]] = vector.interleave %[[VAL_4]], %[[VAL_5]] : vector<8x16xi8>
+// CHECK:           %[[VAL_7:.*]] = arith.extui %[[VAL_6]] : vector<8x32xi8> to vector<8x32xi32>
+// CHECK:           return %[[VAL_7]] : vector<8x32xi32>
+  %0 = arith.extui %a : vector<8x32xi4> to vector<8x32xi32>
+  return %0 : vector<8x32xi32>
+}
+
+
+// CHECK-LABEL: func.func @aligned_extui_base_case(
+func.func @aligned_extui_base_case(%a: vector<8xi4>) -> vector<8xi8> {
+// CHECK-SAME:                                       %[[VAL_0:.*]]: vector<8xi4>) -> vector<8xi8> {
+// CHECK:           %[[VAL_1:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK:           %[[VAL_2:.*]] = arith.constant dense<15> : vector<4xi8>
+// CHECK:           %[[VAL_3:.*]] = vector.bitcast %[[VAL_0]] : vector<8xi4> to vector<4xi8>
+// CHECK:           %[[VAL_4:.*]] = arith.andi %[[VAL_3]], %[[VAL_2]] : vector<4xi8>
+// CHECK:           %[[VAL_5:.*]] = arith.shrsi %[[VAL_3]], %[[VAL_1]] : vector<4xi8>
+// CHECK:           %[[VAL_6:.*]] = vector.interleave %[[VAL_4]], %[[VAL_5]] : vector<4xi8>
+  %0 = arith.extui %a : vector<8xi4> to vector<8xi8>
+  return %0 : vector<8xi8>
+}
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
     %f = transform.structured.match ops{["func.func"]} in %module_op
@@ -335,4 +379,3 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
-

>From 1b83c81fc2ca013ec0ff9d4cbda5753af3b49434 Mon Sep 17 00:00:00 2001
From: Kojo Acquah <kooljblack at google.com>
Date: Fri, 26 Apr 2024 18:30:39 +0000
Subject: [PATCH 2/2] review comments

---
 .../Transforms/VectorEmulateNarrowType.cpp    | 80 ++++++-------------
 .../Vector/vector-rewrite-narrow-types.mlir   | 47 +++++------
 2 files changed, 48 insertions(+), 79 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index d9ecd01aa35e25..a301b919dc5232 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -896,17 +896,17 @@ static Value rewriteI4ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
   auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
   Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
 
-  // 2 Extend the i4 elements using shifts & masking. Low i4 elemens of each
-  //  byte are place in one vector and the high i4 elements in another vector.
-  constexpr unsigned char lowBitsMask = 15; // Equivalent to [0000IIII] bit mask
+  // 2 Extend the i4 elements using shifts & masking. Low i4 elements of each
+  //  byte are placed in one vector and the high i4 elements in another vector.
+  constexpr uint8_t lowBitsMask = 15; // Equivalent to [00001111] bit mask
   auto lowBitsMaskValues = rewriter.create<arith::ConstantOp>(
       loc, DenseElementsAttr::get(i8VecType, lowBitsMask));
-  Value low = rewriter.create<arith::AndIOp>(loc, i8Vector.getType(), i8Vector,
+  Value low = rewriter.create<arith::AndIOp>(loc, i8VecType, i8Vector,
                                              lowBitsMaskValues);
   constexpr int8_t highBitsToShift = 4;
   auto highShiftValues = rewriter.create<arith::ConstantOp>(
       loc, DenseElementsAttr::get(i8VecType, highBitsToShift));
-  Value high = rewriter.create<arith::ShRSIOp>(loc, i8Vector, highShiftValues);
+  Value high = rewriter.create<arith::ShRUIOp>(loc, i8Vector, highShiftValues);
 
   // 3. Interleave low and high i8 elements.
   return rewriter.create<vector::InterleaveOp>(loc, low, high);
@@ -1080,9 +1080,10 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
 
 /// Rewrite the i4 -> i8 part of any conversion into a sequence of shuffles and
 /// bitwise ops that take advantage of high-level information to avoid leaving
-/// LLVM to scramble with peephole optimizations.
+/// LLVM to scramble with peephole optimizations. Templated to choose between
+/// signed and unsigned conversions.
 ///
-/// For example:
+/// For example (signed):
 ///    arith.extsi %in : vector<8xi4> to vector<8xi32>
 ///      is rewriten as
 ///        %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
@@ -1101,60 +1102,25 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
 ///        %4 = vector.interleave %2, %3 : vector<4xi8>
 ///        %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
 ///
-template <typename ConversionOpType>
-struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
-  using OpRewritePattern<ConversionOpType>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(ConversionOpType conversionOp,
-                                PatternRewriter &rewriter) const override {
-    // Verify the preconditions.
-    Value srcValue = conversionOp.getIn();
-    auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
-    auto dstVecType = dyn_cast<VectorType>(conversionOp.getType());
-    if (failed(
-            commonConversionPrecondition(rewriter, dstVecType, conversionOp)))
-      return failure();
-
-    // Check general alignment preconditions.
-    if (failed(alignedConversionPrecondition(rewriter, srcVecType, dstVecType,
-                                             conversionOp)))
-      return failure();
-
-    // Perform the rewrite.
-    Value subByteExt =
-        rewriteI4ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue);
-
-    // Finalize the rewrite.
-    rewriter.replaceOpWithNewOp<ConversionOpType>(
-        conversionOp, conversionOp.getType(), subByteExt);
-    return success();
-  }
-};
-
-/// Rewrite the i4 -> i8 part of any unsigned conversion into a sequence of
-/// shuffles and bitwise ops that take advantage of high-level information to
-/// avoid leaving LLVM to scramble with peephole optimizations.
-///
-/// For example:
+/// Example (unsigned):
 ///    arith.extui %in : vector<8xi4> to vector<8xi32>
 ///      is rewritten as
 ///        %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
 ///        %1 = arith.andi %0, 15 : vector<4xi8>
-///        %2 = arith.shrsi %0, 4 : vector<4xi8>
+///        %2 = arith.shrui %0, 4 : vector<4xi8>
 ///        %3 = vector.interleave %1, %2 : vector<4xi8>
-///        %4 = arith.extsi %3 : vector<8xi8> to vector<8xi32>
+///        %4 = arith.extui %3 : vector<8xi8> to vector<8xi32>
 ///
-template <typename ConversionOpType>
-struct RewriteAlignedSubByteIntUnsignedExt
-    : OpRewritePattern<ConversionOpType> {
+template <typename ConversionOpType, bool isSigned>
+struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
   using OpRewritePattern<ConversionOpType>::OpRewritePattern;
 
   LogicalResult matchAndRewrite(ConversionOpType conversionOp,
                                 PatternRewriter &rewriter) const override {
     // Verify the preconditions.
     Value srcValue = conversionOp.getIn();
-    auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
-    auto dstVecType = dyn_cast<VectorType>(conversionOp.getType());
+    auto srcVecType = cast<VectorType>(srcValue.getType());
+    auto dstVecType = cast<VectorType>(conversionOp.getType());
     if (failed(
             commonConversionPrecondition(rewriter, dstVecType, conversionOp)))
       return failure();
@@ -1165,8 +1131,14 @@ struct RewriteAlignedSubByteIntUnsignedExt
       return failure();
 
     // Perform the rewrite.
-    Value subByteExt =
-        rewriteI4ToI8UnsignedExt(rewriter, conversionOp.getLoc(), srcValue);
+    Value subByteExt;
+    if (isSigned) {
+      subByteExt =
+          rewriteI4ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue);
+    } else {
+      subByteExt =
+          rewriteI4ToI8UnsignedExt(rewriter, conversionOp.getLoc(), srcValue);
+    }
 
     // Finalize the rewrite.
     rewriter.replaceOpWithNewOp<ConversionOpType>(
@@ -1305,11 +1277,11 @@ void vector::populateVectorNarrowTypeRewritePatterns(
 
   // Patterns for aligned cases. We set higher priority as they are expected to
   // generate better performance for aligned cases.
-  patterns.add<RewriteAlignedSubByteIntSignedExt<arith::ExtSIOp>,
-               RewriteAlignedSubByteIntSignedExt<arith::SIToFPOp>,
+  patterns.add<RewriteAlignedSubByteIntExt<arith::ExtSIOp, /*isSigned=*/true>,
+               RewriteAlignedSubByteIntExt<arith::SIToFPOp, /*isSigned=*/true>,
                RewriteAlignedSubByteIntTrunc>(patterns.getContext(),
                                               benefit.getBenefit() + 1);
-  patterns.add<RewriteAlignedSubByteIntUnsignedExt<arith::ExtUIOp>>(
+  patterns.add<RewriteAlignedSubByteIntExt<arith::ExtUIOp, /*isSigned=*/false>>(
       patterns.getContext(), benefit.getBenefit() + 1);
 }
 
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
index 6d2b49889a3392..614b2d4945348b 100644
--- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -326,44 +326,41 @@ func.func @i7_transpose(%a: vector<8x16xi7>) -> vector<16x8xi7> {
 
 // CHECK-LABEL: func.func @aligned_extui(
 func.func @aligned_extui(%a: vector<8xi4>) -> vector<8xi32> {
-// CHECK-SAME:                             %[[VAL_0:.*]]: vector<8xi4>) -> vector<8xi32> {
-// CHECK:           %[[VAL_1:.*]] = arith.constant dense<4> : vector<4xi8>
-// CHECK:           %[[VAL_2:.*]] = arith.constant dense<15> : vector<4xi8>
-// CHECK:           %[[VAL_3:.*]] = vector.bitcast %[[VAL_0]] : vector<8xi4> to vector<4xi8>
-// CHECK:           %[[VAL_4:.*]] = arith.andi %[[VAL_3]], %[[VAL_2]] : vector<4xi8>
-// CHECK:           %[[VAL_5:.*]] = arith.shrsi %[[VAL_3]], %[[VAL_1]] : vector<4xi8>
-// CHECK:           %[[VAL_6:.*]] = vector.interleave %[[VAL_4]], %[[VAL_5]] : vector<4xi8>
-// CHECK:           %[[VAL_7:.*]] = arith.extui %[[VAL_6]] : vector<8xi8> to vector<8xi32>
+// CHECK-SAME:                             %[[IN:.*]]: vector<8xi4>) -> vector<8xi32> {
+// CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK:           %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
+// CHECK:           %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<4xi8>
+// CHECK:           %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
+// CHECK:           %[[I32:.*]] = arith.extui %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32>
   %0 = arith.extui %a : vector<8xi4> to vector<8xi32>
   return %0 : vector<8xi32>
 }
 
-
 // CHECK-LABEL: func.func @aligned_extui_2d(
 func.func @aligned_extui_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
 // CHECK-SAME:                                %[[VAL_0:.*]]: vector<8x32xi4>) -> vector<8x32xi32> {
-// CHECK:           %[[VAL_1:.*]] = arith.constant dense<4> : vector<8x16xi8>
-// CHECK:           %[[VAL_2:.*]] = arith.constant dense<15> : vector<8x16xi8>
-// CHECK:           %[[VAL_3:.*]] = vector.bitcast %[[VAL_0]] : vector<8x32xi4> to vector<8x16xi8>
-// CHECK:           %[[VAL_4:.*]] = arith.andi %[[VAL_3]], %[[VAL_2]] : vector<8x16xi8>
-// CHECK:           %[[VAL_5:.*]] = arith.shrsi %[[VAL_3]], %[[VAL_1]] : vector<8x16xi8>
-// CHECK:           %[[VAL_6:.*]] = vector.interleave %[[VAL_4]], %[[VAL_5]] : vector<8x16xi8>
-// CHECK:           %[[VAL_7:.*]] = arith.extui %[[VAL_6]] : vector<8x32xi8> to vector<8x32xi32>
-// CHECK:           return %[[VAL_7]] : vector<8x32xi32>
+// CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
+// CHECK:           %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<8x16xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[VAL_0]] : vector<8x32xi4> to vector<8x16xi8>
+// CHECK:           %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<8x16xi8>
+// CHECK:           %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8>
+// CHECK:           %[[I32:.*]] = arith.extui %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32>
   %0 = arith.extui %a : vector<8x32xi4> to vector<8x32xi32>
   return %0 : vector<8x32xi32>
 }
 
-
 // CHECK-LABEL: func.func @aligned_extui_base_case(
 func.func @aligned_extui_base_case(%a: vector<8xi4>) -> vector<8xi8> {
-// CHECK-SAME:                                       %[[VAL_0:.*]]: vector<8xi4>) -> vector<8xi8> {
-// CHECK:           %[[VAL_1:.*]] = arith.constant dense<4> : vector<4xi8>
-// CHECK:           %[[VAL_2:.*]] = arith.constant dense<15> : vector<4xi8>
-// CHECK:           %[[VAL_3:.*]] = vector.bitcast %[[VAL_0]] : vector<8xi4> to vector<4xi8>
-// CHECK:           %[[VAL_4:.*]] = arith.andi %[[VAL_3]], %[[VAL_2]] : vector<4xi8>
-// CHECK:           %[[VAL_5:.*]] = arith.shrsi %[[VAL_3]], %[[VAL_1]] : vector<4xi8>
-// CHECK:           %[[VAL_6:.*]] = vector.interleave %[[VAL_4]], %[[VAL_5]] : vector<4xi8>
+// CHECK-SAME:                             %[[IN:.*]]: vector<8xi4>) -> vector<8xi8> {
+// CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK:           %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
+// CHECK:           %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<4xi8>
+// CHECK:           %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
   %0 = arith.extui %a : vector<8xi4> to vector<8xi8>
   return %0 : vector<8xi8>
 }



More information about the Mlir-commits mailing list