[Mlir-commits] [mlir] [mlir][Vector] Add narrow type emulation pattern for vector.maskedload (PR #68443)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Oct 18 05:02:36 PDT 2023


================
@@ -103,6 +104,212 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
   }
 };
 
+//===----------------------------------------------------------------------===//
+// ConvertVectorMaskedLoad
+//===----------------------------------------------------------------------===//
+
+struct ConvertVectorMaskedLoad final
+    : OpConversionPattern<vector::MaskedLoadOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    auto loc = op.getLoc();
+    auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
+    Type oldElementType = op.getType().getElementType();
+    Type newElementType = convertedType.getElementType();
+    int srcBits = oldElementType.getIntOrFloatBitWidth();
+    int dstBits = newElementType.getIntOrFloatBitWidth();
+
+    if (dstBits % srcBits != 0) {
+      return rewriter.notifyMatchFailure(
+          op, "only dstBits % srcBits == 0 supported");
+    }
+    int scale = dstBits / srcBits;
+
+    // Adjust the number of elements to load when emulating narrow types,
+    // and then cast back to the original type with vector.bitcast op.
+    // For example, to emulate i4 to i8, the following op:
+    //
+    //   %mask = vector.constant_mask [3] : vector<6xi1>
+    //   %1 = vector.maskedload %0[%c0, %c0], %mask, %pass_thru :
+    //   memref<3x6xi4>, vector<6xi1>, vector<6xi4> into vector<6xi4>
+    //
+    // can be replaced with
+    //
+    //   %new_mask = vector.constant_mask [2] : vector<3xi1>
+    //   %new_pass_thru = vector.bitcast %pass_thru : vector<6xi4> to
+    //   vector<3xi8> %1 = vector.maskedload %0[%linear_index], %new_mask,
+    //   %new_pass_thru : memref<9xi8>, vector<3xi1>, vector<3xi8> into
+    //   vector<3xi8>
+    //
+    // Since we are effectively loading 16 bits (2xi8) from the memref with the
+    // new mask, while originally we only wanted to effectively load 12 bits
+    // (3xi4) from the memref, we need to set the second half of the last i8
+    // that was effectively loaded (i.e. the second i8) to 0.
----------------
tyb0807 wrote:

I'm not sure I see how to do this. Do you mean having a loop of `arith.select` based on the index?

https://github.com/llvm/llvm-project/pull/68443


More information about the Mlir-commits mailing list