[Mlir-commits] [mlir] 96891f0 - Reland: [mlir][Vector][Affine] Improve affine vectorizer algorithm

Diego Caballero llvmlistbot at llvm.org
Thu Mar 11 14:21:06 PST 2021


Author: Diego Caballero
Date: 2021-03-12T00:19:50+02:00
New Revision: 96891f041850186dc77b0f7e740b93da43d65bd4

URL: https://github.com/llvm/llvm-project/commit/96891f041850186dc77b0f7e740b93da43d65bd4
DIFF: https://github.com/llvm/llvm-project/commit/96891f041850186dc77b0f7e740b93da43d65bd4.diff

LOG: Reland: [mlir][Vector][Affine] Improve affine vectorizer algorithm

This patch replaces the root-terminal vectorization approach implemented in the
Affine vectorizer with a topological order approach that vectorizes all the
operations within the target loop nest. These are the most important changes
introduced by the new algorithm:
  * Removed tracking of root and terminal ops. Existing vectorization
    functionality is preserved and extended so that loop nests without
    root-terminal chains can be vectorized.
  * Vectorizing a loop nest now only requires a single topological traversal.
  * A new vector loop nest is incrementally built along the vectorization
    process. The original scalar loop is kept intact. No cloning guard is needed
    to recover the scalar loop if vectorization fails. This approach also
    simplifies the challenging task of replacing a loop operation amid the
    vectorization process without invalidating the analysis information that
    depends on the original loop.
  * Vectorization of specific operations has been implemented as independent,
    preparing them to be moved to a potential vectorization interface.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D97442

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorUtils.h
    mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
    mlir/lib/Dialect/Vector/VectorUtils.cpp
    mlir/test/Dialect/Affine/SuperVectorize/vectorize_1d.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/VectorUtils.h
index 666603250f0a..56f8f6211ccc 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorUtils.h
@@ -19,6 +19,7 @@ namespace mlir {
 class AffineApplyOp;
 class AffineForOp;
 class AffineMap;
+class Block;
 class Location;
 class OpBuilder;
 class Operation;
@@ -98,8 +99,10 @@ Optional<SmallVector<int64_t, 4>> shapeRatio(VectorType superVectorType,
 /// Note that loopToVectorDim is a whole function map from which only enclosing
 /// loop information is extracted.
 ///
-/// Prerequisites: `opInst` is a vectorizable load or store operation (i.e. at
-/// most one invariant index along each AffineForOp of `loopToVectorDim`).
+/// Prerequisites: `indices` belong to a vectorizable load or store operation
+/// (i.e. at most one invariant index along each AffineForOp of
+/// `loopToVectorDim`). `insertPoint` is the insertion point for the vectorized
+/// load or store operation.
 ///
 /// Example 1:
 /// The following MLIR snippet:
@@ -151,7 +154,10 @@ Optional<SmallVector<int64_t, 4>> shapeRatio(VectorType superVectorType,
 /// `%arg0[%c0, %c0]` into vector<128xf32> which needs a 1-D vector broadcast.
 ///
 AffineMap
-makePermutationMap(Operation *op, ArrayRef<Value> indices,
+makePermutationMap(Block *insertPoint, ArrayRef<Value> indices,
+                   const DenseMap<Operation *, unsigned> &loopToVectorDim);
+AffineMap
+makePermutationMap(Operation *insertPoint, ArrayRef<Value> indices,
                    const DenseMap<Operation *, unsigned> &loopToVectorDim);
 
 /// Build the default minor identity map suitable for a vector transfer. This

diff  --git a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
index b3ebc342d5a4..eaf7da48e304 100644
--- a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
@@ -14,28 +14,13 @@
 #include "PassDetail.h"
 #include "mlir/Analysis/LoopAnalysis.h"
 #include "mlir/Analysis/NestedMatcher.h"
-#include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Analysis/Utils.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
-#include "mlir/Dialect/Affine/Passes.h"
 #include "mlir/Dialect/Affine/Utils.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Vector/VectorOps.h"
 #include "mlir/Dialect/Vector/VectorUtils.h"
-#include "mlir/IR/AffineExpr.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/Location.h"
-#include "mlir/IR/Types.h"
+#include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/Support/LLVM.h"
-#include "mlir/Transforms/FoldUtils.h"
-
-#include "llvm/ADT/DenseMap.h"
-#include "llvm/ADT/DenseSet.h"
-#include "llvm/ADT/SetVector.h"
-#include "llvm/ADT/SmallString.h"
-#include "llvm/ADT/SmallVector.h"
-#include "llvm/Support/CommandLine.h"
-#include "llvm/Support/Debug.h"
 
 using namespace mlir;
 using namespace vector;
@@ -252,61 +237,38 @@ using namespace vector;
 ///     fastest varying) ;
 ///  2. analyzing those patterns for profitability (TODO: and
 ///     interference);
-///  3. Then, for each pattern in order:
-///    a. applying iterative rewriting of the loop and the load operations in
-///       inner-to-outer order. Rewriting is implemented by coarsening the loops
-///       and turning load operations into opaque vector.transfer_read ops;
-///    b. keeping track of the load operations encountered as "roots" and the
-///       store operations as "terminals";
-///    c. traversing the use-def chains starting from the roots and iteratively
-///       propagating vectorized values. Scalar values that are encountered
-///       during this process must come from outside the scope of the current
-///       pattern (TODO: enforce this and generalize). Such a scalar value
-///       is vectorized only if it is a constant (into a vector splat). The
-///       non-constant case is not supported for now and results in the pattern
-///       failing to vectorize;
-///    d. performing a second traversal on the terminals (store ops) to
-///       rewriting the scalar value they write to memory into vector form.
-///       If the scalar value has been vectorized previously, we simply replace
-///       it by its vector form. Otherwise, if the scalar value is a constant,
-///       it is vectorized into a splat. In all other cases, vectorization for
-///       the pattern currently fails.
-///    e. if everything under the root AffineForOp in the current pattern
-///       vectorizes properly, we commit that loop to the IR. Otherwise we
-///       discard it and restore a previously cloned version of the loop. Thanks
-///       to the recursive scoping nature of matchers and captured patterns,
-///       this is transparently achieved by a simple RAII implementation.
-///    f. vectorization is applied on the next pattern in the list. Because
+///  3. then, for each pattern in order:
+///    a. applying iterative rewriting of the loops and all their nested
+///       operations in topological order. Rewriting is implemented by
+///       coarsening the loops and converting operations and operands to their
+///       vector forms. Processing operations in topological order is relatively
+///       simple due to the structured nature of the control-flow
+///       representation. This order ensures that all the operands of a given
+///       operation have been vectorized before the operation itself in a single
+///       traversal, except for operands defined outside of the loop nest. The
+///       algorithm can convert the following operations to their vector form:
+///         * Affine load and store operations are converted to opaque vector
+///           transfer read and write operations.
+///         * Scalar constant operations/operands are converted to vector
+///           constant operations (splat).
+///         * Uniform operands (only operands defined outside of the loop nest,
+///           for now) are broadcasted to a vector.
+///           TODO: Support more uniform cases.
+///         * The remaining operations in the loop nest are vectorized by
+///           widening their scalar types to vector types.
+///         * TODO: Add vectorization support for loops with 'iter_args' and
+///           more complex loops with divergent lbs and/or ubs.
+///    b. if everything under the root AffineForOp in the current pattern
+///       is vectorized properly, we commit that loop to the IR and remove the
+///       scalar loop. Otherwise, we discard the vectorized loop and keep the
+///       original scalar loop.
+///    c. vectorization is applied on the next pattern in the list. Because
 ///       pattern interference avoidance is not yet implemented and that we do
 ///       not support further vectorizing an already vector load we need to
 ///       re-verify that the pattern is still vectorizable. This is expected to
 ///       make cost models more 
diff icult to write and is subject to improvement
 ///       in the future.
 ///
-/// Points c. and d. above are worth additional comment. In most passes that
-/// do not change the type of operands, it is usually preferred to eagerly
-/// `replaceAllUsesWith`. Unfortunately this does not work for vectorization
-/// because during the use-def chain traversal, all the operands of an operation
-/// must be available in vector form. Trying to propagate eagerly makes the IR
-/// temporarily invalid and results in errors such as:
-///   `vectorize.mlir:308:13: error: 'addf' op requires the same type for all
-///   operands and results
-///      %s5 = addf %a5, %b5 : f32`
-///
-/// Lastly, we show a minimal example for which use-def chains rooted in load /
-/// vector.transfer_read are not enough. This is what motivated splitting
-/// terminal processing out of the use-def chains starting from loads. In the
-/// following snippet, there is simply no load::
-/// ```mlir
-/// func @fill(%A : memref<128xf32>) -> () {
-///   %f1 = constant 1.0 : f32
-///   affine.for %i0 = 0 to 32 {
-///     affine.store %f1, %A[%i0] : memref<128xf32, 0>
-///   }
-///   return
-/// }
-/// ```
-///
 /// Choice of loop transformation to support the algorithm:
 /// =======================================================
 /// The choice of loop transformation to apply for coarsening vectorized loops
@@ -527,7 +489,6 @@ using namespace vector;
 #define DEBUG_TYPE "early-vect"
 
 using llvm::dbgs;
-using llvm::SetVector;
 
 /// Forward declaration.
 static FilterFunctionType
@@ -632,199 +593,196 @@ static LogicalResult analyzeProfitability(ArrayRef<NestedMatch> matches,
 namespace {
 
 struct VectorizationState {
-  /// Adds an entry of pre/post vectorization operations in the state.
-  void registerReplacement(Operation *key, Operation *value);
-  /// When the current vectorization pattern is successful, this erases the
-  /// operations that were marked for erasure in the proper order and resets
-  /// the internal state for the next pattern.
-  void finishVectorizationPattern();
-
-  // In-order tracking of original Operation that have been vectorized.
-  // Erase in reverse order.
-  SmallVector<Operation *, 16> toErase;
-  // Set of Operation that have been vectorized (the values in the
-  // vectorizationMap for hashed access). The vectorizedSet is used in
-  // particular to filter the operations that have already been vectorized by
-  // this pattern, when iterating over nested loops in this pattern.
-  DenseSet<Operation *> vectorizedSet;
-  // Map of old scalar Operation to new vectorized Operation.
-  DenseMap<Operation *, Operation *> vectorizationMap;
-  // Map of old scalar Value to new vectorized Value.
-  DenseMap<Value, Value> replacementMap;
+
+  VectorizationState(MLIRContext *context) : builder(context) {}
+
+  /// Registers the vector replacement of a scalar operation and its result
+  /// values. Both operations must have the same number of results.
+  ///
+  /// This utility is used to register the replacement for the vast majority of
+  /// the vectorized operations.
+  ///
+  /// Example:
+  ///   * 'replaced': %0 = addf %1, %2 : f32
+  ///   * 'replacement': %0 = addf %1, %2 : vector<128xf32>
+  void registerOpVectorReplacement(Operation *replaced, Operation *replacement);
+
+  /// Registers the vector replacement of a scalar value. The replacement
+  /// operation should have a single result, which replaces the scalar value.
+  ///
+  /// This utility is used to register the vector replacement of block arguments
+  /// and operation results which are not directly vectorized (i.e., their
+  /// scalar version still exists after vectorization), like uniforms.
+  ///
+  /// Example:
+  ///   * 'replaced': block argument or operation outside of the vectorized
+  ///     loop.
+  ///   * 'replacement': %0 = vector.broadcast %1 : f32 to vector<128xf32>
+  void registerValueVectorReplacement(Value replaced, Operation *replacement);
+
+  /// Registers the scalar replacement of a scalar value. 'replacement' must be
+  /// scalar. Both values must be block arguments. Operation results should be
+  /// replaced using the 'registerOp*' utilitites.
+  ///
+  /// This utility is used to register the replacement of block arguments
+  /// that are within the loop to be vectorized and will continue being scalar
+  /// within the vector loop.
+  ///
+  /// Example:
+  ///   * 'replaced': induction variable of a loop to be vectorized.
+  ///   * 'replacement': new induction variable in the new vector loop.
+  void registerValueScalarReplacement(BlockArgument replaced,
+                                      BlockArgument replacement);
+
+  /// Returns in 'replacedVals' the scalar replacement for values in
+  /// 'inputVals'.
+  void getScalarValueReplacementsFor(ValueRange inputVals,
+                                     SmallVectorImpl<Value> &replacedVals);
+
+  /// Erases the scalar loop nest after its successful vectorization.
+  void finishVectorizationPattern(AffineForOp rootLoop);
+
+  // Used to build and insert all the new operations created. The insertion
+  // point is preserved and updated along the vectorization process.
+  OpBuilder builder;
+
+  // Maps input scalar operations to their vector counterparts.
+  DenseMap<Operation *, Operation *> opVectorReplacement;
+  // Maps input scalar values to their vector counterparts.
+  BlockAndValueMapping valueVectorReplacement;
+  // Maps input scalar values to their new scalar counterparts in the vector
+  // loop nest.
+  BlockAndValueMapping valueScalarReplacement;
+
+  // Maps the newly created vector loops to their vector dimension.
+  DenseMap<Operation *, unsigned> vecLoopToVecDim;
+
   // The strategy drives which loop to vectorize by which amount.
   const VectorizationStrategy *strategy;
-  // Use-def roots. These represent the starting points for the worklist in the
-  // vectorizeNonTerminals function. They consist of the subset of load
-  // operations that have been vectorized. They can be retrieved from
-  // `vectorizationMap` but it is convenient to keep track of them in a separate
-  // data structure.
-  DenseSet<Operation *> roots;
-  // Terminal operations for the worklist in the vectorizeNonTerminals
-  // function. They consist of the subset of store operations that have been
-  // vectorized. They can be retrieved from `vectorizationMap` but it is
-  // convenient to keep track of them in a separate data structure. Since they
-  // do not necessarily belong to use-def chains starting from loads (e.g
-  // storing a constant), we need to handle them in a post-pass.
-  DenseSet<Operation *> terminals;
-  // Checks that the type of `op` is AffineStoreOp and adds it to the terminals
-  // set.
-  void registerTerminal(Operation *op);
-  // Folder used to factor out constant creation.
-  OperationFolder *folder;
 
 private:
-  void registerReplacement(Value key, Value value);
+  /// Internal implementation to map input scalar values to new vector or scalar
+  /// values.
+  void registerValueVectorReplacementImpl(Value replaced, Value replacement);
+  void registerValueScalarReplacementImpl(Value replaced, Value replacement);
 };
 
 } // end namespace
 
-void VectorizationState::registerReplacement(Operation *key, Operation *value) {
-  LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ commit vectorized op: ");
-  LLVM_DEBUG(key->print(dbgs()));
-  LLVM_DEBUG(dbgs() << "  into  ");
-  LLVM_DEBUG(value->print(dbgs()));
-  assert(key->getNumResults() == 1 && "already registered");
-  assert(value->getNumResults() == 1 && "already registered");
-  assert(vectorizedSet.count(value) == 0 && "already registered");
-  assert(vectorizationMap.count(key) == 0 && "already registered");
-  toErase.push_back(key);
-  vectorizedSet.insert(value);
-  vectorizationMap.insert(std::make_pair(key, value));
-  registerReplacement(key->getResult(0), value->getResult(0));
-  if (isa<AffineLoadOp>(key)) {
-    assert(roots.count(key) == 0 && "root was already inserted previously");
-    roots.insert(key);
-  }
+/// Registers the vector replacement of a scalar operation and its result
+/// values. Both operations must have the same number of results.
+///
+/// This utility is used to register the replacement for the vast majority of
+/// the vectorized operations.
+///
+/// Example:
+///   * 'replaced': %0 = addf %1, %2 : f32
+///   * 'replacement': %0 = addf %1, %2 : vector<128xf32>
+void VectorizationState::registerOpVectorReplacement(Operation *replaced,
+                                                     Operation *replacement) {
+  LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ commit vectorized op:\n");
+  LLVM_DEBUG(dbgs() << *replaced << "\n");
+  LLVM_DEBUG(dbgs() << "into\n");
+  LLVM_DEBUG(dbgs() << *replacement << "\n");
+
+  assert(replaced->getNumResults() <= 1 && "Unsupported multi-result op");
+  assert(replaced->getNumResults() == replacement->getNumResults() &&
+         "Unexpected replaced and replacement results");
+  assert(opVectorReplacement.count(replaced) == 0 && "already registered");
+  opVectorReplacement[replaced] = replacement;
+
+  if (replaced->getNumResults() > 0)
+    registerValueVectorReplacementImpl(replaced->getResult(0),
+                                       replacement->getResult(0));
 }
 
-void VectorizationState::registerTerminal(Operation *op) {
-  assert(isa<AffineStoreOp>(op) && "terminal must be a AffineStoreOp");
-  assert(terminals.count(op) == 0 &&
-         "terminal was already inserted previously");
-  terminals.insert(op);
+/// Registers the vector replacement of a scalar value. The replacement
+/// operation should have a single result, which replaces the scalar value.
+///
+/// This utility is used to register the vector replacement of block arguments
+/// and operation results which are not directly vectorized (i.e., their
+/// scalar version still exists after vectorization), like uniforms.
+///
+/// Example:
+///   * 'replaced': block argument or operation outside of the vectorized loop.
+///   * 'replacement': %0 = vector.broadcast %1 : f32 to vector<128xf32>
+void VectorizationState::registerValueVectorReplacement(
+    Value replaced, Operation *replacement) {
+  assert(replacement->getNumResults() == 1 &&
+         "Expected single-result replacement");
+  if (Operation *defOp = replaced.getDefiningOp())
+    registerOpVectorReplacement(defOp, replacement);
+  else
+    registerValueVectorReplacementImpl(replaced, replacement->getResult(0));
 }
 
-void VectorizationState::finishVectorizationPattern() {
-  while (!toErase.empty()) {
-    auto *op = toErase.pop_back_val();
-    LLVM_DEBUG(dbgs() << "\n[early-vect] finishVectorizationPattern erase: ");
-    LLVM_DEBUG(op->print(dbgs()));
-    op->erase();
-  }
+void VectorizationState::registerValueVectorReplacementImpl(Value replaced,
+                                                            Value replacement) {
+  assert(!valueVectorReplacement.contains(replaced) &&
+         "Vector replacement already registered");
+  assert(replacement.getType().isa<VectorType>() &&
+         "Expected vector type in vector replacement");
+  valueVectorReplacement.map(replaced, replacement);
 }
 
-void VectorizationState::registerReplacement(Value key, Value value) {
-  assert(replacementMap.count(key) == 0 && "replacement already registered");
-  replacementMap.insert(std::make_pair(key, value));
+/// Registers the scalar replacement of a scalar value. 'replacement' must be
+/// scalar. Both values must be block arguments. Operation results should be
+/// replaced using the 'registerOp*' utilitites.
+///
+/// This utility is used to register the replacement of block arguments
+/// that are within the loop to be vectorized and will continue being scalar
+/// within the vector loop.
+///
+/// Example:
+///   * 'replaced': induction variable of a loop to be vectorized.
+///   * 'replacement': new induction variable in the new vector loop.
+void VectorizationState::registerValueScalarReplacement(
+    BlockArgument replaced, BlockArgument replacement) {
+  registerValueScalarReplacementImpl(replaced, replacement);
+}
+
+void VectorizationState::registerValueScalarReplacementImpl(Value replaced,
+                                                            Value replacement) {
+  assert(!valueScalarReplacement.contains(replaced) &&
+         "Scalar value replacement already registered");
+  assert(!replacement.getType().isa<VectorType>() &&
+         "Expected scalar type in scalar replacement");
+  valueScalarReplacement.map(replaced, replacement);
+}
+
+/// Returns in 'replacedVals' the scalar replacement for values in 'inputVals'.
+void VectorizationState::getScalarValueReplacementsFor(
+    ValueRange inputVals, SmallVectorImpl<Value> &replacedVals) {
+  for (Value inputVal : inputVals)
+    replacedVals.push_back(valueScalarReplacement.lookupOrDefault(inputVal));
+}
+
+/// Erases a loop nest, including all its nested operations.
+static void eraseLoopNest(AffineForOp forOp) {
+  LLVM_DEBUG(dbgs() << "[early-vect]+++++ erasing:\n" << forOp << "\n");
+  forOp.erase();
+}
+
+/// Erases the scalar loop nest after its successful vectorization.
+void VectorizationState::finishVectorizationPattern(AffineForOp rootLoop) {
+  LLVM_DEBUG(dbgs() << "\n[early-vect] Finalizing vectorization\n");
+  eraseLoopNest(rootLoop);
 }
 
 // Apply 'map' with 'mapOperands' returning resulting values in 'results'.
 static void computeMemoryOpIndices(Operation *op, AffineMap map,
                                    ValueRange mapOperands,
+                                   VectorizationState &state,
                                    SmallVectorImpl<Value> &results) {
-  OpBuilder builder(op);
   for (auto resultExpr : map.getResults()) {
     auto singleResMap =
         AffineMap::get(map.getNumDims(), map.getNumSymbols(), resultExpr);
-    auto afOp =
-        builder.create<AffineApplyOp>(op->getLoc(), singleResMap, mapOperands);
+    auto afOp = state.builder.create<AffineApplyOp>(op->getLoc(), singleResMap,
+                                                    mapOperands);
     results.push_back(afOp);
   }
 }
 
-////// TODO: Hoist to a VectorizationMaterialize.cpp when appropriate. ////
-
-/// Handles the vectorization of load and store MLIR operations.
-///
-/// AffineLoadOp operations are the roots of the vectorizeNonTerminals call.
-/// They are vectorized immediately. The resulting vector.transfer_read is
-/// immediately registered to replace all uses of the AffineLoadOp in this
-/// pattern's scope.
-///
-/// AffineStoreOp are the terminals of the vectorizeNonTerminals call. They
-/// need to be vectorized late once all the use-def chains have been traversed.
-/// Additionally, they may have ssa-values operands which come from outside the
-/// scope of the current pattern.
-/// Such special cases force us to delay the vectorization of the stores until
-/// the last step. Here we merely register the store operation.
-template <typename LoadOrStoreOpPointer>
-static LogicalResult vectorizeRootOrTerminal(Value iv,
-                                             LoadOrStoreOpPointer memoryOp,
-                                             VectorizationState *state) {
-  auto memRefType = memoryOp.getMemRef().getType().template cast<MemRefType>();
-
-  auto elementType = memRefType.getElementType();
-  // TODO: ponder whether we want to further vectorize a vector value.
-  assert(VectorType::isValidElementType(elementType) &&
-         "Not a valid vector element type");
-  auto vectorType = VectorType::get(state->strategy->vectorSizes, elementType);
-
-  // Materialize a MemRef with 1 vector.
-  auto *opInst = memoryOp.getOperation();
-  // For now, vector.transfers must be aligned, operate only on indices with an
-  // identity subset of AffineMap and do not change layout.
-  // TODO: increase the expressiveness power of vector.transfer operations
-  // as needed by various targets.
-  if (auto load = dyn_cast<AffineLoadOp>(opInst)) {
-    OpBuilder b(opInst);
-    ValueRange mapOperands = load.getMapOperands();
-    SmallVector<Value, 8> indices;
-    indices.reserve(load.getMemRefType().getRank());
-    if (load.getAffineMap() !=
-        b.getMultiDimIdentityMap(load.getMemRefType().getRank())) {
-      computeMemoryOpIndices(opInst, load.getAffineMap(), mapOperands, indices);
-    } else {
-      indices.append(mapOperands.begin(), mapOperands.end());
-    }
-    auto permutationMap =
-        makePermutationMap(opInst, indices, state->strategy->loopToVectorDim);
-    if (!permutationMap)
-      return failure();
-    LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: ");
-    LLVM_DEBUG(permutationMap.print(dbgs()));
-    auto transfer = b.create<vector::TransferReadOp>(
-        opInst->getLoc(), vectorType, memoryOp.getMemRef(), indices,
-        permutationMap);
-    state->registerReplacement(opInst, transfer.getOperation());
-  } else {
-    state->registerTerminal(opInst);
-  }
-  return success();
-}
-/// end TODO: Hoist to a VectorizationMaterialize.cpp when appropriate. ///
-
-/// Coarsens the loops bounds and transforms all remaining load and store
-/// operations into the appropriate vector.transfer.
-static LogicalResult vectorizeAffineForOp(AffineForOp loop, int64_t step,
-                                          VectorizationState *state) {
-  loop.setStep(step);
-
-  FilterFunctionType notVectorizedThisPattern = [state](Operation &op) {
-    if (!matcher::isLoadOrStore(op)) {
-      return false;
-    }
-    return state->vectorizationMap.count(&op) == 0 &&
-           state->vectorizedSet.count(&op) == 0 &&
-           state->roots.count(&op) == 0 && state->terminals.count(&op) == 0;
-  };
-  auto loadAndStores = matcher::Op(notVectorizedThisPattern);
-  SmallVector<NestedMatch, 8> loadAndStoresMatches;
-  loadAndStores.match(loop.getOperation(), &loadAndStoresMatches);
-  for (auto ls : loadAndStoresMatches) {
-    auto *opInst = ls.getMatchedOperation();
-    auto load = dyn_cast<AffineLoadOp>(opInst);
-    auto store = dyn_cast<AffineStoreOp>(opInst);
-    LLVM_DEBUG(opInst->print(dbgs()));
-    LogicalResult result =
-        load ? vectorizeRootOrTerminal(loop.getInductionVar(), load, state)
-             : vectorizeRootOrTerminal(loop.getInductionVar(), store, state);
-    if (failed(result)) {
-      return failure();
-    }
-  }
-  return success();
-}
-
 /// Returns a FilterFunctionType that can be used in NestedPattern to match a
 /// loop whose underlying load/store accesses are either invariant or all
 // varying along the `fastestVaryingMemRefDimension`.
@@ -846,68 +804,6 @@ isVectorizableLoopPtrFactory(const DenseSet<Operation *> &parallelLoops,
   };
 }
 
-/// Apply vectorization of `loop` according to `state`. `loops` are processed in
-/// inner-to-outer order to ensure that all the children loops have already been
-/// vectorized before vectorizing the parent loop.
-static LogicalResult
-vectorizeLoopsAndLoads(std::vector<SmallVector<AffineForOp, 2>> &loops,
-                       VectorizationState *state) {
-  // Vectorize loops in inner-to-outer order. If any children fails, the parent
-  // will fail too.
-  for (auto &loopsInLevel : llvm::reverse(loops)) {
-    for (AffineForOp loop : loopsInLevel) {
-      // 1. This loop may have been omitted from vectorization for various
-      // reasons (e.g. due to the performance model or pattern depth > vector
-      // size).
-      auto it = state->strategy->loopToVectorDim.find(loop.getOperation());
-      if (it == state->strategy->loopToVectorDim.end())
-        continue;
-
-      // 2. Actual inner-to-outer transformation.
-      auto vectorDim = it->second;
-      assert(vectorDim < state->strategy->vectorSizes.size() &&
-             "vector dim overflow");
-      //   a. get actual vector size
-      auto vectorSize = state->strategy->vectorSizes[vectorDim];
-      //   b. loop transformation for early vectorization is still subject to
-      //     exploratory tradeoffs (see top of the file). Apply coarsening,
-      //     i.e.:
-      //        | ub -> ub
-      //        | step -> step * vectorSize
-      LLVM_DEBUG(dbgs() << "\n[early-vect] vectorizeForOp by " << vectorSize
-                        << " : \n"
-                        << loop);
-      if (failed(
-              vectorizeAffineForOp(loop, loop.getStep() * vectorSize, state)))
-        return failure();
-    } // end for.
-  }
-
-  return success();
-}
-
-/// Tries to transform a scalar constant into a vector splat of that constant.
-/// Returns the vectorized splat operation if the constant is a valid vector
-/// element type.
-/// If `type` is not a valid vector type or if the scalar constant is not a
-/// valid vector element type, returns nullptr.
-static Value vectorizeConstant(Operation *op, ConstantOp constant, Type type) {
-  if (!type || !type.isa<VectorType>() ||
-      !VectorType::isValidElementType(constant.getType())) {
-    return nullptr;
-  }
-  OpBuilder b(op);
-  Location loc = op->getLoc();
-  auto vectorType = type.cast<VectorType>();
-  auto attr = DenseElementsAttr::get(vectorType, constant.getValue());
-  auto *constantOpInst = constant.getOperation();
-
-  OperationState state(loc, constantOpInst->getName().getStringRef(), {},
-                       {vectorType}, {b.getNamedAttr("value", attr)});
-
-  return b.createOperation(state)->getResult(0);
-}
-
 /// Returns the vector type resulting from applying the provided vectorization
 /// strategy on the scalar type.
 static VectorType getVectorType(Type scalarTy,
@@ -916,6 +812,24 @@ static VectorType getVectorType(Type scalarTy,
   return VectorType::get(strategy->vectorSizes, scalarTy);
 }
 
+/// Tries to transform a scalar constant into a vector constant. Returns the
+/// vector constant if the scalar type is valid vector element type. Returns
+/// nullptr, otherwise.
+static ConstantOp vectorizeConstant(ConstantOp constOp,
+                                    VectorizationState &state) {
+  Type scalarTy = constOp.getType();
+  if (!VectorType::isValidElementType(scalarTy))
+    return nullptr;
+
+  auto vecTy = getVectorType(scalarTy, state.strategy);
+  auto vecAttr = DenseElementsAttr::get(vecTy, constOp.getValue());
+  auto newConstOp = state.builder.create<ConstantOp>(constOp.getLoc(), vecAttr);
+
+  // Register vector replacement for future uses in the scope.
+  state.registerOpVectorReplacement(constOp, newConstOp);
+  return newConstOp;
+}
+
 /// Returns true if the provided value is vector uniform given the vectorization
 /// strategy.
 // TODO: For now, only values that are invariants to all the loops in the
@@ -932,32 +846,27 @@ static bool isUniformDefinition(Value value,
 
 /// Generates a broadcast op for the provided uniform value using the
 /// vectorization strategy in 'state'.
-static Value vectorizeUniform(Value value, VectorizationState *state) {
-  OpBuilder builder(value.getContext());
-  builder.setInsertionPointAfterValue(value);
-
-  auto vectorTy = getVectorType(value.getType(), state->strategy);
-  auto bcast = builder.create<BroadcastOp>(value.getLoc(), vectorTy, value);
-
-  // Add broadcast to the replacement map to reuse it for other uses.
-  state->replacementMap[value] = bcast;
-  return bcast;
+static Operation *vectorizeUniform(Value uniformVal,
+                                   VectorizationState &state) {
+  OpBuilder::InsertionGuard guard(state.builder);
+  state.builder.setInsertionPointAfterValue(uniformVal);
+
+  auto vectorTy = getVectorType(uniformVal.getType(), state.strategy);
+  auto bcastOp = state.builder.create<BroadcastOp>(uniformVal.getLoc(),
+                                                   vectorTy, uniformVal);
+  state.registerValueVectorReplacement(uniformVal, bcastOp);
+  return bcastOp;
 }
 
-/// Tries to vectorize a given operand `op` of Operation `op` during
-/// def-chain propagation or during terminal vectorization, by applying the
-/// following logic:
-/// 1. if the defining operation is part of the vectorizedSet (i.e. vectorized
-///    useby -def propagation), `op` is already in the proper vector form;
-/// 2. otherwise, the `op` may be in some other vector form that fails to
-///    vectorize atm (i.e. broadcasting required), returns nullptr to indicate
-///    failure;
-/// 3. if the `op` is a constant, returns the vectorized form of the constant;
-/// 4. if the `op` is uniform, returns a vector broadcast of the `op`;
-/// 5. non-constant scalars are currently non-vectorizable, in particular to
-///    guard against vectorizing an index which may be loop-variant and needs
-///    special handling.
-///
+/// Tries to vectorize a given `operand` by applying the following logic:
+/// 1. if the defining operation has been already vectorized, `operand` is
+///    already in the proper vector form;
+/// 2. if the `operand` is a constant, returns the vectorized form of the
+///    constant;
+/// 3. if the `operand` is uniform, returns a vector broadcast of the `op`;
+/// 4. otherwise, the vectorization of `operand` is not supported.
+/// Newly created vector operations are registered in `state` as replacement
+/// for their scalar counterparts.
 /// In particular this logic captures some of the use cases where definitions
 /// that are not scoped under the current pattern are needed to vectorize.
 /// One such example is top level function constants that need to be splatted.
@@ -966,112 +875,213 @@ static Value vectorizeUniform(Value value, VectorizationState *state) {
 /// vectorization is possible with the above logic. Returns nullptr otherwise.
 ///
 /// TODO: handle more complex cases.
-static Value vectorizeOperand(Value operand, Operation *op,
-                              VectorizationState *state) {
-  LLVM_DEBUG(dbgs() << "\n[early-vect]vectorize operand: " << operand);
-  // 1. If this value has already been vectorized this round, we are done.
-  if (state->vectorizedSet.count(operand.getDefiningOp()) > 0) {
-    LLVM_DEBUG(dbgs() << " -> already vector operand");
-    return operand;
+static Value vectorizeOperand(Value operand, VectorizationState &state) {
+  LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorize operand: " << operand);
+  // If this value is already vectorized, we are done.
+  if (Value vecRepl = state.valueVectorReplacement.lookupOrNull(operand)) {
+    LLVM_DEBUG(dbgs() << " -> already vectorized: " << vecRepl);
+    return vecRepl;
   }
-  // 1.b. Delayed on-demand replacement of a use.
-  //    Note that we cannot just call replaceAllUsesWith because it may result
-  //    in ops with mixed types, for ops whose operands have not all yet
-  //    been vectorized. This would be invalid IR.
-  auto it = state->replacementMap.find(operand);
-  if (it != state->replacementMap.end()) {
-    auto res = it->second;
-    LLVM_DEBUG(dbgs() << "-> delayed replacement by: " << res);
-    return res;
+
+  // An vector operand that is not in the replacement map should never reach
+  // this point. Reaching this point could mean that the code was already
+  // vectorized and we shouldn't try to vectorize already vectorized code.
+  assert(!operand.getType().isa<VectorType>() &&
+         "Vector op not found in replacement map");
+
+  // Vectorize constant.
+  if (auto constOp = operand.getDefiningOp<ConstantOp>()) {
+    ConstantOp vecConstant = vectorizeConstant(constOp, state);
+    LLVM_DEBUG(dbgs() << "-> constant: " << vecConstant);
+    return vecConstant.getResult();
   }
-  // 2. TODO: broadcast needed.
-  if (operand.getType().isa<VectorType>()) {
-    LLVM_DEBUG(dbgs() << "-> non-vectorizable");
-    return nullptr;
+
+  // Vectorize uniform values.
+  if (isUniformDefinition(operand, state.strategy)) {
+    Operation *vecUniform = vectorizeUniform(operand, state);
+    LLVM_DEBUG(dbgs() << "-> uniform: " << *vecUniform);
+    return vecUniform->getResult(0);
   }
-  // 3. vectorize constant.
-  if (auto constant = operand.getDefiningOp<ConstantOp>())
-    return vectorizeConstant(op, constant,
-                             getVectorType(operand.getType(), state->strategy));
 
-  // 4. Uniform values.
-  if (isUniformDefinition(operand, state->strategy))
-    return vectorizeUniform(operand, state);
+  // Check for unsupported block argument scenarios. A supported block argument
+  // should have been vectorized already.
+  if (!operand.getDefiningOp())
+    LLVM_DEBUG(dbgs() << "-> unsupported block argument\n");
+  else
+    // Generic unsupported case.
+    LLVM_DEBUG(dbgs() << "-> non-vectorizable\n");
 
-  // 5. currently non-vectorizable.
-  LLVM_DEBUG(dbgs() << "-> non-vectorizable: " << operand);
   return nullptr;
 }
 
-/// Encodes Operation-specific behavior for vectorization. In general we assume
-/// that all operands of an op must be vectorized but this is not always true.
-/// In the future, it would be nice to have a trait that describes how a
-/// particular operation vectorizes. For now we implement the case distinction
-/// here.
-/// Returns a vectorized form of an operation or nullptr if vectorization fails.
-// TODO: consider adding a trait to Op to describe how it gets vectorized.
-// Maybe some Ops are not vectorizable or require some tricky logic, we cannot
-// do one-off logic here; ideally it would be TableGen'd.
-static Operation *vectorizeOneOperation(Operation *opInst,
-                                        VectorizationState *state) {
-  // Sanity checks.
-  assert(!isa<AffineLoadOp>(opInst) &&
-         "all loads must have already been fully vectorized independently");
-  assert(!isa<vector::TransferReadOp>(opInst) &&
-         "vector.transfer_read cannot be further vectorized");
-  assert(!isa<vector::TransferWriteOp>(opInst) &&
-         "vector.transfer_write cannot be further vectorized");
+/// Vectorizes an affine load with the vectorization strategy in 'state' by
+/// generating a 'vector.transfer_read' op with the proper permutation map
+/// inferred from the indices of the load. The new 'vector.transfer_read' is
+/// registered as replacement of the scalar load. Returns the newly created
+/// 'vector.transfer_read' if vectorization was successful. Returns nullptr,
+/// otherwise.
+static Operation *vectorizeAffineLoad(AffineLoadOp loadOp,
+                                      VectorizationState &state) {
+  MemRefType memRefType = loadOp.getMemRefType();
+  Type elementType = memRefType.getElementType();
+  auto vectorType = VectorType::get(state.strategy->vectorSizes, elementType);
+
+  // Replace map operands with operands from the vector loop nest.
+  SmallVector<Value, 8> mapOperands;
+  state.getScalarValueReplacementsFor(loadOp.getMapOperands(), mapOperands);
+
+  // Compute indices for the transfer op. AffineApplyOp's may be generated.
+  SmallVector<Value, 8> indices;
+  indices.reserve(memRefType.getRank());
+  if (loadOp.getAffineMap() !=
+      state.builder.getMultiDimIdentityMap(memRefType.getRank()))
+    computeMemoryOpIndices(loadOp, loadOp.getAffineMap(), mapOperands, state,
+                           indices);
+  else
+    indices.append(mapOperands.begin(), mapOperands.end());
+
+  // Compute permutation map using the information of new vector loops.
+  auto permutationMap = makePermutationMap(state.builder.getInsertionBlock(),
+                                           indices, state.vecLoopToVecDim);
+  if (!permutationMap) {
+    LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ can't compute permutationMap\n");
+    return nullptr;
+  }
+  LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: ");
+  LLVM_DEBUG(permutationMap.print(dbgs()));
 
-  if (auto store = dyn_cast<AffineStoreOp>(opInst)) {
-    OpBuilder b(opInst);
-    auto memRef = store.getMemRef();
-    auto value = store.getValueToStore();
-    auto vectorValue = vectorizeOperand(value, opInst, state);
-    if (!vectorValue)
-      return nullptr;
+  auto transfer = state.builder.create<vector::TransferReadOp>(
+      loadOp.getLoc(), vectorType, loadOp.getMemRef(), indices, permutationMap);
 
-    ValueRange mapOperands = store.getMapOperands();
-    SmallVector<Value, 8> indices;
-    indices.reserve(store.getMemRefType().getRank());
-    if (store.getAffineMap() !=
-        b.getMultiDimIdentityMap(store.getMemRefType().getRank())) {
-      computeMemoryOpIndices(opInst, store.getAffineMap(), mapOperands,
-                             indices);
-    } else {
-      indices.append(mapOperands.begin(), mapOperands.end());
-    }
+  // Register replacement for future uses in the scope.
+  state.registerOpVectorReplacement(loadOp, transfer);
+  return transfer;
+}
 
-    auto permutationMap =
-        makePermutationMap(opInst, indices, state->strategy->loopToVectorDim);
-    if (!permutationMap)
-      return nullptr;
-    LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: ");
-    LLVM_DEBUG(permutationMap.print(dbgs()));
-    auto transfer = b.create<vector::TransferWriteOp>(
-        opInst->getLoc(), vectorValue, memRef, indices, permutationMap);
-    auto *res = transfer.getOperation();
-    LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorized store: " << *res);
-    // "Terminals" (i.e. AffineStoreOps) are erased on the spot.
-    opInst->erase();
-    return res;
-  }
-  if (opInst->getNumRegions() != 0)
+/// Vectorizes an affine store with the vectorization strategy in 'state' by
+/// generating a 'vector.transfer_write' op with the proper permutation map
+/// inferred from the indices of the store. The new 'vector.transfer_store' is
+/// registered as replacement of the scalar load. Returns the newly created
+/// 'vector.transfer_write' if vectorization was successful. Returns nullptr,
+/// otherwise.
+static Operation *vectorizeAffineStore(AffineStoreOp storeOp,
+                                       VectorizationState &state) {
+  MemRefType memRefType = storeOp.getMemRefType();
+  Value vectorValue = vectorizeOperand(storeOp.getValueToStore(), state);
+  if (!vectorValue)
+    return nullptr;
+
+  // Replace map operands with operands from the vector loop nest.
+  SmallVector<Value, 8> mapOperands;
+  state.getScalarValueReplacementsFor(storeOp.getMapOperands(), mapOperands);
+
+  // Compute indices for the transfer op. AffineApplyOp's may be generated.
+  SmallVector<Value, 8> indices;
+  indices.reserve(memRefType.getRank());
+  if (storeOp.getAffineMap() !=
+      state.builder.getMultiDimIdentityMap(memRefType.getRank()))
+    computeMemoryOpIndices(storeOp, storeOp.getAffineMap(), mapOperands, state,
+                           indices);
+  else
+    indices.append(mapOperands.begin(), mapOperands.end());
+
+  // Compute permutation map using the information of new vector loops.
+  auto permutationMap = makePermutationMap(state.builder.getInsertionBlock(),
+                                           indices, state.vecLoopToVecDim);
+  if (!permutationMap)
+    return nullptr;
+  LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: ");
+  LLVM_DEBUG(permutationMap.print(dbgs()));
+
+  auto transfer = state.builder.create<vector::TransferWriteOp>(
+      storeOp.getLoc(), vectorValue, storeOp.getMemRef(), indices,
+      permutationMap);
+  LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorized store: " << transfer);
+
+  // Register replacement for future uses in the scope.
+  state.registerOpVectorReplacement(storeOp, transfer);
+  return transfer;
+}
+
+/// Vectorizes a loop with the vectorization strategy in 'state'. A new loop is
+/// created and registered as replacement for the scalar loop. The builder's
+/// insertion point is set to the new loop's body so that subsequent vectorized
+/// operations are inserted into the new loop. If the loop is a vector
+/// dimension, the step of the newly created loop will reflect the vectorization
+/// factor used to vectorized that dimension.
+// TODO: Add support for 'iter_args'. Related operands and results will be
+// vectorized at this point.
+static Operation *vectorizeAffineForOp(AffineForOp forOp,
+                                       VectorizationState &state) {
+  // 'iter_args' not supported yet.
+  if (forOp.getNumIterOperands() > 0)
     return nullptr;
 
+  // If we are vectorizing a vector dimension, compute a new step for the new
+  // vectorized loop using the vectorization factor for the vector dimension.
+  // Otherwise, propagate the step of the scalar loop.
+  const VectorizationStrategy &strategy = *state.strategy;
+  auto loopToVecDimIt = strategy.loopToVectorDim.find(forOp);
+  bool isLoopVecDim = loopToVecDimIt != strategy.loopToVectorDim.end();
+  unsigned newStep;
+  if (isLoopVecDim) {
+    unsigned vectorDim = loopToVecDimIt->second;
+    assert(vectorDim < strategy.vectorSizes.size() && "vector dim overflow");
+    int64_t forOpVecFactor = strategy.vectorSizes[vectorDim];
+    newStep = forOp.getStep() * forOpVecFactor;
+  } else {
+    newStep = forOp.getStep();
+  }
+
+  auto vecForOp = state.builder.create<AffineForOp>(
+      forOp.getLoc(), forOp.getLowerBoundOperands(), forOp.getLowerBoundMap(),
+      forOp.getUpperBoundOperands(), forOp.getUpperBoundMap(), newStep,
+      forOp.getIterOperands(),
+      /*bodyBuilder=*/[](OpBuilder &, Location, Value, ValueRange) {
+        // Make sure we don't create a default terminator in the loop body as
+        // the proper terminator will be added during vectorization.
+        return;
+      });
+
+  // Register loop-related replacements:
+  //   1) The new vectorized loop is registered as vector replacement of the
+  //      scalar loop.
+  //      TODO: Support reductions along the vector dimension.
+  //   2) The new iv of the vectorized loop is registered as scalar replacement
+  //      since a scalar copy of the iv will prevail in the vectorized loop.
+  //      TODO: A vector replacement will also be added in the future when
+  //      vectorization of linear ops is supported.
+  //   3) TODO: Support 'iter_args' along non-vector dimensions.
+  state.registerOpVectorReplacement(forOp, vecForOp);
+  state.registerValueScalarReplacement(forOp.getInductionVar(),
+                                       vecForOp.getInductionVar());
+  // Map the new vectorized loop to its vector dimension.
+  if (isLoopVecDim)
+    state.vecLoopToVecDim[vecForOp] = loopToVecDimIt->second;
+
+  // Change insertion point so that upcoming vectorized instructions are
+  // inserted into the vectorized loop's body.
+  state.builder.setInsertionPointToStart(vecForOp.getBody());
+  return vecForOp;
+}
+
+/// Vectorizes arbitrary operation by plain widening. We apply generic type
+/// widening of all its results and retrieve the vector counterparts for all its
+/// operands.
+static Operation *widenOp(Operation *op, VectorizationState &state) {
   SmallVector<Type, 8> vectorTypes;
-  for (auto v : opInst->getResults()) {
+  for (Value result : op->getResults())
     vectorTypes.push_back(
-        VectorType::get(state->strategy->vectorSizes, v.getType()));
-  }
+        VectorType::get(state.strategy->vectorSizes, result.getType()));
+
   SmallVector<Value, 8> vectorOperands;
-  for (auto v : opInst->getOperands()) {
-    vectorOperands.push_back(vectorizeOperand(v, opInst, state));
-  }
-  // Check whether a single operand is null. If so, vectorization failed.
-  bool success = llvm::all_of(vectorOperands, [](Value op) { return op; });
-  if (!success) {
-    LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ an operand failed vectorize");
-    return nullptr;
+  for (Value operand : op->getOperands()) {
+    Value vecOperand = vectorizeOperand(operand, state);
+    if (!vecOperand) {
+      LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ an operand failed vectorize\n");
+      return nullptr;
+    }
+    vectorOperands.push_back(vecOperand);
   }
 
   // Create a clone of the op with the proper operands and return types.
@@ -1079,59 +1089,64 @@ static Operation *vectorizeOneOperation(Operation *opInst,
   // name that works both in scalar mode and vector mode.
   // TODO: Is it worth considering an Operation.clone operation which
   // changes the type so we can promote an Operation with less boilerplate?
-  OpBuilder b(opInst);
-  OperationState newOp(opInst->getLoc(), opInst->getName().getStringRef(),
-                       vectorOperands, vectorTypes, opInst->getAttrs(),
-                       /*successors=*/{}, /*regions=*/{});
-  return b.createOperation(newOp);
+  OperationState vecOpState(op->getLoc(), op->getName().getStringRef(),
+                            vectorOperands, vectorTypes, op->getAttrs(),
+                            /*successors=*/{}, /*regions=*/{});
+  Operation *vecOp = state.builder.createOperation(vecOpState);
+  state.registerOpVectorReplacement(op, vecOp);
+  return vecOp;
 }
 
-/// Iterates over the forward slice from the loads in the vectorization pattern
-/// and rewrites them using their vectorized counterpart by:
-///   1. Create the forward slice starting from the loads in the vectorization
-///   pattern.
-///   2. Topologically sorts the forward slice.
-///   3. For each operation in the slice, create the vector form of this
-///   operation, replacing each operand by a replacement operands retrieved from
-///   replacementMap. If any such replacement is missing, vectorization fails.
-static LogicalResult vectorizeNonTerminals(VectorizationState *state) {
-  // 1. create initial worklist with the uses of the roots.
-  SetVector<Operation *> worklist;
-  // Note: state->roots have already been vectorized and must not be vectorized
-  // again. This fits `getForwardSlice` which does not insert `op` in the
-  // result.
-  // Note: we have to exclude terminals because some of their defs may not be
-  // nested under the vectorization pattern (e.g. constants defined in an
-  // encompassing scope).
-  // TODO: Use a backward slice for terminals, avoid special casing and
-  // merge implementations.
-  for (auto *op : state->roots) {
-    getForwardSlice(op, &worklist, [state](Operation *op) {
-      return state->terminals.count(op) == 0; // propagate if not terminal
-    });
-  }
-  // We merged multiple slices, topological order may not hold anymore.
-  worklist = topologicalSort(worklist);
-
-  for (unsigned i = 0; i < worklist.size(); ++i) {
-    auto *op = worklist[i];
-    LLVM_DEBUG(dbgs() << "\n[early-vect] vectorize use: ");
-    LLVM_DEBUG(op->print(dbgs()));
-
-    // Create vector form of the operation.
-    // Insert it just before op, on success register op as replaced.
-    auto *vectorizedInst = vectorizeOneOperation(op, state);
-    if (!vectorizedInst) {
-      return failure();
-    }
+/// Vectorizes a yield operation by widening its types. The builder's insertion
+/// point is set after the vectorized parent op to continue vectorizing the
+/// operations after the parent op.
+static Operation *vectorizeAffineYieldOp(AffineYieldOp yieldOp,
+                                         VectorizationState &state) {
+  // 'iter_args' not supported yet.
+  if (yieldOp.getNumOperands() > 0)
+    return nullptr;
 
-    // 3. Register replacement for future uses in the scope.
-    //    Note that we cannot just call replaceAllUsesWith because it may
-    //    result in ops with mixed types, for ops whose operands have not all
-    //    yet been vectorized. This would be invalid IR.
-    state->registerReplacement(op, vectorizedInst);
-  }
-  return success();
+  // Vectorize the yield op and change the insertion point right after the new
+  // parent op.
+  Operation *newYieldOp = widenOp(yieldOp, state);
+  Operation *newParentOp = state.builder.getInsertionBlock()->getParentOp();
+  state.builder.setInsertionPointAfter(newParentOp);
+  return newYieldOp;
+}
+
+/// Encodes Operation-specific behavior for vectorization. In general we
+/// assume that all operands of an op must be vectorized but this is not
+/// always true. In the future, it would be nice to have a trait that
+/// describes how a particular operation vectorizes. For now we implement the
+/// case distinction here. Returns a vectorized form of an operation or
+/// nullptr if vectorization fails.
+// TODO: consider adding a trait to Op to describe how it gets vectorized.
+// Maybe some Ops are not vectorizable or require some tricky logic, we cannot
+// do one-off logic here; ideally it would be TableGen'd.
+static Operation *vectorizeOneOperation(Operation *op,
+                                        VectorizationState &state) {
+  // Sanity checks.
+  assert(!isa<vector::TransferReadOp>(op) &&
+         "vector.transfer_read cannot be further vectorized");
+  assert(!isa<vector::TransferWriteOp>(op) &&
+         "vector.transfer_write cannot be further vectorized");
+
+  if (auto loadOp = dyn_cast<AffineLoadOp>(op))
+    return vectorizeAffineLoad(loadOp, state);
+  if (auto storeOp = dyn_cast<AffineStoreOp>(op))
+    return vectorizeAffineStore(storeOp, state);
+  if (auto forOp = dyn_cast<AffineForOp>(op))
+    return vectorizeAffineForOp(forOp, state);
+  if (auto yieldOp = dyn_cast<AffineYieldOp>(op))
+    return vectorizeAffineYieldOp(yieldOp, state);
+  if (auto constant = dyn_cast<ConstantOp>(op))
+    return vectorizeConstant(constant, state);
+
+  // Other ops with regions are not supported.
+  if (op->getNumRegions() != 0)
+    return nullptr;
+
+  return widenOp(op, state);
 }
 
 /// Recursive implementation to convert all the nested loops in 'match' to a 2D
@@ -1171,10 +1186,9 @@ vectorizeLoopNest(std::vector<SmallVector<AffineForOp, 2>> &loops,
                   const VectorizationStrategy &strategy) {
   assert(loops[0].size() == 1 && "Expected single root loop");
   AffineForOp rootLoop = loops[0][0];
-  OperationFolder folder(rootLoop.getContext());
-  VectorizationState state;
+  VectorizationState state(rootLoop.getContext());
+  state.builder.setInsertionPointAfter(rootLoop);
   state.strategy = &strategy;
-  state.folder = &folder;
 
   // Since patterns are recursive, they can very well intersect.
   // Since we do not want a fully greedy strategy in general, we decouple
@@ -1188,70 +1202,48 @@ vectorizeLoopNest(std::vector<SmallVector<AffineForOp, 2>> &loops,
     return failure();
   }
 
-  /// Sets up error handling for this root loop. This is how the root match
-  /// maintains a clone for handling failure and restores the proper state via
-  /// RAII.
-  auto *loopInst = rootLoop.getOperation();
-  OpBuilder builder(loopInst);
-  auto clonedLoop = cast<AffineForOp>(builder.clone(*loopInst));
-  struct Guard {
-    LogicalResult failure() {
-      loop.getInductionVar().replaceAllUsesWith(clonedLoop.getInductionVar());
-      loop.erase();
-      return mlir::failure();
-    }
-    LogicalResult success() {
-      clonedLoop.erase();
-      return mlir::success();
-    }
-    AffineForOp loop;
-    AffineForOp clonedLoop;
-  } guard{rootLoop, clonedLoop};
-
   //////////////////////////////////////////////////////////////////////////////
-  // Start vectorizing.
-  // From now on, any error triggers the scope guard above.
+  // Vectorize the scalar loop nest following a topological order. A new vector
+  // loop nest with the vectorized operations is created along the process. If
+  // vectorization succeeds, the scalar loop nest is erased. If vectorization
+  // fails, the vector loop nest is erased and the scalar loop nest is not
+  // modified.
   //////////////////////////////////////////////////////////////////////////////
-  // 1. Vectorize all the loop candidates, in inner-to-outer order.
-  // This also vectorizes the roots (AffineLoadOp) as well as registers the
-  // terminals (AffineStoreOp) for post-processing vectorization (we need to
-  // wait for all use-def chains into them to be vectorized first).
-  if (failed(vectorizeLoopsAndLoads(loops, &state))) {
-    LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ failed root vectorizeLoop");
-    return guard.failure();
-  }
 
-  // 2. Vectorize operations reached by use-def chains from root except the
-  // terminals (store operations) that need to be post-processed separately.
-  // TODO: add more as we expand.
-  if (failed(vectorizeNonTerminals(&state))) {
-    LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ failed vectorizeNonTerminals");
-    return guard.failure();
-  }
+  auto opVecResult = rootLoop.walk<WalkOrder::PreOrder>([&](Operation *op) {
+    LLVM_DEBUG(dbgs() << "[early-vect]+++++ Vectorizing: " << *op);
+    Operation *vectorOp = vectorizeOneOperation(op, state);
+    if (!vectorOp)
+      return WalkResult::interrupt();
 
-  // 3. Post-process terminals.
-  // Note: we have to post-process terminals because some of their defs may not
-  // be nested under the vectorization pattern (e.g. constants defined in an
-  // encompassing scope).
-  // TODO: Use a backward slice for terminals, avoid special casing and
-  // merge implementations.
-  for (auto *op : state.terminals) {
-    if (!vectorizeOneOperation(op, &state)) { // nullptr == failure
-      LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ failed to vectorize terminals");
-      return guard.failure();
-    }
+    return WalkResult::advance();
+  });
+
+  if (opVecResult.wasInterrupted()) {
+    LLVM_DEBUG(dbgs() << "[early-vect]+++++ failed vectorization for: "
+                      << rootLoop << "\n");
+    // Erase vector loop nest if it was created.
+    auto vecRootLoopIt = state.opVectorReplacement.find(rootLoop);
+    if (vecRootLoopIt != state.opVectorReplacement.end())
+      eraseLoopNest(cast<AffineForOp>(vecRootLoopIt->second));
+
+    return failure();
   }
 
-  // 4. Finish this vectorization pattern.
+  assert(state.opVectorReplacement.count(rootLoop) == 1 &&
+         "Expected vector replacement for loop nest");
   LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ success vectorizing pattern");
-  state.finishVectorizationPattern();
-  return guard.success();
+  LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorization result:\n"
+                    << *state.opVectorReplacement[rootLoop]);
+
+  // Finish this vectorization pattern.
+  state.finishVectorizationPattern(rootLoop);
+  return success();
 }
 
-/// Vectorization is a recursive procedure where anything below can fail. The
-/// root match thus needs to maintain a clone for handling failure. Each root
-/// may succeed independently but will otherwise clean after itself if anything
-/// below it fails.
+/// Extracts the matched loops and vectorizes them following a topological
+/// order. A new vector loop nest will be created if vectorization succeeds. The
+/// original loop nest won't be modified in any case.
 static LogicalResult vectorizeRootMatch(NestedMatch m,
                                         const VectorizationStrategy &strategy) {
   std::vector<SmallVector<AffineForOp, 2>> loopsToVectorize;

diff  --git a/mlir/lib/Dialect/Vector/VectorUtils.cpp b/mlir/lib/Dialect/Vector/VectorUtils.cpp
index ef3ef3db1f81..b68886cb3e40 100644
--- a/mlir/lib/Dialect/Vector/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/VectorUtils.cpp
@@ -211,29 +211,29 @@ static AffineMap makePermutationMap(
 /// TODO: could also be implemented as a collect parents followed by a
 /// filter and made available outside this file.
 template <typename T>
-static SetVector<Operation *> getParentsOfType(Operation *op) {
+static SetVector<Operation *> getParentsOfType(Block *block) {
   SetVector<Operation *> res;
-  auto *current = op;
-  while (auto *parent = current->getParentOp()) {
-    if (auto typedParent = dyn_cast<T>(parent)) {
-      assert(res.count(parent) == 0 && "Already inserted");
-      res.insert(parent);
+  auto *current = block->getParentOp();
+  while (current) {
+    if (auto typedParent = dyn_cast<T>(current)) {
+      assert(res.count(current) == 0 && "Already inserted");
+      res.insert(current);
     }
-    current = parent;
+    current = current->getParentOp();
   }
   return res;
 }
 
 /// Returns the enclosing AffineForOp, from closest to farthest.
-static SetVector<Operation *> getEnclosingforOps(Operation *op) {
-  return getParentsOfType<AffineForOp>(op);
+static SetVector<Operation *> getEnclosingforOps(Block *block) {
+  return getParentsOfType<AffineForOp>(block);
 }
 
 AffineMap mlir::makePermutationMap(
-    Operation *op, ArrayRef<Value> indices,
+    Block *insertPoint, ArrayRef<Value> indices,
     const DenseMap<Operation *, unsigned> &loopToVectorDim) {
   DenseMap<Operation *, unsigned> enclosingLoopToVectorDim;
-  auto enclosingLoops = getEnclosingforOps(op);
+  auto enclosingLoops = getEnclosingforOps(insertPoint);
   for (auto *forInst : enclosingLoops) {
     auto it = loopToVectorDim.find(forInst);
     if (it != loopToVectorDim.end()) {
@@ -243,6 +243,12 @@ AffineMap mlir::makePermutationMap(
   return ::makePermutationMap(indices, enclosingLoopToVectorDim);
 }
 
+AffineMap mlir::makePermutationMap(
+    Operation *op, ArrayRef<Value> indices,
+    const DenseMap<Operation *, unsigned> &loopToVectorDim) {
+  return makePermutationMap(op->getBlock(), indices, loopToVectorDim);
+}
+
 AffineMap mlir::getTransferMinorIdentityMap(ShapedType shapedType,
                                             VectorType vectorType) {
   int64_t elementVectorRank = 0;

diff  --git a/mlir/test/Dialect/Affine/SuperVectorize/vectorize_1d.mlir b/mlir/test/Dialect/Affine/SuperVectorize/vectorize_1d.mlir
index 86749e2c7bab..528169e26d11 100644
--- a/mlir/test/Dialect/Affine/SuperVectorize/vectorize_1d.mlir
+++ b/mlir/test/Dialect/Affine/SuperVectorize/vectorize_1d.mlir
@@ -1,16 +1,8 @@
-// RUN: mlir-opt %s -affine-super-vectorize="virtual-vector-size=128 test-fastest-varying=0" | FileCheck %s
+// RUN: mlir-opt %s -affine-super-vectorize="virtual-vector-size=128 test-fastest-varying=0" -split-input-file | FileCheck %s
 
-// Permutation maps used in vectorization.
-// CHECK-DAG: #[[$map_proj_d0d1_0:map[0-9]+]] = affine_map<(d0, d1) -> (0)>
 // CHECK-DAG: #[[$map_id1:map[0-9]+]] = affine_map<(d0) -> (d0)>
+// CHECK-DAG: #[[$map_proj_d0d1_0:map[0-9]+]] = affine_map<(d0, d1) -> (0)>
 
-#map0 = affine_map<(d0) -> (d0)>
-#mapadd1 = affine_map<(d0) -> (d0 + 1)>
-#mapadd2 = affine_map<(d0) -> (d0 + 2)>
-#mapadd3 = affine_map<(d0) -> (d0 + 3)>
-#set0 = affine_set<(i) : (i >= 0)>
-
-// Maps introduced to vectorize fastest varying memory index.
 // CHECK-LABEL: func @vec1d_1
 func @vec1d_1(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
 // CHECK-DAG: %[[C0:.*]] = constant 0 : index
@@ -37,6 +29,8 @@ func @vec1d_1(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
    return
 }
 
+// -----
+
 // CHECK-LABEL: func @vec1d_2
 func @vec1d_2(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
 // CHECK-DAG: %[[C0:.*]] = constant 0 : index
@@ -61,6 +55,8 @@ func @vec1d_2(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
    return
 }
 
+// -----
+
 // CHECK-LABEL: func @vec1d_3
 func @vec1d_3(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
 // CHECK-DAG: %[[C0:.*]] = constant 0 : index
@@ -90,6 +86,8 @@ func @vec1d_3(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
    return
 }
 
+// -----
+
 // CHECK-LABEL: func @vector_add_2d
 func @vector_add_2d(%M : index, %N : index) -> f32 {
   %A = alloc (%M, %N) : memref<?x?xf32, 0>
@@ -142,6 +140,8 @@ func @vector_add_2d(%M : index, %N : index) -> f32 {
   return %res : f32
 }
 
+// -----
+
 // CHECK-LABEL: func @vec_rejected_1
 func @vec_rejected_1(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
 // CHECK-DAG: %[[C0:.*]] = constant 0 : index
@@ -164,6 +164,8 @@ func @vec_rejected_1(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
    return
 }
 
+// -----
+
 // CHECK-LABEL: func @vec_rejected_2
 func @vec_rejected_2(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
 // CHECK-DAG: %[[C0:.*]] = constant 0 : index
@@ -186,6 +188,8 @@ func @vec_rejected_2(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
    return
 }
 
+// -----
+
 // CHECK-LABEL: func @vec_rejected_3
 func @vec_rejected_3(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
 // CHECK-DAG: %[[C0:.*]] = constant 0 : index
@@ -213,6 +217,8 @@ func @vec_rejected_3(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
    return
 }
 
+// -----
+
 // CHECK-LABEL: func @vec_rejected_4
 func @vec_rejected_4(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
 // CHECK-DAG: %[[C0:.*]] = constant 0 : index
@@ -238,6 +244,8 @@ func @vec_rejected_4(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
    return
 }
 
+// -----
+
 // CHECK-LABEL: func @vec_rejected_5
 func @vec_rejected_5(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
 // CHECK-DAG: %[[C0:.*]] = constant 0 : index
@@ -264,6 +272,8 @@ func @vec_rejected_5(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
    return
 }
 
+// -----
+
 // CHECK-LABEL: func @vec_rejected_6
 func @vec_rejected_6(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
 // CHECK-DAG: %[[C0:.*]] = constant 0 : index
@@ -292,6 +302,8 @@ func @vec_rejected_6(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
    return
 }
 
+// -----
+
 // CHECK-LABEL: func @vec_rejected_7
 func @vec_rejected_7(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
 // CHECK-DAG: %[[C0:.*]] = constant 0 : index
@@ -315,6 +327,11 @@ func @vec_rejected_7(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
    return
 }
 
+// -----
+
+// CHECK-DAG: #[[$map_id1:map[0-9]+]] = affine_map<(d0) -> (d0)>
+// CHECK-DAG: #[[$map_proj_d0d1_0:map[0-9]+]] = affine_map<(d0, d1) -> (0)>
+
 // CHECK-LABEL: func @vec_rejected_8
 func @vec_rejected_8(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
 // CHECK-DAG: %[[C0:.*]] = constant 0 : index
@@ -344,6 +361,11 @@ func @vec_rejected_8(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
    return
 }
 
+// -----
+
+// CHECK-DAG: #[[$map_id1:map[0-9]+]] = affine_map<(d0) -> (d0)>
+// CHECK-DAG: #[[$map_proj_d0d1_0:map[0-9]+]] = affine_map<(d0, d1) -> (0)>
+
 // CHECK-LABEL: func @vec_rejected_9
 func @vec_rejected_9(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
 // CHECK-DAG: %[[C0:.*]] = constant 0 : index
@@ -373,6 +395,10 @@ func @vec_rejected_9(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
    return
 }
 
+// -----
+
+#set0 = affine_set<(i) : (i >= 0)>
+
 // CHECK-LABEL: func @vec_rejected_10
 func @vec_rejected_10(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
 // CHECK-DAG: %[[C0:.*]] = constant 0 : index
@@ -397,6 +423,8 @@ func @vec_rejected_10(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
    return
 }
 
+// -----
+
 // CHECK-LABEL: func @vec_rejected_11
 func @vec_rejected_11(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
   // CHECK-DAG: %[[C0:.*]] = constant 0 : index
@@ -424,7 +452,9 @@ func @vec_rejected_11(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
   return
 }
 
-// This should not vectorize due to the sequential dependence in the scf.
+// -----
+
+// This should not vectorize due to the sequential dependence in the loop.
 // CHECK-LABEL: @vec_rejected_sequential
 func @vec_rejected_sequential(%A : memref<?xf32>) {
   %c0 = constant 0 : index
@@ -437,3 +467,66 @@ func @vec_rejected_sequential(%A : memref<?xf32>) {
   }
   return
 }
+
+// -----
+
+// CHECK-LABEL: @vec_no_load_store_ops
+func @vec_no_load_store_ops(%a: f32, %b: f32) {
+ %cst = constant 0.000000e+00 : f32
+ affine.for %i = 0 to 128 {
+   %add = addf %a, %b : f32
+ }
+ // CHECK-DAG:  %[[bc1:.*]] = vector.broadcast
+ // CHECK-DAG:  %[[bc0:.*]] = vector.broadcast
+ // CHECK:      affine.for %{{.*}} = 0 to 128 step
+ // CHECK-NEXT:   [[add:.*]] addf %[[bc0]], %[[bc1]]
+
+ return
+}
+
+// -----
+
+// This should not be vectorized due to the unsupported block argument (%i).
+// Support for operands with linear evolution is needed.
+// CHECK-LABEL: @vec_rejected_unsupported_block_arg
+func @vec_rejected_unsupported_block_arg(%A : memref<512xi32>) {
+  affine.for %i = 0 to 512 {
+    // CHECK-NOT: vector
+    %idx = std.index_cast %i : index to i32
+    affine.store %idx, %A[%i] : memref<512xi32>
+  }
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @vec_rejected_unsupported_reduction
+func @vec_rejected_unsupported_reduction(%in: memref<128x256xf32>, %out: memref<256xf32>) {
+ %cst = constant 0.000000e+00 : f32
+ affine.for %i = 0 to 256 {
+   // CHECK-NOT: vector
+   %final_red = affine.for %j = 0 to 128 iter_args(%red_iter = %cst) -> (f32) {
+     %ld = affine.load %in[%j, %i] : memref<128x256xf32>
+     %add = addf %red_iter, %ld : f32
+     affine.yield %add : f32
+   }
+   affine.store %final_red, %out[%i] : memref<256xf32>
+ }
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @vec_rejected_unsupported_last_value
+func @vec_rejected_unsupported_last_value(%in: memref<128x256xf32>, %out: memref<256xf32>) {
+ %cst = constant 0.000000e+00 : f32
+ affine.for %i = 0 to 256 {
+   // CHECK-NOT: vector
+   %last_val = affine.for %j = 0 to 128 iter_args(%last_iter = %cst) -> (f32) {
+     %ld = affine.load %in[%j, %i] : memref<128x256xf32>
+     affine.yield %ld : f32
+   }
+   affine.store %last_val, %out[%i] : memref<256xf32>
+ }
+ return
+}


        


More information about the Mlir-commits mailing list