[Mlir-commits] [mlir] [mlir] Add narrow type emulation conversions (PR #72181)

Han-Chung Wang llvmlistbot at llvm.org
Wed Nov 15 14:54:56 PST 2023


================
@@ -35,36 +36,98 @@ using namespace mlir;
 /// Return the bit offset of the value at position `srcIdx`. For example, if
 /// `sourceBits` equals to 4 and `targetBits` equals to 8, the x-th element is
 /// located at (x % 2) * 4. Because there are two elements in one i8, and one
-/// element has 4 bits.
+/// element has 4 bits. If `rightOffset` is true, return the offset from the
+/// right side of the `dstBits` container instead of the left side.
 static Value getOffsetForBitwidth(Location loc, OpFoldResult srcIdx,
                                   int sourceBits, int targetBits,
-                                  OpBuilder &builder) {
+                                  OpBuilder &builder,
+                                  bool rightOffset = false) {
   assert(targetBits % sourceBits == 0);
   AffineExpr s0;
   bindSymbols(builder.getContext(), s0);
   int scaleFactor = targetBits / sourceBits;
-  OpFoldResult offsetVal = affine::makeComposedFoldedAffineApply(
-      builder, loc, (s0 % scaleFactor) * sourceBits, {srcIdx});
+  AffineExpr offsetExpr =
+      rightOffset ? (scaleFactor - 1 - s0 % scaleFactor) * sourceBits
+                  : (s0 % scaleFactor) * sourceBits;
+  OpFoldResult offsetVal =
+      affine::makeComposedFoldedAffineApply(builder, loc, offsetExpr, {srcIdx});
   Value bitOffset = getValueOrCreateConstantIndexOp(builder, loc, offsetVal);
   IntegerType dstType = builder.getIntegerType(targetBits);
   return builder.create<arith::IndexCastOp>(loc, dstType, bitOffset);
 }
 
+/// When writing a subbyte size, writing needs to happen atomically in case of
+/// another write happening on the same byte at the same time. To do the write,
+/// we first must clear `dstBits` at the `linearizedIndices` of the subbyte
+/// store. This function returns the appropriate mask for clearing these bits.
+static Value getAtomicWriteMask(Location loc, OpFoldResult linearizedIndices,
+                                int64_t srcBits, int64_t dstBits,
+                                Value bitwidthOffset, OpBuilder &builder) {
+  auto dstIntegerType = builder.getIntegerType(dstBits);
+  auto maskRightAlignedAttr =
+      builder.getIntegerAttr(dstIntegerType, (1 << srcBits) - 1);
+  Value maskRightAligned =
+      builder
+          .create<arith::ConstantOp>(loc, dstIntegerType, maskRightAlignedAttr)
+          .getResult();
+  Value writeMaskInverse =
+      builder.create<arith::ShLIOp>(loc, maskRightAligned, bitwidthOffset);
+  auto flipValAttr = builder.getIntegerAttr(dstIntegerType, -1);
+  Value flipVal =
+      builder.create<arith::ConstantOp>(loc, dstIntegerType, flipValAttr)
+          .getResult();
+  return builder.create<arith::XOrIOp>(loc, writeMaskInverse, flipVal);
+}
+
+/// Returns the scaled linearized index based on the `srcBits` and `dstBits`
+/// sizes. The input `linearizedIndex` has the grandularity of `srcBits`, and
+/// the returned index has the granularity of `dstBits`
+static Value getIndicesForLoadOrStore(OpBuilder &builder, Location loc,
+                                      OpFoldResult linearizedIndex,
+                                      int64_t srcBits, int64_t dstBits) {
+  AffineExpr s0;
+  bindSymbols(builder.getContext(), s0);
+  int64_t scaler = dstBits / srcBits;
+  OpFoldResult scaledLinearizedIndices = affine::makeComposedFoldedAffineApply(
+      builder, loc, s0.floorDiv(scaler), {linearizedIndex});
+  return getValueOrCreateConstantIndexOp(builder, loc, scaledLinearizedIndices);
+}
+
+static OpFoldResult
+getLinearizedSrcIndices(OpBuilder &builder, Location loc, int64_t srcBits,
+                        const SmallVector<OpFoldResult> &indices,
+                        Value memref) {
+  auto stridedMetadata =
+      builder.create<memref::ExtractStridedMetadataOp>(loc, memref);
+  OpFoldResult linearizedIndices;
+  std::tie(std::ignore, linearizedIndices) =
+      memref::getLinearizedMemRefOffsetAndSize(
+          builder, loc, srcBits, srcBits,
+          stridedMetadata.getConstifiedMixedOffset(),
+          stridedMetadata.getConstifiedMixedSizes(),
+          stridedMetadata.getConstifiedMixedStrides(), indices);
+  return linearizedIndices;
+}
+
 namespace {
 
 //===----------------------------------------------------------------------===//
 // ConvertMemRefAlloc
 //===----------------------------------------------------------------------===//
 
-struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
-  using OpConversionPattern::OpConversionPattern;
+template <typename OpTy>
+struct ConvertMemRefAlloc final : OpConversionPattern<OpTy> {
----------------
hanhanW wrote:

Perhaps we can rename it to `ConvertMemRefAllocation`. It is not only used by memref.alloc, but also memref.alloca.

https://github.com/llvm/llvm-project/pull/72181


More information about the Mlir-commits mailing list