[Mlir-commits] [mlir] a75a46d - [mlir][Vector] Enable create_mask for scalable vectors
Javier Setoain
llvmlistbot at llvm.org
Fri Mar 25 03:50:04 PDT 2022
Author: Javier Setoain
Date: 2022-03-25T10:48:59Z
New Revision: a75a46db89f3fe3f3cb7d683e2b6d0227f282e18
URL: https://github.com/llvm/llvm-project/commit/a75a46db89f3fe3f3cb7d683e2b6d0227f282e18
DIFF: https://github.com/llvm/llvm-project/commit/a75a46db89f3fe3f3cb7d683e2b6d0227f282e18.diff
LOG: [mlir][Vector] Enable create_mask for scalable vectors
The way vector.create_mask is currently lowered is
vector-length-dependent, and therefore incompatible with scalable vector
types. This patch adds an alternative lowering path for create_mask
operations that return a scalable vector mask.
Differential Revision: https://reviews.llvm.org/D118248
Added:
Modified:
mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
mlir/include/mlir/Dialect/Arithmetic/Utils/Utils.h
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
mlir/lib/Dialect/Arithmetic/Utils/Utils.cpp
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
mlir/test/Dialect/Vector/canonicalize.mlir
mlir/test/Dialect/Vector/invalid.mlir
mlir/test/Dialect/Vector/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
index 92323b0a82a1e..94cb53f9300b8 100644
--- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
+++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
@@ -63,9 +63,10 @@ void populateVectorToLLVMMatrixConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns);
/// Collect a set of patterns to convert from the Vector dialect to LLVM.
+/// If `indexOptimizations` is set, assume indices fit in 32-bit.
void populateVectorToLLVMConversionPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns,
- bool reassociateFPReductions = false);
+ bool reassociateFPReductions = false, bool indexOptimizations = false);
/// Create a pass to convert vector operations to the LLVMIR dialect.
std::unique_ptr<OperationPass<ModuleOp>> createConvertVectorToLLVMPass(
diff --git a/mlir/include/mlir/Dialect/Arithmetic/Utils/Utils.h b/mlir/include/mlir/Dialect/Arithmetic/Utils/Utils.h
index 1d4353f176b1b..2c60f885fd10e 100644
--- a/mlir/include/mlir/Dialect/Arithmetic/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Arithmetic/Utils/Utils.h
@@ -80,6 +80,12 @@ class OpWithOffsetSizesAndStridesConstantArgumentFolder final
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
OpFoldResult ofr);
+/// Create a cast from an index-like value (index or integer) to another
+/// index-like value. If the value type and the target type are the same, it
+/// returns the original value.
+Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc,
+ Type targetType, Value value);
+
/// Similar to the other overload, but converts multiple OpFoldResults into
/// Values.
SmallVector<Value>
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 05cd5870bf1f9..327b31a485e6a 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -1752,6 +1752,14 @@ def LLVM_masked_compressstore
/// Create a call to vscale intrinsic.
def LLVM_vscale : LLVM_IntrOp<"vscale", [0], [], [], 1>;
+/// Create a call to stepvector intrinsic.
+def LLVM_StepVectorOp
+ : LLVM_IntrOp<"experimental.stepvector", [0], [], [NoSideEffect], 1> {
+ let arguments = (ins);
+ let results = (outs LLVM_Type:$res);
+ let assemblyFormat = "attr-dict `:` type($res)";
+}
+
// Atomic operations.
//
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 697b7a8d8786b..20e51008c52b1 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -10,6 +10,7 @@
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -900,6 +901,40 @@ class VectorTypeCastOpConversion
}
};
+/// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only).
+/// Non-scalable versions of this operation are handled in Vector Transforms.
+class VectorCreateMaskOpRewritePattern
+ : public OpRewritePattern<vector::CreateMaskOp> {
+public:
+ explicit VectorCreateMaskOpRewritePattern(MLIRContext *context,
+ bool enableIndexOpt)
+ : OpRewritePattern<vector::CreateMaskOp>(context),
+ indexOptimizations(enableIndexOpt) {}
+
+ LogicalResult matchAndRewrite(vector::CreateMaskOp op,
+ PatternRewriter &rewriter) const override {
+ auto dstType = op.getType();
+ if (dstType.getRank() != 1 || !dstType.cast<VectorType>().isScalable())
+ return failure();
+ IntegerType idxType =
+ indexOptimizations ? rewriter.getI32Type() : rewriter.getI64Type();
+ auto loc = op->getLoc();
+ Value indices = rewriter.create<LLVM::StepVectorOp>(
+ loc, LLVM::getVectorType(idxType, dstType.getShape()[0],
+ /*isScalable=*/true));
+ auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType,
+ op.getOperand(0));
+ Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
+ Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
+ indices, bounds);
+ rewriter.replaceOp(op, comp);
+ return success();
+ }
+
+private:
+ const bool indexOptimizations;
+};
+
class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
public:
using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern;
@@ -1157,13 +1192,15 @@ struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
} // namespace
/// Populate the given list with patterns that convert from Vector to LLVM.
-void mlir::populateVectorToLLVMConversionPatterns(
- LLVMTypeConverter &converter, RewritePatternSet &patterns,
- bool reassociateFPReductions) {
+void mlir::populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter,
+ RewritePatternSet &patterns,
+ bool reassociateFPReductions,
+ bool indexOptimizations) {
MLIRContext *ctx = converter.getDialect()->getContext();
patterns.add<VectorFMAOpNDRewritePattern>(ctx);
populateVectorInsertExtractStridedSliceTransforms(patterns);
patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
+ patterns.add<VectorCreateMaskOpRewritePattern>(ctx, indexOptimizations);
patterns
.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
VectorExtractElementOpConversion, VectorExtractOpConversion,
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 16d57efc58588..68edc23e82375 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -80,8 +80,8 @@ void LowerVectorToLLVMPass::runOnOperation() {
populateVectorMaskMaterializationPatterns(patterns, indexOptimizations);
populateVectorTransferLoweringPatterns(patterns);
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
- populateVectorToLLVMConversionPatterns(converter, patterns,
- reassociateFPReductions);
+ populateVectorToLLVMConversionPatterns(
+ converter, patterns, reassociateFPReductions, indexOptimizations);
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
// Architecture specific augmentations.
diff --git a/mlir/lib/Dialect/Arithmetic/Utils/Utils.cpp b/mlir/lib/Dialect/Arithmetic/Utils/Utils.cpp
index 4e35c5b319245..b568891df66a1 100644
--- a/mlir/lib/Dialect/Arithmetic/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Arithmetic/Utils/Utils.cpp
@@ -59,6 +59,27 @@ Value mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
return b.create<arith::ConstantIndexOp>(loc, attr.getValue().getSExtValue());
}
+Value mlir::getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc,
+ Type targetType, Value value) {
+ if (targetType == value.getType())
+ return value;
+
+ bool targetIsIndex = targetType.isIndex();
+ bool valueIsIndex = value.getType().isIndex();
+ if (targetIsIndex ^ valueIsIndex)
+ return b.create<arith::IndexCastOp>(loc, targetType, value);
+
+ auto targetIntegerType = targetType.dyn_cast<IntegerType>();
+ auto valueIntegerType = value.getType().dyn_cast<IntegerType>();
+ assert(targetIntegerType && valueIntegerType &&
+ "unexpected cast between types other than integers and index");
+ assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness());
+
+ if (targetIntegerType.getWidth() > valueIntegerType.getWidth())
+ return b.create<arith::ExtSIOp>(loc, targetIntegerType, value);
+ return b.create<arith::TruncIOp>(loc, targetIntegerType, value);
+}
+
SmallVector<Value>
mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
ArrayRef<OpFoldResult> valueOrAttrVec) {
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index da52e1b580efe..9cf1538dd8bc0 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4232,6 +4232,14 @@ LogicalResult ConstantMaskOp::verify() {
if (anyZeros && !allZeros)
return emitOpError("expected all mask dim sizes to be zeros, "
"as a result of conjunction with zero mask dim");
+ // Verify that if the mask type is scalable, dimensions should be zero because
+ // constant scalable masks can only be defined for the "none set" or "all set"
+ // cases, and there is no VLA way to define an "all set" case for
+ // `vector.constant_mask`. In the future, a convention could be established
+ // to decide if a specific dimension value could be considered as "all set".
+ if (resultType.isScalable() &&
+ mask_dim_sizes()[0].cast<IntegerAttr>().getInt() != 0)
+ return emitOpError("expected mask dim sizes for scalable masks to be 0");
return success();
}
@@ -4269,6 +4277,19 @@ class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
};
if (llvm::any_of(createMaskOp.operands(), isNotDefByConstant))
return failure();
+
+ // CreateMaskOp for scalable vectors can be folded only if all dimensions
+ // are negative or zero.
+ if (auto vType = createMaskOp.getType().dyn_cast<VectorType>()) {
+ if (vType.isScalable())
+ for (auto opDim : createMaskOp.getOperands()) {
+ APInt intVal;
+ if (matchPattern(opDim, m_ConstantInt(&intVal)) &&
+ intVal.isStrictlyPositive())
+ return failure();
+ }
+ }
+
// Gather constant mask dimension sizes.
SmallVector<int64_t, 4> maskDimSizes;
for (auto it : llvm::zip(createMaskOp.operands(),
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index bd16dfaf19504..4f99c7985fd31 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -16,6 +16,8 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
@@ -602,6 +604,13 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
return success();
}
+ // Scalable constant masks can only be lowered for the "none set" case.
+ if (dstType.cast<VectorType>().isScalable()) {
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(
+ op, DenseElementsAttr::get(dstType, false));
+ return success();
+ }
+
int64_t trueDim = std::min(dstType.getDimSize(0),
dimSizes[0].cast<IntegerAttr>().getInt());
@@ -2161,27 +2170,6 @@ struct BubbleUpBitCastForStridedSliceInsert
}
};
-static Value createCastToIndexLike(PatternRewriter &rewriter, Location loc,
- Type targetType, Value value) {
- if (targetType == value.getType())
- return value;
-
- bool targetIsIndex = targetType.isIndex();
- bool valueIsIndex = value.getType().isIndex();
- if (targetIsIndex ^ valueIsIndex)
- return rewriter.create<arith::IndexCastOp>(loc, targetType, value);
-
- auto targetIntegerType = targetType.dyn_cast<IntegerType>();
- auto valueIntegerType = value.getType().dyn_cast<IntegerType>();
- assert(targetIntegerType && valueIntegerType &&
- "unexpected cast between types other than integers and index");
- assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness());
-
- if (targetIntegerType.getWidth() > valueIntegerType.getWidth())
- return rewriter.create<arith::ExtSIOp>(loc, targetIntegerType, value);
- return rewriter.create<arith::TruncIOp>(loc, targetIntegerType, value);
-}
-
// Helper that returns a vector comparison that constructs a mask:
// mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
//
@@ -2217,12 +2205,12 @@ static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op,
Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr);
// Add in an offset if requested.
if (off) {
- Value o = createCastToIndexLike(rewriter, loc, idxType, *off);
+ Value o = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, *off);
Value ov = rewriter.create<vector::SplatOp>(loc, indices.getType(), o);
indices = rewriter.create<arith::AddIOp>(loc, ov, indices);
}
// Construct the vector comparison.
- Value bound = createCastToIndexLike(rewriter, loc, idxType, b);
+ Value bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, b);
Value bounds =
rewriter.create<vector::SplatOp>(loc, indices.getType(), bound);
return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, indices,
@@ -2292,6 +2280,8 @@ class VectorCreateMaskOpConversion
LogicalResult matchAndRewrite(vector::CreateMaskOp op,
PatternRewriter &rewriter) const override {
auto dstType = op.getType();
+ if (dstType.cast<VectorType>().isScalable())
+ return failure();
int64_t rank = dstType.getRank();
if (rank > 1)
return failure();
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir
index 7ed8f96789bb1..3c2ac46613310 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir
@@ -24,6 +24,29 @@ func @genbool_var_1d(%arg0: index) -> vector<11xi1> {
return %0 : vector<11xi1>
}
+// CMP32-LABEL: @genbool_var_1d_scalable(
+// CMP32-SAME: %[[ARG:.*]]: index)
+// CMP32: %[[T0:.*]] = llvm.intr.experimental.stepvector : vector<[11]xi32>
+// CMP32: %[[T1:.*]] = arith.index_cast %[[ARG]] : index to i32
+// CMP32: %[[T2:.*]] = llvm.insertelement %[[T1]], %{{.*}}[%{{.*}} : i32] : vector<[11]xi32>
+// CMP32: %[[T3:.*]] = llvm.shufflevector %[[T2]], %{{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32] : vector<[11]xi32>, vector<[11]xi32>
+// CMP32: %[[T4:.*]] = arith.cmpi slt, %[[T0]], %[[T3]] : vector<[11]xi32>
+// CMP32: return %[[T4]] : vector<[11]xi1>
+
+// CMP64-LABEL: @genbool_var_1d_scalable(
+// CMP64-SAME: %[[ARG:.*]]: index)
+// CMP64: %[[T0:.*]] = llvm.intr.experimental.stepvector : vector<[11]xi64>
+// CMP64: %[[T1:.*]] = arith.index_cast %[[ARG]] : index to i64
+// CMP64: %[[T2:.*]] = llvm.insertelement %[[T1]], %{{.*}}[%{{.*}} : i32] : vector<[11]xi64>
+// CMP64: %[[T3:.*]] = llvm.shufflevector %[[T2]], %{{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32] : vector<[11]xi64>, vector<[11]xi64>
+// CMP64: %[[T4:.*]] = arith.cmpi slt, %[[T0]], %[[T3]] : vector<[11]xi64>
+// CMP64: return %[[T4]] : vector<[11]xi1>
+
+func @genbool_var_1d_scalable(%arg0: index) -> vector<[11]xi1> {
+ %0 = vector.create_mask %arg0 : vector<[11]xi1>
+ return %0 : vector<[11]xi1>
+}
+
// CMP32-LABEL: @transfer_read_1d
// CMP32: %[[MEM:.*]]: memref<?xf32>, %[[OFF:.*]]: index) -> vector<16xf32> {
// CMP32: %[[D:.*]] = memref.dim %[[MEM]], %{{.*}} : memref<?xf32>
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 3dcbd3ae475e2..cda183df00d24 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1459,6 +1459,16 @@ func @genbool_1d() -> vector<8xi1> {
// -----
+func @genbool_1d_scalable() -> vector<[8]xi1> {
+ %0 = vector.constant_mask [0] : vector<[8]xi1>
+ return %0 : vector<[8]xi1>
+}
+// CHECK-LABEL: func @genbool_1d_scalable
+// CHECK: %[[VAL_0:.*]] = arith.constant dense<false> : vector<[8]xi1>
+// CHECK: return %[[VAL_0]] : vector<[8]xi1>
+
+// -----
+
func @genbool_2d() -> vector<4x4xi1> {
%v = vector.constant_mask [2, 2] : vector<4x4xi1>
return %v: vector<4x4xi1>
@@ -1505,6 +1515,20 @@ func @create_mask_1d(%a : index) -> vector<4xi1> {
// CHECK: %[[result:.*]] = arith.cmpi slt, %[[indices]], %[[bounds]] : vector<4xi32>
// CHECK: return %[[result]] : vector<4xi1>
+func @create_mask_1d_scalable(%a : index) -> vector<[4]xi1> {
+ %v = vector.create_mask %a : vector<[4]xi1>
+ return %v: vector<[4]xi1>
+}
+
+// CHECK-LABEL: func @create_mask_1d_scalable
+// CHECK-SAME: %[[arg:.*]]: index
+// CHECK: %[[indices:.*]] = llvm.intr.experimental.stepvector : vector<[4]xi32>
+// CHECK: %[[arg_i32:.*]] = arith.index_cast %[[arg]] : index to i32
+// CHECK: %[[boundsInsert:.*]] = llvm.insertelement %[[arg_i32]], {{.*}} : vector<[4]xi32>
+// CHECK: %[[bounds:.*]] = llvm.shufflevector %[[boundsInsert]], {{.*}} : vector<[4]xi32>, vector<[4]xi32>
+// CHECK: %[[result:.*]] = arith.cmpi slt, %[[indices]], %[[bounds]] : vector<[4]xi32>
+// CHECK: return %[[result]] : vector<[4]xi1>
+
// -----
func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> {
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 9647fb018bcaa..ceae2452cd431 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -13,6 +13,16 @@ func @create_vector_mask_to_constant_mask() -> (vector<4x3xi1>) {
// -----
+// CHECK-LABEL: create_scalable_vector_mask_to_constant_mask
+func @create_scalable_vector_mask_to_constant_mask() -> (vector<[8]xi1>) {
+ %c-1 = arith.constant -1 : index
+ // CHECK: vector.constant_mask [0] : vector<[8]xi1>
+ %0 = vector.create_mask %c-1 : vector<[8]xi1>
+ return %0 : vector<[8]xi1>
+}
+
+// -----
+
// CHECK-LABEL: create_vector_mask_to_constant_mask_truncation
func @create_vector_mask_to_constant_mask_truncation() -> (vector<4x3xi1>) {
%c2 = arith.constant 2 : index
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index c90725e5d8d7b..f60d2b103b882 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -944,6 +944,13 @@ func @constant_mask_with_zero_mask_dim_size() {
// -----
+func @constant_mask_scalable_non_zero_dim_size() {
+ // expected-error at +1 {{expected mask dim sizes for scalable masks to be 0}}
+ %0 = vector.constant_mask [2] : vector<[8]xi1>
+}
+
+// -----
+
func @print_no_result(%arg0 : f32) -> i32 {
// expected-error at +1 {{cannot name an operation with no results}}
%0 = vector.print %arg0 : f32
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index ab8daca78f7c6..43b38efb242eb 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -389,6 +389,8 @@ func @constant_vector_mask_0d() {
func @constant_vector_mask() {
// CHECK: vector.constant_mask [3, 2] : vector<4x3xi1>
%0 = vector.constant_mask [3, 2] : vector<4x3xi1>
+ // CHECK: vector.constant_mask [0] : vector<[4]xi1>
+ %1 = vector.constant_mask [0] : vector<[4]xi1>
return
}
More information about the Mlir-commits
mailing list