[Mlir-commits] [mlir] [mlir][vector] Add distribution pattern for vector.create_mask (PR #71619)

Quinn Dawkins llvmlistbot at llvm.org
Fri Nov 10 06:07:01 PST 2023


https://github.com/qedawkins updated https://github.com/llvm/llvm-project/pull/71619

>From 00646830e848f219ef5189f1433919ce676d5b87 Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Sun, 5 Nov 2023 21:22:07 -0500
Subject: [PATCH 1/3] [MLIR][Vector] Add distribution pattern for
 vector.create_mask

This is the last step needed for basic support for distributing masked
vector code. The lane id gets delinearized based on the distributed mask
shape and then compared against the original mask sizes to compute the
bounds for the distributed mask.
---
 .../Vector/Transforms/VectorDistribute.cpp    | 84 ++++++++++++++++++-
 .../Vector/vector-warp-distribute.mlir        | 60 +++++++++++++
 2 files changed, 143 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 334d23e08419cea..355de7558fd60f5 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1078,6 +1078,88 @@ struct WarpOpShapeCast : public OpRewritePattern<WarpExecuteOnLane0Op> {
   }
 };
 
+/// Sink out vector.create_mask op feeding into a warp op yield.
+/// ```
+/// %0 = ...
+/// %1 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
+///   ...
+///   %mask = vector.create_mask %0 : vector<32xi1>
+///   vector.yield %mask : vector<32xi1>
+/// }
+/// ```
+/// To
+/// ```
+/// %0 = ...
+/// vector.warp_execute_on_lane_0(%arg0) {
+///   ...
+/// }
+/// %cmp = arith.cmpi ult, %laneid, %0
+/// %ub = arith.select %cmp, %c0, %c1
+/// %1 = vector.create_mask %ub : vector<1xi1>
+struct WarpOpCreateMask : public OpRewritePattern<WarpExecuteOnLane0Op> {
+  using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
+  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *yieldOperand = getWarpResult(
+        warpOp, [](Operation *op) { return isa<vector::CreateMaskOp>(op); });
+    if (!yieldOperand)
+      return failure();
+
+    auto mask = yieldOperand->get().getDefiningOp<vector::CreateMaskOp>();
+    Location loc = mask.getLoc();
+    unsigned operandIndex = yieldOperand->getOperandNumber();
+
+    auto distType = cast<VectorType>(warpOp.getResult(operandIndex).getType());
+    VectorType seqType = mask.getVectorType();
+    auto seqShape = seqType.getShape();
+    auto distShape = distType.getShape();
+
+    rewriter.setInsertionPointAfter(warpOp);
+
+    // Delinearize the lane ID for constructing the distributed mask sizes.
+    SmallVector<Value> delinearizedIds;
+    if (!delinearizeLaneId(rewriter, loc, seqShape, distShape,
+                           warpOp.getWarpSize(), warpOp.getLaneid(),
+                           delinearizedIds))
+      return rewriter.notifyMatchFailure(
+          mask, "cannot delinearize lane ID for distribution");
+    assert(!delinearizedIds.empty());
+
+    AffineExpr s0, s1;
+    bindSymbols(rewriter.getContext(), s0, s1);
+    Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    Value prevIsZero;
+    SmallVector<Value> newOperands;
+    for (int i = 0, e = distShape.size(); i < e; ++i) {
+      // Get `mask_size[i] - lane_id[i] * (seq_sizes[i]/dist_sizes[i])` to find
+      // the distance from the largest mask index owned by this lane to the
+      // original mask size.
+      Value maskDimIdx = affine::makeComposedAffineApply(
+          rewriter, loc, s1 - s0 * distShape[i],
+          {delinearizedIds[i], mask.getOperand(i)});
+      // Clamp to the range [0, dist_sizes[i]].
+      Value clampZero = rewriter.create<arith::MaxSIOp>(loc, zero, maskDimIdx);
+      Value vecSizeVal = rewriter.create<arith::ConstantOp>(
+          loc, rewriter.getIndexAttr(distShape[i]));
+      Value clampVecSize =
+          rewriter.create<arith::MinSIOp>(loc, vecSizeVal, clampZero);
+      // If a previous mask size is zero, all trailing sizes must also be zero.
+      if (prevIsZero) {
+        clampVecSize = rewriter.create<arith::SelectOp>(loc, prevIsZero, zero,
+                                                        clampVecSize);
+      }
+      prevIsZero = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
+                                                  clampVecSize, zero);
+      newOperands.push_back(clampVecSize);
+    }
+
+    auto newMask =
+        rewriter.create<vector::CreateMaskOp>(loc, distType, newOperands);
+    rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), newMask);
+    return success();
+  }
+};
+
 /// Pattern to move out vector.extract of single element vector. Those don't
 /// need to be distributed and can just be propagated outside of the region.
 struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
@@ -1733,7 +1815,7 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
   patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit);
   patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
                WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand,
-               WarpOpConstant, WarpOpInsertElement, WarpOpInsert>(
+               WarpOpConstant, WarpOpInsertElement, WarpOpInsert, WarpOpCreateMask>(
       patterns.getContext(), benefit);
   patterns.add<WarpOpExtractElement>(patterns.getContext(),
                                      warpShuffleFromIdxFn, benefit);
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 6d8ad5a0e88c2bd..3320d01bfe0b912 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1369,3 +1369,63 @@ func.func @warp_propagate_unconnected_read_write(%laneid: index, %buffer: memref
 //       CHECK-DIST-AND-PROP:   %[[CST:.+]] = arith.constant dense<2.000000e+00> : vector<4xf32>
 //       CHECK-DIST-AND-PROP:   vector.transfer_read {{.*}} : memref<128xf32>, vector<4xf32>
 //       CHECK-DIST-AND-PROP:   vector.transfer_write %[[CST]], {{.*}} : vector<4xf32>, memref<128xf32>
+
+// -----
+
+func.func @warp_propagate_create_mask(%laneid: index, %m0: index) -> vector<1xi1> {
+  %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xi1>) {
+    %1 = vector.create_mask %m0 : vector<32xi1>
+    vector.yield %1 : vector<32xi1>
+  }
+  return %r : vector<1xi1>
+}
+
+//   CHECK-PROP-DAG: #[[$SUB:.*]] = affine_map<()[s0, s1] -> (-s0 + s1)>
+// CHECK-PROP-LABEL: func @warp_propagate_create_mask
+//  CHECK-PROP-SAME: %[[LANEID:.+]]: index, %[[M0:.+]]: index
+//       CHECK-PROP:   %[[C0:.+]] = arith.constant 0 : index
+//       CHECK-PROP:   %[[C1:.+]] = arith.constant 1 : index
+//       CHECK-PROP:   %[[MBOUNDDIST:.+]] = affine.apply #[[$SUB]]()[%[[LANEID]], %[[M0]]]
+//       CHECK-PROP:   %[[CLAMPZERO:.+]] = arith.maxsi %[[MBOUNDDIST]], %[[C0]] : index
+//       CHECK-PROP:   %[[MDIST:.+]] = arith.minsi %[[CLAMPZERO]], %[[C1]] : index
+//       CHECK-PROP:   vector.create_mask %[[MDIST]] : vector<1xi1>
+
+// -----
+
+func.func @warp_propagate_multi_dim_create_mask(%laneid: index, %m0: index, %m1: index, %m2: index) -> vector<1x2x4xi1> {
+  %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<1x2x4xi1>) {
+    %1 = vector.create_mask %m0, %m1, %m2 : vector<16x4x4xi1>
+    vector.yield %1 : vector<16x4x4xi1>
+  }
+  return %r : vector<1x2x4xi1>
+}
+
+//   CHECK-PROP-DAG: #[[$SUBM0:.*]] = affine_map<()[s0, s1] -> (s0 - s1 floordiv 2)>
+//   CHECK-PROP-DAG: #[[$SUBM1:.*]] = affine_map<()[s0, s1] -> (s0 - s1 * 2 + (s1 floordiv 2) * 4)>
+// CHECK-PROP-LABEL: func @warp_propagate_multi_dim_create_mask
+//  CHECK-PROP-SAME: %[[LANEID:.+]]: index, %[[M0:.+]]: index, %[[M1:.+]]: index, %[[M2:.+]]: index
+//   CHECK-PROP-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//   CHECK-PROP-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//   CHECK-PROP-DAG:   %[[C2:.+]] = arith.constant 2 : index
+//   CHECK-PROP-DAG:   %[[C4:.+]] = arith.constant 4 : index
+
+// Compute distributed m0 based on the first (outermost) delinearized lane id.
+//       CHECK-PROP:   affine.apply #map()[%[[M0]], %[[LANEID]]]
+//       CHECK-PROP:   arith.maxsi {{.*}}, %[[C0]] : index
+//       CHECK-PROP:   %[[DISTM0:.+]] = arith.minsi {{.*}}, %[[C1]] : index
+//       CHECK-PROP:   arith.cmpi eq, {{.*}}, %[[C0]] : index
+
+// Compute m1 based on the second delinearized lane id. If m0 is zero, m1 must
+// also be zero for the mask to be valid.
+//       CHECK-PROP:   affine.apply #map1()[%[[M1]], %[[LANEID]]]
+//       CHECK-PROP:   arith.maxsi {{.*}}, %[[C0]] : index
+//       CHECK-PROP:   arith.minsi {{.*}}, %[[C2]] : index
+//       CHECK-PROP:   %[[DISTM1:.+]] = arith.select {{.*}}, %[[C0]], {{.*}} : index
+
+// Compute m3 and propagate zeros.
+//       CHECK-PROP:   arith.cmpi eq, {{.*}}, %[[C0]] : index
+//       CHECK-PROP:   arith.maxsi %[[M2]], %[[C0]] : index
+//       CHECK-PROP:   arith.minsi {{.*}}, %[[C4]] : index
+//       CHECK-PROP:   %[[DISTM2:.+]] = arith.select {{.*}}, %[[C0]], {{.*}} : index
+
+//       CHECK-PROP:   vector.create_mask %[[DISTM0]], %[[DISTM1]], %[[DISTM2]] : vector<1x2x4xi1>

>From 7c3c682ad55bf938b5e5603bd820b84d728b4e82 Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Wed, 8 Nov 2023 10:30:57 -0500
Subject: [PATCH 2/3] Simplify create_mask distribution pattern based on
 create_mask semantics

---
 .../Vector/Transforms/VectorDistribute.cpp    | 28 ++++++---------
 .../Vector/vector-warp-distribute.mlir        | 34 +++----------------
 2 files changed, 15 insertions(+), 47 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 355de7558fd60f5..d43d2442053fd46 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1106,6 +1106,14 @@ struct WarpOpCreateMask : public OpRewritePattern<WarpExecuteOnLane0Op> {
       return failure();
 
     auto mask = yieldOperand->get().getDefiningOp<vector::CreateMaskOp>();
+
+    // Early exit if any values needed for calculating the new mask indices
+    // are defined inside the warp op.
+    if (!llvm::all_of(mask->getOperands(), [&](Value value) {
+          return warpOp.isDefinedOutsideOfRegion(value);
+        }))
+      return failure();
+
     Location loc = mask.getLoc();
     unsigned operandIndex = yieldOperand->getOperandNumber();
 
@@ -1127,30 +1135,16 @@ struct WarpOpCreateMask : public OpRewritePattern<WarpExecuteOnLane0Op> {
 
     AffineExpr s0, s1;
     bindSymbols(rewriter.getContext(), s0, s1);
-    Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
-    Value prevIsZero;
     SmallVector<Value> newOperands;
     for (int i = 0, e = distShape.size(); i < e; ++i) {
       // Get `mask_size[i] - lane_id[i] * (seq_sizes[i]/dist_sizes[i])` to find
       // the distance from the largest mask index owned by this lane to the
-      // original mask size.
+      // original mask size. vector.create_mask implicitly clamps mask sizes to
+      // the range [0, mask_vector_size[i]].
       Value maskDimIdx = affine::makeComposedAffineApply(
           rewriter, loc, s1 - s0 * distShape[i],
           {delinearizedIds[i], mask.getOperand(i)});
-      // Clamp to the range [0, dist_sizes[i]].
-      Value clampZero = rewriter.create<arith::MaxSIOp>(loc, zero, maskDimIdx);
-      Value vecSizeVal = rewriter.create<arith::ConstantOp>(
-          loc, rewriter.getIndexAttr(distShape[i]));
-      Value clampVecSize =
-          rewriter.create<arith::MinSIOp>(loc, vecSizeVal, clampZero);
-      // If a previous mask size is zero, all trailing sizes must also be zero.
-      if (prevIsZero) {
-        clampVecSize = rewriter.create<arith::SelectOp>(loc, prevIsZero, zero,
-                                                        clampVecSize);
-      }
-      prevIsZero = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
-                                                  clampVecSize, zero);
-      newOperands.push_back(clampVecSize);
+      newOperands.push_back(maskDimIdx);
     }
 
     auto newMask =
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 3320d01bfe0b912..1821190c44e3af4 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1383,11 +1383,7 @@ func.func @warp_propagate_create_mask(%laneid: index, %m0: index) -> vector<1xi1
 //   CHECK-PROP-DAG: #[[$SUB:.*]] = affine_map<()[s0, s1] -> (-s0 + s1)>
 // CHECK-PROP-LABEL: func @warp_propagate_create_mask
 //  CHECK-PROP-SAME: %[[LANEID:.+]]: index, %[[M0:.+]]: index
-//       CHECK-PROP:   %[[C0:.+]] = arith.constant 0 : index
-//       CHECK-PROP:   %[[C1:.+]] = arith.constant 1 : index
-//       CHECK-PROP:   %[[MBOUNDDIST:.+]] = affine.apply #[[$SUB]]()[%[[LANEID]], %[[M0]]]
-//       CHECK-PROP:   %[[CLAMPZERO:.+]] = arith.maxsi %[[MBOUNDDIST]], %[[C0]] : index
-//       CHECK-PROP:   %[[MDIST:.+]] = arith.minsi %[[CLAMPZERO]], %[[C1]] : index
+//       CHECK-PROP:   %[[MDIST:.+]] = affine.apply #[[$SUB]]()[%[[LANEID]], %[[M0]]]
 //       CHECK-PROP:   vector.create_mask %[[MDIST]] : vector<1xi1>
 
 // -----
@@ -1404,28 +1400,6 @@ func.func @warp_propagate_multi_dim_create_mask(%laneid: index, %m0: index, %m1:
 //   CHECK-PROP-DAG: #[[$SUBM1:.*]] = affine_map<()[s0, s1] -> (s0 - s1 * 2 + (s1 floordiv 2) * 4)>
 // CHECK-PROP-LABEL: func @warp_propagate_multi_dim_create_mask
 //  CHECK-PROP-SAME: %[[LANEID:.+]]: index, %[[M0:.+]]: index, %[[M1:.+]]: index, %[[M2:.+]]: index
-//   CHECK-PROP-DAG:   %[[C0:.+]] = arith.constant 0 : index
-//   CHECK-PROP-DAG:   %[[C1:.+]] = arith.constant 1 : index
-//   CHECK-PROP-DAG:   %[[C2:.+]] = arith.constant 2 : index
-//   CHECK-PROP-DAG:   %[[C4:.+]] = arith.constant 4 : index
-
-// Compute distributed m0 based on the first (outermost) delinearized lane id.
-//       CHECK-PROP:   affine.apply #map()[%[[M0]], %[[LANEID]]]
-//       CHECK-PROP:   arith.maxsi {{.*}}, %[[C0]] : index
-//       CHECK-PROP:   %[[DISTM0:.+]] = arith.minsi {{.*}}, %[[C1]] : index
-//       CHECK-PROP:   arith.cmpi eq, {{.*}}, %[[C0]] : index
-
-// Compute m1 based on the second delinearized lane id. If m0 is zero, m1 must
-// also be zero for the mask to be valid.
-//       CHECK-PROP:   affine.apply #map1()[%[[M1]], %[[LANEID]]]
-//       CHECK-PROP:   arith.maxsi {{.*}}, %[[C0]] : index
-//       CHECK-PROP:   arith.minsi {{.*}}, %[[C2]] : index
-//       CHECK-PROP:   %[[DISTM1:.+]] = arith.select {{.*}}, %[[C0]], {{.*}} : index
-
-// Compute m3 and propagate zeros.
-//       CHECK-PROP:   arith.cmpi eq, {{.*}}, %[[C0]] : index
-//       CHECK-PROP:   arith.maxsi %[[M2]], %[[C0]] : index
-//       CHECK-PROP:   arith.minsi {{.*}}, %[[C4]] : index
-//       CHECK-PROP:   %[[DISTM2:.+]] = arith.select {{.*}}, %[[C0]], {{.*}} : index
-
-//       CHECK-PROP:   vector.create_mask %[[DISTM0]], %[[DISTM1]], %[[DISTM2]] : vector<1x2x4xi1>
+//       CHECK-PROP:   %[[DISTM0:.+]] = affine.apply #[[$SUBM0]]()[%[[M0]], %[[LANEID]]]
+//       CHECK-PROP:   %[[DISTM1:.+]] = affine.apply #[[$SUBM1]]()[%[[M1]], %[[LANEID]]]
+//       CHECK-PROP:   vector.create_mask %[[DISTM0]], %[[DISTM1]], %[[M2]] : vector<1x2x4xi1>

>From c2e4030266190e2194e4509df91f56cea6afaaa1 Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Fri, 10 Nov 2023 09:06:23 -0500
Subject: [PATCH 3/3] Address comments

---
 .../Vector/Transforms/VectorDistribute.cpp     | 18 ++++++++++--------
 1 file changed, 10 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index d43d2442053fd46..2a1d5354732aff6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1137,10 +1137,11 @@ struct WarpOpCreateMask : public OpRewritePattern<WarpExecuteOnLane0Op> {
     bindSymbols(rewriter.getContext(), s0, s1);
     SmallVector<Value> newOperands;
     for (int i = 0, e = distShape.size(); i < e; ++i) {
-      // Get `mask_size[i] - lane_id[i] * (seq_sizes[i]/dist_sizes[i])` to find
-      // the distance from the largest mask index owned by this lane to the
-      // original mask size. vector.create_mask implicitly clamps mask sizes to
-      // the range [0, mask_vector_size[i]].
+      // Get `mask_dim_range_upper_limit[i] - lane_id[i] * dist_sizes[i]` to
+      // find the distance from the largest mask index owned by this lane to the
+      // original mask size. `vector.create_mask` implicitly clamps mask
+      // operands to the range [0, mask_vector_size[i]], or in other words, the
+      // mask sizes are always in the range [0, mask_vector_size[i]).
       Value maskDimIdx = affine::makeComposedAffineApply(
           rewriter, loc, s1 - s0 * distShape[i],
           {delinearizedIds[i], mask.getOperand(i)});
@@ -1807,10 +1808,11 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
     const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
     PatternBenefit readBenefit) {
   patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit);
-  patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
-               WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand,
-               WarpOpConstant, WarpOpInsertElement, WarpOpInsert, WarpOpCreateMask>(
-      patterns.getContext(), benefit);
+  patterns
+      .add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
+           WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant,
+           WarpOpInsertElement, WarpOpInsert, WarpOpCreateMask>(
+          patterns.getContext(), benefit);
   patterns.add<WarpOpExtractElement>(patterns.getContext(),
                                      warpShuffleFromIdxFn, benefit);
   patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,



More information about the Mlir-commits mailing list