[Mlir-commits] [mlir] [mlir] Return vectorized values instead of replacing (PR #144158)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jun 24 10:29:29 PDT 2025
https://github.com/Max191 updated https://github.com/llvm/llvm-project/pull/144158
>From b8a783bf4e0aa2c478d9fcafe8e283fdc6198d17 Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Thu, 19 Jun 2025 12:49:05 +0000
Subject: [PATCH 1/2] [mlir] Return vectorized values instead of replacing
Signed-off-by: Max Dawkins <max.dawkins at gmail.com>
---
.../Dialect/Linalg/Transforms/Transforms.h | 27 ++++---
.../TransformOps/LinalgTransformOps.cpp | 20 +++--
.../Linalg/Transforms/Vectorization.cpp | 78 ++++++++-----------
3 files changed, 64 insertions(+), 61 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 2eef0a06d0eb4..248ffe4d53d91 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -771,17 +771,24 @@ LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value /*buffer*/);
/// to work (these are checked by the vectorizer itself).
bool hasVectorizationImpl(Operation *);
+/// Transformation information returned after vectorizing.
+struct VectorizationResult {
+ /// Results of the vectorization transform to replace the original operation.
+ SmallVector<Value> replacements;
+};
/// Emit a suitable vector form for an operation. If provided,
-/// `inputVectorSizes` are used to vectorize this operation. `inputVectorSizes`
-/// must match the rank of the iteration space of the operation and the sizes
-/// must be smaller or equal than their counterpart interation space sizes, if
-/// static. `inputVectorShapes` also allows the vectorization of operations with
-/// dynamic shapes.
-LogicalResult vectorize(RewriterBase &rewriter, Operation *op,
- ArrayRef<int64_t> inputVectorSizes = {},
- ArrayRef<bool> inputScalableVecDims = {},
- bool vectorizeNDExtract = false,
- bool flatten1DDepthwiseConv = false);
+/// `inputVectorSizes` are used to vectorize this operation.
+/// `inputVectorSizes` must match the rank of the iteration space of the
+/// operation and the input vector sizes must be greater than or equal to
+/// their counterpart iteration space sizes, if static. `inputVectorShapes`
+/// also allows the vectorization of operations with dynamic shapes. Returns
+/// a VectorizationResult containing the results of the vectorized op, or
+/// failure if the transformation fails.
+FailureOr<VectorizationResult>
+vectorize(RewriterBase &rewriter, Operation *op,
+ ArrayRef<int64_t> inputVectorSizes = {},
+ ArrayRef<bool> inputScalableVecDims = {},
+ bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false);
/// Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index b2c28f5eed33c..07434507b6eb2 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3609,9 +3609,14 @@ struct VectorizationPattern : public RewritePattern {
if (!linalg::hasVectorizationImpl(op))
return rewriter.notifyMatchFailure(op,
"Unsupported Op, cannot vectorize");
- return vectorize(rewriter, op, /*inputVectorSizes=*/{},
- /*inputScalableVecDims=*/{}, vectorizeNDExtract,
- flatten1DDepthwiseConv);
+ FailureOr<VectorizationResult> vectorResults =
+ vectorize(rewriter, op, /*inputVectorSizes=*/{},
+ /*inputScalableVecDims=*/{}, vectorizeNDExtract,
+ flatten1DDepthwiseConv);
+ if (failed(vectorResults))
+ return failure();
+ rewriter.replaceOp(op, vectorResults->replacements);
+ return success();
}
private:
@@ -3700,13 +3705,14 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
return mlir::emitSilenceableFailure(target->getLoc())
<< "Unsupported Op, cannot vectorize";
}
-
- if (failed(linalg::vectorize(rewriter, target, vectorSizes,
- getScalableSizes(),
- getVectorizeNdExtract().value_or(false)))) {
+ FailureOr<VectorizationResult> vectorResults =
+ linalg::vectorize(rewriter, target, vectorSizes, getScalableSizes(),
+ getVectorizeNdExtract().value_or(false));
+ if (failed(vectorResults)) {
return mlir::emitSilenceableFailure(target->getLoc())
<< "Attempted to vectorize, but failed";
}
+ rewriter.replaceOp(target, vectorResults->replacements);
}
return DiagnosedSilenceableFailure::success();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index ff28bd7c48342..88d49c7af4d60 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -563,7 +563,7 @@ enum VectorizationStatus {
// TODO: support values if Op vectorized to Many-Ops whose results we need to
// aggregate for replacement.
};
-struct VectorizationResult {
+struct VectorizationHookResult {
/// Return status from vectorizing the current op.
enum VectorizationStatus status = VectorizationStatus::Failure;
/// New vectorized operation to replace the current op.
@@ -727,7 +727,7 @@ using CustomVectorizationPrecondition =
// assuming all its vectorized operands are already in the IRMapping.
// Return nullptr if the Operation cannot be vectorized.
using CustomVectorizationHook =
- std::function<VectorizationResult(Operation *, const IRMapping &)>;
+ std::function<VectorizationHookResult(Operation *, const IRMapping &)>;
/// Helper function to vectorize the terminator of a `linalgOp`. New result
/// vector values are appended to `newResults`. Return
@@ -736,13 +736,13 @@ using CustomVectorizationHook =
/// using the `newResults` vector making them available to the vectorization
/// algorithm for RAUW. This function is meant to be used as a
/// CustomVectorizationHook.
-static VectorizationResult
+static VectorizationHookResult
vectorizeLinalgYield(RewriterBase &rewriter, Operation *op,
const IRMapping &bvm, VectorizationState &state,
LinalgOp linalgOp, SmallVectorImpl<Value> &newResults) {
auto yieldOp = dyn_cast<linalg::YieldOp>(op);
if (!yieldOp)
- return VectorizationResult{VectorizationStatus::Failure, nullptr};
+ return VectorizationHookResult{VectorizationStatus::Failure, nullptr};
for (const auto &output : llvm::enumerate(yieldOp.getValues())) {
// TODO: Scan for an opportunity for reuse.
// TODO: use a map.
@@ -754,20 +754,20 @@ vectorizeLinalgYield(RewriterBase &rewriter, Operation *op,
newResults.push_back(newResult);
}
- return VectorizationResult{VectorizationStatus::NoReplace, nullptr};
+ return VectorizationHookResult{VectorizationStatus::NoReplace, nullptr};
}
/// Helper function to vectorize the index operations of a `linalgOp`. Return
/// VectorizationStatus::NewOp to signal the vectorization algorithm that it
/// should map the produced operations. This function is meant to be used as a
/// CustomVectorizationHook.
-static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter,
- VectorizationState &state,
- Operation *op,
- LinalgOp linalgOp) {
+static VectorizationHookResult vectorizeLinalgIndex(RewriterBase &rewriter,
+ VectorizationState &state,
+ Operation *op,
+ LinalgOp linalgOp) {
IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
if (!indexOp)
- return VectorizationResult{VectorizationStatus::Failure, nullptr};
+ return VectorizationHookResult{VectorizationStatus::Failure, nullptr};
auto loc = indexOp.getLoc();
// Compute the static loop sizes of the index op.
ArrayRef<int64_t> targetShape = state.getCanonicalVecShape();
@@ -781,7 +781,7 @@ static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter,
// dimension of the iteration space since the vectorization algorithm in this
// case can handle the broadcast.
if (dim == targetShape.size() - 1)
- return VectorizationResult{VectorizationStatus::NewOp, indexSteps};
+ return VectorizationHookResult{VectorizationStatus::NewOp, indexSteps};
// Otherwise permute the targetShape to move the index dimension last,
// broadcast the one-dimensional index vector to the permuted shape, and
// finally transpose the broadcasted index vector to undo the permutation.
@@ -799,7 +799,7 @@ static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter,
std::swap(transposition.back(), transposition[dim]);
auto transposeOp =
rewriter.create<vector::TransposeOp>(loc, broadCastOp, transposition);
- return VectorizationResult{VectorizationStatus::NewOp, transposeOp};
+ return VectorizationHookResult{VectorizationStatus::NewOp, transposeOp};
}
/// Helper function to check if the tensor.extract can be vectorized by the
@@ -1100,12 +1100,12 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
/// VectorizationStatus::NewOp to signal the vectorization algorithm that it
/// should map the produced operations. This function is meant to be used as a
/// CustomVectorizationHook.
-static VectorizationResult
+static VectorizationHookResult
vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
Operation *op, LinalgOp linalgOp, const IRMapping &bvm) {
tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
if (!extractOp)
- return VectorizationResult{VectorizationStatus::Failure, nullptr};
+ return VectorizationHookResult{VectorizationStatus::Failure, nullptr};
auto loc = extractOp.getLoc();
// Compute the static loop sizes of the extract op.
@@ -1137,7 +1137,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
LDBG("Vectorised as gather load: " << extractOp << "\n");
- return VectorizationResult{VectorizationStatus::NewOp, gatherOp};
+ return VectorizationHookResult{VectorizationStatus::NewOp, gatherOp};
}
// 2. Handle:
@@ -1201,7 +1201,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
mlir::vector::maskOperation(rewriter, transferReadOp, allTrue);
LDBG("Vectorised as scalar broadcast load: " << extractOp << "\n");
- return VectorizationResult{VectorizationStatus::NewOp, maskedReadOp};
+ return VectorizationHookResult{VectorizationStatus::NewOp, maskedReadOp};
}
// 2b. Handle contiguous access.
@@ -1227,7 +1227,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
inBounds);
LDBG("Vectorised as contiguous load: " << extractOp);
- return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
+ return VectorizationHookResult{VectorizationStatus::NewOp, transferReadOp};
}
/// Emit reduction operations if the shapes of the value to reduce is different
@@ -1269,7 +1269,7 @@ static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
/// a topologically-sorted list of ops.
/// This function does not update `bvm` but returns a VectorizationStatus that
/// instructs the caller what `bvm` update needs to occur.
-static VectorizationResult
+static VectorizationHookResult
vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
LinalgOp linalgOp, Operation *op, const IRMapping &bvm,
ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
@@ -1278,7 +1278,7 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
// 1. Try to apply any CustomVectorizationHook.
if (!customVectorizationHooks.empty()) {
for (auto &customFunc : customVectorizationHooks) {
- VectorizationResult result = customFunc(op, bvm);
+ VectorizationHookResult result = customFunc(op, bvm);
if (result.status == VectorizationStatus::Failure)
continue;
return result;
@@ -1288,11 +1288,12 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
// 2. Constant ops don't get vectorized but rather broadcasted at their users.
// Clone so that the constant is not confined to the linalgOp block .
if (isa<arith::ConstantOp, func::ConstantOp>(op))
- return VectorizationResult{VectorizationStatus::NewOp, rewriter.clone(*op)};
+ return VectorizationHookResult{VectorizationStatus::NewOp,
+ rewriter.clone(*op)};
// 3. Only ElementwiseMappable are allowed in the generic vectorization.
if (!OpTrait::hasElementwiseMappableTraits(op))
- return VectorizationResult{VectorizationStatus::Failure, nullptr};
+ return VectorizationHookResult{VectorizationStatus::Failure, nullptr};
// 4 . Check if the operation is a reduction.
SmallVector<std::pair<Value, Value>> reductionOperands;
@@ -1315,7 +1316,7 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
reductionOperands[0].second, bvm);
if (reduceOp)
- return VectorizationResult{VectorizationStatus::NewOp, reduceOp};
+ return VectorizationHookResult{VectorizationStatus::NewOp, reduceOp};
}
// 5. Generic vectorization path for ElementwiseMappable ops.
@@ -1355,7 +1356,7 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
: resultType);
}
// d. Build and return the new op.
- return VectorizationResult{
+ return VectorizationHookResult{
VectorizationStatus::NewOp,
rewriter.create(op->getLoc(), op->getName().getIdentifier(), vecOperands,
resultTypes, op->getAttrs())};
@@ -1460,28 +1461,28 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
SmallVector<CustomVectorizationHook> hooks;
// 4a. Register CustomVectorizationHook for yieldOp.
CustomVectorizationHook vectorizeYield =
- [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
+ [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
return vectorizeLinalgYield(rewriter, op, bvm, state, linalgOp, newResults);
};
hooks.push_back(vectorizeYield);
// 4b. Register CustomVectorizationHook for indexOp.
CustomVectorizationHook vectorizeIndex =
- [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
+ [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
return vectorizeLinalgIndex(rewriter, state, op, linalgOp);
};
hooks.push_back(vectorizeIndex);
// 4c. Register CustomVectorizationHook for extractOp.
CustomVectorizationHook vectorizeExtract =
- [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
+ [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
return vectorizeTensorExtract(rewriter, state, op, linalgOp, bvm);
};
hooks.push_back(vectorizeExtract);
// 5. Iteratively call `vectorizeOneOp` to each op in the slice.
for (Operation &op : block->getOperations()) {
- VectorizationResult result =
+ VectorizationHookResult result =
vectorizeOneOp(rewriter, state, linalgOp, &op, bvm, hooks);
if (result.status == VectorizationStatus::Failure) {
LDBG("failed to vectorize: " << op << "\n");
@@ -2522,17 +2523,11 @@ bool mlir::linalg::hasVectorizationImpl(Operation *op) {
tensor::InsertSliceOp>(op);
}
-/// Emit a suitable vector form for an operation. If provided,
-/// `inputVectorSizes` are used to vectorize this operation.
-/// `inputVectorSizes` must match the rank of the iteration space of the
-/// operation and the input vector sizes must be greater than or equal to
-/// their counterpart iteration space sizes, if static. `inputVectorShapes`
-/// also allows the vectorization of operations with dynamic shapes.
-LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
- ArrayRef<int64_t> inputVectorSizes,
- ArrayRef<bool> inputScalableVecDims,
- bool vectorizeNDExtract,
- bool flatten1DDepthwiseConv) {
+FailureOr<VectorizationResult>
+mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
+ ArrayRef<int64_t> inputVectorSizes,
+ ArrayRef<bool> inputScalableVecDims,
+ bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
LDBG("Attempting to vectorize:\n" << *op << "\n");
LDBG("Input vector sizes: ");
LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
@@ -2614,12 +2609,7 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
return failure();
}
- if (!results.empty())
- rewriter.replaceOp(op, results);
- else
- rewriter.eraseOp(op);
-
- return success();
+ return VectorizationResult({results});
}
LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
>From 9da0c1804cbf00eafe34a8942bf3c1baf89d58aa Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Tue, 24 Jun 2025 16:47:44 +0000
Subject: [PATCH 2/2] address comments
Signed-off-by: Max Dawkins <max.dawkins at gmail.com>
---
.../Dialect/Linalg/Transforms/Transforms.h | 15 ++---
.../Linalg/Transforms/Vectorization.cpp | 64 ++++++++++---------
2 files changed, 42 insertions(+), 37 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 248ffe4d53d91..2863bb54184af 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -776,14 +776,13 @@ struct VectorizationResult {
/// Results of the vectorization transform to replace the original operation.
SmallVector<Value> replacements;
};
-/// Emit a suitable vector form for an operation. If provided,
-/// `inputVectorSizes` are used to vectorize this operation.
-/// `inputVectorSizes` must match the rank of the iteration space of the
-/// operation and the input vector sizes must be greater than or equal to
-/// their counterpart iteration space sizes, if static. `inputVectorShapes`
-/// also allows the vectorization of operations with dynamic shapes. Returns
-/// a VectorizationResult containing the results of the vectorized op, or
-/// failure if the transformation fails.
+/// Returns a `VectorizationResult` containing the results of the vectorized op,
+/// or failure if the transformation fails. If provided, `inputVectorSizes` are
+/// used to vectorize this operation. `inputVectorSizes` must match the rank of
+/// the iteration space of the operation and the input vector sizes must be
+/// greater than or equal to their counterpart iteration space sizes, if static.
+/// `inputVectorShapes` also allows the vectorization of operations with dynamic
+/// shapes.
FailureOr<VectorizationResult>
vectorize(RewriterBase &rewriter, Operation *op,
ArrayRef<int64_t> inputVectorSizes = {},
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 88d49c7af4d60..e0a0c4114ed97 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -550,9 +550,10 @@ enum class Conv1DOpOrder {
Nwc // Corresponds to operation that traverses the input in (n, w, c) order.
};
-/// Helper data structure to represent the result of vectorization.
-/// In certain specific cases, like terminators, we do not want to propagate/
-enum VectorizationStatus {
+/// Helper data structure to represent the result of vectorization for a single
+/// operation. In certain specific cases, like terminators, we do not want to
+/// propagate.
+enum VectorizationHookStatus {
/// Op failed to vectorize.
Failure = 0,
/// Op vectorized and custom function took care of replacement logic
@@ -563,9 +564,12 @@ enum VectorizationStatus {
// TODO: support values if Op vectorized to Many-Ops whose results we need to
// aggregate for replacement.
};
+/// VectorizationHookResult contains the vectorized op returned from a
+/// CustomVectorizationHook. This is an internal implementation detail of
+/// linalg vectorization, not to be confused with VectorizationResult.
struct VectorizationHookResult {
/// Return status from vectorizing the current op.
- enum VectorizationStatus status = VectorizationStatus::Failure;
+ enum VectorizationHookStatus status = VectorizationHookStatus::Failure;
/// New vectorized operation to replace the current op.
/// Replacement behavior is specified by `status`.
Operation *newOp;
@@ -731,10 +735,10 @@ using CustomVectorizationHook =
/// Helper function to vectorize the terminator of a `linalgOp`. New result
/// vector values are appended to `newResults`. Return
-/// VectorizationStatus::NoReplace to signal the vectorization algorithm that it
-/// should not try to map produced operations and instead return the results
-/// using the `newResults` vector making them available to the vectorization
-/// algorithm for RAUW. This function is meant to be used as a
+/// VectorizationHookStatus::NoReplace to signal the vectorization algorithm
+/// that it should not try to map produced operations and instead return the
+/// results using the `newResults` vector making them available to the
+/// vectorization algorithm for RAUW. This function is meant to be used as a
/// CustomVectorizationHook.
static VectorizationHookResult
vectorizeLinalgYield(RewriterBase &rewriter, Operation *op,
@@ -742,7 +746,7 @@ vectorizeLinalgYield(RewriterBase &rewriter, Operation *op,
LinalgOp linalgOp, SmallVectorImpl<Value> &newResults) {
auto yieldOp = dyn_cast<linalg::YieldOp>(op);
if (!yieldOp)
- return VectorizationHookResult{VectorizationStatus::Failure, nullptr};
+ return VectorizationHookResult{VectorizationHookStatus::Failure, nullptr};
for (const auto &output : llvm::enumerate(yieldOp.getValues())) {
// TODO: Scan for an opportunity for reuse.
// TODO: use a map.
@@ -754,11 +758,11 @@ vectorizeLinalgYield(RewriterBase &rewriter, Operation *op,
newResults.push_back(newResult);
}
- return VectorizationHookResult{VectorizationStatus::NoReplace, nullptr};
+ return VectorizationHookResult{VectorizationHookStatus::NoReplace, nullptr};
}
/// Helper function to vectorize the index operations of a `linalgOp`. Return
-/// VectorizationStatus::NewOp to signal the vectorization algorithm that it
+/// VectorizationHookStatus::NewOp to signal the vectorization algorithm that it
/// should map the produced operations. This function is meant to be used as a
/// CustomVectorizationHook.
static VectorizationHookResult vectorizeLinalgIndex(RewriterBase &rewriter,
@@ -767,7 +771,7 @@ static VectorizationHookResult vectorizeLinalgIndex(RewriterBase &rewriter,
LinalgOp linalgOp) {
IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
if (!indexOp)
- return VectorizationHookResult{VectorizationStatus::Failure, nullptr};
+ return VectorizationHookResult{VectorizationHookStatus::Failure, nullptr};
auto loc = indexOp.getLoc();
// Compute the static loop sizes of the index op.
ArrayRef<int64_t> targetShape = state.getCanonicalVecShape();
@@ -781,7 +785,7 @@ static VectorizationHookResult vectorizeLinalgIndex(RewriterBase &rewriter,
// dimension of the iteration space since the vectorization algorithm in this
// case can handle the broadcast.
if (dim == targetShape.size() - 1)
- return VectorizationHookResult{VectorizationStatus::NewOp, indexSteps};
+ return VectorizationHookResult{VectorizationHookStatus::NewOp, indexSteps};
// Otherwise permute the targetShape to move the index dimension last,
// broadcast the one-dimensional index vector to the permuted shape, and
// finally transpose the broadcasted index vector to undo the permutation.
@@ -799,7 +803,7 @@ static VectorizationHookResult vectorizeLinalgIndex(RewriterBase &rewriter,
std::swap(transposition.back(), transposition[dim]);
auto transposeOp =
rewriter.create<vector::TransposeOp>(loc, broadCastOp, transposition);
- return VectorizationHookResult{VectorizationStatus::NewOp, transposeOp};
+ return VectorizationHookResult{VectorizationHookStatus::NewOp, transposeOp};
}
/// Helper function to check if the tensor.extract can be vectorized by the
@@ -1097,7 +1101,7 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
}
/// Helper function to vectorize the tensor.extract operations. Returns
-/// VectorizationStatus::NewOp to signal the vectorization algorithm that it
+/// VectorizationHookStatus::NewOp to signal the vectorization algorithm that it
/// should map the produced operations. This function is meant to be used as a
/// CustomVectorizationHook.
static VectorizationHookResult
@@ -1105,7 +1109,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
Operation *op, LinalgOp linalgOp, const IRMapping &bvm) {
tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
if (!extractOp)
- return VectorizationHookResult{VectorizationStatus::Failure, nullptr};
+ return VectorizationHookResult{VectorizationHookStatus::Failure, nullptr};
auto loc = extractOp.getLoc();
// Compute the static loop sizes of the extract op.
@@ -1137,7 +1141,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
LDBG("Vectorised as gather load: " << extractOp << "\n");
- return VectorizationHookResult{VectorizationStatus::NewOp, gatherOp};
+ return VectorizationHookResult{VectorizationHookStatus::NewOp, gatherOp};
}
// 2. Handle:
@@ -1201,7 +1205,8 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
mlir::vector::maskOperation(rewriter, transferReadOp, allTrue);
LDBG("Vectorised as scalar broadcast load: " << extractOp << "\n");
- return VectorizationHookResult{VectorizationStatus::NewOp, maskedReadOp};
+ return VectorizationHookResult{VectorizationHookStatus::NewOp,
+ maskedReadOp};
}
// 2b. Handle contiguous access.
@@ -1227,7 +1232,8 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
inBounds);
LDBG("Vectorised as contiguous load: " << extractOp);
- return VectorizationHookResult{VectorizationStatus::NewOp, transferReadOp};
+ return VectorizationHookResult{VectorizationHookStatus::NewOp,
+ transferReadOp};
}
/// Emit reduction operations if the shapes of the value to reduce is different
@@ -1267,8 +1273,8 @@ static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
/// This function assumes all operands of `op` have been vectorized and are in
/// the `bvm` mapping. As a consequence, this function is meant to be called on
/// a topologically-sorted list of ops.
-/// This function does not update `bvm` but returns a VectorizationStatus that
-/// instructs the caller what `bvm` update needs to occur.
+/// This function does not update `bvm` but returns a VectorizationHookStatus
+/// that instructs the caller what `bvm` update needs to occur.
static VectorizationHookResult
vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
LinalgOp linalgOp, Operation *op, const IRMapping &bvm,
@@ -1279,7 +1285,7 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
if (!customVectorizationHooks.empty()) {
for (auto &customFunc : customVectorizationHooks) {
VectorizationHookResult result = customFunc(op, bvm);
- if (result.status == VectorizationStatus::Failure)
+ if (result.status == VectorizationHookStatus::Failure)
continue;
return result;
}
@@ -1288,12 +1294,12 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
// 2. Constant ops don't get vectorized but rather broadcasted at their users.
// Clone so that the constant is not confined to the linalgOp block .
if (isa<arith::ConstantOp, func::ConstantOp>(op))
- return VectorizationHookResult{VectorizationStatus::NewOp,
+ return VectorizationHookResult{VectorizationHookStatus::NewOp,
rewriter.clone(*op)};
// 3. Only ElementwiseMappable are allowed in the generic vectorization.
if (!OpTrait::hasElementwiseMappableTraits(op))
- return VectorizationHookResult{VectorizationStatus::Failure, nullptr};
+ return VectorizationHookResult{VectorizationHookStatus::Failure, nullptr};
// 4 . Check if the operation is a reduction.
SmallVector<std::pair<Value, Value>> reductionOperands;
@@ -1316,7 +1322,7 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
reductionOperands[0].second, bvm);
if (reduceOp)
- return VectorizationHookResult{VectorizationStatus::NewOp, reduceOp};
+ return VectorizationHookResult{VectorizationHookStatus::NewOp, reduceOp};
}
// 5. Generic vectorization path for ElementwiseMappable ops.
@@ -1357,7 +1363,7 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
}
// d. Build and return the new op.
return VectorizationHookResult{
- VectorizationStatus::NewOp,
+ VectorizationHookStatus::NewOp,
rewriter.create(op->getLoc(), op->getName().getIdentifier(), vecOperands,
resultTypes, op->getAttrs())};
}
@@ -1484,11 +1490,11 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
for (Operation &op : block->getOperations()) {
VectorizationHookResult result =
vectorizeOneOp(rewriter, state, linalgOp, &op, bvm, hooks);
- if (result.status == VectorizationStatus::Failure) {
+ if (result.status == VectorizationHookStatus::Failure) {
LDBG("failed to vectorize: " << op << "\n");
return failure();
}
- if (result.status == VectorizationStatus::NewOp) {
+ if (result.status == VectorizationHookStatus::NewOp) {
Operation *maybeMaskedOp =
state.maskOperation(rewriter, result.newOp, linalgOp);
LDBG("New vector op: " << *maybeMaskedOp << "\n");
@@ -2609,7 +2615,7 @@ mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
return failure();
}
- return VectorizationResult({results});
+ return VectorizationResult{results};
}
LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
More information about the Mlir-commits
mailing list