[Mlir-commits] [mlir] 0a2a260 - [mlir][Linalg] Refactor Linalg vectorization for better reuse and extensibility.
Nicolas Vasilache
llvmlistbot at llvm.org
Tue Feb 2 03:34:48 PST 2021
Author: Nicolas Vasilache
Date: 2021-02-02T11:31:09Z
New Revision: 0a2a260aab177bfbdef5829ea16e39323ce50916
URL: https://github.com/llvm/llvm-project/commit/0a2a260aab177bfbdef5829ea16e39323ce50916
DIFF: https://github.com/llvm/llvm-project/commit/0a2a260aab177bfbdef5829ea16e39323ce50916.diff
LOG: [mlir][Linalg] Refactor Linalg vectorization for better reuse and extensibility.
This revision unifies Linalg vectorization and paves the way for vectorization of Linalg ops with mixed-precision operations.
The new algorithm traverses the ops in the linalg block in order and avoids recursion.
It uses a BlockAndValueMapping to keep track of vectorized operations.
The revision makes the following modifications but is otherwise NFC:
1. vector.transfer_read are created eagerly and may appear in a different order than the original order.
2. a more progressive vectorization to vector.contract results in only the multiply operation being converted to `vector.contract %a, %b, %zero`, where `%zero` is a
constant of the proper type. Later vector canonicalizations are assumed to rewrite vector.contract %a, %b, %zero + add to a proper accumulate form.
Differential revision: https://reviews.llvm.org/D95797
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/Dialect/Linalg/vectorization.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index fa1aba8fd157..047e5ee045df 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -23,6 +23,8 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
+#include "mlir/Transforms/RegionUtils.h"
+#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include <type_traits>
@@ -36,6 +38,275 @@ using llvm::dbgs;
#define DEBUG_TYPE "linalg-vectorization"
+/// Helper data structure to represent the result of vectorization.
+/// In certain specific cases, like terminators, we do not want to propagate/
+enum VectorizationStatus {
+ /// Op failed to vectorize.
+ Failure = 0,
+ /// Op vectorized and custom function took care of replacement logic
+ NoReplace,
+ /// Op vectorized into a new Op whose results will replace original Op's
+ /// results.
+ NewOp
+ // TODO: support values if Op vectorized to Many-Ops whose results we need to
+ // aggregate for replacement.
+};
+struct VectorizationResult {
+ /// Return status from vectorizing the current op.
+ enum VectorizationStatus status = VectorizationStatus::Failure;
+ /// New vectorized operation to replace the current op.
+ /// Replacement behavior is specified by `status`.
+ Operation *newOp;
+};
+
+/// Return a vector type of the same shape and element type as the (assumed)
+/// ShapedType of `v`.
+static VectorType extractVectorTypeFromShapedValue(Value v) {
+ auto st = v.getType().cast<ShapedType>();
+ if (st.isa<MemRefType>() && st.getShape().empty())
+ return VectorType();
+ return VectorType::get(st.getShape(), st.getElementType());
+}
+
+/// Build a vector.transfer_read from `source` at indices set to all `0`.
+/// If source has rank zero, build an std.load.
+/// Return the produced value.
+static Value buildVectorRead(OpBuilder &builder, Value source) {
+ edsc::ScopedContext scope(builder);
+ auto shapedType = source.getType().cast<ShapedType>();
+ if (VectorType vectorType = extractVectorTypeFromShapedValue(source)) {
+ SmallVector<Value> indices(shapedType.getRank(), std_constant_index(0));
+ return vector_transfer_read(vectorType, source, indices);
+ }
+ return std_load(source);
+}
+
+/// Build a vector.transfer_write of `value` into `dest` at indices set to all
+/// `0`. If `dest` has null rank, build an std.store.
+/// Return the produced value or null if no value is produced.
+static Value buildVectorWrite(OpBuilder &builder, Value value, Value dest) {
+ edsc::ScopedContext scope(builder);
+ Operation *write;
+ auto shapedType = dest.getType().cast<ShapedType>();
+ if (VectorType vectorType = extractVectorTypeFromShapedValue(dest)) {
+ SmallVector<Value> indices(shapedType.getRank(), std_constant_index(0));
+ if (vectorType != value.getType())
+ value = vector_broadcast(vectorType, value);
+ write = vector_transfer_write(value, dest, indices);
+ } else {
+ write = std_store(value, dest);
+ }
+ LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: vectorized op: " << *write);
+ if (!write->getResults().empty())
+ return write->getResult(0);
+ return Value();
+}
+
+/// If value of assumed VectorType has a shape
diff erent than `shape`, buil and
+/// return a new vector.broadcast to `shape`.
+/// Otherwise, just return value.
+static Value broadcastIfNeeded(OpBuilder &builder, Value value,
+ ArrayRef<int64_t> shape) {
+ auto vecType = value.getType().dyn_cast<VectorType>();
+ if (shape.empty() || (vecType != nullptr && vecType.getShape() == shape))
+ return value;
+ auto newVecType = VectorType::get(shape, vecType ? vecType.getElementType()
+ : value.getType());
+ return builder.create<vector::BroadcastOp>(
+ builder.getInsertionPoint()->getLoc(), newVecType, value);
+}
+
+// Custom vectorization function type. Produce a vector form of Operation*
+// assuming all its vectorized operands are already in the BlockAndValueMapping.
+// Return nullptr if the Operation cannot be vectorized.
+using CustomVectorizationHook = std::function<VectorizationResult(
+ Operation *, const BlockAndValueMapping &)>;
+
+/// Helper function to vectorize the terminator of a `linalgOp`. New result
+/// vector values are appended to `results`.
+/// Return VectorizationStatus::NoReplace to signal the vectorization algorithm
+/// that it should not try to map produced operations: this is the purpose of
+/// the `results` argument to capture such values and make them available for
+/// RAUW to the vectorization algorithm.
+/// This function is meant to be used as a CustomVectorizationHook.
+static VectorizationResult
+vectorizeLinalgYield(OpBuilder &builder, Operation *op,
+ const BlockAndValueMapping &bvm, LinalgOp linalgOp,
+ SmallVectorImpl<Value> &results) {
+ auto yieldOp = dyn_cast<linalg::YieldOp>(op);
+ if (!yieldOp)
+ return VectorizationResult{VectorizationStatus::Failure, nullptr};
+ for (auto outputs : llvm::enumerate(yieldOp.values())) {
+ // TODO: Scan for an opportunity for reuse.
+ // TODO: use a map.
+ Value vectorValue = bvm.lookup(outputs.value());
+ Value result = buildVectorWrite(builder, vectorValue,
+ linalgOp.getOutput(outputs.index()));
+ if (result)
+ results.push_back(result);
+ }
+ return VectorizationResult{VectorizationStatus::NoReplace, nullptr};
+};
+
+/// Generic vectorization for a single operation `op`, given already vectorized
+/// operands carried by `bvm`. Vectorization occurs as follows:
+/// 1. Try to apply any of the `customVectorizationHooks` and return its
+/// result on success.
+/// 2. Clone any constant in the current scope without vectorization: each
+/// consumer of the constant will later determine the shape to which the
+/// constant needs to be broadcast to.
+/// 3. Fail on any remaining non `ElementwiseMappable` op. It is the purpose
+/// of the `customVectorizationHooks` to cover such cases.
+/// 4. Clone `op` in vector form to a vector of shape prescribed by the first
+/// operand of maximal rank. Other operands have smaller rank and are
+/// broadcast accordingly. It is assumed this broadcast is always legal,
+/// otherwise, it means one of the `customVectorizationHooks` is incorrect.
+///
+/// This function assumes all operands of `op` have been vectorized and are in
+/// the `bvm` mapping. As a consequence, this function is meant to be called on
+/// a topologically-sorted list of ops.
+/// This function does not update `bvm` but returns a VectorizationStatus that
+/// instructs the caller what `bvm` update needs to occur.
+static VectorizationResult
+vectorizeOneOp(OpBuilder &builder, Operation *op,
+ const BlockAndValueMapping &bvm,
+ ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
+ LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: vectorize op " << *op);
+
+ // 1. Try to apply any CustomVectorizationHook.
+ if (!customVectorizationHooks.empty()) {
+ for (auto &customFunc : customVectorizationHooks) {
+ VectorizationResult result = customFunc(op, bvm);
+ if (result.status == VectorizationStatus::Failure)
+ continue;
+ return result;
+ }
+ }
+
+ // 2. Constant ops don't get vectorized but rather broadcasted at their users.
+ // Clone so that the constant is not confined to the linalgOp block .
+ if (isa<ConstantOp>(op))
+ return VectorizationResult{VectorizationStatus::NewOp, builder.clone(*op)};
+
+ // 3. Only ElementwiseMappable are allowed in the generic vectorization.
+ if (!op->hasTrait<OpTrait::ElementwiseMappable>())
+ return VectorizationResult{VectorizationStatus::Failure, nullptr};
+
+ // 4. Generic vectorization path for ElementwiseMappable ops.
+ // a. first get the first max ranked shape.
+ SmallVector<int64_t, 4> firstMaxRankedShape;
+ for (Value operand : op->getOperands()) {
+ auto vt = bvm.lookup(operand).getType().dyn_cast<VectorType>();
+ if (vt && firstMaxRankedShape.size() < vt.getShape().size())
+ firstMaxRankedShape.assign(vt.getShape().begin(), vt.getShape().end());
+ }
+ // b. broadcast each op if needed.
+ auto vectorizedOperands = llvm::map_range(op->getOperands(), [&](Value v) {
+ return firstMaxRankedShape.empty()
+ ? bvm.lookup(v)
+ : broadcastIfNeeded(builder, bvm.lookup(v), firstMaxRankedShape);
+ });
+ // c. for elementwise, the result is the vector with the firstMaxRankedShape
+ auto returnTypes = llvm::map_range(op->getResultTypes(), [&](Type t) {
+ return firstMaxRankedShape.empty()
+ ? t
+ : VectorType::get(firstMaxRankedShape, t);
+ });
+
+ // Build and return the new op.
+ OperationState state(op->getLoc(), op->getName());
+ state.addAttributes(op->getAttrs());
+ state.addOperands(llvm::to_vector<4>(vectorizedOperands));
+ state.addTypes(llvm::to_vector<4>(returnTypes));
+ return VectorizationResult{VectorizationStatus::NewOp,
+ builder.createOperation(state)};
+}
+
+/// Generic vectorization function that rewrites the body of a `linalgOp` into
+/// vector form. Generic vectorization proceeds as follows:
+/// 1. The region for the linalg op is created if necessary.
+/// 2. Values defined above the region are mapped to themselves and will be
+/// broadcasted on a per-need basis by their consumers.
+/// 3. Each region argument is vectorized into a vector.transfer_read (or 0-d
+/// load).
+/// TODO: Reuse opportunities for RAR dependencies.
+/// 4. Register CustomVectorizationHook for YieldOp to capture the results.
+/// 5. Iteratively call vectorizeOneOp on the region operations.
+/// 6. RAUW the linalg op by the results captured vectorizing the YieldOp.
+static LogicalResult vectorizeAsLinalgGeneric(
+ OpBuilder &builder, LinalgOp linalgOp,
+ ArrayRef<CustomVectorizationHook> customVectorizationHooks = {}) {
+ // 1. Certain Linalg ops do not have a region but only a region builder.
+ // If so, build the region so we can vectorize.
+ std::unique_ptr<Region> owningRegion;
+ Region *region;
+ if (linalgOp->getNumRegions() > 0) {
+ region = &linalgOp->getRegion(0);
+ } else {
+ // RAII avoid remaining in block.
+ OpBuilder::InsertionGuard g(builder);
+ owningRegion = std::make_unique<Region>();
+ region = owningRegion.get();
+ Block *block = builder.createBlock(region);
+ auto elementTypes = llvm::to_vector<4>(
+ llvm::map_range(linalgOp.getShapedOperandTypes(),
+ [](ShapedType t) { return t.getElementType(); }));
+ block->addArguments(elementTypes);
+ linalgOp.getRegionBuilder()(*block);
+ }
+ Block *block = ®ion->front();
+
+ BlockAndValueMapping bvm;
+ // 2. Values defined above the region can only be broadcast for now. Make them
+ // map to themselves.
+ llvm::SetVector<Value> valuesSet;
+ mlir::getUsedValuesDefinedAbove(*region, valuesSet);
+ bvm.map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
+
+ // 3. Turn all BBArgs into vector.transfer_read / load.
+ SmallVector<AffineMap> indexings;
+ for (auto bbarg : block->getArguments()) {
+ Value vectorArg = linalgOp.getShapedOperand(bbarg.getArgNumber());
+ Value vectorRead = buildVectorRead(builder, vectorArg);
+ LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vectorized bbarg("
+ << bbarg.getArgNumber() << "): " << vectorRead);
+ bvm.map(bbarg, vectorRead);
+ bvm.map(vectorArg, vectorRead);
+ }
+
+ // 4. Register CustomVectorizationHook for yieldOp.
+ SmallVector<Value> results;
+ CustomVectorizationHook vectorizeYield =
+ [&](Operation *op,
+ const BlockAndValueMapping &bvm) -> VectorizationResult {
+ return vectorizeLinalgYield(builder, op, bvm, linalgOp, results);
+ };
+ // Append the vectorizeYield hook.
+ auto hooks = llvm::to_vector<4>(customVectorizationHooks);
+ hooks.push_back(vectorizeYield);
+
+ // 5. Iteratively call `vectorizeOneOp` to each op in the slice.
+ for (Operation &op : block->getOperations()) {
+ VectorizationResult result = vectorizeOneOp(builder, &op, bvm, hooks);
+ if (result.status == VectorizationStatus::Failure) {
+ LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: failed to vectorize: " << op);
+ return failure();
+ }
+ if (result.status == VectorizationStatus::NewOp) {
+ LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vector op: "
+ << *result.newOp;);
+ bvm.map(op.getResults(), result.newOp->getResults());
+ }
+ }
+
+ // 6. RAUW the linalg op by the results captured vectorizing the YieldOp.
+ if (!results.empty())
+ linalgOp->replaceAllUsesWith(results);
+ return success();
+}
+
+/// Detect whether `r` exactly computes a floating-point or integer
+/// multiply-accumulate.
static bool hasMultiplyAddBody(Region &r) {
if (!llvm::hasSingleElement(r))
return false;
@@ -65,6 +336,7 @@ static bool hasMultiplyAddBody(Region &r) {
pattern7.match(&r.front().back()) || pattern8.match(&r.front().back());
}
+/// Detect whether the LinalgOp `op` is a contraction.
// TODO: Should be Tablegen'd from a single source that generates the op itself.
static LogicalResult isContraction(Operation *op) {
// TODO: interface for named ops.
@@ -84,6 +356,7 @@ static LogicalResult isContraction(Operation *op) {
hasMultiplyAddBody(genericOp.region()));
}
+/// Detect whether `r` has only ConstantOp, ElementwiseMappable and YieldOp.
static bool hasOnlyScalarElementwiseOp(Region &r) {
if (!llvm::hasSingleElement(r))
return false;
@@ -119,171 +392,6 @@ static bool isElementwise(Operation *op) {
return hasOnlyScalarElementwiseOp(genericOp.getRegion());
}
-static VectorType extractVectorTypeFromShapedValue(Value v) {
- auto st = v.getType().cast<ShapedType>();
- if (st.isa<MemRefType>() && st.getShape().empty())
- return VectorType();
- return VectorType::get(st.getShape(), st.getElementType());
-}
-
-static Value transferReadVector(OpBuilder &builder, Value source) {
- edsc::ScopedContext scope(builder);
- auto shapedType = source.getType().cast<ShapedType>();
- if (VectorType vectorType = extractVectorTypeFromShapedValue(source)) {
- SmallVector<Value, 4> indices(shapedType.getRank(), std_constant_index(0));
- return vector_transfer_read(vectorType, source, indices);
- }
- return std_load(source);
-}
-
-static Value transferWriteVector(OpBuilder &builder, Value value, Value dest) {
- edsc::ScopedContext scope(builder);
- Operation *write;
- auto shapedType = dest.getType().cast<ShapedType>();
- if (VectorType vectorType = extractVectorTypeFromShapedValue(dest)) {
- SmallVector<Value, 4> indices(shapedType.getRank(), std_constant_index(0));
- if (vectorType != value.getType())
- value = vector_broadcast(vectorType, value);
- write = vector_transfer_write(value, dest, indices);
- } else {
- write = std_store(value, dest);
- }
- if (!write->getResults().empty())
- return write->getResult(0);
- return Value();
-}
-
-namespace {
-// Transforms scalar operations into their vectorized counterparts,
-// while using the provided generic op to map:
-// * Its arguments to transfer reads from the views of the generic op.
-// * linalg.yield ops to transfer writes to the views of the generic op.
-class GenericVectorizer {
-public:
- GenericVectorizer(OpBuilder &builder, linalg::GenericOp generic)
- : builder(builder), generic(generic) {}
-
- // Takes a scalar operation and builds its vectorized counterpart or
- // counterparts using the underlying builder.
- // If operands of the scalar operation are referring to previously vectorized
- // operations, then in their vectorized form these operands will be referring
- // to previous vectorization results.
- void vectorize(Operation &scalarOp) {
- auto yieldOp = dyn_cast<linalg::YieldOp>(scalarOp);
- if (yieldOp) {
- for (auto outputs : llvm::enumerate(yieldOp.values())) {
- Value vectorValue = vectorize(outputs.value());
- Value result = transferWriteVector(builder, vectorValue,
- generic.getOutput(outputs.index()));
- if (result)
- results.push_back(result);
- }
- return;
- }
- Operation *vectorOp = uncachedVectorize(scalarOp);
- assert(scalarOp.getNumResults() == vectorOp->getNumResults());
- for (auto result :
- llvm::zip(scalarOp.getResults(), vectorOp->getResults())) {
- valueCache[std::get<0>(result)] = std::get<1>(result);
- }
- }
-
- llvm::ArrayRef<Value> getResults() { return results; }
-
-private:
- // Transforms a scalar value into its vectorized counterpart, recursively
- // vectorizing operations as necessary using the underlying builder.
- // Keeps track of previously vectorized values and reuses vectorization
- // results if these values come up again.
- Value vectorize(Value scalarValue) {
- // Don't vectorize values coming from outside the region.
- if (scalarValue.getParentRegion() != &generic.region())
- return scalarValue;
- auto vectorValueIt = valueCache.find(scalarValue);
- if (vectorValueIt != valueCache.end())
- return vectorValueIt->second;
-
- // If the value is from the region but not in the cache it means it is a
- // block argument.
- auto scalarArg = scalarValue.cast<BlockArgument>();
- assert(scalarArg.getOwner() == &generic.region().front());
- Value vectorArg = generic.getShapedOperand(scalarArg.getArgNumber());
- Value vectorResult = transferReadVector(builder, vectorArg);
- valueCache[scalarArg] = vectorResult;
- return vectorResult;
- }
-
- // Return the largest shape of all the given values. Return an empty
- // SmallVector if there are no vector value.
- static SmallVector<int64_t, 4> getLargestShape(ArrayRef<Value> values) {
- SmallVector<int64_t, 4> largestShape;
- int64_t maxSize = 1;
- for (Value value : values) {
- auto vecType = value.getType().dyn_cast<VectorType>();
- if (!vecType)
- continue;
- if (maxSize < vecType.getNumElements()) {
- maxSize = vecType.getNumElements();
- largestShape.assign(vecType.getShape().begin(),
- vecType.getShape().end());
- }
- }
- return largestShape;
- }
-
- // If the value's type doesn't have the given shape broadcast it.
- Value broadcastIfNeeded(Value value, ArrayRef<int64_t> shape) {
- auto vecType = value.getType().dyn_cast<VectorType>();
- if (shape.empty() || (vecType != nullptr && vecType.getShape() == shape))
- return value;
- auto newVecType = VectorType::get(shape, vecType ? vecType.getElementType()
- : value.getType());
- return builder.create<vector::BroadcastOp>(
- builder.getInsertionPoint()->getLoc(), newVecType, value);
- }
-
- // Takes a scalar operation and builds its vectorized counterpart or
- // counterparts using underlying builder without involving any caches.
- Operation *uncachedVectorize(Operation &base_scalarOp) {
- SmallVector<Value, 4> vectorizedOperands;
- for (Value operand : base_scalarOp.getOperands()) {
- vectorizedOperands.push_back(vectorize(operand));
- }
- SmallVector<int64_t, 4> shape = getLargestShape(vectorizedOperands);
- for (Value &operand : vectorizedOperands)
- operand = broadcastIfNeeded(operand, shape);
- OperationState state(base_scalarOp.getLoc(), base_scalarOp.getName());
- state.addAttributes(base_scalarOp.getAttrs());
- state.addOperands(vectorizedOperands);
- if (shape.empty()) {
- state.addTypes(base_scalarOp.getResultTypes());
- } else {
- SmallVector<VectorType, 4> vectorizedTypes;
- for (auto Type : base_scalarOp.getResultTypes())
- vectorizedTypes.push_back(VectorType::get(shape, Type));
- state.addTypes(vectorizedTypes);
- }
- return builder.createOperation(state);
- }
-
- OpBuilder &builder;
- linalg::GenericOp generic;
- llvm::DenseMap<Value, Value> valueCache;
- SmallVector<Value, 8> results;
-};
-} // namespace
-
-// Replaces elementwise linalg.generic ops with their bodies with scalar
-// operations from these bodies promoted to vector operations.
-static void vectorizeElementwise(linalg::GenericOp op, OpBuilder &builder) {
- GenericVectorizer vectorizer(builder, op);
- for (Operation &scalarOp : op.region().front()) {
- vectorizer.vectorize(scalarOp);
- }
- if (!op->getResults().empty())
- op->replaceAllUsesWith(vectorizer.getResults());
-}
-
LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
auto linalgOp = cast<linalg::LinalgOp>(op);
// All types must be static shape to go to vector.
@@ -313,7 +421,7 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
// Vectorize fill as a vector.broadcast.
LLVM_DEBUG(dbgs() << dbgPref
<< "Rewrite linalg.fill as vector.broadcast: " << *op);
- transferWriteVector(builder, fillOp.value(), fillOp.output());
+ buildVectorWrite(builder, fillOp.value(), fillOp.output());
return;
}
if (auto copyOp = dyn_cast<linalg::CopyOp>(op)) {
@@ -322,17 +430,21 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
<< "Rewrite linalg.copy as vector.transfer_read + "
"vector.transfer_write: "
<< *op);
- Value vector = transferReadVector(builder, copyOp.input());
- transferWriteVector(builder, vector, copyOp.output());
+ Value vector = buildVectorRead(builder, copyOp.input());
+ buildVectorWrite(builder, vector, copyOp.output());
return;
}
+ auto linalgOp = cast<linalg::LinalgOp>(op);
+ Location loc = linalgOp.getLoc();
+
if (isElementwise(op)) {
LLVM_DEBUG(dbgs() << dbgPref
- << "Rewrite linalg op as vector.transfer_read + "
- "vector_op + vector.transfer_write: "
- << *op);
- return vectorizeElementwise(cast<linalg::GenericOp>(op), builder);
+ << "Rewrite linalg op as vector.transfer_read + " << *op);
+ auto status = vectorizeAsLinalgGeneric(builder, linalgOp);
+ assert(succeeded(status) &&
+ "Unexpected vectorization failed despite preconditions");
+ return;
}
assert(succeeded(isContraction(op)) && "Expected contraction");
@@ -341,15 +453,28 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
// TODO: interface.
LLVM_DEBUG(dbgs() << dbgPref
<< "Rewrite linalg op as vector.contract: " << *op);
- auto linalgOp = cast<linalg::LinalgOp>(op);
- Value a = transferReadVector(builder, linalgOp.getInput(0));
- Value b = transferReadVector(builder, linalgOp.getInput(1));
- Value c = transferReadVector(builder, linalgOp.getOutput(0));
- Value res = vector_contract(a, b, c, linalgOp.indexing_maps(),
- linalgOp.iterator_types());
- Value writeResult = transferWriteVector(builder, res, linalgOp.getOutput(0));
- if (writeResult)
- linalgOp->replaceAllUsesWith(ArrayRef<Value>(writeResult));
+ // Special function that describes how to vectorize the multiplication op in a
+ // linalg contraction.
+ CustomVectorizationHook vectorizeContraction =
+ [&](Operation *op,
+ const BlockAndValueMapping &bvm) -> VectorizationResult {
+ if (!isa<MulIOp, MulFOp>(op))
+ return VectorizationResult{VectorizationStatus::Failure, nullptr};
+ auto outShape = linalgOp.getOutputShapedType(0).getShape();
+ auto vType = outShape.empty()
+ ? op->getResult(0).getType()
+ : VectorType::get(outShape, op->getResult(0).getType());
+ auto zero =
+ builder.create<ConstantOp>(loc, vType, builder.getZeroAttr(vType));
+ Operation *contract = builder.create<vector::ContractionOp>(
+ loc, bvm.lookup(op->getOperand(0)), bvm.lookup(op->getOperand(1)), zero,
+ linalgOp.indexing_maps(), linalgOp.iterator_types());
+ return VectorizationResult{VectorizationStatus::NewOp, contract};
+ };
+ auto status =
+ vectorizeAsLinalgGeneric(builder, linalgOp, {vectorizeContraction});
+ assert(succeeded(status) &&
+ "Unexpected vectorization failed despite preconditions");
}
/// Check whether there is any interleaved use of any `values` between `firstOp`
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 2133aad70bd0..aa249542a07d 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -183,13 +183,13 @@ func @generic_vectorize(%arg0: memref<4x256xf32>, %arg1: memref<4x256xf32>,
// CHECK-DAG: %[[CST0:.*]] = constant dense<2.000000e+00> : vector<4x256xf32>
// CHECK-DAG: %[[CST1:.*]] = constant dense<1.000000e+00> : vector<4x256xf32>
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
+// CHECK: %[[V2:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32>
// CHECK: %[[V0:.*]] = vector.transfer_read %[[ARG2]][%[[C0]]], {{.*}} : memref<256xf32>, vector<256xf32>
+// CHECK: %[[V3:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32>
// CHECK: %[[V1:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32>
// CHECK: %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32>
// CHECK: %[[ADD:.*]] = addf %[[V0B]], %[[V1]] : vector<4x256xf32>
-// CHECK: %[[V2:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32>
// CHECK: %[[CMP:.*]] = cmpf ogt, %[[V2]], %[[V1]] : vector<4x256xf32>
-// CHECK: %[[V3:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32>
// CHECK: %[[ARG3B:.*]] = vector.broadcast %[[ARG3]] : f32 to vector<4x256xf32>
// CHECK: %[[DIV:.*]] = divf %[[V3]], %[[ARG3B]] : vector<4x256xf32>
// CHECK: %[[EXP:.*]] = exp2 %[[V3]] : vector<4x256xf32>
@@ -267,13 +267,13 @@ func @generic_vectorize_tensor(%arg0: tensor<4x256xf32>,
// CHECK-DAG: %[[CST0:.*]] = constant dense<2.000000e+00> : vector<4x256xf32>
// CHECK-DAG: %[[CST1:.*]] = constant dense<1.000000e+00> : vector<4x256xf32>
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
+// CHECK: %[[V2:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32>
// CHECK: %[[V0:.*]] = vector.transfer_read %[[ARG2]][%[[C0]]], {{.*}} : tensor<256xf32>, vector<256xf32>
+// CHECK: %[[V3:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32>
// CHECK: %[[V1:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32>
// CHECK: %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32>
// CHECK: %[[ADD:.*]] = addf %[[V0B]], %[[V1]] : vector<4x256xf32>
-// CHECK: %[[V2:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32>
// CHECK: %[[CMP:.*]] = cmpf ogt, %[[V2]], %[[V1]] : vector<4x256xf32>
-// CHECK: %[[V3:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32>
// CHECK: %[[ARG3B:.*]] = vector.broadcast %[[ARG3]] : f32 to vector<4x256xf32>
// CHECK: %[[DIV:.*]] = divf %[[V3]], %[[ARG3B]] : vector<4x256xf32>
// CHECK: %[[EXP:.*]] = exp2 %[[V3]] : vector<4x256xf32>
@@ -307,10 +307,15 @@ func @matmul_tensors(
// CHECK-LABEL: func @matmul_tensors
// CHECK-SAME: (%[[ARG0:.*]]: tensor<8x4xf32>, %[[ARG1:.*]]: tensor<4x12xf32>,
// CHECK-SAME: %[[ARG2:.*]]: tensor<8x12xf32>) -> tensor<8x12xf32>
-// CHECK: %[[C0:.*]] = constant 0 : index
-// CHECK: %[[V0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x4xf32>, vector<8x4xf32>
-// CHECK: %[[V1:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x12xf32>, vector<4x12xf32>
-// CHECK: %[[V2:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x12xf32>, vector<8x12xf32>
-// CHECK: %[[C:.*]] = vector.contract {{.*}} iterator_types = ["parallel", "parallel", "reduction"]} %[[V0]], %[[V1]], %[[V2]] : vector<8x4xf32>, vector<4x12xf32> into vector<8x12xf32>
-// CHECK: %[[W:.*]] = vector.transfer_write %[[C]], %[[ARG2]][%[[C0]], %[[C0]]] {masked = [false, false]} : vector<8x12xf32>, tensor<8x12xf32>
+// CHECK-DAG: %[[C0:.*]] = constant 0 : index
+// CHECK-DAG: %[[VEC_C0:.*]] = constant dense<0.000000e+00> : vector<8x12xf32>
+// CHECK-DAG: %[[V0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x4xf32>, vector<8x4xf32>
+// CHECK-DAG: %[[V1:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x12xf32>, vector<4x12xf32>
+// CHECK-DAG: %[[V2:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x12xf32>, vector<8x12xf32>
+//
+// linalg contraction lowers to %tmp = vector.contract %a, %b, %c0 followed by addf %c, %tmp.
+// a later canonicalization fuses the add into vector.contract.
+// CHECK: %[[C:.*]] = vector.contract {{.*}} iterator_types = ["parallel", "parallel", "reduction"]} %[[V0]], %[[V1]], %[[VEC_C0]] : vector<8x4xf32>, vector<4x12xf32> into vector<8x12xf32>
+// CHECK: %[[C2:.*]] = addf %[[V2]], %[[C]] : vector<8x12xf32>
+// CHECK: %[[W:.*]] = vector.transfer_write %[[C2]], %[[ARG2]][%[[C0]], %[[C0]]] {masked = [false, false]} : vector<8x12xf32>, tensor<8x12xf32>
// CHECK: return %[[W]] : tensor<8x12xf32>
More information about the Mlir-commits
mailing list