[PATCH] D152154: [mlir][Vector] Fix a propagation bug with broadcast

Quentin Colombet via Phabricator via llvm-commits llvm-commits at lists.llvm.org
Mon Jun 5 06:37:38 PDT 2023


qcolombet created this revision.
qcolombet added reviewers: ThomasRaoux, nicolasvasilache, springerm.
qcolombet added a project: MLIR.
Herald added subscribers: bviyer, Moerafaat, zero9178, bzcheeseman, sdasgup3, wenzhicui, wrengr, jsetoain, cota, teijeong, rdzhabarov, tatianashp, msifontes, jurahul, Kayjukh, grosul1, Joonsoo, liufengdb, aartbik, mgester, arpith-jacob, antiagainst, shauheen, rriddle, mehdi_amini.
Herald added a reviewer: aartbik.
Herald added a project: All.
qcolombet requested review of this revision.
Herald added a subscriber: stephenneuendorffer.
Herald added a reviewer: dcaballe.

In the vector distribute patterns, we used to move
`vector.broadcast`s out of `vector.warp_execute_on_lane0`s
irrespectively of how they were defined.

This could create broadcast operations with invalid semantic.
E.g.,

  %r = warop ...[32] ... -> vector<1x2xf32> {
    %val = broadcast %in : vector<64xf32> to vetor<1x64xf32>
    vector.yield %val : vector<1x64xf32>
  }

>
=

  %r = warop ...[32] ... -> vector<64xf32> {
    vector.yield %in : vector<64xf32>
  }
  // Broadcasting to a narrower type!
  broadcast %r : vector<64xf32> to vector<1x2xf32>

The root issue is we are trying to broadcast something that is not the same
for each thread, so there is actually nothing to propagate here.

The fix checks that the broadcast we want to create actually makes sense.


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D152154

Files:
  mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
  mlir/test/Dialect/Vector/vector-warp-distribute.mlir


Index: mlir/test/Dialect/Vector/vector-warp-distribute.mlir
===================================================================
--- mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1152,3 +1152,24 @@
   }
   return %18 : vector<2xf32>
 }
+
+// -----
+
+// Check that we don't fold vector.broadcast when each thread doesn't get the
+// same value.
+
+// CHECK-PROP-LABEL: func @dont_fold_vector_broadcast(
+//       CHECK-PROP:   %[[r:.*]] = vector.warp_execute_on_lane_0{{.*}} -> (vector<1x2xf32>)
+//       CHECK-PROP:     %[[some_def:.*]] = "some_def"
+//       CHECK-PROP:     %[[broadcast:.*]] = vector.broadcast %[[some_def]] : vector<64xf32> to vector<1x64xf32>
+//       CHECK-PROP:     vector.yield %[[broadcast]] : vector<1x64xf32>
+//       CHECK-PROP:   vector.print %[[r]] : vector<1x2xf32>
+func.func @dont_fold_vector_broadcast(%laneid: index) {
+  %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<1x2xf32>) {
+    %0 = "some_def"() : () -> (vector<64xf32>)
+    %1 = vector.broadcast %0 : vector<64xf32> to vector<1x64xf32>
+    vector.yield %1 : vector<1x64xf32>
+  }
+  vector.print %r : vector<1x2xf32>
+  return
+}
Index: mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
===================================================================
--- mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -896,10 +896,19 @@
     Location loc = broadcastOp.getLoc();
     auto destVecType =
         cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
+    Value broadcastSrc = broadcastOp.getSource();
+    Type broadcastSrcType = broadcastSrc.getType();
+
+    // Check that the broadcast actually spans a set of values uniformly across
+    // all threads. In other words, check that each thread can reconstruct
+    // their own broadcast.
+    // For that we simply check that the broadcast we want to build makes sense.
+    if (vector::isBroadcastableTo(broadcastSrcType, destVecType) !=
+        vector::BroadcastableToResult::Success)
+      return failure();
     SmallVector<size_t> newRetIndices;
     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
-        rewriter, warpOp, {broadcastOp.getSource()},
-        {broadcastOp.getSource().getType()}, newRetIndices);
+        rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices);
     rewriter.setInsertionPointAfter(newWarpOp);
     Value broadcasted = rewriter.create<vector::BroadcastOp>(
         loc, destVecType, newWarpOp->getResult(newRetIndices[0]));


-------------- next part --------------
A non-text attachment was scrubbed...
Name: D152154.528405.patch
Type: text/x-patch
Size: 2625 bytes
Desc: not available
URL: <http://lists.llvm.org/pipermail/llvm-commits/attachments/20230605/76bc31cd/attachment.bin>


More information about the llvm-commits mailing list