[Mlir-commits] [mlir] 0fcbbde - [mlir][Linalg] NFC - Refactor vectorization to be more composable
Nicolas Vasilache
llvmlistbot at llvm.org
Fri Feb 5 04:03:57 PST 2021
Author: Nicolas Vasilache
Date: 2021-02-05T12:03:14Z
New Revision: 0fcbbde2c7b02b89503d5d1b631229d64eab7104
URL: https://github.com/llvm/llvm-project/commit/0fcbbde2c7b02b89503d5d1b631229d64eab7104
DIFF: https://github.com/llvm/llvm-project/commit/0fcbbde2c7b02b89503d5d1b631229d64eab7104.diff
LOG: [mlir][Linalg] NFC - Refactor vectorization to be more composable
Differential Revision: https://reviews.llvm.org/D96116
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/Dialect/Linalg/vectorization.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 16203e5459b9..942581b4bbaf 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -31,13 +31,6 @@ struct LinalgTilingOptions;
//===----------------------------------------------------------------------===//
using LinalgLoops = SmallVector<Operation *, 4>;
-struct TiledLinalgOp {
- LinalgOp op;
- SmallVector<Operation *, 8> loops;
- SmallVector<Value, 4> tensorResults;
- TiledLinalgOp &operator=(const TiledLinalgOp &) = default;
-};
-
/// Populates patterns for vectorization of all ConvN-D ops.
void populateConvVectorizationPatterns(
MLIRContext *context, SmallVectorImpl<OwningRewritePatternList> &patterns,
@@ -63,6 +56,12 @@ void populateLinalgBufferizePatterns(MLIRContext *context,
/// `interchangeVector = [1,2,0]`. All values in `interchangeVector` must be
/// integers, in the range 0..`tileSizes.size()` without duplications
/// (i.e. `[1,1,2]` is an invalid permutation).
+struct TiledLinalgOp {
+ LinalgOp op;
+ SmallVector<Operation *, 8> loops;
+ SmallVector<Value, 4> tensorResults;
+ TiledLinalgOp &operator=(const TiledLinalgOp &) = default;
+};
Optional<TiledLinalgOp> tileLinalgOp(OpBuilder &b, LinalgOp op,
const LinalgTilingOptions &options);
@@ -264,7 +263,12 @@ Optional<LinalgOp> promoteSubViews(OpBuilder &b, LinalgOp op,
OperationFolder *folder = nullptr);
/// Emit a suitable vector form for a Linalg op with fully static shape.
-void vectorizeLinalgOp(OpBuilder &builder, Operation *op);
+struct VectorizedLinalgOp {
+ SmallVector<Value> tensorResults;
+ VectorizedLinalgOp &operator=(const VectorizedLinalgOp &) = default;
+};
+Optional<VectorizedLinalgOp> vectorizeLinalgOp(OpBuilder &builder,
+ Operation *op);
/// Emits a loop nest of `LoopTy` with the proper body for `op`.
template <typename LoopTy>
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 8dac82a57de5..b80b6fb090e7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -468,10 +468,13 @@ LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite(
return failure();
if (failed(filter.checkAndNotify(rewriter, linalgOp)))
return failure();
- if (failed(vectorizeLinalgOpPrecondition(op)))
+ Optional<VectorizedLinalgOp> res = vectorizeLinalgOp(rewriter, op);
+ if (!res)
return failure();
- vectorizeLinalgOp(rewriter, op);
- rewriter.eraseOp(op);
+ if (!res->tensorResults.empty())
+ rewriter.replaceOp(op, res->tensorResults);
+ else
+ rewriter.eraseOp(op);
return success();
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 6e5b49125845..a9a43e194d75 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -248,8 +248,7 @@ vectorizeOneOp(OpBuilder &builder, Operation *op,
/// TODO: Reuse opportunities for RAR dependencies.
/// 4. Register CustomVectorizationHook for YieldOp to capture the results.
/// 5. Iteratively call vectorizeOneOp on the region operations.
-/// 6. RAUW the linalg op by the results captured vectorizing the YieldOp.
-static LogicalResult vectorizeAsLinalgGeneric(
+static Optional<VectorizedLinalgOp> vectorizeAsLinalgGeneric(
OpBuilder &builder, LinalgOp linalgOp,
ArrayRef<CustomVectorizationHook> customVectorizationHooks = {}) {
// 1. Certain Linalg ops do not have a region but only a region builder.
@@ -306,7 +305,7 @@ static LogicalResult vectorizeAsLinalgGeneric(
VectorizationResult result = vectorizeOneOp(builder, &op, bvm, hooks);
if (result.status == VectorizationStatus::Failure) {
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: failed to vectorize: " << op);
- return failure();
+ return llvm::None;
}
if (result.status == VectorizationStatus::NewOp) {
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vector op: "
@@ -315,10 +314,7 @@ static LogicalResult vectorizeAsLinalgGeneric(
}
}
- // 6. RAUW the linalg op by the results captured vectorizing the YieldOp.
- if (!results.empty())
- linalgOp->replaceAllUsesWith(results);
- return success();
+ return VectorizedLinalgOp{{results}};
}
/// Detect whether `r` has only ConstantOp, ElementwiseMappable and YieldOp.
@@ -357,7 +353,8 @@ static bool isElementwise(Operation *op) {
return hasOnlyScalarElementwiseOp(genericOp.getRegion());
}
-static void vectorizeContraction(OpBuilder &builder, LinalgOp linalgOp) {
+static Optional<VectorizedLinalgOp> vectorizeContraction(OpBuilder &builder,
+ LinalgOp linalgOp) {
assert(isaContractionOpInterface(linalgOp) &&
"expected vectorizeContraction preconditions to be met");
Location loc = linalgOp.getLoc();
@@ -384,11 +381,7 @@ static void vectorizeContraction(OpBuilder &builder, LinalgOp linalgOp) {
linalgOp.indexing_maps(), linalgOp.iterator_types());
return VectorizationResult{VectorizationStatus::NewOp, contract};
};
- auto status =
- vectorizeAsLinalgGeneric(builder, linalgOp, {vectorizeContraction});
- (void)status;
- assert(succeeded(status) &&
- "Unexpected vectorization failed despite preconditions");
+ return vectorizeAsLinalgGeneric(builder, linalgOp, {vectorizeContraction});
}
LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
@@ -408,8 +401,10 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
return success(isaContractionOpInterface(linalgOp));
}
-void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
- assert(succeeded(vectorizeLinalgOpPrecondition(op)));
+Optional<VectorizedLinalgOp> mlir::linalg::vectorizeLinalgOp(OpBuilder &builder,
+ Operation *op) {
+ if (failed(vectorizeLinalgOpPrecondition(op)))
+ return llvm::None;
edsc::ScopedContext scope(builder, op->getLoc());
// In the case of 0-D memrefs, return null and special case to scalar load or
@@ -418,8 +413,10 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
// Vectorize fill as a vector.broadcast.
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
<< "Rewrite linalg.fill as vector.broadcast: " << *op);
- buildVectorWrite(builder, fillOp.value(), fillOp.output());
- return;
+ VectorizedLinalgOp res;
+ if (Value v = buildVectorWrite(builder, fillOp.value(), fillOp.output()))
+ res.tensorResults.push_back(v);
+ return res;
}
if (auto copyOp = dyn_cast<linalg::CopyOp>(op)) {
// Vectorize copy as a vector.transfer_read+vector.transfer_write.
@@ -428,21 +425,26 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
"vector.transfer_write: "
<< *op);
Value vector = buildVectorRead(builder, copyOp.input());
- buildVectorWrite(builder, vector, copyOp.output());
- return;
+ VectorizedLinalgOp res;
+ if (Value v = buildVectorWrite(builder, vector, copyOp.output()))
+ res.tensorResults.push_back(v);
+ return res;
}
-
if (isElementwise(op)) {
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
- << "Rewrite linalg op as vector.transfer_read + " << *op);
- auto status = vectorizeAsLinalgGeneric(builder, cast<LinalgOp>(op));
- (void)status;
- assert(succeeded(status) &&
- "Unexpected vectorization failed despite preconditions");
- return;
+ << "Vectorize linalg op as a generic: " << *op);
+ return vectorizeAsLinalgGeneric(builder, cast<LinalgOp>(op));
}
- vectorizeContraction(builder, cast<LinalgOp>(op));
+ // TODO: as soon as Copy and FillOp. get a region builder, replace all the
+ // above by:
+ // if (isa<FillOp, CopyOp>(op) || isElementwise(op)) {
+ // LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
+ // << "Vectorize linalg op as a generic: " << *op);
+ // return vectorizeAsLinalgGeneric(builder, cast<LinalgOp>(op));
+ // }
+
+ return vectorizeContraction(builder, cast<LinalgOp>(op));
}
//----------------------------------------------------------------------------//
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 3904353287c5..12841a4b6803 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -1,4 +1,6 @@
-// RUN: mlir-opt %s -test-linalg-transform-patterns=test-linalg-to-vector-patterns -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-transform-patterns=test-linalg-to-vector-patterns -split-input-file -debug-only=linalg-vectorization
+
+//| FileCheck %s
// -----
More information about the Mlir-commits
mailing list