[Mlir-commits] [mlir] [mlir][vector] VectorLinearize: `ub.poison` support (PR #128612)

Ivan Butygin llvmlistbot at llvm.org
Sat Mar 1 09:41:09 PST 2025


https://github.com/Hardcode84 updated https://github.com/llvm/llvm-project/pull/128612

>From 489bcfadfa410f3d1eb5e9189e655d5db5686b7f Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Tue, 25 Feb 2025 02:07:30 +0100
Subject: [PATCH 1/4] [mlir][vector] VectorLinearize: ub.poison support

---
 .../Vector/Transforms/VectorLinearize.cpp     | 38 +++++++++++++++++--
 mlir/test/Dialect/Vector/linearize.mlir       | 16 ++++++++
 2 files changed, 50 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 3ecd585c5a26d..65bd982319e45 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -11,6 +11,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
 #include "mlir/IR/Attributes.h"
@@ -97,6 +98,35 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
   unsigned targetVectorBitWidth;
 };
 
+struct LinearizePoison final : OpConversionPattern<ub::PoisonOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LinearizePoison(
+      const TypeConverter &typeConverter, MLIRContext *context,
+      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit),
+        targetVectorBitWidth(targetVectBitWidth) {}
+  LogicalResult
+  matchAndRewrite(ub::PoisonOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    auto resType = getTypeConverter()->convertType<VectorType>(op.getType());
+
+    if (!resType)
+      return rewriter.notifyMatchFailure(loc, "can't convert return type");
+
+    if (!isLessThanTargetBitWidth(op, targetVectorBitWidth))
+      return rewriter.notifyMatchFailure(
+          loc, "Can't flatten since targetBitWidth <= OpSize");
+
+    rewriter.replaceOpWithNewOp<ub::PoisonOp>(op, resType);
+    return success();
+  }
+
+private:
+  unsigned targetVectorBitWidth;
+};
+
 struct LinearizeVectorizable final
     : OpTraitConversionPattern<OpTrait::Vectorizable> {
   using OpTraitConversionPattern::OpTraitConversionPattern;
@@ -525,7 +555,7 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
   typeConverter.addTargetMaterialization(materializeCast);
   target.markUnknownOpDynamicallyLegal(
       [=](Operation *op) -> std::optional<bool> {
-        if ((isa<arith::ConstantOp>(op) || isa<vector::BitCastOp>(op) ||
+        if ((isa<arith::ConstantOp, ub::PoisonOp, vector::BitCastOp>(op) ||
              op->hasTrait<OpTrait::Vectorizable>())) {
           return (isLessThanTargetBitWidth(op, targetBitWidth)
                       ? typeConverter.isLegal(op)
@@ -534,9 +564,9 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
         return std::nullopt;
       });
 
-  patterns
-      .add<LinearizeConstant, LinearizeVectorizable, LinearizeVectorBitCast>(
-          typeConverter, patterns.getContext(), targetBitWidth);
+  patterns.add<LinearizeConstant, LinearizePoison, LinearizeVectorizable,
+               LinearizeVectorBitCast>(typeConverter, patterns.getContext(),
+                                       targetBitWidth);
 }
 
 void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 99b1bbab1eede..22d2cd452166b 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -32,6 +32,22 @@ func.func @test_linearize(%arg0: vector<2x2xf32>) -> vector<2x2xf32> {
 
 // -----
 
+// ALL-LABEL: test_linearize_poison
+func.func @test_linearize_poison(%arg0: vector<2x2xf32>) -> vector<2x2xf32> {
+  // DEFAULT: %[[P:.*]] = ub.poison : vector<4xf32>
+  // DEFAULT: %[[RES:.*]] = vector.shape_cast %[[P]] : vector<4xf32> to vector<2x2xf32>
+
+  // BW-128: %[[P:.*]] = ub.poison : vector<4xf32>
+  // BW-128: %[[RES:.*]] = vector.shape_cast %[[P]] : vector<4xf32> to vector<2x2xf32>
+
+  // BW-0: %[[RES:.*]] = ub.poison : vector<2x2xf32>
+  %0 = ub.poison : vector<2x2xf32>
+  // ALL: return %[[RES]] : vector<2x2xf32>
+  return %0 : vector<2x2xf32>
+}
+
+// -----
+
 // ALL-LABEL: test_partial_linearize
 // ALL-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>, %[[ORIG_ARG2:.*]]: vector<4x4xf32>)
 func.func @test_partial_linearize(%arg0: vector<2x2xf32>, %arg1: vector<4x4xf32>) -> vector<2x2xf32> {

>From df42371ba5d10921e017a490b05fbd8f4465a701 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Tue, 25 Feb 2025 02:10:57 +0100
Subject: [PATCH 2/4] fix test

---
 mlir/test/Dialect/Vector/linearize.mlir | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 22d2cd452166b..f859ffd0e19d7 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -33,7 +33,7 @@ func.func @test_linearize(%arg0: vector<2x2xf32>) -> vector<2x2xf32> {
 // -----
 
 // ALL-LABEL: test_linearize_poison
-func.func @test_linearize_poison(%arg0: vector<2x2xf32>) -> vector<2x2xf32> {
+func.func @test_linearize_poison() -> vector<2x2xf32> {
   // DEFAULT: %[[P:.*]] = ub.poison : vector<4xf32>
   // DEFAULT: %[[RES:.*]] = vector.shape_cast %[[P]] : vector<4xf32> to vector<2x2xf32>
 

>From 1008cce565af069a315ef42cbb845e9db1e91257 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Tue, 25 Feb 2025 03:13:20 +0100
Subject: [PATCH 3/4] LinearizeConstantLike

---
 .../Vector/Transforms/VectorLinearize.cpp     | 95 +++++++++----------
 1 file changed, 47 insertions(+), 48 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 65bd982319e45..0caa2cb01e019 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -58,69 +58,67 @@ static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned targetBitWidth) {
 }
 
 namespace {
-struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
-  using OpConversionPattern::OpConversionPattern;
-  LinearizeConstant(
+struct LinearizeConstantLike final
+    : OpTraitConversionPattern<OpTrait::ConstantLike> {
+  using OpTraitConversionPattern::OpTraitConversionPattern;
+
+  LinearizeConstantLike(
       const TypeConverter &typeConverter, MLIRContext *context,
       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
       PatternBenefit benefit = 1)
-      : OpConversionPattern(typeConverter, context, benefit),
+      : OpTraitConversionPattern(typeConverter, context, benefit),
         targetVectorBitWidth(targetVectBitWidth) {}
   LogicalResult
-  matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    Location loc = constOp.getLoc();
+    Location loc = op->getLoc();
+    if (op->getNumResults() != 1)
+      return rewriter.notifyMatchFailure(loc, "expected 1 result");
+
     auto resType =
-        getTypeConverter()->convertType<VectorType>(constOp.getType());
+        getTypeConverter()->convertType<VectorType>(op->getResult(0).getType());
 
     if (!resType)
       return rewriter.notifyMatchFailure(loc, "can't convert return type");
 
-    if (resType.isScalable() && !isa<SplatElementsAttr>(constOp.getValue()))
-      return rewriter.notifyMatchFailure(
-          loc,
-          "Cannot linearize a constant scalable vector that's not a splat");
-
-    if (!isLessThanTargetBitWidth(constOp, targetVectorBitWidth))
+    if (!isLessThanTargetBitWidth(op, targetVectorBitWidth))
       return rewriter.notifyMatchFailure(
           loc, "Can't flatten since targetBitWidth <= OpSize");
-    auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
-    if (!dstElementsAttr)
-      return rewriter.notifyMatchFailure(loc, "unsupported attr type");
-
-    dstElementsAttr = dstElementsAttr.reshape(resType);
-    rewriter.replaceOpWithNewOp<arith::ConstantOp>(constOp, resType,
-                                                   dstElementsAttr);
-    return success();
-  }
-
-private:
-  unsigned targetVectorBitWidth;
-};
 
-struct LinearizePoison final : OpConversionPattern<ub::PoisonOp> {
-  using OpConversionPattern::OpConversionPattern;
-  LinearizePoison(
-      const TypeConverter &typeConverter, MLIRContext *context,
-      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
-      PatternBenefit benefit = 1)
-      : OpConversionPattern(typeConverter, context, benefit),
-        targetVectorBitWidth(targetVectBitWidth) {}
-  LogicalResult
-  matchAndRewrite(ub::PoisonOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    Location loc = op.getLoc();
-    auto resType = getTypeConverter()->convertType<VectorType>(op.getType());
+    StringAttr attrName = rewriter.getStringAttr("value");
+    Attribute value = op->getAttr(attrName);
+    if (!value)
+      return rewriter.notifyMatchFailure(loc, "no 'value' attr");
+
+    if (auto dstElementsAttr = dyn_cast<DenseElementsAttr>(value)) {
+      if (resType.isScalable() && !isa<SplatElementsAttr>(value))
+        return rewriter.notifyMatchFailure(
+            loc,
+            "Cannot linearize a constant scalable vector that's not a splat");
+
+      dstElementsAttr = dstElementsAttr.reshape(resType);
+      FailureOr<Operation *> newOp =
+          convertOpResultTypes(op, {}, *getTypeConverter(), rewriter);
+      if (failed(newOp))
+        return failure();
+
+      (*newOp)->setAttr(attrName, dstElementsAttr);
+      rewriter.replaceOp(op, *newOp);
+      return success();
+    }
 
-    if (!resType)
-      return rewriter.notifyMatchFailure(loc, "can't convert return type");
+    if (auto poisonAttr = dyn_cast<ub::PoisonAttr>(value)) {
+      FailureOr<Operation *> newOp =
+          convertOpResultTypes(op, {}, *getTypeConverter(), rewriter);
+      if (failed(newOp))
+        return failure();
 
-    if (!isLessThanTargetBitWidth(op, targetVectorBitWidth))
-      return rewriter.notifyMatchFailure(
-          loc, "Can't flatten since targetBitWidth <= OpSize");
+      (*newOp)->setAttr(attrName, poisonAttr);
+      rewriter.replaceOp(op, *newOp);
+      return success();
+    }
 
-    rewriter.replaceOpWithNewOp<ub::PoisonOp>(op, resType);
-    return success();
+    return rewriter.notifyMatchFailure(loc, "unsupported attr type");
   }
 
 private:
@@ -555,7 +553,8 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
   typeConverter.addTargetMaterialization(materializeCast);
   target.markUnknownOpDynamicallyLegal(
       [=](Operation *op) -> std::optional<bool> {
-        if ((isa<arith::ConstantOp, ub::PoisonOp, vector::BitCastOp>(op) ||
+        if ((isa<vector::BitCastOp>(op) ||
+             op->hasTrait<OpTrait::ConstantLike>() ||
              op->hasTrait<OpTrait::Vectorizable>())) {
           return (isLessThanTargetBitWidth(op, targetBitWidth)
                       ? typeConverter.isLegal(op)
@@ -564,7 +563,7 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
         return std::nullopt;
       });
 
-  patterns.add<LinearizeConstant, LinearizePoison, LinearizeVectorizable,
+  patterns.add<LinearizeConstantLike, LinearizeVectorizable,
                LinearizeVectorBitCast>(typeConverter, patterns.getContext(),
                                        targetBitWidth);
 }

>From 4072713fee8c80de2c07a6bb1742d8699cdad28d Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sat, 1 Mar 2025 18:38:11 +0100
Subject: [PATCH 4/4] review comments

---
 .../Vector/Transforms/VectorLinearize.cpp     | 60 ++++++++++---------
 mlir/test/Dialect/Vector/linearize.mlir       |  8 +--
 2 files changed, 36 insertions(+), 32 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 0caa2cb01e019..9dccc005322eb 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -57,6 +57,24 @@ static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned targetBitWidth) {
   return trailingVecDimBitWidth <= targetBitWidth;
 }
 
+static FailureOr<Attribute>
+linearizeConstAttr(Location loc, ConversionPatternRewriter &rewriter,
+                   VectorType resType, Attribute value) {
+  if (auto dstElementsAttr = dyn_cast<DenseElementsAttr>(value)) {
+    if (resType.isScalable() && !isa<SplatElementsAttr>(value))
+      return rewriter.notifyMatchFailure(
+          loc,
+          "Cannot linearize a constant scalable vector that's not a splat");
+
+    return dstElementsAttr.reshape(resType);
+  }
+
+  if (auto poisonAttr = dyn_cast<ub::PoisonAttr>(value))
+    return poisonAttr;
+
+  return rewriter.notifyMatchFailure(loc, "unsupported attr type");
+}
+
 namespace {
 struct LinearizeConstantLike final
     : OpTraitConversionPattern<OpTrait::ConstantLike> {
@@ -75,8 +93,9 @@ struct LinearizeConstantLike final
     if (op->getNumResults() != 1)
       return rewriter.notifyMatchFailure(loc, "expected 1 result");
 
+    const TypeConverter &converter = *getTypeConverter();
     auto resType =
-        getTypeConverter()->convertType<VectorType>(op->getResult(0).getType());
+        converter.convertType<VectorType>(op->getResult(0).getType());
 
     if (!resType)
       return rewriter.notifyMatchFailure(loc, "can't convert return type");
@@ -90,35 +109,20 @@ struct LinearizeConstantLike final
     if (!value)
       return rewriter.notifyMatchFailure(loc, "no 'value' attr");
 
-    if (auto dstElementsAttr = dyn_cast<DenseElementsAttr>(value)) {
-      if (resType.isScalable() && !isa<SplatElementsAttr>(value))
-        return rewriter.notifyMatchFailure(
-            loc,
-            "Cannot linearize a constant scalable vector that's not a splat");
-
-      dstElementsAttr = dstElementsAttr.reshape(resType);
-      FailureOr<Operation *> newOp =
-          convertOpResultTypes(op, {}, *getTypeConverter(), rewriter);
-      if (failed(newOp))
-        return failure();
-
-      (*newOp)->setAttr(attrName, dstElementsAttr);
-      rewriter.replaceOp(op, *newOp);
-      return success();
-    }
-
-    if (auto poisonAttr = dyn_cast<ub::PoisonAttr>(value)) {
-      FailureOr<Operation *> newOp =
-          convertOpResultTypes(op, {}, *getTypeConverter(), rewriter);
-      if (failed(newOp))
-        return failure();
+    FailureOr<Attribute> newValue =
+        linearizeConstAttr(loc, rewriter, resType, value);
+    if (failed(newValue))
+      return failure();
 
-      (*newOp)->setAttr(attrName, poisonAttr);
-      rewriter.replaceOp(op, *newOp);
-      return success();
-    }
+    FailureOr<Operation *> convertResult =
+        convertOpResultTypes(op, /*operands=*/{}, converter, rewriter);
+    if (failed(convertResult))
+      return failure();
 
-    return rewriter.notifyMatchFailure(loc, "unsupported attr type");
+    Operation *newOp = *convertResult;
+    newOp->setAttr(attrName, *newValue);
+    rewriter.replaceOp(op, newOp);
+    return success();
   }
 
 private:
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index f859ffd0e19d7..0c7d2b124f621 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -34,11 +34,11 @@ func.func @test_linearize(%arg0: vector<2x2xf32>) -> vector<2x2xf32> {
 
 // ALL-LABEL: test_linearize_poison
 func.func @test_linearize_poison() -> vector<2x2xf32> {
-  // DEFAULT: %[[P:.*]] = ub.poison : vector<4xf32>
-  // DEFAULT: %[[RES:.*]] = vector.shape_cast %[[P]] : vector<4xf32> to vector<2x2xf32>
+  // DEFAULT: %[[POISON:.*]] = ub.poison : vector<4xf32>
+  // DEFAULT: %[[RES:.*]] = vector.shape_cast %[[POISON]] : vector<4xf32> to vector<2x2xf32>
 
-  // BW-128: %[[P:.*]] = ub.poison : vector<4xf32>
-  // BW-128: %[[RES:.*]] = vector.shape_cast %[[P]] : vector<4xf32> to vector<2x2xf32>
+  // BW-128: %[[POISON:.*]] = ub.poison : vector<4xf32>
+  // BW-128: %[[RES:.*]] = vector.shape_cast %[[POISON]] : vector<4xf32> to vector<2x2xf32>
 
   // BW-0: %[[RES:.*]] = ub.poison : vector<2x2xf32>
   %0 = ub.poison : vector<2x2xf32>



More information about the Mlir-commits mailing list