[Mlir-commits] [mlir] c1a4cd5 - [mlir][linalg] refactor the result handling during vectorization.
Tobias Gysi
llvmlistbot at llvm.org
Mon Mar 8 23:27:07 PST 2021
Author: Tobias Gysi
Date: 2021-03-09T07:11:57Z
New Revision: c1a4cd551f1c577008c33d78972929ba6593efcc
URL: https://github.com/llvm/llvm-project/commit/c1a4cd551f1c577008c33d78972929ba6593efcc
DIFF: https://github.com/llvm/llvm-project/commit/c1a4cd551f1c577008c33d78972929ba6593efcc.diff
LOG: [mlir][linalg] refactor the result handling during vectorization.
Return the vectorization results using a vector passed by reference instead of returning them embedded in a structure.
Differential Revision: https://reviews.llvm.org/D98182
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 5c0d1dc3a2fa..8f422d284df6 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -263,12 +263,8 @@ Optional<LinalgOp> promoteSubViews(OpBuilder &b, LinalgOp op,
OperationFolder *folder = nullptr);
/// Emit a suitable vector form for a Linalg op with fully static shape.
-struct VectorizedLinalgOp {
- SmallVector<Value> tensorResults;
- VectorizedLinalgOp &operator=(const VectorizedLinalgOp &) = default;
-};
-Optional<VectorizedLinalgOp> vectorizeLinalgOp(OpBuilder &builder,
- Operation *op);
+LogicalResult vectorizeLinalgOp(OpBuilder &builder, Operation *op,
+ SmallVectorImpl<Value> &newResults);
/// Emits a loop nest of `LoopTy` with the proper body for `op`.
template <typename LoopTy>
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index dd92ccd838cd..7f604807030d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -468,11 +468,11 @@ LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite(
return failure();
if (failed(filter.checkAndNotify(rewriter, linalgOp)))
return failure();
- Optional<VectorizedLinalgOp> res = vectorizeLinalgOp(rewriter, op);
- if (!res)
+ SmallVector<Value> newResults;
+ if (failed(vectorizeLinalgOp(rewriter, op, newResults)))
return failure();
- if (!res->tensorResults.empty())
- rewriter.replaceOp(op, res->tensorResults);
+ if (!newResults.empty())
+ rewriter.replaceOp(op, newResults);
else
rewriter.eraseOp(op);
return success();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index f471ab0ebd75..48b6165d7b68 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -139,16 +139,16 @@ using CustomVectorizationHook = std::function<VectorizationResult(
Operation *, const BlockAndValueMapping &)>;
/// Helper function to vectorize the terminator of a `linalgOp`. New result
-/// vector values are appended to `results`.
-/// Return VectorizationStatus::NoReplace to signal the vectorization algorithm
-/// that it should not try to map produced operations: this is the purpose of
-/// the `results` argument to capture such values and make them available for
-/// RAUW to the vectorization algorithm.
-/// This function is meant to be used as a CustomVectorizationHook.
+/// vector values are appended to `newResults`. Return
+/// VectorizationStatus::NoReplace to signal the vectorization algorithm that it
+/// should not try to map produced operations and instead return the results
+/// using the `newResults` vector making them available to the
+/// vectorization algorithm for RAUW. This function is meant to be used as a
+/// CustomVectorizationHook.
static VectorizationResult
vectorizeLinalgYield(OpBuilder &builder, Operation *op,
const BlockAndValueMapping &bvm, LinalgOp linalgOp,
- SmallVectorImpl<Value> &results) {
+ SmallVectorImpl<Value> &newResults) {
auto yieldOp = dyn_cast<linalg::YieldOp>(op);
if (!yieldOp)
return VectorizationResult{VectorizationStatus::Failure, nullptr};
@@ -156,10 +156,10 @@ vectorizeLinalgYield(OpBuilder &builder, Operation *op,
// TODO: Scan for an opportunity for reuse.
// TODO: use a map.
Value vectorValue = bvm.lookup(outputs.value());
- Value result = buildVectorWrite(builder, vectorValue,
- linalgOp.getOutput(outputs.index()));
- if (result)
- results.push_back(result);
+ Value newResult = buildVectorWrite(builder, vectorValue,
+ linalgOp.getOutput(outputs.index()));
+ if (newResult)
+ newResults.push_back(newResult);
}
return VectorizationResult{VectorizationStatus::NoReplace, nullptr};
}
@@ -248,8 +248,8 @@ vectorizeOneOp(OpBuilder &builder, Operation *op,
/// TODO: Reuse opportunities for RAR dependencies.
/// 4. Register CustomVectorizationHook for YieldOp to capture the results.
/// 5. Iteratively call vectorizeOneOp on the region operations.
-static Optional<VectorizedLinalgOp> vectorizeAsLinalgGeneric(
- OpBuilder &builder, LinalgOp linalgOp,
+LogicalResult vectorizeAsLinalgGeneric(
+ OpBuilder &builder, LinalgOp linalgOp, SmallVectorImpl<Value> &newResults,
ArrayRef<CustomVectorizationHook> customVectorizationHooks = {}) {
// 1. Certain Linalg ops do not have a region but only a region builder.
// If so, build the region so we can vectorize.
@@ -290,11 +290,10 @@ static Optional<VectorizedLinalgOp> vectorizeAsLinalgGeneric(
}
// 4. Register CustomVectorizationHook for yieldOp.
- SmallVector<Value> results;
CustomVectorizationHook vectorizeYield =
[&](Operation *op,
const BlockAndValueMapping &bvm) -> VectorizationResult {
- return vectorizeLinalgYield(builder, op, bvm, linalgOp, results);
+ return vectorizeLinalgYield(builder, op, bvm, linalgOp, newResults);
};
// Append the vectorizeYield hook.
auto hooks = llvm::to_vector<4>(customVectorizationHooks);
@@ -305,7 +304,7 @@ static Optional<VectorizedLinalgOp> vectorizeAsLinalgGeneric(
VectorizationResult result = vectorizeOneOp(builder, &op, bvm, hooks);
if (result.status == VectorizationStatus::Failure) {
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: failed to vectorize: " << op);
- return llvm::None;
+ return failure();
}
if (result.status == VectorizationStatus::NewOp) {
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vector op: "
@@ -314,7 +313,7 @@ static Optional<VectorizedLinalgOp> vectorizeAsLinalgGeneric(
}
}
- return VectorizedLinalgOp{{results}};
+ return success();
}
/// Detect whether `r` has only ConstantOp, ElementwiseMappable and YieldOp.
@@ -355,8 +354,8 @@ static bool isElementwise(Operation *op) {
return hasOnlyScalarElementwiseOp(linalgOp->getRegion(0));
}
-static Optional<VectorizedLinalgOp> vectorizeContraction(OpBuilder &builder,
- LinalgOp linalgOp) {
+static LogicalResult vectorizeContraction(OpBuilder &builder, LinalgOp linalgOp,
+ SmallVectorImpl<Value> &newResults) {
assert(isaContractionOpInterface(linalgOp) &&
"expected vectorizeContraction preconditions to be met");
Location loc = linalgOp.getLoc();
@@ -383,7 +382,8 @@ static Optional<VectorizedLinalgOp> vectorizeContraction(OpBuilder &builder,
linalgOp.indexing_maps(), linalgOp.iterator_types());
return VectorizationResult{VectorizationStatus::NewOp, contract};
};
- return vectorizeAsLinalgGeneric(builder, linalgOp, {vectorizeContraction});
+ return vectorizeAsLinalgGeneric(builder, linalgOp, newResults,
+ {vectorizeContraction});
}
LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
@@ -400,19 +400,20 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
return success(isaContractionOpInterface(linalgOp));
}
-Optional<VectorizedLinalgOp> mlir::linalg::vectorizeLinalgOp(OpBuilder &builder,
- Operation *op) {
+LogicalResult
+mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op,
+ SmallVectorImpl<Value> &newResults) {
if (failed(vectorizeLinalgOpPrecondition(op)))
- return llvm::None;
+ return failure();
edsc::ScopedContext scope(builder, op->getLoc());
if (isElementwise(op)) {
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
<< "Vectorize linalg op as a generic: " << *op);
- return vectorizeAsLinalgGeneric(builder, cast<LinalgOp>(op));
+ return vectorizeAsLinalgGeneric(builder, cast<LinalgOp>(op), newResults);
}
- return vectorizeContraction(builder, cast<LinalgOp>(op));
+ return vectorizeContraction(builder, cast<LinalgOp>(op), newResults);
}
//----------------------------------------------------------------------------//
More information about the Mlir-commits
mailing list