[Mlir-commits] [mlir] 408553d - [mlir][Vector] Support 0-D vectors in `CreateMaskOp`

Nicolas Vasilache llvmlistbot at llvm.org
Sun Dec 12 05:35:58 PST 2021


Author: Nicolas Vasilache
Date: 2021-12-12T13:32:29Z
New Revision: 408553dd96792929b2468bcd2d8e4764ae7c2c9e

URL: https://github.com/llvm/llvm-project/commit/408553dd96792929b2468bcd2d8e4764ae7c2c9e
DIFF: https://github.com/llvm/llvm-project/commit/408553dd96792929b2468bcd2d8e4764ae7c2c9e.diff

LOG: [mlir][Vector] Support 0-D vectors in `CreateMaskOp`

The 0-D case gets lowered in almost the same way that the 1-D case does
in VectorCreateMaskOpConversion. I also had to slightly update the
verifier for the op to always require exactly 1 operand in the 0-D case.

Depends On D115220

Reviewed by: ftynse

Differential revision: https://reviews.llvm.org/D115221

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.td
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
    mlir/test/Dialect/Vector/invalid.mlir
    mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 14afb47004f43..a05a377731665 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -2149,7 +2149,8 @@ def Vector_ConstantMaskOp :
 
 def Vector_CreateMaskOp :
   Vector_Op<"create_mask", [NoSideEffect]>,
-    Arguments<(ins Variadic<Index>:$operands)>, Results<(outs VectorOf<[I1]>)> {
+    Arguments<(ins Variadic<Index>:$operands)>,
+    Results<(outs VectorOfAnyRankOf<[I1]>)> {
   let summary = "creates a vector mask";
   let description = [{
     Creates and returns a vector mask where elements of the result vector

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 9768483231c9d..27a0fc6d9dcbe 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -3968,11 +3968,17 @@ static LogicalResult verify(ConstantMaskOp &op) {
 //===----------------------------------------------------------------------===//
 
 static LogicalResult verify(CreateMaskOp op) {
+  auto vectorType = op.getResult().getType().cast<VectorType>();
   // Verify that an operand was specified for each result vector each dimension.
-  if (op.getNumOperands() !=
-      op.getResult().getType().cast<VectorType>().getRank())
+  if (vectorType.getRank() == 0) {
+    if (op->getNumOperands() != 1)
+      return op.emitOpError(
+          "must specify exactly one operand for 0-D create_mask");
+  } else if (op.getNumOperands() !=
+             op.getResult().getType().cast<VectorType>().getRank()) {
     return op.emitOpError(
         "must specify an operand for each result vector dimension");
+  }
   return success();
 }
 

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 30335b70f4a22..008db8139df10 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -677,16 +677,17 @@ class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
 
   LogicalResult matchAndRewrite(vector::CreateMaskOp op,
                                 PatternRewriter &rewriter) const override {
-    auto loc = op.getLoc();
     auto dstType = op.getResult().getType().cast<VectorType>();
+    int64_t rank = dstType.getRank();
+    if (rank <= 1)
+      return rewriter.notifyMatchFailure(
+          op, "0-D and 1-D vectors are handled separately");
+
+    auto loc = op.getLoc();
     auto eltType = dstType.getElementType();
     int64_t dim = dstType.getDimSize(0);
-    int64_t rank = dstType.getRank();
     Value idx = op.getOperand(0);
 
-    if (rank == 1)
-      return failure(); // leave for lowering
-
     VectorType lowType =
         VectorType::get(dstType.getShape().drop_front(), eltType);
     Value trueVal = rewriter.create<vector::CreateMaskOp>(
@@ -2717,6 +2718,8 @@ static Value createCastToIndexLike(PatternRewriter &rewriter, Location loc,
 // Helper that returns a vector comparison that constructs a mask:
 //     mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
 //
+// If `dim == 0` then the result will be a 0-D vector.
+//
 // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
 //       much more compact, IR for this operation, but LLVM eventually
 //       generates more elaborate instructions for this intrinsic since it
@@ -2728,19 +2731,23 @@ static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op,
   // If we can assume all indices fit in 32-bit, we perform the vector
   // comparison in 32-bit to get a higher degree of SIMD parallelism.
   // Otherwise we perform the vector comparison using 64-bit indices.
-  Value indices;
-  Type idxType;
-  if (indexOptimizations) {
-    indices = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getI32VectorAttr(
-                 llvm::to_vector<4>(llvm::seq<int32_t>(0, dim))));
-    idxType = rewriter.getI32Type();
+  Type idxType =
+      indexOptimizations ? rewriter.getI32Type() : rewriter.getI64Type();
+  DenseIntElementsAttr indicesAttr;
+  if (dim == 0 && indexOptimizations) {
+    indicesAttr = DenseIntElementsAttr::get(
+        VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int32_t>{0});
+  } else if (dim == 0) {
+    indicesAttr = DenseIntElementsAttr::get(
+        VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int64_t>{0});
+  } else if (indexOptimizations) {
+    indicesAttr = rewriter.getI32VectorAttr(
+        llvm::to_vector<4>(llvm::seq<int32_t>(0, dim)));
   } else {
-    indices = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getI64VectorAttr(
-                 llvm::to_vector<4>(llvm::seq<int64_t>(0, dim))));
-    idxType = rewriter.getI64Type();
+    indicesAttr = rewriter.getI64VectorAttr(
+        llvm::to_vector<4>(llvm::seq<int64_t>(0, dim)));
   }
+  Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr);
   // Add in an offset if requested.
   if (off) {
     Value o = createCastToIndexLike(rewriter, loc, idxType, *off);
@@ -2806,7 +2813,7 @@ struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> {
   const bool indexOptimizations;
 };
 
-/// Conversion pattern for a vector.create_mask (1-D only).
+/// Conversion pattern for a `vector.create_mask` (0-D and 1-D only).
 class VectorCreateMaskOpConversion
     : public OpRewritePattern<vector::CreateMaskOp> {
 public:
@@ -2819,13 +2826,13 @@ class VectorCreateMaskOpConversion
                                 PatternRewriter &rewriter) const override {
     auto dstType = op.getType();
     int64_t rank = dstType.getRank();
-    if (rank == 1) {
-      rewriter.replaceOp(
-          op, buildVectorComparison(rewriter, op, indexOptimizations,
-                                    dstType.getDimSize(0), op.getOperand(0)));
-      return success();
-    }
-    return failure();
+    if (rank > 1)
+      return failure();
+    rewriter.replaceOp(
+        op, buildVectorComparison(rewriter, op, indexOptimizations,
+                                  rank == 0 ? 0 : dstType.getDimSize(0),
+                                  op.getOperand(0)));
+    return success();
   }
 
 private:

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index ee3c7567dd3aa..ccb75b8a606a2 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1442,6 +1442,35 @@ func @genbool_2d() -> vector<4x4xi1> {
 
 // -----
 
+func @create_mask_0d(%a : index) -> vector<i1> {
+  %v = vector.create_mask %a : vector<i1>
+  return %v: vector<i1>
+}
+
+// CHECK-LABEL: func @create_mask_0d
+// CHECK-SAME: %[[arg:.*]]: index
+// CHECK:  %[[indices:.*]] = arith.constant dense<0> : vector<i32>
+// CHECK:  %[[arg_i32:.*]] = arith.index_cast %[[arg]] : index to i32
+// CHECK:  %[[bounds:.*]] = splat %[[arg_i32]] : vector<i32>
+// CHECK:  %[[result:.*]] = arith.cmpi slt, %[[indices]], %[[bounds]] : vector<i32>
+// CHECK:  return %[[result]] : vector<i1>
+// -----
+
+func @create_mask_1d(%a : index) -> vector<4xi1> {
+  %v = vector.create_mask %a : vector<4xi1>
+  return %v: vector<4xi1>
+}
+
+// CHECK-LABEL: func @create_mask_1d
+// CHECK-SAME: %[[arg:.*]]: index
+// CHECK:  %[[indices:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xi32>
+// CHECK:  %[[arg_i32:.*]] = arith.index_cast %[[arg]] : index to i32
+// CHECK:  %[[bounds:.*]] = splat %[[arg_i32]] : vector<4xi32>
+// CHECK:  %[[result:.*]] = arith.cmpi slt, %[[indices]], %[[bounds]] : vector<4xi32>
+// CHECK:  return %[[result]] : vector<4xi1>
+
+// -----
+
 func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> {
   %0 = vector.flat_transpose %arg0 { rows = 4: i32, columns = 4: i32 }
      : vector<16xf32> -> vector<16xf32>

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 195e720d2ad99..a384d42ef6112 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -874,6 +874,24 @@ func @contraction(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<
 
 // -----
 
+func @create_mask_0d_no_operands() {
+  %c1 = arith.constant 1 : index
+  // expected-error at +1 {{must specify exactly one operand for 0-D create_mask}}
+  %0 = vector.create_mask : vector<i1>
+}
+
+// -----
+
+func @create_mask_0d_many_operands() {
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  // expected-error at +1 {{must specify exactly one operand for 0-D create_mask}}
+  %0 = vector.create_mask %c1, %c2, %c3 : vector<i1>
+}
+
+// -----
+
 func @create_mask() {
   %c2 = arith.constant 2 : index
   %c3 = arith.constant 3 : index

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
index 67a4257fa35dc..572280f570737 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
@@ -93,6 +93,18 @@ func @arith_cmpi_0d(%smaller : vector<i32>, %bigger : vector<i32>) {
   return
 }
 
+func @create_mask_0d(%zero : index, %one : index) {
+  %zero_mask = vector.create_mask %zero : vector<i1>
+  // CHECK: ( 0 )
+  vector.print %zero_mask : vector<i1>
+
+  %one_mask = vector.create_mask %one : vector<i1>
+  // CHECK: ( 1 )
+  vector.print %one_mask : vector<i1>
+
+  return
+}
+
 func @entry() {
   %0 = arith.constant 42.0 : f32
   %1 = arith.constant dense<0.0> : vector<f32>
@@ -115,5 +127,9 @@ func @entry() {
   %bigger = arith.constant dense<4242> : vector<i32>
   call  @arith_cmpi_0d(%smaller, %bigger) : (vector<i32>, vector<i32>) -> ()
 
+  %zero_idx = arith.constant 0 : index
+  %one_idx = arith.constant 1 : index
+  call  @create_mask_0d(%zero_idx, %one_idx) : (index, index) -> ()
+
   return
 }


        


More information about the Mlir-commits mailing list