[Mlir-commits] [mlir] [mlir][Vector] Fold vector.constant_mask to splat (PR #146724)
Fabian Mora
llvmlistbot at llvm.org
Wed Jul 2 08:36:42 PDT 2025
================
@@ -6594,6 +6594,26 @@ bool ConstantMaskOp::isAllOnesMask() {
return true;
}
+static Attribute createBoolSplat(ShapedType ty, bool x) {
+ return SplatElementsAttr::get(ty, BoolAttr::get(ty.getContext(), x));
+}
+
+OpFoldResult ConstantMaskOp::fold(FoldAdaptor adaptor) {
+ ArrayRef<int64_t> bounds = getMaskDimSizes();
+ ArrayRef<int64_t> vectorSizes = getVectorType().getShape();
+ // Check the corner case of 0-D vectors first.
+ if (vectorSizes.size() == 0) {
+ assert(bounds.size() == 1 && "invalid sizes for zero rank mask");
+ return createBoolSplat(getVectorType(), bounds[0] == 1);
+ }
+ // Fold vector.constant_mask to splat if possible.
+ if (bounds == vectorSizes)
+ return createBoolSplat(getVectorType(), true);
+ if (llvm::all_of(bounds, [](int64_t x) { return x == 0; }))
+ return createBoolSplat(getVectorType(), false);
+ return {};
----------------
fabianmcg wrote:
```suggestion
return OpFoldResult();
```
https://github.com/llvm/llvm-project/pull/146724
More information about the Mlir-commits
mailing list