[Mlir-commits] [mlir] 753a67b - [mlir][Linalg] Refactor and improve vectorization to add support for reduction into 0-d tensors.
Nicolas Vasilache
llvmlistbot at llvm.org
Tue Oct 12 05:47:41 PDT 2021
Author: Nicolas Vasilache
Date: 2021-10-12T12:47:36Z
New Revision: 753a67b5c98f86ddddd4326e73de600250ea3cbe
URL: https://github.com/llvm/llvm-project/commit/753a67b5c98f86ddddd4326e73de600250ea3cbe
DIFF: https://github.com/llvm/llvm-project/commit/753a67b5c98f86ddddd4326e73de600250ea3cbe.diff
LOG: [mlir][Linalg] Refactor and improve vectorization to add support for reduction into 0-d tensors.
This revision takes advantage of the recently added support for 0-d transfers and vector.multi_reduction that return a scalar.
Reviewed By: pifon2a
Differential Revision: https://reviews.llvm.org/D111626
Added:
Modified:
mlir/include/mlir/Dialect/Vector/VectorOps.td
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/test/Dialect/Linalg/vectorization.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index f48ef35cdab07..cdd5fcdbc548a 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -1288,7 +1288,7 @@ def Vector_TransferReadOp :
OpBuilder<(ins "VectorType":$vector, "Value":$source,
"ValueRange":$indices, "AffineMap":$permutationMap,
CArg<"ArrayRef<bool>", "{}">:$inBounds)>,
- // Builder that sets padding to 'getMinorIdentityMap'.
+ // Builder that sets permutation map to 'getMinorIdentityMap'.
OpBuilder<(ins "VectorType":$vector, "Value":$source,
"ValueRange":$indices, "Value":$padding,
CArg<"ArrayRef<bool>", "{}">:$inBounds)>,
@@ -1306,6 +1306,17 @@ def Vector_TransferReadOp :
"ArrayAttr":$inBounds)>
];
+ let extraClassDeclaration = [{
+ /// Temporary convenience builders to account for the fact that we do not
+ /// have 0-d vectors atm. These create a constant `vector<1xt>` and
+ /// insert/extract into it.
+ // Builder that sets permutation map (resp. padding) to
+ // 'getMinorIdentityMap' (resp. zero).
+ static Value createScalarOp(OpBuilder &builder, Location loc, Value source,
+ ValueRange indices,
+ ArrayRef<bool> inBounds = ArrayRef<bool>{});
+ }];
+
let hasCanonicalizer = 1;
let hasFolder = 1;
}
@@ -1416,11 +1427,12 @@ def Vector_TransferWriteOp :
}];
let builders = [
+ // Builder that sets an empty mask.
+ OpBuilder<(ins "Value":$vector, "Value":$source, "ValueRange":$indices,
+ "AffineMap":$permutationMap, CArg<"ArrayRef<bool>", "{}">:$inBounds)>,
// Builder that sets permutation map to 'getMinorIdentityMap'.
OpBuilder<(ins "Value":$vector, "Value":$source, "ValueRange":$indices,
CArg<"ArrayRef<bool>", "{}">:$inBounds)>,
- OpBuilder<(ins "Value":$vector, "Value":$source, "ValueRange":$indices,
- "AffineMap":$permutationMap)>,
OpBuilder<(ins "Value":$vector, "Value":$source, "ValueRange":$indices,
"AffineMapAttr":$permutationMap, "ArrayAttr":$inBounds)>,
OpBuilder<(ins "Value":$vector, "Value":$source, "ValueRange":$indices,
@@ -1429,6 +1441,18 @@ def Vector_TransferWriteOp :
"AffineMap":$permutationMap, "ArrayAttr":$inBounds)>,
];
+ let extraClassDeclaration = [{
+ /// Temporary convenience builders to account for the fact that we do not
+ /// have 0-d vectors atm. These create a constant `vector<1xt>` and
+ /// insert/extract into it.
+ // Builder that sets permutation map (resp. padding) to
+ // 'getMinorIdentityMap' (resp. zero).
+ static Operation *createScalarOp(
+ OpBuilder &builder, Location loc, Value value,
+ Value dest, ValueRange indices,
+ ArrayRef<bool> inBounds = ArrayRef<bool>{});
+ }];
+
let hasFolder = 1;
let hasCanonicalizer = 1;
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 60a9e67e476a6..b6408135b7b7a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -40,6 +40,9 @@ using llvm::dbgs;
#define DEBUG_TYPE "linalg-vectorization"
+#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X)
+
/// Return the unique instance of OpType in `block` if it is indeed unique.
/// Return null if none or more than 1 instances exist.
template <typename OpType>
@@ -106,7 +109,7 @@ struct VectorizationResult {
/// ShapedType of `v`.
static VectorType extractVectorTypeFromShapedValue(Value v) {
auto st = v.getType().cast<ShapedType>();
- if (st.isa<MemRefType>() && st.getShape().empty())
+ if (st.getShape().empty())
return VectorType();
return VectorType::get(st.getShape(), st.getElementType());
}
@@ -163,16 +166,23 @@ static Value broadcastIfNeeded(OpBuilder &b, Value value,
return b.createOrFold<vector::BroadcastOp>(loc, targetVectorType, value);
}
-/// If value of assumed VectorType has a shape
diff erent than `shape`, build and
-/// return a new vector.broadcast to `shape`.
-/// Otherwise, just return value.
-static Value reduceIfNeeded(OpBuilder &b, VectorType targetVectorType,
- Value value, OpOperand *outputOperand) {
+/// Assuming `outputOperand` is an output operand of a LinalgOp, determine
+/// whether a reduction is needed to produce a `targetType` and create that
+/// reduction if it is the case.
+static Value reduceIfNeeded(OpBuilder &b, Type targetType, Value value,
+ OpOperand *outputOperand) {
+ LDBG("Reduce " << value << " to type " << targetType);
+ LDBG("In LinalgOp operand #" << outputOperand->getOperandNumber() << "\n"
+ << *(outputOperand->getOwner()));
auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
auto vecType = value.getType().dyn_cast<VectorType>();
- if (!vecType || vecType.getShape() == targetVectorType.getShape())
+ VectorType targetVectorType = targetType.dyn_cast<VectorType>();
+ if (!vecType)
+ return value;
+ if (targetVectorType && vecType.getShape() == targetVectorType.getShape())
return value;
+ // At this point, we know we need to reduce. Detect the reduction operator.
unsigned pos = 0;
MLIRContext *ctx = b.getContext();
SmallVector<AffineExpr> exprs;
@@ -181,7 +191,6 @@ static Value reduceIfNeeded(OpBuilder &b, VectorType targetVectorType,
exprs.push_back(getAffineDimExpr(pos++, ctx));
auto loc = value.getLoc();
- // At this point, we know we need to reduce. Detect the reduction operator.
auto maybeKind = matchLinalgReduction(outputOperand);
assert(maybeKind && "Failed precondition: could not get reduction kind");
unsigned idx = 0;
@@ -196,16 +205,18 @@ static Value reduceIfNeeded(OpBuilder &b, VectorType targetVectorType,
}
/// Build a vector.transfer_read from `source` at indices set to all `0`.
-/// If source has rank zero, build an memref.load.
+/// If source has rank zero, build a `vector<1xt> transfer_read + extract`.
/// Return the produced value.
-static Value buildVectorRead(OpBuilder &b, Value source, VectorType vectorType,
+static Value buildVectorRead(OpBuilder &b, Value source, Type readType,
AffineMap map) {
Location loc = source.getLoc();
auto shapedType = source.getType().cast<ShapedType>();
SmallVector<Value> indices(shapedType.getRank(),
b.create<ConstantIndexOp>(loc, 0));
- return b.create<vector::TransferReadOp>(loc, vectorType, source, indices,
- map);
+ if (auto vectorType = readType.dyn_cast<VectorType>())
+ return b.create<vector::TransferReadOp>(loc, vectorType, source, indices,
+ map);
+ return vector::TransferReadOp::createScalarOp(b, loc, source, indices);
}
/// Build a vector.transfer_write of `value` into `outputOperand` at indices set
@@ -216,13 +227,14 @@ static Value buildVectorWrite(OpBuilder &b, Value value,
OpOperand *outputOperand) {
Operation *write;
Location loc = value.getLoc();
+ auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
if (VectorType vectorType =
extractVectorTypeFromShapedValue(outputOperand->get())) {
- auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
AffineMap map =
reindexIndexingMap(linalgOp.getTiedIndexingMap(outputOperand));
SmallVector<int64_t> transposeShape =
applyPermutationMap(inversePermutation(map), vectorType.getShape());
+ assert(!transposeShape.empty() && "unexpected empty transpose shape");
vectorType = VectorType::get(transposeShape, vectorType.getElementType());
SmallVector<Value> indices(linalgOp.getRank(outputOperand),
b.create<ConstantIndexOp>(loc, 0));
@@ -231,9 +243,12 @@ static Value buildVectorWrite(OpBuilder &b, Value value,
write = b.create<vector::TransferWriteOp>(loc, value, outputOperand->get(),
indices, map);
} else {
- write = b.create<memref::StoreOp>(loc, value, outputOperand->get());
+ value =
+ reduceIfNeeded(b, getElementTypeOrSelf(value), value, outputOperand);
+ write = vector::TransferWriteOp::createScalarOp(
+ b, loc, value, outputOperand->get(), ValueRange{});
}
- LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: vectorized op: " << *write);
+ LDBG("vectorized op: " << *write);
if (!write->getResults().empty())
return write->getResult(0);
return Value();
@@ -329,7 +344,7 @@ static VectorizationResult vectorizeLinalgIndex(OpBuilder &b, Operation *op,
static VectorizationResult
vectorizeOneOp(OpBuilder &b, Operation *op, const BlockAndValueMapping &bvm,
ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
- LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: vectorize op " << *op);
+ LDBG("vectorize op " << *op);
// 1. Try to apply any CustomVectorizationHook.
if (!customVectorizationHooks.empty()) {
@@ -466,33 +481,27 @@ LogicalResult vectorizeAsLinalgGeneric(
continue;
}
// TODO: 0-d vectors.
- if (linalgOp.getShape(opOperand).empty()) {
- Value loaded =
- b.create<memref::LoadOp>(linalgOp.getLoc(), opOperand->get());
- LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vectorized bbarg("
- << bbarg.getArgNumber() << "): " << loaded);
- bvm.map(bbarg, loaded);
- bvm.map(opOperand->get(), loaded);
- continue;
- }
+ Type readType;
AffineMap map;
- VectorType vectorType;
- if (broadcastToMaximalCommonShape) {
- map = inverseAndBroadcastProjectedPermuation(
- linalgOp.getTiedIndexingMap(opOperand));
- vectorType = VectorType::get(commonVectorShape,
- getElementTypeOrSelf(opOperand->get()));
+ if (linalgOp.getShape(opOperand).empty()) {
+ readType = bbarg.getType();
} else {
- map = inversePermutation(
- reindexIndexingMap(linalgOp.getTiedIndexingMap(opOperand)));
- vectorType = VectorType::get(map.compose(linalgOp.getShape(opOperand)),
+ if (broadcastToMaximalCommonShape) {
+ map = inverseAndBroadcastProjectedPermuation(
+ linalgOp.getTiedIndexingMap(opOperand));
+ readType = VectorType::get(commonVectorShape,
+ getElementTypeOrSelf(opOperand->get()));
+ } else {
+ map = inversePermutation(
+ reindexIndexingMap(linalgOp.getTiedIndexingMap(opOperand)));
+ readType = VectorType::get(map.compose(linalgOp.getShape(opOperand)),
getElementTypeOrSelf(opOperand->get()));
+ }
}
- Value vectorRead = buildVectorRead(b, opOperand->get(), vectorType, map);
- LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vectorized bbarg("
- << bbarg.getArgNumber() << "): " << vectorRead);
- bvm.map(bbarg, vectorRead);
- bvm.map(opOperand->get(), vectorRead);
+ Value readValue = buildVectorRead(b, opOperand->get(), readType, map);
+ LDBG("new vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue);
+ bvm.map(bbarg, readValue);
+ bvm.map(opOperand->get(), readValue);
}
auto hooks = llvm::to_vector<4>(customVectorizationHooks);
@@ -516,12 +525,11 @@ LogicalResult vectorizeAsLinalgGeneric(
for (Operation &op : block.getOperations()) {
VectorizationResult result = vectorizeOneOp(b, &op, bvm, hooks);
if (result.status == VectorizationStatus::Failure) {
- LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: failed to vectorize: " << op);
+ LDBG("failed to vectorize: " << op);
return failure();
}
if (result.status == VectorizationStatus::NewOp) {
- LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vector op: "
- << *result.newOp;);
+ LDBG("new vector op: " << *result.newOp;);
bvm.map(op.getResults(), result.newOp->getResults());
}
}
@@ -536,9 +544,9 @@ static LogicalResult vectorizeContraction(OpBuilder &b, LinalgOp linalgOp,
Location loc = linalgOp.getLoc();
// Vectorize other ops as vector contraction.
// TODO: interface.
- LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
- << "Rewrite linalg op as vector.contract: ";
- linalgOp.dump());
+ LDBG(""
+ << "Rewrite linalg op as vector.contract: ";
+ linalgOp.dump());
// Special function that describes how to vectorize the multiplication op in a
// linalg contraction.
CustomVectorizationHook vectorizeContraction =
@@ -592,11 +600,15 @@ static bool allIndexingsAreProjectedPermutation(LinalgOp op) {
// TODO: probably need some extra checks for reduction followed by consumer
// ops that may not commute (e.g. linear reduction + non-linear instructions).
static LogicalResult reductionPreconditions(LinalgOp op) {
- if (llvm::none_of(op.iterator_types(), isReductionIterator))
+ if (llvm::none_of(op.iterator_types(), isReductionIterator)) {
+ LDBG("reduction precondition failed: no reduction iterator");
return failure();
+ }
for (OpOperand *opOperand : op.getOutputOperands()) {
- if (!matchLinalgReduction(opOperand))
+ if (!matchLinalgReduction(opOperand)) {
+ LDBG("reduction precondition failed: reduction detection failed");
return failure();
+ }
}
return success();
}
@@ -604,8 +616,10 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
auto linalgOp = cast<linalg::LinalgOp>(op);
// All types must be static shape to go to vector.
- if (linalgOp.hasDynamicShape())
+ if (linalgOp.hasDynamicShape()) {
+ LDBG("precondition failed: dynamic shape");
return failure();
+ }
if (isElementwise(op))
return success();
if (isaContractionOpInterface(linalgOp))
@@ -613,10 +627,15 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
// TODO: the common vector shape is equal to the static loop sizes only when
// all indexing maps are projected permutations. For convs and stencils the
// logic will need to evolve.
- if (allIndexingsAreProjectedPermutation(linalgOp) &&
- succeeded(reductionPreconditions(linalgOp)))
- return success();
- return failure();
+ if (!allIndexingsAreProjectedPermutation(linalgOp)) {
+ LDBG("precondition failed: not projected permutations");
+ return failure();
+ }
+ if (failed(reductionPreconditions(linalgOp))) {
+ LDBG("precondition failed: reduction preconditions");
+ return failure();
+ }
+ return success();
}
LogicalResult
@@ -629,10 +648,10 @@ mlir::linalg::vectorizeLinalgOp(OpBuilder &b, Operation *op,
if (isaContractionOpInterface(linalgOp))
return vectorizeContraction(b, linalgOp, newResults);
- LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
- << "Vectorize linalg op as a generic by broadcasting to "
- "maximal common shape: "
- << *op);
+ LDBG(""
+ << "Vectorize linalg op as a generic by broadcasting to "
+ "maximal common shape: "
+ << *op);
return vectorizeAsLinalgGeneric(b, linalgOp, newResults,
/*broadcastToMaximalCommonShape=*/true);
}
@@ -1200,9 +1219,8 @@ static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
ValueRange values) {
if (firstOp->getBlock() != secondOp->getBlock() ||
!firstOp->isBeforeInBlock(secondOp)) {
- LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
- << "interleavedUses precondition failed, firstOp: "
- << *firstOp << ", second op: " << *secondOp);
+ LDBG("interleavedUses precondition failed, firstOp: "
+ << *firstOp << ", second op: " << *secondOp);
return true;
}
for (auto v : values) {
@@ -1214,10 +1232,8 @@ static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
if (owner->getBlock() == firstOp->getBlock() &&
(owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner)))
continue;
- LLVM_DEBUG(llvm::dbgs()
- << "\n[" DEBUG_TYPE "]: "
- << " found interleaved op " << *owner
- << ", firstOp: " << *firstOp << ", second op: " << *secondOp);
+ LDBG(" found interleaved op " << *owner << ", firstOp: " << *firstOp
+ << ", second op: " << *secondOp);
return true;
}
}
@@ -1248,15 +1264,14 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
!viewOrAlloc.getDefiningOp<memref::AllocOp>())
return failure();
- LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: " << viewOrAlloc);
+ LDBG(viewOrAlloc);
// Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
if (!subViewOp)
return failure();
Value subView = subViewOp.getResult();
- LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
- << "with subView " << subView);
+ LDBG("with subView " << subView);
// Find the copy into `subView` without interleaved uses.
CopyOp copyOp;
@@ -1265,8 +1280,7 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
assert(newCopyOp.output().getType().isa<MemRefType>());
if (newCopyOp.output() != subView)
continue;
- LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
- << "copy candidate " << *newCopyOp);
+ LDBG("copy candidate " << *newCopyOp);
if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
continue;
copyOp = newCopyOp;
@@ -1275,8 +1289,7 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
}
if (!copyOp)
return failure();
- LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
- << "with copy " << *copyOp);
+ LDBG("with copy " << *copyOp);
// Find the fill into `viewOrAlloc` without interleaved uses before the copy.
FillOp maybeFillOp;
@@ -1285,8 +1298,7 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
assert(newFillOp.output().getType().isa<MemRefType>());
if (newFillOp.output() != viewOrAlloc)
continue;
- LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
- << "fill candidate " << *newFillOp);
+ LDBG("fill candidate " << *newFillOp);
if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
continue;
maybeFillOp = newFillOp;
@@ -1297,8 +1309,7 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
if (maybeFillOp && xferOp.padding() != maybeFillOp.value())
return failure();
if (maybeFillOp)
- LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
- << "with maybeFillOp " << *maybeFillOp);
+ LDBG("with maybeFillOp " << *maybeFillOp);
// `in` is the subview that linalg.copy reads. Replace it.
Value in = copyOp.input();
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 09d1bdc5349c2..e696f3481ed0f 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -2439,6 +2439,18 @@ void TransferReadOp::build(OpBuilder &builder, OperationState &result,
/*mask=*/Value(), inBounds);
}
+Value TransferReadOp::createScalarOp(OpBuilder &builder, Location loc,
+ Value source, ValueRange indices,
+ ArrayRef<bool> inBounds) {
+ Type elemType = source.getType().cast<ShapedType>().getElementType();
+ auto vectorType = VectorType::get(ArrayRef<int64_t>{1}, elemType);
+ AffineMap map = AffineMap::get(/*numDims=*/0, /*numSymbols=*/0,
+ getAffineConstantExpr(0, loc.getContext()));
+ Value read = builder.create<vector::TransferReadOp>(loc, vectorType, source,
+ indices, map, inBounds);
+ return builder.create<vector::ExtractOp>(loc, read, ArrayRef<int64_t>{0});
+}
+
static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
SmallVector<StringRef, 3> elidedAttrs;
elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
@@ -2769,6 +2781,16 @@ void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
// TransferWriteOp
//===----------------------------------------------------------------------===//
+void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
+ Value vector, Value dest, ValueRange indices,
+ AffineMap permutationMap, ArrayRef<bool> inBounds) {
+ if (inBounds.empty())
+ return build(builder, result, vector, dest, indices, permutationMap,
+ /*mask=*/Value(), ArrayAttr());
+ build(builder, result, vector, dest, indices, permutationMap,
+ /*mask=*/Value(), builder.getBoolArrayAttr(inBounds));
+}
+
/// Builder that sets permutation map to 'getMinorIdentityMap'.
void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
Value vector, Value source, ValueRange indices,
@@ -2783,13 +2805,6 @@ void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
build(builder, result, vector, source, indices, permMap, inBoundsArrayAttr);
}
-void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
- Value vector, Value source, ValueRange indices,
- AffineMap permutationMap) {
- build(builder, result, vector, source, indices, permutationMap,
- /*inBounds=*/ArrayAttr());
-}
-
void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
Value vector, Value source, ValueRange indices,
AffineMapAttr permutationMap,
@@ -2817,6 +2832,20 @@ void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
mask, inBounds);
}
+Operation *TransferWriteOp::createScalarOp(OpBuilder &builder, Location loc,
+ Value value, Value dest,
+ ValueRange indices,
+ ArrayRef<bool> inBounds) {
+ Value vectorOfAScalar = value;
+ if (!value.getType().isa<VectorType>())
+ vectorOfAScalar = builder.create<vector::BroadcastOp>(
+ loc, VectorType::get({1}, value.getType()), value);
+ AffineMap map = AffineMap::get(/*numDims=*/0, /*numSymbols=*/0,
+ getAffineConstantExpr(0, loc.getContext()));
+ return builder.create<vector::TransferWriteOp>(loc, vectorOfAScalar, dest,
+ indices, map, inBounds);
+}
+
static ParseResult parseTransferWriteOp(OpAsmParser &parser,
OperationState &result) {
auto &builder = parser.getBuilder();
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index d56214fb2fdf7..c3a4a05413deb 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -203,8 +203,9 @@ func @test_vectorize_fill(%A : memref<8x16xf32>, %arg0 : f32) {
// CHECK-LABEL: func @test_vectorize_fill
func @test_vectorize_fill_scalar(%A : memref<f32>, %arg0 : f32) {
- // CHECK-SAME: (%[[M:.*]]: memref<f32>, %[[V:.*]]: f32)
- // CHECK: store %[[V]], %[[M]][] : memref<f32>
+ // CHECK-SAME: (%[[M:.*]]: memref<f32>, %[[val:.*]]: f32)
+ // CHECK: %[[VEC:.*]] = vector.broadcast %[[val]] : f32 to vector<1xf32>
+ // CHECK: vector.transfer_write %[[VEC]], %[[M]][] {{.*}} : vector<1xf32>, memref<f32>
linalg.fill(%arg0, %A) : f32, memref<f32>
return
}
@@ -223,8 +224,11 @@ func @test_vectorize_copy(%A : memref<8x16xf32>, %B : memref<8x16xf32>) {
// CHECK-LABEL: func @test_vectorize_copy_scalar
func @test_vectorize_copy_scalar(%A : memref<f32>, %B : memref<f32>) {
- // CHECK: %[[V:.*]] = memref.load {{.*}} : memref<f32>
- // CHECK: store %[[V]], {{.*}} : memref<f32>
+ // CHECK-SAME: (%[[A:.*]]: memref<f32>, %[[B:.*]]: memref<f32>)
+ // CHECK: %[[V:.*]] = vector.transfer_read %[[A]][]{{.*}} : memref<f32>, vector<1xf32>
+ // CHECK: %[[val:.*]] = vector.extract %[[V]][0] : vector<1xf32>
+ // CHECK: %[[VV:.*]] = vector.broadcast %[[val]] : f32 to vector<1xf32>
+ // CHECK: vector.transfer_write %[[VV]], %[[B]][] {{.*}} : vector<1xf32>, memref<f32>
linalg.copy(%A, %B) : memref<f32>, memref<f32>
return
}
@@ -857,3 +861,42 @@ func @red_min_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
return %red : tensor<4xf32>
}
+// -----
+
+// CHECK-LABEL: func @reduce_1d(
+// CHECK-SAME: %[[A:.*]]: tensor<32xf32>
+func @reduce_1d(%arg0: tensor<32xf32>) -> tensor<f32> {
+ // CHECK-DAG: %[[F0_v1:.*]] = constant dense<0.000000e+00> : vector<1xf32>
+ // CHECK-DAG: %[[F0_v32:.*]] = constant dense<0.000000e+00> : vector<32xf32>
+ // CHECK-DAG: %[[C0:.*]] = constant 0 : index
+ %f0 = constant 0.000000e+00 : f32
+
+ // CHECK: %[[init:.*]] = linalg.init_tensor [] : tensor<f32>
+ %0 = linalg.init_tensor [] : tensor<f32>
+
+ // CHECK: %[[f:.*]] = vector.transfer_write %[[F0_v1]], %[[init]][]
+ // CHECK-SAME: : vector<1xf32>, tensor<f32>
+ %1 = linalg.fill(%f0, %0) : f32, tensor<f32> -> tensor<f32>
+
+ // CHECK: %[[r:.*]] = vector.transfer_read %[[A]][%[[C0]]]
+ // CHECK-SAME: : tensor<32xf32>, vector<32xf32>
+ // CHECK: %[[a:.*]] = addf %[[r]], %[[F0_v32]] : vector<32xf32>
+ // CHECK: %[[red:.*]] = vector.multi_reduction #vector.kind<add>, %[[a]] [0]
+ // CHECK-SAME: : vector<32xf32> to f32
+ // CHECK: %[[red_v1:.*]] = vector.broadcast %[[red]] : f32 to vector<1xf32>
+ // CHECK: %[[res:.*]] = vector.transfer_write %[[red_v1]], %[[f]][]
+ // CHECK-SAME: : vector<1xf32>, tensor<f32>
+ %2 = linalg.generic {
+ indexing_maps = [affine_map<(d0) -> (d0)>,
+ affine_map<(d0) -> ()>],
+ iterator_types = ["reduction"]}
+ ins(%arg0 : tensor<32xf32>)
+ outs(%1 : tensor<f32>) {
+ ^bb0(%a: f32, %b: f32): // no predecessors
+ %3 = addf %a, %b : f32
+ linalg.yield %3 : f32
+ } -> tensor<f32>
+
+ return %2 : tensor<f32>
+}
+
More information about the Mlir-commits
mailing list