[Mlir-commits] [mlir] 5abf116 - [mlir][vector] Allow values outside of [0; dim-size] in create_mask

Sergei Grechanik llvmlistbot at llvm.org
Thu Jan 20 10:00:30 PST 2022


Author: Sergei Grechanik
Date: 2022-01-20T09:34:42-08:00
New Revision: 5abf1163224549902bf34c9d07b822e5283beb7a

URL: https://github.com/llvm/llvm-project/commit/5abf1163224549902bf34c9d07b822e5283beb7a
DIFF: https://github.com/llvm/llvm-project/commit/5abf1163224549902bf34c9d07b822e5283beb7a.diff

LOG: [mlir][vector] Allow values outside of [0; dim-size] in create_mask

This commits explicitly states that negative values and values exceeding
vector dimensions are allowed in vector.create_mask (but not in
vector.constant_mask). These values are now truncated when
canonicalizing vector.create_mask to vector.constant_mask.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D116069

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.td
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir
    mlir/test/Integration/Dialect/Vector/CPU/test-create-mask.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 20b431a7b7b25..826c7d0338f0b 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -2131,7 +2131,9 @@ def Vector_ConstantMaskOp :
     specifies an exclusive upper bound [0, mask-dim-size-element-value)
     for a unique dimension in the vector result. The conjunction of the ranges
     define a hyper-rectangular region within which elements values are set to 1
-    (otherwise element values are set to 0).
+    (otherwise element values are set to 0). Each value of 'mask_dim_sizes' must
+    be non-negative and not greater than the size of the corresponding vector
+    dimension (as opposed to vector.create_mask which allows this).
 
     Example:
 
@@ -2169,7 +2171,9 @@ def Vector_CreateMaskOp :
     each operand specifies a range [0, operand-value) for a unique dimension in
     the vector result. The conjunction of the operand ranges define a
     hyper-rectangular region within which elements values are set to 1
-    (otherwise element values are set to 0).
+    (otherwise element values are set to 0). If operand-value is negative, it is
+    treated as if it were zero, and if it is greater than the corresponding
+    dimension size, it is treated as if it were equal to the dimension size.
 
     Example:
 

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index eaa4f4e97e1dd..2e22fc0495bf2 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -4235,9 +4235,18 @@ class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
       return failure();
     // Gather constant mask dimension sizes.
     SmallVector<int64_t, 4> maskDimSizes;
-    for (auto operand : createMaskOp.operands()) {
-      auto *defOp = operand.getDefiningOp();
-      maskDimSizes.push_back(cast<arith::ConstantIndexOp>(defOp).value());
+    for (auto it : llvm::zip(createMaskOp.operands(),
+                             createMaskOp.getType().getShape())) {
+      auto *defOp = std::get<0>(it).getDefiningOp();
+      int64_t maxDimSize = std::get<1>(it);
+      int64_t dimSize = cast<arith::ConstantIndexOp>(defOp).value();
+      dimSize = std::min(dimSize, maxDimSize);
+      // If one of dim sizes is zero, set all dims to zero.
+      if (dimSize <= 0) {
+        maskDimSizes.assign(createMaskOp.getType().getRank(), 0);
+        break;
+      }
+      maskDimSizes.push_back(dimSize);
     }
     // Replace 'createMaskOp' with ConstantMaskOp.
     rewriter.replaceOpWithNewOp<ConstantMaskOp>(

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 5da33821eeaf0..3d1923ac09ace 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -13,6 +13,39 @@ func @create_vector_mask_to_constant_mask() -> (vector<4x3xi1>) {
 
 // -----
 
+// CHECK-LABEL: create_vector_mask_to_constant_mask_truncation
+func @create_vector_mask_to_constant_mask_truncation() -> (vector<4x3xi1>) {
+  %c2 = arith.constant 2 : index
+  %c5 = arith.constant 5 : index
+  // CHECK: vector.constant_mask [4, 2] : vector<4x3xi1>
+  %0 = vector.create_mask %c5, %c2 : vector<4x3xi1>
+  return %0 : vector<4x3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: create_vector_mask_to_constant_mask_truncation_neg
+func @create_vector_mask_to_constant_mask_truncation_neg() -> (vector<4x3xi1>) {
+  %cneg2 = arith.constant -2 : index
+  %c5 = arith.constant 5 : index
+  // CHECK: vector.constant_mask [0, 0] : vector<4x3xi1>
+  %0 = vector.create_mask %c5, %cneg2 : vector<4x3xi1>
+  return %0 : vector<4x3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: create_vector_mask_to_constant_mask_truncation_zero
+func @create_vector_mask_to_constant_mask_truncation_zero() -> (vector<4x3xi1>) {
+  %c2 = arith.constant 2 : index
+  %c0 = arith.constant 0 : index
+  // CHECK: vector.constant_mask [0, 0] : vector<4x3xi1>
+  %0 = vector.create_mask %c0, %c2 : vector<4x3xi1>
+  return %0 : vector<4x3xi1>
+}
+
+// -----
+
 func @extract_strided_slice_of_constant_mask() -> (vector<2x2xi1>) {
   %0 = vector.constant_mask [2, 2] : vector<4x3xi1>
   %1 = vector.extract_strided_slice %0

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/test-create-mask.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-create-mask.mlir
index 4a3113bdbe5a1..5834f14c6d22a 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-create-mask.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-create-mask.mlir
@@ -4,11 +4,13 @@
 // RUN: FileCheck %s
 
 func @entry() {
+  %cneg1 = arith.constant -1 : index
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
   %c2 = arith.constant 2 : index
   %c3 = arith.constant 3 : index
   %c6 = arith.constant 6 : index
+  %c7 = arith.constant 7 : index
 
   //
   // 1-D.
@@ -18,16 +20,18 @@ func @entry() {
   vector.print %1 : vector<5xi1>
   // CHECK: ( 1, 1, 0, 0, 0 )
 
-  scf.for %i = %c0 to %c6 step %c1 {
+  scf.for %i = %cneg1 to %c7 step %c1 {
     %2 = vector.create_mask %i : vector<5xi1>
     vector.print %2 : vector<5xi1>
   }
   // CHECK: ( 0, 0, 0, 0, 0 )
+  // CHECK: ( 0, 0, 0, 0, 0 )
   // CHECK: ( 1, 0, 0, 0, 0 )
   // CHECK: ( 1, 1, 0, 0, 0 )
   // CHECK: ( 1, 1, 1, 0, 0 )
   // CHECK: ( 1, 1, 1, 1, 0 )
   // CHECK: ( 1, 1, 1, 1, 1 )
+  // CHECK: ( 1, 1, 1, 1, 1 )
 
   //
   // 2-D.


        


More information about the Mlir-commits mailing list