[Mlir-commits] [mlir] caf89c0 - [mlir][Vector] Support 0-D vectors in `ConstantMaskOp`
Nicolas Vasilache
llvmlistbot at llvm.org
Mon Dec 6 00:05:31 PST 2021
Author: Michal Terepeta
Date: 2021-12-06T08:03:04Z
New Revision: caf89c0db679f79ca6c9a75c5acc6151dd380f26
URL: https://github.com/llvm/llvm-project/commit/caf89c0db679f79ca6c9a75c5acc6151dd380f26
DIFF: https://github.com/llvm/llvm-project/commit/caf89c0db679f79ca6c9a75c5acc6151dd380f26.diff
LOG: [mlir][Vector] Support 0-D vectors in `ConstantMaskOp`
To support creating both a mask with just a single `true` and `false` values,
I had to relax the restriction in the verifier that the rank is always equal to
the length of the attribute array, in other words, we now allow:
- `vector.constant_mask [0] : vector<i1>` which gets lowered to
`arith.constant dense<false> : vector<i1>`
- `vector.constant_mask [1] : vector<i1>` which gets lowered to
`arith.constant dense<true> : vector<i1>`
(the attribute list for the 0-D case must be a singleton containing
either `0` or `1`)
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D115023
Added:
Modified:
mlir/include/mlir/Dialect/Vector/VectorOps.td
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
mlir/test/Dialect/Vector/invalid.mlir
mlir/test/Dialect/Vector/ops.mlir
mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 74edc5fe5f9b9..14afb47004f43 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -2111,7 +2111,7 @@ def Vector_TypeCastOp :
def Vector_ConstantMaskOp :
Vector_Op<"constant_mask", [NoSideEffect]>,
Arguments<(ins I64ArrayAttr:$mask_dim_sizes)>,
- Results<(outs VectorOf<[I1]>)> {
+ Results<(outs VectorOfAnyRankOf<[I1]>)> {
let summary = "creates a constant vector mask";
let description = [{
Creates and returns a vector mask where elements of the result vector
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 76fcb97898e54..1b18b19df7e82 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -3924,8 +3924,19 @@ void vector::TransposeOp::getTransp(SmallVectorImpl<int64_t> &results) {
//===----------------------------------------------------------------------===//
static LogicalResult verify(ConstantMaskOp &op) {
- // Verify that array attr size matches the rank of the vector result.
auto resultType = op.getResult().getType().cast<VectorType>();
+ // Check the corner case of 0-D vectors first.
+ if (resultType.getRank() == 0) {
+ if (op.mask_dim_sizes().size() != 1)
+ return op->emitError("array attr must have length 1 for 0-D vectors");
+ auto dim = op.mask_dim_sizes()[0].cast<IntegerAttr>().getInt();
+ if (dim != 0 && dim != 1)
+ return op->emitError(
+ "mask dim size must be either 0 or 1 for 0-D vectors");
+ return success();
+ }
+
+ // Verify that array attr size matches the rank of the vector result.
if (static_cast<int64_t>(op.mask_dim_sizes().size()) != resultType.getRank())
return op.emitOpError(
"must specify array attr of size equal vector result rank");
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 876f8aeb219cb..6d50838cc9ead 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -960,7 +960,20 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
auto dstType = op.getType();
auto eltType = dstType.getElementType();
auto dimSizes = op.mask_dim_sizes();
- int64_t rank = dimSizes.size();
+ int64_t rank = dstType.getRank();
+
+ if (rank == 0) {
+ assert(dimSizes.size() == 1 &&
+ "Expected exactly one dim size for a 0-D vector");
+ bool value = dimSizes[0].cast<IntegerAttr>().getInt() == 1;
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(
+ op, dstType,
+ DenseIntElementsAttr::get(
+ VectorType::get(ArrayRef<int64_t>{}, rewriter.getI1Type()),
+ ArrayRef<bool>{value}));
+ return success();
+ }
+
int64_t trueDim = std::min(dstType.getDimSize(0),
dimSizes[0].cast<IntegerAttr>().getInt());
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 0c21cf9eecaa1..ee3c7567dd3aa 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1396,6 +1396,26 @@ func @transfer_read_1d_mask(%A : memref<?xf32>, %base : index) -> vector<5xf32>
// -----
+func @genbool_0d_f() -> vector<i1> {
+ %0 = vector.constant_mask [0] : vector<i1>
+ return %0 : vector<i1>
+}
+// CHECK-LABEL: func @genbool_0d_f
+// CHECK: %[[VAL_0:.*]] = arith.constant dense<false> : vector<i1>
+// CHECK: return %[[VAL_0]] : vector<i1>
+
+// -----
+
+func @genbool_0d_t() -> vector<i1> {
+ %0 = vector.constant_mask [1] : vector<i1>
+ return %0 : vector<i1>
+}
+// CHECK-LABEL: func @genbool_0d_t
+// CHECK: %[[VAL_0:.*]] = arith.constant dense<true> : vector<i1>
+// CHECK: return %[[VAL_0]] : vector<i1>
+
+// -----
+
func @genbool_1d() -> vector<8xi1> {
%0 = vector.constant_mask [4] : vector<8xi1>
return %0 : vector<8xi1>
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index fb69798ee1054..63e30cf8912e7 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -882,6 +882,20 @@ func @create_mask() {
}
+// -----
+
+func @constant_mask_0d_no_attr() {
+ // expected-error at +1 {{array attr must have length 1 for 0-D vectors}}
+ %0 = vector.constant_mask [] : vector<i1>
+}
+
+// -----
+
+func @constant_mask_0d_bad_attr() {
+ // expected-error at +1 {{mask dim size must be either 0 or 1 for 0-D vectors}}
+ %0 = vector.constant_mask [2] : vector<i1>
+}
+
// -----
func @constant_mask() {
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 2bd0e13f05e4e..43c5abdd9ef8c 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -376,6 +376,15 @@ func @create_vector_mask() {
return
}
+// CHECK-LABEL: @constant_vector_mask_0d
+func @constant_vector_mask_0d() {
+ // CHECK: vector.constant_mask [0] : vector<i1>
+ %0 = vector.constant_mask [0] : vector<i1>
+ // CHECK: vector.constant_mask [1] : vector<i1>
+ %1 = vector.constant_mask [1] : vector<i1>
+ return
+}
+
// CHECK-LABEL: @constant_vector_mask
func @constant_vector_mask() {
// CHECK: vector.constant_mask [3, 2] : vector<4x3xi1>
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
index 74bbadaac520b..a0d4c3d82974c 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
@@ -68,6 +68,16 @@ func @bitcast_0d() {
}
+func @constant_mask_0d() {
+ %1 = vector.constant_mask [0] : vector<i1>
+ // CHECK: ( 0 )
+ vector.print %1: vector<i1>
+ %2 = vector.constant_mask [1] : vector<i1>
+ // CHECK: ( 1 )
+ vector.print %2: vector<i1>
+ return
+}
+
func @entry() {
%0 = arith.constant 42.0 : f32
%1 = arith.constant dense<0.0> : vector<f32>
@@ -78,10 +88,13 @@ func @entry() {
call @print_vector_0d(%3) : (vector<f32>) -> ()
%4 = arith.constant 42.0 : f32
+
+ // Warning: these must be called in their textual order of definition in the
+ // file to not mess up FileCheck.
call @splat_0d(%4) : (f32) -> ()
call @broadcast_0d(%4) : (f32) -> ()
-
call @bitcast_0d() : () -> ()
+ call @constant_mask_0d() : () -> ()
return
}
More information about the Mlir-commits
mailing list