[Mlir-commits] [mlir] [vector][linearize] Refactor code to push target bit width out of patterns (PR #136581)
James Newling
llvmlistbot at llvm.org
Tue Apr 22 10:00:57 PDT 2025
================
@@ -518,24 +441,103 @@ struct LinearizeVectorBitCast final
if (!resType)
return rewriter.notifyMatchFailure(loc, "can't convert return type.");
- if (!isLessThanTargetBitWidth(castOp, targetVectorBitWidth))
- return rewriter.notifyMatchFailure(
- loc, "Can't flatten since targetBitWidth <= OpSize");
-
rewriter.replaceOpWithNewOp<vector::BitCastOp>(castOp, resType,
adaptor.getSource());
return mlir::success();
}
-
-private:
- unsigned targetVectorBitWidth;
};
} // namespace
-void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
- TypeConverter &typeConverter, RewritePatternSet &patterns,
- ConversionTarget &target, unsigned targetBitWidth) {
+/// If `type` is VectorType with trailing dimension of (bit) size greater than
+/// or equal to `targetBitWidth`, its defining op is considered legal.
+static bool legalBecauseOfBitwidth(Type type, unsigned targetBitWidth) {
+
+ VectorType vecType = dyn_cast<VectorType>(type);
+
+ if (!vecType)
+ return true;
+
+ // The width of the type 'index' is unbounded (and therefore potentially above
+ // the target width).
+ if (vecType.getElementType().isIndex())
+ return true;
+
+ unsigned finalDimSize =
+ vecType.getRank() == 0 ? 0 : vecType.getShape().back();
+
+ unsigned trailingVecDimBitWidth =
+ finalDimSize * vecType.getElementTypeBitWidth();
+
+ return trailingVecDimBitWidth >= targetBitWidth;
+}
+
+static SmallVector<std::pair<Type, unsigned>>
+getChecksForBitwidth(Operation *op, unsigned targetBitWidth) {
+
+ if (auto insertOp = dyn_cast<vector::InsertOp>(op)) {
+ auto w = targetBitWidth < std::numeric_limits<unsigned>::max()
+ ? targetBitWidth + 1
+ : targetBitWidth;
+ return {{insertOp.getValueToStoreType(), w}};
+ }
+ auto resultTypes = op->getResultTypes();
+ SmallVector<std::pair<Type, unsigned>> resultsWithBitWidth;
+ resultsWithBitWidth.reserve(resultTypes.size());
+ for (Type type : resultTypes) {
+ resultsWithBitWidth.push_back({type, targetBitWidth});
+ }
+ return resultsWithBitWidth;
+}
+
+/// Return true if the operation `op` does not support scalable vectors and
+/// has at least 1 scalable vector result.
+static bool legalBecauseScalable(Operation *op) {
+
+ bool scalableSupported = op->hasTrait<OpTrait::ConstantLike>() ||
+ op->hasTrait<OpTrait::Vectorizable>() ||
+ isa<vector::BitCastOp>(op);
----------------
newling wrote:
I was just trying to do a pure NFC refactor, and these were the only ops that support scalable vectors. Flipping the logic is a good idea, will make this clearer.
https://github.com/llvm/llvm-project/pull/136581
More information about the Mlir-commits
mailing list