[Mlir-commits] [mlir] [mlir][vector] Clarify the semantics of BroadcastOp (PR #101928)
Jakub Kuderski
llvmlistbot at llvm.org
Tue Aug 6 07:35:44 PDT 2024
================
@@ -2390,13 +2390,29 @@ mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType,
// Source has an exact match or singleton value for all trailing dimensions
// (all leading dimensions are simply duplicated).
int64_t lead = dstRank - srcRank;
- for (int64_t r = 0; r < srcRank; ++r) {
- int64_t srcDim = srcVectorType.getDimSize(r);
- int64_t dstDim = dstVectorType.getDimSize(lead + r);
- if (srcDim != 1 && srcDim != dstDim) {
- if (mismatchingDims) {
- mismatchingDims->first = srcDim;
- mismatchingDims->second = dstDim;
+ for (int64_t dimIdx = 0; dimIdx < srcRank; ++dimIdx) {
+ bool mismatch = false;
+
+ // Check fixed-width dims.
+ int64_t srcDim = srcVectorType.getDimSize(dimIdx);
+ int64_t dstDim = dstVectorType.getDimSize(lead + dimIdx);
+ if (srcDim != 1 && srcDim != dstDim)
+ mismatch = true;
+
+ // Check scalable flags.
+ bool srcDimScalableFlag = srcVectorType.getScalableDims()[dimIdx];
+ bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + dimIdx];
+ if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) ||
+ (srcDimScalableFlag != dstDimScalableFlag))
+ mismatch = true;
+
+ if (mismatch) {
+ if (mismatchingDims != nullptr) {
----------------
kuhar wrote:
This helps, thanks!
https://github.com/llvm/llvm-project/pull/101928
More information about the Mlir-commits
mailing list