[Mlir-commits] [mlir] [mlir][vector] Add distribution pattern for vector.create_mask (PR #71619)

Quinn Dawkins llvmlistbot at llvm.org
Wed Nov 8 07:47:38 PST 2023


https://github.com/qedawkins updated https://github.com/llvm/llvm-project/pull/71619

>From b847b4ebe1c40eb842ac24ac84735b75bc3751ad Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Sun, 5 Nov 2023 21:22:07 -0500
Subject: [PATCH 1/2] [MLIR][Vector] Add distribution pattern for
 vector.create_mask

This is the last step needed for basic support for distributing masked
vector code. The lane id gets delinearized based on the distributed mask
shape and then compared against the original mask sizes to compute the
bounds for the distributed mask.
---
 .../Vector/Transforms/VectorDistribute.cpp    | 84 ++++++++++++++++++-
 .../Vector/vector-warp-distribute.mlir        | 60 +++++++++++++
 2 files changed, 143 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index e128cc71a5d628c..9d5354300efe31a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1047,6 +1047,88 @@ struct WarpOpShapeCast : public OpRewritePattern<WarpExecuteOnLane0Op> {
   }
 };
 
+/// Sink out vector.create_mask op feeding into a warp op yield.
+/// ```
+/// %0 = ...
+/// %1 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
+///   ...
+///   %mask = vector.create_mask %0 : vector<32xi1>
+///   vector.yield %mask : vector<32xi1>
+/// }
+/// ```
+/// To
+/// ```
+/// %0 = ...
+/// vector.warp_execute_on_lane_0(%arg0) {
+///   ...
+/// }
+/// %cmp = arith.cmpi ult, %laneid, %0
+/// %ub = arith.select %cmp, %c0, %c1
+/// %1 = vector.create_mask %ub : vector<1xi1>
+struct WarpOpCreateMask : public OpRewritePattern<WarpExecuteOnLane0Op> {
+  using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
+  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *yieldOperand = getWarpResult(
+        warpOp, [](Operation *op) { return isa<vector::CreateMaskOp>(op); });
+    if (!yieldOperand)
+      return failure();
+
+    auto mask = yieldOperand->get().getDefiningOp<vector::CreateMaskOp>();
+    Location loc = mask.getLoc();
+    unsigned operandIndex = yieldOperand->getOperandNumber();
+
+    auto distType = cast<VectorType>(warpOp.getResult(operandIndex).getType());
+    VectorType seqType = mask.getVectorType();
+    auto seqShape = seqType.getShape();
+    auto distShape = distType.getShape();
+
+    rewriter.setInsertionPointAfter(warpOp);
+
+    // Delinearize the lane ID for constructing the distributed mask sizes.
+    SmallVector<Value> delinearizedIds;
+    if (!delinearizeLaneId(rewriter, loc, seqShape, distShape,
+                           warpOp.getWarpSize(), warpOp.getLaneid(),
+                           delinearizedIds))
+      return rewriter.notifyMatchFailure(
+          mask, "cannot delinearize lane ID for distribution");
+    assert(!delinearizedIds.empty());
+
+    AffineExpr s0, s1;
+    bindSymbols(rewriter.getContext(), s0, s1);
+    Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    Value prevIsZero;
+    SmallVector<Value> newOperands;
+    for (int i = 0, e = distShape.size(); i < e; ++i) {
+      // Get `mask_size[i] - lane_id[i] * (seq_sizes[i]/dist_sizes[i])` to find
+      // the distance from the largest mask index owned by this lane to the
+      // original mask size.
+      Value maskDimIdx = affine::makeComposedAffineApply(
+          rewriter, loc, s1 - s0 * distShape[i],
+          {delinearizedIds[i], mask.getOperand(i)});
+      // Clamp to the range [0, dist_sizes[i]].
+      Value clampZero = rewriter.create<arith::MaxSIOp>(loc, zero, maskDimIdx);
+      Value vecSizeVal = rewriter.create<arith::ConstantOp>(
+          loc, rewriter.getIndexAttr(distShape[i]));
+      Value clampVecSize =
+          rewriter.create<arith::MinSIOp>(loc, vecSizeVal, clampZero);
+      // If a previous mask size is zero, all trailing sizes must also be zero.
+      if (prevIsZero) {
+        clampVecSize = rewriter.create<arith::SelectOp>(loc, prevIsZero, zero,
+                                                        clampVecSize);
+      }
+      prevIsZero = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
+                                                  clampVecSize, zero);
+      newOperands.push_back(clampVecSize);
+    }
+
+    auto newMask =
+        rewriter.create<vector::CreateMaskOp>(loc, distType, newOperands);
+    rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), newMask);
+    return success();
+  }
+};
+
 /// Pattern to move out vector.extract of single element vector. Those don't
 /// need to be distributed and can just be propagated outside of the region.
 struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
@@ -1701,7 +1783,7 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
   patterns.add<WarpOpElementwise, WarpOpTransferRead, WarpOpDeadResult,
                WarpOpBroadcast, WarpOpShapeCast, WarpOpExtract,
                WarpOpForwardOperand, WarpOpConstant, WarpOpInsertElement,
-               WarpOpInsert>(patterns.getContext(), benefit);
+               WarpOpCreateMask, WarpOpInsert>(patterns.getContext(), benefit);
   patterns.add<WarpOpExtractElement>(patterns.getContext(),
                                      warpShuffleFromIdxFn, benefit);
   patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index f050bcd246e5ef7..b8c6b7fa99418d0 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1281,3 +1281,63 @@ func.func @warp_propagate_masked_write(%laneid: index, %dest: memref<4096xf32>)
 //       CHECK-DIST-AND-PROP:   }
 //       CHECK-DIST-AND-PROP:   vector.transfer_write %[[W]]#2, {{.*}}, %[[W]]#3 {in_bounds = [true]} : vector<128xf32>, memref<4096xf32>
 //       CHECK-DIST-AND-PROP:   vector.transfer_write %[[W]]#0, {{.*}}, %[[W]]#1 {in_bounds = [true]} : vector<1xf32>, memref<4096xf32>
+
+// -----
+
+func.func @warp_propagate_create_mask(%laneid: index, %m0: index) -> vector<1xi1> {
+  %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xi1>) {
+    %1 = vector.create_mask %m0 : vector<32xi1>
+    vector.yield %1 : vector<32xi1>
+  }
+  return %r : vector<1xi1>
+}
+
+//   CHECK-PROP-DAG: #[[$SUB:.*]] = affine_map<()[s0, s1] -> (-s0 + s1)>
+// CHECK-PROP-LABEL: func @warp_propagate_create_mask
+//  CHECK-PROP-SAME: %[[LANEID:.+]]: index, %[[M0:.+]]: index
+//       CHECK-PROP:   %[[C0:.+]] = arith.constant 0 : index
+//       CHECK-PROP:   %[[C1:.+]] = arith.constant 1 : index
+//       CHECK-PROP:   %[[MBOUNDDIST:.+]] = affine.apply #[[$SUB]]()[%[[LANEID]], %[[M0]]]
+//       CHECK-PROP:   %[[CLAMPZERO:.+]] = arith.maxsi %[[MBOUNDDIST]], %[[C0]] : index
+//       CHECK-PROP:   %[[MDIST:.+]] = arith.minsi %[[CLAMPZERO]], %[[C1]] : index
+//       CHECK-PROP:   vector.create_mask %[[MDIST]] : vector<1xi1>
+
+// -----
+
+func.func @warp_propagate_multi_dim_create_mask(%laneid: index, %m0: index, %m1: index, %m2: index) -> vector<1x2x4xi1> {
+  %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<1x2x4xi1>) {
+    %1 = vector.create_mask %m0, %m1, %m2 : vector<16x4x4xi1>
+    vector.yield %1 : vector<16x4x4xi1>
+  }
+  return %r : vector<1x2x4xi1>
+}
+
+//   CHECK-PROP-DAG: #[[$SUBM0:.*]] = affine_map<()[s0, s1] -> (s0 - s1 floordiv 2)>
+//   CHECK-PROP-DAG: #[[$SUBM1:.*]] = affine_map<()[s0, s1] -> (s0 - s1 * 2 + (s1 floordiv 2) * 4)>
+// CHECK-PROP-LABEL: func @warp_propagate_multi_dim_create_mask
+//  CHECK-PROP-SAME: %[[LANEID:.+]]: index, %[[M0:.+]]: index, %[[M1:.+]]: index, %[[M2:.+]]: index
+//   CHECK-PROP-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//   CHECK-PROP-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//   CHECK-PROP-DAG:   %[[C2:.+]] = arith.constant 2 : index
+//   CHECK-PROP-DAG:   %[[C4:.+]] = arith.constant 4 : index
+
+// Compute distributed m0 based on the first (outermost) delinearized lane id.
+//       CHECK-PROP:   affine.apply #map()[%[[M0]], %[[LANEID]]]
+//       CHECK-PROP:   arith.maxsi {{.*}}, %[[C0]] : index
+//       CHECK-PROP:   %[[DISTM0:.+]] = arith.minsi {{.*}}, %[[C1]] : index
+//       CHECK-PROP:   arith.cmpi eq, {{.*}}, %[[C0]] : index
+
+// Compute m1 based on the second delinearized lane id. If m0 is zero, m1 must
+// also be zero for the mask to be valid.
+//       CHECK-PROP:   affine.apply #map1()[%[[M1]], %[[LANEID]]]
+//       CHECK-PROP:   arith.maxsi {{.*}}, %[[C0]] : index
+//       CHECK-PROP:   arith.minsi {{.*}}, %[[C2]] : index
+//       CHECK-PROP:   %[[DISTM1:.+]] = arith.select {{.*}}, %[[C0]], {{.*}} : index
+
+// Compute m3 and propagate zeros.
+//       CHECK-PROP:   arith.cmpi eq, {{.*}}, %[[C0]] : index
+//       CHECK-PROP:   arith.maxsi %[[M2]], %[[C0]] : index
+//       CHECK-PROP:   arith.minsi {{.*}}, %[[C4]] : index
+//       CHECK-PROP:   %[[DISTM2:.+]] = arith.select {{.*}}, %[[C0]], {{.*}} : index
+
+//       CHECK-PROP:   vector.create_mask %[[DISTM0]], %[[DISTM1]], %[[DISTM2]] : vector<1x2x4xi1>

>From f72afa3a0852c21dde6cd1ca7ed2f2af6aa1ff75 Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Wed, 8 Nov 2023 10:30:57 -0500
Subject: [PATCH 2/2] Simplify create_mask distribution pattern based on
 create_mask semantics

---
 .../Vector/Transforms/VectorDistribute.cpp    | 28 ++++++---------
 .../Vector/vector-warp-distribute.mlir        | 34 +++----------------
 2 files changed, 15 insertions(+), 47 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 9d5354300efe31a..e7c4be6325d87ea 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1075,6 +1075,14 @@ struct WarpOpCreateMask : public OpRewritePattern<WarpExecuteOnLane0Op> {
       return failure();
 
     auto mask = yieldOperand->get().getDefiningOp<vector::CreateMaskOp>();
+
+    // Early exit if any values needed for calculating the new mask indices
+    // are defined inside the warp op.
+    if (!llvm::all_of(mask->getOperands(), [&](Value value) {
+          return warpOp.isDefinedOutsideOfRegion(value);
+        }))
+      return failure();
+
     Location loc = mask.getLoc();
     unsigned operandIndex = yieldOperand->getOperandNumber();
 
@@ -1096,30 +1104,16 @@ struct WarpOpCreateMask : public OpRewritePattern<WarpExecuteOnLane0Op> {
 
     AffineExpr s0, s1;
     bindSymbols(rewriter.getContext(), s0, s1);
-    Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
-    Value prevIsZero;
     SmallVector<Value> newOperands;
     for (int i = 0, e = distShape.size(); i < e; ++i) {
       // Get `mask_size[i] - lane_id[i] * (seq_sizes[i]/dist_sizes[i])` to find
       // the distance from the largest mask index owned by this lane to the
-      // original mask size.
+      // original mask size. vector.create_mask implicitly clamps mask sizes to
+      // the range [0, mask_vector_size[i]].
       Value maskDimIdx = affine::makeComposedAffineApply(
           rewriter, loc, s1 - s0 * distShape[i],
           {delinearizedIds[i], mask.getOperand(i)});
-      // Clamp to the range [0, dist_sizes[i]].
-      Value clampZero = rewriter.create<arith::MaxSIOp>(loc, zero, maskDimIdx);
-      Value vecSizeVal = rewriter.create<arith::ConstantOp>(
-          loc, rewriter.getIndexAttr(distShape[i]));
-      Value clampVecSize =
-          rewriter.create<arith::MinSIOp>(loc, vecSizeVal, clampZero);
-      // If a previous mask size is zero, all trailing sizes must also be zero.
-      if (prevIsZero) {
-        clampVecSize = rewriter.create<arith::SelectOp>(loc, prevIsZero, zero,
-                                                        clampVecSize);
-      }
-      prevIsZero = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
-                                                  clampVecSize, zero);
-      newOperands.push_back(clampVecSize);
+      newOperands.push_back(maskDimIdx);
     }
 
     auto newMask =
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index b8c6b7fa99418d0..c5592c649e8ce93 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1295,11 +1295,7 @@ func.func @warp_propagate_create_mask(%laneid: index, %m0: index) -> vector<1xi1
 //   CHECK-PROP-DAG: #[[$SUB:.*]] = affine_map<()[s0, s1] -> (-s0 + s1)>
 // CHECK-PROP-LABEL: func @warp_propagate_create_mask
 //  CHECK-PROP-SAME: %[[LANEID:.+]]: index, %[[M0:.+]]: index
-//       CHECK-PROP:   %[[C0:.+]] = arith.constant 0 : index
-//       CHECK-PROP:   %[[C1:.+]] = arith.constant 1 : index
-//       CHECK-PROP:   %[[MBOUNDDIST:.+]] = affine.apply #[[$SUB]]()[%[[LANEID]], %[[M0]]]
-//       CHECK-PROP:   %[[CLAMPZERO:.+]] = arith.maxsi %[[MBOUNDDIST]], %[[C0]] : index
-//       CHECK-PROP:   %[[MDIST:.+]] = arith.minsi %[[CLAMPZERO]], %[[C1]] : index
+//       CHECK-PROP:   %[[MDIST:.+]] = affine.apply #[[$SUB]]()[%[[LANEID]], %[[M0]]]
 //       CHECK-PROP:   vector.create_mask %[[MDIST]] : vector<1xi1>
 
 // -----
@@ -1316,28 +1312,6 @@ func.func @warp_propagate_multi_dim_create_mask(%laneid: index, %m0: index, %m1:
 //   CHECK-PROP-DAG: #[[$SUBM1:.*]] = affine_map<()[s0, s1] -> (s0 - s1 * 2 + (s1 floordiv 2) * 4)>
 // CHECK-PROP-LABEL: func @warp_propagate_multi_dim_create_mask
 //  CHECK-PROP-SAME: %[[LANEID:.+]]: index, %[[M0:.+]]: index, %[[M1:.+]]: index, %[[M2:.+]]: index
-//   CHECK-PROP-DAG:   %[[C0:.+]] = arith.constant 0 : index
-//   CHECK-PROP-DAG:   %[[C1:.+]] = arith.constant 1 : index
-//   CHECK-PROP-DAG:   %[[C2:.+]] = arith.constant 2 : index
-//   CHECK-PROP-DAG:   %[[C4:.+]] = arith.constant 4 : index
-
-// Compute distributed m0 based on the first (outermost) delinearized lane id.
-//       CHECK-PROP:   affine.apply #map()[%[[M0]], %[[LANEID]]]
-//       CHECK-PROP:   arith.maxsi {{.*}}, %[[C0]] : index
-//       CHECK-PROP:   %[[DISTM0:.+]] = arith.minsi {{.*}}, %[[C1]] : index
-//       CHECK-PROP:   arith.cmpi eq, {{.*}}, %[[C0]] : index
-
-// Compute m1 based on the second delinearized lane id. If m0 is zero, m1 must
-// also be zero for the mask to be valid.
-//       CHECK-PROP:   affine.apply #map1()[%[[M1]], %[[LANEID]]]
-//       CHECK-PROP:   arith.maxsi {{.*}}, %[[C0]] : index
-//       CHECK-PROP:   arith.minsi {{.*}}, %[[C2]] : index
-//       CHECK-PROP:   %[[DISTM1:.+]] = arith.select {{.*}}, %[[C0]], {{.*}} : index
-
-// Compute m3 and propagate zeros.
-//       CHECK-PROP:   arith.cmpi eq, {{.*}}, %[[C0]] : index
-//       CHECK-PROP:   arith.maxsi %[[M2]], %[[C0]] : index
-//       CHECK-PROP:   arith.minsi {{.*}}, %[[C4]] : index
-//       CHECK-PROP:   %[[DISTM2:.+]] = arith.select {{.*}}, %[[C0]], {{.*}} : index
-
-//       CHECK-PROP:   vector.create_mask %[[DISTM0]], %[[DISTM1]], %[[DISTM2]] : vector<1x2x4xi1>
+//       CHECK-PROP:   %[[DISTM0:.+]] = affine.apply #[[$SUBM0]]()[%[[M0]], %[[LANEID]]]
+//       CHECK-PROP:   %[[DISTM1:.+]] = affine.apply #[[$SUBM1]]()[%[[M1]], %[[LANEID]]]
+//       CHECK-PROP:   vector.create_mask %[[DISTM0]], %[[DISTM1]], %[[M2]] : vector<1x2x4xi1>



More information about the Mlir-commits mailing list