[Mlir-commits] [mlir] [mlir] Add narrow type emulation conversions (PR #72181)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Nov 13 16:56:06 PST 2023


https://github.com/Max191 updated https://github.com/llvm/llvm-project/pull/72181

>From 1b4a88bfcbaebe89259dffddaca34d40440da3b8 Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Fri, 10 Nov 2023 11:27:13 -0500
Subject: [PATCH] [mlir] Add narrow type emulation conversions Adds narrow type
 emulation support for:     - `memref.alloca`     - `memref.store`     -
 `memref.reinterpret_cast`

---
 .../MemRef/Transforms/EmulateNarrowType.cpp   | 272 +++++++++++++++---
 .../Dialect/MemRef/emulate-narrow-type.mlir   | 228 +++++++++++++++
 2 files changed, 460 insertions(+), 40 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index 9f58e9055acadbb..f23ace1924bbf51 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/OpDefinition.h"
 #include "mlir/Support/MathExtras.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/Support/FormatVariadic.h"
@@ -35,36 +36,98 @@ using namespace mlir;
 /// Return the bit offset of the value at position `srcIdx`. For example, if
 /// `sourceBits` equals to 4 and `targetBits` equals to 8, the x-th element is
 /// located at (x % 2) * 4. Because there are two elements in one i8, and one
-/// element has 4 bits.
+/// element has 4 bits. If `rightOffset` is true, return the offset from the
+/// right side of the `dstBits` container instead of the left side.
 static Value getOffsetForBitwidth(Location loc, OpFoldResult srcIdx,
                                   int sourceBits, int targetBits,
-                                  OpBuilder &builder) {
+                                  OpBuilder &builder,
+                                  bool rightOffset = false) {
   assert(targetBits % sourceBits == 0);
   AffineExpr s0;
   bindSymbols(builder.getContext(), s0);
   int scaleFactor = targetBits / sourceBits;
-  OpFoldResult offsetVal = affine::makeComposedFoldedAffineApply(
-      builder, loc, (s0 % scaleFactor) * sourceBits, {srcIdx});
+  AffineExpr offsetExpr =
+      rightOffset ? (scaleFactor - 1 - s0 % scaleFactor) * sourceBits
+                  : (s0 % scaleFactor) * sourceBits;
+  OpFoldResult offsetVal =
+      affine::makeComposedFoldedAffineApply(builder, loc, offsetExpr, {srcIdx});
   Value bitOffset = getValueOrCreateConstantIndexOp(builder, loc, offsetVal);
   IntegerType dstType = builder.getIntegerType(targetBits);
   return builder.create<arith::IndexCastOp>(loc, dstType, bitOffset);
 }
 
+/// When writing a subbyte size, writing needs to happen atomically in case of
+/// another write happening on the same byte at the same time. To do the write,
+/// we first must clear `dstBits` at the `linearizedIndices` of the subbyte
+/// store. This function returns the appropriate mask for clearing these bits.
+static Value getAtomicWriteMask(Location loc, OpFoldResult linearizedIndices,
+                                int64_t srcBits, int64_t dstBits,
+                                Value bitwidthOffset, OpBuilder &builder) {
+  auto dstIntegerType = builder.getIntegerType(dstBits);
+  auto maskRightAlignedAttr =
+      builder.getIntegerAttr(dstIntegerType, (1 << srcBits) - 1);
+  Value maskRightAligned =
+      builder
+          .create<arith::ConstantOp>(loc, dstIntegerType, maskRightAlignedAttr)
+          .getResult();
+  Value writeMaskInverse =
+      builder.create<arith::ShLIOp>(loc, maskRightAligned, bitwidthOffset);
+  auto flipValAttr = builder.getIntegerAttr(dstIntegerType, -1);
+  Value flipVal =
+      builder.create<arith::ConstantOp>(loc, dstIntegerType, flipValAttr)
+          .getResult();
+  return builder.create<arith::XOrIOp>(loc, writeMaskInverse, flipVal);
+}
+
+/// Returns the scaled linearized index based on the `srcBits` and `dstBits`
+/// sizes. The input `linearizedIndex` has the grandularity of `srcBits`, and
+/// the returned index has the granularity of `dstBits`
+static Value getIndicesForLoadOrStore(OpBuilder &builder, Location loc,
+                                      OpFoldResult linearizedIndex,
+                                      int64_t srcBits, int64_t dstBits) {
+  AffineExpr s0;
+  bindSymbols(builder.getContext(), s0);
+  int64_t scaler = dstBits / srcBits;
+  OpFoldResult scaledLinearizedIndices = affine::makeComposedFoldedAffineApply(
+      builder, loc, s0.floorDiv(scaler), {linearizedIndex});
+  return getValueOrCreateConstantIndexOp(builder, loc, scaledLinearizedIndices);
+}
+
+static OpFoldResult
+getLinearizedSrcIndices(OpBuilder &builder, Location loc, int64_t srcBits,
+                        const SmallVector<OpFoldResult> &indices,
+                        Value memref) {
+  auto stridedMetadata =
+      builder.create<memref::ExtractStridedMetadataOp>(loc, memref);
+  OpFoldResult linearizedIndices;
+  std::tie(std::ignore, linearizedIndices) =
+      memref::getLinearizedMemRefOffsetAndSize(
+          builder, loc, srcBits, srcBits,
+          stridedMetadata.getConstifiedMixedOffset(),
+          stridedMetadata.getConstifiedMixedSizes(),
+          stridedMetadata.getConstifiedMixedStrides(), indices);
+  return linearizedIndices;
+}
+
 namespace {
 
 //===----------------------------------------------------------------------===//
 // ConvertMemRefAlloc
 //===----------------------------------------------------------------------===//
 
-struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
-  using OpConversionPattern::OpConversionPattern;
+template <typename OpTy>
+struct ConvertMemRefAlloc final : OpConversionPattern<OpTy> {
+  using OpConversionPattern<OpTy>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
+  matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    auto currentType = op.getMemref().getType().cast<MemRefType>();
-    auto newResultType =
-        getTypeConverter()->convertType(op.getType()).dyn_cast<MemRefType>();
+    static_assert(std::is_same<OpTy, memref::AllocOp>() ||
+                      std::is_same<OpTy, memref::AllocaOp>(),
+                  "expected only memref::AllocOp or memref::AllocaOp");
+    auto currentType = cast<MemRefType>(op.getMemref().getType());
+    auto newResultType = dyn_cast<MemRefType>(
+        this->getTypeConverter()->convertType(op.getType()));
     if (!newResultType) {
       return rewriter.notifyMatchFailure(
           op->getLoc(),
@@ -73,9 +136,9 @@ struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
 
     // Special case zero-rank memrefs.
     if (currentType.getRank() == 0) {
-      rewriter.replaceOpWithNewOp<memref::AllocOp>(
-          op, newResultType, ValueRange{}, adaptor.getSymbolOperands(),
-          adaptor.getAlignmentAttr());
+      rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, ValueRange{},
+                                        adaptor.getSymbolOperands(),
+                                        adaptor.getAlignmentAttr());
       return success();
     }
 
@@ -97,9 +160,9 @@ struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
           rewriter, loc, linearizedMemRefInfo.linearizedSize));
     }
 
-    rewriter.replaceOpWithNewOp<memref::AllocOp>(
-        op, newResultType, dynamicLinearizedSize, adaptor.getSymbolOperands(),
-        adaptor.getAlignmentAttr());
+    rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, dynamicLinearizedSize,
+                                      adaptor.getSymbolOperands(),
+                                      adaptor.getAlignmentAttr());
     return success();
   }
 };
@@ -155,32 +218,15 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
       bitsLoad = rewriter.create<memref::LoadOp>(loc, adaptor.getMemref(),
                                                  ValueRange{});
     } else {
-      SmallVector<OpFoldResult> indices =
-          getAsOpFoldResult(adaptor.getIndices());
-
-      auto stridedMetadata = rewriter.create<memref::ExtractStridedMetadataOp>(
-          loc, op.getMemRef());
-
       // Linearize the indices of the original load instruction. Do not account
       // for the scaling yet. This will be accounted for later.
-      OpFoldResult linearizedIndices;
-      std::tie(std::ignore, linearizedIndices) =
-          memref::getLinearizedMemRefOffsetAndSize(
-              rewriter, loc, srcBits, srcBits,
-              stridedMetadata.getConstifiedMixedOffset(),
-              stridedMetadata.getConstifiedMixedSizes(),
-              stridedMetadata.getConstifiedMixedStrides(), indices);
-
-      AffineExpr s0;
-      bindSymbols(rewriter.getContext(), s0);
-      int64_t scaler = dstBits / srcBits;
-      OpFoldResult scaledLinearizedIndices =
-          affine::makeComposedFoldedAffineApply(
-              rewriter, loc, s0.floorDiv(scaler), {linearizedIndices});
+      OpFoldResult linearizedIndices = getLinearizedSrcIndices(
+          rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
+
       Value newLoad = rewriter.create<memref::LoadOp>(
           loc, adaptor.getMemref(),
-          getValueOrCreateConstantIndexOp(rewriter, loc,
-                                          scaledLinearizedIndices));
+          getIndicesForLoadOrStore(rewriter, loc, linearizedIndices, srcBits,
+                                   dstBits));
 
       // Get the offset and shift the bits to the rightmost.
       // Note, currently only the big-endian is supported.
@@ -211,6 +257,150 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
   }
 };
 
+//===----------------------------------------------------------------------===//
+// ConvertMemRefReinterpretCast
+//===----------------------------------------------------------------------===//
+
+/// Currently there is very limited support for memref::ReinterpretCastOp
+/// conversion. Only the 0 dimensional case is supported.
+struct ConvertMemRefReinterpretCast final
+    : OpConversionPattern<memref::ReinterpretCastOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    MemRefType newTy =
+        dyn_cast<MemRefType>(getTypeConverter()->convertType(op.getType()));
+    if (!newTy) {
+      return rewriter.notifyMatchFailure(
+          op->getLoc(),
+          llvm::formatv("failed to convert memref type: {0}", op.getType()));
+    }
+
+    auto convertedElementType = newTy.getElementType();
+    auto oldElementType = op.getType().getElementType();
+    int srcBits = oldElementType.getIntOrFloatBitWidth();
+    int dstBits = convertedElementType.getIntOrFloatBitWidth();
+    if (dstBits % srcBits != 0) {
+      return rewriter.notifyMatchFailure(
+          op, "only dstBits % srcBits == 0 supported");
+    }
+
+    // Only support offset for 0-D subview.
+    if (op.getType().getRank() != 0) {
+      return rewriter.notifyMatchFailure(
+          op->getLoc(), "subview with rank > 0 is not supported");
+    }
+
+    int64_t offset = op.getStaticOffset(0);
+    // Only support static sizes and offsets.
+    if (offset == ShapedType::kDynamic) {
+      return rewriter.notifyMatchFailure(
+          op->getLoc(), "subview with dynamic offset is not supported");
+    }
+
+    int elementsPerByte = dstBits / srcBits;
+    if (offset % elementsPerByte != 0) {
+      return rewriter.notifyMatchFailure(
+          op->getLoc(),
+          "subview with offset not multiple of elementsPerByte is not "
+          "supported");
+    }
+
+    offset = offset / elementsPerByte;
+
+    rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
+        op, newTy, *adaptor.getODSOperands(0).begin(), offset,
+        SmallVector<int64_t>{}, op.getStaticStrides());
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertMemrefStore
+//===----------------------------------------------------------------------===//
+
+struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto convertedType = adaptor.getMemref().getType().cast<MemRefType>();
+    auto convertedElementType = convertedType.getElementType();
+    auto oldElementType = op.getMemRefType().getElementType();
+    int srcBits = oldElementType.getIntOrFloatBitWidth();
+    int dstBits = convertedElementType.getIntOrFloatBitWidth();
+    auto dstIntegerType = rewriter.getIntegerType(dstBits);
+    if (dstBits % srcBits != 0) {
+      return rewriter.notifyMatchFailure(
+          op, "only dstBits % srcBits == 0 supported");
+    }
+
+    Location loc = op.getLoc();
+    Value extendedInput = rewriter.create<arith::ExtUIOp>(loc, dstIntegerType,
+                                                          adaptor.getValue());
+
+    // Special case 0-rank memref stores. We can compute the mask at compile
+    // time.
+    if (convertedType.getRank() == 0) {
+      // Shift extended value to be left aligned
+      auto shiftValAttr =
+          rewriter.getIntegerAttr(dstIntegerType, dstBits - srcBits);
+      Value shiftVal =
+          rewriter.create<arith::ConstantOp>(loc, dstIntegerType, shiftValAttr)
+              .getResult();
+      Value alignedVal =
+          rewriter.create<arith::ShLIOp>(loc, extendedInput, shiftVal)
+              .getResult();
+      // Create mask to clear destination bits
+      auto writeMaskValAttr = rewriter.getIntegerAttr(
+          dstIntegerType, (1 << (dstBits - srcBits)) - 1);
+      Value writeMask =
+          rewriter
+              .create<arith::ConstantOp>(loc, dstIntegerType, writeMaskValAttr)
+              .getResult();
+
+      // Clear destination bits
+      rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::andi,
+                                           writeMask, adaptor.getMemref(),
+                                           ValueRange{});
+      // Write srcs bits to destination
+      rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::ori,
+                                           alignedVal, adaptor.getMemref(),
+                                           ValueRange{});
+      rewriter.eraseOp(op);
+      return success();
+    }
+
+    OpFoldResult linearizedIndices = getLinearizedSrcIndices(
+        rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
+    Value storeIndices = getIndicesForLoadOrStore(
+        rewriter, loc, linearizedIndices, srcBits, dstBits);
+    Value bitwidthOffset = getOffsetForBitwidth(loc, linearizedIndices, srcBits,
+                                                dstBits, rewriter, true);
+    Value writeMask = getAtomicWriteMask(loc, linearizedIndices, srcBits,
+                                         dstBits, bitwidthOffset, rewriter);
+    // Align the value to write with the destination bits
+    Value alignedVal =
+        rewriter.create<arith::ShLIOp>(loc, extendedInput, bitwidthOffset)
+            .getResult();
+
+    // Clear destination bits
+    rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::andi,
+                                         writeMask, adaptor.getMemref(),
+                                         storeIndices);
+    // Write srcs bits to destination
+    rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::ori,
+                                         alignedVal, adaptor.getMemref(),
+                                         storeIndices);
+
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // ConvertMemRefSubview
 //===----------------------------------------------------------------------===//
@@ -291,8 +481,10 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
     RewritePatternSet &patterns) {
 
   // Populate `memref.*` conversion patterns.
-  patterns.add<ConvertMemRefAlloc, ConvertMemRefLoad,
-               ConvertMemRefAssumeAlignment, ConvertMemRefSubview>(
+  patterns.add<ConvertMemRefAlloc<memref::AllocOp>,
+               ConvertMemRefAlloc<memref::AllocaOp>, ConvertMemRefLoad,
+               ConvertMemRefAssumeAlignment, ConvertMemRefSubview,
+               ConvertMemrefStore, ConvertMemRefReinterpretCast>(
       typeConverter, patterns.getContext());
   memref::populateResolveExtractStridedMetadataPatterns(patterns);
 }
diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
index 6ed97f05aa7cff2..05ec5761c8fe024 100644
--- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
@@ -174,3 +174,231 @@ func.func @memref_strided_i4(%idx : index) -> i4 {
 //       CHECK32:   %[[ALLOC:.+]] = memref.alloc() : memref<16xi32>
 //       CHECK32:   %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]][4] [4] [1] : memref<16xi32> to memref<4xi32, strided<[1], offset: 4>>
 //       CHECK32:   %[[LOAD:.+]] = memref.load %[[SUBVIEW]]
+
+// -----
+
+func.func @memref_alloca_load_i4(%arg0: index) -> i4 {
+    %0 = memref.alloca() : memref<5xi4>
+    %1 = memref.load %0[%arg0] : memref<5xi4>
+    return %1 : i4
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 2)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 2) * 8)
+//      CHECK: func @memref_alloca_load_i4(
+// CHECK-SAME:     %[[ARG0:.+]]: index
+//      CHECK:   %[[ALLOCA:.+]] = memref.alloca() : memref<3xi8>
+//      CHECK:   %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+//      CHECK:   %[[LOADVAL:.+]] = memref.load %[[ALLOCA]][%[[INDEX]]]
+//      CHECK:   %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+//      CHECK:   %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
+//      CHECK:   %[[SHIFTRT:.+]] = arith.shrsi %[[LOADVAL]], %[[CAST]]
+//      CHECK:   %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i8 to i4
+//      CHECK:   return %[[TRUNC]]
+
+//  CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 8)>
+//  CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 8) * 32)
+//      CHECK32: func @memref_alloca_load_i4(
+// CHECK32-SAME:     %[[ARG0:.+]]: index
+//      CHECK32:   %[[ALLOCA:.+]] = memref.alloca() : memref<1xi32>
+//      CHECK32:   %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+//      CHECK32:   %[[LOADVAL:.+]] = memref.load %[[ALLOCA]][%[[INDEX]]]
+//      CHECK32:   %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+//      CHECK32:   %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
+//      CHECK32:   %[[SHIFTRT:.+]] = arith.shrsi %[[LOADVAL]], %[[CAST]]
+//      CHECK32:   %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i32 to i4
+//      CHECK32:   return %[[TRUNC]]
+
+// -----
+
+func.func @memref_store_i4(%arg0: index, %arg1: i4) -> () {
+    %0 = memref.alloc() : memref<5xi4>
+    memref.store %arg1, %0[%arg0] : memref<5xi4>
+    return
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 2)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * -4 + (s0 floordiv 2) * 8 + 4)>
+//      CHECK: func @memref_store_i4(
+// CHECK-SAME:     %[[ARG0:.+]]: index, %[[ARG1:.+]]: i4
+//  CHECK-DAG:   %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+//  CHECK-DAG:   %[[EXTUI:.+]] = arith.extui %[[ARG1]] : i4 to i8
+//  CHECK-DAG:   %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+//  CHECK-DAG:   %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+//  CHECK-DAG:   %[[BITOFFSET_I8:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
+//  CHECK-DAG:   %[[MASK_BASE:.+]] = arith.constant 15 : i8
+//  CHECK-DAG:   %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I8]] : i8
+//  CHECK-DAG:   %[[CST_NEG_ONE:.+]] = arith.constant -1 : i8
+//  CHECK-DAG:   %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i8
+//  CHECK-DAG:   %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I8]] : i8
+//      CHECK:   %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %alloc[%[[INDEX]]] : (i8, memref<3xi8>) -> i8
+//      CHECK:   %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %alloc[%[[INDEX]]] : (i8, memref<3xi8>) -> i8
+//      CHECK:   return
+
+//  CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 8)>
+//  CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * -4 + (s0 floordiv 8) * 32 + 28)>
+//      CHECK32: func @memref_store_i4(
+// CHECK32-SAME:     %[[ARG0:.+]]: index, %[[ARG1:.+]]: i4
+//  CHECK32-DAG:   %[[ALLOC:.+]] = memref.alloc() : memref<1xi32>
+//  CHECK32-DAG:   %[[EXTUI:.+]] = arith.extui %[[ARG1]] : i4 to i32
+//  CHECK32-DAG:   %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+//  CHECK32-DAG:   %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+//  CHECK32-DAG:   %[[BITOFFSET_I32:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
+//  CHECK32-DAG:   %[[MASK_BASE:.+]] = arith.constant 15 : i32
+//  CHECK32-DAG:   %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I32]] : i32
+//  CHECK32-DAG:   %[[CST_NEG_ONE:.+]] = arith.constant -1 : i32
+//  CHECK32-DAG:   %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i32
+//  CHECK32-DAG:   %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I32]] : i32
+//      CHECK32:   %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %alloc[%[[INDEX]]] : (i32, memref<1xi32>) -> i32
+//      CHECK32:   %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %alloc[%[[INDEX]]] : (i32, memref<1xi32>) -> i32
+//      CHECK32:   return
+
+// -----
+
+func.func @memref_store_i4_rank2(%arg0: index, %arg1: index, %arg2: i4) -> () {
+    %0 = memref.alloc() : memref<3x125xi4>
+    memref.assume_alignment %0, 64 : memref<3x125xi4>
+    memref.store %arg2, %0[%arg0,%arg1] : memref<3x125xi4>
+    return
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * 125 + s1) floordiv 2)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 * -500 - s1 * 4 + ((s0 * 125 + s1) floordiv 2) * 8 + 4)>
+//      CHECK: func @memref_store_i4_rank2(
+// CHECK-SAME:     %[[ARG0:.+]]: index, %[[ARG1:.+]]: index, %[[ARG2:.+]]: i4
+//  CHECK-DAG:   %[[ALLOC:.+]] = memref.alloc() : memref<188xi8>
+//  CHECK-DAG:   memref.assume_alignment %[[ALLOC]], 64 : memref<188xi8>
+//  CHECK-DAG:   %[[EXTUI:.+]] = arith.extui %[[ARG2]] : i4 to i8
+//  CHECK-DAG:   %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+//  CHECK-DAG:   %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]]
+//  CHECK-DAG:   %[[BITOFFSET_I8:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
+//  CHECK-DAG:   %[[MASK_BASE:.+]] = arith.constant 15 : i8
+//  CHECK-DAG:   %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I8]] : i8
+//  CHECK-DAG:   %[[CST_NEG_ONE:.+]] = arith.constant -1 : i8
+//  CHECK-DAG:   %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i8
+//  CHECK-DAG:   %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I8]] : i8
+//      CHECK:   %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %alloc[%[[INDEX]]] : (i8, memref<188xi8>) -> i8
+//      CHECK:   %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %alloc[%[[INDEX]]] : (i8, memref<188xi8>) -> i8
+//      CHECK:   return
+
+//  CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * 125 + s1) floordiv 8)>
+//  CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 * -500 - s1 * 4 + ((s0 * 125 + s1) floordiv 8) * 32 + 28)>
+//      CHECK32: func @memref_store_i4_rank2(
+// CHECK32-SAME:     %[[ARG0:.+]]: index, %[[ARG1:.+]]: index, %[[ARG2:.+]]: i4
+//  CHECK32-DAG:   %[[ALLOC:.+]] = memref.alloc() : memref<47xi32>
+//  CHECK32-DAG:   memref.assume_alignment %[[ALLOC]], 64 : memref<47xi32>
+//  CHECK32-DAG:   %[[EXTUI:.+]] = arith.extui %[[ARG2]] : i4 to i32
+//  CHECK32-DAG:   %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+//  CHECK32-DAG:   %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]]
+//  CHECK32-DAG:   %[[BITOFFSET_I32:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
+//  CHECK32-DAG:   %[[MASK_BASE:.+]] = arith.constant 15 : i32
+//  CHECK32-DAG:   %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I32]] : i32
+//  CHECK32-DAG:   %[[CST_NEG_ONE:.+]] = arith.constant -1 : i32
+//  CHECK32-DAG:   %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i32
+//  CHECK32-DAG:   %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I32]] : i32
+//      CHECK32:   %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %alloc[%[[INDEX]]] : (i32, memref<47xi32>) -> i32
+//      CHECK32:   %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %alloc[%[[INDEX]]] : (i32, memref<47xi32>) -> i32
+//      CHECK32:   return
+
+// -----
+
+func.func @memref_store_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %arg3 : index, %arg4: i4) -> () {
+  %0 = memref.alloc(%arg0, %arg1) : memref<?x?xi4>
+  memref.store %arg4, %0[%arg2, %arg3] : memref<?x?xi4>
+  return
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 2)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * -4 - s2 * 4 + ((s2 + s0 * s1) floordiv 2) * 8 + 4)>
+//      CHECK: func @memref_store_i4_dynamic(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index
+// CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: index
+// CHECK-SAME:     %[[ARG4:[a-zA-Z0-9]+]]: i4
+//  CHECK-DAG:   %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+//  CHECK-DAG:   %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi8>
+//  CHECK-DAG:   %[[EXTUI:.+]] = arith.extui %[[ARG4]] : i4 to i8
+//  CHECK-DAG:   %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
+//  CHECK-DAG:   %[[BITOFFSET:.+]] = affine.apply #[[MAP2]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
+//  CHECK-DAG:   %[[BITOFFSET_I8:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
+//  CHECK-DAG:   %[[MASK_BASE:.+]] = arith.constant 15 : i8
+//  CHECK-DAG:   %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I8]] : i8
+//  CHECK-DAG:   %[[CST_NEG_ONE:.+]] = arith.constant -1 : i8
+//  CHECK-DAG:   %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i8
+//  CHECK-DAG:   %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I8]] : i8
+//      CHECK:   %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %alloc[%[[INDEX]]] : (i8, memref<?xi8>) -> i8
+//      CHECK:   %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %alloc[%[[INDEX]]] : (i8, memref<?xi8>) -> i8
+//      CHECK:   return
+
+//  CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 8)>
+//  CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 8)>
+//  CHECK32-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * -4 - s2 * 4 + ((s2 + s0 * s1) floordiv 8) * 32 + 28)>
+//      CHECK32: func @memref_store_i4_dynamic(
+// CHECK32-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index
+// CHECK32-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK32-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index
+// CHECK32-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: index
+// CHECK32-SAME:     %[[ARG4:[a-zA-Z0-9]+]]: i4
+//  CHECK32-DAG:   %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+//  CHECK32-DAG:   %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi32>
+//  CHECK32-DAG:   %[[EXTUI:.+]] = arith.extui %[[ARG4]] : i4 to i32
+//  CHECK32-DAG:   %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
+//  CHECK32-DAG:   %[[BITOFFSET:.+]] = affine.apply #[[MAP2]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
+//  CHECK32-DAG:   %[[BITOFFSET_I32:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
+//  CHECK32-DAG:   %[[MASK_BASE:.+]] = arith.constant 15 : i32
+//  CHECK32-DAG:   %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I32]] : i32
+//  CHECK32-DAG:   %[[CST_NEG_ONE:.+]] = arith.constant -1 : i32
+//  CHECK32-DAG:   %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i32
+//  CHECK32-DAG:   %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I32]] : i32
+//      CHECK32:   %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %alloc[%[[INDEX]]] : (i32, memref<?xi32>) -> i32
+//      CHECK32:   %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %alloc[%[[INDEX]]] : (i32, memref<?xi32>) -> i32
+//      CHECK32:   return
+
+// -----
+
+func.func @rank_zero_memref(%arg0: i4) -> () {
+  %0 = memref.alloc() : memref<i4>
+  memref.store %arg0, %0[] : memref<i4>
+  return
+}
+// CHECK-LABEL: func @rank_zero_memref
+//  CHECK-SAME:     %[[ARG0:.+]]: i4
+//       CHECK:   %[[ALLOC:.+]] = memref.alloc() : memref<i8>
+//       CHECK:   %[[EXTUI:.+]] = arith.extui %[[ARG0]] : i4 to i8
+//       CHECK:   %[[BITOFFSET:.+]] = arith.constant 4 : i8
+//       CHECK:   %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET]] : i8
+//       CHECK:   %[[MASK:.+]] = arith.constant 15 : i8
+//       CHECK:   %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %alloc[] : (i8, memref<i8>) -> i8
+//       CHECK:   %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %alloc[] : (i8, memref<i8>) -> i8
+//       CHECK:   return
+
+// CHECK32-LABEL: func @rank_zero_memref
+//  CHECK32-SAME:     %[[ARG0:.+]]: i4
+//       CHECK32:   %[[ALLOC:.+]] = memref.alloc() : memref<i32>
+//       CHECK32:   %[[EXTUI:.+]] = arith.extui %[[ARG0]] : i4 to i32
+//       CHECK32:   %[[BITOFFSET:.+]] = arith.constant 28 : i32
+//       CHECK32:   %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET]] : i32
+//       CHECK32:   %[[MASK:.+]] = arith.constant 268435455 : i32
+//       CHECK32:   %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %alloc[] : (i32, memref<i32>) -> i32
+//       CHECK32:   %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %alloc[] : (i32, memref<i32>) -> i32
+//       CHECK32:   return
+
+// -----
+
+func.func @reinterpret_cast_memref_load() -> i4 {
+    %0 = memref.alloc() : memref<5xi4>
+    %reinterpret_cast_0 = memref.reinterpret_cast %0 to offset: [0], sizes: [], strides: [] : memref<5xi4> to memref<i4>
+    %1 = memref.load %reinterpret_cast_0[] : memref<i4>
+    return %1 : i4
+}
+// CHECK-LABEL: func @reinterpret_cast_memref_load()
+//       CHECK:   %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+//       CHECK:   %[[RE_CAST:.+]] = memref.reinterpret_cast %[[ALLOC]] to offset: [0], sizes: [], strides: [] : memref<3xi8> to memref<i8>
+//       CHECK:   %[[LOAD:.+]] = memref.load %[[RE_CAST]][] : memref<i8>
+//       CHECK:   %[[TRUNC:.+]] = arith.trunci %[[LOAD]] : i8 to i4
+//       CHECK:   return %[[TRUNC]]
+
+// CHECK32-LABEL: func @reinterpret_cast_memref_load()
+//       CHECK32:   %[[ALLOC:.+]] = memref.alloc() : memref<1xi32>
+//       CHECK32:   %[[RE_CAST:.+]] = memref.reinterpret_cast %[[ALLOC]] to offset: [0], sizes: [], strides: [] : memref<1xi32> to memref<i32>
+//       CHECK32:   %[[LOAD:.+]] = memref.load %[[RE_CAST]][] : memref<i32>
+//       CHECK32:   %[[TRUNC:.+]] = arith.trunci %[[LOAD]] : i32 to i4
+//       CHECK32:   return %[[TRUNC]]



More information about the Mlir-commits mailing list