[Mlir-commits] [mlir] [mlir] Expose linalg vectorization without replacement (PR #144158)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jun 19 05:49:36 PDT 2025


https://github.com/Max191 updated https://github.com/llvm/llvm-project/pull/144158

>From b8a783bf4e0aa2c478d9fcafe8e283fdc6198d17 Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Thu, 19 Jun 2025 12:49:05 +0000
Subject: [PATCH] [mlir] Return vectorized values instead of replacing

Signed-off-by: Max Dawkins <max.dawkins at gmail.com>
---
 .../Dialect/Linalg/Transforms/Transforms.h    | 27 ++++---
 .../TransformOps/LinalgTransformOps.cpp       | 20 +++--
 .../Linalg/Transforms/Vectorization.cpp       | 78 ++++++++-----------
 3 files changed, 64 insertions(+), 61 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 2eef0a06d0eb4..248ffe4d53d91 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -771,17 +771,24 @@ LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value /*buffer*/);
 /// to work (these are checked by the vectorizer itself).
 bool hasVectorizationImpl(Operation *);
 
+/// Transformation information returned after vectorizing.
+struct VectorizationResult {
+  /// Results of the vectorization transform to replace the original operation.
+  SmallVector<Value> replacements;
+};
 /// Emit a suitable vector form for an operation. If provided,
-/// `inputVectorSizes` are used to vectorize this operation. `inputVectorSizes`
-/// must match the rank of the iteration space of the operation and the sizes
-/// must be smaller or equal than their counterpart interation space sizes, if
-/// static. `inputVectorShapes` also allows the vectorization of operations with
-/// dynamic shapes.
-LogicalResult vectorize(RewriterBase &rewriter, Operation *op,
-                        ArrayRef<int64_t> inputVectorSizes = {},
-                        ArrayRef<bool> inputScalableVecDims = {},
-                        bool vectorizeNDExtract = false,
-                        bool flatten1DDepthwiseConv = false);
+/// `inputVectorSizes` are used to vectorize this operation.
+/// `inputVectorSizes` must match the rank of the iteration space of the
+/// operation and the input vector sizes must be greater than or equal to
+/// their counterpart iteration space sizes, if static. `inputVectorShapes`
+/// also allows the vectorization of operations with dynamic shapes. Returns
+/// a VectorizationResult containing the results of the vectorized op, or
+/// failure if the transformation fails.
+FailureOr<VectorizationResult>
+vectorize(RewriterBase &rewriter, Operation *op,
+          ArrayRef<int64_t> inputVectorSizes = {},
+          ArrayRef<bool> inputScalableVecDims = {},
+          bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false);
 
 /// Emit a suitable vector form for a Copy op with fully static shape.
 LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index b2c28f5eed33c..07434507b6eb2 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3609,9 +3609,14 @@ struct VectorizationPattern : public RewritePattern {
     if (!linalg::hasVectorizationImpl(op))
       return rewriter.notifyMatchFailure(op,
                                          "Unsupported Op, cannot vectorize");
-    return vectorize(rewriter, op, /*inputVectorSizes=*/{},
-                     /*inputScalableVecDims=*/{}, vectorizeNDExtract,
-                     flatten1DDepthwiseConv);
+    FailureOr<VectorizationResult> vectorResults =
+        vectorize(rewriter, op, /*inputVectorSizes=*/{},
+                  /*inputScalableVecDims=*/{}, vectorizeNDExtract,
+                  flatten1DDepthwiseConv);
+    if (failed(vectorResults))
+      return failure();
+    rewriter.replaceOp(op, vectorResults->replacements);
+    return success();
   }
 
 private:
@@ -3700,13 +3705,14 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
       return mlir::emitSilenceableFailure(target->getLoc())
              << "Unsupported Op, cannot vectorize";
     }
-
-    if (failed(linalg::vectorize(rewriter, target, vectorSizes,
-                                 getScalableSizes(),
-                                 getVectorizeNdExtract().value_or(false)))) {
+    FailureOr<VectorizationResult> vectorResults =
+        linalg::vectorize(rewriter, target, vectorSizes, getScalableSizes(),
+                          getVectorizeNdExtract().value_or(false));
+    if (failed(vectorResults)) {
       return mlir::emitSilenceableFailure(target->getLoc())
              << "Attempted to vectorize, but failed";
     }
+    rewriter.replaceOp(target, vectorResults->replacements);
   }
 
   return DiagnosedSilenceableFailure::success();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index ff28bd7c48342..88d49c7af4d60 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -563,7 +563,7 @@ enum VectorizationStatus {
   // TODO: support values if Op vectorized to Many-Ops whose results we need to
   // aggregate for replacement.
 };
-struct VectorizationResult {
+struct VectorizationHookResult {
   /// Return status from vectorizing the current op.
   enum VectorizationStatus status = VectorizationStatus::Failure;
   /// New vectorized operation to replace the current op.
@@ -727,7 +727,7 @@ using CustomVectorizationPrecondition =
 // assuming all its vectorized operands are already in the IRMapping.
 // Return nullptr if the Operation cannot be vectorized.
 using CustomVectorizationHook =
-    std::function<VectorizationResult(Operation *, const IRMapping &)>;
+    std::function<VectorizationHookResult(Operation *, const IRMapping &)>;
 
 /// Helper function to vectorize the terminator of a `linalgOp`. New result
 /// vector values are appended to `newResults`. Return
@@ -736,13 +736,13 @@ using CustomVectorizationHook =
 /// using the `newResults` vector making them available to the vectorization
 /// algorithm for RAUW. This function is meant to be used as a
 /// CustomVectorizationHook.
-static VectorizationResult
+static VectorizationHookResult
 vectorizeLinalgYield(RewriterBase &rewriter, Operation *op,
                      const IRMapping &bvm, VectorizationState &state,
                      LinalgOp linalgOp, SmallVectorImpl<Value> &newResults) {
   auto yieldOp = dyn_cast<linalg::YieldOp>(op);
   if (!yieldOp)
-    return VectorizationResult{VectorizationStatus::Failure, nullptr};
+    return VectorizationHookResult{VectorizationStatus::Failure, nullptr};
   for (const auto &output : llvm::enumerate(yieldOp.getValues())) {
     // TODO: Scan for an opportunity for reuse.
     // TODO: use a map.
@@ -754,20 +754,20 @@ vectorizeLinalgYield(RewriterBase &rewriter, Operation *op,
       newResults.push_back(newResult);
   }
 
-  return VectorizationResult{VectorizationStatus::NoReplace, nullptr};
+  return VectorizationHookResult{VectorizationStatus::NoReplace, nullptr};
 }
 
 /// Helper function to vectorize the index operations of a `linalgOp`. Return
 /// VectorizationStatus::NewOp to signal the vectorization algorithm that it
 /// should map the produced operations. This function is meant to be used as a
 /// CustomVectorizationHook.
-static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter,
-                                                VectorizationState &state,
-                                                Operation *op,
-                                                LinalgOp linalgOp) {
+static VectorizationHookResult vectorizeLinalgIndex(RewriterBase &rewriter,
+                                                    VectorizationState &state,
+                                                    Operation *op,
+                                                    LinalgOp linalgOp) {
   IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
   if (!indexOp)
-    return VectorizationResult{VectorizationStatus::Failure, nullptr};
+    return VectorizationHookResult{VectorizationStatus::Failure, nullptr};
   auto loc = indexOp.getLoc();
   // Compute the static loop sizes of the index op.
   ArrayRef<int64_t> targetShape = state.getCanonicalVecShape();
@@ -781,7 +781,7 @@ static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter,
   // dimension of the iteration space since the vectorization algorithm in this
   // case can handle the broadcast.
   if (dim == targetShape.size() - 1)
-    return VectorizationResult{VectorizationStatus::NewOp, indexSteps};
+    return VectorizationHookResult{VectorizationStatus::NewOp, indexSteps};
   // Otherwise permute the targetShape to move the index dimension last,
   // broadcast the one-dimensional index vector to the permuted shape, and
   // finally transpose the broadcasted index vector to undo the permutation.
@@ -799,7 +799,7 @@ static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter,
   std::swap(transposition.back(), transposition[dim]);
   auto transposeOp =
       rewriter.create<vector::TransposeOp>(loc, broadCastOp, transposition);
-  return VectorizationResult{VectorizationStatus::NewOp, transposeOp};
+  return VectorizationHookResult{VectorizationStatus::NewOp, transposeOp};
 }
 
 /// Helper function to check if the tensor.extract can be vectorized by the
@@ -1100,12 +1100,12 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
 /// VectorizationStatus::NewOp to signal the vectorization algorithm that it
 /// should map the produced operations. This function is meant to be used as a
 /// CustomVectorizationHook.
-static VectorizationResult
+static VectorizationHookResult
 vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
                        Operation *op, LinalgOp linalgOp, const IRMapping &bvm) {
   tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
   if (!extractOp)
-    return VectorizationResult{VectorizationStatus::Failure, nullptr};
+    return VectorizationHookResult{VectorizationStatus::Failure, nullptr};
   auto loc = extractOp.getLoc();
 
   // Compute the static loop sizes of the extract op.
@@ -1137,7 +1137,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
     gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
 
     LDBG("Vectorised as gather load: " << extractOp << "\n");
-    return VectorizationResult{VectorizationStatus::NewOp, gatherOp};
+    return VectorizationHookResult{VectorizationStatus::NewOp, gatherOp};
   }
 
   // 2. Handle:
@@ -1201,7 +1201,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
         mlir::vector::maskOperation(rewriter, transferReadOp, allTrue);
 
     LDBG("Vectorised as scalar broadcast load: " << extractOp << "\n");
-    return VectorizationResult{VectorizationStatus::NewOp, maskedReadOp};
+    return VectorizationHookResult{VectorizationStatus::NewOp, maskedReadOp};
   }
 
   // 2b. Handle contiguous access.
@@ -1227,7 +1227,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
       inBounds);
 
   LDBG("Vectorised as contiguous load: " << extractOp);
-  return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
+  return VectorizationHookResult{VectorizationStatus::NewOp, transferReadOp};
 }
 
 /// Emit reduction operations if the shapes of the value to reduce is different
@@ -1269,7 +1269,7 @@ static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
 /// a topologically-sorted list of ops.
 /// This function does not update `bvm` but returns a VectorizationStatus that
 /// instructs the caller what `bvm` update needs to occur.
-static VectorizationResult
+static VectorizationHookResult
 vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
                LinalgOp linalgOp, Operation *op, const IRMapping &bvm,
                ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
@@ -1278,7 +1278,7 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
   // 1. Try to apply any CustomVectorizationHook.
   if (!customVectorizationHooks.empty()) {
     for (auto &customFunc : customVectorizationHooks) {
-      VectorizationResult result = customFunc(op, bvm);
+      VectorizationHookResult result = customFunc(op, bvm);
       if (result.status == VectorizationStatus::Failure)
         continue;
       return result;
@@ -1288,11 +1288,12 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
   // 2. Constant ops don't get vectorized but rather broadcasted at their users.
   // Clone so that the constant is not confined to the linalgOp block .
   if (isa<arith::ConstantOp, func::ConstantOp>(op))
-    return VectorizationResult{VectorizationStatus::NewOp, rewriter.clone(*op)};
+    return VectorizationHookResult{VectorizationStatus::NewOp,
+                                   rewriter.clone(*op)};
 
   // 3. Only ElementwiseMappable are allowed in the generic vectorization.
   if (!OpTrait::hasElementwiseMappableTraits(op))
-    return VectorizationResult{VectorizationStatus::Failure, nullptr};
+    return VectorizationHookResult{VectorizationStatus::Failure, nullptr};
 
   // 4 . Check if the operation is a reduction.
   SmallVector<std::pair<Value, Value>> reductionOperands;
@@ -1315,7 +1316,7 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
         reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
                        reductionOperands[0].second, bvm);
     if (reduceOp)
-      return VectorizationResult{VectorizationStatus::NewOp, reduceOp};
+      return VectorizationHookResult{VectorizationStatus::NewOp, reduceOp};
   }
 
   // 5. Generic vectorization path for ElementwiseMappable ops.
@@ -1355,7 +1356,7 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
             : resultType);
   }
   //   d. Build and return the new op.
-  return VectorizationResult{
+  return VectorizationHookResult{
       VectorizationStatus::NewOp,
       rewriter.create(op->getLoc(), op->getName().getIdentifier(), vecOperands,
                       resultTypes, op->getAttrs())};
@@ -1460,28 +1461,28 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
   SmallVector<CustomVectorizationHook> hooks;
   // 4a. Register CustomVectorizationHook for yieldOp.
   CustomVectorizationHook vectorizeYield =
-      [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
+      [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
     return vectorizeLinalgYield(rewriter, op, bvm, state, linalgOp, newResults);
   };
   hooks.push_back(vectorizeYield);
 
   // 4b. Register CustomVectorizationHook for indexOp.
   CustomVectorizationHook vectorizeIndex =
-      [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
+      [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
     return vectorizeLinalgIndex(rewriter, state, op, linalgOp);
   };
   hooks.push_back(vectorizeIndex);
 
   // 4c. Register CustomVectorizationHook for extractOp.
   CustomVectorizationHook vectorizeExtract =
-      [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
+      [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
     return vectorizeTensorExtract(rewriter, state, op, linalgOp, bvm);
   };
   hooks.push_back(vectorizeExtract);
 
   // 5. Iteratively call `vectorizeOneOp` to each op in the slice.
   for (Operation &op : block->getOperations()) {
-    VectorizationResult result =
+    VectorizationHookResult result =
         vectorizeOneOp(rewriter, state, linalgOp, &op, bvm, hooks);
     if (result.status == VectorizationStatus::Failure) {
       LDBG("failed to vectorize: " << op << "\n");
@@ -2522,17 +2523,11 @@ bool mlir::linalg::hasVectorizationImpl(Operation *op) {
              tensor::InsertSliceOp>(op);
 }
 
-/// Emit a suitable vector form for an operation. If provided,
-/// `inputVectorSizes` are used to vectorize this operation.
-/// `inputVectorSizes` must match the rank of the iteration space of the
-/// operation and the input vector sizes must be greater than or equal to
-/// their counterpart iteration space sizes, if static. `inputVectorShapes`
-/// also allows the vectorization of operations with dynamic shapes.
-LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
-                                      ArrayRef<int64_t> inputVectorSizes,
-                                      ArrayRef<bool> inputScalableVecDims,
-                                      bool vectorizeNDExtract,
-                                      bool flatten1DDepthwiseConv) {
+FailureOr<VectorizationResult>
+mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
+                        ArrayRef<int64_t> inputVectorSizes,
+                        ArrayRef<bool> inputScalableVecDims,
+                        bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
   LDBG("Attempting to vectorize:\n" << *op << "\n");
   LDBG("Input vector sizes: ");
   LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
@@ -2614,12 +2609,7 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
     return failure();
   }
 
-  if (!results.empty())
-    rewriter.replaceOp(op, results);
-  else
-    rewriter.eraseOp(op);
-
-  return success();
+  return VectorizationResult({results});
 }
 
 LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,



More information about the Mlir-commits mailing list