[Mlir-commits] [mlir] 8e4c806 - [mlir][Linalg] NFC - Add additional control to lower vector.shape_cast ops

Nicolas Vasilache llvmlistbot at llvm.org
Wed Oct 27 01:14:20 PDT 2021


Author: Nicolas Vasilache
Date: 2021-10-27T08:12:57Z
New Revision: 8e4c806ed5a481e4d2163c8330f3c3c024d61a36

URL: https://github.com/llvm/llvm-project/commit/8e4c806ed5a481e4d2163c8330f3c3c024d61a36
DIFF: https://github.com/llvm/llvm-project/commit/8e4c806ed5a481e4d2163c8330f3c3c024d61a36.diff

LOG: [mlir][Linalg] NFC - Add additional control to lower vector.shape_cast ops

This also moves some code to a new patterns file.

Differential Revision: https://reviews.llvm.org/D112575

Added: 
    mlir/lib/Dialect/Vector/VectorTransferPermutationMapRewritePatterns.cpp

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/include/mlir/Dialect/Vector/VectorOps.h
    mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h
    mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
    mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
    mlir/lib/Dialect/Vector/CMakeLists.txt
    mlir/lib/Dialect/Vector/VectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index cfa38d71c2ba3..27688b5492530 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -877,49 +877,63 @@ struct LinalgEnablingOptions {
 /// Vector lowering options control how ops are lowered down to 1-D and scf.for
 /// form.
 struct LinalgVectorLoweringOptions {
-  /// Maximal transfer rank under which we do not lower further.
-  int64_t maxTransferRank = 1;
-  LinalgVectorLoweringOptions &setMaxTransferRank(int64_t val) {
-    maxTransferRank = val;
-    return *this;
-  }
-  /// Vector lowering operations may result in surprising behavior when
-  /// composing multiple codegen strategies and must be enabled explicitly.
-  bool transferLowering = true;
-  LinalgVectorLoweringOptions &enableTransferLowering(bool val = true) {
-    transferLowering = val;
-    return *this;
-  }
-  /// Enable lowering of vector.transpose.
-  bool transposeLowering = false;
-  LinalgVectorLoweringOptions &enableVectorTransposeLowering(bool val = true) {
-    transposeLowering = val;
+  /// Enable lowering of vector.contract.
+  /// In a progressive lowering of vectors, this would be the 1st step.
+  bool contractionLowering = false;
+  LinalgVectorLoweringOptions &enableContractionLowering(bool val = true) {
+    contractionLowering = val;
     return *this;
   }
   /// Enable lowering of vector.multi_reduce.
+  /// In a progressive lowering of vectors, this would be the 2nd step.
   bool multiReductionLowering = false;
   LinalgVectorLoweringOptions &enableMultiReductionLowering(bool val = true) {
     multiReductionLowering = val;
     return *this;
   }
-  /// Enable lowering of vector.contract.
-  bool contractionLowering = false;
-  LinalgVectorLoweringOptions &enableContractionLowering(bool val = true) {
-    contractionLowering = val;
-    return *this;
-  }
   /// Trigger full / partial vector.transfer splits.
+  /// In a progressive lowering of vectors, this would be the 3rd step.
   bool transferPartialRewrite = false;
   LinalgVectorLoweringOptions &enableTransferPartialRewrite(bool val = true) {
     transferPartialRewrite = val;
     return *this;
   }
   /// Enable lowering of vector.transfer to scf.
+  /// In a progressive lowering of vectors, this would be the 4th step.
   bool transferToSCFConversion = false;
   LinalgVectorLoweringOptions &enableTransferToSCFConversion(bool val = true) {
     transferToSCFConversion = val;
     return *this;
   }
+  /// Maximal transfer rank under which we do not lower further.
+  int64_t maxTransferRank = 1;
+  LinalgVectorLoweringOptions &setMaxTransferRank(int64_t val) {
+    maxTransferRank = val;
+    return *this;
+  }
+  /// Vector lowering operations may result in surprising behavior when
+  /// composing multiple codegen strategies and must be enabled explicitly.
+  /// In a progressive lowering of vectors, this would be the 5th step.
+  bool transferLowering = true;
+  LinalgVectorLoweringOptions &enableTransferLowering(bool val = true) {
+    transferLowering = val;
+    return *this;
+  }
+  /// Enable lowering of vector.shape_cast to insert/extract.
+  /// In a progressive lowering of vectors, this would be the 6th step.
+  bool shapeCastLowering = true;
+  LinalgVectorLoweringOptions &enableShapeCastLowering(bool val = true) {
+    shapeCastLowering = val;
+    return *this;
+  }
+  /// Enable lowering of vector.transpose.
+  /// In a progressive lowering of vectors, this would be the 7th step.
+  bool transposeLowering = false;
+  LinalgVectorLoweringOptions &enableVectorTransposeLowering(bool val = true) {
+    transposeLowering = val;
+    return *this;
+  }
+
   /// Configure the post staged-patterns late vector.transfer to scf
   /// conversion.
   VectorTransferToSCFOptions vectorTransferToSCFOptions;

diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h
index c6f4ba4bc0e59..a29683431344f 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h
@@ -81,12 +81,6 @@ void populateVectorTransferLoweringPatterns(
     RewritePatternSet &patterns,
     llvm::Optional<unsigned> maxTransferRank = llvm::None);
 
-/// Collect a set of transfer read/write lowering patterns that simplify the
-/// permutation map (e.g., converting it to a minor identity map) by inserting
-/// broadcasts and transposes.
-void populateVectorTransferPermutationMapLoweringPatterns(
-    RewritePatternSet &patterns);
-
 /// These patterns materialize masks for various vector ops such as transfers.
 void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
                                                bool enableIndexOptimizations);

diff  --git a/mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h
index 47375c56673f8..587f334bc0473 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h
@@ -169,6 +169,64 @@ void populateVectorContractLoweringPatterns(
 /// transpose/broadcast ops into the contract.
 void populateVectorReductionToContractPatterns(RewritePatternSet &patterns);
 
+//===----------------------------------------------------------------------===//
+// Vector.transfer patterns.
+//===----------------------------------------------------------------------===//
+/// Collect a set of transfer read/write lowering patterns that simplify the
+/// permutation map (e.g., converting it to a minor identity map) by inserting
+/// broadcasts and transposes. More specifically:
+///
+/// [TransferReadPermutationLowering]
+/// Lower transfer_read op with permutation into a transfer_read with a
+/// permutation map composed of leading zeros followed by a minor identity +
+/// vector.transpose op.
+/// Ex:
+///     vector.transfer_read ...
+///         permutation_map: (d0, d1, d2) -> (0, d1)
+/// into:
+///     %v = vector.transfer_read ...
+///         permutation_map: (d0, d1, d2) -> (d1, 0)
+///     vector.transpose %v, [1, 0]
+///
+///     vector.transfer_read ...
+///         permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3)
+/// into:
+///     %v = vector.transfer_read ...
+///         permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3)
+///     vector.transpose %v, [0, 1, 3, 2, 4]
+/// Note that an alternative is to transform it to linalg.transpose +
+/// vector.transfer_read to do the transpose in memory instead.
+///
+/// [TransferWritePermutationLowering]
+/// Lower transfer_write op with permutation into a transfer_write with a
+/// minor identity permutation map. (transfer_write ops cannot have broadcasts.)
+/// Ex:
+///     vector.transfer_write %v ...
+///         permutation_map: (d0, d1, d2) -> (d2, d0, d1)
+/// into:
+///     %tmp = vector.transpose %v, [2, 0, 1]
+///     vector.transfer_write %tmp ...
+///         permutation_map: (d0, d1, d2) -> (d0, d1, d2)
+///
+///     vector.transfer_write %v ...
+///         permutation_map: (d0, d1, d2, d3) -> (d3, d2)
+/// into:
+///     %tmp = vector.transpose %v, [1, 0]
+///     %v = vector.transfer_write %tmp ...
+///         permutation_map: (d0, d1, d2, d3) -> (d2, d3)
+///
+/// [TransferOpReduceRank]
+/// Lower transfer_read op with broadcast in the leading dimensions into
+/// transfer_read of lower rank + vector.broadcast.
+/// Ex: vector.transfer_read ...
+///         permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3)
+/// into:
+///     %v = vector.transfer_read ...
+///         permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
+///     vector.broadcast %v
+void populateVectorTransferPermutationMapLoweringPatterns(
+    RewritePatternSet &patterns);
+
 /// Collect a set of patterns to reduce the rank of the operands of vector
 /// transfer ops to operate on the largest contigious vector.
 /// These patterns are useful when lowering to dialects with 1d vector type

diff  --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 1aacd4919c8a3..2fd4959884981 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -20,8 +20,7 @@
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/SCF.h"
-#include "mlir/Dialect/Vector/VectorOps.h"
-#include "mlir/Dialect/Vector/VectorUtils.h"
+#include "mlir/Dialect/Vector/VectorTransforms.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/Pass/Pass.h"

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
index 97831866ffd08..4462bbe2147a8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
@@ -264,34 +264,44 @@ struct LinalgStrategyLowerVectorsPass
     MLIRContext *context = funcOp.getContext();
     RewritePatternSet patterns(context);
     vector::populateVectorToVectorCanonicalizationPatterns(patterns);
-    if (options.transferLowering) {
-      vector::populateVectorTransferLoweringPatterns(patterns,
-                                                     options.maxTransferRank);
-    }
-    if (options.transposeLowering) {
-      vector::populateVectorTransposeLoweringPatterns(
-          patterns, options.vectorTransformOptions);
-    }
-    if (options.multiReductionLowering) {
-      vector::populateVectorMultiReductionLoweringPatterns(
-          patterns,
-          options.vectorTransformOptions.vectorMultiReductionLowering);
-    }
+    // In a progressive lowering of vectors, this would be the 1st step.
     if (options.contractionLowering) {
       patterns.add<ContractionOpToOuterProductOpLowering,
                    ContractionOpToMatmulOpLowering, ContractionOpLowering>(
           options.vectorTransformOptions, context);
       vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
     }
+    // In a progressive lowering of vectors, this would be the 2nd step.
+    if (options.multiReductionLowering) {
+      vector::populateVectorMultiReductionLoweringPatterns(
+          patterns,
+          options.vectorTransformOptions.vectorMultiReductionLowering);
+    }
+    // In a progressive lowering of vectors, this would be the 3rd step.
     if (options.transferPartialRewrite) {
       patterns.add<vector::VectorTransferFullPartialRewriter>(
           context, options.vectorTransformOptions);
     }
+    // In a progressive lowering of vectors, this would be the 4th step.
+    if (options.transferLowering) {
+      vector::populateVectorTransferLoweringPatterns(patterns,
+                                                     options.maxTransferRank);
+    }
+    // In a progressive lowering of vectors, this would be the 5th step.
     if (options.transferToSCFConversion) {
-      populateVectorToSCFConversionPatterns(patterns,
-                                            options.vectorTransferToSCFOptions);
+      populateVectorToSCFConversionPatterns(
+          patterns, options.vectorTransferToSCFOptions.setTargetRank(
+                        options.maxTransferRank));
+    }
+    // In a progressive lowering of vectors, this would be the 6th step.
+    if (options.shapeCastLowering) {
+      vector::populateVectorShapeCastLoweringPatterns(patterns);
+    }
+    // In a progressive lowering of vectors, this would be the 7th step.
+    if (options.transposeLowering) {
+      vector::populateVectorTransposeLoweringPatterns(
+          patterns, options.vectorTransformOptions);
     }
-    vector::populateVectorShapeCastLoweringPatterns(patterns);
     (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
   }
 

diff  --git a/mlir/lib/Dialect/Vector/CMakeLists.txt b/mlir/lib/Dialect/Vector/CMakeLists.txt
index f620a370c8359..abd961626e599 100644
--- a/mlir/lib/Dialect/Vector/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRVector
   VectorMultiDimReductionTransforms.cpp
   VectorOps.cpp
   VectorTransferOpTransforms.cpp
+  VectorTransferPermutationMapRewritePatterns.cpp
   VectorTransforms.cpp
   VectorUtils.cpp
 

diff  --git a/mlir/lib/Dialect/Vector/VectorTransferPermutationMapRewritePatterns.cpp b/mlir/lib/Dialect/Vector/VectorTransferPermutationMapRewritePatterns.cpp
new file mode 100644
index 0000000000000..3f5c3127a286c
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/VectorTransferPermutationMapRewritePatterns.cpp
@@ -0,0 +1,260 @@
+//===- VectorTransferPermutationMapRewritePatterns.cpp - Xfer map rewrite -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements rewrite patterns for the permutation_map attribute of
+// vector.transfer operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/VectorTransforms.h"
+#include "mlir/Interfaces/VectorInterfaces.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+/// Transpose a vector transfer op's `in_bounds` attribute according to given
+/// indices.
+static ArrayAttr
+transposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr,
+                      const SmallVector<unsigned> &permutation) {
+  SmallVector<bool> newInBoundsValues;
+  for (unsigned pos : permutation)
+    newInBoundsValues.push_back(
+        attr.getValue()[pos].cast<BoolAttr>().getValue());
+  return builder.getBoolArrayAttr(newInBoundsValues);
+}
+/// Lower transfer_read op with permutation into a transfer_read with a
+/// permutation map composed of leading zeros followed by a minor identiy +
+/// vector.transpose op.
+/// Ex:
+///     vector.transfer_read ...
+///         permutation_map: (d0, d1, d2) -> (0, d1)
+/// into:
+///     %v = vector.transfer_read ...
+///         permutation_map: (d0, d1, d2) -> (d1, 0)
+///     vector.transpose %v, [1, 0]
+///
+///     vector.transfer_read ...
+///         permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3)
+/// into:
+///     %v = vector.transfer_read ...
+///         permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3)
+///     vector.transpose %v, [0, 1, 3, 2, 4]
+/// Note that an alternative is to transform it to linalg.transpose +
+/// vector.transfer_read to do the transpose in memory instead.
+struct TransferReadPermutationLowering
+    : public OpRewritePattern<vector::TransferReadOp> {
+  using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::TransferReadOp op,
+                                PatternRewriter &rewriter) const override {
+    SmallVector<unsigned> permutation;
+    AffineMap map = op.permutation_map();
+    if (map.getNumResults() == 0)
+      return failure();
+    if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
+      return failure();
+    AffineMap permutationMap =
+        map.getPermutationMap(permutation, op.getContext());
+    if (permutationMap.isIdentity())
+      return failure();
+
+    permutationMap = map.getPermutationMap(permutation, op.getContext());
+    // Caluclate the map of the new read by applying the inverse permutation.
+    permutationMap = inversePermutation(permutationMap);
+    AffineMap newMap = permutationMap.compose(map);
+    // Apply the reverse transpose to deduce the type of the transfer_read.
+    ArrayRef<int64_t> originalShape = op.getVectorType().getShape();
+    SmallVector<int64_t> newVectorShape(originalShape.size());
+    for (auto pos : llvm::enumerate(permutation)) {
+      newVectorShape[pos.value()] = originalShape[pos.index()];
+    }
+
+    // Transpose mask operand.
+    Value newMask;
+    if (op.mask()) {
+      // Remove unused dims from the permutation map. E.g.:
+      // E.g.:  (d0, d1, d2, d3, d4, d5) -> (d5, 0, d3, 0, d2)
+      // comp = (d0, d1, d2) -> (d2, 0, d1, 0 d0)
+      auto comp = compressUnusedDims(map);
+      // Get positions of remaining result dims.
+      // E.g.:  (d0, d1, d2) -> (d2, 0, d1, 0 d0)
+      // maskTransposeIndices = [ 2,     1,    0]
+      SmallVector<int64_t> maskTransposeIndices;
+      for (unsigned i = 0; i < comp.getNumResults(); ++i) {
+        if (auto expr = comp.getResult(i).dyn_cast<AffineDimExpr>())
+          maskTransposeIndices.push_back(expr.getPosition());
+      }
+
+      newMask = rewriter.create<vector::TransposeOp>(op.getLoc(), op.mask(),
+                                                     maskTransposeIndices);
+    }
+
+    // Transpose in_bounds attribute.
+    ArrayAttr newInBounds =
+        op.in_bounds() ? transposeInBoundsAttr(
+                             rewriter, op.in_bounds().getValue(), permutation)
+                       : ArrayAttr();
+
+    // Generate new transfer_read operation.
+    VectorType newReadType =
+        VectorType::get(newVectorShape, op.getVectorType().getElementType());
+    Value newRead = rewriter.create<vector::TransferReadOp>(
+        op.getLoc(), newReadType, op.source(), op.indices(), newMap,
+        op.padding(), newMask, newInBounds);
+
+    // Transpose result of transfer_read.
+    SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
+    rewriter.replaceOpWithNewOp<vector::TransposeOp>(op, newRead,
+                                                     transposePerm);
+    return success();
+  }
+};
+
+/// Lower transfer_write op with permutation into a transfer_write with a
+/// minor identity permutation map. (transfer_write ops cannot have broadcasts.)
+/// Ex:
+///     vector.transfer_write %v ...
+///         permutation_map: (d0, d1, d2) -> (d2, d0, d1)
+/// into:
+///     %tmp = vector.transpose %v, [2, 0, 1]
+///     vector.transfer_write %tmp ...
+///         permutation_map: (d0, d1, d2) -> (d0, d1, d2)
+///
+///     vector.transfer_write %v ...
+///         permutation_map: (d0, d1, d2, d3) -> (d3, d2)
+/// into:
+///     %tmp = vector.transpose %v, [1, 0]
+///     %v = vector.transfer_write %tmp ...
+///         permutation_map: (d0, d1, d2, d3) -> (d2, d3)
+struct TransferWritePermutationLowering
+    : public OpRewritePattern<vector::TransferWriteOp> {
+  using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::TransferWriteOp op,
+                                PatternRewriter &rewriter) const override {
+    if (op.isZeroD())
+      return failure();
+
+    SmallVector<unsigned> permutation;
+    AffineMap map = op.permutation_map();
+    if (map.isMinorIdentity())
+      return failure();
+    if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
+      return failure();
+
+    // Remove unused dims from the permutation map. E.g.:
+    // E.g.:  (d0, d1, d2, d3, d4, d5) -> (d5, d3, d4)
+    // comp = (d0, d1, d2) -> (d2, d0, d1)
+    auto comp = compressUnusedDims(map);
+    // Get positions of remaining result dims.
+    SmallVector<int64_t> indices;
+    llvm::transform(comp.getResults(), std::back_inserter(indices),
+                    [](AffineExpr expr) {
+                      return expr.dyn_cast<AffineDimExpr>().getPosition();
+                    });
+
+    // Transpose mask operand.
+    Value newMask = op.mask() ? rewriter.create<vector::TransposeOp>(
+                                    op.getLoc(), op.mask(), indices)
+                              : Value();
+
+    // Transpose in_bounds attribute.
+    ArrayAttr newInBounds =
+        op.in_bounds() ? transposeInBoundsAttr(
+                             rewriter, op.in_bounds().getValue(), permutation)
+                       : ArrayAttr();
+
+    // Generate new transfer_write operation.
+    Value newVec =
+        rewriter.create<vector::TransposeOp>(op.getLoc(), op.vector(), indices);
+    auto newMap = AffineMap::getMinorIdentityMap(
+        map.getNumDims(), map.getNumResults(), rewriter.getContext());
+    rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
+        op, Type(), newVec, op.source(), op.indices(), newMap, newMask,
+        newInBounds);
+
+    return success();
+  }
+};
+
+/// Lower transfer_read op with broadcast in the leading dimensions into
+/// transfer_read of lower rank + vector.broadcast.
+/// Ex: vector.transfer_read ...
+///         permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3)
+/// into:
+///     %v = vector.transfer_read ...
+///         permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
+///     vector.broadcast %v
+struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
+  using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::TransferReadOp op,
+                                PatternRewriter &rewriter) const override {
+    AffineMap map = op.permutation_map();
+    unsigned numLeadingBroadcast = 0;
+    for (auto expr : map.getResults()) {
+      auto dimExpr = expr.dyn_cast<AffineConstantExpr>();
+      if (!dimExpr || dimExpr.getValue() != 0)
+        break;
+      numLeadingBroadcast++;
+    }
+    // If there are no leading zeros in the map there is nothing to do.
+    if (numLeadingBroadcast == 0)
+      return failure();
+    VectorType originalVecType = op.getVectorType();
+    unsigned reducedShapeRank = originalVecType.getRank() - numLeadingBroadcast;
+    // Calculate new map, vector type and masks without the leading zeros.
+    AffineMap newMap = AffineMap::get(
+        map.getNumDims(), 0, map.getResults().take_back(reducedShapeRank),
+        op.getContext());
+    // Only remove the leading zeros if the rest of the map is a minor identity
+    // with broadasting. Otherwise we first want to permute the map.
+    if (!newMap.isMinorIdentityWithBroadcasting())
+      return failure();
+
+    // TODO: support zero-dimension vectors natively.  See:
+    // https://llvm.discourse.group/t/should-we-have-0-d-vectors/3097.
+    // In the meantime, lower these to a scalar load when they pop up.
+    if (reducedShapeRank == 0) {
+      Value newRead = rewriter.create<memref::LoadOp>(
+          op.getLoc(), originalVecType.getElementType(), op.source(),
+          op.indices());
+      rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
+                                                       newRead);
+      return success();
+    }
+    SmallVector<int64_t> newShape = llvm::to_vector<4>(
+        originalVecType.getShape().take_back(reducedShapeRank));
+    // Vector rank cannot be zero. Handled by TransferReadToVectorLoadLowering.
+    if (newShape.empty())
+      return failure();
+    VectorType newReadType =
+        VectorType::get(newShape, originalVecType.getElementType());
+    ArrayAttr newInBounds =
+        op.in_bounds()
+            ? rewriter.getArrayAttr(
+                  op.in_boundsAttr().getValue().take_back(reducedShapeRank))
+            : ArrayAttr();
+    Value newRead = rewriter.create<vector::TransferReadOp>(
+        op.getLoc(), newReadType, op.source(), op.indices(), newMap,
+        op.padding(), op.mask(), newInBounds);
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
+                                                     newRead);
+    return success();
+  }
+};
+
+void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<TransferReadPermutationLowering,
+               TransferWritePermutationLowering, TransferOpReduceRank>(
+      patterns.getContext());
+}

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 4c7ef516fd927..efb22b9ab6be6 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -2882,240 +2882,6 @@ struct TransferWriteToVectorStoreLowering
   llvm::Optional<unsigned> maxTransferRank;
 };
 
-/// Transpose a vector transfer op's `in_bounds` attribute according to given
-/// indices.
-static ArrayAttr
-transposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr,
-                      const SmallVector<unsigned> &permutation) {
-  SmallVector<bool> newInBoundsValues;
-  for (unsigned pos : permutation)
-    newInBoundsValues.push_back(
-        attr.getValue()[pos].cast<BoolAttr>().getValue());
-  return builder.getBoolArrayAttr(newInBoundsValues);
-}
-
-/// Lower transfer_read op with permutation into a transfer_read with a
-/// permutation map composed of leading zeros followed by a minor identiy +
-/// vector.transpose op.
-/// Ex:
-///     vector.transfer_read ...
-///         permutation_map: (d0, d1, d2) -> (0, d1)
-/// into:
-///     %v = vector.transfer_read ...
-///         permutation_map: (d0, d1, d2) -> (d1, 0)
-///     vector.transpose %v, [1, 0]
-///
-///     vector.transfer_read ...
-///         permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3)
-/// into:
-///     %v = vector.transfer_read ...
-///         permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3)
-///     vector.transpose %v, [0, 1, 3, 2, 4]
-/// Note that an alternative is to transform it to linalg.transpose +
-/// vector.transfer_read to do the transpose in memory instead.
-struct TransferReadPermutationLowering
-    : public OpRewritePattern<vector::TransferReadOp> {
-  using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(vector::TransferReadOp op,
-                                PatternRewriter &rewriter) const override {
-    SmallVector<unsigned> permutation;
-    AffineMap map = op.permutation_map();
-    if (map.getNumResults() == 0)
-      return failure();
-    if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
-      return failure();
-    AffineMap permutationMap =
-        map.getPermutationMap(permutation, op.getContext());
-    if (permutationMap.isIdentity())
-      return failure();
-
-    permutationMap = map.getPermutationMap(permutation, op.getContext());
-    // Caluclate the map of the new read by applying the inverse permutation.
-    permutationMap = inversePermutation(permutationMap);
-    AffineMap newMap = permutationMap.compose(map);
-    // Apply the reverse transpose to deduce the type of the transfer_read.
-    ArrayRef<int64_t> originalShape = op.getVectorType().getShape();
-    SmallVector<int64_t> newVectorShape(originalShape.size());
-    for (auto pos : llvm::enumerate(permutation)) {
-      newVectorShape[pos.value()] = originalShape[pos.index()];
-    }
-
-    // Transpose mask operand.
-    Value newMask;
-    if (op.mask()) {
-      // Remove unused dims from the permutation map. E.g.:
-      // E.g.:  (d0, d1, d2, d3, d4, d5) -> (d5, 0, d3, 0, d2)
-      // comp = (d0, d1, d2) -> (d2, 0, d1, 0 d0)
-      auto comp = compressUnusedDims(map);
-      // Get positions of remaining result dims.
-      // E.g.:  (d0, d1, d2) -> (d2, 0, d1, 0 d0)
-      // maskTransposeIndices = [ 2,     1,    0]
-      SmallVector<int64_t> maskTransposeIndices;
-      for (unsigned i = 0; i < comp.getNumResults(); ++i) {
-        if (auto expr = comp.getResult(i).dyn_cast<AffineDimExpr>())
-          maskTransposeIndices.push_back(expr.getPosition());
-      }
-
-      newMask = rewriter.create<vector::TransposeOp>(op.getLoc(), op.mask(),
-                                                     maskTransposeIndices);
-    }
-
-    // Transpose in_bounds attribute.
-    ArrayAttr newInBounds =
-        op.in_bounds() ? transposeInBoundsAttr(
-                             rewriter, op.in_bounds().getValue(), permutation)
-                       : ArrayAttr();
-
-    // Generate new transfer_read operation.
-    VectorType newReadType =
-        VectorType::get(newVectorShape, op.getVectorType().getElementType());
-    Value newRead = rewriter.create<vector::TransferReadOp>(
-        op.getLoc(), newReadType, op.source(), op.indices(), newMap,
-        op.padding(), newMask, newInBounds);
-
-    // Transpose result of transfer_read.
-    SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
-    rewriter.replaceOpWithNewOp<vector::TransposeOp>(op, newRead,
-                                                     transposePerm);
-    return success();
-  }
-};
-
-/// Lower transfer_write op with permutation into a transfer_write with a
-/// minor identity permutation map. (transfer_write ops cannot have broadcasts.)
-/// Ex:
-///     vector.transfer_write %v ...
-///         permutation_map: (d0, d1, d2) -> (d2, d0, d1)
-/// into:
-///     %tmp = vector.transpose %v, [2, 0, 1]
-///     vector.transfer_write %tmp ...
-///         permutation_map: (d0, d1, d2) -> (d0, d1, d2)
-///
-///     vector.transfer_write %v ...
-///         permutation_map: (d0, d1, d2, d3) -> (d3, d2)
-/// into:
-///     %tmp = vector.transpose %v, [1, 0]
-///     %v = vector.transfer_write %tmp ...
-///         permutation_map: (d0, d1, d2, d3) -> (d2, d3)
-struct TransferWritePermutationLowering
-    : public OpRewritePattern<vector::TransferWriteOp> {
-  using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(vector::TransferWriteOp op,
-                                PatternRewriter &rewriter) const override {
-    if (op.isZeroD())
-      return failure();
-
-    SmallVector<unsigned> permutation;
-    AffineMap map = op.permutation_map();
-    if (map.isMinorIdentity())
-      return failure();
-    if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
-      return failure();
-
-    // Remove unused dims from the permutation map. E.g.:
-    // E.g.:  (d0, d1, d2, d3, d4, d5) -> (d5, d3, d4)
-    // comp = (d0, d1, d2) -> (d2, d0, d1)
-    auto comp = compressUnusedDims(map);
-    // Get positions of remaining result dims.
-    SmallVector<int64_t> indices;
-    llvm::transform(comp.getResults(), std::back_inserter(indices),
-                    [](AffineExpr expr) {
-                      return expr.dyn_cast<AffineDimExpr>().getPosition();
-                    });
-
-    // Transpose mask operand.
-    Value newMask = op.mask() ? rewriter.create<vector::TransposeOp>(
-                                    op.getLoc(), op.mask(), indices)
-                              : Value();
-
-    // Transpose in_bounds attribute.
-    ArrayAttr newInBounds =
-        op.in_bounds() ? transposeInBoundsAttr(
-                             rewriter, op.in_bounds().getValue(), permutation)
-                       : ArrayAttr();
-
-    // Generate new transfer_write operation.
-    Value newVec =
-        rewriter.create<vector::TransposeOp>(op.getLoc(), op.vector(), indices);
-    auto newMap = AffineMap::getMinorIdentityMap(
-        map.getNumDims(), map.getNumResults(), rewriter.getContext());
-    rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
-        op, Type(), newVec, op.source(), op.indices(), newMap, newMask,
-        newInBounds);
-
-    return success();
-  }
-};
-
-/// Lower transfer_read op with broadcast in the leading dimensions into
-/// transfer_read of lower rank + vector.broadcast.
-/// Ex: vector.transfer_read ...
-///         permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3)
-/// into:
-///     %v = vector.transfer_read ...
-///         permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
-///     vector.broadcast %v
-struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
-  using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(vector::TransferReadOp op,
-                                PatternRewriter &rewriter) const override {
-    AffineMap map = op.permutation_map();
-    unsigned numLeadingBroadcast = 0;
-    for (auto expr : map.getResults()) {
-      auto dimExpr = expr.dyn_cast<AffineConstantExpr>();
-      if (!dimExpr || dimExpr.getValue() != 0)
-        break;
-      numLeadingBroadcast++;
-    }
-    // If there are no leading zeros in the map there is nothing to do.
-    if (numLeadingBroadcast == 0)
-      return failure();
-    VectorType originalVecType = op.getVectorType();
-    unsigned reducedShapeRank = originalVecType.getRank() - numLeadingBroadcast;
-    // Calculate new map, vector type and masks without the leading zeros.
-    AffineMap newMap = AffineMap::get(
-        map.getNumDims(), 0, map.getResults().take_back(reducedShapeRank),
-        op.getContext());
-    // Only remove the leading zeros if the rest of the map is a minor identity
-    // with broadasting. Otherwise we first want to permute the map.
-    if (!newMap.isMinorIdentityWithBroadcasting())
-      return failure();
-
-    // TODO: support zero-dimension vectors natively.  See:
-    // https://llvm.discourse.group/t/should-we-have-0-d-vectors/3097.
-    // In the meantime, lower these to a scalar load when they pop up.
-    if (reducedShapeRank == 0) {
-      Value newRead = rewriter.create<memref::LoadOp>(
-          op.getLoc(), originalVecType.getElementType(), op.source(),
-          op.indices());
-      rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
-                                                       newRead);
-      return success();
-    }
-    SmallVector<int64_t> newShape = llvm::to_vector<4>(
-        originalVecType.getShape().take_back(reducedShapeRank));
-    // Vector rank cannot be zero. Handled by TransferReadToVectorLoadLowering.
-    if (newShape.empty())
-      return failure();
-    VectorType newReadType =
-        VectorType::get(newShape, originalVecType.getElementType());
-    ArrayAttr newInBounds =
-        op.in_bounds()
-            ? rewriter.getArrayAttr(
-                  op.in_boundsAttr().getValue().take_back(reducedShapeRank))
-            : ArrayAttr();
-    Value newRead = rewriter.create<vector::TransferReadOp>(
-        op.getLoc(), newReadType, op.source(), op.indices(), newMap,
-        op.padding(), op.mask(), newInBounds);
-    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
-                                                     newRead);
-    return success();
-  }
-};
-
 // Trims leading one dimensions from `oldType` and returns the result type.
 // Returns `vector<1xT>` if `oldType` only has one element.
 static VectorType trimLeadingOneDims(VectorType oldType) {
@@ -3891,23 +3657,6 @@ void mlir::vector::populateVectorReductionToContractPatterns(
                CombineContractTranspose>(patterns.getContext());
 }
 
-void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
-    RewritePatternSet &patterns) {
-  patterns.add<TransferReadPermutationLowering,
-               TransferWritePermutationLowering, TransferOpReduceRank>(
-      patterns.getContext());
-}
-
-void mlir::vector::populateVectorTransferLoweringPatterns(
-    RewritePatternSet &patterns, llvm::Optional<unsigned> maxTransferRank) {
-  patterns.add<TransferReadToVectorLoadLowering,
-               TransferWriteToVectorStoreLowering>(patterns.getContext(),
-                                                   maxTransferRank);
-  patterns
-      .add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>(
-          patterns.getContext());
-}
-
 void mlir::vector::populateVectorUnrollPatterns(
     RewritePatternSet &patterns, const UnrollVectorOptions &options) {
   patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
@@ -3920,3 +3669,13 @@ void mlir::vector::
         RewritePatternSet &patterns) {
   patterns.add<DropInnerMostUnitDims>(patterns.getContext());
 }
+
+void mlir::vector::populateVectorTransferLoweringPatterns(
+    RewritePatternSet &patterns, llvm::Optional<unsigned> maxTransferRank) {
+  patterns.add<TransferReadToVectorLoadLowering,
+               TransferWriteToVectorStoreLowering>(patterns.getContext(),
+                                                   maxTransferRank);
+  patterns
+      .add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>(
+          patterns.getContext());
+}


        


More information about the Mlir-commits mailing list