[Mlir-commits] [mlir] 1757164 - [mlir][vector] Add distribution for extract from 0d vector

Thomas Raoux llvmlistbot at llvm.org
Fri Oct 14 16:17:26 PDT 2022


Author: Thomas Raoux
Date: 2022-10-14T23:06:42Z
New Revision: 1757164eed244b221c6c078baa7c836e4809e133

URL: https://github.com/llvm/llvm-project/commit/1757164eed244b221c6c078baa7c836e4809e133
DIFF: https://github.com/llvm/llvm-project/commit/1757164eed244b221c6c078baa7c836e4809e133.diff

LOG: [mlir][vector] Add distribution for extract from 0d vector

Differential Revision: https://reviews.llvm.org/D135994

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
    mlir/test/Dialect/Vector/vector-warp-distribute.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 3c4f20fd5f9c2..f730044abcf85 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -895,6 +895,34 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
   }
 };
 
+/// Pattern to move out vector.extractelement of 0-D tensors. Those don't
+/// need to be distributed and can just be propagated outside of the region.
+struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
+  using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
+  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *operand = getWarpResult(warpOp, [](Operation *op) {
+      return isa<vector::ExtractElementOp>(op);
+    });
+    if (!operand)
+      return failure();
+    unsigned int operandNumber = operand->getOperandNumber();
+    auto extractOp = operand->get().getDefiningOp<vector::ExtractElementOp>();
+    if (extractOp.getVectorType().getRank() != 0)
+      return failure();
+    Location loc = extractOp.getLoc();
+    SmallVector<size_t> newRetIndices;
+    WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, {extractOp.getVector()}, {extractOp.getVectorType()},
+        newRetIndices);
+    rewriter.setInsertionPointAfter(newWarpOp);
+    Value newExtract = rewriter.create<vector::ExtractElementOp>(
+        loc, newWarpOp->getResult(newRetIndices[0]));
+    newWarpOp->getResult(operandNumber).replaceAllUsesWith(newExtract);
+    return success();
+  }
+};
+
 /// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
 /// the scf.ForOp is the last operation in the region so that it doesn't change
 /// the order of execution. This creates a new scf.for region after the
@@ -1093,8 +1121,9 @@ void mlir::vector::populateDistributeTransferWriteOpPatterns(
 void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
   patterns.add<WarpOpElementwise, WarpOpTransferRead, WarpOpDeadResult,
-               WarpOpBroadcast, WarpOpExtract, WarpOpForwardOperand,
-               WarpOpScfForOp, WarpOpConstant>(patterns.getContext(), benefit);
+               WarpOpBroadcast, WarpOpExtract, WarpOpExtractElement,
+               WarpOpForwardOperand, WarpOpScfForOp, WarpOpConstant>(
+      patterns.getContext(), benefit);
 }
 
 void mlir::vector::populateDistributeReduction(

diff  --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 3978d94f377b4..49c36fe18c90d 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -632,6 +632,24 @@ func.func @vector_extract_simple(%laneid: index) -> (f32) {
 
 // -----
 
+// CHECK-PROP-LABEL: func.func @vector_extractelement_simple(
+//       CHECK-PROP:   %[[R:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<f32>) {
+//       CHECK-PROP:     %[[V:.*]] = "some_def"() : () -> vector<f32>
+//       CHECK-PROP:     vector.yield %[[V]] : vector<f32>
+//       CHECK-PROP:   }
+//       CHECK-PROP:   %[[E:.*]] = vector.extractelement %[[R]][] : vector<f32>
+//       CHECK-PROP:   return %[[E]] : f32
+func.func @vector_extractelement_simple(%laneid: index) -> (f32) {
+  %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
+    %0 = "some_def"() : () -> (vector<f32>)
+    %1 = vector.extractelement %0[] : vector<f32>
+    vector.yield %1 : f32
+  }
+  return %r : f32
+}
+
+// -----
+
 // CHECK-PROP:   func @lane_dependent_warp_propagate_read
 //  CHECK-PROP-SAME:   %[[ID:.*]]: index
 func.func @lane_dependent_warp_propagate_read(


        


More information about the Mlir-commits mailing list