[Mlir-commits] [mlir] [mlir][Vector] Add support for trunci to narrow type emulation (PR #82565)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Feb 21 23:10:33 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Diego Caballero (dcaballe)

<details>
<summary>Changes</summary>

WIP

---
Full diff: https://github.com/llvm/llvm-project/pull/82565.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp (+114-5) 
- (modified) mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir (+42) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index fc11ae63e718a5..82c08cc5a54936 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -729,8 +729,8 @@ static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter,
 
   // TODO: consider relaxing this restriction in the future if we find ways
   // to really work with subbyte elements across the MLIR/LLVM boundary.
-  unsigned resultBitwidth = preconditionType.getElementTypeBitWidth();
-  if (resultBitwidth % 8 != 0)
+  unsigned bitwidth = preconditionType.getElementTypeBitWidth();
+  if (bitwidth % 8 != 0)
     return rewriter.notifyMatchFailure(op, "bitwidth is not k * 8");
 
   return success();
@@ -876,6 +876,57 @@ static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
   return rewriter.create<vector::InterleaveOp>(loc, low, high);
 }
 
+/// Rewrite the i8 -> i4 signed extension into a sequence of shuffles and
+/// bitwise ops that take advantage of high-level information to avoid leaving
+/// LLVM to scramble with peephole optimizations.
+static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc,
+                                Value srcValue) {
+  VectorType srcVecType = cast<VectorType>(srcValue.getType());
+  assert(srcVecType.getElementType().isSignlessInteger(8) &&
+         "Expected i8 type");
+
+  // 1. De-interleave low and high i8 elements.
+  int64_t vecDimSize = srcVecType.getShape().back();
+  SmallVector<int64_t> deinterleaveLowMaskValues;
+  SmallVector<int64_t> deinterleaveHighMaskValues;
+  deinterleaveLowMaskValues.reserve(vecDimSize / 2);
+  deinterleaveHighMaskValues.reserve(vecDimSize / 2);
+  for (int i = 0, end = vecDimSize; i < end; i += 2) {
+    deinterleaveLowMaskValues.push_back(i);
+    deinterleaveHighMaskValues.push_back(i + 1);
+  }
+
+  auto lowShuffleOp = rewriter.create<vector::ShuffleOp>(
+      loc, srcValue, srcValue,
+      rewriter.getI64ArrayAttr(deinterleaveLowMaskValues));
+  auto highShuffleOp = rewriter.create<vector::ShuffleOp>(
+      loc, srcValue, srcValue,
+      rewriter.getI64ArrayAttr(deinterleaveHighMaskValues));
+
+  // 2. Zero out the upper side of each low i8 element.
+  constexpr int8_t i8LowBitMask = 0x0F;
+  Value zeroOutMask = rewriter.create<arith::ConstantOp>(
+      loc,
+      DenseElementsAttr::get(lowShuffleOp.getResultVectorType(), i8LowBitMask));
+  Value zeroOutLow =
+      rewriter.create<arith::AndIOp>(loc, lowShuffleOp, zeroOutMask);
+
+  // 3. Move high i4 values to upper side of the byte.
+  constexpr int8_t bitsToShift = 4;
+  VectorType deinterI8VecType = highShuffleOp.getResultVectorType();
+  auto shiftValues = rewriter.create<arith::ConstantOp>(
+      loc, DenseElementsAttr::get(deinterI8VecType, bitsToShift));
+  Value shlHigh =
+      rewriter.create<arith::ShLIOp>(loc, highShuffleOp, shiftValues);
+
+  // 4. Merge high and low i4 values.
+  auto mergedHiLowOp = rewriter.create<arith::OrIOp>(loc, zeroOutLow, shlHigh);
+
+  // 5. Generate a bitcast vector<Xxi8> -> vector<2Xxi4>.
+  auto i4VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI4Type());
+  return rewriter.create<vector::BitCastOp>(loc, i4VecType, mergedHiLowOp);
+}
+
 namespace {
 /// Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take
 /// advantage of high-level information to avoid leaving LLVM to scramble with
@@ -1019,7 +1070,7 @@ struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
 
   LogicalResult matchAndRewrite(ConversionOpType conversionOp,
                                 PatternRewriter &rewriter) const override {
-    // Set up the BitCastRewriter and verify the preconditions.
+    // Verify the preconditions.
     Value srcValue = conversionOp.getIn();
     auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
     auto dstVecType = dyn_cast<VectorType>(conversionOp.getType());
@@ -1043,6 +1094,63 @@ struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
   }
 };
 
+/// Rewrite the i8 -> i4 part of any truncation into a sequence of shuffles and
+/// bitwise ops that take advantage of high-level information to avoid leaving
+/// LLVM to scramble with peephole optimizations.
+///
+/// For example:
+///    arith.trunci %in : vector<8xi32> to vector<8xi4>
+///      is rewriten as
+///
+///        %cst = arith.constant dense<15> : vector<4xi8>
+///        %cst_0 = arith.constant dense<4> : vector<4xi8>
+///        %0 = arith.trunci %in : vector<8xi32> to vector<8xi8>
+///        %1 = vector.shuffle %0, %0 [0, 2, 4, 6] : vector<8xi8>, vector<8xi8>
+///        %2 = vector.shuffle %0, %0 [1, 3, 5, 7] : vector<8xi8>, vector<8xi8>
+///        %3 = arith.andi %1, %cst : vector<4xi8>
+///        %4 = arith.shli %2, %cst_0 : vector<4xi8>
+///        %5 = arith.ori %3, %4 : vector<4xi8>
+///        %6 = vector.bitcast %5 : vector<4xi8> to vector<8xi4>
+///
+struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
+  using OpRewritePattern<arith::TruncIOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(arith::TruncIOp truncOp,
+                                PatternRewriter &rewriter) const override {
+    // Verify the preconditions.
+    Value srcValue = truncOp.getIn();
+    auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
+    auto dstVecType = dyn_cast<VectorType>(truncOp.getType());
+
+    // Only single dim vectors are supported until we have
+    // `vector.deinterleave`.
+    if (srcVecType.getRank() != 1)
+      return failure();
+
+    if (failed(commonConversionPrecondition(rewriter, srcVecType, truncOp)))
+      return failure();
+
+    // Check general alignment preconditions. We invert the src/dst type order
+    // to reuse the existing precondition logic.
+    if (failed(alignedConversionPrecondition(rewriter, dstVecType, srcVecType,
+                                             truncOp)))
+      return failure();
+
+    // Create a new iX -> i8 truncation op.
+    Location loc = truncOp.getLoc();
+    auto i8VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI8Type());
+    Value i8TruncVal =
+        rewriter.create<arith::TruncIOp>(loc, i8VecType, srcValue);
+
+    // Rewrite the i8 -> i4 truncation part.
+    Value subByteTrunc = rewriteI8ToI4Trunc(rewriter, loc, i8TruncVal);
+
+    // Finalize the rewrite.
+    rewriter.replaceOp(truncOp, subByteTrunc);
+    return success();
+  }
+};
+
 /// Rewrite a sub-byte vector transpose into a sequence of instructions that
 /// perform the transpose on wider (byte) element types.
 /// For example:
@@ -1115,8 +1223,9 @@ void vector::populateVectorNarrowTypeRewritePatterns(
   // Patterns for aligned cases. We set higher priority as they are expected to
   // generate better performance for aligned cases.
   patterns.add<RewriteAlignedSubByteIntSignedExt<arith::ExtSIOp>,
-               RewriteAlignedSubByteIntSignedExt<arith::SIToFPOp>>(
-      patterns.getContext(), benefit.getBenefit() + 1);
+               RewriteAlignedSubByteIntSignedExt<arith::SIToFPOp>,
+               RewriteAlignedSubByteIntTrunc>(patterns.getContext(),
+                                              benefit.getBenefit() + 1);
 }
 
 void vector::populateVectorTransposeNarrowTypeRewritePatterns(
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
index 94e78ce40a3c19..8f0148119806c9 100644
--- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -262,6 +262,48 @@ func.func @aligned_sitofp_2d(%a: vector<8x32xi4>) -> vector<8x32xf32> {
   return %0 : vector<8x32xf32>
 }
 
+// CHECK-LABEL: func.func @aligned_trunci(
+func.func @aligned_trunci(%a: vector<8xi32>) -> vector<8xi4> {
+// CHECK-SAME:    %[[IN:.*]]: vector<8xi32>) -> vector<8xi4> {
+// CHECK-DAG:       %[[LOW_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
+// CHECK-DAG:       %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK:           %[[I8:.*]] = arith.trunci %[[IN]] : vector<8xi32> to vector<8xi8>
+// CHECK:           %[[LOW:.*]] = vector.shuffle %[[I8]], %[[I8]] [0, 2, 4, 6] : vector<8xi8>, vector<8xi8>
+// CHECK:           %[[HIGH:.*]] = vector.shuffle %[[I8]], %[[I8]] [1, 3, 5, 7] : vector<8xi8>, vector<8xi8>
+// CHECK:           %[[ZEROED_LOW:.*]] = arith.andi %[[LOW]], %[[LOW_MASK]] : vector<4xi8>
+// CHECK:           %[[SHL_HIGH:.*]] = arith.shli %[[HIGH]], %[[I4_BITS]] : vector<4xi8>
+// CHECK:           %[[MERGED:.*]] = arith.ori %[[ZEROED_LOW]], %[[SHL_HIGH]] : vector<4xi8>
+// CHECK:           %[[I4:.*]] = vector.bitcast %[[MERGED]] : vector<4xi8> to vector<8xi4>
+  %0 = arith.trunci %a : vector<8xi32> to vector<8xi4>
+  return %0 : vector<8xi4>
+}
+
+// CHECK-LABEL: func.func @aligned_trunci_base_case(
+func.func @aligned_trunci_base_case(%a: vector<8xi8>) -> vector<8xi4> {
+// CHECK-SAME:    %[[IN:.*]]: vector<8xi8>) -> vector<8xi4> {
+// CHECK-DAG:       %[[LOW_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
+// CHECK-DAG:       %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK:           %[[LOW:.*]] = vector.shuffle %[[IN]], %[[IN]] [0, 2, 4, 6] : vector<8xi8>, vector<8xi8>
+// CHECK:           %[[HIGH:.*]] = vector.shuffle %[[IN]], %[[IN]] [1, 3, 5, 7] : vector<8xi8>, vector<8xi8>
+// CHECK:           %[[ZEROED_LOW:.*]] = arith.andi %[[LOW]], %[[LOW_MASK]] : vector<4xi8>
+// CHECK:           %[[SHL_HIGH:.*]] = arith.shli %[[HIGH]], %[[I4_BITS]] : vector<4xi8>
+// CHECK:           %[[MERGED:.*]] = arith.ori %[[ZEROED_LOW]], %[[SHL_HIGH]] : vector<4xi8>
+// CHECK:           %[[I4:.*]] = vector.bitcast %[[MERGED]] : vector<4xi8> to vector<8xi4>
+  %0 = arith.trunci %a : vector<8xi8> to vector<8xi4>
+  return %0 : vector<8xi4>
+}
+
+// CHECK-LABEL: func.func @aligned_trunci_2d(
+func.func @aligned_trunci_2d(%a: vector<8x32xi32>) -> vector<8x32xi4> {
+// CHECK-NOT:       vector.shuffle
+// CHECK-NOT:       vector.andi
+// CHECK-NOT:       vector.shli
+// CHECK-NOT:       vector.ori
+// CHECK:           arith.trunci
+  %0 = arith.trunci %a : vector<8x32xi32> to vector<8x32xi4>
+  return %0 : vector<8x32xi4>
+}
+
 // CHECK-LABEL: func.func @i4_transpose(
 func.func @i4_transpose(%a: vector<8x16xi4>) -> vector<16x8xi4> {
 // CHECK-SAME:    %[[IN:.*]]: vector<8x16xi4>) -> vector<16x8xi4> {

``````````

</details>


https://github.com/llvm/llvm-project/pull/82565


More information about the Mlir-commits mailing list