[Mlir-commits] [mlir] 60da33c - [mlir] Support masks in TransferOpReduceRank and TransferReadPermutationLowering

Matthias Springer llvmlistbot at llvm.org
Wed May 12 23:16:37 PDT 2021


Author: Matthias Springer
Date: 2021-05-13T15:08:08+09:00
New Revision: 60da33c2d4b2cd744c70259088b2ab89eb858f33

URL: https://github.com/llvm/llvm-project/commit/60da33c2d4b2cd744c70259088b2ab89eb858f33
DIFF: https://github.com/llvm/llvm-project/commit/60da33c2d4b2cd744c70259088b2ab89eb858f33.diff

LOG: [mlir] Support masks in TransferOpReduceRank and TransferReadPermutationLowering

These two patterns allow for more efficient codegen in VectorToSCF.

Differential Revision: https://reviews.llvm.org/D102222

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/Dialect/Vector/vector-transfer-lowering.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 30b3c4a32c20..1dd04a2e4b21 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -3069,13 +3069,11 @@ struct TransferReadPermutationLowering
     AffineMap map = op.permutation_map();
     if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
       return failure();
-
     AffineMap permutationMap =
         map.getPermutationMap(permutation, op.getContext());
     if (permutationMap.isIdentity())
       return failure();
-    if (op.mask())
-      return failure();
+
     // Caluclate the map of the new read by applying the inverse permutation.
     permutationMap = inversePermutation(permutationMap);
     AffineMap newMap = permutationMap.compose(map);
@@ -3085,11 +3083,32 @@ struct TransferReadPermutationLowering
     for (auto pos : llvm::enumerate(permutation)) {
       newVectorShape[pos.value()] = originalShape[pos.index()];
     }
+
+    Value newMask;
+    if (op.mask()) {
+      // Remove unused dims from the permutation map. E.g.:
+      // E.g.:  (d0, d1, d2, d3, d4, d5) -> (d5, 0, d3, 0, d2)
+      // comp = (d0, d1, d2) -> (d2, 0, d1, 0 d0)
+      auto comp = compressUnusedDims(map);
+      // Get positions of remaining result dims.
+      // E.g.:  (d0, d1, d2) -> (d2, 0, d1, 0 d0)
+      // maskTransposeIndices = [ 2,     1,    0]
+      SmallVector<int64_t> maskTransposeIndices;
+      for (unsigned i = 0; i < comp.getNumResults(); ++i) {
+        if (auto expr = comp.getResult(i).dyn_cast<AffineDimExpr>())
+          maskTransposeIndices.push_back(expr.getPosition());
+      }
+
+      newMask = rewriter.create<vector::TransposeOp>(op.getLoc(), op.mask(),
+                                                     maskTransposeIndices);
+    }
+
     VectorType newReadType =
         VectorType::get(newVectorShape, op.getVectorType().getElementType());
     Value newRead = rewriter.create<vector::TransferReadOp>(
         op.getLoc(), newReadType, op.source(), op.indices(), newMap,
-        op.padding(), op.in_bounds() ? *op.in_bounds() : ArrayAttr());
+        op.padding(), newMask, op.in_bounds() ? *op.in_bounds() : ArrayAttr());
+
     SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
     rewriter.replaceOpWithNewOp<vector::TransposeOp>(op, newRead,
                                                      transposePerm);
@@ -3110,8 +3129,6 @@ struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
 
   LogicalResult matchAndRewrite(vector::TransferReadOp op,
                                 PatternRewriter &rewriter) const override {
-    if (op.mask())
-      return failure();
     AffineMap map = op.permutation_map();
     unsigned numLeadingBroadcast = 0;
     for (auto expr : map.getResults()) {
@@ -3147,7 +3164,7 @@ struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
             : ArrayAttr();
     Value newRead = rewriter.create<vector::TransferReadOp>(
         op.getLoc(), newReadType, op.source(), op.indices(), newMap,
-        op.padding(), newInBounds);
+        op.padding(), op.mask(), newInBounds);
     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
                                                      newRead);
     return success();

diff  --git a/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir
index 22f375c60744..98683de90990 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir
@@ -228,17 +228,24 @@ func @transfer_read_permutations(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?x?x?
 // CHECK-DAG: %[[C0:.*]] = constant 0 : index
   %cst = constant 0.000000e+00 : f32
   %c0 = constant 0 : index
+  %m = constant 1 : i1
 
-  %0 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst {permutation_map = #map0} : memref<?x?x?x?xf32>, vector<7x14x8x16xf32>
-// CHECK: vector.transfer_read {{.*}} {permutation_map = #[[$MAP0]]} : memref<?x?x?x?xf32>, vector<14x7x8x16xf32>
+  %mask0 = splat %m : vector<7x14xi1>
+  %0 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst, %mask0 {permutation_map = #map0} : memref<?x?x?x?xf32>, vector<7x14x8x16xf32>
+// CHECK: %[[MASK0:.*]] = vector.transpose {{.*}} : vector<7x14xi1> to vector<14x7xi1>
+// CHECK: vector.transfer_read {{.*}} %[[MASK0]] {permutation_map = #[[$MAP0]]} : memref<?x?x?x?xf32>, vector<14x7x8x16xf32>
 // CHECK: vector.transpose %{{.*}}, [1, 0, 2, 3] : vector<14x7x8x16xf32> to vector<7x14x8x16xf32>
 
-  %1 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst {permutation_map = #map1} : memref<?x?x?x?xf32>, vector<7x14x8x16xf32>
-// CHECK: vector.transfer_read {{.*}} {permutation_map = #[[$MAP0]]} : memref<?x?x?x?xf32>, vector<16x14x7x8xf32>
+  %mask1 = splat %m : vector<14x16xi1>
+  %1 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst, %mask1 {permutation_map = #map1} : memref<?x?x?x?xf32>, vector<7x14x8x16xf32>
+// CHECK: %[[MASK1:.*]] = vector.transpose {{.*}} : vector<14x16xi1> to vector<16x14xi1>
+// CHECK: vector.transfer_read {{.*}} %[[MASK1]] {permutation_map = #[[$MAP0]]} : memref<?x?x?x?xf32>, vector<16x14x7x8xf32>
 // CHECK: vector.transpose %{{.*}}, [2, 1, 3, 0] : vector<16x14x7x8xf32> to vector<7x14x8x16xf32>
 
-  %2 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, false, true], permutation_map = #map2} : memref<?x?x?x?xf32>, vector<7x14x8x16xf32>
-// CHECK: vector.transfer_read {{.*}} {in_bounds = [true, false, true], permutation_map = #[[$MAP1]]} : memref<?x?x?x?xf32>, vector<14x16x7xf32>
+  %mask2 = splat %m : vector<7x14xi1>
+  %2 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst, %mask2 {in_bounds = [true, true, false, true], permutation_map = #map2} : memref<?x?x?x?xf32>, vector<7x14x8x16xf32>
+// CHECK: %[[MASK2:.*]] = vector.transpose {{.*}} : vector<7x14xi1> to vector<14x7xi1>
+// CHECK: vector.transfer_read {{.*}} %[[MASK2]] {in_bounds = [true, false, true], permutation_map = #[[$MAP1]]} : memref<?x?x?x?xf32>, vector<14x16x7xf32>
 // CHECK: vector.broadcast %{{.*}} : vector<14x16x7xf32> to vector<8x14x16x7xf32>
 // CHECK: vector.transpose %{{.*}}, [3, 1, 0, 2] : vector<8x14x16x7xf32> to vector<7x14x8x16xf32>
 


        


More information about the Mlir-commits mailing list