[Mlir-commits] [mlir] 79da91c - Revert "[mlir][Vector][Affine] Improve affine vectorizer algorithm"

Alex Zinenko llvmlistbot at llvm.org
Wed Mar 10 11:26:29 PST 2021


Author: Alex Zinenko
Date: 2021-03-10T20:25:49+01:00
New Revision: 79da91c59aeea3b126ec9a0d6f403a8c0f59e4dc

URL: https://github.com/llvm/llvm-project/commit/79da91c59aeea3b126ec9a0d6f403a8c0f59e4dc
DIFF: https://github.com/llvm/llvm-project/commit/79da91c59aeea3b126ec9a0d6f403a8c0f59e4dc.diff

LOG: Revert "[mlir][Vector][Affine] Improve affine vectorizer algorithm"

This reverts commit 95db7b4aeaad590f37720898e339a6d54313422f.

This breaks vectorize_2d.mlir and vectorize_3d.mlir test under ASAN (use
after free).

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorUtils.h
    mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
    mlir/lib/Dialect/Vector/VectorUtils.cpp
    mlir/test/Dialect/Affine/SuperVectorize/vectorize_1d.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/VectorUtils.h
index 56f8f6211ccc..666603250f0a 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorUtils.h
@@ -19,7 +19,6 @@ namespace mlir {
 class AffineApplyOp;
 class AffineForOp;
 class AffineMap;
-class Block;
 class Location;
 class OpBuilder;
 class Operation;
@@ -99,10 +98,8 @@ Optional<SmallVector<int64_t, 4>> shapeRatio(VectorType superVectorType,
 /// Note that loopToVectorDim is a whole function map from which only enclosing
 /// loop information is extracted.
 ///
-/// Prerequisites: `indices` belong to a vectorizable load or store operation
-/// (i.e. at most one invariant index along each AffineForOp of
-/// `loopToVectorDim`). `insertPoint` is the insertion point for the vectorized
-/// load or store operation.
+/// Prerequisites: `opInst` is a vectorizable load or store operation (i.e. at
+/// most one invariant index along each AffineForOp of `loopToVectorDim`).
 ///
 /// Example 1:
 /// The following MLIR snippet:
@@ -154,10 +151,7 @@ Optional<SmallVector<int64_t, 4>> shapeRatio(VectorType superVectorType,
 /// `%arg0[%c0, %c0]` into vector<128xf32> which needs a 1-D vector broadcast.
 ///
 AffineMap
-makePermutationMap(Block *insertPoint, ArrayRef<Value> indices,
-                   const DenseMap<Operation *, unsigned> &loopToVectorDim);
-AffineMap
-makePermutationMap(Operation *insertPoint, ArrayRef<Value> indices,
+makePermutationMap(Operation *op, ArrayRef<Value> indices,
                    const DenseMap<Operation *, unsigned> &loopToVectorDim);
 
 /// Build the default minor identity map suitable for a vector transfer. This

diff  --git a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
index c27b3e3ab7cc..3540b5106620 100644
--- a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
@@ -14,13 +14,28 @@
 #include "PassDetail.h"
 #include "mlir/Analysis/LoopAnalysis.h"
 #include "mlir/Analysis/NestedMatcher.h"
+#include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Analysis/Utils.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/Passes.h"
 #include "mlir/Dialect/Affine/Utils.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Vector/VectorOps.h"
 #include "mlir/Dialect/Vector/VectorUtils.h"
-#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/Types.h"
 #include "mlir/Support/LLVM.h"
+#include "mlir/Transforms/FoldUtils.h"
+
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallString.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
 
 using namespace mlir;
 using namespace vector;
@@ -237,38 +252,61 @@ using namespace vector;
 ///     fastest varying) ;
 ///  2. analyzing those patterns for profitability (TODO: and
 ///     interference);
-///  3. then, for each pattern in order:
-///    a. applying iterative rewriting of the loops and all their nested
-///       operations in topological order. Rewriting is implemented by
-///       coarsening the loops and converting operations and operands to their
-///       vector forms. Processing operations in topological order is relatively
-///       simple due to the structured nature of the control-flow
-///       representation. This order ensures that all the operands of a given
-///       operation have been vectorized before the operation itself in a single
-///       traversal, except for operands defined outside of the loop nest. The
-///       algorithm can convert the following operations to their vector form:
-///         * Affine load and store operations are converted to opaque vector
-///           transfer read and write operations.
-///         * Scalar constant operations/operands are converted to vector
-///           constant operations (splat).
-///         * Uniform operands (only operands defined outside of the loop nest,
-///           for now) are broadcasted to a vector.
-///           TODO: Support more uniform cases.
-///         * The remaining operations in the loop nest are vectorized by
-///           widening their scalar types to vector types.
-///         * TODO: Add vectorization support for loops with 'iter_args' and
-///           more complex loops with divergent lbs and/or ubs.
-///    b. if everything under the root AffineForOp in the current pattern
-///       is vectorized properly, we commit that loop to the IR and remove the
-///       scalar loop. Otherwise, we discard the vectorized loop and keep the
-///       original scalar loop.
-///    c. vectorization is applied on the next pattern in the list. Because
+///  3. Then, for each pattern in order:
+///    a. applying iterative rewriting of the loop and the load operations in
+///       inner-to-outer order. Rewriting is implemented by coarsening the loops
+///       and turning load operations into opaque vector.transfer_read ops;
+///    b. keeping track of the load operations encountered as "roots" and the
+///       store operations as "terminals";
+///    c. traversing the use-def chains starting from the roots and iteratively
+///       propagating vectorized values. Scalar values that are encountered
+///       during this process must come from outside the scope of the current
+///       pattern (TODO: enforce this and generalize). Such a scalar value
+///       is vectorized only if it is a constant (into a vector splat). The
+///       non-constant case is not supported for now and results in the pattern
+///       failing to vectorize;
+///    d. performing a second traversal on the terminals (store ops) to
+///       rewriting the scalar value they write to memory into vector form.
+///       If the scalar value has been vectorized previously, we simply replace
+///       it by its vector form. Otherwise, if the scalar value is a constant,
+///       it is vectorized into a splat. In all other cases, vectorization for
+///       the pattern currently fails.
+///    e. if everything under the root AffineForOp in the current pattern
+///       vectorizes properly, we commit that loop to the IR. Otherwise we
+///       discard it and restore a previously cloned version of the loop. Thanks
+///       to the recursive scoping nature of matchers and captured patterns,
+///       this is transparently achieved by a simple RAII implementation.
+///    f. vectorization is applied on the next pattern in the list. Because
 ///       pattern interference avoidance is not yet implemented and that we do
 ///       not support further vectorizing an already vector load we need to
 ///       re-verify that the pattern is still vectorizable. This is expected to
 ///       make cost models more 
diff icult to write and is subject to improvement
 ///       in the future.
 ///
+/// Points c. and d. above are worth additional comment. In most passes that
+/// do not change the type of operands, it is usually preferred to eagerly
+/// `replaceAllUsesWith`. Unfortunately this does not work for vectorization
+/// because during the use-def chain traversal, all the operands of an operation
+/// must be available in vector form. Trying to propagate eagerly makes the IR
+/// temporarily invalid and results in errors such as:
+///   `vectorize.mlir:308:13: error: 'addf' op requires the same type for all
+///   operands and results
+///      %s5 = addf %a5, %b5 : f32`
+///
+/// Lastly, we show a minimal example for which use-def chains rooted in load /
+/// vector.transfer_read are not enough. This is what motivated splitting
+/// terminal processing out of the use-def chains starting from loads. In the
+/// following snippet, there is simply no load::
+/// ```mlir
+/// func @fill(%A : memref<128xf32>) -> () {
+///   %f1 = constant 1.0 : f32
+///   affine.for %i0 = 0 to 32 {
+///     affine.store %f1, %A[%i0] : memref<128xf32, 0>
+///   }
+///   return
+/// }
+/// ```
+///
 /// Choice of loop transformation to support the algorithm:
 /// =======================================================
 /// The choice of loop transformation to apply for coarsening vectorized loops
@@ -489,6 +527,7 @@ using namespace vector;
 #define DEBUG_TYPE "early-vect"
 
 using llvm::dbgs;
+using llvm::SetVector;
 
 /// Forward declaration.
 static FilterFunctionType
@@ -593,196 +632,199 @@ static LogicalResult analyzeProfitability(ArrayRef<NestedMatch> matches,
 namespace {
 
 struct VectorizationState {
-
-  VectorizationState(MLIRContext *context) : builder(context) {}
-
-  /// Registers the vector replacement of a scalar operation and its result
-  /// values. Both operations must have the same number of results.
-  ///
-  /// This utility is used to register the replacement for the vast majority of
-  /// the vectorized operations.
-  ///
-  /// Example:
-  ///   * 'replaced': %0 = addf %1, %2 : f32
-  ///   * 'replacement': %0 = addf %1, %2 : vector<128xf32>
-  void registerOpVectorReplacement(Operation *replaced, Operation *replacement);
-
-  /// Registers the vector replacement of a scalar value. The replacement
-  /// operation should have a single result, which replaces the scalar value.
-  ///
-  /// This utility is used to register the vector replacement of block arguments
-  /// and operation results which are not directly vectorized (i.e., their
-  /// scalar version still exists after vectorization), like uniforms.
-  ///
-  /// Example:
-  ///   * 'replaced': block argument or operation outside of the vectorized
-  ///     loop.
-  ///   * 'replacement': %0 = vector.broadcast %1 : f32 to vector<128xf32>
-  void registerValueVectorReplacement(Value replaced, Operation *replacement);
-
-  /// Registers the scalar replacement of a scalar value. 'replacement' must be
-  /// scalar. Both values must be block arguments. Operation results should be
-  /// replaced using the 'registerOp*' utilitites.
-  ///
-  /// This utility is used to register the replacement of block arguments
-  /// that are within the loop to be vectorized and will continue being scalar
-  /// within the vector loop.
-  ///
-  /// Example:
-  ///   * 'replaced': induction variable of a loop to be vectorized.
-  ///   * 'replacement': new induction variable in the new vector loop.
-  void registerValueScalarReplacement(BlockArgument replaced,
-                                      BlockArgument replacement);
-
-  /// Returns in 'replacedVals' the scalar replacement for values in
-  /// 'inputVals'.
-  void getScalarValueReplacementsFor(ValueRange inputVals,
-                                     SmallVectorImpl<Value> &replacedVals);
-
-  /// Erases the scalar loop nest after its successful vectorization.
-  void finishVectorizationPattern(AffineForOp rootLoop);
-
-  // Used to build and insert all the new operations created. The insertion
-  // point is preserved and updated along the vectorization process.
-  OpBuilder builder;
-
-  // Maps input scalar operations to their vector counterparts.
-  DenseMap<Operation *, Operation *> opVectorReplacement;
-  // Maps input scalar values to their vector counterparts.
-  BlockAndValueMapping valueVectorReplacement;
-  // Maps input scalar values to their new scalar counterparts in the vector
-  // loop nest.
-  BlockAndValueMapping valueScalarReplacement;
-
-  // Maps the newly created vector loops to their vector dimension.
-  DenseMap<Operation *, unsigned> vecLoopToVecDim;
-
+  /// Adds an entry of pre/post vectorization operations in the state.
+  void registerReplacement(Operation *key, Operation *value);
+  /// When the current vectorization pattern is successful, this erases the
+  /// operations that were marked for erasure in the proper order and resets
+  /// the internal state for the next pattern.
+  void finishVectorizationPattern();
+
+  // In-order tracking of original Operation that have been vectorized.
+  // Erase in reverse order.
+  SmallVector<Operation *, 16> toErase;
+  // Set of Operation that have been vectorized (the values in the
+  // vectorizationMap for hashed access). The vectorizedSet is used in
+  // particular to filter the operations that have already been vectorized by
+  // this pattern, when iterating over nested loops in this pattern.
+  DenseSet<Operation *> vectorizedSet;
+  // Map of old scalar Operation to new vectorized Operation.
+  DenseMap<Operation *, Operation *> vectorizationMap;
+  // Map of old scalar Value to new vectorized Value.
+  DenseMap<Value, Value> replacementMap;
   // The strategy drives which loop to vectorize by which amount.
   const VectorizationStrategy *strategy;
+  // Use-def roots. These represent the starting points for the worklist in the
+  // vectorizeNonTerminals function. They consist of the subset of load
+  // operations that have been vectorized. They can be retrieved from
+  // `vectorizationMap` but it is convenient to keep track of them in a separate
+  // data structure.
+  DenseSet<Operation *> roots;
+  // Terminal operations for the worklist in the vectorizeNonTerminals
+  // function. They consist of the subset of store operations that have been
+  // vectorized. They can be retrieved from `vectorizationMap` but it is
+  // convenient to keep track of them in a separate data structure. Since they
+  // do not necessarily belong to use-def chains starting from loads (e.g
+  // storing a constant), we need to handle them in a post-pass.
+  DenseSet<Operation *> terminals;
+  // Checks that the type of `op` is AffineStoreOp and adds it to the terminals
+  // set.
+  void registerTerminal(Operation *op);
+  // Folder used to factor out constant creation.
+  OperationFolder *folder;
 
 private:
-  /// Internal implementation to map input scalar values to new vector or scalar
-  /// values.
-  void registerValueVectorReplacementImpl(Value replaced, Value replacement);
-  void registerValueScalarReplacementImpl(Value replaced, Value replacement);
+  void registerReplacement(Value key, Value value);
 };
 
 } // end namespace
 
-/// Registers the vector replacement of a scalar operation and its result
-/// values. Both operations must have the same number of results.
-///
-/// This utility is used to register the replacement for the vast majority of
-/// the vectorized operations.
-///
-/// Example:
-///   * 'replaced': %0 = addf %1, %2 : f32
-///   * 'replacement': %0 = addf %1, %2 : vector<128xf32>
-void VectorizationState::registerOpVectorReplacement(Operation *replaced,
-                                                     Operation *replacement) {
-  LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ commit vectorized op:\n");
-  LLVM_DEBUG(dbgs() << *replaced << "\n");
-  LLVM_DEBUG(dbgs() << "into\n");
-  LLVM_DEBUG(dbgs() << *replacement << "\n");
-
-  assert(replaced->getNumResults() <= 1 && "Unsupported multi-result op");
-  assert(replaced->getNumResults() == replacement->getNumResults() &&
-         "Unexpected replaced and replacement results");
-  assert(opVectorReplacement.count(replaced) == 0 && "already registered");
-  opVectorReplacement[replaced] = replacement;
-
-  if (replaced->getNumResults() > 0)
-    registerValueVectorReplacementImpl(replaced->getResult(0),
-                                       replacement->getResult(0));
-}
-
-/// Registers the vector replacement of a scalar value. The replacement
-/// operation should have a single result, which replaces the scalar value.
-///
-/// This utility is used to register the vector replacement of block arguments
-/// and operation results which are not directly vectorized (i.e., their
-/// scalar version still exists after vectorization), like uniforms.
-///
-/// Example:
-///   * 'replaced': block argument or operation outside of the vectorized loop.
-///   * 'replacement': %0 = vector.broadcast %1 : f32 to vector<128xf32>
-void VectorizationState::registerValueVectorReplacement(
-    Value replaced, Operation *replacement) {
-  assert(replacement->getNumResults() == 1 &&
-         "Expected single-result replacement");
-  if (Operation *defOp = replaced.getDefiningOp())
-    registerOpVectorReplacement(defOp, replacement);
-  else
-    registerValueVectorReplacementImpl(replaced, replacement->getResult(0));
-}
-
-void VectorizationState::registerValueVectorReplacementImpl(Value replaced,
-                                                            Value replacement) {
-  assert(!valueVectorReplacement.contains(replaced) &&
-         "Vector replacement already registered");
-  assert(replacement.getType().isa<VectorType>() &&
-         "Expected vector type in vector replacement");
-  valueVectorReplacement.map(replaced, replacement);
-}
-
-/// Registers the scalar replacement of a scalar value. 'replacement' must be
-/// scalar. Both values must be block arguments. Operation results should be
-/// replaced using the 'registerOp*' utilitites.
-///
-/// This utility is used to register the replacement of block arguments
-/// that are within the loop to be vectorized and will continue being scalar
-/// within the vector loop.
-///
-/// Example:
-///   * 'replaced': induction variable of a loop to be vectorized.
-///   * 'replacement': new induction variable in the new vector loop.
-void VectorizationState::registerValueScalarReplacement(
-    BlockArgument replaced, BlockArgument replacement) {
-  registerValueScalarReplacementImpl(replaced, replacement);
-}
-
-void VectorizationState::registerValueScalarReplacementImpl(Value replaced,
-                                                            Value replacement) {
-  assert(!valueScalarReplacement.contains(replaced) &&
-         "Scalar value replacement already registered");
-  assert(!replacement.getType().isa<VectorType>() &&
-         "Expected scalar type in scalar replacement");
-  valueScalarReplacement.map(replaced, replacement);
+void VectorizationState::registerReplacement(Operation *key, Operation *value) {
+  LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ commit vectorized op: ");
+  LLVM_DEBUG(key->print(dbgs()));
+  LLVM_DEBUG(dbgs() << "  into  ");
+  LLVM_DEBUG(value->print(dbgs()));
+  assert(key->getNumResults() == 1 && "already registered");
+  assert(value->getNumResults() == 1 && "already registered");
+  assert(vectorizedSet.count(value) == 0 && "already registered");
+  assert(vectorizationMap.count(key) == 0 && "already registered");
+  toErase.push_back(key);
+  vectorizedSet.insert(value);
+  vectorizationMap.insert(std::make_pair(key, value));
+  registerReplacement(key->getResult(0), value->getResult(0));
+  if (isa<AffineLoadOp>(key)) {
+    assert(roots.count(key) == 0 && "root was already inserted previously");
+    roots.insert(key);
+  }
 }
 
-/// Returns in 'replacedVals' the scalar replacement for values in 'inputVals'.
-void VectorizationState::getScalarValueReplacementsFor(
-    ValueRange inputVals, SmallVectorImpl<Value> &replacedVals) {
-  for (Value inputVal : inputVals)
-    replacedVals.push_back(valueScalarReplacement.lookupOrDefault(inputVal));
+void VectorizationState::registerTerminal(Operation *op) {
+  assert(isa<AffineStoreOp>(op) && "terminal must be a AffineStoreOp");
+  assert(terminals.count(op) == 0 &&
+         "terminal was already inserted previously");
+  terminals.insert(op);
 }
 
-/// Erases a loop nest, including all its nested operations.
-static void eraseLoopNest(AffineForOp forOp) {
-  LLVM_DEBUG(dbgs() << "[early-vect]+++++ erasing:\n" << forOp << "\n");
-  forOp.erase();
+void VectorizationState::finishVectorizationPattern() {
+  while (!toErase.empty()) {
+    auto *op = toErase.pop_back_val();
+    LLVM_DEBUG(dbgs() << "\n[early-vect] finishVectorizationPattern erase: ");
+    LLVM_DEBUG(op->print(dbgs()));
+    op->erase();
+  }
 }
 
-/// Erases the scalar loop nest after its successful vectorization.
-void VectorizationState::finishVectorizationPattern(AffineForOp rootLoop) {
-  LLVM_DEBUG(dbgs() << "\n[early-vect] Finalizing vectorization\n");
-  eraseLoopNest(rootLoop);
+void VectorizationState::registerReplacement(Value key, Value value) {
+  assert(replacementMap.count(key) == 0 && "replacement already registered");
+  replacementMap.insert(std::make_pair(key, value));
 }
 
 // Apply 'map' with 'mapOperands' returning resulting values in 'results'.
 static void computeMemoryOpIndices(Operation *op, AffineMap map,
                                    ValueRange mapOperands,
-                                   VectorizationState &state,
                                    SmallVectorImpl<Value> &results) {
+  OpBuilder builder(op);
   for (auto resultExpr : map.getResults()) {
     auto singleResMap =
         AffineMap::get(map.getNumDims(), map.getNumSymbols(), resultExpr);
-    auto afOp = state.builder.create<AffineApplyOp>(op->getLoc(), singleResMap,
-                                                    mapOperands);
+    auto afOp =
+        builder.create<AffineApplyOp>(op->getLoc(), singleResMap, mapOperands);
     results.push_back(afOp);
   }
 }
 
+////// TODO: Hoist to a VectorizationMaterialize.cpp when appropriate. ////
+
+/// Handles the vectorization of load and store MLIR operations.
+///
+/// AffineLoadOp operations are the roots of the vectorizeNonTerminals call.
+/// They are vectorized immediately. The resulting vector.transfer_read is
+/// immediately registered to replace all uses of the AffineLoadOp in this
+/// pattern's scope.
+///
+/// AffineStoreOp are the terminals of the vectorizeNonTerminals call. They
+/// need to be vectorized late once all the use-def chains have been traversed.
+/// Additionally, they may have ssa-values operands which come from outside the
+/// scope of the current pattern.
+/// Such special cases force us to delay the vectorization of the stores until
+/// the last step. Here we merely register the store operation.
+template <typename LoadOrStoreOpPointer>
+static LogicalResult vectorizeRootOrTerminal(Value iv,
+                                             LoadOrStoreOpPointer memoryOp,
+                                             VectorizationState *state) {
+  auto memRefType = memoryOp.getMemRef().getType().template cast<MemRefType>();
+
+  auto elementType = memRefType.getElementType();
+  // TODO: ponder whether we want to further vectorize a vector value.
+  assert(VectorType::isValidElementType(elementType) &&
+         "Not a valid vector element type");
+  auto vectorType = VectorType::get(state->strategy->vectorSizes, elementType);
+
+  // Materialize a MemRef with 1 vector.
+  auto *opInst = memoryOp.getOperation();
+  // For now, vector.transfers must be aligned, operate only on indices with an
+  // identity subset of AffineMap and do not change layout.
+  // TODO: increase the expressiveness power of vector.transfer operations
+  // as needed by various targets.
+  if (auto load = dyn_cast<AffineLoadOp>(opInst)) {
+    OpBuilder b(opInst);
+    ValueRange mapOperands = load.getMapOperands();
+    SmallVector<Value, 8> indices;
+    indices.reserve(load.getMemRefType().getRank());
+    if (load.getAffineMap() !=
+        b.getMultiDimIdentityMap(load.getMemRefType().getRank())) {
+      computeMemoryOpIndices(opInst, load.getAffineMap(), mapOperands, indices);
+    } else {
+      indices.append(mapOperands.begin(), mapOperands.end());
+    }
+    auto permutationMap =
+        makePermutationMap(opInst, indices, state->strategy->loopToVectorDim);
+    if (!permutationMap)
+      return failure();
+    LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: ");
+    LLVM_DEBUG(permutationMap.print(dbgs()));
+    auto transfer = b.create<vector::TransferReadOp>(
+        opInst->getLoc(), vectorType, memoryOp.getMemRef(), indices,
+        permutationMap);
+    state->registerReplacement(opInst, transfer.getOperation());
+  } else {
+    state->registerTerminal(opInst);
+  }
+  return success();
+}
+/// end TODO: Hoist to a VectorizationMaterialize.cpp when appropriate. ///
+
+/// Coarsens the loops bounds and transforms all remaining load and store
+/// operations into the appropriate vector.transfer.
+static LogicalResult vectorizeAffineForOp(AffineForOp loop, int64_t step,
+                                          VectorizationState *state) {
+  loop.setStep(step);
+
+  FilterFunctionType notVectorizedThisPattern = [state](Operation &op) {
+    if (!matcher::isLoadOrStore(op)) {
+      return false;
+    }
+    return state->vectorizationMap.count(&op) == 0 &&
+           state->vectorizedSet.count(&op) == 0 &&
+           state->roots.count(&op) == 0 && state->terminals.count(&op) == 0;
+  };
+  auto loadAndStores = matcher::Op(notVectorizedThisPattern);
+  SmallVector<NestedMatch, 8> loadAndStoresMatches;
+  loadAndStores.match(loop.getOperation(), &loadAndStoresMatches);
+  for (auto ls : loadAndStoresMatches) {
+    auto *opInst = ls.getMatchedOperation();
+    auto load = dyn_cast<AffineLoadOp>(opInst);
+    auto store = dyn_cast<AffineStoreOp>(opInst);
+    LLVM_DEBUG(opInst->print(dbgs()));
+    LogicalResult result =
+        load ? vectorizeRootOrTerminal(loop.getInductionVar(), load, state)
+             : vectorizeRootOrTerminal(loop.getInductionVar(), store, state);
+    if (failed(result)) {
+      return failure();
+    }
+  }
+  return success();
+}
+
 /// Returns a FilterFunctionType that can be used in NestedPattern to match a
 /// loop whose underlying load/store accesses are either invariant or all
 // varying along the `fastestVaryingMemRefDimension`.
@@ -804,6 +846,68 @@ isVectorizableLoopPtrFactory(const DenseSet<Operation *> &parallelLoops,
   };
 }
 
+/// Apply vectorization of `loop` according to `state`. `loops` are processed in
+/// inner-to-outer order to ensure that all the children loops have already been
+/// vectorized before vectorizing the parent loop.
+static LogicalResult
+vectorizeLoopsAndLoads(std::vector<SmallVector<AffineForOp, 2>> &loops,
+                       VectorizationState *state) {
+  // Vectorize loops in inner-to-outer order. If any children fails, the parent
+  // will fail too.
+  for (auto &loopsInLevel : llvm::reverse(loops)) {
+    for (AffineForOp loop : loopsInLevel) {
+      // 1. This loop may have been omitted from vectorization for various
+      // reasons (e.g. due to the performance model or pattern depth > vector
+      // size).
+      auto it = state->strategy->loopToVectorDim.find(loop.getOperation());
+      if (it == state->strategy->loopToVectorDim.end())
+        continue;
+
+      // 2. Actual inner-to-outer transformation.
+      auto vectorDim = it->second;
+      assert(vectorDim < state->strategy->vectorSizes.size() &&
+             "vector dim overflow");
+      //   a. get actual vector size
+      auto vectorSize = state->strategy->vectorSizes[vectorDim];
+      //   b. loop transformation for early vectorization is still subject to
+      //     exploratory tradeoffs (see top of the file). Apply coarsening,
+      //     i.e.:
+      //        | ub -> ub
+      //        | step -> step * vectorSize
+      LLVM_DEBUG(dbgs() << "\n[early-vect] vectorizeForOp by " << vectorSize
+                        << " : \n"
+                        << loop);
+      if (failed(
+              vectorizeAffineForOp(loop, loop.getStep() * vectorSize, state)))
+        return failure();
+    } // end for.
+  }
+
+  return success();
+}
+
+/// Tries to transform a scalar constant into a vector splat of that constant.
+/// Returns the vectorized splat operation if the constant is a valid vector
+/// element type.
+/// If `type` is not a valid vector type or if the scalar constant is not a
+/// valid vector element type, returns nullptr.
+static Value vectorizeConstant(Operation *op, ConstantOp constant, Type type) {
+  if (!type || !type.isa<VectorType>() ||
+      !VectorType::isValidElementType(constant.getType())) {
+    return nullptr;
+  }
+  OpBuilder b(op);
+  Location loc = op->getLoc();
+  auto vectorType = type.cast<VectorType>();
+  auto attr = DenseElementsAttr::get(vectorType, constant.getValue());
+  auto *constantOpInst = constant.getOperation();
+
+  OperationState state(loc, constantOpInst->getName().getStringRef(), {},
+                       {vectorType}, {b.getNamedAttr("value", attr)});
+
+  return b.createOperation(state)->getResult(0);
+}
+
 /// Returns the vector type resulting from applying the provided vectorization
 /// strategy on the scalar type.
 static VectorType getVectorType(Type scalarTy,
@@ -812,24 +916,6 @@ static VectorType getVectorType(Type scalarTy,
   return VectorType::get(strategy->vectorSizes, scalarTy);
 }
 
-/// Tries to transform a scalar constant into a vector constant. Returns the
-/// vector constant if the scalar type is valid vector element type. Returns
-/// nullptr, otherwise.
-static ConstantOp vectorizeConstant(ConstantOp constOp,
-                                    VectorizationState &state) {
-  Type scalarTy = constOp.getType();
-  if (!VectorType::isValidElementType(scalarTy))
-    return nullptr;
-
-  auto vecTy = getVectorType(scalarTy, state.strategy);
-  auto vecAttr = DenseElementsAttr::get(vecTy, constOp.getValue());
-  auto newConstOp = state.builder.create<ConstantOp>(constOp.getLoc(), vecAttr);
-
-  // Register vector replacement for future uses in the scope.
-  state.registerOpVectorReplacement(constOp, newConstOp);
-  return newConstOp;
-}
-
 /// Returns true if the provided value is vector uniform given the vectorization
 /// strategy.
 // TODO: For now, only values that are invariants to all the loops in the
@@ -846,27 +932,32 @@ static bool isUniformDefinition(Value value,
 
 /// Generates a broadcast op for the provided uniform value using the
 /// vectorization strategy in 'state'.
-static Operation *vectorizeUniform(Value uniformVal,
-                                   VectorizationState &state) {
-  OpBuilder::InsertionGuard guard(state.builder);
-  state.builder.setInsertionPointAfterValue(uniformVal);
-
-  auto vectorTy = getVectorType(uniformVal.getType(), state.strategy);
-  auto bcastOp = state.builder.create<BroadcastOp>(uniformVal.getLoc(),
-                                                   vectorTy, uniformVal);
-  state.registerValueVectorReplacement(uniformVal, bcastOp);
-  return bcastOp;
+static Value vectorizeUniform(Value value, VectorizationState *state) {
+  OpBuilder builder(value.getContext());
+  builder.setInsertionPointAfterValue(value);
+
+  auto vectorTy = getVectorType(value.getType(), state->strategy);
+  auto bcast = builder.create<BroadcastOp>(value.getLoc(), vectorTy, value);
+
+  // Add broadcast to the replacement map to reuse it for other uses.
+  state->replacementMap[value] = bcast;
+  return bcast;
 }
 
-/// Tries to vectorize a given `operand` by applying the following logic:
-/// 1. if the defining operation has been already vectorized, `operand` is
-///    already in the proper vector form;
-/// 2. if the `operand` is a constant, returns the vectorized form of the
-///    constant;
-/// 3. if the `operand` is uniform, returns a vector broadcast of the `op`;
-/// 4. otherwise, the vectorization of `operand` is not supported.
-/// Newly created vector operations are registered in `state` as replacement
-/// for their scalar counterparts.
+/// Tries to vectorize a given operand `op` of Operation `op` during
+/// def-chain propagation or during terminal vectorization, by applying the
+/// following logic:
+/// 1. if the defining operation is part of the vectorizedSet (i.e. vectorized
+///    useby -def propagation), `op` is already in the proper vector form;
+/// 2. otherwise, the `op` may be in some other vector form that fails to
+///    vectorize atm (i.e. broadcasting required), returns nullptr to indicate
+///    failure;
+/// 3. if the `op` is a constant, returns the vectorized form of the constant;
+/// 4. if the `op` is uniform, returns a vector broadcast of the `op`;
+/// 5. non-constant scalars are currently non-vectorizable, in particular to
+///    guard against vectorizing an index which may be loop-variant and needs
+///    special handling.
+///
 /// In particular this logic captures some of the use cases where definitions
 /// that are not scoped under the current pattern are needed to vectorize.
 /// One such example is top level function constants that need to be splatted.
@@ -875,213 +966,112 @@ static Operation *vectorizeUniform(Value uniformVal,
 /// vectorization is possible with the above logic. Returns nullptr otherwise.
 ///
 /// TODO: handle more complex cases.
-static Value vectorizeOperand(Value operand, VectorizationState &state) {
-  LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorize operand: " << operand);
-  // If this value is already vectorized, we are done.
-  if (Value vecRepl = state.valueVectorReplacement.lookupOrNull(operand)) {
-    LLVM_DEBUG(dbgs() << " -> already vectorized: " << vecRepl);
-    return vecRepl;
+static Value vectorizeOperand(Value operand, Operation *op,
+                              VectorizationState *state) {
+  LLVM_DEBUG(dbgs() << "\n[early-vect]vectorize operand: " << operand);
+  // 1. If this value has already been vectorized this round, we are done.
+  if (state->vectorizedSet.count(operand.getDefiningOp()) > 0) {
+    LLVM_DEBUG(dbgs() << " -> already vector operand");
+    return operand;
   }
-
-  // An vector operand that is not in the replacement map should never reach
-  // this point. Reaching this point could mean that the code was already
-  // vectorized and we shouldn't try to vectorize already vectorized code.
-  assert(!operand.getType().isa<VectorType>() &&
-         "Vector op not found in replacement map");
-
-  // Vectorize constant.
-  if (auto constOp = operand.getDefiningOp<ConstantOp>()) {
-    ConstantOp vecConstant = vectorizeConstant(constOp, state);
-    LLVM_DEBUG(dbgs() << "-> constant: " << vecConstant);
-    return vecConstant.getResult();
+  // 1.b. Delayed on-demand replacement of a use.
+  //    Note that we cannot just call replaceAllUsesWith because it may result
+  //    in ops with mixed types, for ops whose operands have not all yet
+  //    been vectorized. This would be invalid IR.
+  auto it = state->replacementMap.find(operand);
+  if (it != state->replacementMap.end()) {
+    auto res = it->second;
+    LLVM_DEBUG(dbgs() << "-> delayed replacement by: " << res);
+    return res;
   }
-
-  // Vectorize uniform values.
-  if (isUniformDefinition(operand, state.strategy)) {
-    Operation *vecUniform = vectorizeUniform(operand, state);
-    LLVM_DEBUG(dbgs() << "-> uniform: " << *vecUniform);
-    return vecUniform->getResult(0);
-  }
-
-  // Check for unsupported block argument scenarios. A supported block argument
-  // should have been vectorized already.
-  if (!operand.getDefiningOp())
-    LLVM_DEBUG(dbgs() << "-> unsupported block argument\n");
-  else
-    // Generic unsupported case.
-    LLVM_DEBUG(dbgs() << "-> non-vectorizable\n");
-
-  return nullptr;
-}
-
-/// Vectorizes an affine load with the vectorization strategy in 'state' by
-/// generating a 'vector.transfer_read' op with the proper permutation map
-/// inferred from the indices of the load. The new 'vector.transfer_read' is
-/// registered as replacement of the scalar load. Returns the newly created
-/// 'vector.transfer_read' if vectorization was successful. Returns nullptr,
-/// otherwise.
-static Operation *vectorizeAffineLoad(AffineLoadOp loadOp,
-                                      VectorizationState &state) {
-  MemRefType memRefType = loadOp.getMemRefType();
-  Type elementType = memRefType.getElementType();
-  auto vectorType = VectorType::get(state.strategy->vectorSizes, elementType);
-
-  // Replace map operands with operands from the vector loop nest.
-  SmallVector<Value, 8> mapOperands;
-  state.getScalarValueReplacementsFor(loadOp.getMapOperands(), mapOperands);
-
-  // Compute indices for the transfer op. AffineApplyOp's may be generated.
-  SmallVector<Value, 8> indices;
-  indices.reserve(memRefType.getRank());
-  if (loadOp.getAffineMap() !=
-      state.builder.getMultiDimIdentityMap(memRefType.getRank()))
-    computeMemoryOpIndices(loadOp, loadOp.getAffineMap(), mapOperands, state,
-                           indices);
-  else
-    indices.append(mapOperands.begin(), mapOperands.end());
-
-  // Compute permutation map using the information of new vector loops.
-  auto permutationMap = makePermutationMap(state.builder.getInsertionBlock(),
-                                           indices, state.vecLoopToVecDim);
-  if (!permutationMap) {
-    LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ can't compute permutationMap\n");
+  // 2. TODO: broadcast needed.
+  if (operand.getType().isa<VectorType>()) {
+    LLVM_DEBUG(dbgs() << "-> non-vectorizable");
     return nullptr;
   }
-  LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: ");
-  LLVM_DEBUG(permutationMap.print(dbgs()));
+  // 3. vectorize constant.
+  if (auto constant = operand.getDefiningOp<ConstantOp>())
+    return vectorizeConstant(op, constant,
+                             getVectorType(operand.getType(), state->strategy));
 
-  auto transfer = state.builder.create<vector::TransferReadOp>(
-      loadOp.getLoc(), vectorType, loadOp.getMemRef(), indices, permutationMap);
+  // 4. Uniform values.
+  if (isUniformDefinition(operand, state->strategy))
+    return vectorizeUniform(operand, state);
 
-  // Register replacement for future uses in the scope.
-  state.registerOpVectorReplacement(loadOp, transfer);
-  return transfer;
+  // 5. currently non-vectorizable.
+  LLVM_DEBUG(dbgs() << "-> non-vectorizable: " << operand);
+  return nullptr;
 }
 
-/// Vectorizes an affine store with the vectorization strategy in 'state' by
-/// generating a 'vector.transfer_write' op with the proper permutation map
-/// inferred from the indices of the store. The new 'vector.transfer_store' is
-/// registered as replacement of the scalar load. Returns the newly created
-/// 'vector.transfer_write' if vectorization was successful. Returns nullptr,
-/// otherwise.
-static Operation *vectorizeAffineStore(AffineStoreOp storeOp,
-                                       VectorizationState &state) {
-  MemRefType memRefType = storeOp.getMemRefType();
-  Value vectorValue = vectorizeOperand(storeOp.getValueToStore(), state);
-  if (!vectorValue)
-    return nullptr;
-
-  // Replace map operands with operands from the vector loop nest.
-  SmallVector<Value, 8> mapOperands;
-  state.getScalarValueReplacementsFor(storeOp.getMapOperands(), mapOperands);
-
-  // Compute indices for the transfer op. AffineApplyOp's may be generated.
-  SmallVector<Value, 8> indices;
-  indices.reserve(memRefType.getRank());
-  if (storeOp.getAffineMap() !=
-      state.builder.getMultiDimIdentityMap(memRefType.getRank()))
-    computeMemoryOpIndices(storeOp, storeOp.getAffineMap(), mapOperands, state,
-                           indices);
-  else
-    indices.append(mapOperands.begin(), mapOperands.end());
-
-  // Compute permutation map using the information of new vector loops.
-  auto permutationMap = makePermutationMap(state.builder.getInsertionBlock(),
-                                           indices, state.vecLoopToVecDim);
-  if (!permutationMap)
-    return nullptr;
-  LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: ");
-  LLVM_DEBUG(permutationMap.print(dbgs()));
-
-  auto transfer = state.builder.create<vector::TransferWriteOp>(
-      storeOp.getLoc(), vectorValue, storeOp.getMemRef(), indices,
-      permutationMap);
-  LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorized store: " << transfer);
+/// Encodes Operation-specific behavior for vectorization. In general we assume
+/// that all operands of an op must be vectorized but this is not always true.
+/// In the future, it would be nice to have a trait that describes how a
+/// particular operation vectorizes. For now we implement the case distinction
+/// here.
+/// Returns a vectorized form of an operation or nullptr if vectorization fails.
+// TODO: consider adding a trait to Op to describe how it gets vectorized.
+// Maybe some Ops are not vectorizable or require some tricky logic, we cannot
+// do one-off logic here; ideally it would be TableGen'd.
+static Operation *vectorizeOneOperation(Operation *opInst,
+                                        VectorizationState *state) {
+  // Sanity checks.
+  assert(!isa<AffineLoadOp>(opInst) &&
+         "all loads must have already been fully vectorized independently");
+  assert(!isa<vector::TransferReadOp>(opInst) &&
+         "vector.transfer_read cannot be further vectorized");
+  assert(!isa<vector::TransferWriteOp>(opInst) &&
+         "vector.transfer_write cannot be further vectorized");
 
-  // Register replacement for future uses in the scope.
-  state.registerOpVectorReplacement(storeOp, transfer);
-  return transfer;
-}
+  if (auto store = dyn_cast<AffineStoreOp>(opInst)) {
+    OpBuilder b(opInst);
+    auto memRef = store.getMemRef();
+    auto value = store.getValueToStore();
+    auto vectorValue = vectorizeOperand(value, opInst, state);
+    if (!vectorValue)
+      return nullptr;
 
-/// Vectorizes a loop with the vectorization strategy in 'state'. A new loop is
-/// created and registered as replacement for the scalar loop. The builder's
-/// insertion point is set to the new loop's body so that subsequent vectorized
-/// operations are inserted into the new loop. If the loop is a vector
-/// dimension, the step of the newly created loop will reflect the vectorization
-/// factor used to vectorized that dimension.
-// TODO: Add support for 'iter_args'. Related operands and results will be
-// vectorized at this point.
-static Operation *vectorizeAffineForOp(AffineForOp forOp,
-                                       VectorizationState &state) {
-  // 'iter_args' not supported yet.
-  if (forOp.getNumIterOperands() > 0)
-    return nullptr;
+    ValueRange mapOperands = store.getMapOperands();
+    SmallVector<Value, 8> indices;
+    indices.reserve(store.getMemRefType().getRank());
+    if (store.getAffineMap() !=
+        b.getMultiDimIdentityMap(store.getMemRefType().getRank())) {
+      computeMemoryOpIndices(opInst, store.getAffineMap(), mapOperands,
+                             indices);
+    } else {
+      indices.append(mapOperands.begin(), mapOperands.end());
+    }
 
-  // If we are vectorizing a vector dimension, compute a new step for the new
-  // vectorized loop using the vectorization factor for the vector dimension.
-  // Otherwise, propagate the step of the scalar loop.
-  const VectorizationStrategy &strategy = *state.strategy;
-  auto loopToVecDimIt = strategy.loopToVectorDim.find(forOp);
-  bool isLoopVecDim = loopToVecDimIt != strategy.loopToVectorDim.end();
-  unsigned newStep;
-  if (isLoopVecDim) {
-    unsigned vectorDim = loopToVecDimIt->second;
-    assert(vectorDim < strategy.vectorSizes.size() && "vector dim overflow");
-    int64_t forOpVecFactor = strategy.vectorSizes[vectorDim];
-    newStep = forOp.getStep() * forOpVecFactor;
-  } else {
-    newStep = forOp.getStep();
+    auto permutationMap =
+        makePermutationMap(opInst, indices, state->strategy->loopToVectorDim);
+    if (!permutationMap)
+      return nullptr;
+    LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: ");
+    LLVM_DEBUG(permutationMap.print(dbgs()));
+    auto transfer = b.create<vector::TransferWriteOp>(
+        opInst->getLoc(), vectorValue, memRef, indices, permutationMap);
+    auto *res = transfer.getOperation();
+    LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorized store: " << *res);
+    // "Terminals" (i.e. AffineStoreOps) are erased on the spot.
+    opInst->erase();
+    return res;
   }
+  if (opInst->getNumRegions() != 0)
+    return nullptr;
 
-  auto vecForOp = state.builder.create<AffineForOp>(
-      forOp.getLoc(), forOp.getLowerBoundOperands(), forOp.getLowerBoundMap(),
-      forOp.getUpperBoundOperands(), forOp.getUpperBoundMap(), newStep,
-      forOp.getIterOperands(),
-      /*bodyBuilder=*/[](OpBuilder &, Location, Value, ValueRange) {
-        // Make sure we don't create a default terminator in the loop body as
-        // the proper terminator will be added during vectorization.
-        return;
-      });
-
-  // Register loop-related replacements:
-  //   1) The new vectorized loop is registered as vector replacement of the
-  //      scalar loop.
-  //      TODO: Support reductions along the vector dimension.
-  //   2) The new iv of the vectorized loop is registered as scalar replacement
-  //      since a scalar copy of the iv will prevail in the vectorized loop.
-  //      TODO: A vector replacement will also be added in the future when
-  //      vectorization of linear ops is supported.
-  //   3) TODO: Support 'iter_args' along non-vector dimensions.
-  state.registerOpVectorReplacement(forOp, vecForOp);
-  state.registerValueScalarReplacement(forOp.getInductionVar(),
-                                       vecForOp.getInductionVar());
-  // Map the new vectorized loop to its vector dimension.
-  if (isLoopVecDim)
-    state.vecLoopToVecDim[vecForOp] = loopToVecDimIt->second;
-
-  // Change insertion point so that upcoming vectorized instructions are
-  // inserted into the vectorized loop's body.
-  state.builder.setInsertionPointToStart(vecForOp.getBody());
-  return vecForOp;
-}
-
-/// Vectorizes arbitrary operation by plain widening. We apply generic type
-/// widening of all its results and retrieve the vector counterparts for all its
-/// operands.
-static Operation *widenOp(Operation *op, VectorizationState &state) {
   SmallVector<Type, 8> vectorTypes;
-  for (Value result : op->getResults())
+  for (auto v : opInst->getResults()) {
     vectorTypes.push_back(
-        VectorType::get(state.strategy->vectorSizes, result.getType()));
-
+        VectorType::get(state->strategy->vectorSizes, v.getType()));
+  }
   SmallVector<Value, 8> vectorOperands;
-  for (Value operand : op->getOperands()) {
-    Value vecOperand = vectorizeOperand(operand, state);
-    if (!vecOperand) {
-      LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ an operand failed vectorize\n");
-      return nullptr;
-    }
-    vectorOperands.push_back(vecOperand);
+  for (auto v : opInst->getOperands()) {
+    vectorOperands.push_back(vectorizeOperand(v, opInst, state));
+  }
+  // Check whether a single operand is null. If so, vectorization failed.
+  bool success = llvm::all_of(vectorOperands, [](Value op) { return op; });
+  if (!success) {
+    LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ an operand failed vectorize");
+    return nullptr;
   }
 
   // Create a clone of the op with the proper operands and return types.
@@ -1089,64 +1079,59 @@ static Operation *widenOp(Operation *op, VectorizationState &state) {
   // name that works both in scalar mode and vector mode.
   // TODO: Is it worth considering an Operation.clone operation which
   // changes the type so we can promote an Operation with less boilerplate?
-  OperationState vecOpState(op->getLoc(), op->getName().getStringRef(),
-                            vectorOperands, vectorTypes, op->getAttrs(),
-                            /*successors=*/{}, /*regions=*/{});
-  Operation *vecOp = state.builder.createOperation(vecOpState);
-  state.registerOpVectorReplacement(op, vecOp);
-  return vecOp;
+  OpBuilder b(opInst);
+  OperationState newOp(opInst->getLoc(), opInst->getName().getStringRef(),
+                       vectorOperands, vectorTypes, opInst->getAttrs(),
+                       /*successors=*/{}, /*regions=*/{});
+  return b.createOperation(newOp);
 }
 
-/// Vectorizes a yield operation by widening its types. The builder's insertion
-/// point is set after the vectorized parent op to continue vectorizing the
-/// operations after the parent op.
-static Operation *vectorizeAffineYieldOp(AffineYieldOp yieldOp,
-                                         VectorizationState &state) {
-  // 'iter_args' not supported yet.
-  if (yieldOp.getNumOperands() > 0)
-    return nullptr;
-
-  // Vectorize the yield op and change the insertion point right after the new
-  // parent op.
-  Operation *newYieldOp = widenOp(yieldOp, state);
-  Operation *newParentOp = state.builder.getInsertionBlock()->getParentOp();
-  state.builder.setInsertionPointAfter(newParentOp);
-  return newYieldOp;
-}
-
-/// Encodes Operation-specific behavior for vectorization. In general we
-/// assume that all operands of an op must be vectorized but this is not
-/// always true. In the future, it would be nice to have a trait that
-/// describes how a particular operation vectorizes. For now we implement the
-/// case distinction here. Returns a vectorized form of an operation or
-/// nullptr if vectorization fails.
-// TODO: consider adding a trait to Op to describe how it gets vectorized.
-// Maybe some Ops are not vectorizable or require some tricky logic, we cannot
-// do one-off logic here; ideally it would be TableGen'd.
-static Operation *vectorizeOneOperation(Operation *op,
-                                        VectorizationState &state) {
-  // Sanity checks.
-  assert(!isa<vector::TransferReadOp>(op) &&
-         "vector.transfer_read cannot be further vectorized");
-  assert(!isa<vector::TransferWriteOp>(op) &&
-         "vector.transfer_write cannot be further vectorized");
-
-  if (auto loadOp = dyn_cast<AffineLoadOp>(op))
-    return vectorizeAffineLoad(loadOp, state);
-  if (auto storeOp = dyn_cast<AffineStoreOp>(op))
-    return vectorizeAffineStore(storeOp, state);
-  if (auto forOp = dyn_cast<AffineForOp>(op))
-    return vectorizeAffineForOp(forOp, state);
-  if (auto yieldOp = dyn_cast<AffineYieldOp>(op))
-    return vectorizeAffineYieldOp(yieldOp, state);
-  if (auto constant = dyn_cast<ConstantOp>(op))
-    return vectorizeConstant(constant, state);
-
-  // Other ops with regions are not supported.
-  if (op->getNumRegions() != 0)
-    return nullptr;
+/// Iterates over the forward slice from the loads in the vectorization pattern
+/// and rewrites them using their vectorized counterpart by:
+///   1. Create the forward slice starting from the loads in the vectorization
+///   pattern.
+///   2. Topologically sorts the forward slice.
+///   3. For each operation in the slice, create the vector form of this
+///   operation, replacing each operand by a replacement operands retrieved from
+///   replacementMap. If any such replacement is missing, vectorization fails.
+static LogicalResult vectorizeNonTerminals(VectorizationState *state) {
+  // 1. create initial worklist with the uses of the roots.
+  SetVector<Operation *> worklist;
+  // Note: state->roots have already been vectorized and must not be vectorized
+  // again. This fits `getForwardSlice` which does not insert `op` in the
+  // result.
+  // Note: we have to exclude terminals because some of their defs may not be
+  // nested under the vectorization pattern (e.g. constants defined in an
+  // encompassing scope).
+  // TODO: Use a backward slice for terminals, avoid special casing and
+  // merge implementations.
+  for (auto *op : state->roots) {
+    getForwardSlice(op, &worklist, [state](Operation *op) {
+      return state->terminals.count(op) == 0; // propagate if not terminal
+    });
+  }
+  // We merged multiple slices, topological order may not hold anymore.
+  worklist = topologicalSort(worklist);
+
+  for (unsigned i = 0; i < worklist.size(); ++i) {
+    auto *op = worklist[i];
+    LLVM_DEBUG(dbgs() << "\n[early-vect] vectorize use: ");
+    LLVM_DEBUG(op->print(dbgs()));
+
+    // Create vector form of the operation.
+    // Insert it just before op, on success register op as replaced.
+    auto *vectorizedInst = vectorizeOneOperation(op, state);
+    if (!vectorizedInst) {
+      return failure();
+    }
 
-  return widenOp(op, state);
+    // 3. Register replacement for future uses in the scope.
+    //    Note that we cannot just call replaceAllUsesWith because it may
+    //    result in ops with mixed types, for ops whose operands have not all
+    //    yet been vectorized. This would be invalid IR.
+    state->registerReplacement(op, vectorizedInst);
+  }
+  return success();
 }
 
 /// Recursive implementation to convert all the nested loops in 'match' to a 2D
@@ -1186,9 +1171,10 @@ vectorizeLoopNest(std::vector<SmallVector<AffineForOp, 2>> &loops,
                   const VectorizationStrategy &strategy) {
   assert(loops[0].size() == 1 && "Expected single root loop");
   AffineForOp rootLoop = loops[0][0];
-  VectorizationState state(rootLoop.getContext());
-  state.builder.setInsertionPointAfter(rootLoop);
+  OperationFolder folder(rootLoop.getContext());
+  VectorizationState state;
   state.strategy = &strategy;
+  state.folder = &folder;
 
   // Since patterns are recursive, they can very well intersect.
   // Since we do not want a fully greedy strategy in general, we decouple
@@ -1202,48 +1188,70 @@ vectorizeLoopNest(std::vector<SmallVector<AffineForOp, 2>> &loops,
     return failure();
   }
 
+  /// Sets up error handling for this root loop. This is how the root match
+  /// maintains a clone for handling failure and restores the proper state via
+  /// RAII.
+  auto *loopInst = rootLoop.getOperation();
+  OpBuilder builder(loopInst);
+  auto clonedLoop = cast<AffineForOp>(builder.clone(*loopInst));
+  struct Guard {
+    LogicalResult failure() {
+      loop.getInductionVar().replaceAllUsesWith(clonedLoop.getInductionVar());
+      loop.erase();
+      return mlir::failure();
+    }
+    LogicalResult success() {
+      clonedLoop.erase();
+      return mlir::success();
+    }
+    AffineForOp loop;
+    AffineForOp clonedLoop;
+  } guard{rootLoop, clonedLoop};
+
   //////////////////////////////////////////////////////////////////////////////
-  // Vectorize the scalar loop nest following a topological order. A new vector
-  // loop nest with the vectorized operations is created along the process. If
-  // vectorization succeeds, the scalar loop nest is erased. If vectorization
-  // fails, the vector loop nest is erased and the scalar loop nest is not
-  // modified.
+  // Start vectorizing.
+  // From now on, any error triggers the scope guard above.
   //////////////////////////////////////////////////////////////////////////////
+  // 1. Vectorize all the loop candidates, in inner-to-outer order.
+  // This also vectorizes the roots (AffineLoadOp) as well as registers the
+  // terminals (AffineStoreOp) for post-processing vectorization (we need to
+  // wait for all use-def chains into them to be vectorized first).
+  if (failed(vectorizeLoopsAndLoads(loops, &state))) {
+    LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ failed root vectorizeLoop");
+    return guard.failure();
+  }
 
-  auto opVecResult = rootLoop.walk<WalkOrder::PreOrder>([&](Operation *op) {
-    LLVM_DEBUG(dbgs() << "[early-vect]+++++ Vectorizing: " << *op);
-    Operation *vectorOp = vectorizeOneOperation(op, state);
-    if (!vectorOp)
-      return WalkResult::interrupt();
-
-    return WalkResult::advance();
-  });
-
-  if (opVecResult.wasInterrupted()) {
-    LLVM_DEBUG(dbgs() << "[early-vect]+++++ failed vectorization for: "
-                      << rootLoop << "\n");
-    // Erase vector loop nest if it was created.
-    auto vecRootLoopIt = state.opVectorReplacement.find(rootLoop);
-    if (vecRootLoopIt != state.opVectorReplacement.end())
-      eraseLoopNest(cast<AffineForOp>(vecRootLoopIt->second));
+  // 2. Vectorize operations reached by use-def chains from root except the
+  // terminals (store operations) that need to be post-processed separately.
+  // TODO: add more as we expand.
+  if (failed(vectorizeNonTerminals(&state))) {
+    LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ failed vectorizeNonTerminals");
+    return guard.failure();
+  }
 
-    return failure();
+  // 3. Post-process terminals.
+  // Note: we have to post-process terminals because some of their defs may not
+  // be nested under the vectorization pattern (e.g. constants defined in an
+  // encompassing scope).
+  // TODO: Use a backward slice for terminals, avoid special casing and
+  // merge implementations.
+  for (auto *op : state.terminals) {
+    if (!vectorizeOneOperation(op, &state)) { // nullptr == failure
+      LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ failed to vectorize terminals");
+      return guard.failure();
+    }
   }
 
-  assert(state.opVectorReplacement.count(rootLoop) == 1 &&
-         "Expected vector replacement for loop nest");
+  // 4. Finish this vectorization pattern.
   LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ success vectorizing pattern");
-  LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorization result:\n"
-                    << *state.opVectorReplacement[rootLoop]);
-
-  // Finish this vectorization pattern.
-  state.finishVectorizationPattern(rootLoop);
-  return success();
+  state.finishVectorizationPattern();
+  return guard.success();
 }
 
-/// Extracts the matched loops and vectorizes them following a topological
-/// order. A new vector loop nest will be created if vectorization succeeds. The
-/// original loop nest won't be modified in any case.
+/// Vectorization is a recursive procedure where anything below can fail. The
+/// root match thus needs to maintain a clone for handling failure. Each root
+/// may succeed independently but will otherwise clean after itself if anything
+/// below it fails.
 static LogicalResult vectorizeRootMatch(NestedMatch m,
                                         const VectorizationStrategy &strategy) {
   std::vector<SmallVector<AffineForOp, 2>> loopsToVectorize;
@@ -1264,7 +1272,7 @@ static void vectorizeLoops(Operation *parentOp, DenseSet<Operation *> &loops,
     LLVM_DEBUG(dbgs() << "\n******************************************");
     LLVM_DEBUG(dbgs() << "\n******************************************");
     LLVM_DEBUG(dbgs() << "\n[early-vect] new pattern on parent op\n");
-    LLVM_DEBUG(dbgs() << *parentOp << "\n");
+    LLVM_DEBUG(parentOp->print(dbgs()));
 
     unsigned patternDepth = pat.getDepth();
 

diff  --git a/mlir/lib/Dialect/Vector/VectorUtils.cpp b/mlir/lib/Dialect/Vector/VectorUtils.cpp
index b68886cb3e40..ef3ef3db1f81 100644
--- a/mlir/lib/Dialect/Vector/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/VectorUtils.cpp
@@ -211,29 +211,29 @@ static AffineMap makePermutationMap(
 /// TODO: could also be implemented as a collect parents followed by a
 /// filter and made available outside this file.
 template <typename T>
-static SetVector<Operation *> getParentsOfType(Block *block) {
+static SetVector<Operation *> getParentsOfType(Operation *op) {
   SetVector<Operation *> res;
-  auto *current = block->getParentOp();
-  while (current) {
-    if (auto typedParent = dyn_cast<T>(current)) {
-      assert(res.count(current) == 0 && "Already inserted");
-      res.insert(current);
+  auto *current = op;
+  while (auto *parent = current->getParentOp()) {
+    if (auto typedParent = dyn_cast<T>(parent)) {
+      assert(res.count(parent) == 0 && "Already inserted");
+      res.insert(parent);
     }
-    current = current->getParentOp();
+    current = parent;
   }
   return res;
 }
 
 /// Returns the enclosing AffineForOp, from closest to farthest.
-static SetVector<Operation *> getEnclosingforOps(Block *block) {
-  return getParentsOfType<AffineForOp>(block);
+static SetVector<Operation *> getEnclosingforOps(Operation *op) {
+  return getParentsOfType<AffineForOp>(op);
 }
 
 AffineMap mlir::makePermutationMap(
-    Block *insertPoint, ArrayRef<Value> indices,
+    Operation *op, ArrayRef<Value> indices,
     const DenseMap<Operation *, unsigned> &loopToVectorDim) {
   DenseMap<Operation *, unsigned> enclosingLoopToVectorDim;
-  auto enclosingLoops = getEnclosingforOps(insertPoint);
+  auto enclosingLoops = getEnclosingforOps(op);
   for (auto *forInst : enclosingLoops) {
     auto it = loopToVectorDim.find(forInst);
     if (it != loopToVectorDim.end()) {
@@ -243,12 +243,6 @@ AffineMap mlir::makePermutationMap(
   return ::makePermutationMap(indices, enclosingLoopToVectorDim);
 }
 
-AffineMap mlir::makePermutationMap(
-    Operation *op, ArrayRef<Value> indices,
-    const DenseMap<Operation *, unsigned> &loopToVectorDim) {
-  return makePermutationMap(op->getBlock(), indices, loopToVectorDim);
-}
-
 AffineMap mlir::getTransferMinorIdentityMap(ShapedType shapedType,
                                             VectorType vectorType) {
   int64_t elementVectorRank = 0;

diff  --git a/mlir/test/Dialect/Affine/SuperVectorize/vectorize_1d.mlir b/mlir/test/Dialect/Affine/SuperVectorize/vectorize_1d.mlir
index 528169e26d11..86749e2c7bab 100644
--- a/mlir/test/Dialect/Affine/SuperVectorize/vectorize_1d.mlir
+++ b/mlir/test/Dialect/Affine/SuperVectorize/vectorize_1d.mlir
@@ -1,8 +1,16 @@
-// RUN: mlir-opt %s -affine-super-vectorize="virtual-vector-size=128 test-fastest-varying=0" -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -affine-super-vectorize="virtual-vector-size=128 test-fastest-varying=0" | FileCheck %s
 
-// CHECK-DAG: #[[$map_id1:map[0-9]+]] = affine_map<(d0) -> (d0)>
+// Permutation maps used in vectorization.
 // CHECK-DAG: #[[$map_proj_d0d1_0:map[0-9]+]] = affine_map<(d0, d1) -> (0)>
+// CHECK-DAG: #[[$map_id1:map[0-9]+]] = affine_map<(d0) -> (d0)>
+
+#map0 = affine_map<(d0) -> (d0)>
+#mapadd1 = affine_map<(d0) -> (d0 + 1)>
+#mapadd2 = affine_map<(d0) -> (d0 + 2)>
+#mapadd3 = affine_map<(d0) -> (d0 + 3)>
+#set0 = affine_set<(i) : (i >= 0)>
 
+// Maps introduced to vectorize fastest varying memory index.
 // CHECK-LABEL: func @vec1d_1
 func @vec1d_1(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
 // CHECK-DAG: %[[C0:.*]] = constant 0 : index
@@ -29,8 +37,6 @@ func @vec1d_1(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
    return
 }
 
-// -----
-
 // CHECK-LABEL: func @vec1d_2
 func @vec1d_2(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
 // CHECK-DAG: %[[C0:.*]] = constant 0 : index
@@ -55,8 +61,6 @@ func @vec1d_2(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
    return
 }
 
-// -----
-
 // CHECK-LABEL: func @vec1d_3
 func @vec1d_3(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
 // CHECK-DAG: %[[C0:.*]] = constant 0 : index
@@ -86,8 +90,6 @@ func @vec1d_3(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
    return
 }
 
-// -----
-
 // CHECK-LABEL: func @vector_add_2d
 func @vector_add_2d(%M : index, %N : index) -> f32 {
   %A = alloc (%M, %N) : memref<?x?xf32, 0>
@@ -140,8 +142,6 @@ func @vector_add_2d(%M : index, %N : index) -> f32 {
   return %res : f32
 }
 
-// -----
-
 // CHECK-LABEL: func @vec_rejected_1
 func @vec_rejected_1(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
 // CHECK-DAG: %[[C0:.*]] = constant 0 : index
@@ -164,8 +164,6 @@ func @vec_rejected_1(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
    return
 }
 
-// -----
-
 // CHECK-LABEL: func @vec_rejected_2
 func @vec_rejected_2(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
 // CHECK-DAG: %[[C0:.*]] = constant 0 : index
@@ -188,8 +186,6 @@ func @vec_rejected_2(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
    return
 }
 
-// -----
-
 // CHECK-LABEL: func @vec_rejected_3
 func @vec_rejected_3(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
 // CHECK-DAG: %[[C0:.*]] = constant 0 : index
@@ -217,8 +213,6 @@ func @vec_rejected_3(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
    return
 }
 
-// -----
-
 // CHECK-LABEL: func @vec_rejected_4
 func @vec_rejected_4(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
 // CHECK-DAG: %[[C0:.*]] = constant 0 : index
@@ -244,8 +238,6 @@ func @vec_rejected_4(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
    return
 }
 
-// -----
-
 // CHECK-LABEL: func @vec_rejected_5
 func @vec_rejected_5(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
 // CHECK-DAG: %[[C0:.*]] = constant 0 : index
@@ -272,8 +264,6 @@ func @vec_rejected_5(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
    return
 }
 
-// -----
-
 // CHECK-LABEL: func @vec_rejected_6
 func @vec_rejected_6(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
 // CHECK-DAG: %[[C0:.*]] = constant 0 : index
@@ -302,8 +292,6 @@ func @vec_rejected_6(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
    return
 }
 
-// -----
-
 // CHECK-LABEL: func @vec_rejected_7
 func @vec_rejected_7(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
 // CHECK-DAG: %[[C0:.*]] = constant 0 : index
@@ -327,11 +315,6 @@ func @vec_rejected_7(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
    return
 }
 
-// -----
-
-// CHECK-DAG: #[[$map_id1:map[0-9]+]] = affine_map<(d0) -> (d0)>
-// CHECK-DAG: #[[$map_proj_d0d1_0:map[0-9]+]] = affine_map<(d0, d1) -> (0)>
-
 // CHECK-LABEL: func @vec_rejected_8
 func @vec_rejected_8(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
 // CHECK-DAG: %[[C0:.*]] = constant 0 : index
@@ -361,11 +344,6 @@ func @vec_rejected_8(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
    return
 }
 
-// -----
-
-// CHECK-DAG: #[[$map_id1:map[0-9]+]] = affine_map<(d0) -> (d0)>
-// CHECK-DAG: #[[$map_proj_d0d1_0:map[0-9]+]] = affine_map<(d0, d1) -> (0)>
-
 // CHECK-LABEL: func @vec_rejected_9
 func @vec_rejected_9(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
 // CHECK-DAG: %[[C0:.*]] = constant 0 : index
@@ -395,10 +373,6 @@ func @vec_rejected_9(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
    return
 }
 
-// -----
-
-#set0 = affine_set<(i) : (i >= 0)>
-
 // CHECK-LABEL: func @vec_rejected_10
 func @vec_rejected_10(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
 // CHECK-DAG: %[[C0:.*]] = constant 0 : index
@@ -423,8 +397,6 @@ func @vec_rejected_10(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
    return
 }
 
-// -----
-
 // CHECK-LABEL: func @vec_rejected_11
 func @vec_rejected_11(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
   // CHECK-DAG: %[[C0:.*]] = constant 0 : index
@@ -452,9 +424,7 @@ func @vec_rejected_11(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
   return
 }
 
-// -----
-
-// This should not vectorize due to the sequential dependence in the loop.
+// This should not vectorize due to the sequential dependence in the scf.
 // CHECK-LABEL: @vec_rejected_sequential
 func @vec_rejected_sequential(%A : memref<?xf32>) {
   %c0 = constant 0 : index
@@ -467,66 +437,3 @@ func @vec_rejected_sequential(%A : memref<?xf32>) {
   }
   return
 }
-
-// -----
-
-// CHECK-LABEL: @vec_no_load_store_ops
-func @vec_no_load_store_ops(%a: f32, %b: f32) {
- %cst = constant 0.000000e+00 : f32
- affine.for %i = 0 to 128 {
-   %add = addf %a, %b : f32
- }
- // CHECK-DAG:  %[[bc1:.*]] = vector.broadcast
- // CHECK-DAG:  %[[bc0:.*]] = vector.broadcast
- // CHECK:      affine.for %{{.*}} = 0 to 128 step
- // CHECK-NEXT:   [[add:.*]] addf %[[bc0]], %[[bc1]]
-
- return
-}
-
-// -----
-
-// This should not be vectorized due to the unsupported block argument (%i).
-// Support for operands with linear evolution is needed.
-// CHECK-LABEL: @vec_rejected_unsupported_block_arg
-func @vec_rejected_unsupported_block_arg(%A : memref<512xi32>) {
-  affine.for %i = 0 to 512 {
-    // CHECK-NOT: vector
-    %idx = std.index_cast %i : index to i32
-    affine.store %idx, %A[%i] : memref<512xi32>
-  }
-  return
-}
-
-// -----
-
-// CHECK-LABEL: @vec_rejected_unsupported_reduction
-func @vec_rejected_unsupported_reduction(%in: memref<128x256xf32>, %out: memref<256xf32>) {
- %cst = constant 0.000000e+00 : f32
- affine.for %i = 0 to 256 {
-   // CHECK-NOT: vector
-   %final_red = affine.for %j = 0 to 128 iter_args(%red_iter = %cst) -> (f32) {
-     %ld = affine.load %in[%j, %i] : memref<128x256xf32>
-     %add = addf %red_iter, %ld : f32
-     affine.yield %add : f32
-   }
-   affine.store %final_red, %out[%i] : memref<256xf32>
- }
- return
-}
-
-// -----
-
-// CHECK-LABEL: @vec_rejected_unsupported_last_value
-func @vec_rejected_unsupported_last_value(%in: memref<128x256xf32>, %out: memref<256xf32>) {
- %cst = constant 0.000000e+00 : f32
- affine.for %i = 0 to 256 {
-   // CHECK-NOT: vector
-   %last_val = affine.for %j = 0 to 128 iter_args(%last_iter = %cst) -> (f32) {
-     %ld = affine.load %in[%j, %i] : memref<128x256xf32>
-     affine.yield %ld : f32
-   }
-   affine.store %last_val, %out[%i] : memref<256xf32>
- }
- return
-}


        


More information about the Mlir-commits mailing list