[Mlir-commits] [mlir] [mlir][VectorOps] Extend vector.constant_mask to support 'all true' scalable dims (PR #66638)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Sep 18 05:37:01 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

<details>
<summary>Changes</summary>

This extends `vector.constant_mask` so that mask dim sizes that correspond to a scalable dimension are treated as if they're implicitly multiplied by vscale. Currently this is limited to mask dim sizes of 0 or the size of the dim/vscale. This allows constant masks to represent all true and all false scalable masks (and some variations):

```
// All true scalable mask
%mask = vector.constant_mask [8] : vector<[8]xi1>

// All false scalable mask
%mask = vector.constant_mask [0] : vector<[8]xi1>

// First two scalable rows
%mask = vector.constant_mask [2,4] : vector<4x[4]xi1>
```
---
Full diff: https://github.com/llvm/llvm-project/pull/66638.diff


6 Files Affected:

- (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+4-1) 
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+9-12) 
- (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp (+24-27) 
- (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+41-2) 
- (modified) mlir/test/Dialect/Vector/invalid.mlir (+1-1) 
- (modified) mlir/test/Dialect/Vector/ops.mlir (+5-1) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 28b5864914f6920..64fbd722a4f02c3 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2248,7 +2248,10 @@ def Vector_ConstantMaskOp :
     define a hyper-rectangular region within which elements values are set to 1
     (otherwise element values are set to 0). Each value of 'mask_dim_sizes' must
     be non-negative and not greater than the size of the corresponding vector
-    dimension (as opposed to vector.create_mask which allows this).
+    dimension (as opposed to vector.create_mask which allows this). Sizes that
+    correspond to scalable dimensions are implicitly multiplied by vscale,
+    though currently only zero (none set) or the size of the dim/vscale
+    (all set) are supported.
 
     Example:
 
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index a8ad05f7bc1cabf..3c68cb26fb55a11 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5320,13 +5320,18 @@ LogicalResult ConstantMaskOp::verify() {
   // Verify that each array attr element is in bounds of corresponding vector
   // result dimension size.
   auto resultShape = resultType.getShape();
+  auto resultScalableDims = resultType.getScalableDims();
   SmallVector<int64_t, 4> maskDimSizes;
-  for (const auto &it : llvm::enumerate(getMaskDimSizes())) {
-    int64_t attrValue = llvm::cast<IntegerAttr>(it.value()).getInt();
-    if (attrValue < 0 || attrValue > resultShape[it.index()])
+  for (const auto [index, intAttr] : llvm::enumerate(getMaskDimSizes())) {
+    int64_t maskDimSize = llvm::cast<IntegerAttr>(intAttr).getInt();
+    if (maskDimSize < 0 || maskDimSize > resultShape[index])
       return emitOpError(
           "array attr of size out of bounds of vector result dimension size");
-    maskDimSizes.push_back(attrValue);
+    if (resultScalableDims[index] && maskDimSize != 0 &&
+        maskDimSize != resultShape[index])
+      return emitOpError(
+          "only supports 'none set' or 'all set' scalable dimensions");
+    maskDimSizes.push_back(maskDimSize);
   }
   // Verify that if one mask dim size is zero, they all should be zero (because
   // the mask region is a conjunction of each mask dimension interval).
@@ -5335,14 +5340,6 @@ LogicalResult ConstantMaskOp::verify() {
   if (anyZeros && !allZeros)
     return emitOpError("expected all mask dim sizes to be zeros, "
                        "as a result of conjunction with zero mask dim");
-  // Verify that if the mask type is scalable, dimensions should be zero because
-  // constant scalable masks can only be defined for the "none set" or "all set"
-  // cases, and there is no VLA way to define an "all set" case for
-  // `vector.constant_mask`. In the future, a convention could be established
-  // to decide if a specific dimension value could be considered as "all set".
-  if (resultType.isScalable() &&
-      llvm::cast<IntegerAttr>(getMaskDimSizes()[0]).getInt() != 0)
-    return emitOpError("expected mask dim sizes for scalable masks to be 0");
   return success();
 }
 
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
index 9a828ec0b845e4a..418dc6786a76ed4 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
@@ -105,7 +105,6 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
                                 PatternRewriter &rewriter) const override {
     auto loc = op.getLoc();
     auto dstType = op.getType();
-    auto eltType = dstType.getElementType();
     auto dimSizes = op.getMaskDimSizes();
     int64_t rank = dstType.getRank();
 
@@ -115,43 +114,41 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
       bool value = cast<IntegerAttr>(dimSizes[0]).getInt() == 1;
       rewriter.replaceOpWithNewOp<arith::ConstantOp>(
           op, dstType,
-          DenseIntElementsAttr::get(
-              VectorType::get(ArrayRef<int64_t>{}, rewriter.getI1Type()),
-              ArrayRef<bool>{value}));
+          DenseIntElementsAttr::get(VectorType::get({}, rewriter.getI1Type()),
+                                    value));
       return success();
     }
 
-    // Scalable constant masks can only be lowered for the "none set" case.
-    if (cast<VectorType>(dstType).isScalable()) {
-      rewriter.replaceOpWithNewOp<arith::ConstantOp>(
-          op, DenseElementsAttr::get(dstType, false));
-      return success();
-    }
-
-    int64_t trueDim = std::min(dstType.getDimSize(0),
-                               cast<IntegerAttr>(dimSizes[0]).getInt());
+    int64_t trueDimSize = cast<IntegerAttr>(dimSizes[0]).getInt();
 
     if (rank == 1) {
-      // Express constant 1-D case in explicit vector form:
-      //   [T,..,T,F,..,F].
-      SmallVector<bool> values(dstType.getDimSize(0));
-      for (int64_t d = 0; d < trueDim; d++)
-        values[d] = true;
-      rewriter.replaceOpWithNewOp<arith::ConstantOp>(
-          op, dstType, rewriter.getBoolVectorAttr(values));
+      if (trueDimSize == 0 || trueDimSize == dstType.getDimSize(0)) {
+        // Use constant splat for 'all set' or 'none set' dims.
+        // This produces correct code for scalable dimensions.
+        rewriter.replaceOpWithNewOp<arith::ConstantOp>(
+            op, DenseElementsAttr::get(dstType, trueDimSize != 0));
+      } else {
+        // Express constant 1-D case in explicit vector form:
+        //   [T,..,T,F,..,F].
+        SmallVector<bool> values(dstType.getDimSize(0));
+        for (int64_t d = 0; d < trueDimSize; d++)
+          values[d] = true;
+        rewriter.replaceOpWithNewOp<arith::ConstantOp>(
+            op, dstType, rewriter.getBoolVectorAttr(values));
+      }
       return success();
     }
 
-    VectorType lowType =
-        VectorType::get(dstType.getShape().drop_front(), eltType);
-    SmallVector<int64_t> newDimSizes;
-    for (int64_t r = 1; r < rank; r++)
-      newDimSizes.push_back(cast<IntegerAttr>(dimSizes[r]).getInt());
+    if (dstType.getScalableDims().front())
+      return rewriter.notifyMatchFailure(
+          op, "Cannot unroll leading scalable dim in dstType");
+
+    VectorType lowType = VectorType::Builder(dstType).dropDim(0);
     Value trueVal = rewriter.create<vector::ConstantMaskOp>(
-        loc, lowType, rewriter.getI64ArrayAttr(newDimSizes));
+        loc, lowType, rewriter.getArrayAttr(dimSizes.getValue().drop_front()));
     Value result = rewriter.create<arith::ConstantOp>(
         loc, dstType, rewriter.getZeroAttr(dstType));
-    for (int64_t d = 0; d < trueDim; d++)
+    for (int64_t d = 0; d < trueDimSize; d++)
       result =
           rewriter.create<vector::InsertOp>(loc, dstType, trueVal, result, d);
     rewriter.replaceOp(op, result);
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 7b29ef44c1f2f2e..27bd5b5ea0eed7b 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1819,16 +1819,55 @@ func.func @genbool_1d() -> vector<8xi1> {
 
 // -----
 
-func.func @genbool_1d_scalable() -> vector<[8]xi1> {
+func.func @genbool_1d_scalable_pfalse() -> vector<[8]xi1> {
   %0 = vector.constant_mask [0] : vector<[8]xi1>
   return %0 : vector<[8]xi1>
 }
-// CHECK-LABEL: func @genbool_1d_scalable
+// CHECK-LABEL: func @genbool_1d_scalable_pfalse
 // CHECK: %[[VAL_0:.*]] = arith.constant dense<false> : vector<[8]xi1>
 // CHECK: return %[[VAL_0]] : vector<[8]xi1>
 
 // -----
 
+func.func @genbool_1d_scalable_ptrue() -> vector<[8]xi1> {
+  %0 = vector.constant_mask [8] : vector<[8]xi1>
+  return %0 : vector<[8]xi1>
+}
+// CHECK-LABEL: func @genbool_1d_scalable_ptrue
+// CHECK: %[[VAL_0:.*]] = arith.constant dense<true> : vector<[8]xi1>
+// CHECK: return %[[VAL_0]] : vector<[8]xi1>
+
+// -----
+
+func.func @genbool_2d_scalable() -> vector<4x[4]xi1> {
+  %0 = vector.constant_mask [2, 4] : vector<4x[4]xi1>
+  return %0 : vector<4x[4]xi1>
+}
+// CHECK-LABEL:   func.func @genbool_2d_scalable() -> vector<4x[4]xi1> {
+// CHECK:           %[[VAL_0:.*]] = arith.constant dense<true> : vector<[4]xi1>
+// CHECK:           %[[VAL_1:.*]] = arith.constant dense<false> : vector<4x[4]xi1>
+// CHECK:           %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[VAL_1]] : vector<4x[4]xi1> to !llvm.array<4 x vector<[4]xi1>>
+// CHECK:           %[[VAL_3:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_2]][0] : !llvm.array<4 x vector<[4]xi1>>
+// CHECK:           %[[VAL_4:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_3]][1] : !llvm.array<4 x vector<[4]xi1>>
+// CHECK:           %[[VAL_5:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : !llvm.array<4 x vector<[4]xi1>> to vector<4x[4]xi1>
+// CHECK:           return %[[VAL_5]] : vector<4x[4]xi1>
+// CHECK:         }
+
+// -----
+
+/// Currently, this is not supported as generating the mask would require
+/// unrolling the leading scalable dimension at compile time.
+func.func @cannot_genbool_2d_leading_scalable() -> vector<[4]x4xi1> {
+  %0 = vector.constant_mask [4, 2] : vector<[4]x4xi1>
+  return %0 : vector<[4]x4xi1>
+}
+// CHECK-LABEL:   func.func @cannot_genbool_2d_leading_scalable() -> vector<[4]x4xi1> {
+// CHECK:           %[[VAL_0:.*]] = vector.constant_mask [4, 2] : vector<[4]x4xi1>
+// CHECK:           return %[[VAL_0]] : vector<[4]x4xi1>
+// CHECK:         }
+
+// -----
+
 func.func @genbool_2d() -> vector<4x4xi1> {
   %v = vector.constant_mask [2, 2] : vector<4x4xi1>
   return %v: vector<4x4xi1>
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 50119c2b4a36261..26772b929493585 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -995,7 +995,7 @@ func.func @constant_mask_with_zero_mask_dim_size() {
 // -----
 
 func.func @constant_mask_scalable_non_zero_dim_size() {
-  // expected-error at +1 {{expected mask dim sizes for scalable masks to be 0}}
+  // expected-error at +1 {{only supports 'none set' or 'all set' scalable dimensions}}
   %0 = vector.constant_mask [2] : vector<[8]xi1>
 }
 
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 4ea4379372e8380..96c56946cd1cfff 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -448,6 +448,10 @@ func.func @constant_vector_mask() {
   %0 = vector.constant_mask [3, 2] : vector<4x3xi1>
   // CHECK: vector.constant_mask [0] : vector<[4]xi1>
   %1 = vector.constant_mask [0] : vector<[4]xi1>
+  // CHECK: vector.constant_mask [4] : vector<[4]xi1>
+  %2 = vector.constant_mask [4] : vector<[4]xi1>
+  // CHECK: vector.constant_mask [1, 4] : vector<2x[4]xi1>
+  %3 = vector.constant_mask [1, 4] : vector<2x[4]xi1>
   return
 }
 
@@ -1003,7 +1007,7 @@ func.func @contraction_masked_scalable(%A: vector<3x4xf32>,
                                     %C: vector<3x[8]xf32>,
                                     %M : vector<3x[8]x4xi1>) -> vector<3x[8]xf32> {
  // CHECK:  vector.mask %[[M]] { vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[A]], %[[B]], %[[C]] : vector<3x4xf32>, vector<4x[8]xf32> into vector<3x[8]xf32> } : vector<3x[8]x4xi1> -> vector<3x[8]xf32>
-  %0 = vector.mask %M { vector.contract #matmat_trait %A, %B, %C : vector<3x4xf32>, vector<4x[8]xf32> into vector<3x[8]xf32> } 
+  %0 = vector.mask %M { vector.contract #matmat_trait %A, %B, %C : vector<3x4xf32>, vector<4x[8]xf32> into vector<3x[8]xf32> }
     : vector<3x[8]x4xi1> -> vector<3x[8]xf32>
   return %0 : vector<3x[8]xf32>
 }

``````````

</details>


https://github.com/llvm/llvm-project/pull/66638


More information about the Mlir-commits mailing list