[Mlir-commits] [mlir] [mlir][vector] VectorLinearize: `ub.poison` support (PR #128612)

Diego Caballero llvmlistbot at llvm.org
Tue Feb 25 16:30:39 PST 2025


================
@@ -57,40 +58,67 @@ static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned targetBitWidth) {
 }
 
 namespace {
-struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
-  using OpConversionPattern::OpConversionPattern;
-  LinearizeConstant(
+struct LinearizeConstantLike final
+    : OpTraitConversionPattern<OpTrait::ConstantLike> {
+  using OpTraitConversionPattern::OpTraitConversionPattern;
+
+  LinearizeConstantLike(
       const TypeConverter &typeConverter, MLIRContext *context,
       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
       PatternBenefit benefit = 1)
-      : OpConversionPattern(typeConverter, context, benefit),
+      : OpTraitConversionPattern(typeConverter, context, benefit),
         targetVectorBitWidth(targetVectBitWidth) {}
   LogicalResult
-  matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    Location loc = constOp.getLoc();
+    Location loc = op->getLoc();
+    if (op->getNumResults() != 1)
+      return rewriter.notifyMatchFailure(loc, "expected 1 result");
+
     auto resType =
-        getTypeConverter()->convertType<VectorType>(constOp.getType());
+        getTypeConverter()->convertType<VectorType>(op->getResult(0).getType());
 
     if (!resType)
       return rewriter.notifyMatchFailure(loc, "can't convert return type");
 
-    if (resType.isScalable() && !isa<SplatElementsAttr>(constOp.getValue()))
-      return rewriter.notifyMatchFailure(
-          loc,
-          "Cannot linearize a constant scalable vector that's not a splat");
-
-    if (!isLessThanTargetBitWidth(constOp, targetVectorBitWidth))
+    if (!isLessThanTargetBitWidth(op, targetVectorBitWidth))
       return rewriter.notifyMatchFailure(
           loc, "Can't flatten since targetBitWidth <= OpSize");
-    auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
-    if (!dstElementsAttr)
-      return rewriter.notifyMatchFailure(loc, "unsupported attr type");
 
-    dstElementsAttr = dstElementsAttr.reshape(resType);
-    rewriter.replaceOpWithNewOp<arith::ConstantOp>(constOp, resType,
-                                                   dstElementsAttr);
-    return success();
+    StringAttr attrName = rewriter.getStringAttr("value");
+    Attribute value = op->getAttr(attrName);
+    if (!value)
+      return rewriter.notifyMatchFailure(loc, "no 'value' attr");
+
+    if (auto dstElementsAttr = dyn_cast<DenseElementsAttr>(value)) {
+      if (resType.isScalable() && !isa<SplatElementsAttr>(value))
+        return rewriter.notifyMatchFailure(
+            loc,
+            "Cannot linearize a constant scalable vector that's not a splat");
+
+      dstElementsAttr = dstElementsAttr.reshape(resType);
+      FailureOr<Operation *> newOp =
+          convertOpResultTypes(op, {}, *getTypeConverter(), rewriter);
----------------
dcaballe wrote:

nit: `{}` -> `/*paramName=*/{}`

https://github.com/llvm/llvm-project/pull/128612


More information about the Mlir-commits mailing list