[Mlir-commits] [mlir] 0a2a260 - [mlir][Linalg] Refactor Linalg vectorization for better reuse and extensibility.

Nicolas Vasilache llvmlistbot at llvm.org
Tue Feb 2 03:34:48 PST 2021


Author: Nicolas Vasilache
Date: 2021-02-02T11:31:09Z
New Revision: 0a2a260aab177bfbdef5829ea16e39323ce50916

URL: https://github.com/llvm/llvm-project/commit/0a2a260aab177bfbdef5829ea16e39323ce50916
DIFF: https://github.com/llvm/llvm-project/commit/0a2a260aab177bfbdef5829ea16e39323ce50916.diff

LOG: [mlir][Linalg] Refactor Linalg vectorization for better reuse and extensibility.

This revision unifies Linalg vectorization and paves the way for vectorization of Linalg ops with mixed-precision operations.
The new algorithm traverses the ops in the linalg block in order and avoids recursion.
It uses a BlockAndValueMapping to keep track of vectorized operations.

The revision makes the following modifications but is otherwise NFC:
1. vector.transfer_read are created eagerly and may appear in a different order than the original order.
2. a more progressive vectorization to vector.contract results in only the multiply operation being converted to `vector.contract %a, %b, %zero`, where `%zero` is a
constant of the proper type. Later vector canonicalizations are assumed to rewrite vector.contract %a, %b, %zero + add to a proper accumulate form.

Differential revision: https://reviews.llvm.org/D95797

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/test/Dialect/Linalg/vectorization.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index fa1aba8fd157..047e5ee045df 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -23,6 +23,8 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Support/LLVM.h"
+#include "mlir/Transforms/RegionUtils.h"
+#include "llvm/ADT/ScopeExit.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
 #include <type_traits>
@@ -36,6 +38,275 @@ using llvm::dbgs;
 
 #define DEBUG_TYPE "linalg-vectorization"
 
+/// Helper data structure to represent the result of vectorization.
+/// In certain specific cases, like terminators, we do not want to propagate/
+enum VectorizationStatus {
+  /// Op failed to vectorize.
+  Failure = 0,
+  /// Op vectorized and custom function took care of replacement logic
+  NoReplace,
+  /// Op vectorized into a new Op whose results will replace original Op's
+  /// results.
+  NewOp
+  // TODO: support values if Op vectorized to Many-Ops whose results we need to
+  // aggregate for replacement.
+};
+struct VectorizationResult {
+  /// Return status from vectorizing the current op.
+  enum VectorizationStatus status = VectorizationStatus::Failure;
+  /// New vectorized operation to replace the current op.
+  /// Replacement behavior is specified by `status`.
+  Operation *newOp;
+};
+
+/// Return a vector type of the same shape and element type as the (assumed)
+/// ShapedType of `v`.
+static VectorType extractVectorTypeFromShapedValue(Value v) {
+  auto st = v.getType().cast<ShapedType>();
+  if (st.isa<MemRefType>() && st.getShape().empty())
+    return VectorType();
+  return VectorType::get(st.getShape(), st.getElementType());
+}
+
+/// Build a vector.transfer_read from `source` at indices set to all `0`.
+/// If source has rank zero, build an std.load.
+/// Return the produced value.
+static Value buildVectorRead(OpBuilder &builder, Value source) {
+  edsc::ScopedContext scope(builder);
+  auto shapedType = source.getType().cast<ShapedType>();
+  if (VectorType vectorType = extractVectorTypeFromShapedValue(source)) {
+    SmallVector<Value> indices(shapedType.getRank(), std_constant_index(0));
+    return vector_transfer_read(vectorType, source, indices);
+  }
+  return std_load(source);
+}
+
+/// Build a vector.transfer_write of `value` into `dest` at indices set to all
+/// `0`. If `dest` has null rank, build an std.store.
+/// Return the produced value or null if no value is produced.
+static Value buildVectorWrite(OpBuilder &builder, Value value, Value dest) {
+  edsc::ScopedContext scope(builder);
+  Operation *write;
+  auto shapedType = dest.getType().cast<ShapedType>();
+  if (VectorType vectorType = extractVectorTypeFromShapedValue(dest)) {
+    SmallVector<Value> indices(shapedType.getRank(), std_constant_index(0));
+    if (vectorType != value.getType())
+      value = vector_broadcast(vectorType, value);
+    write = vector_transfer_write(value, dest, indices);
+  } else {
+    write = std_store(value, dest);
+  }
+  LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: vectorized op: " << *write);
+  if (!write->getResults().empty())
+    return write->getResult(0);
+  return Value();
+}
+
+/// If value of assumed VectorType has a shape 
diff erent than `shape`, buil and
+/// return a new vector.broadcast to `shape`.
+/// Otherwise, just return value.
+static Value broadcastIfNeeded(OpBuilder &builder, Value value,
+                               ArrayRef<int64_t> shape) {
+  auto vecType = value.getType().dyn_cast<VectorType>();
+  if (shape.empty() || (vecType != nullptr && vecType.getShape() == shape))
+    return value;
+  auto newVecType = VectorType::get(shape, vecType ? vecType.getElementType()
+                                                   : value.getType());
+  return builder.create<vector::BroadcastOp>(
+      builder.getInsertionPoint()->getLoc(), newVecType, value);
+}
+
+// Custom vectorization function type. Produce a vector form of Operation*
+// assuming all its vectorized operands are already in the BlockAndValueMapping.
+// Return nullptr if the Operation cannot be vectorized.
+using CustomVectorizationHook = std::function<VectorizationResult(
+    Operation *, const BlockAndValueMapping &)>;
+
+/// Helper function to vectorize the terminator of a `linalgOp`. New result
+/// vector values are appended to `results`.
+/// Return VectorizationStatus::NoReplace to signal the vectorization algorithm
+/// that it should not try to map produced operations: this is the purpose of
+/// the `results` argument to capture such values and make them available for
+/// RAUW to the vectorization algorithm.
+/// This function is meant to be used as a CustomVectorizationHook.
+static VectorizationResult
+vectorizeLinalgYield(OpBuilder &builder, Operation *op,
+                     const BlockAndValueMapping &bvm, LinalgOp linalgOp,
+                     SmallVectorImpl<Value> &results) {
+  auto yieldOp = dyn_cast<linalg::YieldOp>(op);
+  if (!yieldOp)
+    return VectorizationResult{VectorizationStatus::Failure, nullptr};
+  for (auto outputs : llvm::enumerate(yieldOp.values())) {
+    // TODO: Scan for an opportunity for reuse.
+    // TODO: use a map.
+    Value vectorValue = bvm.lookup(outputs.value());
+    Value result = buildVectorWrite(builder, vectorValue,
+                                    linalgOp.getOutput(outputs.index()));
+    if (result)
+      results.push_back(result);
+  }
+  return VectorizationResult{VectorizationStatus::NoReplace, nullptr};
+};
+
+/// Generic vectorization for a single operation `op`, given already vectorized
+/// operands carried by `bvm`. Vectorization occurs as follows:
+///   1. Try to apply any of the `customVectorizationHooks` and return its
+///   result on success.
+///   2. Clone any constant in the current scope without vectorization: each
+///   consumer of the constant will later determine the shape to which the
+///   constant needs to be broadcast to.
+///   3. Fail on any remaining non `ElementwiseMappable` op. It is the purpose
+///   of the `customVectorizationHooks` to cover such cases.
+///   4. Clone `op` in vector form to a vector of shape prescribed by the first
+///   operand of maximal rank. Other operands have smaller rank and are
+///   broadcast accordingly. It is assumed this broadcast is always legal,
+///   otherwise, it means one of the `customVectorizationHooks` is incorrect.
+///
+/// This function assumes all operands of `op` have been vectorized and are in
+/// the `bvm` mapping. As a consequence, this function is meant to be called on
+/// a topologically-sorted list of ops.
+/// This function does not update `bvm` but returns a VectorizationStatus that
+/// instructs the caller what `bvm` update needs to occur.
+static VectorizationResult
+vectorizeOneOp(OpBuilder &builder, Operation *op,
+               const BlockAndValueMapping &bvm,
+               ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
+  LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: vectorize op " << *op);
+
+  // 1. Try to apply any CustomVectorizationHook.
+  if (!customVectorizationHooks.empty()) {
+    for (auto &customFunc : customVectorizationHooks) {
+      VectorizationResult result = customFunc(op, bvm);
+      if (result.status == VectorizationStatus::Failure)
+        continue;
+      return result;
+    }
+  }
+
+  // 2. Constant ops don't get vectorized but rather broadcasted at their users.
+  // Clone so that the constant is not confined to the linalgOp block .
+  if (isa<ConstantOp>(op))
+    return VectorizationResult{VectorizationStatus::NewOp, builder.clone(*op)};
+
+  // 3. Only ElementwiseMappable are allowed in the generic vectorization.
+  if (!op->hasTrait<OpTrait::ElementwiseMappable>())
+    return VectorizationResult{VectorizationStatus::Failure, nullptr};
+
+  // 4. Generic vectorization path for ElementwiseMappable ops.
+  //   a. first get the first max ranked shape.
+  SmallVector<int64_t, 4> firstMaxRankedShape;
+  for (Value operand : op->getOperands()) {
+    auto vt = bvm.lookup(operand).getType().dyn_cast<VectorType>();
+    if (vt && firstMaxRankedShape.size() < vt.getShape().size())
+      firstMaxRankedShape.assign(vt.getShape().begin(), vt.getShape().end());
+  }
+  //   b. broadcast each op if needed.
+  auto vectorizedOperands = llvm::map_range(op->getOperands(), [&](Value v) {
+    return firstMaxRankedShape.empty()
+               ? bvm.lookup(v)
+               : broadcastIfNeeded(builder, bvm.lookup(v), firstMaxRankedShape);
+  });
+  //   c. for elementwise, the result is the vector with the firstMaxRankedShape
+  auto returnTypes = llvm::map_range(op->getResultTypes(), [&](Type t) {
+    return firstMaxRankedShape.empty()
+               ? t
+               : VectorType::get(firstMaxRankedShape, t);
+  });
+
+  // Build and return the new op.
+  OperationState state(op->getLoc(), op->getName());
+  state.addAttributes(op->getAttrs());
+  state.addOperands(llvm::to_vector<4>(vectorizedOperands));
+  state.addTypes(llvm::to_vector<4>(returnTypes));
+  return VectorizationResult{VectorizationStatus::NewOp,
+                             builder.createOperation(state)};
+}
+
+/// Generic vectorization function that rewrites the body of a `linalgOp` into
+/// vector form. Generic vectorization proceeds as follows:
+///   1. The region for the linalg op is created if necessary.
+///   2. Values defined above the region are mapped to themselves and will be
+///   broadcasted on a per-need basis by their consumers.
+///   3. Each region argument is vectorized into a vector.transfer_read (or 0-d
+///   load).
+///   TODO: Reuse opportunities for RAR dependencies.
+///   4. Register CustomVectorizationHook for YieldOp to capture the results.
+///   5. Iteratively call vectorizeOneOp on the region operations.
+///   6. RAUW the linalg op by the results captured vectorizing the YieldOp.
+static LogicalResult vectorizeAsLinalgGeneric(
+    OpBuilder &builder, LinalgOp linalgOp,
+    ArrayRef<CustomVectorizationHook> customVectorizationHooks = {}) {
+  // 1. Certain Linalg ops do not have a region but only a region builder.
+  // If so, build the region so we can vectorize.
+  std::unique_ptr<Region> owningRegion;
+  Region *region;
+  if (linalgOp->getNumRegions() > 0) {
+    region = &linalgOp->getRegion(0);
+  } else {
+    // RAII avoid remaining in block.
+    OpBuilder::InsertionGuard g(builder);
+    owningRegion = std::make_unique<Region>();
+    region = owningRegion.get();
+    Block *block = builder.createBlock(region);
+    auto elementTypes = llvm::to_vector<4>(
+        llvm::map_range(linalgOp.getShapedOperandTypes(),
+                        [](ShapedType t) { return t.getElementType(); }));
+    block->addArguments(elementTypes);
+    linalgOp.getRegionBuilder()(*block);
+  }
+  Block *block = &region->front();
+
+  BlockAndValueMapping bvm;
+  // 2. Values defined above the region can only be broadcast for now. Make them
+  // map to themselves.
+  llvm::SetVector<Value> valuesSet;
+  mlir::getUsedValuesDefinedAbove(*region, valuesSet);
+  bvm.map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
+
+  // 3. Turn all BBArgs into vector.transfer_read / load.
+  SmallVector<AffineMap> indexings;
+  for (auto bbarg : block->getArguments()) {
+    Value vectorArg = linalgOp.getShapedOperand(bbarg.getArgNumber());
+    Value vectorRead = buildVectorRead(builder, vectorArg);
+    LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vectorized bbarg("
+                      << bbarg.getArgNumber() << "): " << vectorRead);
+    bvm.map(bbarg, vectorRead);
+    bvm.map(vectorArg, vectorRead);
+  }
+
+  // 4. Register CustomVectorizationHook for yieldOp.
+  SmallVector<Value> results;
+  CustomVectorizationHook vectorizeYield =
+      [&](Operation *op,
+          const BlockAndValueMapping &bvm) -> VectorizationResult {
+    return vectorizeLinalgYield(builder, op, bvm, linalgOp, results);
+  };
+  // Append the vectorizeYield hook.
+  auto hooks = llvm::to_vector<4>(customVectorizationHooks);
+  hooks.push_back(vectorizeYield);
+
+  // 5. Iteratively call `vectorizeOneOp` to each op in the slice.
+  for (Operation &op : block->getOperations()) {
+    VectorizationResult result = vectorizeOneOp(builder, &op, bvm, hooks);
+    if (result.status == VectorizationStatus::Failure) {
+      LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: failed to vectorize: " << op);
+      return failure();
+    }
+    if (result.status == VectorizationStatus::NewOp) {
+      LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vector op: "
+                        << *result.newOp;);
+      bvm.map(op.getResults(), result.newOp->getResults());
+    }
+  }
+
+  // 6. RAUW the linalg op by the results captured vectorizing the YieldOp.
+  if (!results.empty())
+    linalgOp->replaceAllUsesWith(results);
+  return success();
+}
+
+/// Detect whether `r` exactly computes a floating-point or integer
+/// multiply-accumulate.
 static bool hasMultiplyAddBody(Region &r) {
   if (!llvm::hasSingleElement(r))
     return false;
@@ -65,6 +336,7 @@ static bool hasMultiplyAddBody(Region &r) {
          pattern7.match(&r.front().back()) || pattern8.match(&r.front().back());
 }
 
+/// Detect whether the LinalgOp `op` is a contraction.
 // TODO: Should be Tablegen'd from a single source that generates the op itself.
 static LogicalResult isContraction(Operation *op) {
   // TODO: interface for named ops.
@@ -84,6 +356,7 @@ static LogicalResult isContraction(Operation *op) {
       hasMultiplyAddBody(genericOp.region()));
 }
 
+/// Detect whether `r` has only ConstantOp, ElementwiseMappable and YieldOp.
 static bool hasOnlyScalarElementwiseOp(Region &r) {
   if (!llvm::hasSingleElement(r))
     return false;
@@ -119,171 +392,6 @@ static bool isElementwise(Operation *op) {
   return hasOnlyScalarElementwiseOp(genericOp.getRegion());
 }
 
-static VectorType extractVectorTypeFromShapedValue(Value v) {
-  auto st = v.getType().cast<ShapedType>();
-  if (st.isa<MemRefType>() && st.getShape().empty())
-    return VectorType();
-  return VectorType::get(st.getShape(), st.getElementType());
-}
-
-static Value transferReadVector(OpBuilder &builder, Value source) {
-  edsc::ScopedContext scope(builder);
-  auto shapedType = source.getType().cast<ShapedType>();
-  if (VectorType vectorType = extractVectorTypeFromShapedValue(source)) {
-    SmallVector<Value, 4> indices(shapedType.getRank(), std_constant_index(0));
-    return vector_transfer_read(vectorType, source, indices);
-  }
-  return std_load(source);
-}
-
-static Value transferWriteVector(OpBuilder &builder, Value value, Value dest) {
-  edsc::ScopedContext scope(builder);
-  Operation *write;
-  auto shapedType = dest.getType().cast<ShapedType>();
-  if (VectorType vectorType = extractVectorTypeFromShapedValue(dest)) {
-    SmallVector<Value, 4> indices(shapedType.getRank(), std_constant_index(0));
-    if (vectorType != value.getType())
-      value = vector_broadcast(vectorType, value);
-    write = vector_transfer_write(value, dest, indices);
-  } else {
-    write = std_store(value, dest);
-  }
-  if (!write->getResults().empty())
-    return write->getResult(0);
-  return Value();
-}
-
-namespace {
-// Transforms scalar operations into their vectorized counterparts,
-// while using the provided generic op to map:
-//   * Its arguments to transfer reads from the views of the generic op.
-//   * linalg.yield ops to transfer writes to the views of the generic op.
-class GenericVectorizer {
-public:
-  GenericVectorizer(OpBuilder &builder, linalg::GenericOp generic)
-      : builder(builder), generic(generic) {}
-
-  // Takes a scalar operation and builds its vectorized counterpart or
-  // counterparts using the underlying builder.
-  // If operands of the scalar operation are referring to previously vectorized
-  // operations, then in their vectorized form these operands will be referring
-  // to previous vectorization results.
-  void vectorize(Operation &scalarOp) {
-    auto yieldOp = dyn_cast<linalg::YieldOp>(scalarOp);
-    if (yieldOp) {
-      for (auto outputs : llvm::enumerate(yieldOp.values())) {
-        Value vectorValue = vectorize(outputs.value());
-        Value result = transferWriteVector(builder, vectorValue,
-                                           generic.getOutput(outputs.index()));
-        if (result)
-          results.push_back(result);
-      }
-      return;
-    }
-    Operation *vectorOp = uncachedVectorize(scalarOp);
-    assert(scalarOp.getNumResults() == vectorOp->getNumResults());
-    for (auto result :
-         llvm::zip(scalarOp.getResults(), vectorOp->getResults())) {
-      valueCache[std::get<0>(result)] = std::get<1>(result);
-    }
-  }
-
-  llvm::ArrayRef<Value> getResults() { return results; }
-
-private:
-  // Transforms a scalar value into its vectorized counterpart, recursively
-  // vectorizing operations as necessary using the underlying builder.
-  // Keeps track of previously vectorized values and reuses vectorization
-  // results if these values come up again.
-  Value vectorize(Value scalarValue) {
-    // Don't vectorize values coming from outside the region.
-    if (scalarValue.getParentRegion() != &generic.region())
-      return scalarValue;
-    auto vectorValueIt = valueCache.find(scalarValue);
-    if (vectorValueIt != valueCache.end())
-      return vectorValueIt->second;
-
-    // If the value is from the region but not in the cache it means it is a
-    // block argument.
-    auto scalarArg = scalarValue.cast<BlockArgument>();
-    assert(scalarArg.getOwner() == &generic.region().front());
-    Value vectorArg = generic.getShapedOperand(scalarArg.getArgNumber());
-    Value vectorResult = transferReadVector(builder, vectorArg);
-    valueCache[scalarArg] = vectorResult;
-    return vectorResult;
-  }
-
-  // Return the largest shape of all the given values. Return an empty
-  // SmallVector if there are no vector value.
-  static SmallVector<int64_t, 4> getLargestShape(ArrayRef<Value> values) {
-    SmallVector<int64_t, 4> largestShape;
-    int64_t maxSize = 1;
-    for (Value value : values) {
-      auto vecType = value.getType().dyn_cast<VectorType>();
-      if (!vecType)
-        continue;
-      if (maxSize < vecType.getNumElements()) {
-        maxSize = vecType.getNumElements();
-        largestShape.assign(vecType.getShape().begin(),
-                            vecType.getShape().end());
-      }
-    }
-    return largestShape;
-  }
-
-  // If the value's type doesn't have the given shape broadcast it.
-  Value broadcastIfNeeded(Value value, ArrayRef<int64_t> shape) {
-    auto vecType = value.getType().dyn_cast<VectorType>();
-    if (shape.empty() || (vecType != nullptr && vecType.getShape() == shape))
-      return value;
-    auto newVecType = VectorType::get(shape, vecType ? vecType.getElementType()
-                                                     : value.getType());
-    return builder.create<vector::BroadcastOp>(
-        builder.getInsertionPoint()->getLoc(), newVecType, value);
-  }
-
-  // Takes a scalar operation and builds its vectorized counterpart or
-  // counterparts using underlying builder without involving any caches.
-  Operation *uncachedVectorize(Operation &base_scalarOp) {
-    SmallVector<Value, 4> vectorizedOperands;
-    for (Value operand : base_scalarOp.getOperands()) {
-      vectorizedOperands.push_back(vectorize(operand));
-    }
-    SmallVector<int64_t, 4> shape = getLargestShape(vectorizedOperands);
-    for (Value &operand : vectorizedOperands)
-      operand = broadcastIfNeeded(operand, shape);
-    OperationState state(base_scalarOp.getLoc(), base_scalarOp.getName());
-    state.addAttributes(base_scalarOp.getAttrs());
-    state.addOperands(vectorizedOperands);
-    if (shape.empty()) {
-      state.addTypes(base_scalarOp.getResultTypes());
-    } else {
-      SmallVector<VectorType, 4> vectorizedTypes;
-      for (auto Type : base_scalarOp.getResultTypes())
-        vectorizedTypes.push_back(VectorType::get(shape, Type));
-      state.addTypes(vectorizedTypes);
-    }
-    return builder.createOperation(state);
-  }
-
-  OpBuilder &builder;
-  linalg::GenericOp generic;
-  llvm::DenseMap<Value, Value> valueCache;
-  SmallVector<Value, 8> results;
-};
-} // namespace
-
-// Replaces elementwise linalg.generic ops with their bodies with scalar
-// operations from these bodies promoted to vector operations.
-static void vectorizeElementwise(linalg::GenericOp op, OpBuilder &builder) {
-  GenericVectorizer vectorizer(builder, op);
-  for (Operation &scalarOp : op.region().front()) {
-    vectorizer.vectorize(scalarOp);
-  }
-  if (!op->getResults().empty())
-    op->replaceAllUsesWith(vectorizer.getResults());
-}
-
 LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
   auto linalgOp = cast<linalg::LinalgOp>(op);
   // All types must be static shape to go to vector.
@@ -313,7 +421,7 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
     // Vectorize fill as a vector.broadcast.
     LLVM_DEBUG(dbgs() << dbgPref
                       << "Rewrite linalg.fill as vector.broadcast: " << *op);
-    transferWriteVector(builder, fillOp.value(), fillOp.output());
+    buildVectorWrite(builder, fillOp.value(), fillOp.output());
     return;
   }
   if (auto copyOp = dyn_cast<linalg::CopyOp>(op)) {
@@ -322,17 +430,21 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
                       << "Rewrite linalg.copy as vector.transfer_read + "
                          "vector.transfer_write: "
                       << *op);
-    Value vector = transferReadVector(builder, copyOp.input());
-    transferWriteVector(builder, vector, copyOp.output());
+    Value vector = buildVectorRead(builder, copyOp.input());
+    buildVectorWrite(builder, vector, copyOp.output());
     return;
   }
 
+  auto linalgOp = cast<linalg::LinalgOp>(op);
+  Location loc = linalgOp.getLoc();
+
   if (isElementwise(op)) {
     LLVM_DEBUG(dbgs() << dbgPref
-                      << "Rewrite linalg op as vector.transfer_read + "
-                         "vector_op + vector.transfer_write: "
-                      << *op);
-    return vectorizeElementwise(cast<linalg::GenericOp>(op), builder);
+                      << "Rewrite linalg op as vector.transfer_read + " << *op);
+    auto status = vectorizeAsLinalgGeneric(builder, linalgOp);
+    assert(succeeded(status) &&
+           "Unexpected vectorization failed despite preconditions");
+    return;
   }
 
   assert(succeeded(isContraction(op)) && "Expected contraction");
@@ -341,15 +453,28 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
   // TODO: interface.
   LLVM_DEBUG(dbgs() << dbgPref
                     << "Rewrite linalg op as vector.contract: " << *op);
-  auto linalgOp = cast<linalg::LinalgOp>(op);
-  Value a = transferReadVector(builder, linalgOp.getInput(0));
-  Value b = transferReadVector(builder, linalgOp.getInput(1));
-  Value c = transferReadVector(builder, linalgOp.getOutput(0));
-  Value res = vector_contract(a, b, c, linalgOp.indexing_maps(),
-                              linalgOp.iterator_types());
-  Value writeResult = transferWriteVector(builder, res, linalgOp.getOutput(0));
-  if (writeResult)
-    linalgOp->replaceAllUsesWith(ArrayRef<Value>(writeResult));
+  // Special function that describes how to vectorize the multiplication op in a
+  // linalg contraction.
+  CustomVectorizationHook vectorizeContraction =
+      [&](Operation *op,
+          const BlockAndValueMapping &bvm) -> VectorizationResult {
+    if (!isa<MulIOp, MulFOp>(op))
+      return VectorizationResult{VectorizationStatus::Failure, nullptr};
+    auto outShape = linalgOp.getOutputShapedType(0).getShape();
+    auto vType = outShape.empty()
+                     ? op->getResult(0).getType()
+                     : VectorType::get(outShape, op->getResult(0).getType());
+    auto zero =
+        builder.create<ConstantOp>(loc, vType, builder.getZeroAttr(vType));
+    Operation *contract = builder.create<vector::ContractionOp>(
+        loc, bvm.lookup(op->getOperand(0)), bvm.lookup(op->getOperand(1)), zero,
+        linalgOp.indexing_maps(), linalgOp.iterator_types());
+    return VectorizationResult{VectorizationStatus::NewOp, contract};
+  };
+  auto status =
+      vectorizeAsLinalgGeneric(builder, linalgOp, {vectorizeContraction});
+  assert(succeeded(status) &&
+         "Unexpected vectorization failed despite preconditions");
 }
 
 /// Check whether there is any interleaved use of any `values` between `firstOp`

diff  --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 2133aad70bd0..aa249542a07d 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -183,13 +183,13 @@ func @generic_vectorize(%arg0: memref<4x256xf32>, %arg1: memref<4x256xf32>,
 //   CHECK-DAG:   %[[CST0:.*]] = constant dense<2.000000e+00> : vector<4x256xf32>
 //   CHECK-DAG:   %[[CST1:.*]] = constant dense<1.000000e+00> : vector<4x256xf32>
 //   CHECK-DAG:   %[[C0:.*]] = constant 0 : index
+//       CHECK:   %[[V2:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32>
 //       CHECK:   %[[V0:.*]] = vector.transfer_read %[[ARG2]][%[[C0]]], {{.*}} : memref<256xf32>, vector<256xf32>
+//       CHECK:   %[[V3:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32>
 //       CHECK:   %[[V1:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32>
 //       CHECK:   %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32>
 //       CHECK:   %[[ADD:.*]] = addf %[[V0B]], %[[V1]] : vector<4x256xf32>
-//       CHECK:   %[[V2:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32>
 //       CHECK:   %[[CMP:.*]] = cmpf ogt, %[[V2]], %[[V1]] : vector<4x256xf32>
-//       CHECK:   %[[V3:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32>
 //       CHECK:   %[[ARG3B:.*]] = vector.broadcast %[[ARG3]] : f32 to vector<4x256xf32>
 //       CHECK:   %[[DIV:.*]] = divf %[[V3]], %[[ARG3B]] : vector<4x256xf32>
 //       CHECK:   %[[EXP:.*]] = exp2 %[[V3]] : vector<4x256xf32>
@@ -267,13 +267,13 @@ func @generic_vectorize_tensor(%arg0: tensor<4x256xf32>,
 //   CHECK-DAG:   %[[CST0:.*]] = constant dense<2.000000e+00> : vector<4x256xf32>
 //   CHECK-DAG:   %[[CST1:.*]] = constant dense<1.000000e+00> : vector<4x256xf32>
 //   CHECK-DAG:   %[[C0:.*]] = constant 0 : index
+//       CHECK:   %[[V2:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32>
 //       CHECK:   %[[V0:.*]] = vector.transfer_read %[[ARG2]][%[[C0]]], {{.*}} : tensor<256xf32>, vector<256xf32>
+//       CHECK:   %[[V3:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32>
 //       CHECK:   %[[V1:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32>
 //       CHECK:   %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32>
 //       CHECK:   %[[ADD:.*]] = addf %[[V0B]], %[[V1]] : vector<4x256xf32>
-//       CHECK:   %[[V2:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32>
 //       CHECK:   %[[CMP:.*]] = cmpf ogt, %[[V2]], %[[V1]] : vector<4x256xf32>
-//       CHECK:   %[[V3:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32>
 //       CHECK:   %[[ARG3B:.*]] = vector.broadcast %[[ARG3]] : f32 to vector<4x256xf32>
 //       CHECK:   %[[DIV:.*]] = divf %[[V3]], %[[ARG3B]] : vector<4x256xf32>
 //       CHECK:   %[[EXP:.*]] = exp2 %[[V3]] : vector<4x256xf32>
@@ -307,10 +307,15 @@ func @matmul_tensors(
 // CHECK-LABEL: func @matmul_tensors
 //  CHECK-SAME: (%[[ARG0:.*]]: tensor<8x4xf32>, %[[ARG1:.*]]: tensor<4x12xf32>,
 //  CHECK-SAME:  %[[ARG2:.*]]: tensor<8x12xf32>) -> tensor<8x12xf32>
-//       CHECK:   %[[C0:.*]] = constant 0 : index
-//       CHECK:   %[[V0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x4xf32>, vector<8x4xf32>
-//       CHECK:   %[[V1:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x12xf32>, vector<4x12xf32>
-//       CHECK:   %[[V2:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x12xf32>, vector<8x12xf32>
-//       CHECK:   %[[C:.*]] = vector.contract {{.*}} iterator_types = ["parallel", "parallel", "reduction"]} %[[V0]], %[[V1]], %[[V2]] : vector<8x4xf32>, vector<4x12xf32> into vector<8x12xf32>
-//       CHECK:   %[[W:.*]] = vector.transfer_write %[[C]], %[[ARG2]][%[[C0]], %[[C0]]] {masked = [false, false]} : vector<8x12xf32>, tensor<8x12xf32>
+//   CHECK-DAG:   %[[C0:.*]] = constant 0 : index
+//   CHECK-DAG:   %[[VEC_C0:.*]] = constant dense<0.000000e+00> : vector<8x12xf32>
+//   CHECK-DAG:   %[[V0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x4xf32>, vector<8x4xf32>
+//   CHECK-DAG:   %[[V1:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x12xf32>, vector<4x12xf32>
+//   CHECK-DAG:   %[[V2:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x12xf32>, vector<8x12xf32>
+//
+// linalg contraction lowers to %tmp = vector.contract %a, %b, %c0 followed by addf %c, %tmp.
+// a later canonicalization fuses the add into vector.contract.
+//       CHECK:   %[[C:.*]] = vector.contract {{.*}} iterator_types = ["parallel", "parallel", "reduction"]} %[[V0]], %[[V1]], %[[VEC_C0]] : vector<8x4xf32>, vector<4x12xf32> into vector<8x12xf32>
+//       CHECK:   %[[C2:.*]] = addf %[[V2]], %[[C]] : vector<8x12xf32>
+//       CHECK:   %[[W:.*]] = vector.transfer_write %[[C2]], %[[ARG2]][%[[C0]], %[[C0]]] {masked = [false, false]} : vector<8x12xf32>, tensor<8x12xf32>
 //       CHECK:   return %[[W]] : tensor<8x12xf32>


        


More information about the Mlir-commits mailing list