[Mlir-commits] [mlir] [mlir][vector] VectorLinearize: `ub.poison` support (PR #128612)
Diego Caballero
llvmlistbot at llvm.org
Tue Feb 25 16:30:39 PST 2025
================
@@ -57,40 +58,67 @@ static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned targetBitWidth) {
}
namespace {
-struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
- using OpConversionPattern::OpConversionPattern;
- LinearizeConstant(
+struct LinearizeConstantLike final
+ : OpTraitConversionPattern<OpTrait::ConstantLike> {
+ using OpTraitConversionPattern::OpTraitConversionPattern;
+
+ LinearizeConstantLike(
const TypeConverter &typeConverter, MLIRContext *context,
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
PatternBenefit benefit = 1)
- : OpConversionPattern(typeConverter, context, benefit),
+ : OpTraitConversionPattern(typeConverter, context, benefit),
targetVectorBitWidth(targetVectBitWidth) {}
LogicalResult
- matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- Location loc = constOp.getLoc();
+ Location loc = op->getLoc();
+ if (op->getNumResults() != 1)
+ return rewriter.notifyMatchFailure(loc, "expected 1 result");
+
auto resType =
- getTypeConverter()->convertType<VectorType>(constOp.getType());
+ getTypeConverter()->convertType<VectorType>(op->getResult(0).getType());
if (!resType)
return rewriter.notifyMatchFailure(loc, "can't convert return type");
- if (resType.isScalable() && !isa<SplatElementsAttr>(constOp.getValue()))
- return rewriter.notifyMatchFailure(
- loc,
- "Cannot linearize a constant scalable vector that's not a splat");
-
- if (!isLessThanTargetBitWidth(constOp, targetVectorBitWidth))
+ if (!isLessThanTargetBitWidth(op, targetVectorBitWidth))
return rewriter.notifyMatchFailure(
loc, "Can't flatten since targetBitWidth <= OpSize");
- auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
- if (!dstElementsAttr)
- return rewriter.notifyMatchFailure(loc, "unsupported attr type");
- dstElementsAttr = dstElementsAttr.reshape(resType);
- rewriter.replaceOpWithNewOp<arith::ConstantOp>(constOp, resType,
- dstElementsAttr);
- return success();
+ StringAttr attrName = rewriter.getStringAttr("value");
+ Attribute value = op->getAttr(attrName);
+ if (!value)
+ return rewriter.notifyMatchFailure(loc, "no 'value' attr");
+
+ if (auto dstElementsAttr = dyn_cast<DenseElementsAttr>(value)) {
+ if (resType.isScalable() && !isa<SplatElementsAttr>(value))
+ return rewriter.notifyMatchFailure(
+ loc,
+ "Cannot linearize a constant scalable vector that's not a splat");
+
+ dstElementsAttr = dstElementsAttr.reshape(resType);
+ FailureOr<Operation *> newOp =
+ convertOpResultTypes(op, {}, *getTypeConverter(), rewriter);
----------------
dcaballe wrote:
nit: `{}` -> `/*paramName=*/{}`
https://github.com/llvm/llvm-project/pull/128612
More information about the Mlir-commits
mailing list