[Mlir-commits] [mlir] [mlir][vector] Support n-D vectors in i8 to i4 trunci emulation (PR #94946)

Mubashar Ahmad llvmlistbot at llvm.org
Mon Jun 24 01:26:50 PDT 2024


https://github.com/mub-at-arm updated https://github.com/llvm/llvm-project/pull/94946

>From 3b152d7e79ea9e5a79af114ecba6d33bb8e6bdb7 Mon Sep 17 00:00:00 2001
From: "Mubashar.Ahmad at arm.com" <mubashar.ahmad at arm.com>
Date: Mon, 10 Jun 2024 09:20:30 +0000
Subject: [PATCH] [mlir][vector] VectorEmulateNarrowType uses deinterleave

VectorEmulateNarrowTyp now uses the
vector.deinterleave operation in its
lowering for truncating i8 to i4
types.
---
 .../Transforms/VectorEmulateNarrowType.cpp    | 54 ++++++-------------
 .../Vector/vector-rewrite-narrow-types.mlir   | 27 +++++++---
 2 files changed, 36 insertions(+), 45 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 59b6cb3ae667a..ac2a4d3abcc68 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -912,8 +912,8 @@ static Value rewriteI4ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
   return rewriter.create<vector::InterleaveOp>(loc, low, high);
 }
 
-/// Rewrite the i8 -> i4 truncation into a sequence of shuffles and bitwise ops
-/// that take advantage of high-level information to avoid leaving LLVM to
+/// Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise
+/// ops that take advantage of high-level information to avoid leaving LLVM to
 /// scramble with peephole optimizations.
 static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc,
                                 Value srcValue) {
@@ -922,39 +922,22 @@ static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc,
          "Expected i8 type");
 
   // 1. De-interleave low and high i8 elements.
-  int64_t vecDimSize = srcVecType.getShape().back();
-  SmallVector<int64_t> deinterleaveLowMaskValues;
-  SmallVector<int64_t> deinterleaveHighMaskValues;
-  assert((vecDimSize % 2) == 0 && "Odd number of i4 elements");
-  deinterleaveLowMaskValues.reserve(vecDimSize / 2);
-  deinterleaveHighMaskValues.reserve(vecDimSize / 2);
-  for (int i = 0, end = vecDimSize; i < end; i += 2) {
-    deinterleaveLowMaskValues.push_back(i);
-    deinterleaveHighMaskValues.push_back(i + 1);
-  }
-
-  auto lowShuffleOp = rewriter.create<vector::ShuffleOp>(
-      loc, srcValue, srcValue,
-      rewriter.getI64ArrayAttr(deinterleaveLowMaskValues));
-  auto highShuffleOp = rewriter.create<vector::ShuffleOp>(
-      loc, srcValue, srcValue,
-      rewriter.getI64ArrayAttr(deinterleaveHighMaskValues));
+  auto deinterleaveOp = rewriter.create<vector::DeinterleaveOp>(loc, srcValue);
 
   // 2. Zero out the upper side of each low i8 element.
   constexpr int8_t i8LowBitMask = 0x0F;
+  VectorType deinterI8VecType = deinterleaveOp.getResultVectorType();
   Value zeroOutMask = rewriter.create<arith::ConstantOp>(
-      loc,
-      DenseElementsAttr::get(lowShuffleOp.getResultVectorType(), i8LowBitMask));
-  Value zeroOutLow =
-      rewriter.create<arith::AndIOp>(loc, lowShuffleOp, zeroOutMask);
+      loc, DenseElementsAttr::get(deinterI8VecType, i8LowBitMask));
+  Value zeroOutLow = rewriter.create<arith::AndIOp>(
+      loc, deinterleaveOp.getRes1(), zeroOutMask);
 
   // 3. Move high i4 values to upper side of the byte.
   constexpr int8_t bitsToShift = 4;
-  VectorType deinterI8VecType = highShuffleOp.getResultVectorType();
   auto shiftValues = rewriter.create<arith::ConstantOp>(
       loc, DenseElementsAttr::get(deinterI8VecType, bitsToShift));
-  Value shlHigh =
-      rewriter.create<arith::ShLIOp>(loc, highShuffleOp, shiftValues);
+  Value shlHigh = rewriter.create<arith::ShLIOp>(loc, deinterleaveOp.getRes2(),
+                                                 shiftValues);
 
   // 4. Merge high and low i4 values.
   auto mergedHiLowOp = rewriter.create<arith::OrIOp>(loc, zeroOutLow, shlHigh);
@@ -1148,7 +1131,7 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
   }
 };
 
-/// Rewrite the i8 -> i4 part of any truncation into a sequence of shuffles and
+/// Rewrite the i8 -> i4 part of any truncation into a deinterleave and
 /// bitwise ops that take advantage of high-level information to avoid leaving
 /// LLVM to scramble with peephole optimizations.
 ///
@@ -1158,13 +1141,11 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
 ///
 ///        %cst = arith.constant dense<15> : vector<4xi8>
 ///        %cst_0 = arith.constant dense<4> : vector<4xi8>
-///        %0 = arith.trunci %in : vector<8xi32> to vector<8xi8>
-///        %1 = vector.shuffle %0, %0 [0, 2, 4, 6] : vector<8xi8>, vector<8xi8>
-///        %2 = vector.shuffle %0, %0 [1, 3, 5, 7] : vector<8xi8>, vector<8xi8>
-///        %3 = arith.andi %1, %cst : vector<4xi8>
-///        %4 = arith.shli %2, %cst_0 : vector<4xi8>
-///        %5 = arith.ori %3, %4 : vector<4xi8>
-///        %6 = vector.bitcast %5 : vector<4xi8> to vector<8xi4>
+///        %0, %1 = vector.deinterleave %in : vector<8xi8>, vector<8xi8>
+///        %2 = arith.andi %0, %cst : vector<4xi8>
+///        %3 = arith.shli %1, %cst_0 : vector<4xi8>
+///        %4 = arith.ori %2, %3 : vector<4xi8>
+///        %5 = vector.bitcast %4 : vector<4xi8> to vector<8xi4>
 ///
 struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
   using OpRewritePattern<arith::TruncIOp>::OpRewritePattern;
@@ -1178,11 +1159,6 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
     if (!srcVecType || !dstVecType)
       return failure();
 
-    // Only single dim vectors are supported until we have
-    // `vector.deinterleave`.
-    if (srcVecType.getRank() != 1)
-      return failure();
-
     if (failed(commonConversionPrecondition(rewriter, srcVecType, truncOp)))
       return failure();
 
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
index 614b2d4945348..84aaa9c61200b 100644
--- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -268,8 +268,7 @@ func.func @aligned_trunci(%a: vector<8xi32>) -> vector<8xi4> {
 // CHECK-DAG:       %[[LOW_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
 // CHECK-DAG:       %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
 // CHECK:           %[[I8:.*]] = arith.trunci %[[IN]] : vector<8xi32> to vector<8xi8>
-// CHECK:           %[[LOW:.*]] = vector.shuffle %[[I8]], %[[I8]] [0, 2, 4, 6] : vector<8xi8>, vector<8xi8>
-// CHECK:           %[[HIGH:.*]] = vector.shuffle %[[I8]], %[[I8]] [1, 3, 5, 7] : vector<8xi8>, vector<8xi8>
+// CHECK:           %[[LOW:.*]], %[[HIGH:.*]] = vector.deinterleave %[[I8]] : vector<8xi8> -> vector<4xi8>
 // CHECK:           %[[ZEROED_LOW:.*]] = arith.andi %[[LOW]], %[[LOW_MASK]] : vector<4xi8>
 // CHECK:           %[[SHL_HIGH:.*]] = arith.shli %[[HIGH]], %[[I4_BITS]] : vector<4xi8>
 // CHECK:           %[[MERGED:.*]] = arith.ori %[[ZEROED_LOW]], %[[SHL_HIGH]] : vector<4xi8>
@@ -283,8 +282,7 @@ func.func @aligned_trunci_base_case(%a: vector<8xi8>) -> vector<8xi4> {
 // CHECK-SAME:    %[[IN:.*]]: vector<8xi8>) -> vector<8xi4> {
 // CHECK-DAG:       %[[LOW_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
 // CHECK-DAG:       %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
-// CHECK:           %[[LOW:.*]] = vector.shuffle %[[IN]], %[[IN]] [0, 2, 4, 6] : vector<8xi8>, vector<8xi8>
-// CHECK:           %[[HIGH:.*]] = vector.shuffle %[[IN]], %[[IN]] [1, 3, 5, 7] : vector<8xi8>, vector<8xi8>
+// CHECK:           %[[LOW:.*]], %[[HIGH:.*]] = vector.deinterleave %[[IN]] : vector<8xi8> -> vector<4xi8>
 // CHECK:           %[[ZEROED_LOW:.*]] = arith.andi %[[LOW]], %[[LOW_MASK]] : vector<4xi8>
 // CHECK:           %[[SHL_HIGH:.*]] = arith.shli %[[HIGH]], %[[I4_BITS]] : vector<4xi8>
 // CHECK:           %[[MERGED:.*]] = arith.ori %[[ZEROED_LOW]], %[[SHL_HIGH]] : vector<4xi8>
@@ -299,17 +297,34 @@ func.func @aligned_trunci_2d(%a: vector<8x32xi32>) -> vector<8x32xi4> {
 // CHECK-NOT:       vector.andi
 // CHECK-NOT:       vector.shli
 // CHECK-NOT:       vector.ori
-// CHECK:           arith.trunci
+// CHECK:           arith.trunci {{.*}} : vector<8x32xi32> to vector<8x32xi8>
+// CHECK-NOT:       arith.trunci {{.*}} : vector<8x32xi8> to vector<8x32xi4>
+// CHECK:           vector.deinterleave
   %0 = arith.trunci %a : vector<8x32xi32> to vector<8x32xi4>
   return %0 : vector<8x32xi4>
 }
 
+// CHECK-LABEL: func.func @aligned_trunci_nd(
+// CHECK-SAME: %[[IN:.*]]: vector<3x8x32xi32>) -> vector<3x8x32xi4> {
+func.func @aligned_trunci_nd(%a: vector<3x8x32xi32>) -> vector<3x8x32xi4> {
+  // CHECK: %[[LEFT_SHIFT_BITS:.*]] = arith.constant dense<4> : vector<3x8x16xi8>
+  // CHECK: %[[I4_MASK:.*]] = arith.constant dense<15> : vector<3x8x16xi8>
+  // CHECK: %[[I8:.*]] = arith.trunci %[[IN]] : vector<3x8x32xi32> to vector<3x8x32xi8>
+  // CHECK: %[[LOW:.*]], %[[HIGH:.*]] = vector.deinterleave %[[I8]] : vector<3x8x32xi8> -> vector<3x8x16xi8>
+  // CHECK: %[[ZEROED_LOW:.*]] = arith.andi %[[LOW]], %[[I4_MASK]] : vector<3x8x16xi8>
+  // CHECK: %[[SHL_HIGH:.*]] = arith.shli %[[HIGH]], %[[LEFT_SHIFT_BITS]] : vector<3x8x16xi8>
+  // CHECK: %[[MERGED:.*]] = arith.ori %[[ZEROED_LOW]], %[[SHL_HIGH]] : vector<3x8x16xi8>
+  // CHECK: %[[I4:.*]] = vector.bitcast %[[MERGED]] : vector<3x8x16xi8> to vector<3x8x32xi4> 
+  %0 = arith.trunci %a : vector<3x8x32xi32> to vector<3x8x32xi4>
+  return %0 : vector<3x8x32xi4>
+}
+
 // CHECK-LABEL: func.func @i4_transpose(
 func.func @i4_transpose(%a: vector<8x16xi4>) -> vector<16x8xi4> {
 // CHECK-SAME:    %[[IN:.*]]: vector<8x16xi4>) -> vector<16x8xi4> {
 // CHECK:           %[[EXT:.*]] = vector.interleave
 // CHECK:           %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8>
-// CHECK:           %[[TRUNC:.*]] = arith.trunci %[[TRANS]] : vector<16x8xi8> to vector<16x8xi4>
+// CHECK:           vector.deinterleave %[[TRANS]] : vector<16x8xi8> -> vector<16x4xi8>
   %0 = vector.transpose %a, [1, 0] : vector<8x16xi4> to vector<16x8xi4>
   return %0 : vector<16x8xi4>
 }



More information about the Mlir-commits mailing list