[Mlir-commits] [mlir] [mlir][VectorOps] Extend vector.constant_mask to support 'all true' scalable dims (PR #66638)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Wed Sep 20 01:46:11 PDT 2023
================
@@ -115,43 +114,41 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
bool value = cast<IntegerAttr>(dimSizes[0]).getInt() == 1;
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, dstType,
- DenseIntElementsAttr::get(
- VectorType::get(ArrayRef<int64_t>{}, rewriter.getI1Type()),
- ArrayRef<bool>{value}));
+ DenseIntElementsAttr::get(VectorType::get({}, rewriter.getI1Type()),
+ value));
return success();
}
- // Scalable constant masks can only be lowered for the "none set" case.
- if (cast<VectorType>(dstType).isScalable()) {
- rewriter.replaceOpWithNewOp<arith::ConstantOp>(
- op, DenseElementsAttr::get(dstType, false));
- return success();
- }
-
- int64_t trueDim = std::min(dstType.getDimSize(0),
- cast<IntegerAttr>(dimSizes[0]).getInt());
+ int64_t trueDimSize = cast<IntegerAttr>(dimSizes[0]).getInt();
if (rank == 1) {
- // Express constant 1-D case in explicit vector form:
- // [T,..,T,F,..,F].
- SmallVector<bool> values(dstType.getDimSize(0));
- for (int64_t d = 0; d < trueDim; d++)
- values[d] = true;
- rewriter.replaceOpWithNewOp<arith::ConstantOp>(
- op, dstType, rewriter.getBoolVectorAttr(values));
+ if (trueDimSize == 0 || trueDimSize == dstType.getDimSize(0)) {
+ // Use constant splat for 'all set' or 'none set' dims.
+ // This produces correct code for scalable dimensions.
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(
+ op, DenseElementsAttr::get(dstType, trueDimSize != 0));
+ } else {
+ // Express constant 1-D case in explicit vector form:
+ // [T,..,T,F,..,F].
+ SmallVector<bool> values(dstType.getDimSize(0));
+ for (int64_t d = 0; d < trueDimSize; d++)
+ values[d] = true;
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(
+ op, dstType, rewriter.getBoolVectorAttr(values));
+ }
return success();
}
- VectorType lowType =
- VectorType::get(dstType.getShape().drop_front(), eltType);
- SmallVector<int64_t> newDimSizes;
- for (int64_t r = 1; r < rank; r++)
- newDimSizes.push_back(cast<IntegerAttr>(dimSizes[r]).getInt());
+ if (dstType.getScalableDims().front())
+ return rewriter.notifyMatchFailure(
+ op, "Cannot unroll leading scalable dim in dstType");
----------------
banach-space wrote:
We should be able to test for this in invalid.mlir, right?
https://github.com/llvm/llvm-project/pull/66638
More information about the Mlir-commits
mailing list