[Mlir-commits] [mlir] [mlir][VectorOps] Extend vector.constant_mask to support 'all true' scalable dims (PR #66638)
Benjamin Maxwell
llvmlistbot at llvm.org
Wed Sep 20 04:11:53 PDT 2023
================
@@ -115,43 +114,41 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
bool value = cast<IntegerAttr>(dimSizes[0]).getInt() == 1;
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, dstType,
- DenseIntElementsAttr::get(
- VectorType::get(ArrayRef<int64_t>{}, rewriter.getI1Type()),
- ArrayRef<bool>{value}));
+ DenseIntElementsAttr::get(VectorType::get({}, rewriter.getI1Type()),
+ value));
return success();
}
- // Scalable constant masks can only be lowered for the "none set" case.
- if (cast<VectorType>(dstType).isScalable()) {
- rewriter.replaceOpWithNewOp<arith::ConstantOp>(
- op, DenseElementsAttr::get(dstType, false));
- return success();
- }
-
- int64_t trueDim = std::min(dstType.getDimSize(0),
- cast<IntegerAttr>(dimSizes[0]).getInt());
+ int64_t trueDimSize = cast<IntegerAttr>(dimSizes[0]).getInt();
if (rank == 1) {
- // Express constant 1-D case in explicit vector form:
- // [T,..,T,F,..,F].
- SmallVector<bool> values(dstType.getDimSize(0));
- for (int64_t d = 0; d < trueDim; d++)
- values[d] = true;
- rewriter.replaceOpWithNewOp<arith::ConstantOp>(
- op, dstType, rewriter.getBoolVectorAttr(values));
+ if (trueDimSize == 0 || trueDimSize == dstType.getDimSize(0)) {
+ // Use constant splat for 'all set' or 'none set' dims.
+ // This produces correct code for scalable dimensions.
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(
+ op, DenseElementsAttr::get(dstType, trueDimSize != 0));
+ } else {
+ // Express constant 1-D case in explicit vector form:
+ // [T,..,T,F,..,F].
----------------
MacDue wrote:
Done
https://github.com/llvm/llvm-project/pull/66638
More information about the Mlir-commits
mailing list