[Mlir-commits] [mlir] [mlir][Vector] Add a rewrite pattern for better low-precision ext(bit… (PR #65774)

Oleksandr Alex Zinenko llvmlistbot at llvm.org
Wed Sep 13 01:39:53 PDT 2023


================
@@ -155,6 +170,256 @@ struct ConvertVectorTransferRead final
 };
 } // end anonymous namespace
 
+//===----------------------------------------------------------------------===//
+// RewriteExtOfBitCast
+//===----------------------------------------------------------------------===//
+
+/// Create a vector of bit masks: `idx .. idx + step - 1` and broadcast it
+/// `numOccurrences` times.
+/// `step` must divide `bitwidth` evenly.
+/// Example: (4, 2, 3) -> [0xc, 0x3, 0xc, 0x3, 0xc, 0x3].
+static SmallVector<Attribute> computeExtOfBitCastMasks(MLIRContext *ctx,
+                                                       int64_t bitwidth,
+                                                       int64_t step,
+                                                       int64_t numOccurrences) {
+  assert(bitwidth % step == 0 && "step must divide bitwidth evenly");
+  IntegerType interimIntType = IntegerType::get(ctx, bitwidth);
+  SmallVector<Attribute> tmpMasks;
+  tmpMasks.reserve(bitwidth / step);
+  // Create a vector of bit masks: `idx .. idx + step - 1`.
+  for (int64_t idx = 0; idx < bitwidth; idx += step) {
+    LDBG("Mask bits " << idx << " .. " << idx + step - 1 << " out of "
+                      << bitwidth);
+    IntegerAttr mask = IntegerAttr::get(
+        interimIntType, llvm::APInt::getBitsSet(bitwidth, idx, idx + step));
+    tmpMasks.push_back(mask);
+  }
+  // Replicate the vector of bit masks to the desired size.
+  SmallVector<Attribute> masks;
+  masks.reserve(numOccurrences * tmpMasks.size());
+  for (int64_t idx = 0; idx < numOccurrences; ++idx)
+    llvm::append_range(masks, tmpMasks);
+  return masks;
+}
+
+/// Create a vector of bit shifts by `k * idx` and broadcast it `numOccurrences`
+/// times.
+/// `step` must divide `bitwidth` evenly.
+/// Example: (4, 2, 3) -> [0x0, 0x2, 0x0, 0x2, 0x0, 0x2].
+static SmallVector<Attribute>
+computeExtOfBitCastShifts(MLIRContext *ctx, int64_t bitwidth, int64_t step,
+                          int64_t numOccurrences) {
+  assert(bitwidth % step == 0 && "step must divide bitwidth evenly");
+  IntegerType interimIntType = IntegerType::get(ctx, bitwidth);
+  SmallVector<Attribute> tmpShifts;
+  for (int64_t idx = 0; idx < bitwidth; idx += step) {
+    IntegerAttr shift = IntegerAttr::get(interimIntType, idx);
+    tmpShifts.push_back(shift);
+  }
+  SmallVector<Attribute> shifts;
+  for (int64_t idx = 0; idx < numOccurrences; ++idx)
+    llvm::append_range(shifts, tmpShifts);
+  return shifts;
+}
+
+/// Create a vector of bit shuffles: `numOccurrences * idx` and broadcast it
+/// `bitwidth/step` times.
+/// `step` must divide `bitwidth` evenly.
+/// Example: (4, 2, 3) -> [0x0, 0x1, 0x0, 0x1, 0x0, 0x1].
+static SmallVector<int64_t>
+computeExtOfBitCastShuffles(MLIRContext *ctx, int64_t bitwidth, int64_t step,
+                            int64_t numOccurrences) {
+  assert(bitwidth % step == 0 && "step must divide bitwidth evenly");
+  SmallVector<int64_t> shuffles;
+  int64_t n = floorDiv(bitwidth, step);
+  for (int64_t idx = 0; idx < numOccurrences; ++idx)
+    llvm::append_range(shuffles, SmallVector<int64_t>(n, idx));
+  return shuffles;
+}
+
+/// Compute the intermediate vector type, its elemental type must be an integer
+/// with bitwidth that:
+///   1. is smaller than 64 (TODO: in the future we may want target-specific
+///   control).
+///   2. divides sourceBitWidth * mostMinorSourceDim
+static int64_t computeExtOfBitCastBitWidth(int64_t sourceBitWidth,
+                                           int64_t mostMinorSourceDim,
+                                           int64_t targetBitWidth) {
+  for (int64_t mult : {32, 16, 8, 4, 2, 1}) {
+    int64_t interimBitWidth =
+        std::lcm(mult, std::lcm(sourceBitWidth, targetBitWidth));
+    if (interimBitWidth > 64)
+      continue;
+    if ((sourceBitWidth * mostMinorSourceDim) % interimBitWidth != 0)
+      continue;
+    return interimBitWidth;
+  }
+  return 0;
+}
+
+FailureOr<Value>
+mlir::vector::rewriteExtOfBitCast(RewriterBase &rewriter, Operation *extOp,
+                                  vector::BitCastOp bitCastOp,
+                                  vector::BroadcastOp maybeBroadcastOp) {
+  assert(
+      (llvm::isa<arith::ExtSIOp>(extOp) || llvm::isa<arith::ExtUIOp>(extOp)) &&
+      "unsupported op");
+
+  // The bitcast op is the load-bearing part, capture the source and bitCast
+  // types as well as bitwidth and most minor dimension.
+  VectorType sourceVectorType = bitCastOp.getSourceVectorType();
+  int64_t sourceBitWidth = sourceVectorType.getElementTypeBitWidth();
+  int64_t mostMinorSourceDim = sourceVectorType.getShape().back();
+  LDBG("sourceVectorType: " << sourceVectorType);
+
+  VectorType bitCastVectorType = bitCastOp.getResultVectorType();
+  int64_t targetBitWidth = bitCastVectorType.getElementTypeBitWidth();
+  LDBG("bitCastVectorType: " << bitCastVectorType);
+
+  int64_t interimBitWidth = computeExtOfBitCastBitWidth(
+      sourceBitWidth, mostMinorSourceDim, targetBitWidth);
+  LDBG("interimBitWidth: " << interimBitWidth);
+  if (!interimBitWidth) {
+    return rewriter.notifyMatchFailure(
+        extOp, "heuristic could not find a reasonable interim bitwidth");
+  }
+  if (sourceBitWidth == interimBitWidth || targetBitWidth == interimBitWidth) {
+    return rewriter.notifyMatchFailure(
+        extOp, "interim bitwidth is equal to source or target, nothing to do");
+  }
+
+  int64_t interimMostMinorDim =
+      sourceBitWidth * mostMinorSourceDim / interimBitWidth;
+  LDBG("interimMostMinorDim: " << interimMostMinorDim);
+
+  Location loc = extOp->getLoc();
+  MLIRContext *ctx = extOp->getContext();
+
+  VectorType interimVectorType =
+      VectorType::Builder(sourceVectorType)
+          .setDim(sourceVectorType.getRank() - 1, interimMostMinorDim)
+          .setElementType(IntegerType::get(ctx, interimBitWidth));
+  LDBG("interimVectorType: " << interimVectorType);
+
+  IntegerType interimIntType = IntegerType::get(ctx, interimBitWidth);
----------------
ftynse wrote:

Nit: put above `interimVectorType` and use in its construction.

https://github.com/llvm/llvm-project/pull/65774


More information about the Mlir-commits mailing list