[Mlir-commits] [mlir] [mlir][AMDGPU] Improve masked_load(..., broadcast(...), ...) handling (PR #159635)
    Kunwar Grover 
    llvmlistbot at llvm.org
       
    Fri Sep 19 01:10:03 PDT 2025
    
    
  
================
@@ -62,13 +61,25 @@ static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc,
   return load;
 }
 
-/// Check if the given value comes from a broadcasted i1 condition.
-static FailureOr<Value> matchFullMask(OpBuilder &b, Value val) {
+/// If the given value is the broadcast of a non-constant scalar, return that
+/// scalar, extracting it from length-1 vectors if necessary.
+static FailureOr<Value> getFullMask(RewriterBase &rw, Value val) {
+  while (auto shapeCast = val.getDefiningOp<vector::ShapeCastOp>())
+    val = shapeCast.getSource();
+  auto splatOp = val.getDefiningOp<vector::SplatOp>();
+  if (splatOp)
+    return splatOp.getInput();
   auto broadcastOp = val.getDefiningOp<vector::BroadcastOp>();
   if (!broadcastOp)
     return failure();
-  if (isa<VectorType>(broadcastOp.getSourceType()))
-    return failure();
+  if (auto sourceVecType = dyn_cast<VectorType>(broadcastOp.getSourceType())) {
+    if (sourceVecType.isScalable() || sourceVecType.getNumElements() != 1)
+      return failure();
+    SmallVector<int64_t> indices(sourceVecType.getRank(), 0);
+    Value scalarSource = vector::ExtractOp::create(
+        rw, broadcastOp.getLoc(), broadcastOp.getSource(), indices);
+    return scalarSource;
+  }
----------------
Groverkss wrote:
Unrolling/Flattening preserve broadcasts from scalar -> vector. If they are not, we should fix it.
https://github.com/llvm/llvm-project/pull/159635
    
    
More information about the Mlir-commits
mailing list