[Mlir-commits] [mlir] 1870e78 - [mlir][Vector] Add an optional "masked" boolean array attribute to vector transfer operations
Nicolas Vasilache
llvmlistbot at llvm.org
Mon May 18 08:55:54 PDT 2020
Author: Nicolas Vasilache
Date: 2020-05-18T11:52:08-04:00
New Revision: 1870e787af961d1b409e18a18ddf297f02333a78
URL: https://github.com/llvm/llvm-project/commit/1870e787af961d1b409e18a18ddf297f02333a78
DIFF: https://github.com/llvm/llvm-project/commit/1870e787af961d1b409e18a18ddf297f02333a78.diff
LOG: [mlir][Vector] Add an optional "masked" boolean array attribute to vector transfer operations
Summary:
Vector transfer ops semantic is extended to allow specifying a per-dimension `masked`
attribute. When the attribute is false on a particular dimension, lowering to LLVM emits
unmasked load and store operations.
Differential Revision: https://reviews.llvm.org/D80098
Added:
Modified:
mlir/include/mlir/Dialect/Vector/VectorOps.td
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
mlir/test/Conversion/VectorToLoops/vector-to-loops.mlir
mlir/test/Dialect/Vector/invalid.mlir
mlir/test/Dialect/Vector/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index b8a47a27e41f..29e72857b291 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -865,7 +865,12 @@ def Vector_ExtractStridedSliceOp :
def Vector_TransferOpUtils {
code extraTransferDeclaration = [{
+ static StringRef getMaskedAttrName() { return "masked"; }
static StringRef getPermutationMapAttrName() { return "permutation_map"; }
+ bool isMaskedDim(unsigned dim) {
+ return !masked() ||
+ masked()->cast<ArrayAttr>()[dim].cast<BoolAttr>().getValue();
+ }
MemRefType getMemRefType() {
return memref().getType().cast<MemRefType>();
}
@@ -878,14 +883,15 @@ def Vector_TransferOpUtils {
def Vector_TransferReadOp :
Vector_Op<"transfer_read">,
Arguments<(ins AnyMemRef:$memref, Variadic<Index>:$indices,
- AffineMapAttr:$permutation_map, AnyType:$padding)>,
+ AffineMapAttr:$permutation_map, AnyType:$padding,
+ OptionalAttr<BoolArrayAttr>:$masked)>,
Results<(outs AnyVector:$vector)> {
let summary = "Reads a supervector from memory into an SSA vector value.";
let description = [{
- The `vector.transfer_read` op performs a blocking read from a slice within
- a [MemRef](../LangRef.md#memref-type) supplied as its first operand
+ The `vector.transfer_read` op performs a read from a slice within a
+ [MemRef](../LangRef.md#memref-type) supplied as its first operand
into a [vector](../LangRef.md#vector-type) of the same base elemental type.
A memref operand with vector element type, must have its vector element
@@ -893,8 +899,9 @@ def Vector_TransferReadOp :
memref<3x2x6x4x3xf32>, vector<1x1x4x3xf32>).
The slice is further defined by a full-rank index within the MemRef,
- supplied as the operands `2 .. 1 + rank(memref)`. The permutation_map
- [attribute](../LangRef.md#attributes) is an
+ supplied as the operands `2 .. 1 + rank(memref)`.
+
+ The permutation_map [attribute](../LangRef.md#attributes) is an
[affine-map](Affine.md#affine-maps) which specifies the transposition on the
slice to match the vector shape. The permutation map may be implicit and
ommitted from parsing and printing if it is the canonical minor identity map
@@ -906,6 +913,12 @@ def Vector_TransferReadOp :
An `ssa-value` of the same elemental type as the MemRef is provided as the
last operand to specify padding in the case of out-of-bounds accesses.
+ An optional boolean array attribute is provided to specify which dimensions
+ of the transfer need masking. When a dimension is specified as not requiring
+ masking, the `vector.transfer_read` may be lowered to simple loads. The
+ absence of this `masked` attribute signifies that all dimensions of the
+ transfer need to be masked.
+
This operation is called 'read' by opposition to 'load' because the
super-vector granularity is generally not representable with a single
hardware register. A `vector.transfer_read` is thus a mid-level abstraction
@@ -1015,11 +1028,13 @@ def Vector_TransferReadOp :
let builders = [
// Builder that sets padding to zero.
OpBuilder<"OpBuilder &builder, OperationState &result, VectorType vector, "
- "Value memref, ValueRange indices, AffineMap permutationMap">,
+ "Value memref, ValueRange indices, AffineMap permutationMap, "
+ "ArrayRef<bool> maybeMasked = {}">,
// Builder that sets permutation map (resp. padding) to
// 'getMinorIdentityMap' (resp. zero).
OpBuilder<"OpBuilder &builder, OperationState &result, VectorType vector, "
- "Value memref, ValueRange indices">
+ "Value memref, ValueRange indices, "
+ "ArrayRef<bool> maybeMasked = {}">
];
let extraClassDeclaration = Vector_TransferOpUtils.extraTransferDeclaration #
@@ -1039,12 +1054,13 @@ def Vector_TransferWriteOp :
Vector_Op<"transfer_write">,
Arguments<(ins AnyVector:$vector, AnyMemRef:$memref,
Variadic<Index>:$indices,
- AffineMapAttr:$permutation_map)> {
+ AffineMapAttr:$permutation_map,
+ OptionalAttr<BoolArrayAttr>:$masked)> {
let summary = "The vector.transfer_write op writes a supervector to memory.";
let description = [{
- The `vector.transfer_write` performs a blocking write from a
+ The `vector.transfer_write` op performs a write from a
[vector](../LangRef.md#vector-type), supplied as its first operand, into a
slice within a [MemRef](../LangRef.md#memref-type) of the same base
elemental type, supplied as its second operand.
@@ -1055,6 +1071,7 @@ def Vector_TransferWriteOp :
The slice is further defined by a full-rank index within the MemRef,
supplied as the operands `3 .. 2 + rank(memref)`.
+
The permutation_map [attribute](../LangRef.md#attributes) is an
[affine-map](Affine.md#affine-maps) which specifies the transposition on the
slice to match the vector shape. The permutation map may be implicit and
@@ -1063,6 +1080,12 @@ def Vector_TransferWriteOp :
The size of the slice is specified by the size of the vector.
+ An optional boolean array attribute is provided to specify which dimensions
+ of the transfer need masking. When a dimension is specified as not requiring
+ masking, the `vector.transfer_write` may be lowered to simple stores. The
+ absence of this `mask` attribute signifies that all dimensions of the
+ transfer need to be masked.
+
This operation is called 'write' by opposition to 'store' because the
super-vector granularity is generally not representable with a single
hardware register. A `vector.transfer_write` is thus a
@@ -1097,7 +1120,10 @@ def Vector_TransferWriteOp :
let builders = [
// Builder that sets permutation map to 'getMinorIdentityMap'.
OpBuilder<"OpBuilder &builder, OperationState &result, Value vector, "
- "Value memref, ValueRange indices">
+ "Value memref, ValueRange indices, "
+ "ArrayRef<bool> maybeMasked = {}">,
+ OpBuilder<"OpBuilder &builder, OperationState &result, Value vector, "
+ "Value memref, ValueRange indices, AffineMap permutationMap">,
];
let extraClassDeclaration = Vector_TransferOpUtils.extraTransferDeclaration #
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index eb25bf3abf85..975807ca8671 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -746,12 +746,6 @@ class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
}
};
-template <typename ConcreteOp>
-LogicalResult replaceTransferOp(ConversionPatternRewriter &rewriter,
- LLVMTypeConverter &typeConverter, Location loc,
- Operation *op, ArrayRef<Value> operands,
- Value dataPtr, Value mask);
-
LogicalResult getLLVMTypeAndAlignment(LLVMTypeConverter &typeConverter,
Type type, LLVM::LLVMType &llvmType,
unsigned &align) {
@@ -765,12 +759,25 @@ LogicalResult getLLVMTypeAndAlignment(LLVMTypeConverter &typeConverter,
return success();
}
-template <>
-LogicalResult replaceTransferOp<TransferReadOp>(
- ConversionPatternRewriter &rewriter, LLVMTypeConverter &typeConverter,
- Location loc, Operation *op, ArrayRef<Value> operands, Value dataPtr,
- Value mask) {
- auto xferOp = cast<TransferReadOp>(op);
+LogicalResult
+replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
+ LLVMTypeConverter &typeConverter, Location loc,
+ TransferReadOp xferOp,
+ ArrayRef<Value> operands, Value dataPtr) {
+ LLVM::LLVMType vecTy;
+ unsigned align;
+ if (failed(getLLVMTypeAndAlignment(typeConverter, xferOp.getVectorType(),
+ vecTy, align)))
+ return failure();
+ rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr);
+ return success();
+}
+
+LogicalResult replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
+ LLVMTypeConverter &typeConverter,
+ Location loc, TransferReadOp xferOp,
+ ArrayRef<Value> operands,
+ Value dataPtr, Value mask) {
auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
VectorType fillType = xferOp.getVectorType();
Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding());
@@ -783,19 +790,32 @@ LogicalResult replaceTransferOp<TransferReadOp>(
return failure();
rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
- op, vecTy, dataPtr, mask, ValueRange{fill},
+ xferOp, vecTy, dataPtr, mask, ValueRange{fill},
rewriter.getI32IntegerAttr(align));
return success();
}
-template <>
-LogicalResult replaceTransferOp<TransferWriteOp>(
- ConversionPatternRewriter &rewriter, LLVMTypeConverter &typeConverter,
- Location loc, Operation *op, ArrayRef<Value> operands, Value dataPtr,
- Value mask) {
+LogicalResult
+replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
+ LLVMTypeConverter &typeConverter, Location loc,
+ TransferWriteOp xferOp,
+ ArrayRef<Value> operands, Value dataPtr) {
auto adaptor = TransferWriteOpOperandAdaptor(operands);
+ LLVM::LLVMType vecTy;
+ unsigned align;
+ if (failed(getLLVMTypeAndAlignment(typeConverter, xferOp.getVectorType(),
+ vecTy, align)))
+ return failure();
+ rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr);
+ return success();
+}
- auto xferOp = cast<TransferWriteOp>(op);
+LogicalResult replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
+ LLVMTypeConverter &typeConverter,
+ Location loc, TransferWriteOp xferOp,
+ ArrayRef<Value> operands,
+ Value dataPtr, Value mask) {
+ auto adaptor = TransferWriteOpOperandAdaptor(operands);
LLVM::LLVMType vecTy;
unsigned align;
if (failed(getLLVMTypeAndAlignment(typeConverter, xferOp.getVectorType(),
@@ -803,7 +823,8 @@ LogicalResult replaceTransferOp<TransferWriteOp>(
return failure();
rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
- op, adaptor.vector(), dataPtr, mask, rewriter.getI32IntegerAttr(align));
+ xferOp, adaptor.vector(), dataPtr, mask,
+ rewriter.getI32IntegerAttr(align));
return success();
}
@@ -877,6 +898,10 @@ class VectorTransferConversion : public ConvertToLLVMPattern {
vectorDataPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
loc, vecTy.getPointerTo(), dataPtr);
+ if (!xferOp.isMaskedDim(0))
+ return replaceTransferOpWithLoadOrStore(rewriter, typeConverter, loc,
+ xferOp, operands, vectorDataPtr);
+
// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
unsigned vecWidth = vecTy.getVectorNumElements();
VectorType vectorCmpType = VectorType::get(vecWidth, i64Type);
@@ -910,8 +935,8 @@ class VectorTransferConversion : public ConvertToLLVMPattern {
mask);
// 5. Rewrite as a masked read / write.
- return replaceTransferOp<ConcreteOp>(rewriter, typeConverter, loc, op,
- operands, vectorDataPtr, mask);
+ return replaceTransferOpWithMasked(rewriter, typeConverter, loc, xferOp,
+ operands, vectorDataPtr, mask);
}
};
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index d3da7bff7b5b..03b78491fa12 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -157,25 +157,34 @@ void NDTransferOpHelper<ConcreteOp>::emitInBounds(
ValueRange majorIvs, ValueRange majorOffsets,
MemRefBoundsCapture &memrefBounds, LambdaThen thenBlockBuilder,
LambdaElse elseBlockBuilder) {
- Value inBounds = std_constant_int(/*value=*/1, /*width=*/1);
+ Value inBounds;
SmallVector<Value, 4> majorIvsPlusOffsets;
majorIvsPlusOffsets.reserve(majorIvs.size());
+ unsigned idx = 0;
for (auto it : llvm::zip(majorIvs, majorOffsets, memrefBounds.getUbs())) {
Value iv = std::get<0>(it), off = std::get<1>(it), ub = std::get<2>(it);
using namespace mlir::edsc::op;
majorIvsPlusOffsets.push_back(iv + off);
- Value inBounds2 = majorIvsPlusOffsets.back() < ub;
- inBounds = inBounds && inBounds2;
+ if (xferOp.isMaskedDim(leadingRank + idx)) {
+ Value inBounds2 = majorIvsPlusOffsets.back() < ub;
+ inBounds = (inBounds) ? (inBounds && inBounds2) : inBounds2;
+ }
+ ++idx;
}
- auto ifOp = ScopedContext::getBuilderRef().create<scf::IfOp>(
- ScopedContext::getLocation(), TypeRange{}, inBounds,
- /*withElseRegion=*/std::is_same<ConcreteOp, TransferReadOp>());
- BlockBuilder(&ifOp.thenRegion().front(),
- Append())([&] { thenBlockBuilder(majorIvsPlusOffsets); });
- if (std::is_same<ConcreteOp, TransferReadOp>())
- BlockBuilder(&ifOp.elseRegion().front(),
- Append())([&] { elseBlockBuilder(majorIvsPlusOffsets); });
+ if (inBounds) {
+ auto ifOp = ScopedContext::getBuilderRef().create<scf::IfOp>(
+ ScopedContext::getLocation(), TypeRange{}, inBounds,
+ /*withElseRegion=*/std::is_same<ConcreteOp, TransferReadOp>());
+ BlockBuilder(&ifOp.thenRegion().front(),
+ Append())([&] { thenBlockBuilder(majorIvsPlusOffsets); });
+ if (std::is_same<ConcreteOp, TransferReadOp>())
+ BlockBuilder(&ifOp.elseRegion().front(),
+ Append())([&] { elseBlockBuilder(majorIvsPlusOffsets); });
+ } else {
+ // Just build the body of the then block right here.
+ thenBlockBuilder(majorIvsPlusOffsets);
+ }
}
template <>
@@ -192,13 +201,18 @@ LogicalResult NDTransferOpHelper<TransferReadOp>::doReplace() {
indexing.append(leadingOffsets.begin(), leadingOffsets.end());
indexing.append(majorIvsPlusOffsets.begin(), majorIvsPlusOffsets.end());
indexing.append(minorOffsets.begin(), minorOffsets.end());
- // Lower to 1-D vector_transfer_read and let recursion handle it.
+
Value memref = xferOp.memref();
auto map = TransferReadOp::getTransferMinorIdentityMap(
xferOp.getMemRefType(), minorVectorType);
- auto loaded1D =
- vector_transfer_read(minorVectorType, memref, indexing,
- AffineMapAttr::get(map), xferOp.padding());
+ ArrayAttr masked;
+ if (xferOp.isMaskedDim(xferOp.getVectorType().getRank() - 1)) {
+ OpBuilder &b = ScopedContext::getBuilderRef();
+ masked = b.getBoolArrayAttr({true});
+ }
+ auto loaded1D = vector_transfer_read(minorVectorType, memref, indexing,
+ AffineMapAttr::get(map),
+ xferOp.padding(), masked);
// Store the 1-D vector.
std_store(loaded1D, alloc, majorIvs);
};
@@ -229,7 +243,6 @@ LogicalResult NDTransferOpHelper<TransferWriteOp>::doReplace() {
ValueRange majorOffsets, ValueRange minorOffsets,
MemRefBoundsCapture &memrefBounds) {
auto thenBlockBuilder = [&](ValueRange majorIvsPlusOffsets) {
- // Lower to 1-D vector_transfer_write and let recursion handle it.
SmallVector<Value, 8> indexing;
indexing.reserve(leadingRank + majorRank + minorRank);
indexing.append(leadingOffsets.begin(), leadingOffsets.end());
@@ -239,8 +252,13 @@ LogicalResult NDTransferOpHelper<TransferWriteOp>::doReplace() {
Value loaded1D = std_load(alloc, majorIvs);
auto map = TransferWriteOp::getTransferMinorIdentityMap(
xferOp.getMemRefType(), minorVectorType);
+ ArrayAttr masked;
+ if (xferOp.isMaskedDim(xferOp.getVectorType().getRank() - 1)) {
+ OpBuilder &b = ScopedContext::getBuilderRef();
+ masked = b.getBoolArrayAttr({true});
+ }
vector_transfer_write(loaded1D, xferOp.memref(), indexing,
- AffineMapAttr::get(map));
+ AffineMapAttr::get(map), masked);
};
// Don't write anything when out of bounds.
auto elseBlockBuilder = [&](ValueRange majorIvsPlusOffsets) {};
diff --git a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
index c72b835fc51a..f5b98f9bf065 100644
--- a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
@@ -1017,8 +1017,7 @@ static Operation *vectorizeOneOperation(Operation *opInst,
LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: ");
LLVM_DEBUG(permutationMap.print(dbgs()));
auto transfer = b.create<vector::TransferWriteOp>(
- opInst->getLoc(), vectorValue, memRef, indices,
- AffineMapAttr::get(permutationMap));
+ opInst->getLoc(), vectorValue, memRef, indices, permutationMap);
auto *res = transfer.getOperation();
LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorized store: " << *res);
// "Terminals" (i.e. AffineStoreOps) are erased on the spot.
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 94695b6473de..f347a564f446 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -1202,6 +1202,23 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns(
//===----------------------------------------------------------------------===//
// TransferReadOp
//===----------------------------------------------------------------------===//
+
+/// Build the default minor identity map suitable for a vector transfer. This
+/// also handles the case memref<... x vector<...>> -> vector<...> in which the
+/// rank of the identity map must take the vector element type into account.
+AffineMap
+mlir::vector::impl::getTransferMinorIdentityMap(MemRefType memRefType,
+ VectorType vectorType) {
+ int64_t elementVectorRank = 0;
+ VectorType elementVectorType =
+ memRefType.getElementType().dyn_cast<VectorType>();
+ if (elementVectorType)
+ elementVectorRank += elementVectorType.getRank();
+ return AffineMap::getMinorIdentityMap(
+ memRefType.getRank(), vectorType.getRank() - elementVectorRank,
+ memRefType.getContext());
+}
+
template <typename EmitFun>
static LogicalResult verifyPermutationMap(AffineMap permutationMap,
EmitFun emitOpError) {
@@ -1233,7 +1250,8 @@ static LogicalResult verifyPermutationMap(AffineMap permutationMap,
static LogicalResult verifyTransferOp(Operation *op, MemRefType memrefType,
VectorType vectorType,
- AffineMap permutationMap) {
+ AffineMap permutationMap,
+ ArrayAttr optionalMasked) {
auto memrefElementType = memrefType.getElementType();
if (auto memrefVectorElementType = memrefElementType.dyn_cast<VectorType>()) {
// Memref has vector element type.
@@ -1282,52 +1300,60 @@ static LogicalResult verifyTransferOp(Operation *op, MemRefType memrefType,
return op->emitOpError("requires a permutation_map with input dims of the "
"same rank as the memref type");
- return success();
-}
+ if (optionalMasked) {
+ if (permutationMap.getNumResults() !=
+ static_cast<int64_t>(optionalMasked.size()))
+ return op->emitOpError("expects the optional masked attr of same rank as "
+ "permutation_map results: ")
+ << AffineMapAttr::get(permutationMap);
+ }
-/// Build the default minor identity map suitable for a vector transfer. This
-/// also handles the case memref<... x vector<...>> -> vector<...> in which the
-/// rank of the identity map must take the vector element type into account.
-AffineMap
-mlir::vector::impl::getTransferMinorIdentityMap(MemRefType memRefType,
- VectorType vectorType) {
- int64_t elementVectorRank = 0;
- VectorType elementVectorType =
- memRefType.getElementType().dyn_cast<VectorType>();
- if (elementVectorType)
- elementVectorRank += elementVectorType.getRank();
- return AffineMap::getMinorIdentityMap(
- memRefType.getRank(), vectorType.getRank() - elementVectorRank,
- memRefType.getContext());
+ return success();
}
-/// Builder that sets permutation map and padding to 'getMinorIdentityMap' and
-/// zero, respectively, by default.
+/// Builder that sets padding to zero.
void TransferReadOp::build(OpBuilder &builder, OperationState &result,
VectorType vector, Value memref, ValueRange indices,
- AffineMap permutationMap) {
+ AffineMap permutationMap,
+ ArrayRef<bool> maybeMasked) {
Type elemType = vector.cast<VectorType>().getElementType();
Value padding = builder.create<ConstantOp>(result.location, elemType,
builder.getZeroAttr(elemType));
- build(builder, result, vector, memref, indices, permutationMap, padding);
+ if (maybeMasked.empty())
+ return build(builder, result, vector, memref, indices, permutationMap,
+ padding, ArrayAttr());
+ ArrayAttr maskedArrayAttr = builder.getBoolArrayAttr(maybeMasked);
+ build(builder, result, vector, memref, indices, permutationMap, padding,
+ maskedArrayAttr);
}
/// Builder that sets permutation map (resp. padding) to 'getMinorIdentityMap'
/// (resp. zero).
void TransferReadOp::build(OpBuilder &builder, OperationState &result,
VectorType vectorType, Value memref,
- ValueRange indices) {
- build(builder, result, vectorType, memref, indices,
- getTransferMinorIdentityMap(memref.getType().cast<MemRefType>(),
- vectorType));
+ ValueRange indices, ArrayRef<bool> maybeMasked) {
+ auto permMap = getTransferMinorIdentityMap(
+ memref.getType().cast<MemRefType>(), vectorType);
+ build(builder, result, vectorType, memref, indices, permMap, maybeMasked);
}
template <typename TransferOp>
void printTransferAttrs(OpAsmPrinter &p, TransferOp op) {
- SmallVector<StringRef, 1> elidedAttrs;
+ SmallVector<StringRef, 2> elidedAttrs;
if (op.permutation_map() == TransferOp::getTransferMinorIdentityMap(
op.getMemRefType(), op.getVectorType()))
elidedAttrs.push_back(op.getPermutationMapAttrName());
+ bool elideMasked = true;
+ if (auto maybeMasked = op.masked()) {
+ for (auto attr : *maybeMasked) {
+ if (!attr.template cast<BoolAttr>().getValue()) {
+ elideMasked = false;
+ break;
+ }
+ }
+ }
+ if (elideMasked)
+ elidedAttrs.push_back(op.getMaskedAttrName());
p.printOptionalAttrDict(op.getAttrs(), elidedAttrs);
}
@@ -1388,7 +1414,8 @@ static LogicalResult verify(TransferReadOp op) {
return op.emitOpError("requires ") << memrefType.getRank() << " indices";
if (failed(verifyTransferOp(op.getOperation(), memrefType, vectorType,
- permutationMap)))
+ permutationMap,
+ op.masked() ? *op.masked() : ArrayAttr())))
return failure();
if (auto memrefVectorElementType = memrefElementType.dyn_cast<VectorType>()) {
@@ -1419,11 +1446,24 @@ static LogicalResult verify(TransferReadOp op) {
/// Builder that sets permutation map to 'getMinorIdentityMap'.
void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
- Value vector, Value memref, ValueRange indices) {
+ Value vector, Value memref, ValueRange indices,
+ ArrayRef<bool> maybeMasked) {
auto vectorType = vector.getType().cast<VectorType>();
auto permMap = getTransferMinorIdentityMap(
memref.getType().cast<MemRefType>(), vectorType);
- build(builder, result, vector, memref, indices, permMap);
+ if (maybeMasked.empty())
+ return build(builder, result, vector, memref, indices, permMap,
+ ArrayAttr());
+ ArrayAttr maskedArrayAttr = builder.getBoolArrayAttr(maybeMasked);
+ build(builder, result, vector, memref, indices, permMap, maskedArrayAttr);
+}
+
+/// Builder that sets permutation map to 'getMinorIdentityMap'.
+void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
+ Value vector, Value memref, ValueRange indices,
+ AffineMap permutationMap) {
+ build(builder, result, vector, memref, indices,
+ /*maybeMasked=*/ArrayRef<bool>{});
}
static ParseResult parseTransferWriteOp(OpAsmParser &parser,
@@ -1477,7 +1517,8 @@ static LogicalResult verify(TransferWriteOp op) {
return op.emitOpError("requires ") << memrefType.getRank() << " indices";
if (failed(verifyTransferOp(op.getOperation(), memrefType, vectorType,
- permutationMap)))
+ permutationMap,
+ op.masked() ? *op.masked() : ArrayAttr())))
return failure();
return verifyPermutationMap(permutationMap,
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index af7e5ad86af8..cf1bdede9027 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -564,9 +564,12 @@ struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> {
// Get VectorType for slice 'i'.
auto sliceVectorType = resultTupleType.getType(index);
// Create split TransferReadOp for 'sliceUser'.
+ // `masked` attribute propagates conservatively: if the coarse op didn't
+ // need masking, the fine op doesn't either.
vectorTupleValues[index] = rewriter.create<vector::TransferReadOp>(
loc, sliceVectorType, xferReadOp.memref(), sliceIndices,
- xferReadOp.permutation_map(), xferReadOp.padding());
+ xferReadOp.permutation_map(), xferReadOp.padding(),
+ xferReadOp.masked() ? *xferReadOp.masked() : ArrayAttr());
};
generateTransferOpSlices(memrefElementType, sourceVectorType,
resultTupleType, sizes, strides, indices, rewriter,
@@ -620,9 +623,12 @@ struct SplitTransferWriteOp : public OpRewritePattern<vector::TransferWriteOp> {
xferWriteOp.indices().end());
auto createSlice = [&](unsigned index, ArrayRef<Value> sliceIndices) {
// Create split TransferWriteOp for source vector 'tupleOp.operand[i]'.
+ // `masked` attribute propagates conservatively: if the coarse op didn't
+ // need masking, the fine op doesn't either.
rewriter.create<vector::TransferWriteOp>(
loc, tupleOp.getOperand(index), xferWriteOp.memref(), sliceIndices,
- xferWriteOp.permutation_map());
+ xferWriteOp.permutation_map(),
+ xferWriteOp.masked() ? *xferWriteOp.masked() : ArrayAttr());
};
generateTransferOpSlices(memrefElementType, resultVectorType,
sourceTupleType, sizes, strides, indices, rewriter,
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 1c23072b6109..26e3e9dbe2b1 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -918,6 +918,24 @@ func @transfer_read_1d_non_zero_addrspace(%A : memref<?xf32, 3>, %base: index) -
// CHECK: %[[vecPtr_b:.*]] = llvm.addrspacecast %[[gep_b]] :
// CHECK-SAME: !llvm<"float addrspace(3)*"> to !llvm<"<17 x float>*">
+func @transfer_read_1d_not_masked(%A : memref<?xf32>, %base: index) -> vector<17xf32> {
+ %f7 = constant 7.0: f32
+ %f = vector.transfer_read %A[%base], %f7 {masked = [false]} :
+ memref<?xf32>, vector<17xf32>
+ return %f: vector<17xf32>
+}
+// CHECK-LABEL: func @transfer_read_1d_not_masked
+// CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: !llvm.i64) -> !llvm<"<17 x float>">
+//
+// 1. Bitcast to vector form.
+// CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} :
+// CHECK-SAME: (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
+// CHECK: %[[vecPtr:.*]] = llvm.bitcast %[[gep]] :
+// CHECK-SAME: !llvm<"float*"> to !llvm<"<17 x float>*">
+//
+// 2. Rewrite as a load.
+// CHECK: %[[loaded:.*]] = llvm.load %[[vecPtr]] : !llvm<"<17 x float>*">
+
func @genbool_1d() -> vector<8xi1> {
%0 = vector.constant_mask [4] : vector<8xi1>
return %0 : vector<8xi1>
diff --git a/mlir/test/Conversion/VectorToLoops/vector-to-loops.mlir b/mlir/test/Conversion/VectorToLoops/vector-to-loops.mlir
index 5c1e6361adb9..c0bc5542e21d 100644
--- a/mlir/test/Conversion/VectorToLoops/vector-to-loops.mlir
+++ b/mlir/test/Conversion/VectorToLoops/vector-to-loops.mlir
@@ -220,14 +220,12 @@ func @transfer_read_progressive(%A : memref<?x?xf32>, %base: index) -> vector<17
// CHECK: %[[cst:.*]] = constant 7.000000e+00 : f32
%f7 = constant 7.0: f32
- // CHECK-DAG: %[[cond0:.*]] = constant 1 : i1
// CHECK-DAG: %[[splat:.*]] = constant dense<7.000000e+00> : vector<15xf32>
// CHECK-DAG: %[[alloc:.*]] = alloc() : memref<17xvector<15xf32>>
// CHECK-DAG: %[[dim:.*]] = dim %[[A]], 0 : memref<?x?xf32>
// CHECK: affine.for %[[I:.*]] = 0 to 17 {
// CHECK: %[[add:.*]] = affine.apply #[[MAP0]](%[[I]])[%[[base]]]
- // CHECK: %[[cmp:.*]] = cmpi "slt", %[[add]], %[[dim]] : index
- // CHECK: %[[cond1:.*]] = and %[[cmp]], %[[cond0]] : i1
+ // CHECK: %[[cond1:.*]] = cmpi "slt", %[[add]], %[[dim]] : index
// CHECK: scf.if %[[cond1]] {
// CHECK: %[[vec_1d:.*]] = vector.transfer_read %[[A]][%[[add]], %[[base]]], %[[cst]] : memref<?x?xf32>, vector<15xf32>
// CHECK: store %[[vec_1d]], %[[alloc]][%[[I]]] : memref<17xvector<15xf32>>
@@ -253,7 +251,6 @@ func @transfer_read_progressive(%A : memref<?x?xf32>, %base: index) -> vector<17
// CHECK-SAME: %[[base:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[vec:[a-zA-Z0-9]+]]: vector<17x15xf32>
func @transfer_write_progressive(%A : memref<?x?xf32>, %base: index, %vec: vector<17x15xf32>) {
- // CHECK: %[[cond0:.*]] = constant 1 : i1
// CHECK: %[[alloc:.*]] = alloc() : memref<17xvector<15xf32>>
// CHECK: %[[vmemref:.*]] = vector.type_cast %[[alloc]] : memref<17xvector<15xf32>> to memref<vector<17x15xf32>>
// CHECK: store %[[vec]], %[[vmemref]][] : memref<vector<17x15xf32>>
@@ -261,8 +258,7 @@ func @transfer_write_progressive(%A : memref<?x?xf32>, %base: index, %vec: vecto
// CHECK: affine.for %[[I:.*]] = 0 to 17 {
// CHECK: %[[add:.*]] = affine.apply #[[MAP0]](%[[I]])[%[[base]]]
// CHECK: %[[cmp:.*]] = cmpi "slt", %[[add]], %[[dim]] : index
- // CHECK: %[[cond1:.*]] = and %[[cmp]], %[[cond0]] : i1
- // CHECK: scf.if %[[cond1]] {
+ // CHECK: scf.if %[[cmp]] {
// CHECK: %[[vec_1d:.*]] = load %0[%[[I]]] : memref<17xvector<15xf32>>
// CHECK: vector.transfer_write %[[vec_1d]], %[[A]][%[[add]], %[[base]]] : vector<15xf32>, memref<?x?xf32>
// CHECK: }
@@ -271,3 +267,26 @@ func @transfer_write_progressive(%A : memref<?x?xf32>, %base: index, %vec: vecto
vector<17x15xf32>, memref<?x?xf32>
return
}
+
+// -----
+
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
+
+// CHECK-LABEL: transfer_write_progressive_not_masked(
+// CHECK-SAME: %[[A:[a-zA-Z0-9]+]]: memref<?x?xf32>,
+// CHECK-SAME: %[[base:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[vec:[a-zA-Z0-9]+]]: vector<17x15xf32>
+func @transfer_write_progressive_not_masked(%A : memref<?x?xf32>, %base: index, %vec: vector<17x15xf32>) {
+ // CHECK-NOT: scf.if
+ // CHECK-NEXT: %[[alloc:.*]] = alloc() : memref<17xvector<15xf32>>
+ // CHECK-NEXT: %[[vmemref:.*]] = vector.type_cast %[[alloc]] : memref<17xvector<15xf32>> to memref<vector<17x15xf32>>
+ // CHECK-NEXT: store %[[vec]], %[[vmemref]][] : memref<vector<17x15xf32>>
+ // CHECK-NEXT: affine.for %[[I:.*]] = 0 to 17 {
+ // CHECK-NEXT: %[[add:.*]] = affine.apply #[[MAP0]](%[[I]])[%[[base]]]
+ // CHECK-NEXT: %[[vec_1d:.*]] = load %0[%[[I]]] : memref<17xvector<15xf32>>
+ // CHECK-NEXT: vector.transfer_write %[[vec_1d]], %[[A]][%[[add]], %[[base]]] : vector<15xf32>, memref<?x?xf32>
+ vector.transfer_write %vec, %A[%base, %base] {masked = [false, false]} :
+ vector<17x15xf32>, memref<?x?xf32>
+ return
+}
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 1b0b0e38c4d5..c18cf38edfc9 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -348,6 +348,16 @@ func @test_vector.transfer_read(%arg0: memref<?x?xvector<4x3xf32>>) {
// -----
+func @test_vector.transfer_read(%arg0: memref<?x?xvector<2x3xf32>>) {
+ %c3 = constant 3 : index
+ %f0 = constant 0.0 : f32
+ %vf0 = splat %f0 : vector<2x3xf32>
+ // expected-error at +1 {{ expects the optional masked attr of same rank as permutation_map results: affine_map<(d0, d1) -> (d0, d1)>}}
+ %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 {masked = [false], permutation_map = affine_map<(d0, d1)->(d0, d1)>} : memref<?x?xvector<2x3xf32>>, vector<1x1x2x3xf32>
+}
+
+// -----
+
func @test_vector.transfer_write(%arg0: memref<?x?xf32>) {
%c3 = constant 3 : index
%cst = constant 3.0 : f32
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index aacfdf75d028..c194cbe23811 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -22,6 +22,8 @@ func @vector_transfer_ops(%arg0: memref<?x?xf32>,
%3 = vector.transfer_read %arg0[%c3, %c3], %cst {permutation_map = affine_map<(d0, d1)->(d1)>} : memref<?x?xf32>, vector<128xf32>
// CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} : memref<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>
%4 = vector.transfer_read %arg1[%c3, %c3], %vf0 {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : memref<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>
+ // CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} {masked = [true, false]} : memref<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>
+ %5 = vector.transfer_read %arg1[%c3, %c3], %vf0 {masked = [true, false]} : memref<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>
// CHECK: vector.transfer_write
vector.transfer_write %0, %arg0[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d0)>} : vector<128xf32>, memref<?x?xf32>
@@ -29,6 +31,8 @@ func @vector_transfer_ops(%arg0: memref<?x?xf32>,
vector.transfer_write %1, %arg0[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d1, d0)>} : vector<3x7xf32>, memref<?x?xf32>
// CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] : vector<1x1x4x3xf32>, memref<?x?xvector<4x3xf32>>
vector.transfer_write %4, %arg1[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : vector<1x1x4x3xf32>, memref<?x?xvector<4x3xf32>>
+ // CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] : vector<1x1x4x3xf32>, memref<?x?xvector<4x3xf32>>
+ vector.transfer_write %5, %arg1[%c3, %c3] {masked = [true, true]} : vector<1x1x4x3xf32>, memref<?x?xvector<4x3xf32>>
return
}
More information about the Mlir-commits
mailing list