[Mlir-commits] [mlir] 1dd00d3 - [mlir][Vector] Fix a propagation bug with broadcast

Quentin Colombet llvmlistbot at llvm.org
Tue Jun 6 07:53:41 PDT 2023


Author: Quentin Colombet
Date: 2023-06-06T16:40:15+02:00
New Revision: 1dd00d39037b14e06555a79a397ee1e85d787db9

URL: https://github.com/llvm/llvm-project/commit/1dd00d39037b14e06555a79a397ee1e85d787db9
DIFF: https://github.com/llvm/llvm-project/commit/1dd00d39037b14e06555a79a397ee1e85d787db9.diff

LOG: [mlir][Vector] Fix a propagation bug with broadcast

In the vector distribute patterns, we used to move
`vector.broadcast`s out of `vector.warp_execute_on_lane0`s
irrespectively of how they were defined.

This could create broadcast operations with invalid semantic.
E.g.,
```
%r = warop ...[32] ... -> vector<1x2xf32> {
  %val = broadcast %in : vector<64xf32> to vetor<1x64xf32>
  vector.yield %val : vector<1x64xf32>
}
```
=>
```
%r = warop ...[32] ... -> vector<64xf32> {
  vector.yield %in : vector<64xf32>
}
// Broadcasting to a narrower type!
broadcast %r : vector<64xf32> to vector<1x2xf32>
```

The root issue is we are trying to broadcast something that is not the same
for each thread, so there is actually nothing to propagate here.

The fix checks that the broadcast we want to create actually makes sense.

Differential Revision: https://reviews.llvm.org/D152154

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
    mlir/test/Dialect/Vector/vector-warp-distribute.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 584b57bb24dd6..72aae4956b3e2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -896,10 +896,19 @@ struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> {
     Location loc = broadcastOp.getLoc();
     auto destVecType =
         cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
+    Value broadcastSrc = broadcastOp.getSource();
+    Type broadcastSrcType = broadcastSrc.getType();
+
+    // Check that the broadcast actually spans a set of values uniformly across
+    // all threads. In other words, check that each thread can reconstruct
+    // their own broadcast.
+    // For that we simply check that the broadcast we want to build makes sense.
+    if (vector::isBroadcastableTo(broadcastSrcType, destVecType) !=
+        vector::BroadcastableToResult::Success)
+      return failure();
     SmallVector<size_t> newRetIndices;
     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
-        rewriter, warpOp, {broadcastOp.getSource()},
-        {broadcastOp.getSource().getType()}, newRetIndices);
+        rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices);
     rewriter.setInsertionPointAfter(newWarpOp);
     Value broadcasted = rewriter.create<vector::BroadcastOp>(
         loc, destVecType, newWarpOp->getResult(newRetIndices[0]));

diff  --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index f376977adf4da..28efd5721524e 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1152,3 +1152,24 @@ func.func @transfer_read_no_prop(%in2: vector<1x2xindex>, %ar1 :  memref<1x4x2xi
   }
   return %18 : vector<2xf32>
 }
+
+// -----
+
+// Check that we don't fold vector.broadcast when each thread doesn't get the
+// same value.
+
+// CHECK-PROP-LABEL: func @dont_fold_vector_broadcast(
+//       CHECK-PROP:   %[[r:.*]] = vector.warp_execute_on_lane_0{{.*}} -> (vector<1x2xf32>)
+//       CHECK-PROP:     %[[some_def:.*]] = "some_def"
+//       CHECK-PROP:     %[[broadcast:.*]] = vector.broadcast %[[some_def]] : vector<64xf32> to vector<1x64xf32>
+//       CHECK-PROP:     vector.yield %[[broadcast]] : vector<1x64xf32>
+//       CHECK-PROP:   vector.print %[[r]] : vector<1x2xf32>
+func.func @dont_fold_vector_broadcast(%laneid: index) {
+  %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<1x2xf32>) {
+    %0 = "some_def"() : () -> (vector<64xf32>)
+    %1 = vector.broadcast %0 : vector<64xf32> to vector<1x64xf32>
+    vector.yield %1 : vector<1x64xf32>
+  }
+  vector.print %r : vector<1x2xf32>
+  return
+}


        


More information about the Mlir-commits mailing list