[Mlir-commits] [mlir] f48ce52 - [mlir][vector] Pattern to clean up vector.extract during distribution

Thomas Raoux llvmlistbot at llvm.org
Thu Jul 14 10:07:48 PDT 2022


Author: Thomas Raoux
Date: 2022-07-14T17:07:32Z
New Revision: f48ce52c4c2de8dc80e1bdd5caebbdb9f9db00ce

URL: https://github.com/llvm/llvm-project/commit/f48ce52c4c2de8dc80e1bdd5caebbdb9f9db00ce
DIFF: https://github.com/llvm/llvm-project/commit/f48ce52c4c2de8dc80e1bdd5caebbdb9f9db00ce.diff

LOG: [mlir][vector] Pattern to clean up vector.extract during distribution

This prevents blocking propagation when converting between scalar and
vector<1>

Differential Revision: https://reviews.llvm.org/D129782

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
    mlir/test/Dialect/Vector/vector-warp-distribute.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 3fcd42e981dbd..8ecb8986fd9fc 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -719,6 +719,33 @@ struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> {
   }
 };
 
+/// Pattern to move out vector.extract of single element vector. Those don't
+/// need to be distributed and can just be propagated outside of the region.
+struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
+  using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
+  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *operand = getWarpResult(
+        warpOp, [](Operation *op) { return isa<vector::ExtractOp>(op); });
+    if (!operand)
+      return failure();
+    unsigned int operandNumber = operand->getOperandNumber();
+    auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>();
+    if (extractOp.getVectorType().getNumElements() != 1)
+      return failure();
+    Location loc = extractOp.getLoc();
+    SmallVector<size_t> newRetIndices;
+    WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, {extractOp.getVector()}, {extractOp.getVectorType()},
+        newRetIndices);
+    rewriter.setInsertionPointAfter(newWarpOp);
+    Value newExtract = rewriter.create<vector::ExtractOp>(
+        loc, newWarpOp->getResult(newRetIndices[0]), extractOp.getPosition());
+    newWarpOp->getResult(operandNumber).replaceAllUsesWith(newExtract);
+    return success();
+  }
+};
+
 /// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
 /// the scf.ForOp is the last operation in the region so that it doesn't change
 /// the order of execution. This creates a new scf.for region after the
@@ -915,8 +942,8 @@ void mlir::vector::populateDistributeTransferWriteOpPatterns(
 void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
     RewritePatternSet &patterns) {
   patterns.add<WarpOpElementwise, WarpOpTransferRead, WarpOpDeadResult,
-               WarpOpBroadcast, WarpOpForwardOperand, WarpOpScfForOp,
-               WarpOpConstant>(patterns.getContext());
+               WarpOpBroadcast, WarpOpExtract, WarpOpForwardOperand,
+               WarpOpScfForOp, WarpOpConstant>(patterns.getContext());
 }
 
 void mlir::vector::populateDistributeReduction(

diff  --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 6d377920eeca5..1b87094c39819 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -611,3 +611,21 @@ func.func @warp_constant(%laneid: index) -> (vector<1xf32>) {
   }
   return %r : vector<1xf32>
 }
+
+// -----
+
+// CHECK-PROP-LABEL: func.func @vector_extract_simple(
+//       CHECK-PROP:   %[[R:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1xf32>) {
+//       CHECK-PROP:     %[[V:.*]] = "some_def"() : () -> vector<1xf32>
+//       CHECK-PROP:     vector.yield %[[V]] : vector<1xf32>
+//       CHECK-PROP:   }
+//       CHECK-PROP:   %[[E:.*]] = vector.extract %[[R]][0] : vector<1xf32>
+//       CHECK-PROP:   return %[[E]] : f32
+func.func @vector_extract_simple(%laneid: index) -> (f32) {
+  %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
+    %0 = "some_def"() : () -> (vector<1xf32>)
+    %1 = vector.extract %0[0] : vector<1xf32>
+    vector.yield %1 : f32
+  }
+  return %r : f32
+}


        


More information about the Mlir-commits mailing list