[Mlir-commits] [mlir] 02fae68 - [mlir][vector] VectorLinearize: `ub.poison` support (#128612)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Mar 13 04:18:29 PDT 2025
Author: Ivan Butygin
Date: 2025-03-13T14:18:21+03:00
New Revision: 02fae68a45fdd752e3ad5ee767c378a45b77744d
URL: https://github.com/llvm/llvm-project/commit/02fae68a45fdd752e3ad5ee767c378a45b77744d
DIFF: https://github.com/llvm/llvm-project/commit/02fae68a45fdd752e3ad5ee767c378a45b77744d.diff
LOG: [mlir][vector] VectorLinearize: `ub.poison` support (#128612)
Unify `arith.constant` and `up.poison` using
`OpTraitConversionPattern<OpTrait::ConstantLike>`.
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
mlir/test/Dialect/Vector/linearize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 3ecd585c5a26d..9dccc005322eb 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/IR/Attributes.h"
@@ -56,40 +57,71 @@ static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned targetBitWidth) {
return trailingVecDimBitWidth <= targetBitWidth;
}
+static FailureOr<Attribute>
+linearizeConstAttr(Location loc, ConversionPatternRewriter &rewriter,
+ VectorType resType, Attribute value) {
+ if (auto dstElementsAttr = dyn_cast<DenseElementsAttr>(value)) {
+ if (resType.isScalable() && !isa<SplatElementsAttr>(value))
+ return rewriter.notifyMatchFailure(
+ loc,
+ "Cannot linearize a constant scalable vector that's not a splat");
+
+ return dstElementsAttr.reshape(resType);
+ }
+
+ if (auto poisonAttr = dyn_cast<ub::PoisonAttr>(value))
+ return poisonAttr;
+
+ return rewriter.notifyMatchFailure(loc, "unsupported attr type");
+}
+
namespace {
-struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
- using OpConversionPattern::OpConversionPattern;
- LinearizeConstant(
+struct LinearizeConstantLike final
+ : OpTraitConversionPattern<OpTrait::ConstantLike> {
+ using OpTraitConversionPattern::OpTraitConversionPattern;
+
+ LinearizeConstantLike(
const TypeConverter &typeConverter, MLIRContext *context,
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
PatternBenefit benefit = 1)
- : OpConversionPattern(typeConverter, context, benefit),
+ : OpTraitConversionPattern(typeConverter, context, benefit),
targetVectorBitWidth(targetVectBitWidth) {}
LogicalResult
- matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- Location loc = constOp.getLoc();
+ Location loc = op->getLoc();
+ if (op->getNumResults() != 1)
+ return rewriter.notifyMatchFailure(loc, "expected 1 result");
+
+ const TypeConverter &converter = *getTypeConverter();
auto resType =
- getTypeConverter()->convertType<VectorType>(constOp.getType());
+ converter.convertType<VectorType>(op->getResult(0).getType());
if (!resType)
return rewriter.notifyMatchFailure(loc, "can't convert return type");
- if (resType.isScalable() && !isa<SplatElementsAttr>(constOp.getValue()))
- return rewriter.notifyMatchFailure(
- loc,
- "Cannot linearize a constant scalable vector that's not a splat");
-
- if (!isLessThanTargetBitWidth(constOp, targetVectorBitWidth))
+ if (!isLessThanTargetBitWidth(op, targetVectorBitWidth))
return rewriter.notifyMatchFailure(
loc, "Can't flatten since targetBitWidth <= OpSize");
- auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
- if (!dstElementsAttr)
- return rewriter.notifyMatchFailure(loc, "unsupported attr type");
- dstElementsAttr = dstElementsAttr.reshape(resType);
- rewriter.replaceOpWithNewOp<arith::ConstantOp>(constOp, resType,
- dstElementsAttr);
+ StringAttr attrName = rewriter.getStringAttr("value");
+ Attribute value = op->getAttr(attrName);
+ if (!value)
+ return rewriter.notifyMatchFailure(loc, "no 'value' attr");
+
+ FailureOr<Attribute> newValue =
+ linearizeConstAttr(loc, rewriter, resType, value);
+ if (failed(newValue))
+ return failure();
+
+ FailureOr<Operation *> convertResult =
+ convertOpResultTypes(op, /*operands=*/{}, converter, rewriter);
+ if (failed(convertResult))
+ return failure();
+
+ Operation *newOp = *convertResult;
+ newOp->setAttr(attrName, *newValue);
+ rewriter.replaceOp(op, newOp);
return success();
}
@@ -525,7 +557,8 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
typeConverter.addTargetMaterialization(materializeCast);
target.markUnknownOpDynamicallyLegal(
[=](Operation *op) -> std::optional<bool> {
- if ((isa<arith::ConstantOp>(op) || isa<vector::BitCastOp>(op) ||
+ if ((isa<vector::BitCastOp>(op) ||
+ op->hasTrait<OpTrait::ConstantLike>() ||
op->hasTrait<OpTrait::Vectorizable>())) {
return (isLessThanTargetBitWidth(op, targetBitWidth)
? typeConverter.isLegal(op)
@@ -534,9 +567,9 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
return std::nullopt;
});
- patterns
- .add<LinearizeConstant, LinearizeVectorizable, LinearizeVectorBitCast>(
- typeConverter, patterns.getContext(), targetBitWidth);
+ patterns.add<LinearizeConstantLike, LinearizeVectorizable,
+ LinearizeVectorBitCast>(typeConverter, patterns.getContext(),
+ targetBitWidth);
}
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 8e5ddbfffcdd9..9052c6440e6ac 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -32,6 +32,22 @@ func.func @test_linearize(%arg0: vector<2x2xf32>) -> vector<2x2xf32> {
// -----
+// ALL-LABEL: test_linearize_poison
+func.func @test_linearize_poison() -> vector<2x2xf32> {
+ // DEFAULT: %[[POISON:.*]] = ub.poison : vector<4xf32>
+ // DEFAULT: %[[RES:.*]] = vector.shape_cast %[[POISON]] : vector<4xf32> to vector<2x2xf32>
+
+ // BW-128: %[[POISON:.*]] = ub.poison : vector<4xf32>
+ // BW-128: %[[RES:.*]] = vector.shape_cast %[[POISON]] : vector<4xf32> to vector<2x2xf32>
+
+ // BW-0: %[[RES:.*]] = ub.poison : vector<2x2xf32>
+ %0 = ub.poison : vector<2x2xf32>
+ // ALL: return %[[RES]] : vector<2x2xf32>
+ return %0 : vector<2x2xf32>
+}
+
+// -----
+
// ALL-LABEL: test_partial_linearize
// ALL-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>, %[[ORIG_ARG2:.*]]: vector<4x4xf32>)
func.func @test_partial_linearize(%arg0: vector<2x2xf32>, %arg1: vector<4x4xf32>) -> vector<2x2xf32> {
More information about the Mlir-commits
mailing list