[Mlir-commits] [mlir] 4d21da0 - [mlir] Return vectorized values instead of replacing (#144158)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jun 24 12:06:44 PDT 2025
Author: Max191
Date: 2025-06-24T12:06:41-07:00
New Revision: 4d21da002a056c64231fb89ee9e4eba90080e9bb
URL: https://github.com/llvm/llvm-project/commit/4d21da002a056c64231fb89ee9e4eba90080e9bb
DIFF: https://github.com/llvm/llvm-project/commit/4d21da002a056c64231fb89ee9e4eba90080e9bb.diff
LOG: [mlir] Return vectorized values instead of replacing (#144158)
Updates the linalg::vectorize function to return a
`FailureOr<VectorizationResult>` containing the values to replace the
original operation, instead of directly replacing the original
operation. This aligns better with the style of transforms used with the
TilingInterface, and gives more control to users over the lowering,
since it allows for additional transformation of the IR before
replacement.
There was already a `VectorizationResult` defined, which was used for
the internal vectorize implementation using `CustomVectorizationHook`s,
so the old struct is renamed to `VectorizationHookResult`.
Note for integration: The replacement of the original operation is now
the responsibility of the caller, so wherever `linalg::vectorize` is
used, the caller must also do
`rewriter.replaceOp(vectorizeResults->replacements)`.
---------
Signed-off-by: Max Dawkins <max.dawkins at gmail.com>
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 9db2a742a7d55..189438e9ad528 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -854,17 +854,23 @@ LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value /*buffer*/);
/// to work (these are checked by the vectorizer itself).
bool hasVectorizationImpl(Operation *);
-/// Emit a suitable vector form for an operation. If provided,
-/// `inputVectorSizes` are used to vectorize this operation. `inputVectorSizes`
-/// must match the rank of the iteration space of the operation and the sizes
-/// must be smaller or equal than their counterpart interation space sizes, if
-/// static. `inputVectorShapes` also allows the vectorization of operations with
-/// dynamic shapes.
-LogicalResult vectorize(RewriterBase &rewriter, Operation *op,
- ArrayRef<int64_t> inputVectorSizes = {},
- ArrayRef<bool> inputScalableVecDims = {},
- bool vectorizeNDExtract = false,
- bool flatten1DDepthwiseConv = false);
+/// Transformation information returned after vectorizing.
+struct VectorizationResult {
+ /// Results of the vectorization transform to replace the original operation.
+ SmallVector<Value> replacements;
+};
+/// Returns a `VectorizationResult` containing the results of the vectorized op,
+/// or failure if the transformation fails. If provided, `inputVectorSizes` are
+/// used to vectorize this operation. `inputVectorSizes` must match the rank of
+/// the iteration space of the operation and the input vector sizes must be
+/// greater than or equal to their counterpart iteration space sizes, if static.
+/// `inputVectorShapes` also allows the vectorization of operations with dynamic
+/// shapes.
+FailureOr<VectorizationResult>
+vectorize(RewriterBase &rewriter, Operation *op,
+ ArrayRef<int64_t> inputVectorSizes = {},
+ ArrayRef<bool> inputScalableVecDims = {},
+ bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false);
/// Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 2355edea2df6c..2b78e31558ea2 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3823,9 +3823,14 @@ struct VectorizationPattern : public RewritePattern {
if (!linalg::hasVectorizationImpl(op))
return rewriter.notifyMatchFailure(op,
"Unsupported Op, cannot vectorize");
- return vectorize(rewriter, op, /*inputVectorSizes=*/{},
- /*inputScalableVecDims=*/{}, vectorizeNDExtract,
- flatten1DDepthwiseConv);
+ FailureOr<VectorizationResult> vectorResults =
+ vectorize(rewriter, op, /*inputVectorSizes=*/{},
+ /*inputScalableVecDims=*/{}, vectorizeNDExtract,
+ flatten1DDepthwiseConv);
+ if (failed(vectorResults))
+ return failure();
+ rewriter.replaceOp(op, vectorResults->replacements);
+ return success();
}
private:
@@ -3914,13 +3919,14 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
return mlir::emitSilenceableFailure(target->getLoc())
<< "Unsupported Op, cannot vectorize";
}
-
- if (failed(linalg::vectorize(rewriter, target, vectorSizes,
- getScalableSizes(),
- getVectorizeNdExtract().value_or(false)))) {
+ FailureOr<VectorizationResult> vectorResults =
+ linalg::vectorize(rewriter, target, vectorSizes, getScalableSizes(),
+ getVectorizeNdExtract().value_or(false));
+ if (failed(vectorResults)) {
return mlir::emitSilenceableFailure(target->getLoc())
<< "Attempted to vectorize, but failed";
}
+ rewriter.replaceOp(target, vectorResults->replacements);
}
return DiagnosedSilenceableFailure::success();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index ff8e0b8977ae8..e6a19fb5f57be 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -551,9 +551,10 @@ enum class Conv1DOpOrder {
Nwc // Corresponds to operation that traverses the input in (n, w, c) order.
};
-/// Helper data structure to represent the result of vectorization.
-/// In certain specific cases, like terminators, we do not want to propagate/
-enum VectorizationStatus {
+/// Helper data structure to represent the result of vectorization for a single
+/// operation. In certain specific cases, like terminators, we do not want to
+/// propagate.
+enum VectorizationHookStatus {
/// Op failed to vectorize.
Failure = 0,
/// Op vectorized and custom function took care of replacement logic
@@ -564,9 +565,12 @@ enum VectorizationStatus {
// TODO: support values if Op vectorized to Many-Ops whose results we need to
// aggregate for replacement.
};
-struct VectorizationResult {
+/// VectorizationHookResult contains the vectorized op returned from a
+/// CustomVectorizationHook. This is an internal implementation detail of
+/// linalg vectorization, not to be confused with VectorizationResult.
+struct VectorizationHookResult {
/// Return status from vectorizing the current op.
- enum VectorizationStatus status = VectorizationStatus::Failure;
+ enum VectorizationHookStatus status = VectorizationHookStatus::Failure;
/// New vectorized operation to replace the current op.
/// Replacement behavior is specified by `status`.
Operation *newOp;
@@ -728,22 +732,22 @@ using CustomVectorizationPrecondition =
// assuming all its vectorized operands are already in the IRMapping.
// Return nullptr if the Operation cannot be vectorized.
using CustomVectorizationHook =
- std::function<VectorizationResult(Operation *, const IRMapping &)>;
+ std::function<VectorizationHookResult(Operation *, const IRMapping &)>;
/// Helper function to vectorize the terminator of a `linalgOp`. New result
/// vector values are appended to `newResults`. Return
-/// VectorizationStatus::NoReplace to signal the vectorization algorithm that it
-/// should not try to map produced operations and instead return the results
-/// using the `newResults` vector making them available to the vectorization
-/// algorithm for RAUW. This function is meant to be used as a
+/// VectorizationHookStatus::NoReplace to signal the vectorization algorithm
+/// that it should not try to map produced operations and instead return the
+/// results using the `newResults` vector making them available to the
+/// vectorization algorithm for RAUW. This function is meant to be used as a
/// CustomVectorizationHook.
-static VectorizationResult
+static VectorizationHookResult
vectorizeLinalgYield(RewriterBase &rewriter, Operation *op,
const IRMapping &bvm, VectorizationState &state,
LinalgOp linalgOp, SmallVectorImpl<Value> &newResults) {
auto yieldOp = dyn_cast<linalg::YieldOp>(op);
if (!yieldOp)
- return VectorizationResult{VectorizationStatus::Failure, nullptr};
+ return VectorizationHookResult{VectorizationHookStatus::Failure, nullptr};
for (const auto &output : llvm::enumerate(yieldOp.getValues())) {
// TODO: Scan for an opportunity for reuse.
// TODO: use a map.
@@ -755,20 +759,20 @@ vectorizeLinalgYield(RewriterBase &rewriter, Operation *op,
newResults.push_back(newResult);
}
- return VectorizationResult{VectorizationStatus::NoReplace, nullptr};
+ return VectorizationHookResult{VectorizationHookStatus::NoReplace, nullptr};
}
/// Helper function to vectorize the index operations of a `linalgOp`. Return
-/// VectorizationStatus::NewOp to signal the vectorization algorithm that it
+/// VectorizationHookStatus::NewOp to signal the vectorization algorithm that it
/// should map the produced operations. This function is meant to be used as a
/// CustomVectorizationHook.
-static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter,
- VectorizationState &state,
- Operation *op,
- LinalgOp linalgOp) {
+static VectorizationHookResult vectorizeLinalgIndex(RewriterBase &rewriter,
+ VectorizationState &state,
+ Operation *op,
+ LinalgOp linalgOp) {
IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
if (!indexOp)
- return VectorizationResult{VectorizationStatus::Failure, nullptr};
+ return VectorizationHookResult{VectorizationHookStatus::Failure, nullptr};
auto loc = indexOp.getLoc();
// Compute the static loop sizes of the index op.
ArrayRef<int64_t> targetShape = state.getCanonicalVecShape();
@@ -782,7 +786,7 @@ static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter,
// dimension of the iteration space since the vectorization algorithm in this
// case can handle the broadcast.
if (dim == targetShape.size() - 1)
- return VectorizationResult{VectorizationStatus::NewOp, indexSteps};
+ return VectorizationHookResult{VectorizationHookStatus::NewOp, indexSteps};
// Otherwise permute the targetShape to move the index dimension last,
// broadcast the one-dimensional index vector to the permuted shape, and
// finally transpose the broadcasted index vector to undo the permutation.
@@ -800,7 +804,7 @@ static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter,
std::swap(transposition.back(), transposition[dim]);
auto transposeOp =
rewriter.create<vector::TransposeOp>(loc, broadCastOp, transposition);
- return VectorizationResult{VectorizationStatus::NewOp, transposeOp};
+ return VectorizationHookResult{VectorizationHookStatus::NewOp, transposeOp};
}
/// Helper function to check if the tensor.extract can be vectorized by the
@@ -1098,15 +1102,15 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
}
/// Helper function to vectorize the tensor.extract operations. Returns
-/// VectorizationStatus::NewOp to signal the vectorization algorithm that it
+/// VectorizationHookStatus::NewOp to signal the vectorization algorithm that it
/// should map the produced operations. This function is meant to be used as a
/// CustomVectorizationHook.
-static VectorizationResult
+static VectorizationHookResult
vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
Operation *op, LinalgOp linalgOp, const IRMapping &bvm) {
tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
if (!extractOp)
- return VectorizationResult{VectorizationStatus::Failure, nullptr};
+ return VectorizationHookResult{VectorizationHookStatus::Failure, nullptr};
auto loc = extractOp.getLoc();
// Compute the static loop sizes of the extract op.
@@ -1138,7 +1142,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
LDBG("Vectorised as gather load: " << extractOp << "\n");
- return VectorizationResult{VectorizationStatus::NewOp, gatherOp};
+ return VectorizationHookResult{VectorizationHookStatus::NewOp, gatherOp};
}
// 2. Handle:
@@ -1202,7 +1206,8 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
mlir::vector::maskOperation(rewriter, transferReadOp, allTrue);
LDBG("Vectorised as scalar broadcast load: " << extractOp << "\n");
- return VectorizationResult{VectorizationStatus::NewOp, maskedReadOp};
+ return VectorizationHookResult{VectorizationHookStatus::NewOp,
+ maskedReadOp};
}
// 2b. Handle contiguous access.
@@ -1228,7 +1233,8 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
inBounds);
LDBG("Vectorised as contiguous load: " << extractOp);
- return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
+ return VectorizationHookResult{VectorizationHookStatus::NewOp,
+ transferReadOp};
}
/// Emit reduction operations if the shapes of the value to reduce is
diff erent
@@ -1268,9 +1274,9 @@ static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
/// This function assumes all operands of `op` have been vectorized and are in
/// the `bvm` mapping. As a consequence, this function is meant to be called on
/// a topologically-sorted list of ops.
-/// This function does not update `bvm` but returns a VectorizationStatus that
-/// instructs the caller what `bvm` update needs to occur.
-static VectorizationResult
+/// This function does not update `bvm` but returns a VectorizationHookStatus
+/// that instructs the caller what `bvm` update needs to occur.
+static VectorizationHookResult
vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
LinalgOp linalgOp, Operation *op, const IRMapping &bvm,
ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
@@ -1279,8 +1285,8 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
// 1. Try to apply any CustomVectorizationHook.
if (!customVectorizationHooks.empty()) {
for (auto &customFunc : customVectorizationHooks) {
- VectorizationResult result = customFunc(op, bvm);
- if (result.status == VectorizationStatus::Failure)
+ VectorizationHookResult result = customFunc(op, bvm);
+ if (result.status == VectorizationHookStatus::Failure)
continue;
return result;
}
@@ -1289,11 +1295,12 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
// 2. Constant ops don't get vectorized but rather broadcasted at their users.
// Clone so that the constant is not confined to the linalgOp block .
if (isa<arith::ConstantOp, func::ConstantOp>(op))
- return VectorizationResult{VectorizationStatus::NewOp, rewriter.clone(*op)};
+ return VectorizationHookResult{VectorizationHookStatus::NewOp,
+ rewriter.clone(*op)};
// 3. Only ElementwiseMappable are allowed in the generic vectorization.
if (!OpTrait::hasElementwiseMappableTraits(op))
- return VectorizationResult{VectorizationStatus::Failure, nullptr};
+ return VectorizationHookResult{VectorizationHookStatus::Failure, nullptr};
// 4 . Check if the operation is a reduction.
SmallVector<std::pair<Value, Value>> reductionOperands;
@@ -1316,7 +1323,7 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
reductionOperands[0].second, bvm);
if (reduceOp)
- return VectorizationResult{VectorizationStatus::NewOp, reduceOp};
+ return VectorizationHookResult{VectorizationHookStatus::NewOp, reduceOp};
}
// 5. Generic vectorization path for ElementwiseMappable ops.
@@ -1356,8 +1363,8 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
: resultType);
}
// d. Build and return the new op.
- return VectorizationResult{
- VectorizationStatus::NewOp,
+ return VectorizationHookResult{
+ VectorizationHookStatus::NewOp,
rewriter.create(op->getLoc(), op->getName().getIdentifier(), vecOperands,
resultTypes, op->getAttrs())};
}
@@ -1461,34 +1468,34 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
SmallVector<CustomVectorizationHook> hooks;
// 4a. Register CustomVectorizationHook for yieldOp.
CustomVectorizationHook vectorizeYield =
- [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
+ [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
return vectorizeLinalgYield(rewriter, op, bvm, state, linalgOp, newResults);
};
hooks.push_back(vectorizeYield);
// 4b. Register CustomVectorizationHook for indexOp.
CustomVectorizationHook vectorizeIndex =
- [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
+ [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
return vectorizeLinalgIndex(rewriter, state, op, linalgOp);
};
hooks.push_back(vectorizeIndex);
// 4c. Register CustomVectorizationHook for extractOp.
CustomVectorizationHook vectorizeExtract =
- [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
+ [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
return vectorizeTensorExtract(rewriter, state, op, linalgOp, bvm);
};
hooks.push_back(vectorizeExtract);
// 5. Iteratively call `vectorizeOneOp` to each op in the slice.
for (Operation &op : block->getOperations()) {
- VectorizationResult result =
+ VectorizationHookResult result =
vectorizeOneOp(rewriter, state, linalgOp, &op, bvm, hooks);
- if (result.status == VectorizationStatus::Failure) {
+ if (result.status == VectorizationHookStatus::Failure) {
LDBG("failed to vectorize: " << op << "\n");
return failure();
}
- if (result.status == VectorizationStatus::NewOp) {
+ if (result.status == VectorizationHookStatus::NewOp) {
Operation *maybeMaskedOp =
state.maskOperation(rewriter, result.newOp, linalgOp);
LDBG("New vector op: " << *maybeMaskedOp << "\n");
@@ -2525,17 +2532,11 @@ bool mlir::linalg::hasVectorizationImpl(Operation *op) {
tensor::InsertSliceOp>(op);
}
-/// Emit a suitable vector form for an operation. If provided,
-/// `inputVectorSizes` are used to vectorize this operation.
-/// `inputVectorSizes` must match the rank of the iteration space of the
-/// operation and the input vector sizes must be greater than or equal to
-/// their counterpart iteration space sizes, if static. `inputVectorShapes`
-/// also allows the vectorization of operations with dynamic shapes.
-LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
- ArrayRef<int64_t> inputVectorSizes,
- ArrayRef<bool> inputScalableVecDims,
- bool vectorizeNDExtract,
- bool flatten1DDepthwiseConv) {
+FailureOr<VectorizationResult>
+mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
+ ArrayRef<int64_t> inputVectorSizes,
+ ArrayRef<bool> inputScalableVecDims,
+ bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
LDBG("Attempting to vectorize:\n" << *op << "\n");
LDBG("Input vector sizes: ");
LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
@@ -2617,12 +2618,7 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
return failure();
}
- if (!results.empty())
- rewriter.replaceOp(op, results);
- else
- rewriter.eraseOp(op);
-
- return success();
+ return VectorizationResult{results};
}
LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
More information about the Mlir-commits
mailing list