[Mlir-commits] [mlir] [mlir][AMDGPU] Improve masked_load(..., broadcast(...), ...) handling (PR #159635)
Krzysztof Drewniak
llvmlistbot at llvm.org
Fri Sep 19 00:19:47 PDT 2025
================
@@ -62,13 +61,25 @@ static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc,
return load;
}
-/// Check if the given value comes from a broadcasted i1 condition.
-static FailureOr<Value> matchFullMask(OpBuilder &b, Value val) {
+/// If the given value is the broadcast of a non-constant scalar, return that
+/// scalar, extracting it from length-1 vectors if necessary.
+static FailureOr<Value> getFullMask(RewriterBase &rw, Value val) {
+ while (auto shapeCast = val.getDefiningOp<vector::ShapeCastOp>())
+ val = shapeCast.getSource();
+ auto splatOp = val.getDefiningOp<vector::SplatOp>();
+ if (splatOp)
+ return splatOp.getInput();
auto broadcastOp = val.getDefiningOp<vector::BroadcastOp>();
if (!broadcastOp)
return failure();
- if (isa<VectorType>(broadcastOp.getSourceType()))
- return failure();
+ if (auto sourceVecType = dyn_cast<VectorType>(broadcastOp.getSourceType())) {
+ if (sourceVecType.isScalable() || sourceVecType.getNumElements() != 1)
+ return failure();
+ SmallVector<int64_t> indices(sourceVecType.getRank(), 0);
+ Value scalarSource = vector::ExtractOp::create(
+ rw, broadcastOp.getLoc(), broadcastOp.getSource(), indices);
+ return scalarSource;
+ }
----------------
krzysz00 wrote:
But after unrolling or flattening that's aren't any broadcasts?
We might need to move broadcast lowering much, much later in out downstream pipeline if you want to keep this as is?
https://github.com/llvm/llvm-project/pull/159635
More information about the Mlir-commits
mailing list