[PATCH] D152154: [mlir][Vector] Fix a propagation bug with broadcast
Quentin Colombet via Phabricator via llvm-commits
llvm-commits at lists.llvm.org
Tue Jun 6 07:53:51 PDT 2023
This revision was automatically updated to reflect the committed changes.
Closed by commit rG1dd00d39037b: [mlir][Vector] Fix a propagation bug with broadcast (authored by qcolombet).
Repository:
rG LLVM Github Monorepo
CHANGES SINCE LAST ACTION
https://reviews.llvm.org/D152154/new/
https://reviews.llvm.org/D152154
Files:
mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
mlir/test/Dialect/Vector/vector-warp-distribute.mlir
Index: mlir/test/Dialect/Vector/vector-warp-distribute.mlir
===================================================================
--- mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1152,3 +1152,24 @@
}
return %18 : vector<2xf32>
}
+
+// -----
+
+// Check that we don't fold vector.broadcast when each thread doesn't get the
+// same value.
+
+// CHECK-PROP-LABEL: func @dont_fold_vector_broadcast(
+// CHECK-PROP: %[[r:.*]] = vector.warp_execute_on_lane_0{{.*}} -> (vector<1x2xf32>)
+// CHECK-PROP: %[[some_def:.*]] = "some_def"
+// CHECK-PROP: %[[broadcast:.*]] = vector.broadcast %[[some_def]] : vector<64xf32> to vector<1x64xf32>
+// CHECK-PROP: vector.yield %[[broadcast]] : vector<1x64xf32>
+// CHECK-PROP: vector.print %[[r]] : vector<1x2xf32>
+func.func @dont_fold_vector_broadcast(%laneid: index) {
+ %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<1x2xf32>) {
+ %0 = "some_def"() : () -> (vector<64xf32>)
+ %1 = vector.broadcast %0 : vector<64xf32> to vector<1x64xf32>
+ vector.yield %1 : vector<1x64xf32>
+ }
+ vector.print %r : vector<1x2xf32>
+ return
+}
Index: mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
===================================================================
--- mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -896,10 +896,19 @@
Location loc = broadcastOp.getLoc();
auto destVecType =
cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
+ Value broadcastSrc = broadcastOp.getSource();
+ Type broadcastSrcType = broadcastSrc.getType();
+
+ // Check that the broadcast actually spans a set of values uniformly across
+ // all threads. In other words, check that each thread can reconstruct
+ // their own broadcast.
+ // For that we simply check that the broadcast we want to build makes sense.
+ if (vector::isBroadcastableTo(broadcastSrcType, destVecType) !=
+ vector::BroadcastableToResult::Success)
+ return failure();
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
- rewriter, warpOp, {broadcastOp.getSource()},
- {broadcastOp.getSource().getType()}, newRetIndices);
+ rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices);
rewriter.setInsertionPointAfter(newWarpOp);
Value broadcasted = rewriter.create<vector::BroadcastOp>(
loc, destVecType, newWarpOp->getResult(newRetIndices[0]));
-------------- next part --------------
A non-text attachment was scrubbed...
Name: D152154.528865.patch
Type: text/x-patch
Size: 2625 bytes
Desc: not available
URL: <http://lists.llvm.org/pipermail/llvm-commits/attachments/20230606/1461d0b3/attachment.bin>
More information about the llvm-commits
mailing list