[Mlir-commits] [mlir] [mlir][vector] VectorLinearize: `ub.poison` support (PR #128612)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Feb 24 17:12:25 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-vector

Author: Ivan Butygin (Hardcode84)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/128612.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (+34-4) 
- (modified) mlir/test/Dialect/Vector/linearize.mlir (+16) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 3ecd585c5a26d..65bd982319e45 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -11,6 +11,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
 #include "mlir/IR/Attributes.h"
@@ -97,6 +98,35 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
   unsigned targetVectorBitWidth;
 };
 
+struct LinearizePoison final : OpConversionPattern<ub::PoisonOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LinearizePoison(
+      const TypeConverter &typeConverter, MLIRContext *context,
+      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit),
+        targetVectorBitWidth(targetVectBitWidth) {}
+  LogicalResult
+  matchAndRewrite(ub::PoisonOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    auto resType = getTypeConverter()->convertType<VectorType>(op.getType());
+
+    if (!resType)
+      return rewriter.notifyMatchFailure(loc, "can't convert return type");
+
+    if (!isLessThanTargetBitWidth(op, targetVectorBitWidth))
+      return rewriter.notifyMatchFailure(
+          loc, "Can't flatten since targetBitWidth <= OpSize");
+
+    rewriter.replaceOpWithNewOp<ub::PoisonOp>(op, resType);
+    return success();
+  }
+
+private:
+  unsigned targetVectorBitWidth;
+};
+
 struct LinearizeVectorizable final
     : OpTraitConversionPattern<OpTrait::Vectorizable> {
   using OpTraitConversionPattern::OpTraitConversionPattern;
@@ -525,7 +555,7 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
   typeConverter.addTargetMaterialization(materializeCast);
   target.markUnknownOpDynamicallyLegal(
       [=](Operation *op) -> std::optional<bool> {
-        if ((isa<arith::ConstantOp>(op) || isa<vector::BitCastOp>(op) ||
+        if ((isa<arith::ConstantOp, ub::PoisonOp, vector::BitCastOp>(op) ||
              op->hasTrait<OpTrait::Vectorizable>())) {
           return (isLessThanTargetBitWidth(op, targetBitWidth)
                       ? typeConverter.isLegal(op)
@@ -534,9 +564,9 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
         return std::nullopt;
       });
 
-  patterns
-      .add<LinearizeConstant, LinearizeVectorizable, LinearizeVectorBitCast>(
-          typeConverter, patterns.getContext(), targetBitWidth);
+  patterns.add<LinearizeConstant, LinearizePoison, LinearizeVectorizable,
+               LinearizeVectorBitCast>(typeConverter, patterns.getContext(),
+                                       targetBitWidth);
 }
 
 void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 99b1bbab1eede..f859ffd0e19d7 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -32,6 +32,22 @@ func.func @test_linearize(%arg0: vector<2x2xf32>) -> vector<2x2xf32> {
 
 // -----
 
+// ALL-LABEL: test_linearize_poison
+func.func @test_linearize_poison() -> vector<2x2xf32> {
+  // DEFAULT: %[[P:.*]] = ub.poison : vector<4xf32>
+  // DEFAULT: %[[RES:.*]] = vector.shape_cast %[[P]] : vector<4xf32> to vector<2x2xf32>
+
+  // BW-128: %[[P:.*]] = ub.poison : vector<4xf32>
+  // BW-128: %[[RES:.*]] = vector.shape_cast %[[P]] : vector<4xf32> to vector<2x2xf32>
+
+  // BW-0: %[[RES:.*]] = ub.poison : vector<2x2xf32>
+  %0 = ub.poison : vector<2x2xf32>
+  // ALL: return %[[RES]] : vector<2x2xf32>
+  return %0 : vector<2x2xf32>
+}
+
+// -----
+
 // ALL-LABEL: test_partial_linearize
 // ALL-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>, %[[ORIG_ARG2:.*]]: vector<4x4xf32>)
 func.func @test_partial_linearize(%arg0: vector<2x2xf32>, %arg1: vector<4x4xf32>) -> vector<2x2xf32> {

``````````

</details>


https://github.com/llvm/llvm-project/pull/128612


More information about the Mlir-commits mailing list