[Mlir-commits] [mlir] [mlir][Vector] Add narrow type emulation pattern for vector.maskedload (PR #68443)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Oct 26 16:56:04 PDT 2023


https://github.com/tyb0807 updated https://github.com/llvm/llvm-project/pull/68443

>From d5301d959238f29eb9c7b57317cb6e8b28a107c0 Mon Sep 17 00:00:00 2001
From: tyb0807 <vuson at google.com>
Date: Wed, 4 Oct 2023 23:50:19 +0000
Subject: [PATCH 1/4] [mlir][Vector] Add narrow type emulation pattern for
 vector.maskedload

---
 .../Transforms/VectorEmulateNarrowType.cpp    | 189 ++++++++++++++++-
 .../Vector/vector-emulate-narrow-type.mlir    | 198 +++++++++++++++++-
 2 files changed, 383 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 94300291dcd7d23..ad08e9b14a100f7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -7,6 +7,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
@@ -103,6 +104,190 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
   }
 };
 
+//===----------------------------------------------------------------------===//
+// ConvertVectorMaskedLoad
+//===----------------------------------------------------------------------===//
+
+struct ConvertVectorMaskedLoad final
+    : OpConversionPattern<vector::MaskedLoadOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    auto loc = op.getLoc();
+    auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
+    Type oldElementType = op.getType().getElementType();
+    Type newElementType = convertedType.getElementType();
+    int srcBits = oldElementType.getIntOrFloatBitWidth();
+    int dstBits = newElementType.getIntOrFloatBitWidth();
+
+    if (dstBits % srcBits != 0) {
+      return rewriter.notifyMatchFailure(
+          op, "only dstBits % srcBits == 0 supported");
+    }
+    int scale = dstBits / srcBits;
+
+    // Adjust the number of elements to load when emulating narrow types,
+    // and then cast back to the original type with vector.bitcast op.
+    // For example, to emulate i4 to i8, the following op:
+    //
+    //   %mask = vector.constant_mask [3] : vector<6xi1>
+    //   %1 = vector.maskedload %0[%c0, %c0], %mask, %pass_thru :
+    //   memref<3x6xi4>, vector<6xi1>, vector<6xi4> into vector<6xi4>
+    //
+    // can be replaced with
+    //
+    //   %new_mask = vector.constant_mask [2] : vector<3xi1>
+    //   %new_pass_thru = vector.bitcast %pass_thru : vector<6xi4> to
+    //   vector<3xi8> %1 = vector.maskedload %0[%linear_index], %new_mask,
+    //   %new_pass_thru : memref<9xi8>, vector<3xi1>, vector<3xi8> into
+    //   vector<3xi8>
+    //
+    // Since we are effectively loading 16 bits (2xi8) from the memref with the
+    // new mask, while originally we only wanted to effectively load 12 bits
+    // (3xi4) from the memref, we need to set the second half of the last i8
+    // that was effectively loaded (i.e. the second i8) to 0.
+    //
+    //   %unset_mask = arith.extsi %mask : vector<6xi1> to vector<6xi4>
+    //   %2 = vector.bitcast %unset_mask : vector<6xi4> to vector<3xi8>
+    //   %3 = arith.andi %1, %2 : vector<3xi8>
+    //
+    // Then if the second half of the second i8 from %pass_thru is not all 0s,
+    // we need to write their values back to the result.
+    //
+    //   %cst_1 = arith.constant dense<-1> : vector<6xi4>
+    //   %set_mask = arith.xori %unset_mask, %cst_1 : vector<6xi4>
+    //   %4 = vector.bitcast %set_mask : vector<6xi4> to vector<3xi8>
+    //   %5 = arith.andi %new_pass_thru, %4 : vector<3xi8>
+    //
+    //   %6 = arith.ori %3, %5 : vector<3xi8>
+    //   %7 = vector.bitcast %6 : vector<3xi8> to vector<6xi4>
+    //
+    // Given these input values:
+    //   %mask = [1, 1, 1, 0, 0, 0]
+    //   %0[%c0, %c0] contains [0x1, 0x2, 0x3, 0x4, 0x5, 0x6]
+    //   %pass_thru = [0x7, 0x8, 0x9, 0xA, 0xB, 0xC]
+    //
+    // we'll have:
+    //
+    //   expected output: [0x1, 0x2, 0x3, 0xA, 0xB, 0xC]
+    //
+    //   %new_mask = [1, 1, 0]
+    //   %new_pass_thru = [0x78, 0x9A, 0xBC]
+    //   %1 = [0x12, 0x34, 0xBC]
+    //
+    //   %unset_mask = [0xF, 0xF, 0xF, 0, 0, 0]
+    //   %2 = [0xFF, 0xF0, 0]
+    //   %3 = [0x12, 0x30, 0]
+    //
+    //   %set_mask = [0, 0, 0, 0xF, 0xF, 0xF]
+    //   %4 = [0, 0x0F, 0xFF]
+    //   %5 = [0, 0x0A, 0xBC]
+    //
+    //   %6 = [0x12, 0x3A, 0xBC]
+    //   %7 = [0x1, 0x2, 0x3, 0xA, 0xB, 0xC]
+    //
+    // TODO: Currently, only the even number of elements loading is supported.
+    // To deal with the odd number of elements, one has to extract the
+    // subvector at the proper offset after bit-casting.
+
+    auto origType = op.getVectorType();
+    auto origElements = origType.getNumElements();
+    if (origElements % scale != 0)
+      return failure();
+
+    auto stridedMetadata =
+        rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
+
+    OpFoldResult linearizedIndices;
+    std::tie(std::ignore, linearizedIndices) =
+        memref::getLinearizedMemRefOffsetAndSize(
+            rewriter, loc, srcBits, dstBits,
+            stridedMetadata.getConstifiedMixedOffset(),
+            stridedMetadata.getConstifiedMixedSizes(),
+            stridedMetadata.getConstifiedMixedStrides(),
+            getAsOpFoldResult(adaptor.getIndices()));
+
+    auto numElements = (origElements + scale - 1) / scale;
+    auto newType = VectorType::get(numElements, newElementType);
+
+    auto createMaskOp = op.getMask().getDefiningOp<vector::CreateMaskOp>();
+    auto constantMaskOp = op.getMask().getDefiningOp<vector::ConstantMaskOp>();
+    // TODO: Handle extracted mask.
+    if (!createMaskOp && !constantMaskOp)
+      return failure();
+
+    // Computing the "compressed" mask. All the emulation logic (i.e. computing
+    // new mask index) only happens on the last dimension of the vectors.
+    Operation *newMask = nullptr;
+    auto newMaskType = VectorType::get(numElements, rewriter.getI1Type());
+    if (createMaskOp) {
+      auto maskOperands = createMaskOp.getOperands();
+      auto numMaskOperands = maskOperands.size();
+      AffineExpr s0;
+      bindSymbols(rewriter.getContext(), s0);
+      s0 = s0 + scale - 1;
+      s0 = s0.floorDiv(scale);
+      OpFoldResult origIndex =
+          getAsOpFoldResult(maskOperands[numMaskOperands - 1]);
+      OpFoldResult maskIndex =
+          affine::makeComposedFoldedAffineApply(rewriter, loc, s0, origIndex);
+      newMask = rewriter.create<vector::CreateMaskOp>(
+          loc, newMaskType,
+          getValueOrCreateConstantIndexOp(rewriter, loc, maskIndex));
+    } else if (constantMaskOp) {
+      auto maskDimSizes = constantMaskOp.getMaskDimSizes().getValue();
+      auto numMaskOperands = maskDimSizes.size();
+      auto origIndex =
+          cast<IntegerAttr>(maskDimSizes[numMaskOperands - 1]).getInt();
+      auto maskIndex =
+          rewriter.getI64IntegerAttr((origIndex + scale - 1) / scale);
+      newMask = rewriter.create<vector::ConstantMaskOp>(
+          loc, newMaskType, ArrayAttr::get(op.getContext(), maskIndex));
+    }
+
+    auto newPassThru =
+        rewriter.create<vector::BitCastOp>(loc, newType, op.getPassThru());
+
+    // Generating the new masked load.
+    auto newLoad = rewriter.create<vector::MaskedLoadOp>(
+        loc, newType, adaptor.getBase(),
+        getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
+        newMask->getResult(0), newPassThru);
+
+    // Setting the part that originally was not effectively loaded from memory
+    // to 0.
+    auto andMask = rewriter.create<arith::ExtSIOp>(loc, origType, op.getMask());
+    auto bitCastedAndMask =
+        rewriter.create<vector::BitCastOp>(loc, newType, andMask);
+    auto loadedFromMem =
+        rewriter.create<arith::AndIOp>(loc, newLoad, bitCastedAndMask);
+
+    // Copying from pass through.
+    auto allOne = rewriter.create<arith::ConstantOp>(
+        loc, origType,
+        DenseIntElementsAttr::get(origType, {APInt::getAllOnes(srcBits)}));
+    auto passThruMask = rewriter.create<arith::XOrIOp>(loc, allOne.getResult(),
+                                                       andMask.getResult());
+    auto bitCastedPassThruMask =
+        rewriter.create<vector::BitCastOp>(loc, newType, passThruMask);
+    auto copiedFromPassThru =
+        rewriter.create<arith::AndIOp>(loc, newPassThru, bitCastedPassThruMask);
+
+    // Or-ing the first part loaded from memory and the second one copied from
+    // pass through to form the result.
+    auto result =
+        rewriter.create<arith::OrIOp>(loc, loadedFromMem, copiedFromPassThru);
+    auto bitCast =
+        rewriter.create<vector::BitCastOp>(loc, op.getType(), result);
+
+    rewriter.replaceOp(op, bitCast->getResult(0));
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // ConvertVectorTransferRead
 //===----------------------------------------------------------------------===//
@@ -588,8 +773,8 @@ void vector::populateVectorNarrowTypeEmulationPatterns(
     RewritePatternSet &patterns) {
 
   // Populate `vector.*` conversion patterns.
-  patterns.add<ConvertVectorLoad, ConvertVectorTransferRead>(
-      typeConverter, patterns.getContext());
+  patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad,
+               ConvertVectorTransferRead>(typeConverter, patterns.getContext());
 }
 
 void vector::populateVectorNarrowTypeRewritePatterns(
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
index 6fcea33ddc952fe..28ca1c88e3eb3b7 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=8" --cse --split-input-file %s | FileCheck %s
-// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=32" --cse --split-input-file %s | FileCheck %s --check-prefix=CHECK32
+// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=8" --cse --split-input-file %s | FileCheck %s
+// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=32" --cse --split-input-file %s | FileCheck %s --check-prefix=CHECK32
 
 func.func @vector_load_i8(%arg1: index, %arg2: index) -> vector<4xi8> {
     %0 = memref.alloc() : memref<3x4xi8>
@@ -108,3 +108,197 @@ func.func @vector_transfer_read_i4(%arg1: index, %arg2: index) -> vector<8xi4> {
 //      CHECK32:   %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]], %[[ARG1]]]
 //      CHECK32:   %[[VEC:.+]] = vector.transfer_read %[[ALLOC]][%[[INDEX]]], %[[PAD]] : memref<3xi32>, vector<1xi32>
 //      CHECK32:   %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<1xi32> to vector<8xi4>
+
+// -----
+
+func.func @vector_maskedload_i8(%arg1: index, %arg2: index, %arg3: index, %passthru: vector<4xi8>) -> vector<4xi8> {
+    %0 = memref.alloc() : memref<3x4xi8>
+    %mask = vector.create_mask %arg3 : vector<4xi1>
+    %1 = vector.maskedload %0[%arg1, %arg2], %mask, %passthru :
+      memref<3x4xi8>, vector<4xi1>, vector<4xi8> into vector<4xi8>
+    return %1 : vector<4xi8>
+}
+// Expect no conversions, i8 is supported.
+//      CHECK: func @vector_maskedload_i8(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index, %[[ARG3:[a-zA-Z0-9]+]]: vector<4xi8>)
+// CHECK-NEXT:   %[[ALLOC:.+]] = memref.alloc() : memref<3x4xi8>
+// CHECK-NEXT:   %[[MASK:.+]] = vector.create_mask %[[ARG2]] : vector<4xi1>
+// CHECK-NEXT:   [[L:%.+]] = vector.maskedload %[[ALLOC]][%[[ARG0]], %[[ARG1]]], %[[MASK]], %[[ARG3]] :
+// CHECK-SAME:     memref<3x4xi8>, vector<4xi1>, vector<4xi8> into vector<4xi8>
+// CHECK-NEXT:   return
+
+//  CHECK32-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 4)>
+//  CHECK32-DAG: #[[MASK_IDX_MAP:.+]] = affine_map<()[s0] -> ((s0 + 3) floordiv 4)>
+//      CHECK32: func @vector_maskedload_i8(
+// CHECK32-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index,
+// CHECK32-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index, %[[ARG3:[a-zA-Z0-9]+]]: vector<4xi8>)
+//      CHECK32:   %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
+//      CHECK32:   %[[ORIG_MASK:.+]] = vector.create_mask %[[ARG2]] : vector<4xi1>
+//      CHECK32:   %[[LD_IDX:.+]] = affine.apply #[[LOAD_IDX_MAP]]()[%[[ARG0]], %[[ARG1]]]
+//      CHECK32:   %[[MASK_IDX:.+]] = affine.apply #[[MASK_IDX_MAP]]()[%[[ARG2]]]
+//      CHECK32:   %[[NEW_MASK:.+]] = vector.create_mask %[[MASK_IDX]] : vector<1xi1>
+//      CHECK32:   %[[NEW_PASSTHRU:.+]] = vector.bitcast %[[ARG3]] : vector<4xi8> to vector<1xi32>
+//      CHECK32:   %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LD_IDX]]], %[[NEW_MASK]], %[[NEW_PASSTHRU]] :
+// CHECK32-SAME:     memref<3xi32>, vector<1xi1>, vector<1xi32> into vector<1xi32>
+//      CHECK32:   %[[EXT:.+]] = arith.extsi %[[ORIG_MASK]] : vector<4xi1> to vector<4xi8>
+//      CHECK32:   %[[AND_MASK:.+]] = vector.bitcast %[[EXT]] : vector<4xi8> to vector<1xi32>
+//      CHECK32:   %[[FIRST_PART:.+]] = arith.andi %[[LOAD]], %[[AND_MASK]] : vector<1xi32>
+//      CHECK32:   %[[ONES:.+]] = arith.constant dense<-1> : vector<4xi8>
+//      CHECK32:   %[[XOR:.+]] = arith.xori %[[ONES]], %[[EXT]] : vector<4xi8>
+//      CHECK32:   %[[PASSTHRU_MASK:.+]] = vector.bitcast %[[XOR]] : vector<4xi8> to vector<1xi32>
+//      CHECK32:   %[[SECOND_PART:.+]] = arith.andi %[[NEW_PASSTHRU]], %[[PASSTHRU_MASK]] : vector<1xi32>
+//      CHECK32:   %[[VEC:.+]] = arith.ori %[[FIRST_PART]], %[[SECOND_PART]] : vector<1xi32>
+//      CHECK32:   %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<1xi32> to vector<4xi8>
+//      CHECK32:   return %[[VEC_I4]]
+
+// -----
+
+func.func @vector_maskedload_i4(%arg1: index, %arg2: index, %arg3: index, %passthru: vector<8xi4>) -> vector<3x8xi4> {
+    %0 = memref.alloc() : memref<3x8xi4>
+    %cst = arith.constant dense<0> : vector<3x8xi4>
+    %mask = vector.create_mask %arg3 : vector<8xi1>
+    %1 = vector.maskedload %0[%arg1, %arg2], %mask, %passthru :
+      memref<3x8xi4>, vector<8xi1>, vector<8xi4> into vector<8xi4>
+    %2 = vector.insert %1, %cst [0] : vector<8xi4> into vector<3x8xi4>
+    return %2 : vector<3x8xi4>
+}
+//  CHECK-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
+//  CHECK-DAG: #[[MASK_IDX_MAP:.+]] = affine_map<()[s0] -> ((s0 + 1) floordiv 2)>
+//      CHECK: func @vector_maskedload_i4(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index, %[[ARG3:[a-zA-Z0-9]+]]: vector<8xi4>)
+//      CHECK:   %[[ALLOC:.+]] = memref.alloc() : memref<12xi8>
+//      CHECK:   %[[ORIG_MASK:.+]] = vector.create_mask %[[ARG2]] : vector<8xi1>
+//      CHECK:   %[[LD_IDX:.+]] = affine.apply #[[LOAD_IDX_MAP]]()[%[[ARG0]], %[[ARG1]]]
+//      CHECK:   %[[MASK_IDX:.+]] = affine.apply #[[MASK_IDX_MAP]]()[%[[ARG2]]]
+//      CHECK:   %[[NEW_MASK:.+]] = vector.create_mask %[[MASK_IDX]] : vector<4xi1>
+//      CHECK:   %[[NEW_PASSTHRU:.+]] = vector.bitcast %[[ARG3]] : vector<8xi4> to vector<4xi8>
+//      CHECK:   %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LD_IDX]]], %[[NEW_MASK]], %[[NEW_PASSTHRU]] :
+// CHECK-SAME:     memref<12xi8>, vector<4xi1>, vector<4xi8> into vector<4xi8>
+//      CHECK:   %[[EXT:.+]] = arith.extsi %[[ORIG_MASK]] : vector<8xi1> to vector<8xi4>
+//      CHECK:   %[[AND_MASK:.+]] = vector.bitcast %[[EXT]] : vector<8xi4> to vector<4xi8>
+//      CHECK:   %[[FIRST_PART:.+]] = arith.andi %[[LOAD]], %[[AND_MASK]] : vector<4xi8>
+//      CHECK:   %[[ONES:.+]] = arith.constant dense<-1> : vector<8xi4>
+//      CHECK:   %[[XOR:.+]] = arith.xori %[[ONES]], %[[EXT]] : vector<8xi4>
+//      CHECK:   %[[PASSTHRU_MASK:.+]] = vector.bitcast %[[XOR]] : vector<8xi4> to vector<4xi8>
+//      CHECK:   %[[SECOND_PART:.+]] = arith.andi %[[NEW_PASSTHRU]], %[[PASSTHRU_MASK]] : vector<4xi8>
+//      CHECK:   %[[VEC:.+]] = arith.ori %[[FIRST_PART]], %[[SECOND_PART]] : vector<4xi8>
+//      CHECK:   %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<4xi8> to vector<8xi4>
+
+//  CHECK32-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
+//  CHECK32-DAG: #[[MASK_IDX_MAP:.+]] = affine_map<()[s0] -> ((s0 + 7) floordiv 8)>
+//      CHECK32: func @vector_maskedload_i4(
+// CHECK32-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index,
+// CHECK32-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index, %[[ARG3:[a-zA-Z0-9]+]]: vector<8xi4>)
+//      CHECK32:   %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
+//      CHECK32:   %[[ORIG_MASK:.+]] = vector.create_mask %[[ARG2]] : vector<8xi1>
+//      CHECK32:   %[[LD_IDX:.+]] = affine.apply #[[LOAD_IDX_MAP]]()[%[[ARG0]], %[[ARG1]]]
+//      CHECK32:   %[[MASK_IDX:.+]] = affine.apply #[[MASK_IDX_MAP]]()[%[[ARG2]]]
+//      CHECK32:   %[[NEW_MASK:.+]] = vector.create_mask %[[MASK_IDX]] : vector<1xi1>
+//      CHECK32:   %[[NEW_PASSTHRU:.+]] = vector.bitcast %[[ARG3]] : vector<8xi4> to vector<1xi32>
+//      CHECK32:   %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LD_IDX]]], %[[NEW_MASK]], %[[NEW_PASSTHRU]] :
+// CHECK32-SAME:     memref<3xi32>, vector<1xi1>, vector<1xi32> into vector<1xi32>
+//      CHECK32:   %[[EXT:.+]] = arith.extsi %[[ORIG_MASK]] : vector<8xi1> to vector<8xi4>
+//      CHECK32:   %[[AND_MASK:.+]] = vector.bitcast %[[EXT]] : vector<8xi4> to vector<1xi32>
+//      CHECK32:   %[[FIRST_PART:.+]] = arith.andi %[[LOAD]], %[[AND_MASK]] : vector<1xi32>
+//      CHECK32:   %[[ONES:.+]] = arith.constant dense<-1> : vector<8xi4>
+//      CHECK32:   %[[XOR:.+]] = arith.xori %[[ONES]], %[[EXT]] : vector<8xi4>
+//      CHECK32:   %[[PASSTHRU_MASK:.+]] = vector.bitcast %[[XOR]] : vector<8xi4> to vector<1xi32>
+//      CHECK32:   %[[SECOND_PART:.+]] = arith.andi %[[NEW_PASSTHRU]], %[[PASSTHRU_MASK]] : vector<1xi32>
+//      CHECK32:   %[[VEC:.+]] = arith.ori %[[FIRST_PART]], %[[SECOND_PART]] : vector<1xi32>
+//      CHECK32:   %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<1xi32> to vector<8xi4>
+
+// -----
+
+func.func @vector_cst_maskedload_i8(%arg1: index, %arg2: index, %passthru: vector<4xi8>) -> vector<4xi8> {
+    %0 = memref.alloc() : memref<3x4xi8>
+    %mask = vector.constant_mask [2] : vector<4xi1>
+    %1 = vector.maskedload %0[%arg1, %arg2], %mask, %passthru :
+      memref<3x4xi8>, vector<4xi1>, vector<4xi8> into vector<4xi8>
+    return %1 : vector<4xi8>
+}
+// Expect no conversions, i8 is supported.
+//      CHECK: func @vector_cst_maskedload_i8(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: vector<4xi8>)
+// CHECK-NEXT:   %[[ALLOC:.+]] = memref.alloc() : memref<3x4xi8>
+// CHECK-NEXT:   %[[MASK:.+]] = vector.constant_mask [2] : vector<4xi1>
+// CHECK-NEXT:   [[L:%.+]] = vector.maskedload %[[ALLOC]][%[[ARG0]], %[[ARG1]]], %[[MASK]], %[[ARG2]] :
+// CHECK-SAME:     memref<3x4xi8>, vector<4xi1>, vector<4xi8> into vector<4xi8>
+// CHECK-NEXT:   return
+
+//  CHECK32-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 4)>
+//      CHECK32: func @vector_cst_maskedload_i8(
+// CHECK32-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index,
+// CHECK32-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: vector<4xi8>)
+//      CHECK32:   %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
+//      CHECK32:   %[[ORIG_MASK:.+]] = vector.constant_mask [2] : vector<4xi1>
+//      CHECK32:   %[[LD_IDX:.+]] = affine.apply #[[LOAD_IDX_MAP]]()[%[[ARG0]], %[[ARG1]]]
+//      CHECK32:   %[[NEW_MASK:.+]] = vector.constant_mask [1] : vector<1xi1>
+//      CHECK32:   %[[NEW_PASSTHRU:.+]] = vector.bitcast %[[ARG3]] : vector<4xi8> to vector<1xi32>
+//      CHECK32:   %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LD_IDX]]], %[[NEW_MASK]], %[[NEW_PASSTHRU]] :
+// CHECK32-SAME:     memref<3xi32>, vector<1xi1>, vector<1xi32> into vector<1xi32>
+//      CHECK32:   %[[EXT:.+]] = arith.extsi %[[ORIG_MASK]] : vector<4xi1> to vector<4xi8>
+//      CHECK32:   %[[AND_MASK:.+]] = vector.bitcast %[[EXT]] : vector<4xi8> to vector<1xi32>
+//      CHECK32:   %[[FIRST_PART:.+]] = arith.andi %[[LOAD]], %[[AND_MASK]] : vector<1xi32>
+//      CHECK32:   %[[ONES:.+]] = arith.constant dense<-1> : vector<4xi8>
+//      CHECK32:   %[[XOR:.+]] = arith.xori %[[ONES]], %[[EXT]] : vector<4xi8>
+//      CHECK32:   %[[PASSTHRU_MASK:.+]] = vector.bitcast %[[XOR]] : vector<4xi8> to vector<1xi32>
+//      CHECK32:   %[[SECOND_PART:.+]] = arith.andi %[[NEW_PASSTHRU]], %[[PASSTHRU_MASK]] : vector<1xi32>
+//      CHECK32:   %[[VEC:.+]] = arith.ori %[[FIRST_PART]], %[[SECOND_PART]] : vector<1xi32>
+//      CHECK32:   %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<1xi32> to vector<4xi8>
+//      CHECK32:   return %[[VEC_I4]]
+
+// -----
+
+func.func @vector_cst_maskedload_i4(%arg1: index, %arg2: index, %passthru: vector<8xi4>) -> vector<3x8xi4> {
+    %0 = memref.alloc() : memref<3x8xi4>
+    %cst = arith.constant dense<0> : vector<3x8xi4>
+    %mask = vector.constant_mask [4] : vector<8xi1>
+    %1 = vector.maskedload %0[%arg1, %arg2], %mask, %passthru :
+      memref<3x8xi4>, vector<8xi1>, vector<8xi4> into vector<8xi4>
+    %2 = vector.insert %1, %cst [0] : vector<8xi4> into vector<3x8xi4>
+    return %2 : vector<3x8xi4>
+}
+//  CHECK-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
+//      CHECK: func @vector_cst_maskedload_i4(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: vector<8xi4>)
+//      CHECK:   %[[ALLOC:.+]] = memref.alloc() : memref<12xi8>
+//      CHECK:   %[[ORIG_MASK:.+]] = vector.constant_mask [4] : vector<8xi1>
+//      CHECK:   %[[LD_IDX:.+]] = affine.apply #[[LOAD_IDX_MAP]]()[%[[ARG0]], %[[ARG1]]]
+//      CHECK:   %[[NEW_MASK:.+]] = vector.constant_mask [2] : vector<4xi1>
+//      CHECK:   %[[NEW_PASSTHRU:.+]] = vector.bitcast %[[ARG2]] : vector<8xi4> to vector<4xi8>
+//      CHECK:   %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LD_IDX]]], %[[NEW_MASK]], %[[NEW_PASSTHRU]] :
+// CHECK-SAME:     memref<12xi8>, vector<4xi1>, vector<4xi8> into vector<4xi8>
+//      CHECK:   %[[EXT:.+]] = arith.extsi %[[ORIG_MASK]] : vector<8xi1> to vector<8xi4>
+//      CHECK:   %[[AND_MASK:.+]] = vector.bitcast %[[EXT]] : vector<8xi4> to vector<4xi8>
+//      CHECK:   %[[FIRST_PART:.+]] = arith.andi %[[LOAD]], %[[AND_MASK]] : vector<4xi8>
+//      CHECK:   %[[ONES:.+]] = arith.constant dense<-1> : vector<8xi4>
+//      CHECK:   %[[XOR:.+]] = arith.xori %[[ONES]], %[[EXT]] : vector<8xi4>
+//      CHECK:   %[[PASSTHRU_MASK:.+]] = vector.bitcast %[[XOR]] : vector<8xi4> to vector<4xi8>
+//      CHECK:   %[[SECOND_PART:.+]] = arith.andi %[[NEW_PASSTHRU]], %[[PASSTHRU_MASK]] : vector<4xi8>
+//      CHECK:   %[[VEC:.+]] = arith.ori %[[FIRST_PART]], %[[SECOND_PART]] : vector<4xi8>
+//      CHECK:   %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<4xi8> to vector<8xi4>
+
+//  CHECK32-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
+//      CHECK32: func @vector_cst_maskedload_i4(
+// CHECK32-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index,
+// CHECK32-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: vector<8xi4>)
+//      CHECK32:   %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
+//      CHECK32: %[[ORIG_MASK:.+]] = vector.constant_mask [4] : vector<8xi1>
+//      CHECK32:   %[[LD_IDX:.+]] = affine.apply #[[LOAD_IDX_MAP]]()[%[[ARG0]], %[[ARG1]]]
+//      CHECK32:   %[[NEW_MASK:.+]] = vector.constant_mask [1] : vector<1xi1>
+//      CHECK32:   %[[NEW_PASSTHRU:.+]] = vector.bitcast %[[ARG2]] : vector<8xi4> to vector<1xi32>
+//      CHECK32:   %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LD_IDX]]], %[[NEW_MASK]], %[[NEW_PASSTHRU]] :
+// CHECK32-SAME:     memref<3xi32>, vector<1xi1>, vector<1xi32> into vector<1xi32>
+//      CHECK32:   %[[EXT:.+]] = arith.extsi %[[ORIG_MASK]] : vector<8xi1> to vector<8xi4>
+//      CHECK32:   %[[AND_MASK:.+]] = vector.bitcast %[[EXT]] : vector<8xi4> to vector<1xi32>
+//      CHECK32:   %[[FIRST_PART:.+]] = arith.andi %[[LOAD]], %[[AND_MASK]] : vector<1xi32>
+//      CHECK32:   %[[ONES:.+]] = arith.constant dense<-1> : vector<8xi4>
+//      CHECK32:   %[[XOR:.+]] = arith.xori %[[ONES]], %[[EXT]] : vector<8xi4>
+//      CHECK32:   %[[PASSTHRU_MASK:.+]] = vector.bitcast %[[XOR]] : vector<8xi4> to vector<1xi32>
+//      CHECK32:   %[[SECOND_PART:.+]] = arith.andi %[[NEW_PASSTHRU]], %[[PASSTHRU_MASK]] : vector<1xi32>
+//      CHECK32:   %[[VEC:.+]] = arith.ori %[[FIRST_PART]], %[[SECOND_PART]] : vector<1xi32>
+//      CHECK32:   %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<1xi32> to vector<8xi4>

>From 664f20c7d9c68676a19ab76bc49652e833c27322 Mon Sep 17 00:00:00 2001
From: tyb0807 <vuson at google.com>
Date: Mon, 9 Oct 2023 00:01:52 +0000
Subject: [PATCH 2/4] [mlir][Vector] Handle narrow type emulation of
 vector.maskedload when mask is an extraction

---
 .../Transforms/VectorEmulateNarrowType.cpp    |  36 ++++--
 .../Vector/vector-emulate-narrow-type.mlir    | 118 ++++++++++++++++++
 2 files changed, 147 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index ad08e9b14a100f7..a74a9753fd02a7a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -213,16 +213,28 @@ struct ConvertVectorMaskedLoad final
     auto numElements = (origElements + scale - 1) / scale;
     auto newType = VectorType::get(numElements, newElementType);
 
-    auto createMaskOp = op.getMask().getDefiningOp<vector::CreateMaskOp>();
-    auto constantMaskOp = op.getMask().getDefiningOp<vector::ConstantMaskOp>();
-    // TODO: Handle extracted mask.
+    auto maskOp = op.getMask().getDefiningOp();
+    SmallVector<vector::ExtractOp, 2> extractOps;
+    // Finding the mask creation operation.
+    while (maskOp &&
+           !isa<vector::CreateMaskOp, vector::ConstantMaskOp>(maskOp)) {
+      if (auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
+        maskOp = extractOp.getVector().getDefiningOp();
+        extractOps.push_back(extractOp);
+      }
+    }
+    auto createMaskOp = dyn_cast_or_null<vector::CreateMaskOp>(maskOp);
+    auto constantMaskOp = dyn_cast_or_null<vector::ConstantMaskOp>(maskOp);
     if (!createMaskOp && !constantMaskOp)
       return failure();
 
     // Computing the "compressed" mask. All the emulation logic (i.e. computing
     // new mask index) only happens on the last dimension of the vectors.
     Operation *newMask = nullptr;
-    auto newMaskType = VectorType::get(numElements, rewriter.getI1Type());
+    auto shape = llvm::to_vector(
+        maskOp->getResultTypes()[0].cast<VectorType>().getShape().drop_back());
+    shape.push_back(numElements);
+    auto newMaskType = VectorType::get(shape, rewriter.getI1Type());
     if (createMaskOp) {
       auto maskOperands = createMaskOp.getOperands();
       auto numMaskOperands = maskOperands.size();
@@ -234,9 +246,11 @@ struct ConvertVectorMaskedLoad final
           getAsOpFoldResult(maskOperands[numMaskOperands - 1]);
       OpFoldResult maskIndex =
           affine::makeComposedFoldedAffineApply(rewriter, loc, s0, origIndex);
-      newMask = rewriter.create<vector::CreateMaskOp>(
-          loc, newMaskType,
+      auto newMaskOperands = llvm::to_vector(maskOperands.drop_back());
+      newMaskOperands.push_back(
           getValueOrCreateConstantIndexOp(rewriter, loc, maskIndex));
+      newMask = rewriter.create<vector::CreateMaskOp>(loc, newMaskType,
+                                                      newMaskOperands);
     } else if (constantMaskOp) {
       auto maskDimSizes = constantMaskOp.getMaskDimSizes().getValue();
       auto numMaskOperands = maskDimSizes.size();
@@ -244,8 +258,16 @@ struct ConvertVectorMaskedLoad final
           cast<IntegerAttr>(maskDimSizes[numMaskOperands - 1]).getInt();
       auto maskIndex =
           rewriter.getI64IntegerAttr((origIndex + scale - 1) / scale);
+      auto newMaskDimSizes = llvm::to_vector(maskDimSizes.drop_back());
+      newMaskDimSizes.push_back(maskIndex);
       newMask = rewriter.create<vector::ConstantMaskOp>(
-          loc, newMaskType, ArrayAttr::get(op.getContext(), maskIndex));
+          loc, newMaskType, rewriter.getArrayAttr(newMaskDimSizes));
+    }
+
+    while (!extractOps.empty()) {
+      newMask = rewriter.create<vector::ExtractOp>(
+          loc, newMask->getResults()[0], extractOps.back().getMixedPosition());
+      extractOps.pop_back();
     }
 
     auto newPassThru =
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
index 28ca1c88e3eb3b7..b99f3013b701dd3 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
@@ -302,3 +302,121 @@ func.func @vector_cst_maskedload_i4(%arg1: index, %arg2: index, %passthru: vecto
 //      CHECK32:   %[[SECOND_PART:.+]] = arith.andi %[[NEW_PASSTHRU]], %[[PASSTHRU_MASK]] : vector<1xi32>
 //      CHECK32:   %[[VEC:.+]] = arith.ori %[[FIRST_PART]], %[[SECOND_PART]] : vector<1xi32>
 //      CHECK32:   %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<1xi32> to vector<8xi4>
+
+// -----
+
+func.func @vector_extract_maskedload_i4(%arg1: index) -> vector<8x8x16xi4> {
+    %0 = memref.alloc() : memref<8x8x16xi4>
+    %c0 = arith.constant 0 : index
+    %c16 = arith.constant 16 : index
+    %c8 = arith.constant 8 : index
+    %cst_1 = arith.constant dense<0> : vector<8x8x16xi4>
+    %cst_2 = arith.constant dense<0> : vector<16xi4>
+    %27 = vector.create_mask %c8, %arg1, %c16 : vector<8x8x16xi1>
+    %48 = vector.extract %27[0] : vector<8x16xi1> from vector<8x8x16xi1>
+    %49 = vector.extract %48[0] : vector<16xi1> from vector<8x16xi1>
+    %50 = vector.maskedload %0[%c0, %c0, %c0], %49, %cst_2 : memref<8x8x16xi4>, vector<16xi1>, vector<16xi4> into vector<16xi4>
+    %63 = vector.insert %50, %cst_1 [0, 0] : vector<16xi4> into vector<8x8x16xi4>
+    return %63 : vector<8x8x16xi4>
+}
+//      CHECK: func @vector_extract_maskedload_i4(
+//      CHECK:   %[[ALLOC:.+]] = memref.alloc() : memref<512xi8>
+//      CHECK:   %[[PASSTHRU:.+]] = arith.constant dense<0> : vector<16xi4>
+//      CHECK:   %[[ORIG_MASK:.+]] = vector.create_mask {{.*}} vector<8x8x16xi1>
+//      CHECK:   %[[ORIG_EXT1:.+]] = vector.extract %[[ORIG_MASK]][0] : vector<8x16xi1>
+//      CHECK:   %[[ORIG_EXT2:.+]] = vector.extract %[[ORIG_EXT1]][0] : vector<16xi1>
+//      CHECK:   %[[NEW_MASK:.+]] = vector.create_mask {{.*}} vector<8x8x8xi1>
+//      CHECK:   %[[NEW_EXT1:.+]] = vector.extract %[[NEW_MASK]][0] : vector<8x8xi1>
+//      CHECK:   %[[NEW_EXT2:.+]] = vector.extract %[[NEW_EXT1]][0] : vector<8xi1>
+//      CHECK:   %[[NEW_PASSTHRU:.+]] = vector.bitcast %[[PASSTHRU]] : vector<16xi4> to vector<8xi8>
+//      CHECK:   %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%c0], %[[NEW_EXT2]], %[[NEW_PASSTHRU]] :
+// CHECK-SAME:     memref<512xi8>, vector<8xi1>, vector<8xi8> into vector<8xi8>
+//      CHECK:   %[[EXT:.+]] = arith.extsi %[[ORIG_EXT2]] : vector<16xi1> to vector<16xi4>
+//      CHECK:   %[[AND_MASK:.+]] = vector.bitcast %[[EXT]] : vector<16xi4> to vector<8xi8>
+//      CHECK:   %[[FIRST_PART:.+]] = arith.andi %[[LOAD]], %[[AND_MASK]] : vector<8xi8>
+//      CHECK:   %[[ONES:.+]] = arith.constant dense<-1> : vector<16xi4>
+//      CHECK:   %[[XOR:.+]] = arith.xori %[[ONES]], %[[EXT]] : vector<16xi4>
+//      CHECK:   %[[PASSTHRU_MASK:.+]] = vector.bitcast %[[XOR]] : vector<16xi4> to vector<8xi8>
+//      CHECK:   %[[SECOND_PART:.+]] = arith.andi %[[NEW_PASSTHRU]], %[[PASSTHRU_MASK]] : vector<8xi8>
+//      CHECK:   %[[VEC:.+]] = arith.ori %[[FIRST_PART]], %[[SECOND_PART]] : vector<8xi8>
+//      CHECK:   %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<8xi8> to vector<16xi4>
+
+//      CHECK32: func @vector_extract_maskedload_i4(
+//      CHECK32:   %[[ALLOC:.+]] = memref.alloc() : memref<128xi32>
+//      CHECK32:   %[[PASSTHRU:.+]] = arith.constant dense<0> : vector<16xi4>
+//      CHECK32:   %[[ORIG_MASK:.+]] = vector.create_mask {{.*}} vector<8x8x16xi1>
+//      CHECK32:   %[[ORIG_EXT1:.+]] = vector.extract %[[ORIG_MASK]][0] : vector<8x16xi1>
+//      CHECK32:   %[[ORIG_EXT2:.+]] = vector.extract %[[ORIG_EXT1]][0] : vector<16xi1>
+//      CHECK32:   %[[NEW_MASK:.+]] = vector.create_mask {{.*}} vector<8x8x2xi1>
+//      CHECK32:   %[[NEW_EXT1:.+]] = vector.extract %[[NEW_MASK]][0] : vector<8x2xi1>
+//      CHECK32:   %[[NEW_EXT2:.+]] = vector.extract %[[NEW_EXT1]][0] : vector<2xi1>
+//      CHECK32:   %[[NEW_PASSTHRU:.+]] = vector.bitcast %[[PASSTHRU]] : vector<16xi4> to vector<2xi32>
+//      CHECK32:   %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%c0], %[[NEW_EXT2]], %[[NEW_PASSTHRU]] :
+// CHECK32-SAME:     memref<128xi32>, vector<2xi1>, vector<2xi32> into vector<2xi32>
+//      CHECK32:   %[[EXT:.+]] = arith.extsi %[[ORIG_EXT2]] : vector<16xi1> to vector<16xi4>
+//      CHECK32:   %[[AND_MASK:.+]] = vector.bitcast %[[EXT]] : vector<16xi4> to vector<2xi32>
+//      CHECK32:   %[[FIRST_PART:.+]] = arith.andi %[[LOAD]], %[[AND_MASK]] : vector<2xi32>
+//      CHECK32:   %[[ONES:.+]] = arith.constant dense<-1> : vector<16xi4>
+//      CHECK32:   %[[XOR:.+]] = arith.xori %[[ONES]], %[[EXT]] : vector<16xi4>
+//      CHECK32:   %[[PASSTHRU_MASK:.+]] = vector.bitcast %[[XOR]] : vector<16xi4> to vector<2xi32>
+//      CHECK32:   %[[SECOND_PART:.+]] = arith.andi %[[NEW_PASSTHRU]], %[[PASSTHRU_MASK]] : vector<2xi32>
+//      CHECK32:   %[[VEC:.+]] = arith.ori %[[FIRST_PART]], %[[SECOND_PART]] : vector<2xi32>
+//      CHECK32:   %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<2xi32> to vector<16xi4>
+
+// -----
+
+func.func @vector_extract_cst_maskedload_i4() -> vector<8x8x16xi4> {
+    %0 = memref.alloc() : memref<8x8x16xi4>
+    %c0 = arith.constant 0 : index
+    %cst_1 = arith.constant dense<0> : vector<8x8x16xi4>
+    %cst_2 = arith.constant dense<0> : vector<16xi4>
+    %27 = vector.constant_mask [8, 4, 16] : vector<8x8x16xi1>
+    %48 = vector.extract %27[0] : vector<8x16xi1> from vector<8x8x16xi1>
+    %49 = vector.extract %48[0] : vector<16xi1> from vector<8x16xi1>
+    %50 = vector.maskedload %0[%c0, %c0, %c0], %49, %cst_2 : memref<8x8x16xi4>, vector<16xi1>, vector<16xi4> into vector<16xi4>
+    %63 = vector.insert %50, %cst_1 [0, 0] : vector<16xi4> into vector<8x8x16xi4>
+    return %63 : vector<8x8x16xi4>
+}
+//      CHECK: func @vector_extract_cst_maskedload_i4(
+//      CHECK:   %[[ALLOC:.+]] = memref.alloc() : memref<512xi8>
+//      CHECK:   %[[PASSTHRU:.+]] = arith.constant dense<0> : vector<16xi4>
+//      CHECK:   %[[ORIG_MASK:.+]] = vector.constant_mask {{.*}} vector<8x8x16xi1>
+//      CHECK:   %[[ORIG_EXT1:.+]] = vector.extract %[[ORIG_MASK]][0] : vector<8x16xi1>
+//      CHECK:   %[[ORIG_EXT2:.+]] = vector.extract %[[ORIG_EXT1]][0] : vector<16xi1>
+//      CHECK:   %[[NEW_MASK:.+]] = vector.constant_mask {{.*}} vector<8x8x8xi1>
+//      CHECK:   %[[NEW_EXT1:.+]] = vector.extract %[[NEW_MASK]][0] : vector<8x8xi1>
+//      CHECK:   %[[NEW_EXT2:.+]] = vector.extract %[[NEW_EXT1]][0] : vector<8xi1>
+//      CHECK:   %[[NEW_PASSTHRU:.+]] = vector.bitcast %[[PASSTHRU]] : vector<16xi4> to vector<8xi8>
+//      CHECK:   %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%c0], %[[NEW_EXT2]], %[[NEW_PASSTHRU]] :
+// CHECK-SAME:     memref<512xi8>, vector<8xi1>, vector<8xi8> into vector<8xi8>
+//      CHECK:   %[[EXT:.+]] = arith.extsi %[[ORIG_EXT2]] : vector<16xi1> to vector<16xi4>
+//      CHECK:   %[[AND_MASK:.+]] = vector.bitcast %[[EXT]] : vector<16xi4> to vector<8xi8>
+//      CHECK:   %[[FIRST_PART:.+]] = arith.andi %[[LOAD]], %[[AND_MASK]] : vector<8xi8>
+//      CHECK:   %[[ONES:.+]] = arith.constant dense<-1> : vector<16xi4>
+//      CHECK:   %[[XOR:.+]] = arith.xori %[[ONES]], %[[EXT]] : vector<16xi4>
+//      CHECK:   %[[PASSTHRU_MASK:.+]] = vector.bitcast %[[XOR]] : vector<16xi4> to vector<8xi8>
+//      CHECK:   %[[SECOND_PART:.+]] = arith.andi %[[NEW_PASSTHRU]], %[[PASSTHRU_MASK]] : vector<8xi8>
+//      CHECK:   %[[VEC:.+]] = arith.ori %[[FIRST_PART]], %[[SECOND_PART]] : vector<8xi8>
+//      CHECK:   %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<8xi8> to vector<16xi4>
+
+//      CHECK32: func @vector_extract_cst_maskedload_i4(
+//      CHECK32:   %[[ALLOC:.+]] = memref.alloc() : memref<128xi32>
+//      CHECK32:   %[[PASSTHRU:.+]] = arith.constant dense<0> : vector<16xi4>
+//      CHECK32:   %[[ORIG_MASK:.+]] = vector.constant_mask {{.*}} vector<8x8x16xi1>
+//      CHECK32:   %[[ORIG_EXT1:.+]] = vector.extract %[[ORIG_MASK]][0] : vector<8x16xi1>
+//      CHECK32:   %[[ORIG_EXT2:.+]] = vector.extract %[[ORIG_EXT1]][0] : vector<16xi1>
+//      CHECK32:   %[[NEW_MASK:.+]] = vector.constant_mask {{.*}} vector<8x8x2xi1>
+//      CHECK32:   %[[NEW_EXT1:.+]] = vector.extract %[[NEW_MASK]][0] : vector<8x2xi1>
+//      CHECK32:   %[[NEW_EXT2:.+]] = vector.extract %[[NEW_EXT1]][0] : vector<2xi1>
+//      CHECK32:   %[[NEW_PASSTHRU:.+]] = vector.bitcast %[[PASSTHRU]] : vector<16xi4> to vector<2xi32>
+//      CHECK32:   %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%c0], %[[NEW_EXT2]], %[[NEW_PASSTHRU]] :
+// CHECK32-SAME:     memref<128xi32>, vector<2xi1>, vector<2xi32> into vector<2xi32>
+//      CHECK32:   %[[EXT:.+]] = arith.extsi %[[ORIG_EXT2]] : vector<16xi1> to vector<16xi4>
+//      CHECK32:   %[[AND_MASK:.+]] = vector.bitcast %[[EXT]] : vector<16xi4> to vector<2xi32>
+//      CHECK32:   %[[FIRST_PART:.+]] = arith.andi %[[LOAD]], %[[AND_MASK]] : vector<2xi32>
+//      CHECK32:   %[[ONES:.+]] = arith.constant dense<-1> : vector<16xi4>
+//      CHECK32:   %[[XOR:.+]] = arith.xori %[[ONES]], %[[EXT]] : vector<16xi4>
+//      CHECK32:   %[[PASSTHRU_MASK:.+]] = vector.bitcast %[[XOR]] : vector<16xi4> to vector<2xi32>
+//      CHECK32:   %[[SECOND_PART:.+]] = arith.andi %[[NEW_PASSTHRU]], %[[PASSTHRU_MASK]] : vector<2xi32>
+//      CHECK32:   %[[VEC:.+]] = arith.ori %[[FIRST_PART]], %[[SECOND_PART]] : vector<2xi32>
+//      CHECK32:   %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<2xi32> to vector<16xi4>

>From f0ad6fc63c88b31f7a3b5bef4ce53c11bccfb08f Mon Sep 17 00:00:00 2001
From: tyb0807 <vuson at google.com>
Date: Thu, 26 Oct 2023 23:30:44 +0000
Subject: [PATCH 3/4] [mlir][Vector] Use a simpler lowering when emulating
 narrow type for vector.maskedload

---
 .../Transforms/VectorEmulateNarrowType.cpp    |  99 +++++++--------
 .../Vector/vector-emulate-narrow-type.mlir    | 116 ++++--------------
 2 files changed, 65 insertions(+), 150 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index a74a9753fd02a7a..6e816c1099499ca 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -135,35 +135,23 @@ struct ConvertVectorMaskedLoad final
     //
     //   %mask = vector.constant_mask [3] : vector<6xi1>
     //   %1 = vector.maskedload %0[%c0, %c0], %mask, %pass_thru :
-    //   memref<3x6xi4>, vector<6xi1>, vector<6xi4> into vector<6xi4>
+    //        memref<3x6xi4>, vector<6xi1>, vector<6xi4> into vector<6xi4>
     //
     // can be replaced with
     //
     //   %new_mask = vector.constant_mask [2] : vector<3xi1>
-    //   %new_pass_thru = vector.bitcast %pass_thru : vector<6xi4> to
-    //   vector<3xi8> %1 = vector.maskedload %0[%linear_index], %new_mask,
-    //   %new_pass_thru : memref<9xi8>, vector<3xi1>, vector<3xi8> into
-    //   vector<3xi8>
+    //   %new_pass_thru = vector.bitcast %pass_thru :
+    //        vector<6xi4> to vector<3xi8>
+    //   %1 = vector.maskedload %0[%linear_index], %new_mask, %new_pass_thru :
+    //        memref<9xi8>, vector<3xi1>, vector<3xi8> into vector<3xi8>
+    //   %2 = vector.bitcast %1 : vector<3xi8> to vector<6xi4>
     //
     // Since we are effectively loading 16 bits (2xi8) from the memref with the
     // new mask, while originally we only wanted to effectively load 12 bits
     // (3xi4) from the memref, we need to set the second half of the last i8
-    // that was effectively loaded (i.e. the second i8) to 0.
+    // that was effectively loaded (i.e. the second i8) to %pass_thru.
     //
-    //   %unset_mask = arith.extsi %mask : vector<6xi1> to vector<6xi4>
-    //   %2 = vector.bitcast %unset_mask : vector<6xi4> to vector<3xi8>
-    //   %3 = arith.andi %1, %2 : vector<3xi8>
-    //
-    // Then if the second half of the second i8 from %pass_thru is not all 0s,
-    // we need to write their values back to the result.
-    //
-    //   %cst_1 = arith.constant dense<-1> : vector<6xi4>
-    //   %set_mask = arith.xori %unset_mask, %cst_1 : vector<6xi4>
-    //   %4 = vector.bitcast %set_mask : vector<6xi4> to vector<3xi8>
-    //   %5 = arith.andi %new_pass_thru, %4 : vector<3xi8>
-    //
-    //   %6 = arith.ori %3, %5 : vector<3xi8>
-    //   %7 = vector.bitcast %6 : vector<3xi8> to vector<6xi4>
+    //   %3 = arith.select %mask, %2, %pass_thru : vector<6xi1>, vector<6xi4>
     //
     // Given these input values:
     //   %mask = [1, 1, 1, 0, 0, 0]
@@ -177,17 +165,8 @@ struct ConvertVectorMaskedLoad final
     //   %new_mask = [1, 1, 0]
     //   %new_pass_thru = [0x78, 0x9A, 0xBC]
     //   %1 = [0x12, 0x34, 0xBC]
-    //
-    //   %unset_mask = [0xF, 0xF, 0xF, 0, 0, 0]
-    //   %2 = [0xFF, 0xF0, 0]
-    //   %3 = [0x12, 0x30, 0]
-    //
-    //   %set_mask = [0, 0, 0, 0xF, 0xF, 0xF]
-    //   %4 = [0, 0x0F, 0xFF]
-    //   %5 = [0, 0x0A, 0xBC]
-    //
-    //   %6 = [0x12, 0x3A, 0xBC]
-    //   %7 = [0x1, 0x2, 0x3, 0xA, 0xB, 0xC]
+    //   %2 = [0x1, 0x2, 0x3, 0x4, 0xB, 0xC]
+    //   %3 = [0x1, 0x2, 0x3, 0xA, 0xB, 0xC]
     //
     // TODO: Currently, only the even number of elements loading is supported.
     // To deal with the odd number of elements, one has to extract the
@@ -279,33 +258,39 @@ struct ConvertVectorMaskedLoad final
         getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
         newMask->getResult(0), newPassThru);
 
-    // Setting the part that originally was not effectively loaded from memory
-    // to 0.
-    auto andMask = rewriter.create<arith::ExtSIOp>(loc, origType, op.getMask());
-    auto bitCastedAndMask =
-        rewriter.create<vector::BitCastOp>(loc, newType, andMask);
-    auto loadedFromMem =
-        rewriter.create<arith::AndIOp>(loc, newLoad, bitCastedAndMask);
-
-    // Copying from pass through.
-    auto allOne = rewriter.create<arith::ConstantOp>(
-        loc, origType,
-        DenseIntElementsAttr::get(origType, {APInt::getAllOnes(srcBits)}));
-    auto passThruMask = rewriter.create<arith::XOrIOp>(loc, allOne.getResult(),
-                                                       andMask.getResult());
-    auto bitCastedPassThruMask =
-        rewriter.create<vector::BitCastOp>(loc, newType, passThruMask);
-    auto copiedFromPassThru =
-        rewriter.create<arith::AndIOp>(loc, newPassThru, bitCastedPassThruMask);
-
-    // Or-ing the first part loaded from memory and the second one copied from
-    // pass through to form the result.
-    auto result =
-        rewriter.create<arith::OrIOp>(loc, loadedFromMem, copiedFromPassThru);
     auto bitCast =
-        rewriter.create<vector::BitCastOp>(loc, op.getType(), result);
-
-    rewriter.replaceOp(op, bitCast->getResult(0));
+        rewriter.create<vector::BitCastOp>(loc, op.getType(), newLoad);
+    auto select =
+        rewriter.create<arith::SelectOp>(loc, op.getMask(), bitCast, op.getPassThru());
+    rewriter.replaceOp(op, select->getResult(0));
+
+//  // Setting the part that originally was not effectively loaded from memory
+//  // to 0.
+//  auto andMask = rewriter.create<arith::ExtSIOp>(loc, origType, op.getMask());
+//  auto bitCastedAndMask =
+//      rewriter.create<vector::BitCastOp>(loc, newType, andMask);
+//  auto loadedFromMem =
+//      rewriter.create<arith::AndIOp>(loc, newLoad, bitCastedAndMask);
+
+//  // Copying from pass through.
+//  auto allOne = rewriter.create<arith::ConstantOp>(
+//      loc, origType,
+//      DenseIntElementsAttr::get(origType, {APInt::getAllOnes(srcBits)}));
+//  auto passThruMask = rewriter.create<arith::XOrIOp>(loc, allOne.getResult(),
+//                                                     andMask.getResult());
+//  auto bitCastedPassThruMask =
+//      rewriter.create<vector::BitCastOp>(loc, newType, passThruMask);
+//  auto copiedFromPassThru =
+//      rewriter.create<arith::AndIOp>(loc, newPassThru, bitCastedPassThruMask);
+
+//  // Or-ing the first part loaded from memory and the second one copied from
+//  // pass through to form the result.
+//  auto result =
+//      rewriter.create<arith::OrIOp>(loc, loadedFromMem, copiedFromPassThru);
+//  auto bitCast =
+//      rewriter.create<vector::BitCastOp>(loc, op.getType(), result);
+
+//  rewriter.replaceOp(op, bitCast->getResult(0));
     return success();
   }
 };
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
index b99f3013b701dd3..e1d6c3be494713e 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
@@ -141,16 +141,9 @@ func.func @vector_maskedload_i8(%arg1: index, %arg2: index, %arg3: index, %passt
 //      CHECK32:   %[[NEW_PASSTHRU:.+]] = vector.bitcast %[[ARG3]] : vector<4xi8> to vector<1xi32>
 //      CHECK32:   %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LD_IDX]]], %[[NEW_MASK]], %[[NEW_PASSTHRU]] :
 // CHECK32-SAME:     memref<3xi32>, vector<1xi1>, vector<1xi32> into vector<1xi32>
-//      CHECK32:   %[[EXT:.+]] = arith.extsi %[[ORIG_MASK]] : vector<4xi1> to vector<4xi8>
-//      CHECK32:   %[[AND_MASK:.+]] = vector.bitcast %[[EXT]] : vector<4xi8> to vector<1xi32>
-//      CHECK32:   %[[FIRST_PART:.+]] = arith.andi %[[LOAD]], %[[AND_MASK]] : vector<1xi32>
-//      CHECK32:   %[[ONES:.+]] = arith.constant dense<-1> : vector<4xi8>
-//      CHECK32:   %[[XOR:.+]] = arith.xori %[[ONES]], %[[EXT]] : vector<4xi8>
-//      CHECK32:   %[[PASSTHRU_MASK:.+]] = vector.bitcast %[[XOR]] : vector<4xi8> to vector<1xi32>
-//      CHECK32:   %[[SECOND_PART:.+]] = arith.andi %[[NEW_PASSTHRU]], %[[PASSTHRU_MASK]] : vector<1xi32>
-//      CHECK32:   %[[VEC:.+]] = arith.ori %[[FIRST_PART]], %[[SECOND_PART]] : vector<1xi32>
-//      CHECK32:   %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<1xi32> to vector<4xi8>
-//      CHECK32:   return %[[VEC_I4]]
+//      CHECK32:   %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi32> to vector<4xi8>
+//      CHECK32:   %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[BITCAST]], %[[ARG3]] : vector<4xi1>, vector<4xi8>
+//      CHECK32:   return %[[SELECT]]
 
 // -----
 
@@ -176,15 +169,8 @@ func.func @vector_maskedload_i4(%arg1: index, %arg2: index, %arg3: index, %passt
 //      CHECK:   %[[NEW_PASSTHRU:.+]] = vector.bitcast %[[ARG3]] : vector<8xi4> to vector<4xi8>
 //      CHECK:   %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LD_IDX]]], %[[NEW_MASK]], %[[NEW_PASSTHRU]] :
 // CHECK-SAME:     memref<12xi8>, vector<4xi1>, vector<4xi8> into vector<4xi8>
-//      CHECK:   %[[EXT:.+]] = arith.extsi %[[ORIG_MASK]] : vector<8xi1> to vector<8xi4>
-//      CHECK:   %[[AND_MASK:.+]] = vector.bitcast %[[EXT]] : vector<8xi4> to vector<4xi8>
-//      CHECK:   %[[FIRST_PART:.+]] = arith.andi %[[LOAD]], %[[AND_MASK]] : vector<4xi8>
-//      CHECK:   %[[ONES:.+]] = arith.constant dense<-1> : vector<8xi4>
-//      CHECK:   %[[XOR:.+]] = arith.xori %[[ONES]], %[[EXT]] : vector<8xi4>
-//      CHECK:   %[[PASSTHRU_MASK:.+]] = vector.bitcast %[[XOR]] : vector<8xi4> to vector<4xi8>
-//      CHECK:   %[[SECOND_PART:.+]] = arith.andi %[[NEW_PASSTHRU]], %[[PASSTHRU_MASK]] : vector<4xi8>
-//      CHECK:   %[[VEC:.+]] = arith.ori %[[FIRST_PART]], %[[SECOND_PART]] : vector<4xi8>
-//      CHECK:   %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<4xi8> to vector<8xi4>
+//      CHECK:   %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<4xi8> to vector<8xi4>
+//      CHECK:   %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[BITCAST]], %[[ARG3]] : vector<8xi1>, vector<8xi4>
 
 //  CHECK32-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
 //  CHECK32-DAG: #[[MASK_IDX_MAP:.+]] = affine_map<()[s0] -> ((s0 + 7) floordiv 8)>
@@ -199,15 +185,8 @@ func.func @vector_maskedload_i4(%arg1: index, %arg2: index, %arg3: index, %passt
 //      CHECK32:   %[[NEW_PASSTHRU:.+]] = vector.bitcast %[[ARG3]] : vector<8xi4> to vector<1xi32>
 //      CHECK32:   %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LD_IDX]]], %[[NEW_MASK]], %[[NEW_PASSTHRU]] :
 // CHECK32-SAME:     memref<3xi32>, vector<1xi1>, vector<1xi32> into vector<1xi32>
-//      CHECK32:   %[[EXT:.+]] = arith.extsi %[[ORIG_MASK]] : vector<8xi1> to vector<8xi4>
-//      CHECK32:   %[[AND_MASK:.+]] = vector.bitcast %[[EXT]] : vector<8xi4> to vector<1xi32>
-//      CHECK32:   %[[FIRST_PART:.+]] = arith.andi %[[LOAD]], %[[AND_MASK]] : vector<1xi32>
-//      CHECK32:   %[[ONES:.+]] = arith.constant dense<-1> : vector<8xi4>
-//      CHECK32:   %[[XOR:.+]] = arith.xori %[[ONES]], %[[EXT]] : vector<8xi4>
-//      CHECK32:   %[[PASSTHRU_MASK:.+]] = vector.bitcast %[[XOR]] : vector<8xi4> to vector<1xi32>
-//      CHECK32:   %[[SECOND_PART:.+]] = arith.andi %[[NEW_PASSTHRU]], %[[PASSTHRU_MASK]] : vector<1xi32>
-//      CHECK32:   %[[VEC:.+]] = arith.ori %[[FIRST_PART]], %[[SECOND_PART]] : vector<1xi32>
-//      CHECK32:   %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<1xi32> to vector<8xi4>
+//      CHECK32:   %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi32> to vector<8xi4>
+//      CHECK32:   %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[BITCAST]], %[[ARG3]] : vector<8xi1>, vector<8xi4>
 
 // -----
 
@@ -239,16 +218,9 @@ func.func @vector_cst_maskedload_i8(%arg1: index, %arg2: index, %passthru: vecto
 //      CHECK32:   %[[NEW_PASSTHRU:.+]] = vector.bitcast %[[ARG3]] : vector<4xi8> to vector<1xi32>
 //      CHECK32:   %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LD_IDX]]], %[[NEW_MASK]], %[[NEW_PASSTHRU]] :
 // CHECK32-SAME:     memref<3xi32>, vector<1xi1>, vector<1xi32> into vector<1xi32>
-//      CHECK32:   %[[EXT:.+]] = arith.extsi %[[ORIG_MASK]] : vector<4xi1> to vector<4xi8>
-//      CHECK32:   %[[AND_MASK:.+]] = vector.bitcast %[[EXT]] : vector<4xi8> to vector<1xi32>
-//      CHECK32:   %[[FIRST_PART:.+]] = arith.andi %[[LOAD]], %[[AND_MASK]] : vector<1xi32>
-//      CHECK32:   %[[ONES:.+]] = arith.constant dense<-1> : vector<4xi8>
-//      CHECK32:   %[[XOR:.+]] = arith.xori %[[ONES]], %[[EXT]] : vector<4xi8>
-//      CHECK32:   %[[PASSTHRU_MASK:.+]] = vector.bitcast %[[XOR]] : vector<4xi8> to vector<1xi32>
-//      CHECK32:   %[[SECOND_PART:.+]] = arith.andi %[[NEW_PASSTHRU]], %[[PASSTHRU_MASK]] : vector<1xi32>
-//      CHECK32:   %[[VEC:.+]] = arith.ori %[[FIRST_PART]], %[[SECOND_PART]] : vector<1xi32>
-//      CHECK32:   %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<1xi32> to vector<4xi8>
-//      CHECK32:   return %[[VEC_I4]]
+//      CHECK32:   %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi32> to vector<4xi8>
+//      CHECK32:   %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[BITCAST]], %[[ARG3]] : vector<4xi1>, vector<4xi8>
+//      CHECK32:   return %[[SELECT]]
 
 // -----
 
@@ -272,36 +244,22 @@ func.func @vector_cst_maskedload_i4(%arg1: index, %arg2: index, %passthru: vecto
 //      CHECK:   %[[NEW_PASSTHRU:.+]] = vector.bitcast %[[ARG2]] : vector<8xi4> to vector<4xi8>
 //      CHECK:   %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LD_IDX]]], %[[NEW_MASK]], %[[NEW_PASSTHRU]] :
 // CHECK-SAME:     memref<12xi8>, vector<4xi1>, vector<4xi8> into vector<4xi8>
-//      CHECK:   %[[EXT:.+]] = arith.extsi %[[ORIG_MASK]] : vector<8xi1> to vector<8xi4>
-//      CHECK:   %[[AND_MASK:.+]] = vector.bitcast %[[EXT]] : vector<8xi4> to vector<4xi8>
-//      CHECK:   %[[FIRST_PART:.+]] = arith.andi %[[LOAD]], %[[AND_MASK]] : vector<4xi8>
-//      CHECK:   %[[ONES:.+]] = arith.constant dense<-1> : vector<8xi4>
-//      CHECK:   %[[XOR:.+]] = arith.xori %[[ONES]], %[[EXT]] : vector<8xi4>
-//      CHECK:   %[[PASSTHRU_MASK:.+]] = vector.bitcast %[[XOR]] : vector<8xi4> to vector<4xi8>
-//      CHECK:   %[[SECOND_PART:.+]] = arith.andi %[[NEW_PASSTHRU]], %[[PASSTHRU_MASK]] : vector<4xi8>
-//      CHECK:   %[[VEC:.+]] = arith.ori %[[FIRST_PART]], %[[SECOND_PART]] : vector<4xi8>
-//      CHECK:   %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<4xi8> to vector<8xi4>
+//      CHECK:   %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<4xi8> to vector<8xi4>
+//      CHECK:   %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[BITCAST]], %[[ARG2]] : vector<8xi1>, vector<8xi4>
 
 //  CHECK32-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
 //      CHECK32: func @vector_cst_maskedload_i4(
 // CHECK32-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index,
 // CHECK32-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: vector<8xi4>)
 //      CHECK32:   %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
-//      CHECK32: %[[ORIG_MASK:.+]] = vector.constant_mask [4] : vector<8xi1>
+//      CHECK32:   %[[ORIG_MASK:.+]] = vector.constant_mask [4] : vector<8xi1>
 //      CHECK32:   %[[LD_IDX:.+]] = affine.apply #[[LOAD_IDX_MAP]]()[%[[ARG0]], %[[ARG1]]]
 //      CHECK32:   %[[NEW_MASK:.+]] = vector.constant_mask [1] : vector<1xi1>
 //      CHECK32:   %[[NEW_PASSTHRU:.+]] = vector.bitcast %[[ARG2]] : vector<8xi4> to vector<1xi32>
 //      CHECK32:   %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LD_IDX]]], %[[NEW_MASK]], %[[NEW_PASSTHRU]] :
 // CHECK32-SAME:     memref<3xi32>, vector<1xi1>, vector<1xi32> into vector<1xi32>
-//      CHECK32:   %[[EXT:.+]] = arith.extsi %[[ORIG_MASK]] : vector<8xi1> to vector<8xi4>
-//      CHECK32:   %[[AND_MASK:.+]] = vector.bitcast %[[EXT]] : vector<8xi4> to vector<1xi32>
-//      CHECK32:   %[[FIRST_PART:.+]] = arith.andi %[[LOAD]], %[[AND_MASK]] : vector<1xi32>
-//      CHECK32:   %[[ONES:.+]] = arith.constant dense<-1> : vector<8xi4>
-//      CHECK32:   %[[XOR:.+]] = arith.xori %[[ONES]], %[[EXT]] : vector<8xi4>
-//      CHECK32:   %[[PASSTHRU_MASK:.+]] = vector.bitcast %[[XOR]] : vector<8xi4> to vector<1xi32>
-//      CHECK32:   %[[SECOND_PART:.+]] = arith.andi %[[NEW_PASSTHRU]], %[[PASSTHRU_MASK]] : vector<1xi32>
-//      CHECK32:   %[[VEC:.+]] = arith.ori %[[FIRST_PART]], %[[SECOND_PART]] : vector<1xi32>
-//      CHECK32:   %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<1xi32> to vector<8xi4>
+//      CHECK32:   %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi32> to vector<8xi4>
+//      CHECK32:   %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[BITCAST]], %[[ARG2]] : vector<8xi1>, vector<8xi4>
 
 // -----
 
@@ -331,15 +289,8 @@ func.func @vector_extract_maskedload_i4(%arg1: index) -> vector<8x8x16xi4> {
 //      CHECK:   %[[NEW_PASSTHRU:.+]] = vector.bitcast %[[PASSTHRU]] : vector<16xi4> to vector<8xi8>
 //      CHECK:   %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%c0], %[[NEW_EXT2]], %[[NEW_PASSTHRU]] :
 // CHECK-SAME:     memref<512xi8>, vector<8xi1>, vector<8xi8> into vector<8xi8>
-//      CHECK:   %[[EXT:.+]] = arith.extsi %[[ORIG_EXT2]] : vector<16xi1> to vector<16xi4>
-//      CHECK:   %[[AND_MASK:.+]] = vector.bitcast %[[EXT]] : vector<16xi4> to vector<8xi8>
-//      CHECK:   %[[FIRST_PART:.+]] = arith.andi %[[LOAD]], %[[AND_MASK]] : vector<8xi8>
-//      CHECK:   %[[ONES:.+]] = arith.constant dense<-1> : vector<16xi4>
-//      CHECK:   %[[XOR:.+]] = arith.xori %[[ONES]], %[[EXT]] : vector<16xi4>
-//      CHECK:   %[[PASSTHRU_MASK:.+]] = vector.bitcast %[[XOR]] : vector<16xi4> to vector<8xi8>
-//      CHECK:   %[[SECOND_PART:.+]] = arith.andi %[[NEW_PASSTHRU]], %[[PASSTHRU_MASK]] : vector<8xi8>
-//      CHECK:   %[[VEC:.+]] = arith.ori %[[FIRST_PART]], %[[SECOND_PART]] : vector<8xi8>
-//      CHECK:   %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<8xi8> to vector<16xi4>
+//      CHECK:   %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<8xi8> to vector<16xi4>
+//      CHECK:   %[[SELECT:.+]] = arith.select %[[ORIG_EXT2]], %[[BITCAST]], %[[PASSTHRU]] : vector<16xi1>, vector<16xi4>
 
 //      CHECK32: func @vector_extract_maskedload_i4(
 //      CHECK32:   %[[ALLOC:.+]] = memref.alloc() : memref<128xi32>
@@ -353,15 +304,8 @@ func.func @vector_extract_maskedload_i4(%arg1: index) -> vector<8x8x16xi4> {
 //      CHECK32:   %[[NEW_PASSTHRU:.+]] = vector.bitcast %[[PASSTHRU]] : vector<16xi4> to vector<2xi32>
 //      CHECK32:   %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%c0], %[[NEW_EXT2]], %[[NEW_PASSTHRU]] :
 // CHECK32-SAME:     memref<128xi32>, vector<2xi1>, vector<2xi32> into vector<2xi32>
-//      CHECK32:   %[[EXT:.+]] = arith.extsi %[[ORIG_EXT2]] : vector<16xi1> to vector<16xi4>
-//      CHECK32:   %[[AND_MASK:.+]] = vector.bitcast %[[EXT]] : vector<16xi4> to vector<2xi32>
-//      CHECK32:   %[[FIRST_PART:.+]] = arith.andi %[[LOAD]], %[[AND_MASK]] : vector<2xi32>
-//      CHECK32:   %[[ONES:.+]] = arith.constant dense<-1> : vector<16xi4>
-//      CHECK32:   %[[XOR:.+]] = arith.xori %[[ONES]], %[[EXT]] : vector<16xi4>
-//      CHECK32:   %[[PASSTHRU_MASK:.+]] = vector.bitcast %[[XOR]] : vector<16xi4> to vector<2xi32>
-//      CHECK32:   %[[SECOND_PART:.+]] = arith.andi %[[NEW_PASSTHRU]], %[[PASSTHRU_MASK]] : vector<2xi32>
-//      CHECK32:   %[[VEC:.+]] = arith.ori %[[FIRST_PART]], %[[SECOND_PART]] : vector<2xi32>
-//      CHECK32:   %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<2xi32> to vector<16xi4>
+//      CHECK32:   %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<2xi32> to vector<16xi4>
+//      CHECK32:   %[[SELECT:.+]] = arith.select %[[ORIG_EXT2]], %[[BITCAST]], %[[PASSTHRU]] : vector<16xi1>, vector<16xi4>
 
 // -----
 
@@ -389,15 +333,8 @@ func.func @vector_extract_cst_maskedload_i4() -> vector<8x8x16xi4> {
 //      CHECK:   %[[NEW_PASSTHRU:.+]] = vector.bitcast %[[PASSTHRU]] : vector<16xi4> to vector<8xi8>
 //      CHECK:   %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%c0], %[[NEW_EXT2]], %[[NEW_PASSTHRU]] :
 // CHECK-SAME:     memref<512xi8>, vector<8xi1>, vector<8xi8> into vector<8xi8>
-//      CHECK:   %[[EXT:.+]] = arith.extsi %[[ORIG_EXT2]] : vector<16xi1> to vector<16xi4>
-//      CHECK:   %[[AND_MASK:.+]] = vector.bitcast %[[EXT]] : vector<16xi4> to vector<8xi8>
-//      CHECK:   %[[FIRST_PART:.+]] = arith.andi %[[LOAD]], %[[AND_MASK]] : vector<8xi8>
-//      CHECK:   %[[ONES:.+]] = arith.constant dense<-1> : vector<16xi4>
-//      CHECK:   %[[XOR:.+]] = arith.xori %[[ONES]], %[[EXT]] : vector<16xi4>
-//      CHECK:   %[[PASSTHRU_MASK:.+]] = vector.bitcast %[[XOR]] : vector<16xi4> to vector<8xi8>
-//      CHECK:   %[[SECOND_PART:.+]] = arith.andi %[[NEW_PASSTHRU]], %[[PASSTHRU_MASK]] : vector<8xi8>
-//      CHECK:   %[[VEC:.+]] = arith.ori %[[FIRST_PART]], %[[SECOND_PART]] : vector<8xi8>
-//      CHECK:   %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<8xi8> to vector<16xi4>
+//      CHECK:   %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<8xi8> to vector<16xi4>
+//      CHECK:   %[[SELECT:.+]] = arith.select %[[ORIG_EXT2]], %[[BITCAST]], %[[PASSTHRU]] : vector<16xi1>, vector<16xi4>
 
 //      CHECK32: func @vector_extract_cst_maskedload_i4(
 //      CHECK32:   %[[ALLOC:.+]] = memref.alloc() : memref<128xi32>
@@ -411,12 +348,5 @@ func.func @vector_extract_cst_maskedload_i4() -> vector<8x8x16xi4> {
 //      CHECK32:   %[[NEW_PASSTHRU:.+]] = vector.bitcast %[[PASSTHRU]] : vector<16xi4> to vector<2xi32>
 //      CHECK32:   %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%c0], %[[NEW_EXT2]], %[[NEW_PASSTHRU]] :
 // CHECK32-SAME:     memref<128xi32>, vector<2xi1>, vector<2xi32> into vector<2xi32>
-//      CHECK32:   %[[EXT:.+]] = arith.extsi %[[ORIG_EXT2]] : vector<16xi1> to vector<16xi4>
-//      CHECK32:   %[[AND_MASK:.+]] = vector.bitcast %[[EXT]] : vector<16xi4> to vector<2xi32>
-//      CHECK32:   %[[FIRST_PART:.+]] = arith.andi %[[LOAD]], %[[AND_MASK]] : vector<2xi32>
-//      CHECK32:   %[[ONES:.+]] = arith.constant dense<-1> : vector<16xi4>
-//      CHECK32:   %[[XOR:.+]] = arith.xori %[[ONES]], %[[EXT]] : vector<16xi4>
-//      CHECK32:   %[[PASSTHRU_MASK:.+]] = vector.bitcast %[[XOR]] : vector<16xi4> to vector<2xi32>
-//      CHECK32:   %[[SECOND_PART:.+]] = arith.andi %[[NEW_PASSTHRU]], %[[PASSTHRU_MASK]] : vector<2xi32>
-//      CHECK32:   %[[VEC:.+]] = arith.ori %[[FIRST_PART]], %[[SECOND_PART]] : vector<2xi32>
-//      CHECK32:   %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<2xi32> to vector<16xi4>
+//      CHECK32:   %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<2xi32> to vector<16xi4>
+//      CHECK32:   %[[SELECT:.+]] = arith.select %[[ORIG_EXT2]], %[[BITCAST]], %[[PASSTHRU]] : vector<16xi1>, vector<16xi4>

>From 5172954c5e9a048f2429d72b1312a6e2c3ce4367 Mon Sep 17 00:00:00 2001
From: tyb0807 <vuson at google.com>
Date: Thu, 26 Oct 2023 23:51:53 +0000
Subject: [PATCH 4/4] [mlir][Vector] Add integration test for vector.maskedload
 narrow type emulation pattern

---
 .../Vector/CPU/test-rewrite-narrow-types.mlir | 22 +++++++++++++++++++
 1 file changed, 22 insertions(+)

diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-rewrite-narrow-types.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-rewrite-narrow-types.mlir
index 6bdeb4523865fa9..9db70dd77df5dae 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-rewrite-narrow-types.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-rewrite-narrow-types.mlir
@@ -164,6 +164,13 @@ func.func @fext(%a: vector<5xi8>) {
   return
 }
 
+func.func @fcst_maskedload(%A: memref<?xi4>, %passthru: vector<6xi4>) -> vector<6xi4> {
+  %c0 = arith.constant 0: index
+  %mask = vector.constant_mask [3] : vector<6xi1>
+  %1 = vector.maskedload %A[%c0], %mask, %passthru :
+    memref<?xi4>, vector<6xi1>, vector<6xi4> into vector<6xi4>
+  return %1 : vector<6xi4>
+}
 
 func.func @entry() {
   %v = arith.constant dense<[
@@ -187,6 +194,21 @@ func.func @entry() {
   ]> : vector<5xi8>
   func.call @fext(%v4) : (vector<5xi8>) -> ()
 
+  // Set up memory.
+  %c0 = arith.constant 0: index
+  %c1 = arith.constant 1: index
+  %c6 = arith.constant 6: index
+  %A = memref.alloc(%c6) : memref<?xi4>
+  scf.for %i = %c0 to %c6 step %c1 {
+    %i4 = arith.index_cast %i : index to i4
+    memref.store %i4, %A[%i] : memref<?xi4>
+  }
+  %passthru = arith.constant dense<[7, 8, 9, 10, 11, 12]> : vector<6xi4>
+  %load = call @fcst_maskedload(%A, %passthru) : (memref<?xi4>, vector<6xi4>) -> (vector<6xi4>)
+  vector.print %load : vector<6xi4>
+  // CHECK: ( 1, 2, 3, 10, 11, 12 )
+  memref.dealloc %A : memref<?xi4>
+
   return
 }
 



More information about the Mlir-commits mailing list