[Mlir-commits] [mlir] 1523b72 - [mlir][vector] Distribute vector.insert op

Matthias Springer llvmlistbot at llvm.org
Mon Jan 9 07:50:42 PST 2023


Author: Matthias Springer
Date: 2023-01-09T16:50:28+01:00
New Revision: 1523b72946c330070c5bc55d21a58800c845db50

URL: https://github.com/llvm/llvm-project/commit/1523b72946c330070c5bc55d21a58800c845db50
DIFF: https://github.com/llvm/llvm-project/commit/1523b72946c330070c5bc55d21a58800c845db50.diff

LOG: [mlir][vector] Distribute vector.insert op

In case the distributed dim of the dest vector is also a dim of the src vector, each lane inserts a smaller part of the source vector. Otherwise, one lane inserts the entire src vector and the other lanes do nothing.

Differential Revision: https://reviews.llvm.org/D137953

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
    mlir/test/Dialect/Vector/vector-warp-distribute.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index df7b240a4d5c..f60bae61792a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1133,6 +1133,131 @@ struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
   }
 };
 
+struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
+  using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *operand = getWarpResult(
+        warpOp, [](Operation *op) { return isa<vector::InsertOp>(op); });
+    if (!operand)
+      return failure();
+    unsigned int operandNumber = operand->getOperandNumber();
+    auto insertOp = operand->get().getDefiningOp<vector::InsertOp>();
+    Location loc = insertOp.getLoc();
+
+    // "vector.insert %v, %v[] : ..." can be canonicalized to %v.
+    if (insertOp.getPosition().empty())
+      return failure();
+
+    // Rewrite vector.insert with 1d dest to vector.insertelement.
+    if (insertOp.getDestVectorType().getRank() == 1) {
+      assert(insertOp.getPosition().size() == 1 && "expected 1 index");
+      int64_t pos = insertOp.getPosition()[0].cast<IntegerAttr>().getInt();
+      rewriter.setInsertionPoint(insertOp);
+      rewriter.replaceOpWithNewOp<vector::InsertElementOp>(
+          insertOp, insertOp.getSource(), insertOp.getDest(),
+          rewriter.create<arith::ConstantIndexOp>(loc, pos));
+      return success();
+    }
+
+    if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) {
+      // There is no distribution, this is a broadcast. Simply move the insert
+      // out of the warp op.
+      SmallVector<size_t> newRetIndices;
+      WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+          rewriter, warpOp, {insertOp.getSource(), insertOp.getDest()},
+          {insertOp.getSourceType(), insertOp.getDestVectorType()},
+          newRetIndices);
+      rewriter.setInsertionPointAfter(newWarpOp);
+      Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
+      Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
+      Value newResult = rewriter.create<vector::InsertOp>(
+          loc, distributedSrc, distributedDest, insertOp.getPosition());
+      newWarpOp->getResult(operandNumber).replaceAllUsesWith(newResult);
+      return success();
+    }
+
+    // Find the distributed dimension. There should be exactly one.
+    auto distrDestType =
+        warpOp.getResult(operandNumber).getType().cast<VectorType>();
+    auto yieldedType = operand->get().getType().cast<VectorType>();
+    int64_t distrDestDim = -1;
+    for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
+      if (distrDestType.getDimSize(i) != yieldedType.getDimSize(i)) {
+        // Keep this assert here in case WarpExecuteOnLane0Op gets extended to
+        // support distributing multiple dimensions in the future.
+        assert(distrDestDim == -1 && "found multiple distributed dims");
+        distrDestDim = i;
+      }
+    }
+    assert(distrDestDim != -1 && "could not find distributed dimension");
+
+    // Compute the distributed source vector type.
+    VectorType srcVecType = insertOp.getSourceType().cast<VectorType>();
+    SmallVector<int64_t> distrSrcShape(srcVecType.getShape().begin(),
+                                       srcVecType.getShape().end());
+    // E.g.: vector.insert %s, %d [2] : vector<96xf32> into vector<128x96xf32>
+    // Case 1: distrDestDim = 1 (dim of size 96). In that case, each lane will
+    //         insert a smaller vector<3xf32>.
+    // Case 2: distrDestDim = 0 (dim of size 128) => distrSrcDim = -1. In that
+    //         case, one lane will insert the source vector<96xf32>. The other
+    //         lanes will not do anything.
+    int64_t distrSrcDim = distrDestDim - insertOp.getPosition().size();
+    if (distrSrcDim >= 0)
+      distrSrcShape[distrSrcDim] = distrDestType.getDimSize(distrDestDim);
+    auto distrSrcType =
+        VectorType::get(distrSrcShape, distrDestType.getElementType());
+
+    // Yield source and dest vectors from warp op.
+    SmallVector<size_t> newRetIndices;
+    WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, {insertOp.getSource(), insertOp.getDest()},
+        {distrSrcType, distrDestType}, newRetIndices);
+    rewriter.setInsertionPointAfter(newWarpOp);
+    Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
+    Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
+
+    // Insert into the distributed vector.
+    Value newResult;
+    if (distrSrcDim >= 0) {
+      // Every lane inserts a small piece.
+      newResult = rewriter.create<vector::InsertOp>(
+          loc, distributedSrc, distributedDest, insertOp.getPosition());
+    } else {
+      // One lane inserts the entire source vector.
+      int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim);
+      SmallVector<int64_t> newPos = llvm::to_vector(
+          llvm::map_range(insertOp.getPosition(), [](Attribute attr) {
+            return attr.cast<IntegerAttr>().getInt();
+          }));
+      // tid of inserting lane: pos / elementsPerLane
+      Value insertingLane = rewriter.create<arith::ConstantIndexOp>(
+          loc, newPos[distrDestDim] / elementsPerLane);
+      Value isInsertingLane = rewriter.create<arith::CmpIOp>(
+          loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
+      // Insert position: pos % elementsPerLane
+      newPos[distrDestDim] %= elementsPerLane;
+      auto insertingBuilder = [&](OpBuilder &builder, Location loc) {
+        Value newInsert = builder.create<vector::InsertOp>(
+            loc, distributedSrc, distributedDest, newPos);
+        builder.create<scf::YieldOp>(loc, newInsert);
+      };
+      auto nonInsertingBuilder = [&](OpBuilder &builder, Location loc) {
+        builder.create<scf::YieldOp>(loc, distributedDest);
+      };
+      newResult = rewriter
+                      .create<scf::IfOp>(loc, distrDestType, isInsertingLane,
+                                         /*thenBuilder=*/insertingBuilder,
+                                         /*elseBuilder=*/nonInsertingBuilder)
+                      .getResult(0);
+    }
+
+    newWarpOp->getResult(operandNumber).replaceAllUsesWith(newResult);
+    return success();
+  }
+};
+
 /// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
 /// the scf.ForOp is the last operation in the region so that it doesn't change
 /// the order of execution. This creates a new scf.for region after the
@@ -1387,8 +1512,8 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
     const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit) {
   patterns.add<WarpOpElementwise, WarpOpTransferRead, WarpOpDeadResult,
                WarpOpBroadcast, WarpOpExtract, WarpOpForwardOperand,
-               WarpOpConstant, WarpOpInsertElement>(patterns.getContext(),
-                                                    benefit);
+               WarpOpConstant, WarpOpInsertElement, WarpOpInsert>(
+      patterns.getContext(), benefit);
   patterns.add<WarpOpExtractElement>(patterns.getContext(),
                                      warpShuffleFromIdxFn, benefit);
   patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,

diff  --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index b19c3cd96a26..2dd54771d897 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -994,3 +994,98 @@ func.func @vector_insertelement_0d(%laneid: index) -> (vector<f32>) {
   }
   return %r : vector<f32>
 }
+
+// -----
+
+// CHECK-PROP-LABEL: func @vector_insert_1d(
+//  CHECK-PROP-SAME:     %[[LANEID:.*]]: index
+//   CHECK-PROP-DAG:   %[[C1:.*]] = arith.constant 1 : index
+//   CHECK-PROP-DAG:   %[[C26:.*]] = arith.constant 26 : index
+//       CHECK-PROP:   %[[W:.*]]:2 = vector.warp_execute_on_lane_0{{.*}} -> (vector<3xf32>, f32)
+//       CHECK-PROP:     %[[VEC:.*]] = "some_def"
+//       CHECK-PROP:     %[[VAL:.*]] = "another_def"
+//       CHECK-PROP:     vector.yield %[[VEC]], %[[VAL]]
+//       CHECK-PROP:   %[[SHOULD_INSERT:.*]] = arith.cmpi eq, %[[LANEID]], %[[C26]]
+//       CHECK-PROP:   %[[R:.*]] = scf.if %[[SHOULD_INSERT]] -> (vector<3xf32>) {
+//       CHECK-PROP:     %[[INSERT:.*]] = vector.insertelement %[[W]]#1, %[[W]]#0[%[[C1]] : index]
+//       CHECK-PROP:     scf.yield %[[INSERT]]
+//       CHECK-PROP:   } else {
+//       CHECK-PROP:     scf.yield %[[W]]#0
+//       CHECK-PROP:   }
+//       CHECK-PROP:   return %[[R]]
+func.func @vector_insert_1d(%laneid: index) -> (vector<3xf32>) {
+  %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<3xf32>) {
+    %0 = "some_def"() : () -> (vector<96xf32>)
+    %f = "another_def"() : () -> (f32)
+    %1 = vector.insert %f, %0[76] : f32 into vector<96xf32>
+    vector.yield %1 : vector<96xf32>
+  }
+  return %r : vector<3xf32>
+}
+
+// -----
+
+// CHECK-PROP-LABEL: func @vector_insert_2d_distr_src(
+//  CHECK-PROP-SAME:     %[[LANEID:.*]]: index
+//       CHECK-PROP:   %[[W:.*]]:2 = vector.warp_execute_on_lane_0{{.*}} -> (vector<3xf32>, vector<4x3xf32>)
+//       CHECK-PROP:     %[[VEC:.*]] = "some_def"
+//       CHECK-PROP:     %[[VAL:.*]] = "another_def"
+//       CHECK-PROP:     vector.yield %[[VAL]], %[[VEC]]
+//       CHECK-PROP:   %[[INSERT:.*]] = vector.insert %[[W]]#0, %[[W]]#1 [2] : vector<3xf32> into vector<4x3xf32>
+//       CHECK-PROP:   return %[[INSERT]]
+func.func @vector_insert_2d_distr_src(%laneid: index) -> (vector<4x3xf32>) {
+  %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<4x3xf32>) {
+    %0 = "some_def"() : () -> (vector<4x96xf32>)
+    %s = "another_def"() : () -> (vector<96xf32>)
+    %1 = vector.insert %s, %0[2] : vector<96xf32> into vector<4x96xf32>
+    vector.yield %1 : vector<4x96xf32>
+  }
+  return %r : vector<4x3xf32>
+}
+
+// -----
+
+// CHECK-PROP-LABEL: func @vector_insert_2d_distr_pos(
+//  CHECK-PROP-SAME:     %[[LANEID:.*]]: index
+//       CHECK-PROP:   %[[C19:.*]] = arith.constant 19 : index
+//       CHECK-PROP:   %[[W:.*]]:2 = vector.warp_execute_on_lane_0{{.*}} -> (vector<96xf32>, vector<4x96xf32>)
+//       CHECK-PROP:     %[[VEC:.*]] = "some_def"
+//       CHECK-PROP:     %[[VAL:.*]] = "another_def"
+//       CHECK-PROP:     vector.yield %[[VAL]], %[[VEC]]
+//       CHECK-PROP:   %[[SHOULD_INSERT:.*]] = arith.cmpi eq, %[[LANEID]], %[[C19]]
+//       CHECK-PROP:   %[[R:.*]] = scf.if %[[SHOULD_INSERT]] -> (vector<4x96xf32>) {
+//       CHECK-PROP:     %[[INSERT:.*]] = vector.insert %[[W]]#0, %[[W]]#1 [3] : vector<96xf32> into vector<4x96xf32>
+//       CHECK-PROP:     scf.yield %[[INSERT]]
+//       CHECK-PROP:   } else {
+//       CHECK-PROP:     scf.yield %[[W]]#1
+//       CHECK-PROP:   }
+//       CHECK-PROP:   return %[[R]]
+func.func @vector_insert_2d_distr_pos(%laneid: index) -> (vector<4x96xf32>) {
+  %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<4x96xf32>) {
+    %0 = "some_def"() : () -> (vector<128x96xf32>)
+    %s = "another_def"() : () -> (vector<96xf32>)
+    %1 = vector.insert %s, %0[79] : vector<96xf32> into vector<128x96xf32>
+    vector.yield %1 : vector<128x96xf32>
+  }
+  return %r : vector<4x96xf32>
+}
+
+// -----
+
+// CHECK-PROP-LABEL: func @vector_insert_2d_broadcast(
+//  CHECK-PROP-SAME:     %[[LANEID:.*]]: index
+//       CHECK-PROP:   %[[W:.*]]:2 = vector.warp_execute_on_lane_0{{.*}} -> (vector<96xf32>, vector<4x96xf32>)
+//       CHECK-PROP:     %[[VEC:.*]] = "some_def"
+//       CHECK-PROP:     %[[VAL:.*]] = "another_def"
+//       CHECK-PROP:     vector.yield %[[VAL]], %[[VEC]]
+//       CHECK-PROP:   %[[INSERT:.*]] = vector.insert %[[W]]#0, %[[W]]#1 [2] : vector<96xf32> into vector<4x96xf32>
+//       CHECK-PROP:   return %[[INSERT]]
+func.func @vector_insert_2d_broadcast(%laneid: index) -> (vector<4x96xf32>) {
+  %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<4x96xf32>) {
+    %0 = "some_def"() : () -> (vector<4x96xf32>)
+    %s = "another_def"() : () -> (vector<96xf32>)
+    %1 = vector.insert %s, %0[2] : vector<96xf32> into vector<4x96xf32>
+    vector.yield %1 : vector<4x96xf32>
+  }
+  return %r : vector<4x96xf32>
+}


        


More information about the Mlir-commits mailing list