[Mlir-commits] [mlir] 2fc3c5c - [mlir][vector] Prevent duplicating operations during vector distribute

Thomas Raoux llvmlistbot at llvm.org
Thu Feb 9 00:26:48 PST 2023


Author: Thomas Raoux
Date: 2023-02-09T08:26:35Z
New Revision: 2fc3c5c34c4c0ce94a217717a469620e06325fb0

URL: https://github.com/llvm/llvm-project/commit/2fc3c5c34c4c0ce94a217717a469620e06325fb0
DIFF: https://github.com/llvm/llvm-project/commit/2fc3c5c34c4c0ce94a217717a469620e06325fb0.diff

LOG: [mlir][vector] Prevent duplicating operations during vector distribute

We should distribute ops that have other uses than the yield op as this
would duplicate those ops.

Differential Revision: https://reviews.llvm.org/D143629

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
    mlir/test/Dialect/Vector/vector-warp-distribute.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 48995afa6876d..6005f377c9adf 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -228,8 +228,8 @@ static bool canBeHoisted(Operation *op,
          isMemoryEffectFree(op) && op->getNumRegions() == 0;
 }
 
-/// Return a value yielded by `warpOp` which statifies the filter lamdba
-/// condition and is not dead.
+/// Return a value yielded by `warpOp` with no other uses which statifies the
+/// filter lamdba condition and is not dead.
 static OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp,
                                 const std::function<bool(Operation *)> &fn) {
   auto yield = cast<vector::YieldOp>(
@@ -237,7 +237,7 @@ static OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp,
   for (OpOperand &yieldOperand : yield->getOpOperands()) {
     Value yieldValues = yieldOperand.get();
     Operation *definedOp = yieldValues.getDefiningOp();
-    if (definedOp && fn(definedOp)) {
+    if (definedOp && definedOp->hasOneUse() && fn(definedOp)) {
       if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
         return &yieldOperand;
     }

diff  --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index b5087feaed028..3ca585b4506b5 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1109,3 +1109,22 @@ func.func @vector_insert_2d_broadcast(%laneid: index) -> (vector<4x96xf32>) {
   }
   return %r : vector<4x96xf32>
 }
+// -----
+
+// Verify that we don't duplicate the reduction.
+// CHECK-PROP-LABEL: func @vector_reduction_no_duplicate(
+//  CHECK-PROP-SAME:     %[[laneid:.*]]: index)
+//       CHECK-PROP:   %[[warp_op:.*]] = vector.warp_execute_on_lane_0(%[[laneid]])[32] -> (f32) {
+//       CHECK-PROP:     vector.reduction
+//       CHECK-PROP:     vector.yield %{{.*}} : f32
+//       CHECK-PROP:   }
+//  CHECK-PROP-NEXT:   return %{{.*}} : f32
+func.func @vector_reduction_no_duplicate(%laneid: index) -> (f32) {
+  %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
+    %0 = "some_def"() : () -> (vector<32xf32>)
+    %1 = vector.reduction <add>, %0 : vector<32xf32> into f32
+    "some_blocking_use"(%1) : (f32) -> ()
+    vector.yield %1 : f32
+  }
+  return %r : f32
+}
\ No newline at end of file


        


More information about the Mlir-commits mailing list