[Mlir-commits] [mlir] f56933b - [mlir][vector] NFC move vector unroll/distribute patterns to their own file

Thomas Raoux llvmlistbot at llvm.org
Fri Dec 10 14:00:30 PST 2021


Author: Thomas Raoux
Date: 2021-12-10T14:00:13-08:00
New Revision: f56933b2631c1258b1159eb25bf0dde82ce61c1b

URL: https://github.com/llvm/llvm-project/commit/f56933b2631c1258b1159eb25bf0dde82ce61c1b
DIFF: https://github.com/llvm/llvm-project/commit/f56933b2631c1258b1159eb25bf0dde82ce61c1b.diff

LOG: [mlir][vector] NFC move vector unroll/distribute patterns to their own file

Differential Revision: https://reviews.llvm.org/D115548

Added: 
    mlir/lib/Dialect/Vector/VectorUnrollDistribute.cpp

Modified: 
    mlir/lib/Dialect/Vector/CMakeLists.txt
    mlir/lib/Dialect/Vector/VectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/CMakeLists.txt b/mlir/lib/Dialect/Vector/CMakeLists.txt
index 8f01eda3de4f9..143c6c7d688d1 100644
--- a/mlir/lib/Dialect/Vector/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_dialect_library(MLIRVector
   VectorTransferOpTransforms.cpp
   VectorTransferPermutationMapRewritePatterns.cpp
   VectorTransforms.cpp
+  VectorUnrollDistribute.cpp
   VectorUtils.cpp
 
   ADDITIONAL_HEADER_DIRS

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 61364758c641c..30335b70f4a22 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -135,83 +135,6 @@ static Value reshapeStore(Location loc, Value val, Value result,
   return result;
 }
 
-// Clones `op` into a new operations that takes `operands` and returns
-// `resultTypes`.
-static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc,
-                                              Operation *op,
-                                              ArrayRef<Value> operands,
-                                              ArrayRef<Type> resultTypes) {
-  OperationState res(loc, op->getName().getStringRef(), operands, resultTypes,
-                     op->getAttrs());
-  return builder.createOperation(res);
-}
-
-/// Return the target shape for unrolling for the given `op`. Return llvm::None
-/// if the op shouldn't be or cannot be unrolled.
-static Optional<SmallVector<int64_t, 4>>
-getTargetShape(const vector::UnrollVectorOptions &options, Operation *op) {
-  if (options.filterConstraint && failed(options.filterConstraint(op)))
-    return llvm::None;
-  assert(options.nativeShape &&
-         "vector unrolling expects the native shape or native"
-         "shape call back function to be set");
-  auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op);
-  if (!unrollableVectorOp)
-    return llvm::None;
-  auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
-  if (!maybeUnrollShape)
-    return llvm::None;
-  Optional<SmallVector<int64_t, 4>> targetShape = options.nativeShape(op);
-  if (!targetShape)
-    return llvm::None;
-  auto maybeShapeRatio = shapeRatio(*maybeUnrollShape, *targetShape);
-  if (!maybeShapeRatio ||
-      llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; }))
-    return llvm::None;
-  return targetShape;
-}
-
-/// During unrolling from `originalShape` to `targetShape` return the offset for
-/// the slice `index`.
-static SmallVector<int64_t, 4> getVectorOffset(ArrayRef<int64_t> originalShape,
-                                               ArrayRef<int64_t> targetShape,
-                                               int64_t index) {
-  SmallVector<int64_t, 4> dstSliceStrides =
-      computeStrides(originalShape, targetShape);
-  SmallVector<int64_t, 4> vectorOffsets = delinearize(dstSliceStrides, index);
-  SmallVector<int64_t, 4> elementOffsets =
-      computeElementOffsetsFromVectorSliceOffsets(targetShape, vectorOffsets);
-  return elementOffsets;
-}
-
-/// Compute the indices of the slice `index` for a tranfer op.
-static SmallVector<Value>
-sliceTransferIndices(int64_t index, ArrayRef<int64_t> originalShape,
-                     ArrayRef<int64_t> targetShape, ArrayRef<Value> indices,
-                     AffineMap permutationMap, Location loc,
-                     OpBuilder &builder) {
-  MLIRContext *ctx = builder.getContext();
-  auto isBroadcast = [](AffineExpr expr) {
-    if (auto constExpr = expr.dyn_cast<AffineConstantExpr>())
-      return constExpr.getValue() == 0;
-    return false;
-  };
-  SmallVector<int64_t, 4> elementOffsets =
-      getVectorOffset(originalShape, targetShape, index);
-  // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
-  SmallVector<Value> slicedIndices(indices.begin(), indices.end());
-  for (auto dim : llvm::enumerate(permutationMap.getResults())) {
-    if (isBroadcast(dim.value()))
-      continue;
-    unsigned pos = dim.value().cast<AffineDimExpr>().getPosition();
-    auto expr = getAffineDimExpr(0, builder.getContext()) +
-                getAffineConstantExpr(elementOffsets[dim.index()], ctx);
-    auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
-    slicedIndices[pos] = builder.create<AffineApplyOp>(loc, map, indices[pos]);
-  }
-  return slicedIndices;
-}
-
 template <typename IntType>
 static SmallVector<IntType, 4> extractVector(ArrayAttr arrayAttr) {
   return llvm::to_vector<4>(llvm::map_range(
@@ -221,275 +144,6 @@ static SmallVector<IntType, 4> extractVector(ArrayAttr arrayAttr) {
 
 namespace {
 
-struct UnrollTransferReadPattern
-    : public OpRewritePattern<vector::TransferReadOp> {
-  UnrollTransferReadPattern(MLIRContext *context,
-                            const vector::UnrollVectorOptions &options)
-      : OpRewritePattern<vector::TransferReadOp>(context, /*benefit=*/1),
-        options(options) {}
-  LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
-                                PatternRewriter &rewriter) const override {
-    // TODO: support 0-d corner case.
-    if (readOp.getTransferRank() == 0)
-      return failure();
-    if (readOp.mask())
-      return failure();
-    auto targetShape = getTargetShape(options, readOp);
-    if (!targetShape)
-      return failure();
-    auto sourceVectorType = readOp.getVectorType();
-    SmallVector<int64_t, 4> strides(targetShape->size(), 1);
-    Location loc = readOp.getLoc();
-    ArrayRef<int64_t> originalSize = readOp.getVectorType().getShape();
-    SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
-    // Compute shape ratio of 'shape' and 'sizes'.
-    int64_t sliceCount = computeMaxLinearIndex(ratio);
-    // Prepare the result vector;
-    Value result = rewriter.create<arith::ConstantOp>(
-        loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType));
-    auto targetType =
-        VectorType::get(*targetShape, sourceVectorType.getElementType());
-    SmallVector<Value, 4> originalIndices(readOp.indices().begin(),
-                                          readOp.indices().end());
-    for (int64_t i = 0; i < sliceCount; i++) {
-      SmallVector<Value, 4> indices =
-          sliceTransferIndices(i, originalSize, *targetShape, originalIndices,
-                               readOp.permutation_map(), loc, rewriter);
-      auto slicedRead = rewriter.create<vector::TransferReadOp>(
-          loc, targetType, readOp.source(), indices,
-          readOp.permutation_mapAttr(), readOp.padding(), readOp.mask(),
-          readOp.in_boundsAttr());
-
-      SmallVector<int64_t, 4> elementOffsets =
-          getVectorOffset(originalSize, *targetShape, i);
-      result = rewriter.create<vector::InsertStridedSliceOp>(
-          loc, slicedRead, result, elementOffsets, strides);
-    }
-    rewriter.replaceOp(readOp, result);
-    return success();
-  }
-
-private:
-  vector::UnrollVectorOptions options;
-};
-
-struct UnrollTransferWritePattern
-    : public OpRewritePattern<vector::TransferWriteOp> {
-  UnrollTransferWritePattern(MLIRContext *context,
-                             const vector::UnrollVectorOptions &options)
-      : OpRewritePattern<vector::TransferWriteOp>(context, /*benefit=*/1),
-        options(options) {}
-  LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
-                                PatternRewriter &rewriter) const override {
-    // TODO: support 0-d corner case.
-    if (writeOp.getTransferRank() == 0)
-      return failure();
-
-    if (writeOp.mask())
-      return failure();
-    auto targetShape = getTargetShape(options, writeOp);
-    if (!targetShape)
-      return failure();
-    auto sourceVectorType = writeOp.getVectorType();
-    SmallVector<int64_t, 4> strides(targetShape->size(), 1);
-    Location loc = writeOp.getLoc();
-    ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
-    SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
-    // Compute shape ratio of 'shape' and 'sizes'.
-    int64_t sliceCount = computeMaxLinearIndex(ratio);
-    SmallVector<Value, 4> originalIndices(writeOp.indices().begin(),
-                                          writeOp.indices().end());
-    Value resultTensor;
-    for (int64_t i = 0; i < sliceCount; i++) {
-      SmallVector<int64_t, 4> elementOffsets =
-          getVectorOffset(originalSize, *targetShape, i);
-      Value slicedVector = rewriter.create<vector::ExtractStridedSliceOp>(
-          loc, writeOp.vector(), elementOffsets, *targetShape, strides);
-
-      SmallVector<Value, 4> indices =
-          sliceTransferIndices(i, originalSize, *targetShape, originalIndices,
-                               writeOp.permutation_map(), loc, rewriter);
-      Operation *slicedWrite = rewriter.create<vector::TransferWriteOp>(
-          loc, slicedVector, resultTensor ? resultTensor : writeOp.source(),
-          indices, writeOp.permutation_mapAttr(), writeOp.in_boundsAttr());
-      // For the tensor case update the destination for the next transfer write.
-      if (!slicedWrite->getResults().empty())
-        resultTensor = slicedWrite->getResult(0);
-    }
-    if (resultTensor)
-      rewriter.replaceOp(writeOp, resultTensor);
-    else
-      rewriter.eraseOp(writeOp);
-    return success();
-  }
-
-private:
-  vector::UnrollVectorOptions options;
-};
-
-struct UnrollContractionPattern
-    : public OpRewritePattern<vector::ContractionOp> {
-  struct OffsetMapInfo {
-    static SmallVector<int64_t> getEmptyKey() { return {int64_t(-1)}; }
-
-    static SmallVector<int64_t> getTombstoneKey() { return {int64_t(-2)}; }
-
-    static unsigned getHashValue(const SmallVector<int64_t> &v) {
-      return static_cast<unsigned>(
-          llvm::hash_combine_range(v.begin(), v.end()));
-    }
-
-    static bool isEqual(const SmallVector<int64_t> &lhs,
-                        const SmallVector<int64_t> &rhs) {
-      return lhs == rhs;
-    }
-  };
-  UnrollContractionPattern(MLIRContext *context,
-                           const vector::UnrollVectorOptions &options)
-      : OpRewritePattern<vector::ContractionOp>(context, /*benefit=*/1),
-        options(options) {}
-
-  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
-                                PatternRewriter &rewriter) const override {
-    auto targetShape = getTargetShape(options, contractOp);
-    if (!targetShape)
-      return failure();
-    auto dstVecType = contractOp.getResultType().cast<VectorType>();
-    SmallVector<int64_t, 4> originalSize = *contractOp.getShapeForUnroll();
-    SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
-
-    // Compute shape ratio of 'shape' and 'sizes'.
-    int64_t sliceCount = computeMaxLinearIndex(ratio);
-    Location loc = contractOp.getLoc();
-    unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
-    AffineMap dstAffineMap = contractOp.getIndexingMaps()[accIndex];
-    llvm::MapVector<
-        SmallVector<int64_t>, Value,
-        llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
-        accCache;
-    for (int64_t i = 0; i < sliceCount; i++) {
-      SmallVector<int64_t, 4> offsets =
-          getVectorOffset(originalSize, *targetShape, i);
-      SmallVector<Value, 4> slicesOperands(contractOp.getNumOperands());
-
-      // Helper to coompute the new shape of each operand and extract the slice.
-      auto extractOperand = [&](unsigned index, Value operand,
-                                AffineMap permutationMap,
-                                ArrayRef<int64_t> operandOffets) {
-        SmallVector<int64_t> operandShape = applyPermutationMap(
-            permutationMap, ArrayRef<int64_t>(*targetShape));
-        SmallVector<int64_t, 4> operandStrides(operandOffets.size(), 1);
-        slicesOperands[index] = rewriter.create<vector::ExtractStridedSliceOp>(
-            loc, operand, operandOffets, operandShape, operandStrides);
-      };
-
-      // Extract the new lhs operand.
-      AffineMap lhsPermutationMap = contractOp.getIndexingMaps()[0];
-      SmallVector<int64_t> lhsOffets =
-          applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
-      extractOperand(0, contractOp.lhs(), lhsPermutationMap, lhsOffets);
-      // If there is a mask associated to lhs, extract it as well.
-      if (slicesOperands.size() > 3)
-        extractOperand(3, contractOp.masks()[0], lhsPermutationMap, lhsOffets);
-
-      // Extract the new rhs operand.
-      AffineMap rhsPermutationMap = contractOp.getIndexingMaps()[1];
-      SmallVector<int64_t> rhsOffets =
-          applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
-      extractOperand(1, contractOp.rhs(), rhsPermutationMap, rhsOffets);
-      // If there is a mask associated to rhs, extract it as well.
-      if (slicesOperands.size() > 4)
-        extractOperand(4, contractOp.masks()[1], rhsPermutationMap, rhsOffets);
-
-      AffineMap accPermutationMap = contractOp.getIndexingMaps()[2];
-      SmallVector<int64_t> accOffets =
-          applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
-      // If a version of the accumulator has already been computed, use it
-      // otherwise extract the first version from the original operand.
-      auto accIt = accCache.find(accOffets);
-      if (accIt != accCache.end())
-        slicesOperands[2] = accIt->second;
-      else
-        extractOperand(2, contractOp.acc(), accPermutationMap, accOffets);
-
-      SmallVector<int64_t> dstShape =
-          applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(*targetShape));
-      auto targetType = VectorType::get(dstShape, dstVecType.getElementType());
-      Operation *newOp = cloneOpWithOperandsAndTypes(
-          rewriter, loc, contractOp, slicesOperands, targetType);
-
-      SmallVector<int64_t> dstOffets =
-          applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(offsets));
-      // Save the accumulated value untill all the loops are unrolled since
-      // reduction loop keep updating the accumulator.
-      accCache[dstOffets] = newOp->getResult(0);
-    }
-    // Assemble back the accumulator into a single vector.
-    Value result = rewriter.create<arith::ConstantOp>(
-        loc, dstVecType, rewriter.getZeroAttr(dstVecType));
-    for (const auto &it : accCache) {
-      SmallVector<int64_t> dstStrides(it.first.size(), 1);
-      result = rewriter.create<vector::InsertStridedSliceOp>(
-          loc, it.second, result, it.first, dstStrides);
-    }
-    rewriter.replaceOp(contractOp, result);
-    return success();
-  }
-
-private:
-  vector::UnrollVectorOptions options;
-};
-
-struct UnrollElementwisePattern : public RewritePattern {
-  UnrollElementwisePattern(MLIRContext *context,
-                           const vector::UnrollVectorOptions &options)
-      : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context),
-        options(options) {}
-  LogicalResult matchAndRewrite(Operation *op,
-                                PatternRewriter &rewriter) const override {
-    if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1)
-      return failure();
-    auto targetShape = getTargetShape(options, op);
-    if (!targetShape)
-      return failure();
-    auto dstVecType = op->getResult(0).getType().cast<VectorType>();
-    SmallVector<int64_t, 4> originalSize =
-        *cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
-    SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
-    int64_t sliceCount = computeMaxLinearIndex(ratio);
-    Location loc = op->getLoc();
-    // Prepare the result vector.
-    Value result = rewriter.create<arith::ConstantOp>(
-        loc, dstVecType, rewriter.getZeroAttr(dstVecType));
-    SmallVector<int64_t, 4> strides(targetShape->size(), 1);
-    VectorType newVecType =
-        VectorType::get(*targetShape, dstVecType.getElementType());
-    for (int64_t i = 0; i < sliceCount; i++) {
-      SmallVector<int64_t, 4> offsets =
-          getVectorOffset(originalSize, *targetShape, i);
-      SmallVector<Value, 4> extractOperands;
-      for (OpOperand &operand : op->getOpOperands()) {
-        auto vecType = operand.get().getType().template dyn_cast<VectorType>();
-        if (!vecType) {
-          extractOperands.push_back(operand.get());
-          continue;
-        }
-        extractOperands.push_back(
-            rewriter.create<vector::ExtractStridedSliceOp>(
-                loc, operand.get(), offsets, *targetShape, strides));
-      }
-      Operation *newOp = cloneOpWithOperandsAndTypes(
-          rewriter, loc, op, extractOperands, newVecType);
-      result = rewriter.create<vector::InsertStridedSliceOp>(
-          loc, newOp->getResult(0), result, offsets, strides);
-    }
-    rewriter.replaceOp(op, result);
-    return success();
-  }
-
-private:
-  vector::UnrollVectorOptions options;
-};
 
 /// ShapeCastOpFolder folds cancelling ShapeCastOps away.
 //
@@ -2599,198 +2253,6 @@ Optional<mlir::vector::DistributeOps> mlir::vector::distributPointwiseVectorOp(
   return ops;
 }
 
-/// Canonicalize an extract_map using the result of a pointwise operation.
-/// Transforms:
-/// %v = arith.addf %a, %b : vector32xf32>
-/// %dv = vector.extract_map %v[%id] : vector<32xf32> to vector<1xf32>
-/// to:
-/// %da = vector.extract_map %a[%id] : vector<32xf32> to vector<1xf32>
-/// %db = vector.extract_map %a[%id] : vector<32xf32> to vector<1xf32>
-/// %dv = arith.addf %da, %db : vector<1xf32>
-struct PointwiseExtractPattern : public OpRewritePattern<vector::ExtractMapOp> {
-  using OpRewritePattern<vector::ExtractMapOp>::OpRewritePattern;
-  LogicalResult matchAndRewrite(vector::ExtractMapOp extract,
-                                PatternRewriter &rewriter) const override {
-    Operation *definedOp = extract.vector().getDefiningOp();
-    if (!definedOp || !OpTrait::hasElementwiseMappableTraits(definedOp) ||
-        definedOp->getNumResults() != 1)
-      return failure();
-    Location loc = extract.getLoc();
-    SmallVector<Value, 4> extractOperands;
-    for (OpOperand &operand : definedOp->getOpOperands()) {
-      auto vecType = operand.get().getType().template dyn_cast<VectorType>();
-      if (!vecType) {
-        extractOperands.push_back(operand.get());
-        continue;
-      }
-      extractOperands.push_back(rewriter.create<vector::ExtractMapOp>(
-          loc,
-          VectorType::get(extract.getResultType().getShape(),
-                          vecType.getElementType()),
-          operand.get(), extract.ids()));
-    }
-    Operation *newOp = cloneOpWithOperandsAndTypes(
-        rewriter, loc, definedOp, extractOperands, extract.getResultType());
-    rewriter.replaceOp(extract, newOp->getResult(0));
-    return success();
-  }
-};
-
-/// Canonicalize an extract_map using the result of a contract operation.
-/// This propagate the extract_map to operands.
-struct ContractExtractPattern : public OpRewritePattern<vector::ExtractMapOp> {
-  using OpRewritePattern<vector::ExtractMapOp>::OpRewritePattern;
-  LogicalResult matchAndRewrite(vector::ExtractMapOp extract,
-                                PatternRewriter &rewriter) const override {
-    Operation *definedOp = extract.vector().getDefiningOp();
-    auto contract = dyn_cast_or_null<vector::ContractionOp>(definedOp);
-    if (!contract)
-      return failure();
-    Location loc = contract.getLoc();
-    unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
-    AffineMap affineMap = contract.getIndexingMaps()[accIndex];
-    // Create a map of the dimensions distributed based on the acc affine map.
-    // Only parallel dimensions are being distributed, reduction dimensions are
-    // untouched.
-    DenseMap<int64_t, int64_t> map;
-    for (unsigned i : llvm::seq(unsigned(0), affineMap.getNumResults()))
-      map[affineMap.getDimPosition(i)] = extract.getResultType().getDimSize(i);
-    SmallVector<Value, 4> extractOperands;
-    for (auto it : llvm::enumerate(contract.getIndexingMaps())) {
-      // For each operands calculate the new vector type after distribution.
-      Value operand = contract->getOperand(it.index());
-      auto vecType = operand.getType().cast<VectorType>();
-      SmallVector<int64_t> operandShape(vecType.getShape().begin(),
-                                        vecType.getShape().end());
-      for (unsigned i : llvm::seq(unsigned(0), it.value().getNumResults())) {
-        unsigned dim = it.value().getDimPosition(i);
-        auto distributedDim = map.find(dim);
-        // If the dimension is not in the map it means it is a reduction and
-        // doesn't get distributed.
-        if (distributedDim == map.end())
-          continue;
-        operandShape[i] = distributedDim->second;
-      }
-      VectorType newVecType =
-          VectorType::get(operandShape, vecType.getElementType());
-      extractOperands.push_back(rewriter.create<vector::ExtractMapOp>(
-          loc, newVecType, operand, extract.ids()));
-    }
-    Operation *newOp =
-        cloneOpWithOperandsAndTypes(rewriter, loc, definedOp, extractOperands,
-                                    extract.getResult().getType());
-    rewriter.replaceOp(extract, newOp->getResult(0));
-    return success();
-  }
-};
-
-/// Converts TransferRead op used by ExtractMap op into a smaller dimension
-/// TransferRead.
-/// Example:
-/// ```
-/// %a = vector.transfer_read %A[%c0, %c0, %c0], %cf0:
-///   memref<64x64x64xf32>, vector<64x4x32xf32>
-/// %e = vector.extract_map %a[%id] : vector<64x4x32xf32> to vector<2x4x1xf32>
-/// ```
-/// to:
-/// ```
-/// %id1 = affine.apply affine_map<()[s0] -> (s0 * 2)> (%id)
-/// %e = vector.transfer_read %A[%id1, %c0, %id1], %cf0 :
-///   memref<64x64x64xf32>, vector<2x4x1xf32>
-/// ```
-struct TransferReadExtractPattern
-    : public OpRewritePattern<vector::TransferReadOp> {
-  TransferReadExtractPattern(MLIRContext *context)
-      : OpRewritePattern<vector::TransferReadOp>(context) {}
-  LogicalResult matchAndRewrite(vector::TransferReadOp read,
-                                PatternRewriter &rewriter) const override {
-    // TODO: support 0-d corner case.
-    if (read.getTransferRank() == 0)
-      return failure();
-
-    if (!read.getResult().hasOneUse())
-      return failure();
-    auto extract =
-        dyn_cast<vector::ExtractMapOp>(*read.getResult().getUsers().begin());
-    if (!extract)
-      return failure();
-    if (read.mask())
-      return failure();
-
-    SmallVector<Value, 4> indices(read.indices().begin(), read.indices().end());
-    AffineMap indexMap = extract.map().compose(read.permutation_map());
-    unsigned idCount = 0;
-    ImplicitLocOpBuilder lb(read.getLoc(), rewriter);
-    for (auto it :
-         llvm::zip(indexMap.getResults(), extract.map().getResults())) {
-      AffineExpr d0, d1;
-      bindDims(read.getContext(), d0, d1);
-      auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
-      if (!indexExpr)
-        continue;
-      unsigned indexPos = indexExpr.getPosition();
-      unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
-      auto scale = getAffineConstantExpr(
-          extract.getResultType().getDimSize(vectorPos), read.getContext());
-      indices[indexPos] = makeComposedAffineApply(
-          rewriter, read.getLoc(), d0 + scale * d1,
-          {indices[indexPos], extract.ids()[idCount++]});
-    }
-    Value newRead = lb.create<vector::TransferReadOp>(
-        extract.getType(), read.source(), indices, read.permutation_mapAttr(),
-        read.padding(), read.mask(), read.in_boundsAttr());
-    Value dest = lb.create<arith::ConstantOp>(
-        read.getType(), rewriter.getZeroAttr(read.getType()));
-    newRead = lb.create<vector::InsertMapOp>(newRead, dest, extract.ids());
-    rewriter.replaceOp(read, newRead);
-    return success();
-  }
-};
-
-struct TransferWriteInsertPattern
-    : public OpRewritePattern<vector::TransferWriteOp> {
-  TransferWriteInsertPattern(MLIRContext *context)
-      : OpRewritePattern<vector::TransferWriteOp>(context) {}
-  LogicalResult matchAndRewrite(vector::TransferWriteOp write,
-                                PatternRewriter &rewriter) const override {
-    // TODO: support 0-d corner case.
-    if (write.getTransferRank() == 0)
-      return failure();
-
-    auto insert = write.vector().getDefiningOp<vector::InsertMapOp>();
-    if (!insert)
-      return failure();
-    if (write.mask())
-      return failure();
-    SmallVector<Value, 4> indices(write.indices().begin(),
-                                  write.indices().end());
-    AffineMap indexMap = insert.map().compose(write.permutation_map());
-    unsigned idCount = 0;
-    Location loc = write.getLoc();
-    for (auto it :
-         llvm::zip(indexMap.getResults(), insert.map().getResults())) {
-      AffineExpr d0, d1;
-      bindDims(write.getContext(), d0, d1);
-      auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
-      if (!indexExpr)
-        continue;
-      unsigned indexPos = indexExpr.getPosition();
-      unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
-      auto scale = getAffineConstantExpr(
-          insert.getSourceVectorType().getDimSize(vectorPos),
-          write.getContext());
-      indices[indexPos] =
-          makeComposedAffineApply(rewriter, loc, d0 + scale * d1,
-                                  {indices[indexPos], insert.ids()[idCount++]});
-    }
-    rewriter.create<vector::TransferWriteOp>(
-        loc, insert.vector(), write.source(), indices,
-        write.permutation_mapAttr(), write.in_boundsAttr());
-    rewriter.eraseOp(write);
-    return success();
-  }
-};
-
 /// Progressive lowering of transfer_read. This pattern supports lowering of
 /// `vector.transfer_read` to a combination of `vector.load` and
 /// `vector.broadcast` if all of the following hold:
@@ -3470,13 +2932,6 @@ void mlir::vector::populateVectorMaskMaterializationPatterns(
       patterns.getContext(), indexOptimizations);
 }
 
-void mlir::vector::populatePropagateVectorDistributionPatterns(
-    RewritePatternSet &patterns) {
-  patterns.add<PointwiseExtractPattern, ContractExtractPattern,
-               TransferReadExtractPattern, TransferWriteInsertPattern>(
-      patterns.getContext());
-}
-
 void mlir::vector::populateShapeCastFoldingPatterns(
     RewritePatternSet &patterns) {
   patterns.add<ShapeCastOpFolder>(patterns.getContext());
@@ -3527,13 +2982,6 @@ void mlir::vector::populateVectorReductionToContractPatterns(
                CombineContractTranspose>(patterns.getContext());
 }
 
-void mlir::vector::populateVectorUnrollPatterns(
-    RewritePatternSet &patterns, const UnrollVectorOptions &options) {
-  patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
-               UnrollContractionPattern, UnrollElementwisePattern>(
-      patterns.getContext(), options);
-}
-
 void mlir::vector::
     populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
         RewritePatternSet &patterns) {

diff  --git a/mlir/lib/Dialect/Vector/VectorUnrollDistribute.cpp b/mlir/lib/Dialect/Vector/VectorUnrollDistribute.cpp
new file mode 100644
index 0000000000000..4c31164b433e2
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/VectorUnrollDistribute.cpp
@@ -0,0 +1,581 @@
+//===- VectorUnrollDistribute.cpp - patterns to do vector unrolling -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements patterns to do vector unrolling and vector distribution.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/Utils.h"
+#include "mlir/Dialect/Vector/VectorTransforms.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/Interfaces/VectorInterfaces.h"
+#include "llvm/ADT/MapVector.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "vector-unrolling"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+/// During unrolling from `originalShape` to `targetShape` return the offset for
+/// the slice `index`.
+static SmallVector<int64_t, 4> getVectorOffset(ArrayRef<int64_t> originalShape,
+                                               ArrayRef<int64_t> targetShape,
+                                               int64_t index) {
+  SmallVector<int64_t, 4> dstSliceStrides =
+      computeStrides(originalShape, targetShape);
+  SmallVector<int64_t, 4> vectorOffsets = delinearize(dstSliceStrides, index);
+  SmallVector<int64_t, 4> elementOffsets =
+      computeElementOffsetsFromVectorSliceOffsets(targetShape, vectorOffsets);
+  return elementOffsets;
+}
+
+/// Compute the indices of the slice `index` for a tranfer op.
+static SmallVector<Value>
+sliceTransferIndices(int64_t index, ArrayRef<int64_t> originalShape,
+                     ArrayRef<int64_t> targetShape, ArrayRef<Value> indices,
+                     AffineMap permutationMap, Location loc,
+                     OpBuilder &builder) {
+  MLIRContext *ctx = builder.getContext();
+  auto isBroadcast = [](AffineExpr expr) {
+    if (auto constExpr = expr.dyn_cast<AffineConstantExpr>())
+      return constExpr.getValue() == 0;
+    return false;
+  };
+  SmallVector<int64_t, 4> elementOffsets =
+      getVectorOffset(originalShape, targetShape, index);
+  // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
+  SmallVector<Value> slicedIndices(indices.begin(), indices.end());
+  for (auto dim : llvm::enumerate(permutationMap.getResults())) {
+    if (isBroadcast(dim.value()))
+      continue;
+    unsigned pos = dim.value().cast<AffineDimExpr>().getPosition();
+    auto expr = getAffineDimExpr(0, builder.getContext()) +
+                getAffineConstantExpr(elementOffsets[dim.index()], ctx);
+    auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
+    slicedIndices[pos] = builder.create<AffineApplyOp>(loc, map, indices[pos]);
+  }
+  return slicedIndices;
+}
+
+// Clones `op` into a new operations that takes `operands` and returns
+// `resultTypes`.
+static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc,
+                                              Operation *op,
+                                              ArrayRef<Value> operands,
+                                              ArrayRef<Type> resultTypes) {
+  OperationState res(loc, op->getName().getStringRef(), operands, resultTypes,
+                     op->getAttrs());
+  return builder.createOperation(res);
+}
+
+/// Return the target shape for unrolling for the given `op`. Return llvm::None
+/// if the op shouldn't be or cannot be unrolled.
+static Optional<SmallVector<int64_t, 4>>
+getTargetShape(const vector::UnrollVectorOptions &options, Operation *op) {
+  if (options.filterConstraint && failed(options.filterConstraint(op)))
+    return llvm::None;
+  assert(options.nativeShape &&
+         "vector unrolling expects the native shape or native"
+         "shape call back function to be set");
+  auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op);
+  if (!unrollableVectorOp)
+    return llvm::None;
+  auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
+  if (!maybeUnrollShape)
+    return llvm::None;
+  Optional<SmallVector<int64_t, 4>> targetShape = options.nativeShape(op);
+  if (!targetShape)
+    return llvm::None;
+  auto maybeShapeRatio = shapeRatio(*maybeUnrollShape, *targetShape);
+  if (!maybeShapeRatio ||
+      llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; }))
+    return llvm::None;
+  return targetShape;
+}
+
+namespace {
+
+struct UnrollTransferReadPattern
+    : public OpRewritePattern<vector::TransferReadOp> {
+  UnrollTransferReadPattern(MLIRContext *context,
+                            const vector::UnrollVectorOptions &options)
+      : OpRewritePattern<vector::TransferReadOp>(context, /*benefit=*/1),
+        options(options) {}
+  LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
+                                PatternRewriter &rewriter) const override {
+    // TODO: support 0-d corner case.
+    if (readOp.getTransferRank() == 0)
+      return failure();
+    if (readOp.mask())
+      return failure();
+    auto targetShape = getTargetShape(options, readOp);
+    if (!targetShape)
+      return failure();
+    auto sourceVectorType = readOp.getVectorType();
+    SmallVector<int64_t, 4> strides(targetShape->size(), 1);
+    Location loc = readOp.getLoc();
+    ArrayRef<int64_t> originalSize = readOp.getVectorType().getShape();
+    SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
+    // Compute shape ratio of 'shape' and 'sizes'.
+    int64_t sliceCount = computeMaxLinearIndex(ratio);
+    // Prepare the result vector;
+    Value result = rewriter.create<arith::ConstantOp>(
+        loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType));
+    auto targetType =
+        VectorType::get(*targetShape, sourceVectorType.getElementType());
+    SmallVector<Value, 4> originalIndices(readOp.indices().begin(),
+                                          readOp.indices().end());
+    for (int64_t i = 0; i < sliceCount; i++) {
+      SmallVector<Value, 4> indices =
+          sliceTransferIndices(i, originalSize, *targetShape, originalIndices,
+                               readOp.permutation_map(), loc, rewriter);
+      auto slicedRead = rewriter.create<vector::TransferReadOp>(
+          loc, targetType, readOp.source(), indices,
+          readOp.permutation_mapAttr(), readOp.padding(), readOp.mask(),
+          readOp.in_boundsAttr());
+
+      SmallVector<int64_t, 4> elementOffsets =
+          getVectorOffset(originalSize, *targetShape, i);
+      result = rewriter.create<vector::InsertStridedSliceOp>(
+          loc, slicedRead, result, elementOffsets, strides);
+    }
+    rewriter.replaceOp(readOp, result);
+    return success();
+  }
+
+private:
+  vector::UnrollVectorOptions options;
+};
+
+struct UnrollTransferWritePattern
+    : public OpRewritePattern<vector::TransferWriteOp> {
+  UnrollTransferWritePattern(MLIRContext *context,
+                             const vector::UnrollVectorOptions &options)
+      : OpRewritePattern<vector::TransferWriteOp>(context, /*benefit=*/1),
+        options(options) {}
+  LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
+                                PatternRewriter &rewriter) const override {
+    // TODO: support 0-d corner case.
+    if (writeOp.getTransferRank() == 0)
+      return failure();
+
+    if (writeOp.mask())
+      return failure();
+    auto targetShape = getTargetShape(options, writeOp);
+    if (!targetShape)
+      return failure();
+    auto sourceVectorType = writeOp.getVectorType();
+    SmallVector<int64_t, 4> strides(targetShape->size(), 1);
+    Location loc = writeOp.getLoc();
+    ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
+    SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
+    // Compute shape ratio of 'shape' and 'sizes'.
+    int64_t sliceCount = computeMaxLinearIndex(ratio);
+    SmallVector<Value, 4> originalIndices(writeOp.indices().begin(),
+                                          writeOp.indices().end());
+    Value resultTensor;
+    for (int64_t i = 0; i < sliceCount; i++) {
+      SmallVector<int64_t, 4> elementOffsets =
+          getVectorOffset(originalSize, *targetShape, i);
+      Value slicedVector = rewriter.create<vector::ExtractStridedSliceOp>(
+          loc, writeOp.vector(), elementOffsets, *targetShape, strides);
+
+      SmallVector<Value, 4> indices =
+          sliceTransferIndices(i, originalSize, *targetShape, originalIndices,
+                               writeOp.permutation_map(), loc, rewriter);
+      Operation *slicedWrite = rewriter.create<vector::TransferWriteOp>(
+          loc, slicedVector, resultTensor ? resultTensor : writeOp.source(),
+          indices, writeOp.permutation_mapAttr(), writeOp.in_boundsAttr());
+      // For the tensor case update the destination for the next transfer write.
+      if (!slicedWrite->getResults().empty())
+        resultTensor = slicedWrite->getResult(0);
+    }
+    if (resultTensor)
+      rewriter.replaceOp(writeOp, resultTensor);
+    else
+      rewriter.eraseOp(writeOp);
+    return success();
+  }
+
+private:
+  vector::UnrollVectorOptions options;
+};
+
+struct UnrollContractionPattern
+    : public OpRewritePattern<vector::ContractionOp> {
+  struct OffsetMapInfo {
+    static SmallVector<int64_t> getEmptyKey() { return {int64_t(-1)}; }
+
+    static SmallVector<int64_t> getTombstoneKey() { return {int64_t(-2)}; }
+
+    static unsigned getHashValue(const SmallVector<int64_t> &v) {
+      return static_cast<unsigned>(
+          llvm::hash_combine_range(v.begin(), v.end()));
+    }
+
+    static bool isEqual(const SmallVector<int64_t> &lhs,
+                        const SmallVector<int64_t> &rhs) {
+      return lhs == rhs;
+    }
+  };
+  UnrollContractionPattern(MLIRContext *context,
+                           const vector::UnrollVectorOptions &options)
+      : OpRewritePattern<vector::ContractionOp>(context, /*benefit=*/1),
+        options(options) {}
+
+  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
+                                PatternRewriter &rewriter) const override {
+    auto targetShape = getTargetShape(options, contractOp);
+    if (!targetShape)
+      return failure();
+    auto dstVecType = contractOp.getResultType().cast<VectorType>();
+    SmallVector<int64_t, 4> originalSize = *contractOp.getShapeForUnroll();
+    SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
+
+    // Compute shape ratio of 'shape' and 'sizes'.
+    int64_t sliceCount = computeMaxLinearIndex(ratio);
+    Location loc = contractOp.getLoc();
+    unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
+    AffineMap dstAffineMap = contractOp.getIndexingMaps()[accIndex];
+    llvm::MapVector<
+        SmallVector<int64_t>, Value,
+        llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
+        accCache;
+    for (int64_t i = 0; i < sliceCount; i++) {
+      SmallVector<int64_t, 4> offsets =
+          getVectorOffset(originalSize, *targetShape, i);
+      SmallVector<Value, 4> slicesOperands(contractOp.getNumOperands());
+
+      // Helper to coompute the new shape of each operand and extract the slice.
+      auto extractOperand = [&](unsigned index, Value operand,
+                                AffineMap permutationMap,
+                                ArrayRef<int64_t> operandOffets) {
+        SmallVector<int64_t> operandShape = applyPermutationMap(
+            permutationMap, ArrayRef<int64_t>(*targetShape));
+        SmallVector<int64_t, 4> operandStrides(operandOffets.size(), 1);
+        slicesOperands[index] = rewriter.create<vector::ExtractStridedSliceOp>(
+            loc, operand, operandOffets, operandShape, operandStrides);
+      };
+
+      // Extract the new lhs operand.
+      AffineMap lhsPermutationMap = contractOp.getIndexingMaps()[0];
+      SmallVector<int64_t> lhsOffets =
+          applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
+      extractOperand(0, contractOp.lhs(), lhsPermutationMap, lhsOffets);
+      // If there is a mask associated to lhs, extract it as well.
+      if (slicesOperands.size() > 3)
+        extractOperand(3, contractOp.masks()[0], lhsPermutationMap, lhsOffets);
+
+      // Extract the new rhs operand.
+      AffineMap rhsPermutationMap = contractOp.getIndexingMaps()[1];
+      SmallVector<int64_t> rhsOffets =
+          applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
+      extractOperand(1, contractOp.rhs(), rhsPermutationMap, rhsOffets);
+      // If there is a mask associated to rhs, extract it as well.
+      if (slicesOperands.size() > 4)
+        extractOperand(4, contractOp.masks()[1], rhsPermutationMap, rhsOffets);
+
+      AffineMap accPermutationMap = contractOp.getIndexingMaps()[2];
+      SmallVector<int64_t> accOffets =
+          applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
+      // If a version of the accumulator has already been computed, use it
+      // otherwise extract the first version from the original operand.
+      auto accIt = accCache.find(accOffets);
+      if (accIt != accCache.end())
+        slicesOperands[2] = accIt->second;
+      else
+        extractOperand(2, contractOp.acc(), accPermutationMap, accOffets);
+
+      SmallVector<int64_t> dstShape =
+          applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(*targetShape));
+      auto targetType = VectorType::get(dstShape, dstVecType.getElementType());
+      Operation *newOp = cloneOpWithOperandsAndTypes(
+          rewriter, loc, contractOp, slicesOperands, targetType);
+
+      SmallVector<int64_t> dstOffets =
+          applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(offsets));
+      // Save the accumulated value untill all the loops are unrolled since
+      // reduction loop keep updating the accumulator.
+      accCache[dstOffets] = newOp->getResult(0);
+    }
+    // Assemble back the accumulator into a single vector.
+    Value result = rewriter.create<arith::ConstantOp>(
+        loc, dstVecType, rewriter.getZeroAttr(dstVecType));
+    for (const auto &it : accCache) {
+      SmallVector<int64_t> dstStrides(it.first.size(), 1);
+      result = rewriter.create<vector::InsertStridedSliceOp>(
+          loc, it.second, result, it.first, dstStrides);
+    }
+    rewriter.replaceOp(contractOp, result);
+    return success();
+  }
+
+private:
+  vector::UnrollVectorOptions options;
+};
+
+struct UnrollElementwisePattern : public RewritePattern {
+  UnrollElementwisePattern(MLIRContext *context,
+                           const vector::UnrollVectorOptions &options)
+      : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context),
+        options(options) {}
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1)
+      return failure();
+    auto targetShape = getTargetShape(options, op);
+    if (!targetShape)
+      return failure();
+    auto dstVecType = op->getResult(0).getType().cast<VectorType>();
+    SmallVector<int64_t, 4> originalSize =
+        *cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
+    SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
+    int64_t sliceCount = computeMaxLinearIndex(ratio);
+    Location loc = op->getLoc();
+    // Prepare the result vector.
+    Value result = rewriter.create<arith::ConstantOp>(
+        loc, dstVecType, rewriter.getZeroAttr(dstVecType));
+    SmallVector<int64_t, 4> strides(targetShape->size(), 1);
+    VectorType newVecType =
+        VectorType::get(*targetShape, dstVecType.getElementType());
+    for (int64_t i = 0; i < sliceCount; i++) {
+      SmallVector<int64_t, 4> offsets =
+          getVectorOffset(originalSize, *targetShape, i);
+      SmallVector<Value, 4> extractOperands;
+      for (OpOperand &operand : op->getOpOperands()) {
+        auto vecType = operand.get().getType().template dyn_cast<VectorType>();
+        if (!vecType) {
+          extractOperands.push_back(operand.get());
+          continue;
+        }
+        extractOperands.push_back(
+            rewriter.create<vector::ExtractStridedSliceOp>(
+                loc, operand.get(), offsets, *targetShape, strides));
+      }
+      Operation *newOp = cloneOpWithOperandsAndTypes(
+          rewriter, loc, op, extractOperands, newVecType);
+      result = rewriter.create<vector::InsertStridedSliceOp>(
+          loc, newOp->getResult(0), result, offsets, strides);
+    }
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+
+private:
+  vector::UnrollVectorOptions options;
+};
+
+/// Canonicalize an extract_map using the result of a pointwise operation.
+/// Transforms:
+/// %v = arith.addf %a, %b : vector32xf32>
+/// %dv = vector.extract_map %v[%id] : vector<32xf32> to vector<1xf32>
+/// to:
+/// %da = vector.extract_map %a[%id] : vector<32xf32> to vector<1xf32>
+/// %db = vector.extract_map %a[%id] : vector<32xf32> to vector<1xf32>
+/// %dv = arith.addf %da, %db : vector<1xf32>
+struct PointwiseExtractPattern : public OpRewritePattern<vector::ExtractMapOp> {
+  using OpRewritePattern<vector::ExtractMapOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::ExtractMapOp extract,
+                                PatternRewriter &rewriter) const override {
+    Operation *definedOp = extract.vector().getDefiningOp();
+    if (!definedOp || !OpTrait::hasElementwiseMappableTraits(definedOp) ||
+        definedOp->getNumResults() != 1)
+      return failure();
+    Location loc = extract.getLoc();
+    SmallVector<Value, 4> extractOperands;
+    for (OpOperand &operand : definedOp->getOpOperands()) {
+      auto vecType = operand.get().getType().template dyn_cast<VectorType>();
+      if (!vecType) {
+        extractOperands.push_back(operand.get());
+        continue;
+      }
+      extractOperands.push_back(rewriter.create<vector::ExtractMapOp>(
+          loc,
+          VectorType::get(extract.getResultType().getShape(),
+                          vecType.getElementType()),
+          operand.get(), extract.ids()));
+    }
+    Operation *newOp = cloneOpWithOperandsAndTypes(
+        rewriter, loc, definedOp, extractOperands, extract.getResultType());
+    rewriter.replaceOp(extract, newOp->getResult(0));
+    return success();
+  }
+};
+
+/// Canonicalize an extract_map using the result of a contract operation.
+/// This propagate the extract_map to operands.
+struct ContractExtractPattern : public OpRewritePattern<vector::ExtractMapOp> {
+  using OpRewritePattern<vector::ExtractMapOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::ExtractMapOp extract,
+                                PatternRewriter &rewriter) const override {
+    Operation *definedOp = extract.vector().getDefiningOp();
+    auto contract = dyn_cast_or_null<vector::ContractionOp>(definedOp);
+    if (!contract)
+      return failure();
+    Location loc = contract.getLoc();
+    unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
+    AffineMap affineMap = contract.getIndexingMaps()[accIndex];
+    // Create a map of the dimensions distributed based on the acc affine map.
+    // Only parallel dimensions are being distributed, reduction dimensions are
+    // untouched.
+    DenseMap<int64_t, int64_t> map;
+    for (unsigned i : llvm::seq(unsigned(0), affineMap.getNumResults()))
+      map[affineMap.getDimPosition(i)] = extract.getResultType().getDimSize(i);
+    SmallVector<Value, 4> extractOperands;
+    for (auto it : llvm::enumerate(contract.getIndexingMaps())) {
+      // For each operands calculate the new vector type after distribution.
+      Value operand = contract->getOperand(it.index());
+      auto vecType = operand.getType().cast<VectorType>();
+      SmallVector<int64_t> operandShape(vecType.getShape().begin(),
+                                        vecType.getShape().end());
+      for (unsigned i : llvm::seq(unsigned(0), it.value().getNumResults())) {
+        unsigned dim = it.value().getDimPosition(i);
+        auto distributedDim = map.find(dim);
+        // If the dimension is not in the map it means it is a reduction and
+        // doesn't get distributed.
+        if (distributedDim == map.end())
+          continue;
+        operandShape[i] = distributedDim->second;
+      }
+      VectorType newVecType =
+          VectorType::get(operandShape, vecType.getElementType());
+      extractOperands.push_back(rewriter.create<vector::ExtractMapOp>(
+          loc, newVecType, operand, extract.ids()));
+    }
+    Operation *newOp =
+        cloneOpWithOperandsAndTypes(rewriter, loc, definedOp, extractOperands,
+                                    extract.getResult().getType());
+    rewriter.replaceOp(extract, newOp->getResult(0));
+    return success();
+  }
+};
+
+/// Converts TransferRead op used by ExtractMap op into a smaller dimension
+/// TransferRead.
+/// Example:
+/// ```
+/// %a = vector.transfer_read %A[%c0, %c0, %c0], %cf0:
+///   memref<64x64x64xf32>, vector<64x4x32xf32>
+/// %e = vector.extract_map %a[%id] : vector<64x4x32xf32> to vector<2x4x1xf32>
+/// ```
+/// to:
+/// ```
+/// %id1 = affine.apply affine_map<()[s0] -> (s0 * 2)> (%id)
+/// %e = vector.transfer_read %A[%id1, %c0, %id1], %cf0 :
+///   memref<64x64x64xf32>, vector<2x4x1xf32>
+/// ```
+struct TransferReadExtractPattern
+    : public OpRewritePattern<vector::TransferReadOp> {
+  TransferReadExtractPattern(MLIRContext *context)
+      : OpRewritePattern<vector::TransferReadOp>(context) {}
+  LogicalResult matchAndRewrite(vector::TransferReadOp read,
+                                PatternRewriter &rewriter) const override {
+    // TODO: support 0-d corner case.
+    if (read.getTransferRank() == 0)
+      return failure();
+
+    if (!read.getResult().hasOneUse())
+      return failure();
+    auto extract =
+        dyn_cast<vector::ExtractMapOp>(*read.getResult().getUsers().begin());
+    if (!extract)
+      return failure();
+    if (read.mask())
+      return failure();
+
+    SmallVector<Value, 4> indices(read.indices().begin(), read.indices().end());
+    AffineMap indexMap = extract.map().compose(read.permutation_map());
+    unsigned idCount = 0;
+    ImplicitLocOpBuilder lb(read.getLoc(), rewriter);
+    for (auto it :
+         llvm::zip(indexMap.getResults(), extract.map().getResults())) {
+      AffineExpr d0, d1;
+      bindDims(read.getContext(), d0, d1);
+      auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
+      if (!indexExpr)
+        continue;
+      unsigned indexPos = indexExpr.getPosition();
+      unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
+      auto scale = getAffineConstantExpr(
+          extract.getResultType().getDimSize(vectorPos), read.getContext());
+      indices[indexPos] = makeComposedAffineApply(
+          rewriter, read.getLoc(), d0 + scale * d1,
+          {indices[indexPos], extract.ids()[idCount++]});
+    }
+    Value newRead = lb.create<vector::TransferReadOp>(
+        extract.getType(), read.source(), indices, read.permutation_mapAttr(),
+        read.padding(), read.mask(), read.in_boundsAttr());
+    Value dest = lb.create<arith::ConstantOp>(
+        read.getType(), rewriter.getZeroAttr(read.getType()));
+    newRead = lb.create<vector::InsertMapOp>(newRead, dest, extract.ids());
+    rewriter.replaceOp(read, newRead);
+    return success();
+  }
+};
+
+struct TransferWriteInsertPattern
+    : public OpRewritePattern<vector::TransferWriteOp> {
+  TransferWriteInsertPattern(MLIRContext *context)
+      : OpRewritePattern<vector::TransferWriteOp>(context) {}
+  LogicalResult matchAndRewrite(vector::TransferWriteOp write,
+                                PatternRewriter &rewriter) const override {
+    // TODO: support 0-d corner case.
+    if (write.getTransferRank() == 0)
+      return failure();
+
+    auto insert = write.vector().getDefiningOp<vector::InsertMapOp>();
+    if (!insert)
+      return failure();
+    if (write.mask())
+      return failure();
+    SmallVector<Value, 4> indices(write.indices().begin(),
+                                  write.indices().end());
+    AffineMap indexMap = insert.map().compose(write.permutation_map());
+    unsigned idCount = 0;
+    Location loc = write.getLoc();
+    for (auto it :
+         llvm::zip(indexMap.getResults(), insert.map().getResults())) {
+      AffineExpr d0, d1;
+      bindDims(write.getContext(), d0, d1);
+      auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
+      if (!indexExpr)
+        continue;
+      unsigned indexPos = indexExpr.getPosition();
+      unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
+      auto scale = getAffineConstantExpr(
+          insert.getSourceVectorType().getDimSize(vectorPos),
+          write.getContext());
+      indices[indexPos] =
+          makeComposedAffineApply(rewriter, loc, d0 + scale * d1,
+                                  {indices[indexPos], insert.ids()[idCount++]});
+    }
+    rewriter.create<vector::TransferWriteOp>(
+        loc, insert.vector(), write.source(), indices,
+        write.permutation_mapAttr(), write.in_boundsAttr());
+    rewriter.eraseOp(write);
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::vector::populateVectorUnrollPatterns(
+    RewritePatternSet &patterns, const UnrollVectorOptions &options) {
+  patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
+               UnrollContractionPattern, UnrollElementwisePattern>(
+      patterns.getContext(), options);
+}
+
+void mlir::vector::populatePropagateVectorDistributionPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<PointwiseExtractPattern, ContractExtractPattern,
+               TransferReadExtractPattern, TransferWriteInsertPattern>(
+      patterns.getContext());
+}


        


More information about the Mlir-commits mailing list