[Mlir-commits] [mlir] 674261b - [mlir][Vector] Add narrow type emulation pattern for vector.maskedload (#68443)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Oct 27 01:50:03 PDT 2023
Author: tyb0807
Date: 2023-10-27T10:49:58+02:00
New Revision: 674261b20373b9a769d596eb9ab3f44e2305bc55
URL: https://github.com/llvm/llvm-project/commit/674261b20373b9a769d596eb9ab3f44e2305bc55
DIFF: https://github.com/llvm/llvm-project/commit/674261b20373b9a769d596eb9ab3f44e2305bc55.diff
LOG: [mlir][Vector] Add narrow type emulation pattern for vector.maskedload (#68443)
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
mlir/test/Integration/Dialect/Vector/CPU/test-rewrite-narrow-types.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 94300291dcd7d23..3d65123373109b3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -7,6 +7,7 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
@@ -103,6 +104,172 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
}
};
+//===----------------------------------------------------------------------===//
+// ConvertVectorMaskedLoad
+//===----------------------------------------------------------------------===//
+
+struct ConvertVectorMaskedLoad final
+ : OpConversionPattern<vector::MaskedLoadOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ auto loc = op.getLoc();
+ auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
+ Type oldElementType = op.getType().getElementType();
+ Type newElementType = convertedType.getElementType();
+ int srcBits = oldElementType.getIntOrFloatBitWidth();
+ int dstBits = newElementType.getIntOrFloatBitWidth();
+
+ if (dstBits % srcBits != 0) {
+ return rewriter.notifyMatchFailure(
+ op, "only dstBits % srcBits == 0 supported");
+ }
+ int scale = dstBits / srcBits;
+
+ // Adjust the number of elements to load when emulating narrow types,
+ // and then cast back to the original type with vector.bitcast op.
+ // For example, to emulate i4 to i8, the following op:
+ //
+ // %mask = vector.constant_mask [3] : vector<6xi1>
+ // %1 = vector.maskedload %0[%c0, %c0], %mask, %pass_thru :
+ // memref<3x6xi4>, vector<6xi1>, vector<6xi4> into vector<6xi4>
+ //
+ // can be replaced with
+ //
+ // %new_mask = vector.constant_mask [2] : vector<3xi1>
+ // %new_pass_thru = vector.bitcast %pass_thru :
+ // vector<6xi4> to vector<3xi8>
+ // %1 = vector.maskedload %0[%linear_index], %new_mask, %new_pass_thru :
+ // memref<9xi8>, vector<3xi1>, vector<3xi8> into vector<3xi8>
+ // %2 = vector.bitcast %1 : vector<3xi8> to vector<6xi4>
+ //
+ // Since we are effectively loading 16 bits (2xi8) from the memref with the
+ // new mask, while originally we only wanted to effectively load 12 bits
+ // (3xi4) from the memref, we need to set the second half of the last i8
+ // that was effectively loaded (i.e. the second i8) to %pass_thru.
+ //
+ // %3 = arith.select %mask, %2, %pass_thru : vector<6xi1>, vector<6xi4>
+ //
+ // Given these input values:
+ // %mask = [1, 1, 1, 0, 0, 0]
+ // %0[%c0, %c0] contains [0x1, 0x2, 0x3, 0x4, 0x5, 0x6]
+ // %pass_thru = [0x7, 0x8, 0x9, 0xA, 0xB, 0xC]
+ //
+ // we'll have:
+ //
+ // expected output: [0x1, 0x2, 0x3, 0xA, 0xB, 0xC]
+ //
+ // %new_mask = [1, 1, 0]
+ // %new_pass_thru = [0x78, 0x9A, 0xBC]
+ // %1 = [0x12, 0x34, 0xBC]
+ // %2 = [0x1, 0x2, 0x3, 0x4, 0xB, 0xC]
+ // %3 = [0x1, 0x2, 0x3, 0xA, 0xB, 0xC]
+ //
+ // TODO: Currently, only the even number of elements loading is supported.
+ // To deal with the odd number of elements, one has to extract the
+ // subvector at the proper offset after bit-casting.
+
+ auto origType = op.getVectorType();
+ auto origElements = origType.getNumElements();
+ if (origElements % scale != 0)
+ return failure();
+
+ auto stridedMetadata =
+ rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
+
+ OpFoldResult linearizedIndices;
+ std::tie(std::ignore, linearizedIndices) =
+ memref::getLinearizedMemRefOffsetAndSize(
+ rewriter, loc, srcBits, dstBits,
+ stridedMetadata.getConstifiedMixedOffset(),
+ stridedMetadata.getConstifiedMixedSizes(),
+ stridedMetadata.getConstifiedMixedStrides(),
+ getAsOpFoldResult(adaptor.getIndices()));
+
+ auto numElements = (origElements + scale - 1) / scale;
+ auto newType = VectorType::get(numElements, newElementType);
+
+ auto maskOp = op.getMask().getDefiningOp();
+ SmallVector<vector::ExtractOp, 2> extractOps;
+ // Finding the mask creation operation.
+ while (maskOp &&
+ !isa<vector::CreateMaskOp, vector::ConstantMaskOp>(maskOp)) {
+ if (auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
+ maskOp = extractOp.getVector().getDefiningOp();
+ extractOps.push_back(extractOp);
+ }
+ }
+ auto createMaskOp = dyn_cast_or_null<vector::CreateMaskOp>(maskOp);
+ auto constantMaskOp = dyn_cast_or_null<vector::ConstantMaskOp>(maskOp);
+ if (!createMaskOp && !constantMaskOp)
+ return failure();
+
+ // Computing the "compressed" mask. All the emulation logic (i.e. computing
+ // new mask index) only happens on the last dimension of the vectors.
+ Operation *newMask = nullptr;
+ auto shape = llvm::to_vector(
+ maskOp->getResultTypes()[0].cast<VectorType>().getShape().drop_back());
+ shape.push_back(numElements);
+ auto newMaskType = VectorType::get(shape, rewriter.getI1Type());
+ if (createMaskOp) {
+ auto maskOperands = createMaskOp.getOperands();
+ auto numMaskOperands = maskOperands.size();
+ AffineExpr s0;
+ bindSymbols(rewriter.getContext(), s0);
+ s0 = s0 + scale - 1;
+ s0 = s0.floorDiv(scale);
+ OpFoldResult origIndex =
+ getAsOpFoldResult(maskOperands[numMaskOperands - 1]);
+ OpFoldResult maskIndex =
+ affine::makeComposedFoldedAffineApply(rewriter, loc, s0, origIndex);
+ auto newMaskOperands = llvm::to_vector(maskOperands.drop_back());
+ newMaskOperands.push_back(
+ getValueOrCreateConstantIndexOp(rewriter, loc, maskIndex));
+ newMask = rewriter.create<vector::CreateMaskOp>(loc, newMaskType,
+ newMaskOperands);
+ } else if (constantMaskOp) {
+ auto maskDimSizes = constantMaskOp.getMaskDimSizes().getValue();
+ auto numMaskOperands = maskDimSizes.size();
+ auto origIndex =
+ cast<IntegerAttr>(maskDimSizes[numMaskOperands - 1]).getInt();
+ auto maskIndex =
+ rewriter.getI64IntegerAttr((origIndex + scale - 1) / scale);
+ auto newMaskDimSizes = llvm::to_vector(maskDimSizes.drop_back());
+ newMaskDimSizes.push_back(maskIndex);
+ newMask = rewriter.create<vector::ConstantMaskOp>(
+ loc, newMaskType, rewriter.getArrayAttr(newMaskDimSizes));
+ }
+
+ while (!extractOps.empty()) {
+ newMask = rewriter.create<vector::ExtractOp>(
+ loc, newMask->getResults()[0], extractOps.back().getMixedPosition());
+ extractOps.pop_back();
+ }
+
+ auto newPassThru =
+ rewriter.create<vector::BitCastOp>(loc, newType, op.getPassThru());
+
+ // Generating the new masked load.
+ auto newLoad = rewriter.create<vector::MaskedLoadOp>(
+ loc, newType, adaptor.getBase(),
+ getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
+ newMask->getResult(0), newPassThru);
+
+ // Setting the part that originally was not effectively loaded from memory
+ // to pass through.
+ auto bitCast =
+ rewriter.create<vector::BitCastOp>(loc, op.getType(), newLoad);
+ auto select = rewriter.create<arith::SelectOp>(loc, op.getMask(), bitCast,
+ op.getPassThru());
+ rewriter.replaceOp(op, select->getResult(0));
+
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// ConvertVectorTransferRead
//===----------------------------------------------------------------------===//
@@ -588,8 +755,8 @@ void vector::populateVectorNarrowTypeEmulationPatterns(
RewritePatternSet &patterns) {
// Populate `vector.*` conversion patterns.
- patterns.add<ConvertVectorLoad, ConvertVectorTransferRead>(
- typeConverter, patterns.getContext());
+ patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad,
+ ConvertVectorTransferRead>(typeConverter, patterns.getContext());
}
void vector::populateVectorNarrowTypeRewritePatterns(
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
index 6fcea33ddc952fe..e1d6c3be494713e 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=8" --cse --split-input-file %s | FileCheck %s
-// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=32" --cse --split-input-file %s | FileCheck %s --check-prefix=CHECK32
+// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=8" --cse --split-input-file %s | FileCheck %s
+// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=32" --cse --split-input-file %s | FileCheck %s --check-prefix=CHECK32
func.func @vector_load_i8(%arg1: index, %arg2: index) -> vector<4xi8> {
%0 = memref.alloc() : memref<3x4xi8>
@@ -108,3 +108,245 @@ func.func @vector_transfer_read_i4(%arg1: index, %arg2: index) -> vector<8xi4> {
// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]], %[[ARG1]]]
// CHECK32: %[[VEC:.+]] = vector.transfer_read %[[ALLOC]][%[[INDEX]]], %[[PAD]] : memref<3xi32>, vector<1xi32>
// CHECK32: %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<1xi32> to vector<8xi4>
+
+// -----
+
+func.func @vector_maskedload_i8(%arg1: index, %arg2: index, %arg3: index, %passthru: vector<4xi8>) -> vector<4xi8> {
+ %0 = memref.alloc() : memref<3x4xi8>
+ %mask = vector.create_mask %arg3 : vector<4xi1>
+ %1 = vector.maskedload %0[%arg1, %arg2], %mask, %passthru :
+ memref<3x4xi8>, vector<4xi1>, vector<4xi8> into vector<4xi8>
+ return %1 : vector<4xi8>
+}
+// Expect no conversions, i8 is supported.
+// CHECK: func @vector_maskedload_i8(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index, %[[ARG3:[a-zA-Z0-9]+]]: vector<4xi8>)
+// CHECK-NEXT: %[[ALLOC:.+]] = memref.alloc() : memref<3x4xi8>
+// CHECK-NEXT: %[[MASK:.+]] = vector.create_mask %[[ARG2]] : vector<4xi1>
+// CHECK-NEXT: [[L:%.+]] = vector.maskedload %[[ALLOC]][%[[ARG0]], %[[ARG1]]], %[[MASK]], %[[ARG3]] :
+// CHECK-SAME: memref<3x4xi8>, vector<4xi1>, vector<4xi8> into vector<4xi8>
+// CHECK-NEXT: return
+
+// CHECK32-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 4)>
+// CHECK32-DAG: #[[MASK_IDX_MAP:.+]] = affine_map<()[s0] -> ((s0 + 3) floordiv 4)>
+// CHECK32: func @vector_maskedload_i8(
+// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index,
+// CHECK32-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index, %[[ARG3:[a-zA-Z0-9]+]]: vector<4xi8>)
+// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
+// CHECK32: %[[ORIG_MASK:.+]] = vector.create_mask %[[ARG2]] : vector<4xi1>
+// CHECK32: %[[LD_IDX:.+]] = affine.apply #[[LOAD_IDX_MAP]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK32: %[[MASK_IDX:.+]] = affine.apply #[[MASK_IDX_MAP]]()[%[[ARG2]]]
+// CHECK32: %[[NEW_MASK:.+]] = vector.create_mask %[[MASK_IDX]] : vector<1xi1>
+// CHECK32: %[[NEW_PASSTHRU:.+]] = vector.bitcast %[[ARG3]] : vector<4xi8> to vector<1xi32>
+// CHECK32: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LD_IDX]]], %[[NEW_MASK]], %[[NEW_PASSTHRU]] :
+// CHECK32-SAME: memref<3xi32>, vector<1xi1>, vector<1xi32> into vector<1xi32>
+// CHECK32: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi32> to vector<4xi8>
+// CHECK32: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[BITCAST]], %[[ARG3]] : vector<4xi1>, vector<4xi8>
+// CHECK32: return %[[SELECT]]
+
+// -----
+
+func.func @vector_maskedload_i4(%arg1: index, %arg2: index, %arg3: index, %passthru: vector<8xi4>) -> vector<3x8xi4> {
+ %0 = memref.alloc() : memref<3x8xi4>
+ %cst = arith.constant dense<0> : vector<3x8xi4>
+ %mask = vector.create_mask %arg3 : vector<8xi1>
+ %1 = vector.maskedload %0[%arg1, %arg2], %mask, %passthru :
+ memref<3x8xi4>, vector<8xi1>, vector<8xi4> into vector<8xi4>
+ %2 = vector.insert %1, %cst [0] : vector<8xi4> into vector<3x8xi4>
+ return %2 : vector<3x8xi4>
+}
+// CHECK-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
+// CHECK-DAG: #[[MASK_IDX_MAP:.+]] = affine_map<()[s0] -> ((s0 + 1) floordiv 2)>
+// CHECK: func @vector_maskedload_i4(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index, %[[ARG3:[a-zA-Z0-9]+]]: vector<8xi4>)
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<12xi8>
+// CHECK: %[[ORIG_MASK:.+]] = vector.create_mask %[[ARG2]] : vector<8xi1>
+// CHECK: %[[LD_IDX:.+]] = affine.apply #[[LOAD_IDX_MAP]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK: %[[MASK_IDX:.+]] = affine.apply #[[MASK_IDX_MAP]]()[%[[ARG2]]]
+// CHECK: %[[NEW_MASK:.+]] = vector.create_mask %[[MASK_IDX]] : vector<4xi1>
+// CHECK: %[[NEW_PASSTHRU:.+]] = vector.bitcast %[[ARG3]] : vector<8xi4> to vector<4xi8>
+// CHECK: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LD_IDX]]], %[[NEW_MASK]], %[[NEW_PASSTHRU]] :
+// CHECK-SAME: memref<12xi8>, vector<4xi1>, vector<4xi8> into vector<4xi8>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<4xi8> to vector<8xi4>
+// CHECK: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[BITCAST]], %[[ARG3]] : vector<8xi1>, vector<8xi4>
+
+// CHECK32-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
+// CHECK32-DAG: #[[MASK_IDX_MAP:.+]] = affine_map<()[s0] -> ((s0 + 7) floordiv 8)>
+// CHECK32: func @vector_maskedload_i4(
+// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index,
+// CHECK32-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index, %[[ARG3:[a-zA-Z0-9]+]]: vector<8xi4>)
+// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
+// CHECK32: %[[ORIG_MASK:.+]] = vector.create_mask %[[ARG2]] : vector<8xi1>
+// CHECK32: %[[LD_IDX:.+]] = affine.apply #[[LOAD_IDX_MAP]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK32: %[[MASK_IDX:.+]] = affine.apply #[[MASK_IDX_MAP]]()[%[[ARG2]]]
+// CHECK32: %[[NEW_MASK:.+]] = vector.create_mask %[[MASK_IDX]] : vector<1xi1>
+// CHECK32: %[[NEW_PASSTHRU:.+]] = vector.bitcast %[[ARG3]] : vector<8xi4> to vector<1xi32>
+// CHECK32: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LD_IDX]]], %[[NEW_MASK]], %[[NEW_PASSTHRU]] :
+// CHECK32-SAME: memref<3xi32>, vector<1xi1>, vector<1xi32> into vector<1xi32>
+// CHECK32: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi32> to vector<8xi4>
+// CHECK32: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[BITCAST]], %[[ARG3]] : vector<8xi1>, vector<8xi4>
+
+// -----
+
+func.func @vector_cst_maskedload_i8(%arg1: index, %arg2: index, %passthru: vector<4xi8>) -> vector<4xi8> {
+ %0 = memref.alloc() : memref<3x4xi8>
+ %mask = vector.constant_mask [2] : vector<4xi1>
+ %1 = vector.maskedload %0[%arg1, %arg2], %mask, %passthru :
+ memref<3x4xi8>, vector<4xi1>, vector<4xi8> into vector<4xi8>
+ return %1 : vector<4xi8>
+}
+// Expect no conversions, i8 is supported.
+// CHECK: func @vector_cst_maskedload_i8(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: vector<4xi8>)
+// CHECK-NEXT: %[[ALLOC:.+]] = memref.alloc() : memref<3x4xi8>
+// CHECK-NEXT: %[[MASK:.+]] = vector.constant_mask [2] : vector<4xi1>
+// CHECK-NEXT: [[L:%.+]] = vector.maskedload %[[ALLOC]][%[[ARG0]], %[[ARG1]]], %[[MASK]], %[[ARG2]] :
+// CHECK-SAME: memref<3x4xi8>, vector<4xi1>, vector<4xi8> into vector<4xi8>
+// CHECK-NEXT: return
+
+// CHECK32-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 4)>
+// CHECK32: func @vector_cst_maskedload_i8(
+// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index,
+// CHECK32-SAME: %[[ARG3:[a-zA-Z0-9]+]]: vector<4xi8>)
+// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
+// CHECK32: %[[ORIG_MASK:.+]] = vector.constant_mask [2] : vector<4xi1>
+// CHECK32: %[[LD_IDX:.+]] = affine.apply #[[LOAD_IDX_MAP]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK32: %[[NEW_MASK:.+]] = vector.constant_mask [1] : vector<1xi1>
+// CHECK32: %[[NEW_PASSTHRU:.+]] = vector.bitcast %[[ARG3]] : vector<4xi8> to vector<1xi32>
+// CHECK32: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LD_IDX]]], %[[NEW_MASK]], %[[NEW_PASSTHRU]] :
+// CHECK32-SAME: memref<3xi32>, vector<1xi1>, vector<1xi32> into vector<1xi32>
+// CHECK32: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi32> to vector<4xi8>
+// CHECK32: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[BITCAST]], %[[ARG3]] : vector<4xi1>, vector<4xi8>
+// CHECK32: return %[[SELECT]]
+
+// -----
+
+func.func @vector_cst_maskedload_i4(%arg1: index, %arg2: index, %passthru: vector<8xi4>) -> vector<3x8xi4> {
+ %0 = memref.alloc() : memref<3x8xi4>
+ %cst = arith.constant dense<0> : vector<3x8xi4>
+ %mask = vector.constant_mask [4] : vector<8xi1>
+ %1 = vector.maskedload %0[%arg1, %arg2], %mask, %passthru :
+ memref<3x8xi4>, vector<8xi1>, vector<8xi4> into vector<8xi4>
+ %2 = vector.insert %1, %cst [0] : vector<8xi4> into vector<3x8xi4>
+ return %2 : vector<3x8xi4>
+}
+// CHECK-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
+// CHECK: func @vector_cst_maskedload_i4(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: vector<8xi4>)
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<12xi8>
+// CHECK: %[[ORIG_MASK:.+]] = vector.constant_mask [4] : vector<8xi1>
+// CHECK: %[[LD_IDX:.+]] = affine.apply #[[LOAD_IDX_MAP]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK: %[[NEW_MASK:.+]] = vector.constant_mask [2] : vector<4xi1>
+// CHECK: %[[NEW_PASSTHRU:.+]] = vector.bitcast %[[ARG2]] : vector<8xi4> to vector<4xi8>
+// CHECK: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LD_IDX]]], %[[NEW_MASK]], %[[NEW_PASSTHRU]] :
+// CHECK-SAME: memref<12xi8>, vector<4xi1>, vector<4xi8> into vector<4xi8>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<4xi8> to vector<8xi4>
+// CHECK: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[BITCAST]], %[[ARG2]] : vector<8xi1>, vector<8xi4>
+
+// CHECK32-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
+// CHECK32: func @vector_cst_maskedload_i4(
+// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index,
+// CHECK32-SAME: %[[ARG2:[a-zA-Z0-9]+]]: vector<8xi4>)
+// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
+// CHECK32: %[[ORIG_MASK:.+]] = vector.constant_mask [4] : vector<8xi1>
+// CHECK32: %[[LD_IDX:.+]] = affine.apply #[[LOAD_IDX_MAP]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK32: %[[NEW_MASK:.+]] = vector.constant_mask [1] : vector<1xi1>
+// CHECK32: %[[NEW_PASSTHRU:.+]] = vector.bitcast %[[ARG2]] : vector<8xi4> to vector<1xi32>
+// CHECK32: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LD_IDX]]], %[[NEW_MASK]], %[[NEW_PASSTHRU]] :
+// CHECK32-SAME: memref<3xi32>, vector<1xi1>, vector<1xi32> into vector<1xi32>
+// CHECK32: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi32> to vector<8xi4>
+// CHECK32: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[BITCAST]], %[[ARG2]] : vector<8xi1>, vector<8xi4>
+
+// -----
+
+func.func @vector_extract_maskedload_i4(%arg1: index) -> vector<8x8x16xi4> {
+ %0 = memref.alloc() : memref<8x8x16xi4>
+ %c0 = arith.constant 0 : index
+ %c16 = arith.constant 16 : index
+ %c8 = arith.constant 8 : index
+ %cst_1 = arith.constant dense<0> : vector<8x8x16xi4>
+ %cst_2 = arith.constant dense<0> : vector<16xi4>
+ %27 = vector.create_mask %c8, %arg1, %c16 : vector<8x8x16xi1>
+ %48 = vector.extract %27[0] : vector<8x16xi1> from vector<8x8x16xi1>
+ %49 = vector.extract %48[0] : vector<16xi1> from vector<8x16xi1>
+ %50 = vector.maskedload %0[%c0, %c0, %c0], %49, %cst_2 : memref<8x8x16xi4>, vector<16xi1>, vector<16xi4> into vector<16xi4>
+ %63 = vector.insert %50, %cst_1 [0, 0] : vector<16xi4> into vector<8x8x16xi4>
+ return %63 : vector<8x8x16xi4>
+}
+// CHECK: func @vector_extract_maskedload_i4(
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<512xi8>
+// CHECK: %[[PASSTHRU:.+]] = arith.constant dense<0> : vector<16xi4>
+// CHECK: %[[ORIG_MASK:.+]] = vector.create_mask {{.*}} vector<8x8x16xi1>
+// CHECK: %[[ORIG_EXT1:.+]] = vector.extract %[[ORIG_MASK]][0] : vector<8x16xi1>
+// CHECK: %[[ORIG_EXT2:.+]] = vector.extract %[[ORIG_EXT1]][0] : vector<16xi1>
+// CHECK: %[[NEW_MASK:.+]] = vector.create_mask {{.*}} vector<8x8x8xi1>
+// CHECK: %[[NEW_EXT1:.+]] = vector.extract %[[NEW_MASK]][0] : vector<8x8xi1>
+// CHECK: %[[NEW_EXT2:.+]] = vector.extract %[[NEW_EXT1]][0] : vector<8xi1>
+// CHECK: %[[NEW_PASSTHRU:.+]] = vector.bitcast %[[PASSTHRU]] : vector<16xi4> to vector<8xi8>
+// CHECK: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%c0], %[[NEW_EXT2]], %[[NEW_PASSTHRU]] :
+// CHECK-SAME: memref<512xi8>, vector<8xi1>, vector<8xi8> into vector<8xi8>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<8xi8> to vector<16xi4>
+// CHECK: %[[SELECT:.+]] = arith.select %[[ORIG_EXT2]], %[[BITCAST]], %[[PASSTHRU]] : vector<16xi1>, vector<16xi4>
+
+// CHECK32: func @vector_extract_maskedload_i4(
+// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<128xi32>
+// CHECK32: %[[PASSTHRU:.+]] = arith.constant dense<0> : vector<16xi4>
+// CHECK32: %[[ORIG_MASK:.+]] = vector.create_mask {{.*}} vector<8x8x16xi1>
+// CHECK32: %[[ORIG_EXT1:.+]] = vector.extract %[[ORIG_MASK]][0] : vector<8x16xi1>
+// CHECK32: %[[ORIG_EXT2:.+]] = vector.extract %[[ORIG_EXT1]][0] : vector<16xi1>
+// CHECK32: %[[NEW_MASK:.+]] = vector.create_mask {{.*}} vector<8x8x2xi1>
+// CHECK32: %[[NEW_EXT1:.+]] = vector.extract %[[NEW_MASK]][0] : vector<8x2xi1>
+// CHECK32: %[[NEW_EXT2:.+]] = vector.extract %[[NEW_EXT1]][0] : vector<2xi1>
+// CHECK32: %[[NEW_PASSTHRU:.+]] = vector.bitcast %[[PASSTHRU]] : vector<16xi4> to vector<2xi32>
+// CHECK32: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%c0], %[[NEW_EXT2]], %[[NEW_PASSTHRU]] :
+// CHECK32-SAME: memref<128xi32>, vector<2xi1>, vector<2xi32> into vector<2xi32>
+// CHECK32: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<2xi32> to vector<16xi4>
+// CHECK32: %[[SELECT:.+]] = arith.select %[[ORIG_EXT2]], %[[BITCAST]], %[[PASSTHRU]] : vector<16xi1>, vector<16xi4>
+
+// -----
+
+func.func @vector_extract_cst_maskedload_i4() -> vector<8x8x16xi4> {
+ %0 = memref.alloc() : memref<8x8x16xi4>
+ %c0 = arith.constant 0 : index
+ %cst_1 = arith.constant dense<0> : vector<8x8x16xi4>
+ %cst_2 = arith.constant dense<0> : vector<16xi4>
+ %27 = vector.constant_mask [8, 4, 16] : vector<8x8x16xi1>
+ %48 = vector.extract %27[0] : vector<8x16xi1> from vector<8x8x16xi1>
+ %49 = vector.extract %48[0] : vector<16xi1> from vector<8x16xi1>
+ %50 = vector.maskedload %0[%c0, %c0, %c0], %49, %cst_2 : memref<8x8x16xi4>, vector<16xi1>, vector<16xi4> into vector<16xi4>
+ %63 = vector.insert %50, %cst_1 [0, 0] : vector<16xi4> into vector<8x8x16xi4>
+ return %63 : vector<8x8x16xi4>
+}
+// CHECK: func @vector_extract_cst_maskedload_i4(
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<512xi8>
+// CHECK: %[[PASSTHRU:.+]] = arith.constant dense<0> : vector<16xi4>
+// CHECK: %[[ORIG_MASK:.+]] = vector.constant_mask {{.*}} vector<8x8x16xi1>
+// CHECK: %[[ORIG_EXT1:.+]] = vector.extract %[[ORIG_MASK]][0] : vector<8x16xi1>
+// CHECK: %[[ORIG_EXT2:.+]] = vector.extract %[[ORIG_EXT1]][0] : vector<16xi1>
+// CHECK: %[[NEW_MASK:.+]] = vector.constant_mask {{.*}} vector<8x8x8xi1>
+// CHECK: %[[NEW_EXT1:.+]] = vector.extract %[[NEW_MASK]][0] : vector<8x8xi1>
+// CHECK: %[[NEW_EXT2:.+]] = vector.extract %[[NEW_EXT1]][0] : vector<8xi1>
+// CHECK: %[[NEW_PASSTHRU:.+]] = vector.bitcast %[[PASSTHRU]] : vector<16xi4> to vector<8xi8>
+// CHECK: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%c0], %[[NEW_EXT2]], %[[NEW_PASSTHRU]] :
+// CHECK-SAME: memref<512xi8>, vector<8xi1>, vector<8xi8> into vector<8xi8>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<8xi8> to vector<16xi4>
+// CHECK: %[[SELECT:.+]] = arith.select %[[ORIG_EXT2]], %[[BITCAST]], %[[PASSTHRU]] : vector<16xi1>, vector<16xi4>
+
+// CHECK32: func @vector_extract_cst_maskedload_i4(
+// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<128xi32>
+// CHECK32: %[[PASSTHRU:.+]] = arith.constant dense<0> : vector<16xi4>
+// CHECK32: %[[ORIG_MASK:.+]] = vector.constant_mask {{.*}} vector<8x8x16xi1>
+// CHECK32: %[[ORIG_EXT1:.+]] = vector.extract %[[ORIG_MASK]][0] : vector<8x16xi1>
+// CHECK32: %[[ORIG_EXT2:.+]] = vector.extract %[[ORIG_EXT1]][0] : vector<16xi1>
+// CHECK32: %[[NEW_MASK:.+]] = vector.constant_mask {{.*}} vector<8x8x2xi1>
+// CHECK32: %[[NEW_EXT1:.+]] = vector.extract %[[NEW_MASK]][0] : vector<8x2xi1>
+// CHECK32: %[[NEW_EXT2:.+]] = vector.extract %[[NEW_EXT1]][0] : vector<2xi1>
+// CHECK32: %[[NEW_PASSTHRU:.+]] = vector.bitcast %[[PASSTHRU]] : vector<16xi4> to vector<2xi32>
+// CHECK32: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%c0], %[[NEW_EXT2]], %[[NEW_PASSTHRU]] :
+// CHECK32-SAME: memref<128xi32>, vector<2xi1>, vector<2xi32> into vector<2xi32>
+// CHECK32: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<2xi32> to vector<16xi4>
+// CHECK32: %[[SELECT:.+]] = arith.select %[[ORIG_EXT2]], %[[BITCAST]], %[[PASSTHRU]] : vector<16xi1>, vector<16xi4>
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-rewrite-narrow-types.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-rewrite-narrow-types.mlir
index 6bdeb4523865fa9..3c8c8d45013dc82 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-rewrite-narrow-types.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-rewrite-narrow-types.mlir
@@ -164,6 +164,13 @@ func.func @fext(%a: vector<5xi8>) {
return
}
+func.func @fcst_maskedload(%A: memref<?xi4>, %passthru: vector<6xi4>) -> vector<6xi4> {
+ %c0 = arith.constant 0: index
+ %mask = vector.constant_mask [3] : vector<6xi1>
+ %1 = vector.maskedload %A[%c0], %mask, %passthru :
+ memref<?xi4>, vector<6xi1>, vector<6xi4> into vector<6xi4>
+ return %1 : vector<6xi4>
+}
func.func @entry() {
%v = arith.constant dense<[
@@ -187,6 +194,21 @@ func.func @entry() {
]> : vector<5xi8>
func.call @fext(%v4) : (vector<5xi8>) -> ()
+ // Set up memory.
+ %c0 = arith.constant 0: index
+ %c1 = arith.constant 1: index
+ %c6 = arith.constant 6: index
+ %A = memref.alloc(%c6) : memref<?xi4>
+ scf.for %i = %c0 to %c6 step %c1 {
+ %i4 = arith.index_cast %i : index to i4
+ memref.store %i4, %A[%i] : memref<?xi4>
+ }
+ %passthru = arith.constant dense<[7, 8, 9, 10, 11, 12]> : vector<6xi4>
+ %load = call @fcst_maskedload(%A, %passthru) : (memref<?xi4>, vector<6xi4>) -> (vector<6xi4>)
+ vector.print %load : vector<6xi4>
+ // CHECK: ( 1, 2, 3, -6, -5, -4 )
+ memref.dealloc %A : memref<?xi4>
+
return
}
More information about the Mlir-commits
mailing list