[Mlir-commits] [mlir] 5abf116 - [mlir][vector] Allow values outside of [0; dim-size] in create_mask
Sergei Grechanik
llvmlistbot at llvm.org
Thu Jan 20 10:00:30 PST 2022
Author: Sergei Grechanik
Date: 2022-01-20T09:34:42-08:00
New Revision: 5abf1163224549902bf34c9d07b822e5283beb7a
URL: https://github.com/llvm/llvm-project/commit/5abf1163224549902bf34c9d07b822e5283beb7a
DIFF: https://github.com/llvm/llvm-project/commit/5abf1163224549902bf34c9d07b822e5283beb7a.diff
LOG: [mlir][vector] Allow values outside of [0; dim-size] in create_mask
This commits explicitly states that negative values and values exceeding
vector dimensions are allowed in vector.create_mask (but not in
vector.constant_mask). These values are now truncated when
canonicalizing vector.create_mask to vector.constant_mask.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D116069
Added:
Modified:
mlir/include/mlir/Dialect/Vector/VectorOps.td
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir
mlir/test/Integration/Dialect/Vector/CPU/test-create-mask.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 20b431a7b7b25..826c7d0338f0b 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -2131,7 +2131,9 @@ def Vector_ConstantMaskOp :
specifies an exclusive upper bound [0, mask-dim-size-element-value)
for a unique dimension in the vector result. The conjunction of the ranges
define a hyper-rectangular region within which elements values are set to 1
- (otherwise element values are set to 0).
+ (otherwise element values are set to 0). Each value of 'mask_dim_sizes' must
+ be non-negative and not greater than the size of the corresponding vector
+ dimension (as opposed to vector.create_mask which allows this).
Example:
@@ -2169,7 +2171,9 @@ def Vector_CreateMaskOp :
each operand specifies a range [0, operand-value) for a unique dimension in
the vector result. The conjunction of the operand ranges define a
hyper-rectangular region within which elements values are set to 1
- (otherwise element values are set to 0).
+ (otherwise element values are set to 0). If operand-value is negative, it is
+ treated as if it were zero, and if it is greater than the corresponding
+ dimension size, it is treated as if it were equal to the dimension size.
Example:
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index eaa4f4e97e1dd..2e22fc0495bf2 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -4235,9 +4235,18 @@ class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
return failure();
// Gather constant mask dimension sizes.
SmallVector<int64_t, 4> maskDimSizes;
- for (auto operand : createMaskOp.operands()) {
- auto *defOp = operand.getDefiningOp();
- maskDimSizes.push_back(cast<arith::ConstantIndexOp>(defOp).value());
+ for (auto it : llvm::zip(createMaskOp.operands(),
+ createMaskOp.getType().getShape())) {
+ auto *defOp = std::get<0>(it).getDefiningOp();
+ int64_t maxDimSize = std::get<1>(it);
+ int64_t dimSize = cast<arith::ConstantIndexOp>(defOp).value();
+ dimSize = std::min(dimSize, maxDimSize);
+ // If one of dim sizes is zero, set all dims to zero.
+ if (dimSize <= 0) {
+ maskDimSizes.assign(createMaskOp.getType().getRank(), 0);
+ break;
+ }
+ maskDimSizes.push_back(dimSize);
}
// Replace 'createMaskOp' with ConstantMaskOp.
rewriter.replaceOpWithNewOp<ConstantMaskOp>(
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 5da33821eeaf0..3d1923ac09ace 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -13,6 +13,39 @@ func @create_vector_mask_to_constant_mask() -> (vector<4x3xi1>) {
// -----
+// CHECK-LABEL: create_vector_mask_to_constant_mask_truncation
+func @create_vector_mask_to_constant_mask_truncation() -> (vector<4x3xi1>) {
+ %c2 = arith.constant 2 : index
+ %c5 = arith.constant 5 : index
+ // CHECK: vector.constant_mask [4, 2] : vector<4x3xi1>
+ %0 = vector.create_mask %c5, %c2 : vector<4x3xi1>
+ return %0 : vector<4x3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: create_vector_mask_to_constant_mask_truncation_neg
+func @create_vector_mask_to_constant_mask_truncation_neg() -> (vector<4x3xi1>) {
+ %cneg2 = arith.constant -2 : index
+ %c5 = arith.constant 5 : index
+ // CHECK: vector.constant_mask [0, 0] : vector<4x3xi1>
+ %0 = vector.create_mask %c5, %cneg2 : vector<4x3xi1>
+ return %0 : vector<4x3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: create_vector_mask_to_constant_mask_truncation_zero
+func @create_vector_mask_to_constant_mask_truncation_zero() -> (vector<4x3xi1>) {
+ %c2 = arith.constant 2 : index
+ %c0 = arith.constant 0 : index
+ // CHECK: vector.constant_mask [0, 0] : vector<4x3xi1>
+ %0 = vector.create_mask %c0, %c2 : vector<4x3xi1>
+ return %0 : vector<4x3xi1>
+}
+
+// -----
+
func @extract_strided_slice_of_constant_mask() -> (vector<2x2xi1>) {
%0 = vector.constant_mask [2, 2] : vector<4x3xi1>
%1 = vector.extract_strided_slice %0
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-create-mask.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-create-mask.mlir
index 4a3113bdbe5a1..5834f14c6d22a 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-create-mask.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-create-mask.mlir
@@ -4,11 +4,13 @@
// RUN: FileCheck %s
func @entry() {
+ %cneg1 = arith.constant -1 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%c6 = arith.constant 6 : index
+ %c7 = arith.constant 7 : index
//
// 1-D.
@@ -18,16 +20,18 @@ func @entry() {
vector.print %1 : vector<5xi1>
// CHECK: ( 1, 1, 0, 0, 0 )
- scf.for %i = %c0 to %c6 step %c1 {
+ scf.for %i = %cneg1 to %c7 step %c1 {
%2 = vector.create_mask %i : vector<5xi1>
vector.print %2 : vector<5xi1>
}
// CHECK: ( 0, 0, 0, 0, 0 )
+ // CHECK: ( 0, 0, 0, 0, 0 )
// CHECK: ( 1, 0, 0, 0, 0 )
// CHECK: ( 1, 1, 0, 0, 0 )
// CHECK: ( 1, 1, 1, 0, 0 )
// CHECK: ( 1, 1, 1, 1, 0 )
// CHECK: ( 1, 1, 1, 1, 1 )
+ // CHECK: ( 1, 1, 1, 1, 1 )
//
// 2-D.
More information about the Mlir-commits
mailing list