[PATCH] D152154: [mlir][Vector] Fix a propagation bug with broadcast

Quentin Colombet via Phabricator via llvm-commits llvm-commits at lists.llvm.org
Tue Jun 6 06:11:58 PDT 2023


qcolombet updated this revision to Diff 528824.
qcolombet added a comment.

Rebase


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D152154/new/

https://reviews.llvm.org/D152154

Files:
  mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
  mlir/test/Dialect/Vector/vector-warp-distribute.mlir


Index: mlir/test/Dialect/Vector/vector-warp-distribute.mlir
===================================================================
--- mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1152,3 +1152,24 @@
   }
   return %18 : vector<2xf32>
 }
+
+// -----
+
+// Check that we don't fold vector.broadcast when each thread doesn't get the
+// same value.
+
+// CHECK-PROP-LABEL: func @dont_fold_vector_broadcast(
+//       CHECK-PROP:   %[[r:.*]] = vector.warp_execute_on_lane_0{{.*}} -> (vector<1x2xf32>)
+//       CHECK-PROP:     %[[some_def:.*]] = "some_def"
+//       CHECK-PROP:     %[[broadcast:.*]] = vector.broadcast %[[some_def]] : vector<64xf32> to vector<1x64xf32>
+//       CHECK-PROP:     vector.yield %[[broadcast]] : vector<1x64xf32>
+//       CHECK-PROP:   vector.print %[[r]] : vector<1x2xf32>
+func.func @dont_fold_vector_broadcast(%laneid: index) {
+  %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<1x2xf32>) {
+    %0 = "some_def"() : () -> (vector<64xf32>)
+    %1 = vector.broadcast %0 : vector<64xf32> to vector<1x64xf32>
+    vector.yield %1 : vector<1x64xf32>
+  }
+  vector.print %r : vector<1x2xf32>
+  return
+}
Index: mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
===================================================================
--- mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -896,10 +896,19 @@
     Location loc = broadcastOp.getLoc();
     auto destVecType =
         cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
+    Value broadcastSrc = broadcastOp.getSource();
+    Type broadcastSrcType = broadcastSrc.getType();
+
+    // Check that the broadcast actually spans a set of values uniformly across
+    // all threads. In other words, check that each thread can reconstruct
+    // their own broadcast.
+    // For that we simply check that the broadcast we want to build makes sense.
+    if (vector::isBroadcastableTo(broadcastSrcType, destVecType) !=
+        vector::BroadcastableToResult::Success)
+      return failure();
     SmallVector<size_t> newRetIndices;
     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
-        rewriter, warpOp, {broadcastOp.getSource()},
-        {broadcastOp.getSource().getType()}, newRetIndices);
+        rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices);
     rewriter.setInsertionPointAfter(newWarpOp);
     Value broadcasted = rewriter.create<vector::BroadcastOp>(
         loc, destVecType, newWarpOp->getResult(newRetIndices[0]));


-------------- next part --------------
A non-text attachment was scrubbed...
Name: D152154.528824.patch
Type: text/x-patch
Size: 2625 bytes
Desc: not available
URL: <http://lists.llvm.org/pipermail/llvm-commits/attachments/20230606/79090f3d/attachment.bin>


More information about the llvm-commits mailing list