[Mlir-commits] [mlir] 5288c25 - [mlir][vector] Add lowering of Transfer_read with broadcast and permutation map

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Mar 29 08:40:59 PDT 2021


Author: thomasraoux
Date: 2021-03-29T08:38:43-07:00
New Revision: 5288c25c7008debf2a1401cd288fda1179d00484

URL: https://github.com/llvm/llvm-project/commit/5288c25c7008debf2a1401cd288fda1179d00484
DIFF: https://github.com/llvm/llvm-project/commit/5288c25c7008debf2a1401cd288fda1179d00484.diff

LOG: [mlir][vector] Add lowering of Transfer_read with broadcast and permutation map

Convert transfer_read ops with permutation maps into simpler
transfer_read with minority map + vector.braodcast and vector.transpose.
And transfer_read with leading dimensions broacast into transfer_read of
lower rank.

Differential Revision: https://reviews.llvm.org/D99019

Added: 
    

Modified: 
    mlir/include/mlir/IR/AffineMap.h
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/lib/IR/AffineMap.cpp
    mlir/test/Dialect/Vector/vector-transfer-lowering.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index e837fc070ab0..abc3e1b4a6fe 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -113,6 +113,22 @@ class AffineMap {
   bool isMinorIdentityWithBroadcasting(
       SmallVectorImpl<unsigned> *broadcastedDims = nullptr) const;
 
+  /// Return true if this affine map can be converted to a minor identity with
+  /// broadcast by doing a permute. Return a permutation (there may be
+  /// several) to apply to get to a minor identity with broadcasts.
+  /// Ex:
+  ///  * (d0, d1, d2) -> (0, d1) maps to minor identity (d1, 0 = d2) with
+  ///  perm = [1, 0] and broadcast d2
+  ///  * (d0, d1, d2) -> (d0, 0) cannot be mapped to a minor identity by
+  ///  permutation + broadcast
+  ///  * (d0, d1, d2, d3) -> (0, d1, d3) maps to minor identity (d1, 0 = d2, d3)
+  ///  with perm = [1, 0, 2] and broadcast d2
+  ///  * (d0, d1) -> (d1, 0, 0, d0) maps to minor identity (d0, d1) with extra
+  ///  leading broadcat dimensions. The map returned would be (0, 0, d0, d1)
+  ///  with perm = [3, 0, 1, 2]
+  bool isPermutationOfMinorIdentityWithBroadcasting(
+      SmallVectorImpl<unsigned> &permutedDims) const;
+
   /// Returns true if this affine map is an empty map, i.e., () -> ().
   bool isEmpty() const;
 

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 8766efa406c2..50014e874274 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -2842,6 +2842,113 @@ struct TransferWriteToVectorStoreLowering
   }
 };
 
+/// Lower transfer_read op with permutation into a transfer_read with a
+/// permutation map composed of leading zeros followed by a minor identiy +
+/// vector.transpose op.
+/// Ex:
+///     vector.transfer_read ...
+///         permutation_map: (d0, d1, d2) -> (0, d1)
+/// into:
+///     %v = vector.transfer_read ...
+///         permutation_map: (d0, d1, d2) -> (d1, 0)
+///     vector.transpose %v, [1, 0]
+///
+///     vector.transfer_read ...
+///         permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3)
+/// into:
+///     %v = vector.transfer_read ...
+///         permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3)
+///     vector.transpose %v, [0, 1, 3, 2, 4]
+/// Note that an alternative is to transform it to linalg.transpose +
+/// vector.transfer_read to do the transpose in memory instead.
+struct TransferReadPermutationLowering
+    : public OpRewritePattern<vector::TransferReadOp> {
+  using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::TransferReadOp op,
+                                PatternRewriter &rewriter) const override {
+    SmallVector<unsigned> permutation;
+    AffineMap map = op.permutation_map();
+    if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
+      return failure();
+
+    AffineMap permutationMap =
+        map.getPermutationMap(permutation, op.getContext());
+    if (permutationMap.isIdentity())
+      return failure();
+    // Caluclate the map of the new read by applying the inverse permutation.
+    permutationMap = inversePermutation(permutationMap);
+    AffineMap newMap = permutationMap.compose(map);
+    // Apply the reverse transpose to deduce the type of the transfer_read.
+    ArrayRef<int64_t> originalShape = op.getVectorType().getShape();
+    SmallVector<int64_t> newVectorShape(originalShape.size());
+    for (auto pos : llvm::enumerate(permutation)) {
+      newVectorShape[pos.value()] = originalShape[pos.index()];
+    }
+    VectorType newReadType =
+        VectorType::get(newVectorShape, op.getVectorType().getElementType());
+    Value newRead = rewriter.create<vector::TransferReadOp>(
+        op.getLoc(), newReadType, op.source(), op.indices(), newMap,
+        op.padding(), op.masked() ? *op.masked() : ArrayAttr());
+    SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
+    rewriter.replaceOpWithNewOp<vector::TransposeOp>(op, newRead,
+                                                     transposePerm);
+    return success();
+  }
+};
+
+/// Lower transfer_read op with broadcast in the leading dimensions into
+/// transfer_read of lower rank + vector.broadcast.
+/// Ex: vector.transfer_read ...
+///         permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3)
+/// into:
+///     %v = vector.transfer_read ...
+///         permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
+///     vector.broadcast %v
+struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
+  using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::TransferReadOp op,
+                                PatternRewriter &rewriter) const override {
+    AffineMap map = op.permutation_map();
+    unsigned numLeadingBroadcast = 0;
+    for (auto expr : map.getResults()) {
+      auto dimExpr = expr.dyn_cast<AffineConstantExpr>();
+      if (!dimExpr || dimExpr.getValue() != 0)
+        break;
+      numLeadingBroadcast++;
+    }
+    // If there are no leading zeros in the map there is nothing to do.
+    if (numLeadingBroadcast == 0)
+      return failure();
+    VectorType originalVecType = op.getVectorType();
+    unsigned reducedShapeRank = originalVecType.getRank() - numLeadingBroadcast;
+    // Calculate new map, vector type and masks without the leading zeros.
+    AffineMap newMap = AffineMap::get(
+        map.getNumDims(), 0, map.getResults().take_back(reducedShapeRank),
+        op.getContext());
+    // Only remove the leading zeros if the rest of the map is a minor identity
+    // with broadasting. Otherwise we first want to permute the map.
+    if (!newMap.isMinorIdentityWithBroadcasting())
+      return failure();
+    SmallVector<int64_t> newShape = llvm::to_vector<4>(
+        originalVecType.getShape().take_back(reducedShapeRank));
+    VectorType newReadType =
+        VectorType::get(newShape, originalVecType.getElementType());
+    ArrayAttr newMask =
+        op.masked()
+            ? rewriter.getArrayAttr(
+                  op.maskedAttr().getValue().take_back(reducedShapeRank))
+            : ArrayAttr();
+    Value newRead = rewriter.create<vector::TransferReadOp>(
+        op.getLoc(), newReadType, op.source(), op.indices(), newMap,
+        op.padding(), newMask);
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
+                                                     newRead);
+    return success();
+  }
+};
+
 // Trims leading one dimensions from `oldType` and returns the result type.
 // Returns `vector<1xT>` if `oldType` only has one element.
 static VectorType trimLeadingOneDims(VectorType oldType) {
@@ -3317,6 +3424,8 @@ void mlir::vector::populateVectorContractLoweringPatterns(
 
 void mlir::vector::populateVectorTransferLoweringPatterns(
     RewritePatternSet &patterns) {
-  patterns.add<TransferReadToVectorLoadLowering,
-               TransferWriteToVectorStoreLowering>(patterns.getContext());
+  patterns
+      .add<TransferReadToVectorLoadLowering, TransferWriteToVectorStoreLowering,
+           TransferReadPermutationLowering, TransferOpReduceRank>(
+          patterns.getContext());
 }

diff  --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp
index 98ca45bbb6f6..dc9a5c54c7ff 100644
--- a/mlir/lib/IR/AffineMap.cpp
+++ b/mlir/lib/IR/AffineMap.cpp
@@ -12,6 +12,7 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/Support/MathExtras.h"
+#include "llvm/ADT/SmallBitVector.h"
 #include "llvm/ADT/SmallSet.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/Support/raw_ostream.h"
@@ -140,6 +141,66 @@ bool AffineMap::isMinorIdentityWithBroadcasting(
   return true;
 }
 
+/// Return true if this affine map can be converted to a minor identity with
+/// broadcast by doing a permute. Return a permutation (there may be
+/// several) to apply to get to a minor identity with broadcasts.
+/// Ex:
+///  * (d0, d1, d2) -> (0, d1) maps to minor identity (d1, 0 = d2) with
+///  perm = [1, 0] and broadcast d2
+///  * (d0, d1, d2) -> (d0, 0) cannot be mapped to a minor identity by
+///  permutation + broadcast
+///  * (d0, d1, d2, d3) -> (0, d1, d3) maps to minor identity (d1, 0 = d2, d3)
+///  with perm = [1, 0, 2] and broadcast d2
+///  * (d0, d1) -> (d1, 0, 0, d0) maps to minor identity (d0, d1) with extra
+///  leading broadcat dimensions. The map returned would be (0, 0, d0, d1) with
+///  perm = [3, 0, 1, 2]
+bool AffineMap::isPermutationOfMinorIdentityWithBroadcasting(
+    SmallVectorImpl<unsigned> &permutedDims) const {
+  unsigned projectionStart =
+      getNumResults() < getNumInputs() ? getNumInputs() - getNumResults() : 0;
+  permutedDims.clear();
+  SmallVector<unsigned> broadcastDims;
+  permutedDims.resize(getNumResults(), 0);
+  // If there are more results than input dimensions we want the new map to
+  // start with broadcast dimensions in order to be a minor identity with
+  // broadcasting.
+  unsigned leadingBroadcast =
+      getNumResults() > getNumInputs() ? getNumResults() - getNumInputs() : 0;
+  llvm::SmallBitVector dimFound(std::max(getNumInputs(), getNumResults()),
+                                false);
+  for (auto idxAndExpr : llvm::enumerate(getResults())) {
+    unsigned resIdx = idxAndExpr.index();
+    AffineExpr expr = idxAndExpr.value();
+    // Each result may be either a constant 0 (broadcast dimension) or a
+    // dimension.
+    if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
+      if (constExpr.getValue() != 0)
+        return false;
+      broadcastDims.push_back(resIdx);
+    } else if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) {
+      if (dimExpr.getPosition() < projectionStart)
+        return false;
+      unsigned newPosition =
+          dimExpr.getPosition() - projectionStart + leadingBroadcast;
+      permutedDims[resIdx] = newPosition;
+      dimFound[newPosition] = true;
+    } else {
+      return false;
+    }
+  }
+  // Find a permuation for the broadcast dimension. Since they are broadcasted
+  // any valid permutation is acceptable. We just permute the dim into a slot
+  // without an existing dimension.
+  unsigned pos = 0;
+  for (auto dim : broadcastDims) {
+    while (pos < dimFound.size() && dimFound[pos]) {
+      pos++;
+    }
+    permutedDims[dim] = pos++;
+  }
+  return true;
+}
+
 /// Returns an AffineMap representing a permutation.
 AffineMap AffineMap::getPermutationMap(ArrayRef<unsigned> permutation,
                                        MLIRContext *context) {

diff  --git a/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir
index ff32f7d5c823..10f32edb019e 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir
@@ -206,3 +206,56 @@ func @transfer_broadcasting_complex(%mem : memref<10x20x30x8x8xf32>, %i : index)
   %res = vector.transfer_read %mem[%i, %i, %i, %i, %i], %cf0 {masked = [false, false, false, false], permutation_map = #broadcast} : memref<10x20x30x8x8xf32>, vector<3x2x4x5xf32>
   return %res : vector<3x2x4x5xf32>
 }
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2, d3) -> (d1, d0, 0, 0)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d0)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d3, d1, 0, 0)>
+#map3 = affine_map<(d0, d1) -> (d1, d0, 0, 0)>
+#map4 = affine_map<(d0, d1) -> (0, d1, 0, d0)>
+#map5 = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3, d0)>
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, 0, 0)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, 0, d3)>
+
+// CHECK-LABEL: func @transfer_read_permutations
+func @transfer_read_permutations(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?x?x?xf32>)
+    -> (vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, vector<7x14x8x16xf32>,
+       vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, vector<7x14x8x16xf32>) {
+// CHECK-DAG: %[[CF0:.*]] = constant 0.000000e+00 : f32
+// CHECK-DAG: %[[C0:.*]] = constant 0 : index
+  %cst = constant 0.000000e+00 : f32
+  %c0 = constant 0 : index
+
+  %0 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst {permutation_map = #map0} : memref<?x?x?x?xf32>, vector<7x14x8x16xf32>
+// CHECK: vector.transfer_read {{.*}} {permutation_map = #[[$MAP0]]} : memref<?x?x?x?xf32>, vector<14x7x8x16xf32>
+// CHECK: vector.transpose %{{.*}}, [1, 0, 2, 3] : vector<14x7x8x16xf32> to vector<7x14x8x16xf32>
+
+  %1 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst {permutation_map = #map1} : memref<?x?x?x?xf32>, vector<7x14x8x16xf32>
+// CHECK: vector.transfer_read {{.*}} {permutation_map = #[[$MAP0]]} : memref<?x?x?x?xf32>, vector<16x14x7x8xf32>
+// CHECK: vector.transpose %{{.*}}, [2, 1, 3, 0] : vector<16x14x7x8xf32> to vector<7x14x8x16xf32>
+
+  %2 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst {masked = [false, false, true, false], permutation_map = #map2} : memref<?x?x?x?xf32>, vector<7x14x8x16xf32>
+// CHECK: vector.transfer_read {{.*}} {masked = [false, true, false], permutation_map = #[[$MAP1]]} : memref<?x?x?x?xf32>, vector<14x16x7xf32>
+// CHECK: vector.broadcast %{{.*}} : vector<14x16x7xf32> to vector<8x14x16x7xf32>
+// CHECK: vector.transpose %{{.*}}, [3, 1, 0, 2] : vector<8x14x16x7xf32> to vector<7x14x8x16xf32>
+
+  %3 = vector.transfer_read %arg0[%c0, %c0], %cst {permutation_map = #map3} : memref<?x?xf32>, vector<7x14x8x16xf32>
+// CHECK: vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %[[CF0]] : memref<?x?xf32>, vector<14x7xf32>
+// CHECK: vector.broadcast %{{.*}} : vector<14x7xf32> to vector<8x16x14x7xf32>
+// CHECK: vector.transpose %{{.*}}, [3, 2, 0, 1] : vector<8x16x14x7xf32> to vector<7x14x8x16xf32>
+
+  %4 = vector.transfer_read %arg0[%c0, %c0], %cst {permutation_map = #map4} : memref<?x?xf32>, vector<7x14x8x16xf32>
+// CHECK: vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %[[CF0]] : memref<?x?xf32>, vector<16x14xf32>
+// CHECK: vector.broadcast %{{.*}} : vector<16x14xf32> to vector<7x8x16x14xf32>
+// CHECK: vector.transpose %{{.*}}, [0, 3, 1, 2] : vector<7x8x16x14xf32> to vector<7x14x8x16xf32>
+
+  %5 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst {permutation_map = #map5} : memref<?x?x?x?xf32>, vector<7x14x8x16xf32>
+// CHECK: vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[CF0]] : memref<?x?x?x?xf32>, vector<16x14x7x8xf32>
+// CHECK: vector.transpose %{{.*}}, [2, 1, 3, 0] : vector<16x14x7x8xf32> to vector<7x14x8x16xf32>
+
+  return %0, %1, %2, %3, %4, %5 : vector<7x14x8x16xf32>, vector<7x14x8x16xf32>,
+         vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, vector<7x14x8x16xf32>,
+         vector<7x14x8x16xf32>
+}


        


More information about the Mlir-commits mailing list