[Mlir-commits] [mlir] 2f11ce5 - [mlir][VectorOps] Extend vector.constant_mask to support 'all true' scalable dims (#66638)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Sep 20 06:54:46 PDT 2023


Author: Benjamin Maxwell
Date: 2023-09-20T14:54:42+01:00
New Revision: 2f11ce5579b6f7db960a5b7a1e6c7b2b75dc94c8

URL: https://github.com/llvm/llvm-project/commit/2f11ce5579b6f7db960a5b7a1e6c7b2b75dc94c8
DIFF: https://github.com/llvm/llvm-project/commit/2f11ce5579b6f7db960a5b7a1e6c7b2b75dc94c8.diff

LOG: [mlir][VectorOps] Extend vector.constant_mask to support 'all true' scalable dims (#66638)

This extends `vector.constant_mask` so that mask dim sizes that
correspond to a scalable dimension are treated as if they're implicitly
multiplied by vscale. Currently this is limited to mask dim sizes of 0
or the size of the dim/vscale. This allows constant masks to represent
all true and all false scalable masks (and some variations):

```
// All true scalable mask
%mask = vector.constant_mask [8] : vector<[8]xi1>

// All false scalable mask
%mask = vector.constant_mask [0] : vector<[8]xi1>

// First two scalable rows
%mask = vector.constant_mask [2,4] : vector<4x[4]xi1>
```

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
    mlir/test/Dialect/Vector/invalid.mlir
    mlir/test/Dialect/Vector/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 28b5864914f6920..64fbd722a4f02c3 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2248,7 +2248,10 @@ def Vector_ConstantMaskOp :
     define a hyper-rectangular region within which elements values are set to 1
     (otherwise element values are set to 0). Each value of 'mask_dim_sizes' must
     be non-negative and not greater than the size of the corresponding vector
-    dimension (as opposed to vector.create_mask which allows this).
+    dimension (as opposed to vector.create_mask which allows this). Sizes that
+    correspond to scalable dimensions are implicitly multiplied by vscale,
+    though currently only zero (none set) or the size of the dim/vscale
+    (all set) are supported.
 
     Example:
 

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index a8ad05f7bc1cabf..3c68cb26fb55a11 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5320,13 +5320,18 @@ LogicalResult ConstantMaskOp::verify() {
   // Verify that each array attr element is in bounds of corresponding vector
   // result dimension size.
   auto resultShape = resultType.getShape();
+  auto resultScalableDims = resultType.getScalableDims();
   SmallVector<int64_t, 4> maskDimSizes;
-  for (const auto &it : llvm::enumerate(getMaskDimSizes())) {
-    int64_t attrValue = llvm::cast<IntegerAttr>(it.value()).getInt();
-    if (attrValue < 0 || attrValue > resultShape[it.index()])
+  for (const auto [index, intAttr] : llvm::enumerate(getMaskDimSizes())) {
+    int64_t maskDimSize = llvm::cast<IntegerAttr>(intAttr).getInt();
+    if (maskDimSize < 0 || maskDimSize > resultShape[index])
       return emitOpError(
           "array attr of size out of bounds of vector result dimension size");
-    maskDimSizes.push_back(attrValue);
+    if (resultScalableDims[index] && maskDimSize != 0 &&
+        maskDimSize != resultShape[index])
+      return emitOpError(
+          "only supports 'none set' or 'all set' scalable dimensions");
+    maskDimSizes.push_back(maskDimSize);
   }
   // Verify that if one mask dim size is zero, they all should be zero (because
   // the mask region is a conjunction of each mask dimension interval).
@@ -5335,14 +5340,6 @@ LogicalResult ConstantMaskOp::verify() {
   if (anyZeros && !allZeros)
     return emitOpError("expected all mask dim sizes to be zeros, "
                        "as a result of conjunction with zero mask dim");
-  // Verify that if the mask type is scalable, dimensions should be zero because
-  // constant scalable masks can only be defined for the "none set" or "all set"
-  // cases, and there is no VLA way to define an "all set" case for
-  // `vector.constant_mask`. In the future, a convention could be established
-  // to decide if a specific dimension value could be considered as "all set".
-  if (resultType.isScalable() &&
-      llvm::cast<IntegerAttr>(getMaskDimSizes()[0]).getInt() != 0)
-    return emitOpError("expected mask dim sizes for scalable masks to be 0");
   return success();
 }
 

diff  --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
index 9a828ec0b845e4a..95b5ea011c82569 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
@@ -105,7 +105,6 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
                                 PatternRewriter &rewriter) const override {
     auto loc = op.getLoc();
     auto dstType = op.getType();
-    auto eltType = dstType.getElementType();
     auto dimSizes = op.getMaskDimSizes();
     int64_t rank = dstType.getRank();
 
@@ -115,43 +114,43 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
       bool value = cast<IntegerAttr>(dimSizes[0]).getInt() == 1;
       rewriter.replaceOpWithNewOp<arith::ConstantOp>(
           op, dstType,
-          DenseIntElementsAttr::get(
-              VectorType::get(ArrayRef<int64_t>{}, rewriter.getI1Type()),
-              ArrayRef<bool>{value}));
+          DenseIntElementsAttr::get(VectorType::get({}, rewriter.getI1Type()),
+                                    value));
       return success();
     }
 
-    // Scalable constant masks can only be lowered for the "none set" case.
-    if (cast<VectorType>(dstType).isScalable()) {
-      rewriter.replaceOpWithNewOp<arith::ConstantOp>(
-          op, DenseElementsAttr::get(dstType, false));
-      return success();
-    }
-
-    int64_t trueDim = std::min(dstType.getDimSize(0),
-                               cast<IntegerAttr>(dimSizes[0]).getInt());
+    int64_t trueDimSize = cast<IntegerAttr>(dimSizes[0]).getInt();
 
     if (rank == 1) {
-      // Express constant 1-D case in explicit vector form:
-      //   [T,..,T,F,..,F].
-      SmallVector<bool> values(dstType.getDimSize(0));
-      for (int64_t d = 0; d < trueDim; d++)
-        values[d] = true;
-      rewriter.replaceOpWithNewOp<arith::ConstantOp>(
-          op, dstType, rewriter.getBoolVectorAttr(values));
+      if (trueDimSize == 0 || trueDimSize == dstType.getDimSize(0)) {
+        // Use constant splat for 'all set' or 'none set' dims.
+        // This produces correct code for scalable dimensions (it will lower to
+        // a constant splat).
+        rewriter.replaceOpWithNewOp<arith::ConstantOp>(
+            op, DenseElementsAttr::get(dstType, trueDimSize != 0));
+      } else {
+        // Express constant 1-D case in explicit vector form:
+        //   [T,..,T,F,..,F].
+        // Note: The verifier would reject this case for scalable vectors.
+        SmallVector<bool> values(dstType.getDimSize(0), false);
+        for (int64_t d = 0; d < trueDimSize; d++)
+          values[d] = true;
+        rewriter.replaceOpWithNewOp<arith::ConstantOp>(
+            op, dstType, rewriter.getBoolVectorAttr(values));
+      }
       return success();
     }
 
-    VectorType lowType =
-        VectorType::get(dstType.getShape().drop_front(), eltType);
-    SmallVector<int64_t> newDimSizes;
-    for (int64_t r = 1; r < rank; r++)
-      newDimSizes.push_back(cast<IntegerAttr>(dimSizes[r]).getInt());
+    if (dstType.getScalableDims().front())
+      return rewriter.notifyMatchFailure(
+          op, "Cannot unroll leading scalable dim in dstType");
+
+    VectorType lowType = VectorType::Builder(dstType).dropDim(0);
     Value trueVal = rewriter.create<vector::ConstantMaskOp>(
-        loc, lowType, rewriter.getI64ArrayAttr(newDimSizes));
+        loc, lowType, rewriter.getArrayAttr(dimSizes.getValue().drop_front()));
     Value result = rewriter.create<arith::ConstantOp>(
         loc, dstType, rewriter.getZeroAttr(dstType));
-    for (int64_t d = 0; d < trueDim; d++)
+    for (int64_t d = 0; d < trueDimSize; d++)
       result =
           rewriter.create<vector::InsertOp>(loc, dstType, trueVal, result, d);
     rewriter.replaceOp(op, result);

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 7b29ef44c1f2f2e..083b3af90e8c50c 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1819,16 +1819,53 @@ func.func @genbool_1d() -> vector<8xi1> {
 
 // -----
 
-func.func @genbool_1d_scalable() -> vector<[8]xi1> {
+func.func @genbool_1d_scalable_all_false() -> vector<[8]xi1> {
   %0 = vector.constant_mask [0] : vector<[8]xi1>
   return %0 : vector<[8]xi1>
 }
-// CHECK-LABEL: func @genbool_1d_scalable
+// CHECK-LABEL: func @genbool_1d_scalable_all_false
 // CHECK: %[[VAL_0:.*]] = arith.constant dense<false> : vector<[8]xi1>
 // CHECK: return %[[VAL_0]] : vector<[8]xi1>
 
 // -----
 
+func.func @genbool_1d_scalable_all_true() -> vector<[8]xi1> {
+  %0 = vector.constant_mask [8] : vector<[8]xi1>
+  return %0 : vector<[8]xi1>
+}
+// CHECK-LABEL: func @genbool_1d_scalable_all_true
+// CHECK: %[[VAL_0:.*]] = arith.constant dense<true> : vector<[8]xi1>
+// CHECK: return %[[VAL_0]] : vector<[8]xi1>
+
+// -----
+
+func.func @genbool_2d_trailing_scalable() -> vector<4x[4]xi1> {
+  %0 = vector.constant_mask [2, 4] : vector<4x[4]xi1>
+  return %0 : vector<4x[4]xi1>
+}
+// CHECK-LABEL:   func.func @genbool_2d_trailing_scalable
+// CHECK:           %[[VAL_0:.*]] = arith.constant dense<true> : vector<[4]xi1>
+// CHECK:           %[[VAL_1:.*]] = arith.constant dense<false> : vector<4x[4]xi1>
+// CHECK:           %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[VAL_1]] : vector<4x[4]xi1> to !llvm.array<4 x vector<[4]xi1>>
+// CHECK:           %[[VAL_3:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_2]][0] : !llvm.array<4 x vector<[4]xi1>>
+// CHECK:           %[[VAL_4:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_3]][1] : !llvm.array<4 x vector<[4]xi1>>
+// CHECK:           %[[VAL_5:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : !llvm.array<4 x vector<[4]xi1>> to vector<4x[4]xi1>
+// CHECK:           return %[[VAL_5]] : vector<4x[4]xi1>
+
+// -----
+
+/// Currently, this is not supported as generating the mask would require
+/// unrolling the leading scalable dimension at compile time.
+func.func @cannot_genbool_2d_leading_scalable() -> vector<[4]x4xi1> {
+  %0 = vector.constant_mask [4, 2] : vector<[4]x4xi1>
+  return %0 : vector<[4]x4xi1>
+}
+// CHECK-LABEL:   func.func @cannot_genbool_2d_leading_scalable
+// CHECK:           %[[VAL_0:.*]] = vector.constant_mask [4, 2] : vector<[4]x4xi1>
+// CHECK:           return %[[VAL_0]] : vector<[4]x4xi1>
+
+// -----
+
 func.func @genbool_2d() -> vector<4x4xi1> {
   %v = vector.constant_mask [2, 2] : vector<4x4xi1>
   return %v: vector<4x4xi1>

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 50119c2b4a36261..26772b929493585 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -995,7 +995,7 @@ func.func @constant_mask_with_zero_mask_dim_size() {
 // -----
 
 func.func @constant_mask_scalable_non_zero_dim_size() {
-  // expected-error at +1 {{expected mask dim sizes for scalable masks to be 0}}
+  // expected-error at +1 {{only supports 'none set' or 'all set' scalable dimensions}}
   %0 = vector.constant_mask [2] : vector<[8]xi1>
 }
 

diff  --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 4ea4379372e8380..18a4c202edfb0a1 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -448,6 +448,10 @@ func.func @constant_vector_mask() {
   %0 = vector.constant_mask [3, 2] : vector<4x3xi1>
   // CHECK: vector.constant_mask [0] : vector<[4]xi1>
   %1 = vector.constant_mask [0] : vector<[4]xi1>
+  // CHECK: vector.constant_mask [4] : vector<[4]xi1>
+  %2 = vector.constant_mask [4] : vector<[4]xi1>
+  // CHECK: vector.constant_mask [1, 4] : vector<2x[4]xi1>
+  %3 = vector.constant_mask [1, 4] : vector<2x[4]xi1>
   return
 }
 


        


More information about the Mlir-commits mailing list