[Mlir-commits] [mlir] 7ddeffe - [mlir] Lower permutation maps on TransferWriteOps

Matthias Springer llvmlistbot at llvm.org
Sun May 16 23:39:09 PDT 2021


Author: Matthias Springer
Date: 2021-05-17T15:30:46+09:00
New Revision: 7ddeffee55766005327abbac85838225069cc164

URL: https://github.com/llvm/llvm-project/commit/7ddeffee55766005327abbac85838225069cc164
DIFF: https://github.com/llvm/llvm-project/commit/7ddeffee55766005327abbac85838225069cc164.diff

LOG: [mlir] Lower permutation maps on TransferWriteOps

Add TransferWritePermutationLowering, which replaces permutation maps of TransferWriteOps with vector.transpose.

Differential Revision: https://reviews.llvm.org/D102548

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/Dialect/Vector/vector-transfer-lowering.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 1effce2f5679c..c7a0623f3b32a 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -3137,6 +3137,70 @@ struct TransferReadPermutationLowering
   }
 };
 
+/// Lower transfer_write op with permutation into a transfer_write with a
+/// minor identity permutation map. (transfer_write ops cannot have broadcasts.)
+/// Ex:
+///     vector.transfer_write %v ...
+///         permutation_map: (d0, d1, d2) -> (d2, d0, d1)
+/// into:
+///     %tmp = vector.transpose %v, [2, 0, 1]
+///     vector.transfer_write %tmp ...
+///         permutation_map: (d0, d1, d2) -> (d0, d1, d2)
+///
+///     vector.transfer_write %v ...
+///         permutation_map: (d0, d1, d2, d3) -> (d3, d2)
+/// into:
+///     %tmp = vector.transpose %v, [1, 0]
+///     %v = vector.transfer_write %tmp ...
+///         permutation_map: (d0, d1, d2, d3) -> (d2, d3)
+struct TransferWritePermutationLowering
+    : public OpRewritePattern<vector::TransferWriteOp> {
+  using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::TransferWriteOp op,
+                                PatternRewriter &rewriter) const override {
+    SmallVector<unsigned> permutation;
+    AffineMap map = op.permutation_map();
+    if (map.isMinorIdentity())
+      return failure();
+    if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
+      return failure();
+
+    // Remove unused dims from the permutation map. E.g.:
+    // E.g.:  (d0, d1, d2, d3, d4, d5) -> (d5, d3, d4)
+    // comp = (d0, d1, d2) -> (d2, d0, d1)
+    auto comp = compressUnusedDims(map);
+    // Get positions of remaining result dims.
+    SmallVector<int64_t> indices;
+    llvm::transform(comp.getResults(), std::back_inserter(indices),
+                    [](AffineExpr expr) {
+      return expr.dyn_cast<AffineDimExpr>().getPosition();
+    });
+
+    // Transpose mask operand.
+    Value newMask = op.mask()
+        ? rewriter.create<vector::TransposeOp>(op.getLoc(), op.mask(), indices)
+        : Value();
+
+    // Transpose in_bounds attribute.
+    ArrayAttr newInBounds = op.in_bounds()
+        ? transposeInBoundsAttr(rewriter, op.in_bounds().getValue(),
+                                permutation)
+        : ArrayAttr();
+
+    // Generate new transfer_write operation.
+    Value newVec = rewriter.create<vector::TransposeOp>(
+        op.getLoc(), op.vector(), indices);
+    auto newMap = AffineMap::getMinorIdentityMap(
+        map.getNumDims(), map.getNumResults(), rewriter.getContext());
+    rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
+        op, Type(), newVec, op.source(), op.indices(), newMap, newMask,
+        newInBounds);
+
+    return success();
+  }
+};
+
 /// Lower transfer_read op with broadcast in the leading dimensions into
 /// transfer_read of lower rank + vector.broadcast.
 /// Ex: vector.transfer_read ...
@@ -4089,7 +4153,8 @@ void mlir::vector::populateVectorTransferLoweringPatterns(
     RewritePatternSet &patterns) {
   patterns
       .add<TransferReadToVectorLoadLowering, TransferWriteToVectorStoreLowering,
-           TransferReadPermutationLowering, TransferOpReduceRank>(
+           TransferReadPermutationLowering, TransferWritePermutationLowering,
+           TransferOpReduceRank>(
           patterns.getContext());
 }
 

diff  --git a/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir
index 28a267967942c..60bbadf59874a 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir
@@ -267,3 +267,25 @@ func @transfer_read_permutations(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?x?x?
          vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, vector<7x14x8x16xf32>,
          vector<7x14x8x16xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @transfer_write_permutations
+func @transfer_write_permutations(%arg0 : memref<?x?x?x?xf32>,
+    %v1 : vector<7x14x8x16xf32>, %v2 : vector<8x16xf32>) -> () {
+  // CHECK-DAG: %[[C0:.*]] = constant 0 : index
+  %c0 = constant 0 : index
+  %m = constant 1 : i1
+
+  %mask0 = splat %m : vector<7x14x8x16xi1>
+  vector.transfer_write %v1, %arg0[%c0, %c0, %c0, %c0], %mask0 {in_bounds = [true, false, false, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3, d0)>} : vector<7x14x8x16xf32>, memref<?x?x?x?xf32>
+  // CHECK: %[[NEW_MASK0:.*]] = vector.transpose %{{.*}} [2, 1, 3, 0] : vector<7x14x8x16xi1> to vector<8x14x16x7xi1>
+  // CHECK: %[[NEW_VEC0:.*]] = vector.transpose %{{.*}} [2, 1, 3, 0] : vector<7x14x8x16xf32> to vector<8x14x16x7xf32>
+  // CHECK: vector.transfer_write %[[NEW_VEC0]], %arg0[%c0, %c0, %c0, %c0], %[[NEW_MASK0]] {in_bounds = [false, false, true, true]} : vector<8x14x16x7xf32>, memref<?x?x?x?xf32>
+
+  vector.transfer_write %v2, %arg0[%c0, %c0, %c0, %c0] {permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>} : vector<8x16xf32>, memref<?x?x?x?xf32>
+  // CHECK: %[[NEW_VEC1:.*]] = vector.transpose %{{.*}} [1, 0] : vector<8x16xf32> to vector<16x8xf32>
+  // CHECK: vector.transfer_write %[[NEW_VEC1]], %arg0[%c0, %c0, %c0, %c0] : vector<16x8xf32>, memref<?x?x?x?xf32>
+
+  return
+}


        


More information about the Mlir-commits mailing list