[Mlir-commits] [mlir] [mlir] Return vectorized values instead of replacing (PR #144158)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jun 24 10:29:29 PDT 2025


https://github.com/Max191 updated https://github.com/llvm/llvm-project/pull/144158

>From b8a783bf4e0aa2c478d9fcafe8e283fdc6198d17 Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Thu, 19 Jun 2025 12:49:05 +0000
Subject: [PATCH 1/2] [mlir] Return vectorized values instead of replacing

Signed-off-by: Max Dawkins <max.dawkins at gmail.com>
---
 .../Dialect/Linalg/Transforms/Transforms.h    | 27 ++++---
 .../TransformOps/LinalgTransformOps.cpp       | 20 +++--
 .../Linalg/Transforms/Vectorization.cpp       | 78 ++++++++-----------
 3 files changed, 64 insertions(+), 61 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 2eef0a06d0eb4..248ffe4d53d91 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -771,17 +771,24 @@ LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value /*buffer*/);
 /// to work (these are checked by the vectorizer itself).
 bool hasVectorizationImpl(Operation *);
 
+/// Transformation information returned after vectorizing.
+struct VectorizationResult {
+  /// Results of the vectorization transform to replace the original operation.
+  SmallVector<Value> replacements;
+};
 /// Emit a suitable vector form for an operation. If provided,
-/// `inputVectorSizes` are used to vectorize this operation. `inputVectorSizes`
-/// must match the rank of the iteration space of the operation and the sizes
-/// must be smaller or equal than their counterpart interation space sizes, if
-/// static. `inputVectorShapes` also allows the vectorization of operations with
-/// dynamic shapes.
-LogicalResult vectorize(RewriterBase &rewriter, Operation *op,
-                        ArrayRef<int64_t> inputVectorSizes = {},
-                        ArrayRef<bool> inputScalableVecDims = {},
-                        bool vectorizeNDExtract = false,
-                        bool flatten1DDepthwiseConv = false);
+/// `inputVectorSizes` are used to vectorize this operation.
+/// `inputVectorSizes` must match the rank of the iteration space of the
+/// operation and the input vector sizes must be greater than or equal to
+/// their counterpart iteration space sizes, if static. `inputVectorShapes`
+/// also allows the vectorization of operations with dynamic shapes. Returns
+/// a VectorizationResult containing the results of the vectorized op, or
+/// failure if the transformation fails.
+FailureOr<VectorizationResult>
+vectorize(RewriterBase &rewriter, Operation *op,
+          ArrayRef<int64_t> inputVectorSizes = {},
+          ArrayRef<bool> inputScalableVecDims = {},
+          bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false);
 
 /// Emit a suitable vector form for a Copy op with fully static shape.
 LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index b2c28f5eed33c..07434507b6eb2 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3609,9 +3609,14 @@ struct VectorizationPattern : public RewritePattern {
     if (!linalg::hasVectorizationImpl(op))
       return rewriter.notifyMatchFailure(op,
                                          "Unsupported Op, cannot vectorize");
-    return vectorize(rewriter, op, /*inputVectorSizes=*/{},
-                     /*inputScalableVecDims=*/{}, vectorizeNDExtract,
-                     flatten1DDepthwiseConv);
+    FailureOr<VectorizationResult> vectorResults =
+        vectorize(rewriter, op, /*inputVectorSizes=*/{},
+                  /*inputScalableVecDims=*/{}, vectorizeNDExtract,
+                  flatten1DDepthwiseConv);
+    if (failed(vectorResults))
+      return failure();
+    rewriter.replaceOp(op, vectorResults->replacements);
+    return success();
   }
 
 private:
@@ -3700,13 +3705,14 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
       return mlir::emitSilenceableFailure(target->getLoc())
              << "Unsupported Op, cannot vectorize";
     }
-
-    if (failed(linalg::vectorize(rewriter, target, vectorSizes,
-                                 getScalableSizes(),
-                                 getVectorizeNdExtract().value_or(false)))) {
+    FailureOr<VectorizationResult> vectorResults =
+        linalg::vectorize(rewriter, target, vectorSizes, getScalableSizes(),
+                          getVectorizeNdExtract().value_or(false));
+    if (failed(vectorResults)) {
       return mlir::emitSilenceableFailure(target->getLoc())
              << "Attempted to vectorize, but failed";
     }
+    rewriter.replaceOp(target, vectorResults->replacements);
   }
 
   return DiagnosedSilenceableFailure::success();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index ff28bd7c48342..88d49c7af4d60 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -563,7 +563,7 @@ enum VectorizationStatus {
   // TODO: support values if Op vectorized to Many-Ops whose results we need to
   // aggregate for replacement.
 };
-struct VectorizationResult {
+struct VectorizationHookResult {
   /// Return status from vectorizing the current op.
   enum VectorizationStatus status = VectorizationStatus::Failure;
   /// New vectorized operation to replace the current op.
@@ -727,7 +727,7 @@ using CustomVectorizationPrecondition =
 // assuming all its vectorized operands are already in the IRMapping.
 // Return nullptr if the Operation cannot be vectorized.
 using CustomVectorizationHook =
-    std::function<VectorizationResult(Operation *, const IRMapping &)>;
+    std::function<VectorizationHookResult(Operation *, const IRMapping &)>;
 
 /// Helper function to vectorize the terminator of a `linalgOp`. New result
 /// vector values are appended to `newResults`. Return
@@ -736,13 +736,13 @@ using CustomVectorizationHook =
 /// using the `newResults` vector making them available to the vectorization
 /// algorithm for RAUW. This function is meant to be used as a
 /// CustomVectorizationHook.
-static VectorizationResult
+static VectorizationHookResult
 vectorizeLinalgYield(RewriterBase &rewriter, Operation *op,
                      const IRMapping &bvm, VectorizationState &state,
                      LinalgOp linalgOp, SmallVectorImpl<Value> &newResults) {
   auto yieldOp = dyn_cast<linalg::YieldOp>(op);
   if (!yieldOp)
-    return VectorizationResult{VectorizationStatus::Failure, nullptr};
+    return VectorizationHookResult{VectorizationStatus::Failure, nullptr};
   for (const auto &output : llvm::enumerate(yieldOp.getValues())) {
     // TODO: Scan for an opportunity for reuse.
     // TODO: use a map.
@@ -754,20 +754,20 @@ vectorizeLinalgYield(RewriterBase &rewriter, Operation *op,
       newResults.push_back(newResult);
   }
 
-  return VectorizationResult{VectorizationStatus::NoReplace, nullptr};
+  return VectorizationHookResult{VectorizationStatus::NoReplace, nullptr};
 }
 
 /// Helper function to vectorize the index operations of a `linalgOp`. Return
 /// VectorizationStatus::NewOp to signal the vectorization algorithm that it
 /// should map the produced operations. This function is meant to be used as a
 /// CustomVectorizationHook.
-static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter,
-                                                VectorizationState &state,
-                                                Operation *op,
-                                                LinalgOp linalgOp) {
+static VectorizationHookResult vectorizeLinalgIndex(RewriterBase &rewriter,
+                                                    VectorizationState &state,
+                                                    Operation *op,
+                                                    LinalgOp linalgOp) {
   IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
   if (!indexOp)
-    return VectorizationResult{VectorizationStatus::Failure, nullptr};
+    return VectorizationHookResult{VectorizationStatus::Failure, nullptr};
   auto loc = indexOp.getLoc();
   // Compute the static loop sizes of the index op.
   ArrayRef<int64_t> targetShape = state.getCanonicalVecShape();
@@ -781,7 +781,7 @@ static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter,
   // dimension of the iteration space since the vectorization algorithm in this
   // case can handle the broadcast.
   if (dim == targetShape.size() - 1)
-    return VectorizationResult{VectorizationStatus::NewOp, indexSteps};
+    return VectorizationHookResult{VectorizationStatus::NewOp, indexSteps};
   // Otherwise permute the targetShape to move the index dimension last,
   // broadcast the one-dimensional index vector to the permuted shape, and
   // finally transpose the broadcasted index vector to undo the permutation.
@@ -799,7 +799,7 @@ static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter,
   std::swap(transposition.back(), transposition[dim]);
   auto transposeOp =
       rewriter.create<vector::TransposeOp>(loc, broadCastOp, transposition);
-  return VectorizationResult{VectorizationStatus::NewOp, transposeOp};
+  return VectorizationHookResult{VectorizationStatus::NewOp, transposeOp};
 }
 
 /// Helper function to check if the tensor.extract can be vectorized by the
@@ -1100,12 +1100,12 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
 /// VectorizationStatus::NewOp to signal the vectorization algorithm that it
 /// should map the produced operations. This function is meant to be used as a
 /// CustomVectorizationHook.
-static VectorizationResult
+static VectorizationHookResult
 vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
                        Operation *op, LinalgOp linalgOp, const IRMapping &bvm) {
   tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
   if (!extractOp)
-    return VectorizationResult{VectorizationStatus::Failure, nullptr};
+    return VectorizationHookResult{VectorizationStatus::Failure, nullptr};
   auto loc = extractOp.getLoc();
 
   // Compute the static loop sizes of the extract op.
@@ -1137,7 +1137,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
     gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
 
     LDBG("Vectorised as gather load: " << extractOp << "\n");
-    return VectorizationResult{VectorizationStatus::NewOp, gatherOp};
+    return VectorizationHookResult{VectorizationStatus::NewOp, gatherOp};
   }
 
   // 2. Handle:
@@ -1201,7 +1201,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
         mlir::vector::maskOperation(rewriter, transferReadOp, allTrue);
 
     LDBG("Vectorised as scalar broadcast load: " << extractOp << "\n");
-    return VectorizationResult{VectorizationStatus::NewOp, maskedReadOp};
+    return VectorizationHookResult{VectorizationStatus::NewOp, maskedReadOp};
   }
 
   // 2b. Handle contiguous access.
@@ -1227,7 +1227,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
       inBounds);
 
   LDBG("Vectorised as contiguous load: " << extractOp);
-  return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
+  return VectorizationHookResult{VectorizationStatus::NewOp, transferReadOp};
 }
 
 /// Emit reduction operations if the shapes of the value to reduce is different
@@ -1269,7 +1269,7 @@ static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
 /// a topologically-sorted list of ops.
 /// This function does not update `bvm` but returns a VectorizationStatus that
 /// instructs the caller what `bvm` update needs to occur.
-static VectorizationResult
+static VectorizationHookResult
 vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
                LinalgOp linalgOp, Operation *op, const IRMapping &bvm,
                ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
@@ -1278,7 +1278,7 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
   // 1. Try to apply any CustomVectorizationHook.
   if (!customVectorizationHooks.empty()) {
     for (auto &customFunc : customVectorizationHooks) {
-      VectorizationResult result = customFunc(op, bvm);
+      VectorizationHookResult result = customFunc(op, bvm);
       if (result.status == VectorizationStatus::Failure)
         continue;
       return result;
@@ -1288,11 +1288,12 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
   // 2. Constant ops don't get vectorized but rather broadcasted at their users.
   // Clone so that the constant is not confined to the linalgOp block .
   if (isa<arith::ConstantOp, func::ConstantOp>(op))
-    return VectorizationResult{VectorizationStatus::NewOp, rewriter.clone(*op)};
+    return VectorizationHookResult{VectorizationStatus::NewOp,
+                                   rewriter.clone(*op)};
 
   // 3. Only ElementwiseMappable are allowed in the generic vectorization.
   if (!OpTrait::hasElementwiseMappableTraits(op))
-    return VectorizationResult{VectorizationStatus::Failure, nullptr};
+    return VectorizationHookResult{VectorizationStatus::Failure, nullptr};
 
   // 4 . Check if the operation is a reduction.
   SmallVector<std::pair<Value, Value>> reductionOperands;
@@ -1315,7 +1316,7 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
         reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
                        reductionOperands[0].second, bvm);
     if (reduceOp)
-      return VectorizationResult{VectorizationStatus::NewOp, reduceOp};
+      return VectorizationHookResult{VectorizationStatus::NewOp, reduceOp};
   }
 
   // 5. Generic vectorization path for ElementwiseMappable ops.
@@ -1355,7 +1356,7 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
             : resultType);
   }
   //   d. Build and return the new op.
-  return VectorizationResult{
+  return VectorizationHookResult{
       VectorizationStatus::NewOp,
       rewriter.create(op->getLoc(), op->getName().getIdentifier(), vecOperands,
                       resultTypes, op->getAttrs())};
@@ -1460,28 +1461,28 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
   SmallVector<CustomVectorizationHook> hooks;
   // 4a. Register CustomVectorizationHook for yieldOp.
   CustomVectorizationHook vectorizeYield =
-      [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
+      [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
     return vectorizeLinalgYield(rewriter, op, bvm, state, linalgOp, newResults);
   };
   hooks.push_back(vectorizeYield);
 
   // 4b. Register CustomVectorizationHook for indexOp.
   CustomVectorizationHook vectorizeIndex =
-      [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
+      [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
     return vectorizeLinalgIndex(rewriter, state, op, linalgOp);
   };
   hooks.push_back(vectorizeIndex);
 
   // 4c. Register CustomVectorizationHook for extractOp.
   CustomVectorizationHook vectorizeExtract =
-      [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
+      [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
     return vectorizeTensorExtract(rewriter, state, op, linalgOp, bvm);
   };
   hooks.push_back(vectorizeExtract);
 
   // 5. Iteratively call `vectorizeOneOp` to each op in the slice.
   for (Operation &op : block->getOperations()) {
-    VectorizationResult result =
+    VectorizationHookResult result =
         vectorizeOneOp(rewriter, state, linalgOp, &op, bvm, hooks);
     if (result.status == VectorizationStatus::Failure) {
       LDBG("failed to vectorize: " << op << "\n");
@@ -2522,17 +2523,11 @@ bool mlir::linalg::hasVectorizationImpl(Operation *op) {
              tensor::InsertSliceOp>(op);
 }
 
-/// Emit a suitable vector form for an operation. If provided,
-/// `inputVectorSizes` are used to vectorize this operation.
-/// `inputVectorSizes` must match the rank of the iteration space of the
-/// operation and the input vector sizes must be greater than or equal to
-/// their counterpart iteration space sizes, if static. `inputVectorShapes`
-/// also allows the vectorization of operations with dynamic shapes.
-LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
-                                      ArrayRef<int64_t> inputVectorSizes,
-                                      ArrayRef<bool> inputScalableVecDims,
-                                      bool vectorizeNDExtract,
-                                      bool flatten1DDepthwiseConv) {
+FailureOr<VectorizationResult>
+mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
+                        ArrayRef<int64_t> inputVectorSizes,
+                        ArrayRef<bool> inputScalableVecDims,
+                        bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
   LDBG("Attempting to vectorize:\n" << *op << "\n");
   LDBG("Input vector sizes: ");
   LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
@@ -2614,12 +2609,7 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
     return failure();
   }
 
-  if (!results.empty())
-    rewriter.replaceOp(op, results);
-  else
-    rewriter.eraseOp(op);
-
-  return success();
+  return VectorizationResult({results});
 }
 
 LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,

>From 9da0c1804cbf00eafe34a8942bf3c1baf89d58aa Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Tue, 24 Jun 2025 16:47:44 +0000
Subject: [PATCH 2/2] address comments

Signed-off-by: Max Dawkins <max.dawkins at gmail.com>
---
 .../Dialect/Linalg/Transforms/Transforms.h    | 15 ++---
 .../Linalg/Transforms/Vectorization.cpp       | 64 ++++++++++---------
 2 files changed, 42 insertions(+), 37 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 248ffe4d53d91..2863bb54184af 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -776,14 +776,13 @@ struct VectorizationResult {
   /// Results of the vectorization transform to replace the original operation.
   SmallVector<Value> replacements;
 };
-/// Emit a suitable vector form for an operation. If provided,
-/// `inputVectorSizes` are used to vectorize this operation.
-/// `inputVectorSizes` must match the rank of the iteration space of the
-/// operation and the input vector sizes must be greater than or equal to
-/// their counterpart iteration space sizes, if static. `inputVectorShapes`
-/// also allows the vectorization of operations with dynamic shapes. Returns
-/// a VectorizationResult containing the results of the vectorized op, or
-/// failure if the transformation fails.
+/// Returns a `VectorizationResult` containing the results of the vectorized op,
+/// or failure if the transformation fails. If provided, `inputVectorSizes` are
+/// used to vectorize this operation. `inputVectorSizes` must match the rank of
+/// the iteration space of the operation and the input vector sizes must be
+/// greater than or equal to their counterpart iteration space sizes, if static.
+/// `inputVectorShapes` also allows the vectorization of operations with dynamic
+/// shapes.
 FailureOr<VectorizationResult>
 vectorize(RewriterBase &rewriter, Operation *op,
           ArrayRef<int64_t> inputVectorSizes = {},
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 88d49c7af4d60..e0a0c4114ed97 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -550,9 +550,10 @@ enum class Conv1DOpOrder {
   Nwc  // Corresponds to operation that traverses the input in (n, w, c) order.
 };
 
-/// Helper data structure to represent the result of vectorization.
-/// In certain specific cases, like terminators, we do not want to propagate/
-enum VectorizationStatus {
+/// Helper data structure to represent the result of vectorization for a single
+/// operation. In certain specific cases, like terminators, we do not want to
+/// propagate.
+enum VectorizationHookStatus {
   /// Op failed to vectorize.
   Failure = 0,
   /// Op vectorized and custom function took care of replacement logic
@@ -563,9 +564,12 @@ enum VectorizationStatus {
   // TODO: support values if Op vectorized to Many-Ops whose results we need to
   // aggregate for replacement.
 };
+/// VectorizationHookResult contains the vectorized op returned from a
+/// CustomVectorizationHook. This is an internal implementation detail of
+/// linalg vectorization, not to be confused with VectorizationResult.
 struct VectorizationHookResult {
   /// Return status from vectorizing the current op.
-  enum VectorizationStatus status = VectorizationStatus::Failure;
+  enum VectorizationHookStatus status = VectorizationHookStatus::Failure;
   /// New vectorized operation to replace the current op.
   /// Replacement behavior is specified by `status`.
   Operation *newOp;
@@ -731,10 +735,10 @@ using CustomVectorizationHook =
 
 /// Helper function to vectorize the terminator of a `linalgOp`. New result
 /// vector values are appended to `newResults`. Return
-/// VectorizationStatus::NoReplace to signal the vectorization algorithm that it
-/// should not try to map produced operations and instead return the results
-/// using the `newResults` vector making them available to the vectorization
-/// algorithm for RAUW. This function is meant to be used as a
+/// VectorizationHookStatus::NoReplace to signal the vectorization algorithm
+/// that it should not try to map produced operations and instead return the
+/// results using the `newResults` vector making them available to the
+/// vectorization algorithm for RAUW. This function is meant to be used as a
 /// CustomVectorizationHook.
 static VectorizationHookResult
 vectorizeLinalgYield(RewriterBase &rewriter, Operation *op,
@@ -742,7 +746,7 @@ vectorizeLinalgYield(RewriterBase &rewriter, Operation *op,
                      LinalgOp linalgOp, SmallVectorImpl<Value> &newResults) {
   auto yieldOp = dyn_cast<linalg::YieldOp>(op);
   if (!yieldOp)
-    return VectorizationHookResult{VectorizationStatus::Failure, nullptr};
+    return VectorizationHookResult{VectorizationHookStatus::Failure, nullptr};
   for (const auto &output : llvm::enumerate(yieldOp.getValues())) {
     // TODO: Scan for an opportunity for reuse.
     // TODO: use a map.
@@ -754,11 +758,11 @@ vectorizeLinalgYield(RewriterBase &rewriter, Operation *op,
       newResults.push_back(newResult);
   }
 
-  return VectorizationHookResult{VectorizationStatus::NoReplace, nullptr};
+  return VectorizationHookResult{VectorizationHookStatus::NoReplace, nullptr};
 }
 
 /// Helper function to vectorize the index operations of a `linalgOp`. Return
-/// VectorizationStatus::NewOp to signal the vectorization algorithm that it
+/// VectorizationHookStatus::NewOp to signal the vectorization algorithm that it
 /// should map the produced operations. This function is meant to be used as a
 /// CustomVectorizationHook.
 static VectorizationHookResult vectorizeLinalgIndex(RewriterBase &rewriter,
@@ -767,7 +771,7 @@ static VectorizationHookResult vectorizeLinalgIndex(RewriterBase &rewriter,
                                                     LinalgOp linalgOp) {
   IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
   if (!indexOp)
-    return VectorizationHookResult{VectorizationStatus::Failure, nullptr};
+    return VectorizationHookResult{VectorizationHookStatus::Failure, nullptr};
   auto loc = indexOp.getLoc();
   // Compute the static loop sizes of the index op.
   ArrayRef<int64_t> targetShape = state.getCanonicalVecShape();
@@ -781,7 +785,7 @@ static VectorizationHookResult vectorizeLinalgIndex(RewriterBase &rewriter,
   // dimension of the iteration space since the vectorization algorithm in this
   // case can handle the broadcast.
   if (dim == targetShape.size() - 1)
-    return VectorizationHookResult{VectorizationStatus::NewOp, indexSteps};
+    return VectorizationHookResult{VectorizationHookStatus::NewOp, indexSteps};
   // Otherwise permute the targetShape to move the index dimension last,
   // broadcast the one-dimensional index vector to the permuted shape, and
   // finally transpose the broadcasted index vector to undo the permutation.
@@ -799,7 +803,7 @@ static VectorizationHookResult vectorizeLinalgIndex(RewriterBase &rewriter,
   std::swap(transposition.back(), transposition[dim]);
   auto transposeOp =
       rewriter.create<vector::TransposeOp>(loc, broadCastOp, transposition);
-  return VectorizationHookResult{VectorizationStatus::NewOp, transposeOp};
+  return VectorizationHookResult{VectorizationHookStatus::NewOp, transposeOp};
 }
 
 /// Helper function to check if the tensor.extract can be vectorized by the
@@ -1097,7 +1101,7 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
 }
 
 /// Helper function to vectorize the tensor.extract operations. Returns
-/// VectorizationStatus::NewOp to signal the vectorization algorithm that it
+/// VectorizationHookStatus::NewOp to signal the vectorization algorithm that it
 /// should map the produced operations. This function is meant to be used as a
 /// CustomVectorizationHook.
 static VectorizationHookResult
@@ -1105,7 +1109,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
                        Operation *op, LinalgOp linalgOp, const IRMapping &bvm) {
   tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
   if (!extractOp)
-    return VectorizationHookResult{VectorizationStatus::Failure, nullptr};
+    return VectorizationHookResult{VectorizationHookStatus::Failure, nullptr};
   auto loc = extractOp.getLoc();
 
   // Compute the static loop sizes of the extract op.
@@ -1137,7 +1141,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
     gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
 
     LDBG("Vectorised as gather load: " << extractOp << "\n");
-    return VectorizationHookResult{VectorizationStatus::NewOp, gatherOp};
+    return VectorizationHookResult{VectorizationHookStatus::NewOp, gatherOp};
   }
 
   // 2. Handle:
@@ -1201,7 +1205,8 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
         mlir::vector::maskOperation(rewriter, transferReadOp, allTrue);
 
     LDBG("Vectorised as scalar broadcast load: " << extractOp << "\n");
-    return VectorizationHookResult{VectorizationStatus::NewOp, maskedReadOp};
+    return VectorizationHookResult{VectorizationHookStatus::NewOp,
+                                   maskedReadOp};
   }
 
   // 2b. Handle contiguous access.
@@ -1227,7 +1232,8 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
       inBounds);
 
   LDBG("Vectorised as contiguous load: " << extractOp);
-  return VectorizationHookResult{VectorizationStatus::NewOp, transferReadOp};
+  return VectorizationHookResult{VectorizationHookStatus::NewOp,
+                                 transferReadOp};
 }
 
 /// Emit reduction operations if the shapes of the value to reduce is different
@@ -1267,8 +1273,8 @@ static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
 /// This function assumes all operands of `op` have been vectorized and are in
 /// the `bvm` mapping. As a consequence, this function is meant to be called  on
 /// a topologically-sorted list of ops.
-/// This function does not update `bvm` but returns a VectorizationStatus that
-/// instructs the caller what `bvm` update needs to occur.
+/// This function does not update `bvm` but returns a VectorizationHookStatus
+/// that instructs the caller what `bvm` update needs to occur.
 static VectorizationHookResult
 vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
                LinalgOp linalgOp, Operation *op, const IRMapping &bvm,
@@ -1279,7 +1285,7 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
   if (!customVectorizationHooks.empty()) {
     for (auto &customFunc : customVectorizationHooks) {
       VectorizationHookResult result = customFunc(op, bvm);
-      if (result.status == VectorizationStatus::Failure)
+      if (result.status == VectorizationHookStatus::Failure)
         continue;
       return result;
     }
@@ -1288,12 +1294,12 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
   // 2. Constant ops don't get vectorized but rather broadcasted at their users.
   // Clone so that the constant is not confined to the linalgOp block .
   if (isa<arith::ConstantOp, func::ConstantOp>(op))
-    return VectorizationHookResult{VectorizationStatus::NewOp,
+    return VectorizationHookResult{VectorizationHookStatus::NewOp,
                                    rewriter.clone(*op)};
 
   // 3. Only ElementwiseMappable are allowed in the generic vectorization.
   if (!OpTrait::hasElementwiseMappableTraits(op))
-    return VectorizationHookResult{VectorizationStatus::Failure, nullptr};
+    return VectorizationHookResult{VectorizationHookStatus::Failure, nullptr};
 
   // 4 . Check if the operation is a reduction.
   SmallVector<std::pair<Value, Value>> reductionOperands;
@@ -1316,7 +1322,7 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
         reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
                        reductionOperands[0].second, bvm);
     if (reduceOp)
-      return VectorizationHookResult{VectorizationStatus::NewOp, reduceOp};
+      return VectorizationHookResult{VectorizationHookStatus::NewOp, reduceOp};
   }
 
   // 5. Generic vectorization path for ElementwiseMappable ops.
@@ -1357,7 +1363,7 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
   }
   //   d. Build and return the new op.
   return VectorizationHookResult{
-      VectorizationStatus::NewOp,
+      VectorizationHookStatus::NewOp,
       rewriter.create(op->getLoc(), op->getName().getIdentifier(), vecOperands,
                       resultTypes, op->getAttrs())};
 }
@@ -1484,11 +1490,11 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
   for (Operation &op : block->getOperations()) {
     VectorizationHookResult result =
         vectorizeOneOp(rewriter, state, linalgOp, &op, bvm, hooks);
-    if (result.status == VectorizationStatus::Failure) {
+    if (result.status == VectorizationHookStatus::Failure) {
       LDBG("failed to vectorize: " << op << "\n");
       return failure();
     }
-    if (result.status == VectorizationStatus::NewOp) {
+    if (result.status == VectorizationHookStatus::NewOp) {
       Operation *maybeMaskedOp =
           state.maskOperation(rewriter, result.newOp, linalgOp);
       LDBG("New vector op: " << *maybeMaskedOp << "\n");
@@ -2609,7 +2615,7 @@ mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
     return failure();
   }
 
-  return VectorizationResult({results});
+  return VectorizationResult{results};
 }
 
 LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,



More information about the Mlir-commits mailing list