[Mlir-commits] [mlir] 912ebf6 - [mlir][linalg] Cleanup LinalgOp usage in vectorization (NFC).
Tobias Gysi
llvmlistbot at llvm.org
Tue Jun 1 01:09:29 PDT 2021
Author: Tobias Gysi
Date: 2021-06-01T08:08:40Z
New Revision: 912ebf60b15123827299df73a7c9136f6693b487
URL: https://github.com/llvm/llvm-project/commit/912ebf60b15123827299df73a7c9136f6693b487
DIFF: https://github.com/llvm/llvm-project/commit/912ebf60b15123827299df73a7c9136f6693b487.diff
LOG: [mlir][linalg] Cleanup LinalgOp usage in vectorization (NFC).
Replace the uses of deprecated Structured Op Interface methods in Vectorization.cpp. This patch is based on https://reviews.llvm.org/D103394.
Differential Revision: https://reviews.llvm.org/D103410
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 7ee5d5f4dd744..12a8d80c72fcc 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -116,14 +116,14 @@ static VectorType extractVectorTypeFromShapedValue(Value v) {
/// Linalg. This limitation is motivated by the fact that e.g.
/// min(max(X)) != max(min(X))
// TODO: use in LinalgOp verification, there is a circular dependency atm.
-static Operation *getSingleBinaryOpAssumedReduction(OpOperand &outputOperand) {
- auto linalgOp = cast<LinalgOp>(outputOperand.getOwner());
+static Operation *getSingleBinaryOpAssumedReduction(OpOperand *outputOperand) {
+ auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
auto yieldOp = cast<YieldOp>(linalgOp->getRegion(0).front().getTerminator());
unsigned yieldNum =
- outputOperand.getOperandNumber() - linalgOp.getNumInputs();
+ outputOperand->getOperandNumber() - linalgOp.getNumInputs();
llvm::SetVector<Operation *> backwardSlice, forwardSlice;
BlockArgument bbArg = linalgOp->getRegion(0).front().getArgument(
- outputOperand.getOperandNumber());
+ outputOperand->getOperandNumber());
Value yieldVal = yieldOp->getOperand(yieldNum);
getBackwardSlice(yieldVal, &backwardSlice, [&](Operation *op) {
return op->getParentOp() == linalgOp;
@@ -186,16 +186,15 @@ getKindForOp(Operation *reductionOp) {
/// return a new vector.broadcast to `shape`.
/// Otherwise, just return value.
static Value reduceIfNeeded(OpBuilder &b, VectorType targetVectorType,
- Value value, OpOperand &outputOperand) {
- assert(targetVectorType.getShape() ==
- outputOperand.get().getType().cast<ShapedType>().getShape());
+ Value value, OpOperand *outputOperand) {
+ auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
+ assert(targetVectorType.getShape() == linalgOp.getShape(outputOperand));
auto vecType = value.getType().dyn_cast<VectorType>();
if (!vecType || vecType.getShape() == targetVectorType.getShape())
return value;
// At this point, we know we need to reduce. Detect the reduction operator.
// TODO: Use the generic reduction detection util.
Operation *reductionOp = getSingleBinaryOpAssumedReduction(outputOperand);
- auto linalgOp = cast<LinalgOp>(outputOperand.getOwner());
unsigned pos = 0;
MLIRContext *ctx = b.getContext();
SmallVector<AffineExpr> exprs;
@@ -235,23 +234,22 @@ static Value buildVectorRead(OpBuilder &b, Value source, VectorType vectorType,
/// currently being vectorized. If `dest` has null rank, build an memref.store.
/// Return the produced value or null if no value is produced.
static Value buildVectorWrite(OpBuilder &b, Value value,
- OpOperand &outputOperand) {
+ OpOperand *outputOperand) {
Operation *write;
Location loc = value.getLoc();
- auto shapedType = outputOperand.get().getType().cast<ShapedType>();
if (VectorType vectorType =
- extractVectorTypeFromShapedValue(outputOperand.get())) {
- auto linalgOp = cast<LinalgOp>(outputOperand.getOwner());
- AffineMap map = reindexIndexingMap(
- linalgOp.getIndexingMap(outputOperand.getOperandNumber()));
- SmallVector<Value> indices(shapedType.getRank(),
+ extractVectorTypeFromShapedValue(outputOperand->get())) {
+ auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
+ AffineMap map =
+ reindexIndexingMap(linalgOp.getTiedIndexingMap(outputOperand));
+ SmallVector<Value> indices(linalgOp.getRank(outputOperand),
b.create<ConstantIndexOp>(loc, 0));
value = broadcastIfNeeded(b, value, vectorType.getShape());
value = reduceIfNeeded(b, vectorType, value, outputOperand);
- write = b.create<vector::TransferWriteOp>(loc, value, outputOperand.get(),
+ write = b.create<vector::TransferWriteOp>(loc, value, outputOperand->get(),
indices, map);
} else {
- write = b.create<memref::StoreOp>(loc, value, outputOperand.get());
+ write = b.create<memref::StoreOp>(loc, value, outputOperand->get());
}
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: vectorized op: " << *write);
if (!write->getResults().empty())
@@ -284,7 +282,7 @@ vectorizeLinalgYield(OpBuilder &b, Operation *op,
// TODO: use a map.
Value vectorValue = bvm.lookup(outputs.value());
Value newResult = buildVectorWrite(
- b, vectorValue, linalgOp.getOutputOpOperands()[outputs.index()]);
+ b, vectorValue, linalgOp.getOutputOperand(outputs.index()));
if (newResult)
newResults.push_back(newResult);
}
@@ -422,8 +420,8 @@ static bool isElementwise(Operation *op) {
if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops())
return false;
// TODO: relax the restrictions on indexing map.
- for (unsigned i = 0, e = linalgOp.getNumOutputs(); i < e; i++) {
- if (!linalgOp.getOutputIndexingMap(i).isIdentity())
+ for (OpOperand *opOperand : linalgOp.getOutputOperands()) {
+ if (!linalgOp.getTiedIndexingMap(opOperand).isIdentity())
return false;
}
if (linalgOp->getNumRegions() != 1)
@@ -479,36 +477,37 @@ LogicalResult vectorizeAsLinalgGeneric(
// 3. Turn all BBArgs into vector.transfer_read / load.
SmallVector<AffineMap> indexings;
- for (auto bbarg : block.getArguments()) {
- Value shapedArg = linalgOp.getShapedOperand(bbarg.getArgNumber());
- ShapedType shapedType = shapedArg.getType().cast<ShapedType>();
+ for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
+ BlockArgument bbarg = block.getArgument(opOperand->getOperandNumber());
// TODO: 0-d vectors.
- if (shapedType.getShape().empty()) {
- Value loaded = b.create<memref::LoadOp>(linalgOp.getLoc(), shapedArg);
+ if (linalgOp.getShape(opOperand).empty()) {
+ Value loaded =
+ b.create<memref::LoadOp>(linalgOp.getLoc(), opOperand->get());
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vectorized bbarg("
<< bbarg.getArgNumber() << "): " << loaded);
bvm.map(bbarg, loaded);
- bvm.map(shapedArg, loaded);
+ bvm.map(opOperand->get(), loaded);
continue;
}
AffineMap map;
VectorType vectorType;
if (broadcastToMaximalCommonShape) {
map = inverseAndBroadcastProjectedPermuation(
- linalgOp.getIndexingMap(bbarg.getArgNumber()));
- vectorType =
- VectorType::get(commonVectorShape, shapedType.getElementType());
+ linalgOp.getTiedIndexingMap(opOperand));
+ vectorType = VectorType::get(
+ commonVectorShape, getElementTypeOrSelf(opOperand->get().getType()));
} else {
map = inversePermutation(
- reindexIndexingMap(linalgOp.getIndexingMap(bbarg.getArgNumber())));
- vectorType = VectorType::get(map.compose(shapedType.getShape()),
- shapedType.getElementType());
+ reindexIndexingMap(linalgOp.getTiedIndexingMap(opOperand)));
+ vectorType =
+ VectorType::get(map.compose(linalgOp.getShape(opOperand)),
+ getElementTypeOrSelf(opOperand->get().getType()));
}
- Value vectorRead = buildVectorRead(b, shapedArg, vectorType, map);
+ Value vectorRead = buildVectorRead(b, opOperand->get(), vectorType, map);
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vectorized bbarg("
<< bbarg.getArgNumber() << "): " << vectorRead);
bvm.map(bbarg, vectorRead);
- bvm.map(shapedArg, vectorRead);
+ bvm.map(opOperand->get(), vectorRead);
}
auto hooks = llvm::to_vector<4>(customVectorizationHooks);
@@ -562,7 +561,8 @@ static LogicalResult vectorizeContraction(OpBuilder &b, LinalgOp linalgOp,
const BlockAndValueMapping &bvm) -> VectorizationResult {
if (!isa<MulIOp, MulFOp>(op))
return VectorizationResult{VectorizationStatus::Failure, nullptr};
- auto outShape = linalgOp.getOutputShapedType(0).getShape();
+ ArrayRef<int64_t> outShape =
+ linalgOp.getShape(linalgOp.getOutputOperand(0));
auto vType = outShape.empty()
? op->getResult(0).getType()
: VectorType::get(outShape, op->getResult(0).getType());
@@ -574,13 +574,14 @@ static LogicalResult vectorizeContraction(OpBuilder &b, LinalgOp linalgOp,
// TODO: consider dropping contraction special casing altogether, this will
// require more advanced canonicalizations involving vector.multi_reduction
// that are not yet available.
- SmallVector<AffineMap> indexingMaps{
- inversePermutation(reindexIndexingMap(linalgOp.getIndexingMap(0)))
- .compose(linalgOp.getIndexingMap(0)),
- inversePermutation(reindexIndexingMap(linalgOp.getIndexingMap(1)))
- .compose(linalgOp.getIndexingMap(1)),
- inversePermutation(reindexIndexingMap(linalgOp.getIndexingMap(2)))
- .compose(linalgOp.getIndexingMap(2))};
+ SmallVector<AffineMap> indexingMaps;
+ indexingMaps.reserve(linalgOp.getNumInputsAndOutputs());
+ llvm::transform(linalgOp.getIndexingMaps(),
+ std::back_inserter(indexingMaps),
+ [](AffineMap indexingMap) {
+ return inversePermutation(reindexIndexingMap(indexingMap))
+ .compose(indexingMap);
+ });
Operation *contract = b.create<vector::ContractionOp>(
loc, bvm.lookup(op->getOperand(0)), bvm.lookup(op->getOperand(1)), zero,
b.getAffineMapArrayAttr(indexingMaps), linalgOp.iterator_types());
@@ -601,8 +602,8 @@ static bool allIndexingsAreProjectedPermutation(LinalgOp op) {
static LogicalResult reductionPreconditions(LinalgOp op) {
if (llvm::none_of(op.iterator_types(), isReductionIteratorType))
return failure();
- for (auto &operand : op.getOutputOpOperands()) {
- Operation *reductionOp = getSingleBinaryOpAssumedReduction(operand);
+ for (OpOperand *opOperand : op.getOutputOperands()) {
+ Operation *reductionOp = getSingleBinaryOpAssumedReduction(opOperand);
if (!getKindForOp(reductionOp))
return failure();
}
@@ -612,12 +613,8 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
auto linalgOp = cast<linalg::LinalgOp>(op);
// All types must be static shape to go to vector.
- for (Value operand : linalgOp.getShapedOperands())
- if (!operand.getType().cast<ShapedType>().hasStaticShape())
- return failure();
- for (Type outputTensorType : linalgOp.getOutputTensorTypes())
- if (!outputTensorType.cast<ShapedType>().hasStaticShape())
- return failure();
+ if (linalgOp.hasDynamicShape())
+ return failure();
if (isElementwise(op))
return success();
if (isaContractionOpInterface(linalgOp))
@@ -722,13 +719,14 @@ LogicalResult ConvOpVectorization<ConvOp, N>::matchAndRewrite(
Location loc = op.getLoc();
MLIRContext *context = op.getContext();
- ShapedType inShapeType = op.getInputShapedType(0);
- ShapedType kShapeType = op.getInputShapedType(1);
-
- ArrayRef<int64_t> inShape = inShapeType.getShape();
- ArrayRef<int64_t> kShape = kShapeType.getShape();
+ OpOperand *input = op.getInputOperand(0);
+ OpOperand *kernel = op.getInputOperand(1);
+ OpOperand *output = op.getOutputOperand(0);
+ ArrayRef<int64_t> inShape = op.getShape(input);
+ ArrayRef<int64_t> kShape = op.getShape(kernel);
- if (!inShapeType.hasStaticShape() || !kShapeType.hasStaticShape())
+ if (llvm::any_of(inShape, ShapedType::isDynamic) ||
+ llvm::any_of(kShape, ShapedType::isDynamic))
return failure();
SmallVector<AffineExpr, 4> mapping;
@@ -747,22 +745,18 @@ LogicalResult ConvOpVectorization<ConvOp, N>::matchAndRewrite(
}
}
- Value input = op.getInput(0);
- Value kernel = op.getInput(1);
- Value output = op.getOutputBuffer(0);
-
- unsigned rank = inShapeType.getRank();
- unsigned numDims = mapping.size();
- Type elemType = inShapeType.getElementType();
+ int64_t rank = op.getRank(input);
+ int64_t numDims = mapping.size();
+ Type elemType = getElementTypeOrSelf(input->get().getType());
auto map = AffineMap::get(rank, 0, mapping, context);
SmallVector<Value, 4> zeros(rank, rewriter.create<ConstantIndexOp>(loc, 0));
auto vecType = VectorType::get(vectorDims, elemType);
- auto inputVec =
- rewriter.create<vector::TransferReadOp>(loc, vecType, input, zeros, map);
- auto kernelVec =
- rewriter.create<vector::TransferReadOp>(loc, vecType, kernel, zeros, map);
+ auto inputVec = rewriter.create<vector::TransferReadOp>(
+ loc, vecType, input->get(), zeros, map);
+ auto kernelVec = rewriter.create<vector::TransferReadOp>(
+ loc, vecType, kernel->get(), zeros, map);
auto acc = rewriter.create<ConstantOp>(loc, elemType,
rewriter.getZeroAttr(elemType));
@@ -779,7 +773,8 @@ LogicalResult ConvOpVectorization<ConvOp, N>::matchAndRewrite(
rewriter.getAffineMapArrayAttr(indexingMaps),
rewriter.getStrArrayAttr(iteratorTypes));
- rewriter.create<memref::StoreOp>(loc, result, output, ValueRange(zeros));
+ rewriter.create<memref::StoreOp>(loc, result, output->get(),
+ ValueRange(zeros));
rewriter.eraseOp(op);
return success();
}
@@ -939,7 +934,8 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
CopyOp copyOp;
for (auto &u : subView.getUses()) {
if (auto newCopyOp = dyn_cast<CopyOp>(u.getOwner())) {
- if (newCopyOp.getOutputBuffer(0) != subView)
+ assert(newCopyOp.output().getType().isa<MemRefType>());
+ if (newCopyOp.output() != subView)
continue;
LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
<< "copy candidate " << *newCopyOp);
@@ -958,7 +954,8 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
FillOp maybeFillOp;
for (auto &u : viewOrAlloc.getUses()) {
if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
- if (newFillOp.getOutputBuffer(0) != viewOrAlloc)
+ assert(newFillOp.output().getType().isa<MemRefType>());
+ if (newFillOp.output() != viewOrAlloc)
continue;
LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
<< "fill candidate " << *newFillOp);
@@ -976,7 +973,7 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
<< "with maybeFillOp " << *maybeFillOp);
// `in` is the subview that linalg.copy reads. Replace it.
- Value in = copyOp.getInput(0);
+ Value in = copyOp.input();
// linalg.copy + linalg.fill can be used to create a padded local buffer.
// The `masked` attribute is only valid on this padded buffer.
@@ -1014,7 +1011,7 @@ LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
CopyOp copyOp;
for (auto &u : subViewOp.getResult().getUses()) {
if (auto newCopyOp = dyn_cast<CopyOp>(u.getOwner())) {
- if (newCopyOp.getInput(0) != subView)
+ if (newCopyOp.getInputOperand(0)->get() != subView)
continue;
if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
continue;
@@ -1026,7 +1023,8 @@ LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
return failure();
// `out` is the subview copied into that we replace.
- Value out = copyOp.getOutputBuffer(0);
+ assert(copyOp.output().getType().isa<MemRefType>());
+ Value out = copyOp.output();
// Forward vector.transfer into copy.
// linalg.copy + linalg.fill can be used to create a padded local buffer.
More information about the Mlir-commits
mailing list