[Mlir-commits] [mlir] [mlir][vector] Add distribution pattern for vector.create_mask (PR #71619)

Quinn Dawkins llvmlistbot at llvm.org
Thu Nov 9 06:41:56 PST 2023


https://github.com/qedawkins updated https://github.com/llvm/llvm-project/pull/71619

>From 6934a8489909e7e39993e6c00dd2986528f0dcf6 Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Sun, 5 Nov 2023 21:22:07 -0500
Subject: [PATCH 1/2] [MLIR][Vector] Add distribution pattern for
 vector.create_mask

This is the last step needed for basic support for distributing masked
vector code. The lane id gets delinearized based on the distributed mask
shape and then compared against the original mask sizes to compute the
bounds for the distributed mask.
---
 .../Vector/Transforms/VectorDistribute.cpp    | 84 ++++++++++++++++++-
 .../Vector/vector-warp-distribute.mlir        | 60 +++++++++++++
 2 files changed, 143 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 1975ba9c92d9988..5ae54443be67c05 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1071,6 +1071,88 @@ struct WarpOpShapeCast : public OpRewritePattern<WarpExecuteOnLane0Op> {
   }
 };
 
+/// Sink out vector.create_mask op feeding into a warp op yield.
+/// ```
+/// %0 = ...
+/// %1 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
+///   ...
+///   %mask = vector.create_mask %0 : vector<32xi1>
+///   vector.yield %mask : vector<32xi1>
+/// }
+/// ```
+/// To
+/// ```
+/// %0 = ...
+/// vector.warp_execute_on_lane_0(%arg0) {
+///   ...
+/// }
+/// %cmp = arith.cmpi ult, %laneid, %0
+/// %ub = arith.select %cmp, %c0, %c1
+/// %1 = vector.create_mask %ub : vector<1xi1>
+struct WarpOpCreateMask : public OpRewritePattern<WarpExecuteOnLane0Op> {
+  using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
+  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *yieldOperand = getWarpResult(
+        warpOp, [](Operation *op) { return isa<vector::CreateMaskOp>(op); });
+    if (!yieldOperand)
+      return failure();
+
+    auto mask = yieldOperand->get().getDefiningOp<vector::CreateMaskOp>();
+    Location loc = mask.getLoc();
+    unsigned operandIndex = yieldOperand->getOperandNumber();
+
+    auto distType = cast<VectorType>(warpOp.getResult(operandIndex).getType());
+    VectorType seqType = mask.getVectorType();
+    auto seqShape = seqType.getShape();
+    auto distShape = distType.getShape();
+
+    rewriter.setInsertionPointAfter(warpOp);
+
+    // Delinearize the lane ID for constructing the distributed mask sizes.
+    SmallVector<Value> delinearizedIds;
+    if (!delinearizeLaneId(rewriter, loc, seqShape, distShape,
+                           warpOp.getWarpSize(), warpOp.getLaneid(),
+                           delinearizedIds))
+      return rewriter.notifyMatchFailure(
+          mask, "cannot delinearize lane ID for distribution");
+    assert(!delinearizedIds.empty());
+
+    AffineExpr s0, s1;
+    bindSymbols(rewriter.getContext(), s0, s1);
+    Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    Value prevIsZero;
+    SmallVector<Value> newOperands;
+    for (int i = 0, e = distShape.size(); i < e; ++i) {
+      // Get `mask_size[i] - lane_id[i] * (seq_sizes[i]/dist_sizes[i])` to find
+      // the distance from the largest mask index owned by this lane to the
+      // original mask size.
+      Value maskDimIdx = affine::makeComposedAffineApply(
+          rewriter, loc, s1 - s0 * distShape[i],
+          {delinearizedIds[i], mask.getOperand(i)});
+      // Clamp to the range [0, dist_sizes[i]].
+      Value clampZero = rewriter.create<arith::MaxSIOp>(loc, zero, maskDimIdx);
+      Value vecSizeVal = rewriter.create<arith::ConstantOp>(
+          loc, rewriter.getIndexAttr(distShape[i]));
+      Value clampVecSize =
+          rewriter.create<arith::MinSIOp>(loc, vecSizeVal, clampZero);
+      // If a previous mask size is zero, all trailing sizes must also be zero.
+      if (prevIsZero) {
+        clampVecSize = rewriter.create<arith::SelectOp>(loc, prevIsZero, zero,
+                                                        clampVecSize);
+      }
+      prevIsZero = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
+                                                  clampVecSize, zero);
+      newOperands.push_back(clampVecSize);
+    }
+
+    auto newMask =
+        rewriter.create<vector::CreateMaskOp>(loc, distType, newOperands);
+    rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), newMask);
+    return success();
+  }
+};
+
 /// Pattern to move out vector.extract of single element vector. Those don't
 /// need to be distributed and can just be propagated outside of the region.
 struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
@@ -1725,7 +1807,7 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
   patterns.add<WarpOpElementwise, WarpOpTransferRead, WarpOpDeadResult,
                WarpOpBroadcast, WarpOpShapeCast, WarpOpExtract,
                WarpOpForwardOperand, WarpOpConstant, WarpOpInsertElement,
-               WarpOpInsert>(patterns.getContext(), benefit);
+               WarpOpCreateMask, WarpOpInsert>(patterns.getContext(), benefit);
   patterns.add<WarpOpExtractElement>(patterns.getContext(),
                                      warpShuffleFromIdxFn, benefit);
   patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 2a1007fbbe86435..1f96fa1545d8fa7 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1311,3 +1311,63 @@ func.func @warp_propagate_masked_transfer_read(%laneid: index, %src: memref<4096
 //       CHECK-PROP:   vector.transfer_read {{.*}}[%[[DIST_READ_IDX0]], %[[ARG2]]], {{.*}}, %[[R]]#1 {{.*}} vector<2x2xf32>
 //       CHECK-PROP:   %[[DIST_READ_IDX1:.+]] = affine.apply #[[$MAP1]]()[%[[ARG2]], %[[ARG0]]]
 //       CHECK-PROP:   vector.transfer_read {{.*}}[%[[C0]], %[[DIST_READ_IDX1]]], {{.*}}, %[[R]]#0 {{.*}} vector<2xf32>
+
+// -----
+
+func.func @warp_propagate_create_mask(%laneid: index, %m0: index) -> vector<1xi1> {
+  %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xi1>) {
+    %1 = vector.create_mask %m0 : vector<32xi1>
+    vector.yield %1 : vector<32xi1>
+  }
+  return %r : vector<1xi1>
+}
+
+//   CHECK-PROP-DAG: #[[$SUB:.*]] = affine_map<()[s0, s1] -> (-s0 + s1)>
+// CHECK-PROP-LABEL: func @warp_propagate_create_mask
+//  CHECK-PROP-SAME: %[[LANEID:.+]]: index, %[[M0:.+]]: index
+//       CHECK-PROP:   %[[C0:.+]] = arith.constant 0 : index
+//       CHECK-PROP:   %[[C1:.+]] = arith.constant 1 : index
+//       CHECK-PROP:   %[[MBOUNDDIST:.+]] = affine.apply #[[$SUB]]()[%[[LANEID]], %[[M0]]]
+//       CHECK-PROP:   %[[CLAMPZERO:.+]] = arith.maxsi %[[MBOUNDDIST]], %[[C0]] : index
+//       CHECK-PROP:   %[[MDIST:.+]] = arith.minsi %[[CLAMPZERO]], %[[C1]] : index
+//       CHECK-PROP:   vector.create_mask %[[MDIST]] : vector<1xi1>
+
+// -----
+
+func.func @warp_propagate_multi_dim_create_mask(%laneid: index, %m0: index, %m1: index, %m2: index) -> vector<1x2x4xi1> {
+  %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<1x2x4xi1>) {
+    %1 = vector.create_mask %m0, %m1, %m2 : vector<16x4x4xi1>
+    vector.yield %1 : vector<16x4x4xi1>
+  }
+  return %r : vector<1x2x4xi1>
+}
+
+//   CHECK-PROP-DAG: #[[$SUBM0:.*]] = affine_map<()[s0, s1] -> (s0 - s1 floordiv 2)>
+//   CHECK-PROP-DAG: #[[$SUBM1:.*]] = affine_map<()[s0, s1] -> (s0 - s1 * 2 + (s1 floordiv 2) * 4)>
+// CHECK-PROP-LABEL: func @warp_propagate_multi_dim_create_mask
+//  CHECK-PROP-SAME: %[[LANEID:.+]]: index, %[[M0:.+]]: index, %[[M1:.+]]: index, %[[M2:.+]]: index
+//   CHECK-PROP-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//   CHECK-PROP-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//   CHECK-PROP-DAG:   %[[C2:.+]] = arith.constant 2 : index
+//   CHECK-PROP-DAG:   %[[C4:.+]] = arith.constant 4 : index
+
+// Compute distributed m0 based on the first (outermost) delinearized lane id.
+//       CHECK-PROP:   affine.apply #map()[%[[M0]], %[[LANEID]]]
+//       CHECK-PROP:   arith.maxsi {{.*}}, %[[C0]] : index
+//       CHECK-PROP:   %[[DISTM0:.+]] = arith.minsi {{.*}}, %[[C1]] : index
+//       CHECK-PROP:   arith.cmpi eq, {{.*}}, %[[C0]] : index
+
+// Compute m1 based on the second delinearized lane id. If m0 is zero, m1 must
+// also be zero for the mask to be valid.
+//       CHECK-PROP:   affine.apply #map1()[%[[M1]], %[[LANEID]]]
+//       CHECK-PROP:   arith.maxsi {{.*}}, %[[C0]] : index
+//       CHECK-PROP:   arith.minsi {{.*}}, %[[C2]] : index
+//       CHECK-PROP:   %[[DISTM1:.+]] = arith.select {{.*}}, %[[C0]], {{.*}} : index
+
+// Compute m3 and propagate zeros.
+//       CHECK-PROP:   arith.cmpi eq, {{.*}}, %[[C0]] : index
+//       CHECK-PROP:   arith.maxsi %[[M2]], %[[C0]] : index
+//       CHECK-PROP:   arith.minsi {{.*}}, %[[C4]] : index
+//       CHECK-PROP:   %[[DISTM2:.+]] = arith.select {{.*}}, %[[C0]], {{.*}} : index
+
+//       CHECK-PROP:   vector.create_mask %[[DISTM0]], %[[DISTM1]], %[[DISTM2]] : vector<1x2x4xi1>

>From c90b93a9a4e2e7abc9afe36886e1e474c7107160 Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Wed, 8 Nov 2023 10:30:57 -0500
Subject: [PATCH 2/2] Simplify create_mask distribution pattern based on
 create_mask semantics

---
 .../Vector/Transforms/VectorDistribute.cpp    | 28 ++++++---------
 .../Vector/vector-warp-distribute.mlir        | 34 +++----------------
 2 files changed, 15 insertions(+), 47 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 5ae54443be67c05..145dc4cebcc34b3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1099,6 +1099,14 @@ struct WarpOpCreateMask : public OpRewritePattern<WarpExecuteOnLane0Op> {
       return failure();
 
     auto mask = yieldOperand->get().getDefiningOp<vector::CreateMaskOp>();
+
+    // Early exit if any values needed for calculating the new mask indices
+    // are defined inside the warp op.
+    if (!llvm::all_of(mask->getOperands(), [&](Value value) {
+          return warpOp.isDefinedOutsideOfRegion(value);
+        }))
+      return failure();
+
     Location loc = mask.getLoc();
     unsigned operandIndex = yieldOperand->getOperandNumber();
 
@@ -1120,30 +1128,16 @@ struct WarpOpCreateMask : public OpRewritePattern<WarpExecuteOnLane0Op> {
 
     AffineExpr s0, s1;
     bindSymbols(rewriter.getContext(), s0, s1);
-    Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
-    Value prevIsZero;
     SmallVector<Value> newOperands;
     for (int i = 0, e = distShape.size(); i < e; ++i) {
       // Get `mask_size[i] - lane_id[i] * (seq_sizes[i]/dist_sizes[i])` to find
       // the distance from the largest mask index owned by this lane to the
-      // original mask size.
+      // original mask size. vector.create_mask implicitly clamps mask sizes to
+      // the range [0, mask_vector_size[i]].
       Value maskDimIdx = affine::makeComposedAffineApply(
           rewriter, loc, s1 - s0 * distShape[i],
           {delinearizedIds[i], mask.getOperand(i)});
-      // Clamp to the range [0, dist_sizes[i]].
-      Value clampZero = rewriter.create<arith::MaxSIOp>(loc, zero, maskDimIdx);
-      Value vecSizeVal = rewriter.create<arith::ConstantOp>(
-          loc, rewriter.getIndexAttr(distShape[i]));
-      Value clampVecSize =
-          rewriter.create<arith::MinSIOp>(loc, vecSizeVal, clampZero);
-      // If a previous mask size is zero, all trailing sizes must also be zero.
-      if (prevIsZero) {
-        clampVecSize = rewriter.create<arith::SelectOp>(loc, prevIsZero, zero,
-                                                        clampVecSize);
-      }
-      prevIsZero = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
-                                                  clampVecSize, zero);
-      newOperands.push_back(clampVecSize);
+      newOperands.push_back(maskDimIdx);
     }
 
     auto newMask =
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 1f96fa1545d8fa7..1a9245aad064e8d 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1325,11 +1325,7 @@ func.func @warp_propagate_create_mask(%laneid: index, %m0: index) -> vector<1xi1
 //   CHECK-PROP-DAG: #[[$SUB:.*]] = affine_map<()[s0, s1] -> (-s0 + s1)>
 // CHECK-PROP-LABEL: func @warp_propagate_create_mask
 //  CHECK-PROP-SAME: %[[LANEID:.+]]: index, %[[M0:.+]]: index
-//       CHECK-PROP:   %[[C0:.+]] = arith.constant 0 : index
-//       CHECK-PROP:   %[[C1:.+]] = arith.constant 1 : index
-//       CHECK-PROP:   %[[MBOUNDDIST:.+]] = affine.apply #[[$SUB]]()[%[[LANEID]], %[[M0]]]
-//       CHECK-PROP:   %[[CLAMPZERO:.+]] = arith.maxsi %[[MBOUNDDIST]], %[[C0]] : index
-//       CHECK-PROP:   %[[MDIST:.+]] = arith.minsi %[[CLAMPZERO]], %[[C1]] : index
+//       CHECK-PROP:   %[[MDIST:.+]] = affine.apply #[[$SUB]]()[%[[LANEID]], %[[M0]]]
 //       CHECK-PROP:   vector.create_mask %[[MDIST]] : vector<1xi1>
 
 // -----
@@ -1346,28 +1342,6 @@ func.func @warp_propagate_multi_dim_create_mask(%laneid: index, %m0: index, %m1:
 //   CHECK-PROP-DAG: #[[$SUBM1:.*]] = affine_map<()[s0, s1] -> (s0 - s1 * 2 + (s1 floordiv 2) * 4)>
 // CHECK-PROP-LABEL: func @warp_propagate_multi_dim_create_mask
 //  CHECK-PROP-SAME: %[[LANEID:.+]]: index, %[[M0:.+]]: index, %[[M1:.+]]: index, %[[M2:.+]]: index
-//   CHECK-PROP-DAG:   %[[C0:.+]] = arith.constant 0 : index
-//   CHECK-PROP-DAG:   %[[C1:.+]] = arith.constant 1 : index
-//   CHECK-PROP-DAG:   %[[C2:.+]] = arith.constant 2 : index
-//   CHECK-PROP-DAG:   %[[C4:.+]] = arith.constant 4 : index
-
-// Compute distributed m0 based on the first (outermost) delinearized lane id.
-//       CHECK-PROP:   affine.apply #map()[%[[M0]], %[[LANEID]]]
-//       CHECK-PROP:   arith.maxsi {{.*}}, %[[C0]] : index
-//       CHECK-PROP:   %[[DISTM0:.+]] = arith.minsi {{.*}}, %[[C1]] : index
-//       CHECK-PROP:   arith.cmpi eq, {{.*}}, %[[C0]] : index
-
-// Compute m1 based on the second delinearized lane id. If m0 is zero, m1 must
-// also be zero for the mask to be valid.
-//       CHECK-PROP:   affine.apply #map1()[%[[M1]], %[[LANEID]]]
-//       CHECK-PROP:   arith.maxsi {{.*}}, %[[C0]] : index
-//       CHECK-PROP:   arith.minsi {{.*}}, %[[C2]] : index
-//       CHECK-PROP:   %[[DISTM1:.+]] = arith.select {{.*}}, %[[C0]], {{.*}} : index
-
-// Compute m3 and propagate zeros.
-//       CHECK-PROP:   arith.cmpi eq, {{.*}}, %[[C0]] : index
-//       CHECK-PROP:   arith.maxsi %[[M2]], %[[C0]] : index
-//       CHECK-PROP:   arith.minsi {{.*}}, %[[C4]] : index
-//       CHECK-PROP:   %[[DISTM2:.+]] = arith.select {{.*}}, %[[C0]], {{.*}} : index
-
-//       CHECK-PROP:   vector.create_mask %[[DISTM0]], %[[DISTM1]], %[[DISTM2]] : vector<1x2x4xi1>
+//       CHECK-PROP:   %[[DISTM0:.+]] = affine.apply #[[$SUBM0]]()[%[[M0]], %[[LANEID]]]
+//       CHECK-PROP:   %[[DISTM1:.+]] = affine.apply #[[$SUBM1]]()[%[[M1]], %[[LANEID]]]
+//       CHECK-PROP:   vector.create_mask %[[DISTM0]], %[[DISTM1]], %[[M2]] : vector<1x2x4xi1>



More information about the Mlir-commits mailing list