[Mlir-commits] [mlir] a75a46d - [mlir][Vector] Enable create_mask for scalable vectors

Javier Setoain llvmlistbot at llvm.org
Fri Mar 25 03:50:04 PDT 2022


Author: Javier Setoain
Date: 2022-03-25T10:48:59Z
New Revision: a75a46db89f3fe3f3cb7d683e2b6d0227f282e18

URL: https://github.com/llvm/llvm-project/commit/a75a46db89f3fe3f3cb7d683e2b6d0227f282e18
DIFF: https://github.com/llvm/llvm-project/commit/a75a46db89f3fe3f3cb7d683e2b6d0227f282e18.diff

LOG: [mlir][Vector] Enable create_mask for scalable vectors

The way vector.create_mask is currently lowered is
vector-length-dependent, and therefore incompatible with scalable vector
types. This patch adds an alternative lowering path for create_mask
operations that return a scalable vector mask.

Differential Revision: https://reviews.llvm.org/D118248

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
    mlir/include/mlir/Dialect/Arithmetic/Utils/Utils.h
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
    mlir/lib/Dialect/Arithmetic/Utils/Utils.cpp
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
    mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
    mlir/test/Dialect/Vector/canonicalize.mlir
    mlir/test/Dialect/Vector/invalid.mlir
    mlir/test/Dialect/Vector/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
index 92323b0a82a1e..94cb53f9300b8 100644
--- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
+++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
@@ -63,9 +63,10 @@ void populateVectorToLLVMMatrixConversionPatterns(LLVMTypeConverter &converter,
                                                   RewritePatternSet &patterns);
 
 /// Collect a set of patterns to convert from the Vector dialect to LLVM.
+/// If `indexOptimizations` is set, assume indices fit in 32-bit.
 void populateVectorToLLVMConversionPatterns(
     LLVMTypeConverter &converter, RewritePatternSet &patterns,
-    bool reassociateFPReductions = false);
+    bool reassociateFPReductions = false, bool indexOptimizations = false);
 
 /// Create a pass to convert vector operations to the LLVMIR dialect.
 std::unique_ptr<OperationPass<ModuleOp>> createConvertVectorToLLVMPass(

diff  --git a/mlir/include/mlir/Dialect/Arithmetic/Utils/Utils.h b/mlir/include/mlir/Dialect/Arithmetic/Utils/Utils.h
index 1d4353f176b1b..2c60f885fd10e 100644
--- a/mlir/include/mlir/Dialect/Arithmetic/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Arithmetic/Utils/Utils.h
@@ -80,6 +80,12 @@ class OpWithOffsetSizesAndStridesConstantArgumentFolder final
 Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
                                       OpFoldResult ofr);
 
+/// Create a cast from an index-like value (index or integer) to another
+/// index-like value. If the value type and the target type are the same, it
+/// returns the original value.
+Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc,
+                                      Type targetType, Value value);
+
 /// Similar to the other overload, but converts multiple OpFoldResults into
 /// Values.
 SmallVector<Value>

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 05cd5870bf1f9..327b31a485e6a 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -1752,6 +1752,14 @@ def LLVM_masked_compressstore
 /// Create a call to vscale intrinsic.
 def LLVM_vscale : LLVM_IntrOp<"vscale", [0], [], [], 1>;
 
+/// Create a call to stepvector intrinsic.
+def LLVM_StepVectorOp
+    : LLVM_IntrOp<"experimental.stepvector", [0], [], [NoSideEffect], 1> {
+  let arguments = (ins);
+  let results = (outs LLVM_Type:$res);
+  let assemblyFormat = "attr-dict `:` type($res)";
+}
+
 // Atomic operations.
 //
 

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 697b7a8d8786b..20e51008c52b1 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -10,6 +10,7 @@
 
 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -900,6 +901,40 @@ class VectorTypeCastOpConversion
   }
 };
 
+/// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only).
+/// Non-scalable versions of this operation are handled in Vector Transforms.
+class VectorCreateMaskOpRewritePattern
+    : public OpRewritePattern<vector::CreateMaskOp> {
+public:
+  explicit VectorCreateMaskOpRewritePattern(MLIRContext *context,
+                                            bool enableIndexOpt)
+      : OpRewritePattern<vector::CreateMaskOp>(context),
+        indexOptimizations(enableIndexOpt) {}
+
+  LogicalResult matchAndRewrite(vector::CreateMaskOp op,
+                                PatternRewriter &rewriter) const override {
+    auto dstType = op.getType();
+    if (dstType.getRank() != 1 || !dstType.cast<VectorType>().isScalable())
+      return failure();
+    IntegerType idxType =
+        indexOptimizations ? rewriter.getI32Type() : rewriter.getI64Type();
+    auto loc = op->getLoc();
+    Value indices = rewriter.create<LLVM::StepVectorOp>(
+        loc, LLVM::getVectorType(idxType, dstType.getShape()[0],
+                                 /*isScalable=*/true));
+    auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType,
+                                                 op.getOperand(0));
+    Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
+    Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
+                                                indices, bounds);
+    rewriter.replaceOp(op, comp);
+    return success();
+  }
+
+private:
+  const bool indexOptimizations;
+};
+
 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
 public:
   using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern;
@@ -1157,13 +1192,15 @@ struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
 } // namespace
 
 /// Populate the given list with patterns that convert from Vector to LLVM.
-void mlir::populateVectorToLLVMConversionPatterns(
-    LLVMTypeConverter &converter, RewritePatternSet &patterns,
-    bool reassociateFPReductions) {
+void mlir::populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter,
+                                                  RewritePatternSet &patterns,
+                                                  bool reassociateFPReductions,
+                                                  bool indexOptimizations) {
   MLIRContext *ctx = converter.getDialect()->getContext();
   patterns.add<VectorFMAOpNDRewritePattern>(ctx);
   populateVectorInsertExtractStridedSliceTransforms(patterns);
   patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
+  patterns.add<VectorCreateMaskOpRewritePattern>(ctx, indexOptimizations);
   patterns
       .add<VectorBitCastOpConversion, VectorShuffleOpConversion,
            VectorExtractElementOpConversion, VectorExtractOpConversion,

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 16d57efc58588..68edc23e82375 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -80,8 +80,8 @@ void LowerVectorToLLVMPass::runOnOperation() {
   populateVectorMaskMaterializationPatterns(patterns, indexOptimizations);
   populateVectorTransferLoweringPatterns(patterns);
   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
-  populateVectorToLLVMConversionPatterns(converter, patterns,
-                                         reassociateFPReductions);
+  populateVectorToLLVMConversionPatterns(
+      converter, patterns, reassociateFPReductions, indexOptimizations);
   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
 
   // Architecture specific augmentations.

diff  --git a/mlir/lib/Dialect/Arithmetic/Utils/Utils.cpp b/mlir/lib/Dialect/Arithmetic/Utils/Utils.cpp
index 4e35c5b319245..b568891df66a1 100644
--- a/mlir/lib/Dialect/Arithmetic/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Arithmetic/Utils/Utils.cpp
@@ -59,6 +59,27 @@ Value mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
   return b.create<arith::ConstantIndexOp>(loc, attr.getValue().getSExtValue());
 }
 
+Value mlir::getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc,
+                                            Type targetType, Value value) {
+  if (targetType == value.getType())
+    return value;
+
+  bool targetIsIndex = targetType.isIndex();
+  bool valueIsIndex = value.getType().isIndex();
+  if (targetIsIndex ^ valueIsIndex)
+    return b.create<arith::IndexCastOp>(loc, targetType, value);
+
+  auto targetIntegerType = targetType.dyn_cast<IntegerType>();
+  auto valueIntegerType = value.getType().dyn_cast<IntegerType>();
+  assert(targetIntegerType && valueIntegerType &&
+         "unexpected cast between types other than integers and index");
+  assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness());
+
+  if (targetIntegerType.getWidth() > valueIntegerType.getWidth())
+    return b.create<arith::ExtSIOp>(loc, targetIntegerType, value);
+  return b.create<arith::TruncIOp>(loc, targetIntegerType, value);
+}
+
 SmallVector<Value>
 mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
                                       ArrayRef<OpFoldResult> valueOrAttrVec) {

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index da52e1b580efe..9cf1538dd8bc0 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4232,6 +4232,14 @@ LogicalResult ConstantMaskOp::verify() {
   if (anyZeros && !allZeros)
     return emitOpError("expected all mask dim sizes to be zeros, "
                        "as a result of conjunction with zero mask dim");
+  // Verify that if the mask type is scalable, dimensions should be zero because
+  // constant scalable masks can only be defined for the "none set" or "all set"
+  // cases, and there is no VLA way to define an "all set" case for
+  // `vector.constant_mask`. In the future, a convention could be established
+  // to decide if a specific dimension value could be considered as "all set".
+  if (resultType.isScalable() &&
+      mask_dim_sizes()[0].cast<IntegerAttr>().getInt() != 0)
+    return emitOpError("expected mask dim sizes for scalable masks to be 0");
   return success();
 }
 
@@ -4269,6 +4277,19 @@ class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
     };
     if (llvm::any_of(createMaskOp.operands(), isNotDefByConstant))
       return failure();
+
+    // CreateMaskOp for scalable vectors can be folded only if all dimensions
+    // are negative or zero.
+    if (auto vType = createMaskOp.getType().dyn_cast<VectorType>()) {
+      if (vType.isScalable())
+        for (auto opDim : createMaskOp.getOperands()) {
+          APInt intVal;
+          if (matchPattern(opDim, m_ConstantInt(&intVal)) &&
+              intVal.isStrictlyPositive())
+            return failure();
+        }
+    }
+
     // Gather constant mask dimension sizes.
     SmallVector<int64_t, 4> maskDimSizes;
     for (auto it : llvm::zip(createMaskOp.operands(),

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index bd16dfaf19504..4f99c7985fd31 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -16,6 +16,8 @@
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/SCF.h"
@@ -602,6 +604,13 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
       return success();
     }
 
+    // Scalable constant masks can only be lowered for the "none set" case.
+    if (dstType.cast<VectorType>().isScalable()) {
+      rewriter.replaceOpWithNewOp<arith::ConstantOp>(
+          op, DenseElementsAttr::get(dstType, false));
+      return success();
+    }
+
     int64_t trueDim = std::min(dstType.getDimSize(0),
                                dimSizes[0].cast<IntegerAttr>().getInt());
 
@@ -2161,27 +2170,6 @@ struct BubbleUpBitCastForStridedSliceInsert
   }
 };
 
-static Value createCastToIndexLike(PatternRewriter &rewriter, Location loc,
-                                   Type targetType, Value value) {
-  if (targetType == value.getType())
-    return value;
-
-  bool targetIsIndex = targetType.isIndex();
-  bool valueIsIndex = value.getType().isIndex();
-  if (targetIsIndex ^ valueIsIndex)
-    return rewriter.create<arith::IndexCastOp>(loc, targetType, value);
-
-  auto targetIntegerType = targetType.dyn_cast<IntegerType>();
-  auto valueIntegerType = value.getType().dyn_cast<IntegerType>();
-  assert(targetIntegerType && valueIntegerType &&
-         "unexpected cast between types other than integers and index");
-  assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness());
-
-  if (targetIntegerType.getWidth() > valueIntegerType.getWidth())
-    return rewriter.create<arith::ExtSIOp>(loc, targetIntegerType, value);
-  return rewriter.create<arith::TruncIOp>(loc, targetIntegerType, value);
-}
-
 // Helper that returns a vector comparison that constructs a mask:
 //     mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
 //
@@ -2217,12 +2205,12 @@ static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op,
   Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr);
   // Add in an offset if requested.
   if (off) {
-    Value o = createCastToIndexLike(rewriter, loc, idxType, *off);
+    Value o = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, *off);
     Value ov = rewriter.create<vector::SplatOp>(loc, indices.getType(), o);
     indices = rewriter.create<arith::AddIOp>(loc, ov, indices);
   }
   // Construct the vector comparison.
-  Value bound = createCastToIndexLike(rewriter, loc, idxType, b);
+  Value bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, b);
   Value bounds =
       rewriter.create<vector::SplatOp>(loc, indices.getType(), bound);
   return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, indices,
@@ -2292,6 +2280,8 @@ class VectorCreateMaskOpConversion
   LogicalResult matchAndRewrite(vector::CreateMaskOp op,
                                 PatternRewriter &rewriter) const override {
     auto dstType = op.getType();
+    if (dstType.cast<VectorType>().isScalable())
+      return failure();
     int64_t rank = dstType.getRank();
     if (rank > 1)
       return failure();

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir
index 7ed8f96789bb1..3c2ac46613310 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir
@@ -24,6 +24,29 @@ func @genbool_var_1d(%arg0: index) -> vector<11xi1> {
   return %0 : vector<11xi1>
 }
 
+// CMP32-LABEL: @genbool_var_1d_scalable(
+// CMP32-SAME: %[[ARG:.*]]: index)
+// CMP32: %[[T0:.*]] = llvm.intr.experimental.stepvector : vector<[11]xi32>
+// CMP32: %[[T1:.*]] = arith.index_cast %[[ARG]] : index to i32
+// CMP32: %[[T2:.*]] = llvm.insertelement %[[T1]], %{{.*}}[%{{.*}} : i32] : vector<[11]xi32>
+// CMP32: %[[T3:.*]] = llvm.shufflevector %[[T2]], %{{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32] : vector<[11]xi32>, vector<[11]xi32>
+// CMP32: %[[T4:.*]] = arith.cmpi slt, %[[T0]], %[[T3]] : vector<[11]xi32>
+// CMP32: return %[[T4]] : vector<[11]xi1>
+
+// CMP64-LABEL: @genbool_var_1d_scalable(
+// CMP64-SAME: %[[ARG:.*]]: index)
+// CMP64: %[[T0:.*]] = llvm.intr.experimental.stepvector : vector<[11]xi64>
+// CMP64: %[[T1:.*]] = arith.index_cast %[[ARG]] : index to i64
+// CMP64: %[[T2:.*]] = llvm.insertelement %[[T1]], %{{.*}}[%{{.*}} : i32] : vector<[11]xi64>
+// CMP64: %[[T3:.*]] = llvm.shufflevector %[[T2]], %{{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32] : vector<[11]xi64>, vector<[11]xi64>
+// CMP64: %[[T4:.*]] = arith.cmpi slt, %[[T0]], %[[T3]] : vector<[11]xi64>
+// CMP64: return %[[T4]] : vector<[11]xi1>
+
+func @genbool_var_1d_scalable(%arg0: index) -> vector<[11]xi1> {
+  %0 = vector.create_mask %arg0 : vector<[11]xi1>
+  return %0 : vector<[11]xi1>
+}
+
 // CMP32-LABEL: @transfer_read_1d
 // CMP32: %[[MEM:.*]]: memref<?xf32>, %[[OFF:.*]]: index) -> vector<16xf32> {
 // CMP32: %[[D:.*]] = memref.dim %[[MEM]], %{{.*}} : memref<?xf32>

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 3dcbd3ae475e2..cda183df00d24 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1459,6 +1459,16 @@ func @genbool_1d() -> vector<8xi1> {
 
 // -----
 
+func @genbool_1d_scalable() -> vector<[8]xi1> {
+  %0 = vector.constant_mask [0] : vector<[8]xi1>
+  return %0 : vector<[8]xi1>
+}
+// CHECK-LABEL: func @genbool_1d_scalable
+// CHECK: %[[VAL_0:.*]] = arith.constant dense<false> : vector<[8]xi1>
+// CHECK: return %[[VAL_0]] : vector<[8]xi1>
+
+// -----
+
 func @genbool_2d() -> vector<4x4xi1> {
   %v = vector.constant_mask [2, 2] : vector<4x4xi1>
   return %v: vector<4x4xi1>
@@ -1505,6 +1515,20 @@ func @create_mask_1d(%a : index) -> vector<4xi1> {
 // CHECK:  %[[result:.*]] = arith.cmpi slt, %[[indices]], %[[bounds]] : vector<4xi32>
 // CHECK:  return %[[result]] : vector<4xi1>
 
+func @create_mask_1d_scalable(%a : index) -> vector<[4]xi1> {
+  %v = vector.create_mask %a : vector<[4]xi1>
+  return %v: vector<[4]xi1>
+}
+
+// CHECK-LABEL: func @create_mask_1d_scalable
+// CHECK-SAME: %[[arg:.*]]: index
+// CHECK:  %[[indices:.*]] = llvm.intr.experimental.stepvector : vector<[4]xi32>
+// CHECK:  %[[arg_i32:.*]] = arith.index_cast %[[arg]] : index to i32
+// CHECK:  %[[boundsInsert:.*]] = llvm.insertelement %[[arg_i32]], {{.*}} : vector<[4]xi32>
+// CHECK:  %[[bounds:.*]] = llvm.shufflevector %[[boundsInsert]], {{.*}} : vector<[4]xi32>, vector<[4]xi32>
+// CHECK:  %[[result:.*]] = arith.cmpi slt, %[[indices]], %[[bounds]] : vector<[4]xi32>
+// CHECK: return %[[result]] : vector<[4]xi1>
+
 // -----
 
 func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> {

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 9647fb018bcaa..ceae2452cd431 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -13,6 +13,16 @@ func @create_vector_mask_to_constant_mask() -> (vector<4x3xi1>) {
 
 // -----
 
+// CHECK-LABEL: create_scalable_vector_mask_to_constant_mask
+func @create_scalable_vector_mask_to_constant_mask() -> (vector<[8]xi1>) {
+  %c-1 = arith.constant -1 : index
+  // CHECK: vector.constant_mask [0] : vector<[8]xi1>
+  %0 = vector.create_mask %c-1 : vector<[8]xi1>
+  return %0 : vector<[8]xi1>
+}
+
+// -----
+
 // CHECK-LABEL: create_vector_mask_to_constant_mask_truncation
 func @create_vector_mask_to_constant_mask_truncation() -> (vector<4x3xi1>) {
   %c2 = arith.constant 2 : index

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index c90725e5d8d7b..f60d2b103b882 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -944,6 +944,13 @@ func @constant_mask_with_zero_mask_dim_size() {
 
 // -----
 
+func @constant_mask_scalable_non_zero_dim_size() {
+  // expected-error at +1 {{expected mask dim sizes for scalable masks to be 0}}
+  %0 = vector.constant_mask [2] : vector<[8]xi1>
+}
+
+// -----
+
 func @print_no_result(%arg0 : f32) -> i32 {
   // expected-error at +1 {{cannot name an operation with no results}}
   %0 = vector.print %arg0 : f32

diff  --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index ab8daca78f7c6..43b38efb242eb 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -389,6 +389,8 @@ func @constant_vector_mask_0d() {
 func @constant_vector_mask() {
   // CHECK: vector.constant_mask [3, 2] : vector<4x3xi1>
   %0 = vector.constant_mask [3, 2] : vector<4x3xi1>
+  // CHECK: vector.constant_mask [0] : vector<[4]xi1>
+  %1 = vector.constant_mask [0] : vector<[4]xi1>
   return
 }
 


        


More information about the Mlir-commits mailing list