[Mlir-commits] [mlir] 65a3f28 - [mlir] Add "mask" operand to vector.transfer_read/write.
Matthias Springer
llvmlistbot at llvm.org
Wed Apr 7 05:33:34 PDT 2021
Author: Matthias Springer
Date: 2021-04-07T21:33:13+09:00
New Revision: 65a3f289397fd7d6cfcb4ddfdf324e37cf90cad7
URL: https://github.com/llvm/llvm-project/commit/65a3f289397fd7d6cfcb4ddfdf324e37cf90cad7
DIFF: https://github.com/llvm/llvm-project/commit/65a3f289397fd7d6cfcb4ddfdf324e37cf90cad7.diff
LOG: [mlir] Add "mask" operand to vector.transfer_read/write.
Also factors out out-of-bounds mask generation from vector.transfer_read/write into a new MaterializeTransferMask pattern.
Differential Revision: https://reviews.llvm.org/D100001
Added:
Modified:
mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
mlir/include/mlir/Dialect/Vector/VectorOps.h
mlir/include/mlir/Dialect/Vector/VectorOps.td
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
mlir/test/Dialect/Vector/ops.mlir
mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
index efd26ff8808c..0ee3fd5eb4a0 100644
--- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
+++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
@@ -68,7 +68,7 @@ void populateVectorToLLVMMatrixConversionPatterns(LLVMTypeConverter &converter,
/// Collect a set of patterns to convert from the Vector dialect to LLVM.
void populateVectorToLLVMConversionPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns,
- bool reassociateFPReductions = false, bool enableIndexOptimizations = true);
+ bool reassociateFPReductions = false);
/// Create a pass to convert vector operations to the LLVMIR dialect.
std::unique_ptr<OperationPass<ModuleOp>> createConvertVectorToLLVMPass(
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h
index a2ec152c3947..c11e8112b2e8 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h
@@ -88,6 +88,10 @@ void populateVectorSlicesLoweringPatterns(RewritePatternSet &patterns);
/// `vector.store` and `vector.broadcast`.
void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns);
+/// These patterns materialize masks for various vector ops such as transfers.
+void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
+ bool enableIndexOptimizations);
+
/// An attribute that specifies the combining function for `vector.contract`,
/// and `vector.reduction`.
class CombiningKindAttr
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 5ff118b05ff4..14afe9504806 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -1135,10 +1135,12 @@ def Vector_TransferReadOp :
Vector_Op<"transfer_read", [
DeclareOpInterfaceMethods<VectorTransferOpInterface>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
- DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ AttrSizedOperandSegments
]>,
Arguments<(ins AnyShaped:$source, Variadic<Index>:$indices,
AffineMapAttr:$permutation_map, AnyType:$padding,
+ Optional<VectorOf<[I1]>>:$mask,
OptionalAttr<BoolArrayAttr>:$in_bounds)>,
Results<(outs AnyVector:$vector)> {
@@ -1167,13 +1169,19 @@ def Vector_TransferReadOp :
return type.
An SSA value `padding` of the same elemental type as the MemRef/Tensor is
- provided to specify a fallback value in the case of out-of-bounds accesses.
+ provided to specify a fallback value in the case of out-of-bounds accesses
+ and/or masking.
+
+ An optional SSA value `mask` of the same shape as the vector type may be
+ specified to mask out elements. Such elements will be replaces with
+ `padding`. Elements whose corresponding mask element is `0` are masked out.
An optional boolean array attribute is provided to specify which dimensions
of the transfer are guaranteed to be within bounds. The absence of this
`in_bounds` attribute signifies that any dimension of the transfer may be
out-of-bounds. A `vector.transfer_read` can be lowered to a simple load if
- all dimensions are specified to be within bounds.
+ all dimensions are specified to be within bounds and no `mask` was
+ specified.
This operation is called 'read' by opposition to 'load' because the
super-vector granularity is generally not representable with a single
@@ -1299,6 +1307,14 @@ def Vector_TransferReadOp :
// 'getMinorIdentityMap' (resp. zero).
OpBuilder<(ins "VectorType":$vector, "Value":$source,
"ValueRange":$indices, CArg<"ArrayRef<bool>", "{}">:$inBounds)>,
+ // Builder that does not set mask.
+ OpBuilder<(ins "Type":$vector, "Value":$source,
+ "ValueRange":$indices, "AffineMapAttr":$permutationMap, "Value":$padding,
+ "ArrayAttr":$inBounds)>,
+ // Builder that does not set mask.
+ OpBuilder<(ins "Type":$vector, "Value":$source,
+ "ValueRange":$indices, "AffineMap":$permutationMap, "Value":$padding,
+ "ArrayAttr":$inBounds)>
];
let hasFolder = 1;
@@ -1308,11 +1324,13 @@ def Vector_TransferWriteOp :
Vector_Op<"transfer_write", [
DeclareOpInterfaceMethods<VectorTransferOpInterface>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
- DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ AttrSizedOperandSegments
]>,
Arguments<(ins AnyVector:$vector, AnyShaped:$source,
Variadic<Index>:$indices,
AffineMapAttr:$permutation_map,
+ Optional<VectorOf<[I1]>>:$mask,
OptionalAttr<BoolArrayAttr>:$in_bounds)>,
Results<(outs Optional<AnyRankedTensor>:$result)> {
@@ -1341,11 +1359,16 @@ def Vector_TransferWriteOp :
The size of the slice is specified by the size of the vector.
+ An optional SSA value `mask` of the same shape as the vector type may be
+ specified to mask out elements. Elements whose corresponding mask element
+ is `0` are masked out.
+
An optional boolean array attribute is provided to specify which dimensions
of the transfer are guaranteed to be within bounds. The absence of this
`in_bounds` attribute signifies that any dimension of the transfer may be
out-of-bounds. A `vector.transfer_write` can be lowered to a simple store
- if all dimensions are specified to be within bounds.
+ if all dimensions are specified to be within bounds and no `mask` was
+ specified.
This operation is called 'write' by opposition to 'store' because the
super-vector granularity is generally not representable with a single
@@ -1391,6 +1414,8 @@ def Vector_TransferWriteOp :
"AffineMap":$permutationMap)>,
OpBuilder<(ins "Value":$vector, "Value":$source, "ValueRange":$indices,
"AffineMapAttr":$permutationMap, "ArrayAttr":$inBounds)>,
+ OpBuilder<(ins "Value":$vector, "Value":$source, "ValueRange":$indices,
+ "AffineMap":$permutationMap, "Value":$mask, "ArrayAttr":$inBounds)>,
OpBuilder<(ins "Value":$vector, "Value":$source, "ValueRange":$indices,
"AffineMap":$permutationMap, "ArrayAttr":$inBounds)>,
];
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 82e4bc2f4353..0c752c33ff16 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -104,66 +104,6 @@ static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
return res;
}
-static Value createCastToIndexLike(ConversionPatternRewriter &rewriter,
- Location loc, Type targetType, Value value) {
- if (targetType == value.getType())
- return value;
-
- bool targetIsIndex = targetType.isIndex();
- bool valueIsIndex = value.getType().isIndex();
- if (targetIsIndex ^ valueIsIndex)
- return rewriter.create<IndexCastOp>(loc, targetType, value);
-
- auto targetIntegerType = targetType.dyn_cast<IntegerType>();
- auto valueIntegerType = value.getType().dyn_cast<IntegerType>();
- assert(targetIntegerType && valueIntegerType &&
- "unexpected cast between types other than integers and index");
- assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness());
-
- if (targetIntegerType.getWidth() > valueIntegerType.getWidth())
- return rewriter.create<SignExtendIOp>(loc, targetIntegerType, value);
- return rewriter.create<TruncateIOp>(loc, targetIntegerType, value);
-}
-
-// Helper that returns a vector comparison that constructs a mask:
-// mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
-//
-// NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
-// much more compact, IR for this operation, but LLVM eventually
-// generates more elaborate instructions for this intrinsic since it
-// is very conservative on the boundary conditions.
-static Value buildVectorComparison(ConversionPatternRewriter &rewriter,
- Operation *op, bool enableIndexOptimizations,
- int64_t dim, Value b, Value *off = nullptr) {
- auto loc = op->getLoc();
- // If we can assume all indices fit in 32-bit, we perform the vector
- // comparison in 32-bit to get a higher degree of SIMD parallelism.
- // Otherwise we perform the vector comparison using 64-bit indices.
- Value indices;
- Type idxType;
- if (enableIndexOptimizations) {
- indices = rewriter.create<ConstantOp>(
- loc, rewriter.getI32VectorAttr(
- llvm::to_vector<4>(llvm::seq<int32_t>(0, dim))));
- idxType = rewriter.getI32Type();
- } else {
- indices = rewriter.create<ConstantOp>(
- loc, rewriter.getI64VectorAttr(
- llvm::to_vector<4>(llvm::seq<int64_t>(0, dim))));
- idxType = rewriter.getI64Type();
- }
- // Add in an offset if requested.
- if (off) {
- Value o = createCastToIndexLike(rewriter, loc, idxType, *off);
- Value ov = rewriter.create<SplatOp>(loc, indices.getType(), o);
- indices = rewriter.create<AddIOp>(loc, ov, indices);
- }
- // Construct the vector comparison.
- Value bound = createCastToIndexLike(rewriter, loc, idxType, b);
- Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
- return rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, indices, bounds);
-}
-
// Helper that returns data layout alignment of a memref.
LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
MemRefType memrefType, unsigned &align) {
@@ -250,7 +190,7 @@ replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
if (failed(getMemRefAlignment(
typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
return failure();
- auto adaptor = TransferWriteOpAdaptor(operands);
+ auto adaptor = TransferWriteOpAdaptor(operands, xferOp->getAttrDictionary());
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr,
align);
return success();
@@ -266,7 +206,7 @@ replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
return failure();
- auto adaptor = TransferWriteOpAdaptor(operands);
+ auto adaptor = TransferWriteOpAdaptor(operands, xferOp->getAttrDictionary());
rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
xferOp, adaptor.vector(), dataPtr, mask,
rewriter.getI32IntegerAttr(align));
@@ -275,12 +215,12 @@ replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
static TransferReadOpAdaptor getTransferOpAdapter(TransferReadOp xferOp,
ArrayRef<Value> operands) {
- return TransferReadOpAdaptor(operands);
+ return TransferReadOpAdaptor(operands, xferOp->getAttrDictionary());
}
static TransferWriteOpAdaptor getTransferOpAdapter(TransferWriteOp xferOp,
ArrayRef<Value> operands) {
- return TransferWriteOpAdaptor(operands);
+ return TransferWriteOpAdaptor(operands, xferOp->getAttrDictionary());
}
namespace {
@@ -618,33 +558,6 @@ class VectorReductionOpConversion
const bool reassociateFPReductions;
};
-/// Conversion pattern for a vector.create_mask (1-D only).
-class VectorCreateMaskOpConversion
- : public ConvertOpToLLVMPattern<vector::CreateMaskOp> {
-public:
- explicit VectorCreateMaskOpConversion(LLVMTypeConverter &typeConv,
- bool enableIndexOpt)
- : ConvertOpToLLVMPattern<vector::CreateMaskOp>(typeConv),
- enableIndexOptimizations(enableIndexOpt) {}
-
- LogicalResult
- matchAndRewrite(vector::CreateMaskOp op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- auto dstType = op.getType();
- int64_t rank = dstType.getRank();
- if (rank == 1) {
- rewriter.replaceOp(
- op, buildVectorComparison(rewriter, op, enableIndexOptimizations,
- dstType.getDimSize(0), operands[0]));
- return success();
- }
- return failure();
- }
-
-private:
- const bool enableIndexOptimizations;
-};
-
class VectorShuffleOpConversion
: public ConvertOpToLLVMPattern<vector::ShuffleOp> {
public:
@@ -1177,20 +1090,12 @@ class VectorTypeCastOpConversion
}
};
-/// Conversion pattern that converts a 1-D vector transfer read/write op in a
-/// sequence of:
-/// 1. Get the source/dst address as an LLVM vector pointer.
-/// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
-/// 3. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
-/// 4. Create a mask where offsetVector is compared against memref upper bound.
-/// 5. Rewrite op as a masked read or write.
+/// Conversion pattern that converts a 1-D vector transfer read/write op into a
+/// a masked or unmasked read/write.
template <typename ConcreteOp>
class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
public:
- explicit VectorTransferConversion(LLVMTypeConverter &typeConv,
- bool enableIndexOpt)
- : ConvertOpToLLVMPattern<ConcreteOp>(typeConv),
- enableIndexOptimizations(enableIndexOpt) {}
+ using ConvertOpToLLVMPattern<ConcreteOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(ConcreteOp xferOp, ArrayRef<Value> operands,
@@ -1212,6 +1117,9 @@ class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
auto strides = computeContiguousStrides(memRefType);
if (!strides)
return failure();
+ // Out-of-bounds dims are handled by MaterializeTransferMask.
+ if (xferOp.hasOutOfBoundsDim())
+ return failure();
auto toLLVMTy = [&](Type t) {
return this->getTypeConverter()->convertType(t);
@@ -1241,40 +1149,24 @@ class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
#endif // ifndef NDEBUG
}
- // 1. Get the source/dst address as an LLVM vector pointer.
+ // Get the source/dst address as an LLVM vector pointer.
VectorType vtp = xferOp.getVectorType();
Value dataPtr = this->getStridedElementPtr(
loc, memRefType, adaptor.source(), adaptor.indices(), rewriter);
Value vectorDataPtr =
castDataPtr(rewriter, loc, dataPtr, memRefType, toLLVMTy(vtp));
- if (xferOp.isDimInBounds(0))
+ // Rewrite as an unmasked masked read / write.
+ if (!xferOp.mask())
return replaceTransferOpWithLoadOrStore(rewriter,
*this->getTypeConverter(), loc,
xferOp, operands, vectorDataPtr);
- // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
- // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
- // 4. Let dim the memref dimension, compute the vector comparison mask
- // (in-bounds mask):
- // [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
- //
- // TODO: when the leaf transfer rank is k > 1, we need the last `k`
- // dimensions here.
- unsigned vecWidth = LLVM::getVectorNumElements(vtp).getFixedValue();
- unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
- Value off = xferOp.indices()[lastIndex];
- Value dim = rewriter.create<memref::DimOp>(loc, xferOp.source(), lastIndex);
- Value mask = buildVectorComparison(
- rewriter, xferOp, enableIndexOptimizations, vecWidth, dim, &off);
-
- // 5. Rewrite as a masked read / write.
+ // Rewrite as a masked read / write.
return replaceTransferOpWithMasked(rewriter, *this->getTypeConverter(), loc,
- xferOp, operands, vectorDataPtr, mask);
+ xferOp, operands, vectorDataPtr,
+ xferOp.mask());
}
-
-private:
- const bool enableIndexOptimizations;
};
class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
@@ -1484,17 +1376,13 @@ class VectorExtractStridedSliceOpConversion
/// Populate the given list with patterns that convert from Vector to LLVM.
void mlir::populateVectorToLLVMConversionPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns,
- bool reassociateFPReductions, bool enableIndexOptimizations) {
+ bool reassociateFPReductions) {
MLIRContext *ctx = converter.getDialect()->getContext();
patterns.add<VectorFMAOpNDRewritePattern,
VectorInsertStridedSliceOpDifferentRankRewritePattern,
VectorInsertStridedSliceOpSameRankRewritePattern,
VectorExtractStridedSliceOpConversion>(ctx);
patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
- patterns.add<VectorCreateMaskOpConversion,
- VectorTransferConversion<TransferReadOp>,
- VectorTransferConversion<TransferWriteOp>>(
- converter, enableIndexOptimizations);
patterns
.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
VectorExtractElementOpConversion, VectorExtractOpConversion,
@@ -1508,8 +1396,9 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorLoadStoreConversion<vector::MaskedStoreOp,
vector::MaskedStoreOpAdaptor>,
VectorGatherOpConversion, VectorScatterOpConversion,
- VectorExpandLoadOpConversion, VectorCompressStoreOpConversion>(
- converter);
+ VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
+ VectorTransferConversion<TransferReadOp>,
+ VectorTransferConversion<TransferWriteOp>>(converter);
}
void mlir::populateVectorToLLVMMatrixConversionPatterns(
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index abddcd73af1e..49ee670b2f06 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -71,9 +71,10 @@ void LowerVectorToLLVMPass::runOnOperation() {
// Convert to the LLVM IR dialect.
LLVMTypeConverter converter(&getContext());
RewritePatternSet patterns(&getContext());
+ populateVectorMaskMaterializationPatterns(patterns, enableIndexOptimizations);
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
- populateVectorToLLVMConversionPatterns(
- converter, patterns, reassociateFPReductions, enableIndexOptimizations);
+ populateVectorToLLVMConversionPatterns(converter, patterns,
+ reassociateFPReductions);
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
// Architecture specific augmentations.
diff --git a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
index b55c8bc263c8..2f033b18c8f1 100644
--- a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
+++ b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
@@ -42,7 +42,7 @@ static LogicalResult replaceTransferOpWithMubuf(
LLVMTypeConverter &typeConverter, Location loc, TransferWriteOp xferOp,
Type &vecTy, Value &dwordConfig, Value &vindex, Value &offsetSizeInBytes,
Value &glc, Value &slc) {
- auto adaptor = TransferWriteOpAdaptor(operands);
+ auto adaptor = TransferWriteOpAdaptor(operands, xferOp->getAttrDictionary());
rewriter.replaceOpWithNewOp<ROCDL::MubufStoreOp>(xferOp, adaptor.vector(),
dwordConfig, vindex,
offsetSizeInBytes, glc, slc);
@@ -62,7 +62,7 @@ class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
LogicalResult
matchAndRewrite(ConcreteOp xferOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- typename ConcreteOp::Adaptor adaptor(operands);
+ typename ConcreteOp::Adaptor adaptor(operands, xferOp->getAttrDictionary());
if (xferOp.getVectorType().getRank() > 1 ||
llvm::size(xferOp.indices()) == 0)
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 6e963aeb8932..72d32d071e49 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -538,6 +538,8 @@ LogicalResult VectorTransferRewriter<TransferReadOp>::matchAndRewrite(
using namespace mlir::edsc::op;
TransferReadOp transfer = cast<TransferReadOp>(op);
+ if (transfer.mask())
+ return failure();
auto memRefType = transfer.getShapedType().dyn_cast<MemRefType>();
if (!memRefType)
return failure();
@@ -624,6 +626,8 @@ LogicalResult VectorTransferRewriter<TransferWriteOp>::matchAndRewrite(
using namespace edsc::op;
TransferWriteOp transfer = cast<TransferWriteOp>(op);
+ if (transfer.mask())
+ return failure();
auto memRefType = transfer.getShapedType().template dyn_cast<MemRefType>();
if (!memRefType)
return failure();
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 7e4233db12dd..cff5fcb5649e 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -2295,8 +2295,27 @@ void TransferReadOp::build(OpBuilder &builder, OperationState &result,
build(builder, result, vectorType, source, indices, permMap, inBounds);
}
+/// Builder that does not provide a mask.
+void TransferReadOp::build(OpBuilder &builder, OperationState &result,
+ Type vectorType, Value source, ValueRange indices,
+ AffineMap permutationMap, Value padding,
+ ArrayAttr inBounds) {
+ build(builder, result, vectorType, source, indices, permutationMap, padding,
+ /*mask=*/Value(), inBounds);
+}
+
+/// Builder that does not provide a mask.
+void TransferReadOp::build(OpBuilder &builder, OperationState &result,
+ Type vectorType, Value source, ValueRange indices,
+ AffineMapAttr permutationMap, Value padding,
+ ArrayAttr inBounds) {
+ build(builder, result, vectorType, source, indices, permutationMap, padding,
+ /*mask=*/Value(), inBounds);
+}
+
static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
- SmallVector<StringRef, 2> elidedAttrs;
+ SmallVector<StringRef, 3> elidedAttrs;
+ elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
if (op.permutation_map().isMinorIdentity())
elidedAttrs.push_back(op.getPermutationMapAttrName());
bool elideInBounds = true;
@@ -2316,27 +2335,36 @@ static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
static void print(OpAsmPrinter &p, TransferReadOp op) {
p << op.getOperationName() << " " << op.source() << "[" << op.indices()
<< "], " << op.padding();
+ if (op.mask())
+ p << ", " << op.mask();
printTransferAttrs(p, cast<VectorTransferOpInterface>(op.getOperation()));
p << " : " << op.getShapedType() << ", " << op.getVectorType();
}
static ParseResult parseTransferReadOp(OpAsmParser &parser,
OperationState &result) {
+ auto &builder = parser.getBuilder();
llvm::SMLoc typesLoc;
OpAsmParser::OperandType sourceInfo;
SmallVector<OpAsmParser::OperandType, 8> indexInfo;
OpAsmParser::OperandType paddingInfo;
SmallVector<Type, 2> types;
+ OpAsmParser::OperandType maskInfo;
// Parsing with support for paddingValue.
if (parser.parseOperand(sourceInfo) ||
parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
- parser.parseComma() || parser.parseOperand(paddingInfo) ||
- parser.parseOptionalAttrDict(result.attributes) ||
+ parser.parseComma() || parser.parseOperand(paddingInfo))
+ return failure();
+ ParseResult hasMask = parser.parseOptionalComma();
+ if (hasMask.succeeded()) {
+ parser.parseOperand(maskInfo);
+ }
+ if (parser.parseOptionalAttrDict(result.attributes) ||
parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
return failure();
if (types.size() != 2)
return parser.emitError(typesLoc, "requires two types");
- auto indexType = parser.getBuilder().getIndexType();
+ auto indexType = builder.getIndexType();
auto shapedType = types[0].dyn_cast<ShapedType>();
if (!shapedType || !shapedType.isa<MemRefType, RankedTensorType>())
return parser.emitError(typesLoc, "requires memref or ranked tensor type");
@@ -2349,12 +2377,21 @@ static ParseResult parseTransferReadOp(OpAsmParser &parser,
auto permMap = getTransferMinorIdentityMap(shapedType, vectorType);
result.attributes.set(permutationAttrName, AffineMapAttr::get(permMap));
}
- return failure(
- parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
+ if (parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
parser.resolveOperands(indexInfo, indexType, result.operands) ||
parser.resolveOperand(paddingInfo, shapedType.getElementType(),
- result.operands) ||
- parser.addTypeToList(vectorType, result.types));
+ result.operands))
+ return failure();
+ if (hasMask.succeeded()) {
+ auto maskType = VectorType::get(vectorType.getShape(), builder.getI1Type());
+ if (parser.resolveOperand(maskInfo, maskType, result.operands))
+ return failure();
+ }
+ result.addAttribute(
+ TransferReadOp::getOperandSegmentSizeAttr(),
+ builder.getI32VectorAttr({1, static_cast<int32_t>(indexInfo.size()), 1,
+ static_cast<int32_t>(hasMask.succeeded())}));
+ return parser.addTypeToList(vectorType, result.types);
}
static LogicalResult verify(TransferReadOp op) {
@@ -2525,7 +2562,7 @@ void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
/*optional*/ ArrayAttr inBounds) {
Type resultType = source.getType().dyn_cast<RankedTensorType>();
build(builder, result, resultType, vector, source, indices, permutationMap,
- inBounds);
+ /*mask=*/Value(), inBounds);
}
void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
@@ -2534,24 +2571,39 @@ void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
/*optional*/ ArrayAttr inBounds) {
Type resultType = source.getType().dyn_cast<RankedTensorType>();
build(builder, result, resultType, vector, source, indices, permutationMap,
- inBounds);
+ /*mask=*/Value(), inBounds);
+}
+
+void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
+ Value vector, Value source, ValueRange indices,
+ AffineMap permutationMap, /*optional*/ Value mask,
+ /*optional*/ ArrayAttr inBounds) {
+ Type resultType = source.getType().dyn_cast<RankedTensorType>();
+ build(builder, result, resultType, vector, source, indices, permutationMap,
+ mask, inBounds);
}
static ParseResult parseTransferWriteOp(OpAsmParser &parser,
OperationState &result) {
+ auto &builder = parser.getBuilder();
llvm::SMLoc typesLoc;
OpAsmParser::OperandType vectorInfo, sourceInfo;
SmallVector<OpAsmParser::OperandType, 8> indexInfo;
SmallVector<Type, 2> types;
+ OpAsmParser::OperandType maskInfo;
if (parser.parseOperand(vectorInfo) || parser.parseComma() ||
parser.parseOperand(sourceInfo) ||
- parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
- parser.parseOptionalAttrDict(result.attributes) ||
+ parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square))
+ return failure();
+ ParseResult hasMask = parser.parseOptionalComma();
+ if (hasMask.succeeded() && parser.parseOperand(maskInfo))
+ return failure();
+ if (parser.parseOptionalAttrDict(result.attributes) ||
parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
return failure();
if (types.size() != 2)
return parser.emitError(typesLoc, "requires two types");
- auto indexType = parser.getBuilder().getIndexType();
+ auto indexType = builder.getIndexType();
VectorType vectorType = types[0].dyn_cast<VectorType>();
if (!vectorType)
return parser.emitError(typesLoc, "requires vector type");
@@ -2564,17 +2616,28 @@ static ParseResult parseTransferWriteOp(OpAsmParser &parser,
auto permMap = getTransferMinorIdentityMap(shapedType, vectorType);
result.attributes.set(permutationAttrName, AffineMapAttr::get(permMap));
}
- return failure(
- parser.resolveOperand(vectorInfo, vectorType, result.operands) ||
+ if (parser.resolveOperand(vectorInfo, vectorType, result.operands) ||
parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
- parser.resolveOperands(indexInfo, indexType, result.operands) ||
- (shapedType.isa<RankedTensorType>() &&
- parser.addTypeToList(shapedType, result.types)));
+ parser.resolveOperands(indexInfo, indexType, result.operands))
+ return failure();
+ if (hasMask.succeeded()) {
+ auto maskType = VectorType::get(vectorType.getShape(), builder.getI1Type());
+ if (parser.resolveOperand(maskInfo, maskType, result.operands))
+ return failure();
+ }
+ result.addAttribute(
+ TransferWriteOp::getOperandSegmentSizeAttr(),
+ builder.getI32VectorAttr({1, 1, static_cast<int32_t>(indexInfo.size()),
+ static_cast<int32_t>(hasMask.succeeded())}));
+ return failure(shapedType.isa<RankedTensorType>() &&
+ parser.addTypeToList(shapedType, result.types));
}
static void print(OpAsmPrinter &p, TransferWriteOp op) {
p << op.getOperationName() << " " << op.vector() << ", " << op.source() << "["
<< op.indices() << "]";
+ if (op.mask())
+ p << ", " << op.mask();
printTransferAttrs(p, cast<VectorTransferOpInterface>(op.getOperation()));
p << " : " << op.getVectorType() << ", " << op.getShapedType();
}
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index b48c8ace0b88..ba8ca26b336e 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -596,6 +596,8 @@ static Value unrollTransferReadOp(vector::TransferReadOp readOp,
OpBuilder &builder) {
if (!isIdentitySuffix(readOp.permutation_map()))
return nullptr;
+ if (readOp.mask())
+ return nullptr;
auto sourceVectorType = readOp.getVectorType();
SmallVector<int64_t, 4> strides(targetShape.size(), 1);
@@ -641,6 +643,8 @@ mlir::vector::unrollTransferWriteOp(OpBuilder &builder, Operation *op,
auto writeOp = cast<vector::TransferWriteOp>(op);
if (!isIdentitySuffix(writeOp.permutation_map()))
return failure();
+ if (writeOp.mask())
+ return failure();
VectorType sourceVectorType = writeOp.getVectorType();
SmallVector<int64_t, 4> strides(targetShape.size(), 1);
TupleType tupleType = generateExtractSlicesOpResultType(
@@ -722,6 +726,9 @@ class SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> {
if (ignoreFilter && ignoreFilter(readOp))
return failure();
+ if (readOp.mask())
+ return failure();
+
// TODO: Support splitting TransferReadOp with non-identity permutation
// maps. Repurpose code from MaterializeVectors transformation.
if (!isIdentitySuffix(readOp.permutation_map()))
@@ -768,6 +775,9 @@ class SplitTransferWriteOp : public OpRewritePattern<vector::TransferWriteOp> {
if (ignoreFilter && ignoreFilter(writeOp))
return failure();
+ if (writeOp.mask())
+ return failure();
+
// TODO: Support splitting TransferWriteOp with non-identity permutation
// maps. Repurpose code from MaterializeVectors transformation.
if (!isIdentitySuffix(writeOp.permutation_map()))
@@ -2546,6 +2556,9 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer(
"Expected splitFullAndPartialTransferPrecondition to hold");
auto xferReadOp = dyn_cast<vector::TransferReadOp>(xferOp.getOperation());
+ if (xferReadOp.mask())
+ return failure();
+
// TODO: add support for write case.
if (!xferReadOp)
return failure();
@@ -2677,6 +2690,8 @@ struct TransferReadExtractPattern
dyn_cast<vector::ExtractMapOp>(*read.getResult().getUsers().begin());
if (!extract)
return failure();
+ if (read.mask())
+ return failure();
edsc::ScopedContext scope(rewriter, read.getLoc());
using mlir::edsc::op::operator+;
using mlir::edsc::op::operator*;
@@ -2712,6 +2727,8 @@ struct TransferWriteInsertPattern
auto insert = write.vector().getDefiningOp<vector::InsertMapOp>();
if (!insert)
return failure();
+ if (write.mask())
+ return failure();
edsc::ScopedContext scope(rewriter, write.getLoc());
using mlir::edsc::op::operator+;
using mlir::edsc::op::operator*;
@@ -2742,6 +2759,7 @@ struct TransferWriteInsertPattern
/// - If the memref's element type is a vector type then it coincides with the
/// result type.
/// - The permutation map doesn't perform permutation (broadcasting is allowed).
+/// - The op has no mask.
struct TransferReadToVectorLoadLowering
: public OpRewritePattern<vector::TransferReadOp> {
TransferReadToVectorLoadLowering(MLIRContext *context)
@@ -2780,7 +2798,8 @@ struct TransferReadToVectorLoadLowering
// MaskedLoadOp.
if (read.hasOutOfBoundsDim())
return failure();
-
+ if (read.mask())
+ return failure();
Operation *loadOp;
if (!broadcastedDims.empty() &&
unbroadcastedVectorType.getNumElements() == 1) {
@@ -2815,6 +2834,7 @@ struct TransferReadToVectorLoadLowering
/// type of the written value.
/// - The permutation map is the minor identity map (neither permutation nor
/// broadcasting is allowed).
+/// - The op has no mask.
struct TransferWriteToVectorStoreLowering
: public OpRewritePattern<vector::TransferWriteOp> {
TransferWriteToVectorStoreLowering(MLIRContext *context)
@@ -2840,6 +2860,8 @@ struct TransferWriteToVectorStoreLowering
// MaskedStoreOp.
if (write.hasOutOfBoundsDim())
return failure();
+ if (write.mask())
+ return failure();
rewriter.replaceOpWithNewOp<vector::StoreOp>(
write, write.vector(), write.source(), write.indices());
return success();
@@ -2880,6 +2902,8 @@ struct TransferReadPermutationLowering
map.getPermutationMap(permutation, op.getContext());
if (permutationMap.isIdentity())
return failure();
+ if (op.mask())
+ return failure();
// Caluclate the map of the new read by applying the inverse permutation.
permutationMap = inversePermutation(permutationMap);
AffineMap newMap = permutationMap.compose(map);
@@ -2914,6 +2938,8 @@ struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
LogicalResult matchAndRewrite(vector::TransferReadOp op,
PatternRewriter &rewriter) const override {
+ if (op.mask())
+ return failure();
AffineMap map = op.permutation_map();
unsigned numLeadingBroadcast = 0;
for (auto expr : map.getResults()) {
@@ -3062,6 +3088,9 @@ struct CastAwayTransferReadLeadingOneDim
LogicalResult matchAndRewrite(vector::TransferReadOp read,
PatternRewriter &rewriter) const override {
+ if (read.mask())
+ return failure();
+
auto shapedType = read.source().getType().cast<ShapedType>();
if (shapedType.getElementType() != read.getVectorType().getElementType())
return failure();
@@ -3102,6 +3131,9 @@ struct CastAwayTransferWriteLeadingOneDim
LogicalResult matchAndRewrite(vector::TransferWriteOp write,
PatternRewriter &rewriter) const override {
+ if (write.mask())
+ return failure();
+
auto shapedType = write.source().getType().dyn_cast<ShapedType>();
if (shapedType.getElementType() != write.getVectorType().getElementType())
return failure();
@@ -3371,6 +3403,151 @@ struct BubbleUpBitCastForStridedSliceInsert
}
};
+static Value createCastToIndexLike(PatternRewriter &rewriter, Location loc,
+ Type targetType, Value value) {
+ if (targetType == value.getType())
+ return value;
+
+ bool targetIsIndex = targetType.isIndex();
+ bool valueIsIndex = value.getType().isIndex();
+ if (targetIsIndex ^ valueIsIndex)
+ return rewriter.create<IndexCastOp>(loc, targetType, value);
+
+ auto targetIntegerType = targetType.dyn_cast<IntegerType>();
+ auto valueIntegerType = value.getType().dyn_cast<IntegerType>();
+ assert(targetIntegerType && valueIntegerType &&
+ "unexpected cast between types other than integers and index");
+ assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness());
+
+ if (targetIntegerType.getWidth() > valueIntegerType.getWidth())
+ return rewriter.create<SignExtendIOp>(loc, targetIntegerType, value);
+ return rewriter.create<TruncateIOp>(loc, targetIntegerType, value);
+}
+
+// Helper that returns a vector comparison that constructs a mask:
+// mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
+//
+// NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
+// much more compact, IR for this operation, but LLVM eventually
+// generates more elaborate instructions for this intrinsic since it
+// is very conservative on the boundary conditions.
+static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op,
+ bool enableIndexOptimizations, int64_t dim,
+ Value b, Value *off = nullptr) {
+ auto loc = op->getLoc();
+ // If we can assume all indices fit in 32-bit, we perform the vector
+ // comparison in 32-bit to get a higher degree of SIMD parallelism.
+ // Otherwise we perform the vector comparison using 64-bit indices.
+ Value indices;
+ Type idxType;
+ if (enableIndexOptimizations) {
+ indices = rewriter.create<ConstantOp>(
+ loc, rewriter.getI32VectorAttr(
+ llvm::to_vector<4>(llvm::seq<int32_t>(0, dim))));
+ idxType = rewriter.getI32Type();
+ } else {
+ indices = rewriter.create<ConstantOp>(
+ loc, rewriter.getI64VectorAttr(
+ llvm::to_vector<4>(llvm::seq<int64_t>(0, dim))));
+ idxType = rewriter.getI64Type();
+ }
+ // Add in an offset if requested.
+ if (off) {
+ Value o = createCastToIndexLike(rewriter, loc, idxType, *off);
+ Value ov = rewriter.create<SplatOp>(loc, indices.getType(), o);
+ indices = rewriter.create<AddIOp>(loc, ov, indices);
+ }
+ // Construct the vector comparison.
+ Value bound = createCastToIndexLike(rewriter, loc, idxType, b);
+ Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
+ return rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, indices, bounds);
+}
+
+template <typename ConcreteOp>
+struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> {
+public:
+ explicit MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt)
+ : mlir::OpRewritePattern<ConcreteOp>(context),
+ enableIndexOptimizations(enableIndexOpt) {}
+
+ LogicalResult matchAndRewrite(ConcreteOp xferOp,
+ PatternRewriter &rewriter) const override {
+ if (!xferOp.hasOutOfBoundsDim())
+ return failure();
+
+ if (xferOp.getVectorType().getRank() > 1 ||
+ llvm::size(xferOp.indices()) == 0)
+ return failure();
+
+ Location loc = xferOp->getLoc();
+ VectorType vtp = xferOp.getVectorType();
+
+ // * Create a vector with linear indices [ 0 .. vector_length - 1 ].
+ // * Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
+ // * Let dim the memref dimension, compute the vector comparison mask
+ // (in-bounds mask):
+ // [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
+ //
+ // TODO: when the leaf transfer rank is k > 1, we need the last `k`
+ // dimensions here.
+ unsigned vecWidth = vtp.getNumElements();
+ unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
+ Value off = xferOp.indices()[lastIndex];
+ Value dim = rewriter.create<memref::DimOp>(loc, xferOp.source(), lastIndex);
+ Value mask = buildVectorComparison(
+ rewriter, xferOp, enableIndexOptimizations, vecWidth, dim, &off);
+
+ if (xferOp.mask()) {
+ // Intersect the in-bounds with the mask specified as an op parameter.
+ mask = rewriter.create<AndOp>(loc, mask, xferOp.mask());
+ }
+
+ rewriter.updateRootInPlace(xferOp, [&]() {
+ xferOp.maskMutable().assign(mask);
+ xferOp.in_boundsAttr(rewriter.getBoolArrayAttr({true}));
+ });
+
+ return success();
+ }
+
+private:
+ const bool enableIndexOptimizations;
+};
+
+/// Conversion pattern for a vector.create_mask (1-D only).
+class VectorCreateMaskOpConversion
+ : public OpRewritePattern<vector::CreateMaskOp> {
+public:
+ explicit VectorCreateMaskOpConversion(MLIRContext *context,
+ bool enableIndexOpt)
+ : mlir::OpRewritePattern<vector::CreateMaskOp>(context),
+ enableIndexOptimizations(enableIndexOpt) {}
+
+ LogicalResult matchAndRewrite(vector::CreateMaskOp op,
+ PatternRewriter &rewriter) const override {
+ auto dstType = op.getType();
+ int64_t rank = dstType.getRank();
+ if (rank == 1) {
+ rewriter.replaceOp(
+ op, buildVectorComparison(rewriter, op, enableIndexOptimizations,
+ dstType.getDimSize(0), op.getOperand(0)));
+ return success();
+ }
+ return failure();
+ }
+
+private:
+ const bool enableIndexOptimizations;
+};
+
+void mlir::vector::populateVectorMaskMaterializationPatterns(
+ RewritePatternSet &patterns, bool enableIndexOptimizations) {
+ patterns.add<VectorCreateMaskOpConversion,
+ MaterializeTransferMask<vector::TransferReadOp>,
+ MaterializeTransferMask<vector::TransferWriteOp>>(
+ patterns.getContext(), enableIndexOptimizations);
+}
+
// TODO: Add pattern to rewrite ExtractSlices(ConstantMaskOp).
// TODO: Add this as DRR pattern.
void mlir::vector::populateVectorToVectorTransformationPatterns(
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir
index 249f8c09e599..c09b4ac2da96 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir
@@ -3,20 +3,19 @@
// CMP32-LABEL: @genbool_var_1d(
// CMP32-SAME: %[[ARG:.*]]: index)
-// CMP32: %[[A:.*]] = llvm.mlir.cast %[[ARG]] : index to i64
// CMP32: %[[T0:.*]] = constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]> : vector<11xi32>
-// CMP32: %[[T1:.*]] = trunci %[[A]] : i64 to i32
+// CMP32: %[[T1:.*]] = index_cast %[[ARG]] : index to i32
// CMP32: %[[T2:.*]] = splat %[[T1]] : vector<11xi32>
// CMP32: %[[T3:.*]] = cmpi slt, %[[T0]], %[[T2]] : vector<11xi32>
// CMP32: return %[[T3]] : vector<11xi1>
// CMP64-LABEL: @genbool_var_1d(
// CMP64-SAME: %[[ARG:.*]]: index)
-// CMP64: %[[A:.*]] = llvm.mlir.cast %[[ARG]] : index to i64
// CMP64: %[[T0:.*]] = constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]> : vector<11xi64>
-// CMP64: %[[T1:.*]] = splat %[[A]] : vector<11xi64>
-// CMP64: %[[T2:.*]] = cmpi slt, %[[T0]], %[[T1]] : vector<11xi64>
-// CMP64: return %[[T2]] : vector<11xi1>
+// CMP64: %[[T1:.*]] = index_cast %[[ARG]] : index to i64
+// CMP64: %[[T2:.*]] = splat %[[T1]] : vector<11xi64>
+// CMP64: %[[T3:.*]] = cmpi slt, %[[T0]], %[[T2]] : vector<11xi64>
+// CMP64: return %[[T3]] : vector<11xi1>
func @genbool_var_1d(%arg0: index) -> vector<11xi1> {
%0 = vector.create_mask %arg0 : vector<11xi1>
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index a5161b64337f..9faf7caa3439 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1049,31 +1049,31 @@ func @transfer_read_1d(%A : memref<?xf32>, %base: index) -> vector<17xf32> {
// CHECK-LABEL: func @transfer_read_1d
// CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: index) -> vector<17xf32>
// CHECK: %[[c7:.*]] = constant 7.0
-//
-// 1. Bitcast to vector form.
-// CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} :
-// CHECK-SAME: (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
-// CHECK: %[[vecPtr:.*]] = llvm.bitcast %[[gep]] :
-// CHECK-SAME: !llvm.ptr<f32> to !llvm.ptr<vector<17xf32>>
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[DIM:.*]] = memref.dim %{{.*}}, %[[C0]] : memref<?xf32>
//
-// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
+// 1. Create a vector with linear indices [ 0 .. vector_length - 1 ].
// CHECK: %[[linearIndex:.*]] = constant dense
// CHECK-SAME: <[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> :
// CHECK-SAME: vector<17xi32>
//
-// 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
+// 2. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
// CHECK: %[[otrunc:.*]] = index_cast %[[BASE]] : index to i32
// CHECK: %[[offsetVec:.*]] = splat %[[otrunc]] : vector<17xi32>
// CHECK: %[[offsetVec2:.*]] = addi %[[offsetVec]], %[[linearIndex]] : vector<17xi32>
//
-// 4. Let dim the memref dimension, compute the vector comparison mask:
+// 3. Let dim the memref dimension, compute the vector comparison mask:
// [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
// CHECK: %[[dtrunc:.*]] = index_cast %[[DIM]] : index to i32
// CHECK: %[[dimVec:.*]] = splat %[[dtrunc]] : vector<17xi32>
// CHECK: %[[mask:.*]] = cmpi slt, %[[offsetVec2]], %[[dimVec]] : vector<17xi32>
//
+// 4. Bitcast to vector form.
+// CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} :
+// CHECK-SAME: (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
+// CHECK: %[[vecPtr:.*]] = llvm.bitcast %[[gep]] :
+// CHECK-SAME: !llvm.ptr<f32> to !llvm.ptr<vector<17xf32>>
+//
// 5. Rewrite as a masked read.
// CHECK: %[[PASS_THROUGH:.*]] = splat %[[c7]] : vector<17xf32>
// CHECK: %[[loaded:.*]] = llvm.intr.masked.load %[[vecPtr]], %[[mask]],
@@ -1081,26 +1081,26 @@ func @transfer_read_1d(%A : memref<?xf32>, %base: index) -> vector<17xf32> {
// CHECK-SAME: (!llvm.ptr<vector<17xf32>>, vector<17xi1>, vector<17xf32>) -> vector<17xf32>
//
-// 1. Bitcast to vector form.
-// CHECK: %[[gep_b:.*]] = llvm.getelementptr {{.*}} :
-// CHECK-SAME: (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
-// CHECK: %[[vecPtr_b:.*]] = llvm.bitcast %[[gep_b]] :
-// CHECK-SAME: !llvm.ptr<f32> to !llvm.ptr<vector<17xf32>>
-//
-// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
+// 1. Create a vector with linear indices [ 0 .. vector_length - 1 ].
// CHECK: %[[linearIndex_b:.*]] = constant dense
// CHECK-SAME: <[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> :
// CHECK-SAME: vector<17xi32>
//
-// 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
+// 2. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
// CHECK: splat %{{.*}} : vector<17xi32>
// CHECK: addi
//
-// 4. Let dim the memref dimension, compute the vector comparison mask:
+// 3. Let dim the memref dimension, compute the vector comparison mask:
// [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
// CHECK: splat %{{.*}} : vector<17xi32>
// CHECK: %[[mask_b:.*]] = cmpi slt, {{.*}} : vector<17xi32>
//
+// 4. Bitcast to vector form.
+// CHECK: %[[gep_b:.*]] = llvm.getelementptr {{.*}} :
+// CHECK-SAME: (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
+// CHECK: %[[vecPtr_b:.*]] = llvm.bitcast %[[gep_b]] :
+// CHECK-SAME: !llvm.ptr<f32> to !llvm.ptr<vector<17xf32>>
+//
// 5. Rewrite as a masked write.
// CHECK: llvm.intr.masked.store %[[loaded]], %[[vecPtr_b]], %[[mask_b]]
// CHECK-SAME: {alignment = 4 : i32} :
@@ -1182,6 +1182,21 @@ func @transfer_read_1d_inbounds(%A : memref<?xf32>, %base: index) -> vector<17xf
// -----
+// CHECK-LABEL: func @transfer_read_1d_mask
+// CHECK: %[[mask1:.*]] = constant dense<[false, false, true, false, true]>
+// CHECK: %[[cmpi:.*]] = cmpi slt
+// CHECK: %[[mask2:.*]] = and %[[cmpi]], %[[mask1]]
+// CHECK: %[[r:.*]] = llvm.intr.masked.load %{{.*}}, %[[mask2]]
+// CHECK: return %[[r]]
+func @transfer_read_1d_mask(%A : memref<?xf32>, %base : index) -> vector<5xf32> {
+ %m = constant dense<[0, 0, 1, 0, 1]> : vector<5xi1>
+ %f7 = constant 7.0: f32
+ %f = vector.transfer_read %A[%base], %f7, %m : memref<?xf32>, vector<5xf32>
+ return %f: vector<5xf32>
+}
+
+// -----
+
func @transfer_read_1d_cast(%A : memref<?xi32>, %base: index) -> vector<12xi8> {
%c0 = constant 0: i32
%v = vector.transfer_read %A[%base], %c0 {in_bounds = [true]} :
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 1e6f95aa2293..43bef97f799e 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -11,6 +11,7 @@ func @vector_transfer_ops(%arg0: memref<?x?xf32>,
%c0 = constant 0 : i32
%vf0 = splat %f0 : vector<4x3xf32>
%v0 = splat %c0 : vector<4x3xi32>
+ %m = constant dense<[0, 0, 1, 0, 1]> : vector<5xi1>
//
// CHECK: vector.transfer_read
@@ -27,7 +28,8 @@ func @vector_transfer_ops(%arg0: memref<?x?xf32>,
%5 = vector.transfer_read %arg1[%c3, %c3], %vf0 {in_bounds = [false, true]} : memref<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>
// CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} : memref<?x?xvector<4x3xi32>>, vector<5x24xi8>
%6 = vector.transfer_read %arg2[%c3, %c3], %v0 : memref<?x?xvector<4x3xi32>>, vector<5x24xi8>
-
+ // CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}}, %{{.*}} : memref<?x?xf32>, vector<5xf32>
+ %7 = vector.transfer_read %arg0[%c3, %c3], %f0, %m : memref<?x?xf32>, vector<5xf32>
// CHECK: vector.transfer_write
vector.transfer_write %0, %arg0[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d0)>} : vector<128xf32>, memref<?x?xf32>
@@ -39,7 +41,8 @@ func @vector_transfer_ops(%arg0: memref<?x?xf32>,
vector.transfer_write %5, %arg1[%c3, %c3] {in_bounds = [false, false]} : vector<1x1x4x3xf32>, memref<?x?xvector<4x3xf32>>
// CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] : vector<5x24xi8>, memref<?x?xvector<4x3xi32>>
vector.transfer_write %6, %arg2[%c3, %c3] : vector<5x24xi8>, memref<?x?xvector<4x3xi32>>
-
+ // CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} : vector<5xf32>, memref<?x?xf32>
+ vector.transfer_write %7, %arg0[%c3, %c3], %m : vector<5xf32>, memref<?x?xf32>
return
}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read.mlir
index 5cd7d09e6c8b..bed94f02920a 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read.mlir
@@ -12,6 +12,14 @@ func @transfer_read_1d(%A : memref<?xf32>, %base: index) {
return
}
+func @transfer_read_mask_1d(%A : memref<?xf32>, %base: index) {
+ %fm42 = constant -42.0: f32
+ %m = constant dense<[0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0]> : vector<13xi1>
+ %f = vector.transfer_read %A[%base], %fm42, %m : memref<?xf32>, vector<13xf32>
+ vector.print %f: vector<13xf32>
+ return
+}
+
func @transfer_read_inbounds_4(%A : memref<?xf32>, %base: index) {
%fm42 = constant -42.0: f32
%f = vector.transfer_read %A[%base], %fm42
@@ -21,6 +29,15 @@ func @transfer_read_inbounds_4(%A : memref<?xf32>, %base: index) {
return
}
+func @transfer_read_mask_inbounds_4(%A : memref<?xf32>, %base: index) {
+ %fm42 = constant -42.0: f32
+ %m = constant dense<[0, 1, 0, 1]> : vector<4xi1>
+ %f = vector.transfer_read %A[%base], %fm42, %m {in_bounds = [true]}
+ : memref<?xf32>, vector<4xf32>
+ vector.print %f: vector<4xf32>
+ return
+}
+
func @transfer_write_1d(%A : memref<?xf32>, %base: index) {
%f0 = constant 0.0 : f32
%vf0 = splat %f0 : vector<4xf32>
@@ -47,6 +64,8 @@ func @entry() {
// Read shifted by 2 and pad with -42:
// ( 2, 3, 4, -42, ..., -42)
call @transfer_read_1d(%A, %c2) : (memref<?xf32>, index) -> ()
+ // Read with mask and out-of-bounds access.
+ call @transfer_read_mask_1d(%A, %c2) : (memref<?xf32>, index) -> ()
// Write into memory shifted by 3
// memory contains [[ 0, 1, 2, 0, 0, xxx garbage xxx ]]
call @transfer_write_1d(%A, %c3) : (memref<?xf32>, index) -> ()
@@ -56,9 +75,13 @@ func @entry() {
// Read in-bounds 4 @ 1, guaranteed to not overflow.
// Exercises proper alignment.
call @transfer_read_inbounds_4(%A, %c1) : (memref<?xf32>, index) -> ()
+ // Read in-bounds with mask.
+ call @transfer_read_mask_inbounds_4(%A, %c1) : (memref<?xf32>, index) -> ()
return
}
// CHECK: ( 2, 3, 4, -42, -42, -42, -42, -42, -42, -42, -42, -42, -42 )
+// CHECK: ( -42, -42, 4, -42, -42, -42, -42, -42, -42, -42, -42, -42, -42 )
// CHECK: ( 0, 1, 2, 0, 0, -42, -42, -42, -42, -42, -42, -42, -42 )
// CHECK: ( 1, 2, 0, 0 )
+// CHECK: ( -42, 2, -42, 0 )
More information about the Mlir-commits
mailing list