[Mlir-commits] [mlir] caf89c0 - [mlir][Vector] Support 0-D vectors in `ConstantMaskOp`

Nicolas Vasilache llvmlistbot at llvm.org
Mon Dec 6 00:05:31 PST 2021


Author: Michal Terepeta
Date: 2021-12-06T08:03:04Z
New Revision: caf89c0db679f79ca6c9a75c5acc6151dd380f26

URL: https://github.com/llvm/llvm-project/commit/caf89c0db679f79ca6c9a75c5acc6151dd380f26
DIFF: https://github.com/llvm/llvm-project/commit/caf89c0db679f79ca6c9a75c5acc6151dd380f26.diff

LOG: [mlir][Vector] Support 0-D vectors in `ConstantMaskOp`

To support creating both a mask with just a single `true` and `false` values,
I had to relax the restriction in the verifier that the rank is always equal to
the length of the attribute array, in other words, we now allow:

- `vector.constant_mask [0] : vector<i1>` which gets lowered to
  `arith.constant dense<false> : vector<i1>`
- `vector.constant_mask [1] : vector<i1>` which gets lowered to
  `arith.constant dense<true> : vector<i1>`

(the attribute list for the 0-D case must be a singleton containing
either `0` or `1`)

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D115023

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.td
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
    mlir/test/Dialect/Vector/invalid.mlir
    mlir/test/Dialect/Vector/ops.mlir
    mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 74edc5fe5f9b9..14afb47004f43 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -2111,7 +2111,7 @@ def Vector_TypeCastOp :
 def Vector_ConstantMaskOp :
   Vector_Op<"constant_mask", [NoSideEffect]>,
     Arguments<(ins I64ArrayAttr:$mask_dim_sizes)>,
-    Results<(outs VectorOf<[I1]>)> {
+    Results<(outs VectorOfAnyRankOf<[I1]>)> {
   let summary = "creates a constant vector mask";
   let description = [{
     Creates and returns a vector mask where elements of the result vector

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 76fcb97898e54..1b18b19df7e82 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -3924,8 +3924,19 @@ void vector::TransposeOp::getTransp(SmallVectorImpl<int64_t> &results) {
 //===----------------------------------------------------------------------===//
 
 static LogicalResult verify(ConstantMaskOp &op) {
-  // Verify that array attr size matches the rank of the vector result.
   auto resultType = op.getResult().getType().cast<VectorType>();
+  // Check the corner case of 0-D vectors first.
+  if (resultType.getRank() == 0) {
+    if (op.mask_dim_sizes().size() != 1)
+      return op->emitError("array attr must have length 1 for 0-D vectors");
+    auto dim = op.mask_dim_sizes()[0].cast<IntegerAttr>().getInt();
+    if (dim != 0 && dim != 1)
+      return op->emitError(
+          "mask dim size must be either 0 or 1 for 0-D vectors");
+    return success();
+  }
+
+  // Verify that array attr size matches the rank of the vector result.
   if (static_cast<int64_t>(op.mask_dim_sizes().size()) != resultType.getRank())
     return op.emitOpError(
         "must specify array attr of size equal vector result rank");

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 876f8aeb219cb..6d50838cc9ead 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -960,7 +960,20 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
     auto dstType = op.getType();
     auto eltType = dstType.getElementType();
     auto dimSizes = op.mask_dim_sizes();
-    int64_t rank = dimSizes.size();
+    int64_t rank = dstType.getRank();
+
+    if (rank == 0) {
+      assert(dimSizes.size() == 1 &&
+             "Expected exactly one dim size for a 0-D vector");
+      bool value = dimSizes[0].cast<IntegerAttr>().getInt() == 1;
+      rewriter.replaceOpWithNewOp<arith::ConstantOp>(
+          op, dstType,
+          DenseIntElementsAttr::get(
+              VectorType::get(ArrayRef<int64_t>{}, rewriter.getI1Type()),
+              ArrayRef<bool>{value}));
+      return success();
+    }
+
     int64_t trueDim = std::min(dstType.getDimSize(0),
                                dimSizes[0].cast<IntegerAttr>().getInt());
 

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 0c21cf9eecaa1..ee3c7567dd3aa 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1396,6 +1396,26 @@ func @transfer_read_1d_mask(%A : memref<?xf32>, %base : index) -> vector<5xf32>
 
 // -----
 
+func @genbool_0d_f() -> vector<i1> {
+  %0 = vector.constant_mask [0] : vector<i1>
+  return %0 : vector<i1>
+}
+// CHECK-LABEL: func @genbool_0d_f
+// CHECK: %[[VAL_0:.*]] = arith.constant dense<false> : vector<i1>
+// CHECK: return %[[VAL_0]] : vector<i1>
+
+// -----
+
+func @genbool_0d_t() -> vector<i1> {
+  %0 = vector.constant_mask [1] : vector<i1>
+  return %0 : vector<i1>
+}
+// CHECK-LABEL: func @genbool_0d_t
+// CHECK: %[[VAL_0:.*]] = arith.constant dense<true> : vector<i1>
+// CHECK: return %[[VAL_0]] : vector<i1>
+
+// -----
+
 func @genbool_1d() -> vector<8xi1> {
   %0 = vector.constant_mask [4] : vector<8xi1>
   return %0 : vector<8xi1>

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index fb69798ee1054..63e30cf8912e7 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -882,6 +882,20 @@ func @create_mask() {
 }
 
 
+// -----
+
+func @constant_mask_0d_no_attr() {
+  // expected-error at +1 {{array attr must have length 1 for 0-D vectors}}
+  %0 = vector.constant_mask [] : vector<i1>
+}
+
+// -----
+
+func @constant_mask_0d_bad_attr() {
+  // expected-error at +1 {{mask dim size must be either 0 or 1 for 0-D vectors}}
+  %0 = vector.constant_mask [2] : vector<i1>
+}
+
 // -----
 
 func @constant_mask() {

diff  --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 2bd0e13f05e4e..43c5abdd9ef8c 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -376,6 +376,15 @@ func @create_vector_mask() {
   return
 }
 
+// CHECK-LABEL: @constant_vector_mask_0d
+func @constant_vector_mask_0d() {
+  // CHECK: vector.constant_mask [0] : vector<i1>
+  %0 = vector.constant_mask [0] : vector<i1>
+  // CHECK: vector.constant_mask [1] : vector<i1>
+  %1 = vector.constant_mask [1] : vector<i1>
+  return
+}
+
 // CHECK-LABEL: @constant_vector_mask
 func @constant_vector_mask() {
   // CHECK: vector.constant_mask [3, 2] : vector<4x3xi1>

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
index 74bbadaac520b..a0d4c3d82974c 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
@@ -68,6 +68,16 @@ func @bitcast_0d() {
 }
 
 
+func @constant_mask_0d() {
+  %1 = vector.constant_mask [0] : vector<i1>
+  // CHECK: ( 0 )
+  vector.print %1: vector<i1>
+  %2 = vector.constant_mask [1] : vector<i1>
+  // CHECK: ( 1 )
+  vector.print %2: vector<i1>
+  return
+}
+
 func @entry() {
   %0 = arith.constant 42.0 : f32
   %1 = arith.constant dense<0.0> : vector<f32>
@@ -78,10 +88,13 @@ func @entry() {
   call  @print_vector_0d(%3) : (vector<f32>) -> ()
 
   %4 = arith.constant 42.0 : f32
+
+  // Warning: these must be called in their textual order of definition in the
+  // file to not mess up FileCheck.
   call  @splat_0d(%4) : (f32) -> ()
   call  @broadcast_0d(%4) : (f32) -> ()
-
   call  @bitcast_0d() : () -> ()
+  call  @constant_mask_0d() : () -> ()
 
   return
 }


        


More information about the Mlir-commits mailing list