[Mlir-commits] [mlir] [mlir][VectorOps] Extend vector.constant_mask to support 'all true' scalable dims (PR #66638)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Sep 18 05:37:01 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
<details>
<summary>Changes</summary>
This extends `vector.constant_mask` so that mask dim sizes that correspond to a scalable dimension are treated as if they're implicitly multiplied by vscale. Currently this is limited to mask dim sizes of 0 or the size of the dim/vscale. This allows constant masks to represent all true and all false scalable masks (and some variations):
```
// All true scalable mask
%mask = vector.constant_mask [8] : vector<[8]xi1>
// All false scalable mask
%mask = vector.constant_mask [0] : vector<[8]xi1>
// First two scalable rows
%mask = vector.constant_mask [2,4] : vector<4x[4]xi1>
```
---
Full diff: https://github.com/llvm/llvm-project/pull/66638.diff
6 Files Affected:
- (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+4-1)
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+9-12)
- (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp (+24-27)
- (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+41-2)
- (modified) mlir/test/Dialect/Vector/invalid.mlir (+1-1)
- (modified) mlir/test/Dialect/Vector/ops.mlir (+5-1)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 28b5864914f6920..64fbd722a4f02c3 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2248,7 +2248,10 @@ def Vector_ConstantMaskOp :
define a hyper-rectangular region within which elements values are set to 1
(otherwise element values are set to 0). Each value of 'mask_dim_sizes' must
be non-negative and not greater than the size of the corresponding vector
- dimension (as opposed to vector.create_mask which allows this).
+ dimension (as opposed to vector.create_mask which allows this). Sizes that
+ correspond to scalable dimensions are implicitly multiplied by vscale,
+ though currently only zero (none set) or the size of the dim/vscale
+ (all set) are supported.
Example:
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index a8ad05f7bc1cabf..3c68cb26fb55a11 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5320,13 +5320,18 @@ LogicalResult ConstantMaskOp::verify() {
// Verify that each array attr element is in bounds of corresponding vector
// result dimension size.
auto resultShape = resultType.getShape();
+ auto resultScalableDims = resultType.getScalableDims();
SmallVector<int64_t, 4> maskDimSizes;
- for (const auto &it : llvm::enumerate(getMaskDimSizes())) {
- int64_t attrValue = llvm::cast<IntegerAttr>(it.value()).getInt();
- if (attrValue < 0 || attrValue > resultShape[it.index()])
+ for (const auto [index, intAttr] : llvm::enumerate(getMaskDimSizes())) {
+ int64_t maskDimSize = llvm::cast<IntegerAttr>(intAttr).getInt();
+ if (maskDimSize < 0 || maskDimSize > resultShape[index])
return emitOpError(
"array attr of size out of bounds of vector result dimension size");
- maskDimSizes.push_back(attrValue);
+ if (resultScalableDims[index] && maskDimSize != 0 &&
+ maskDimSize != resultShape[index])
+ return emitOpError(
+ "only supports 'none set' or 'all set' scalable dimensions");
+ maskDimSizes.push_back(maskDimSize);
}
// Verify that if one mask dim size is zero, they all should be zero (because
// the mask region is a conjunction of each mask dimension interval).
@@ -5335,14 +5340,6 @@ LogicalResult ConstantMaskOp::verify() {
if (anyZeros && !allZeros)
return emitOpError("expected all mask dim sizes to be zeros, "
"as a result of conjunction with zero mask dim");
- // Verify that if the mask type is scalable, dimensions should be zero because
- // constant scalable masks can only be defined for the "none set" or "all set"
- // cases, and there is no VLA way to define an "all set" case for
- // `vector.constant_mask`. In the future, a convention could be established
- // to decide if a specific dimension value could be considered as "all set".
- if (resultType.isScalable() &&
- llvm::cast<IntegerAttr>(getMaskDimSizes()[0]).getInt() != 0)
- return emitOpError("expected mask dim sizes for scalable masks to be 0");
return success();
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
index 9a828ec0b845e4a..418dc6786a76ed4 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
@@ -105,7 +105,6 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto dstType = op.getType();
- auto eltType = dstType.getElementType();
auto dimSizes = op.getMaskDimSizes();
int64_t rank = dstType.getRank();
@@ -115,43 +114,41 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
bool value = cast<IntegerAttr>(dimSizes[0]).getInt() == 1;
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, dstType,
- DenseIntElementsAttr::get(
- VectorType::get(ArrayRef<int64_t>{}, rewriter.getI1Type()),
- ArrayRef<bool>{value}));
+ DenseIntElementsAttr::get(VectorType::get({}, rewriter.getI1Type()),
+ value));
return success();
}
- // Scalable constant masks can only be lowered for the "none set" case.
- if (cast<VectorType>(dstType).isScalable()) {
- rewriter.replaceOpWithNewOp<arith::ConstantOp>(
- op, DenseElementsAttr::get(dstType, false));
- return success();
- }
-
- int64_t trueDim = std::min(dstType.getDimSize(0),
- cast<IntegerAttr>(dimSizes[0]).getInt());
+ int64_t trueDimSize = cast<IntegerAttr>(dimSizes[0]).getInt();
if (rank == 1) {
- // Express constant 1-D case in explicit vector form:
- // [T,..,T,F,..,F].
- SmallVector<bool> values(dstType.getDimSize(0));
- for (int64_t d = 0; d < trueDim; d++)
- values[d] = true;
- rewriter.replaceOpWithNewOp<arith::ConstantOp>(
- op, dstType, rewriter.getBoolVectorAttr(values));
+ if (trueDimSize == 0 || trueDimSize == dstType.getDimSize(0)) {
+ // Use constant splat for 'all set' or 'none set' dims.
+ // This produces correct code for scalable dimensions.
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(
+ op, DenseElementsAttr::get(dstType, trueDimSize != 0));
+ } else {
+ // Express constant 1-D case in explicit vector form:
+ // [T,..,T,F,..,F].
+ SmallVector<bool> values(dstType.getDimSize(0));
+ for (int64_t d = 0; d < trueDimSize; d++)
+ values[d] = true;
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(
+ op, dstType, rewriter.getBoolVectorAttr(values));
+ }
return success();
}
- VectorType lowType =
- VectorType::get(dstType.getShape().drop_front(), eltType);
- SmallVector<int64_t> newDimSizes;
- for (int64_t r = 1; r < rank; r++)
- newDimSizes.push_back(cast<IntegerAttr>(dimSizes[r]).getInt());
+ if (dstType.getScalableDims().front())
+ return rewriter.notifyMatchFailure(
+ op, "Cannot unroll leading scalable dim in dstType");
+
+ VectorType lowType = VectorType::Builder(dstType).dropDim(0);
Value trueVal = rewriter.create<vector::ConstantMaskOp>(
- loc, lowType, rewriter.getI64ArrayAttr(newDimSizes));
+ loc, lowType, rewriter.getArrayAttr(dimSizes.getValue().drop_front()));
Value result = rewriter.create<arith::ConstantOp>(
loc, dstType, rewriter.getZeroAttr(dstType));
- for (int64_t d = 0; d < trueDim; d++)
+ for (int64_t d = 0; d < trueDimSize; d++)
result =
rewriter.create<vector::InsertOp>(loc, dstType, trueVal, result, d);
rewriter.replaceOp(op, result);
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 7b29ef44c1f2f2e..27bd5b5ea0eed7b 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1819,16 +1819,55 @@ func.func @genbool_1d() -> vector<8xi1> {
// -----
-func.func @genbool_1d_scalable() -> vector<[8]xi1> {
+func.func @genbool_1d_scalable_pfalse() -> vector<[8]xi1> {
%0 = vector.constant_mask [0] : vector<[8]xi1>
return %0 : vector<[8]xi1>
}
-// CHECK-LABEL: func @genbool_1d_scalable
+// CHECK-LABEL: func @genbool_1d_scalable_pfalse
// CHECK: %[[VAL_0:.*]] = arith.constant dense<false> : vector<[8]xi1>
// CHECK: return %[[VAL_0]] : vector<[8]xi1>
// -----
+func.func @genbool_1d_scalable_ptrue() -> vector<[8]xi1> {
+ %0 = vector.constant_mask [8] : vector<[8]xi1>
+ return %0 : vector<[8]xi1>
+}
+// CHECK-LABEL: func @genbool_1d_scalable_ptrue
+// CHECK: %[[VAL_0:.*]] = arith.constant dense<true> : vector<[8]xi1>
+// CHECK: return %[[VAL_0]] : vector<[8]xi1>
+
+// -----
+
+func.func @genbool_2d_scalable() -> vector<4x[4]xi1> {
+ %0 = vector.constant_mask [2, 4] : vector<4x[4]xi1>
+ return %0 : vector<4x[4]xi1>
+}
+// CHECK-LABEL: func.func @genbool_2d_scalable() -> vector<4x[4]xi1> {
+// CHECK: %[[VAL_0:.*]] = arith.constant dense<true> : vector<[4]xi1>
+// CHECK: %[[VAL_1:.*]] = arith.constant dense<false> : vector<4x[4]xi1>
+// CHECK: %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[VAL_1]] : vector<4x[4]xi1> to !llvm.array<4 x vector<[4]xi1>>
+// CHECK: %[[VAL_3:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_2]][0] : !llvm.array<4 x vector<[4]xi1>>
+// CHECK: %[[VAL_4:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_3]][1] : !llvm.array<4 x vector<[4]xi1>>
+// CHECK: %[[VAL_5:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : !llvm.array<4 x vector<[4]xi1>> to vector<4x[4]xi1>
+// CHECK: return %[[VAL_5]] : vector<4x[4]xi1>
+// CHECK: }
+
+// -----
+
+/// Currently, this is not supported as generating the mask would require
+/// unrolling the leading scalable dimension at compile time.
+func.func @cannot_genbool_2d_leading_scalable() -> vector<[4]x4xi1> {
+ %0 = vector.constant_mask [4, 2] : vector<[4]x4xi1>
+ return %0 : vector<[4]x4xi1>
+}
+// CHECK-LABEL: func.func @cannot_genbool_2d_leading_scalable() -> vector<[4]x4xi1> {
+// CHECK: %[[VAL_0:.*]] = vector.constant_mask [4, 2] : vector<[4]x4xi1>
+// CHECK: return %[[VAL_0]] : vector<[4]x4xi1>
+// CHECK: }
+
+// -----
+
func.func @genbool_2d() -> vector<4x4xi1> {
%v = vector.constant_mask [2, 2] : vector<4x4xi1>
return %v: vector<4x4xi1>
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 50119c2b4a36261..26772b929493585 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -995,7 +995,7 @@ func.func @constant_mask_with_zero_mask_dim_size() {
// -----
func.func @constant_mask_scalable_non_zero_dim_size() {
- // expected-error at +1 {{expected mask dim sizes for scalable masks to be 0}}
+ // expected-error at +1 {{only supports 'none set' or 'all set' scalable dimensions}}
%0 = vector.constant_mask [2] : vector<[8]xi1>
}
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 4ea4379372e8380..96c56946cd1cfff 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -448,6 +448,10 @@ func.func @constant_vector_mask() {
%0 = vector.constant_mask [3, 2] : vector<4x3xi1>
// CHECK: vector.constant_mask [0] : vector<[4]xi1>
%1 = vector.constant_mask [0] : vector<[4]xi1>
+ // CHECK: vector.constant_mask [4] : vector<[4]xi1>
+ %2 = vector.constant_mask [4] : vector<[4]xi1>
+ // CHECK: vector.constant_mask [1, 4] : vector<2x[4]xi1>
+ %3 = vector.constant_mask [1, 4] : vector<2x[4]xi1>
return
}
@@ -1003,7 +1007,7 @@ func.func @contraction_masked_scalable(%A: vector<3x4xf32>,
%C: vector<3x[8]xf32>,
%M : vector<3x[8]x4xi1>) -> vector<3x[8]xf32> {
// CHECK: vector.mask %[[M]] { vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[A]], %[[B]], %[[C]] : vector<3x4xf32>, vector<4x[8]xf32> into vector<3x[8]xf32> } : vector<3x[8]x4xi1> -> vector<3x[8]xf32>
- %0 = vector.mask %M { vector.contract #matmat_trait %A, %B, %C : vector<3x4xf32>, vector<4x[8]xf32> into vector<3x[8]xf32> }
+ %0 = vector.mask %M { vector.contract #matmat_trait %A, %B, %C : vector<3x4xf32>, vector<4x[8]xf32> into vector<3x[8]xf32> }
: vector<3x[8]x4xi1> -> vector<3x[8]xf32>
return %0 : vector<3x[8]xf32>
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/66638
More information about the Mlir-commits
mailing list