[llvm] 7f85adb - [mlir][Standard] Allow select to use an i1 for vector and tensor values

River Riddle via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 23 04:51:10 PDT 2020


Author: River Riddle
Date: 2020-04-23T04:50:09-07:00
New Revision: 7f85adb54d1956183630eb43c2f3e578f7366276

URL: https://github.com/llvm/llvm-project/commit/7f85adb54d1956183630eb43c2f3e578f7366276
DIFF: https://github.com/llvm/llvm-project/commit/7f85adb54d1956183630eb43c2f3e578f7366276.diff

LOG: [mlir][Standard] Allow select to use an i1 for vector and tensor values

It currently requires that the condition match the shape of the selected value, but this is only really useful for things like masks. This revision allows for the use of i1 to mean that all of the vector/tensor is selected. This also matches the behavior of LLVM select. A benefit of this change is that transformations that want to generate selects, like those on the CFG, don't have to special case vector/tensor. Previously the only way to generate  a select from an i1 was to use a splat, but that doesn't support dynamically shaped/unranked tensors.

Differential Revision: https://reviews.llvm.org/D78690

Added: 
    

Modified: 
    llvm/include/llvm/ADT/STLExtras.h
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/Dialect/Standard/canonicalize-cf.mlir
    mlir/test/IR/core-ops.mlir
    mlir/test/IR/invalid-ops.mlir

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h
index c10f4d233e50..d9a1d0cce2c5 100644
--- a/llvm/include/llvm/ADT/STLExtras.h
+++ b/llvm/include/llvm/ADT/STLExtras.h
@@ -1124,7 +1124,8 @@ class indexed_accessor_range_base {
 
   /// Compare this range with another.
   template <typename OtherT> bool operator==(const OtherT &other) const {
-    return size() == std::distance(other.begin(), other.end()) &&
+    return size() ==
+               static_cast<size_t>(std::distance(other.begin(), other.end())) &&
            std::equal(begin(), end(), other.begin());
   }
   template <typename OtherT> bool operator!=(const OtherT &other) const {

diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 39c9597d866a..54800a579cd9 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -1915,11 +1915,8 @@ def RsqrtOp : FloatUnaryOp<"rsqrt"> {
 // SelectOp
 //===----------------------------------------------------------------------===//
 
-def SelectOp : Std_Op<"select", [NoSideEffect, SameOperandsAndResultShape,
-     AllTypesMatch<["true_value", "false_value", "result"]>,
-     TypesMatchWith<"condition type matches i1 equivalent of result type",
-                     "result", "condition",
-                     "getI1SameShape($_self)">]> {
+def SelectOp : Std_Op<"select", [NoSideEffect,
+     AllTypesMatch<["true_value", "false_value", "result"]>]> {
   let summary = "select operation";
   let description = [{
     The `select` operation chooses one value based on a binary condition
@@ -1930,7 +1927,8 @@ def SelectOp : Std_Op<"select", [NoSideEffect, SameOperandsAndResultShape,
     The operation applies to vectors and tensors elementwise given the _shape_
     of all operands is identical. The choice is made for each element
     individually based on the value at the same position as the element in the
-    condition operand.
+    condition operand. If an i1 is provided as the condition, the entire vector
+    or tensor is chosen.
 
     The `select` operation combined with [`cmpi`](#stdcmpi-cmpiop) can be used
     to implement `min` and `max` with signed or unsigned comparison semantics.
@@ -1944,9 +1942,11 @@ def SelectOp : Std_Op<"select", [NoSideEffect, SameOperandsAndResultShape,
     // Generic form of the same operation.
     %x = "std.select"(%cond, %true, %false) : (i1, i32, i32) -> i32
 
-    // Vector selection is element-wise
-    %vx = "std.select"(%vcond, %vtrue, %vfalse)
-        : (vector<42xi1>, vector<42xf32>, vector<42xf32>) -> vector<42xf32>
+    // Element-wise vector selection.
+    %vx = std.select %vcond, %vtrue, %vfalse : vector<42xi1>, vector<42xf32>
+
+    // Full vector selection.
+    %vx = std.select %cond, %vtrue, %vfalse : vector<42xf32>
     ```
   }];
 
@@ -1954,7 +1954,6 @@ def SelectOp : Std_Op<"select", [NoSideEffect, SameOperandsAndResultShape,
                        AnyType:$true_value,
                        AnyType:$false_value);
   let results = (outs AnyType:$result);
-  let verifier = ?;
 
   let builders = [OpBuilder<
     "Builder *builder, OperationState &result, Value condition,"
@@ -1970,10 +1969,6 @@ def SelectOp : Std_Op<"select", [NoSideEffect, SameOperandsAndResultShape,
   }];
 
   let hasFolder = 1;
-
-  let assemblyFormat = [{
-    $condition `,` $true_value `,` $false_value attr-dict `:` type($result)
-  }];
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 3294210d5218..bf4bfc8bdef6 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -999,15 +999,6 @@ struct SimplifyCondBranchIdenticalSuccessors
     if (trueDest->getUniquePredecessor() != condbr.getOperation()->getBlock())
       return failure();
 
-    // TODO: ATM Tensor/Vector SelectOp requires that the condition has the same
-    // shape as the operands. We should relax that to allow an i1 to signify
-    // that everything is selected.
-    auto doesntSupportsScalarI1 = [](Type type) {
-      return type.isa<TensorType>() || type.isa<VectorType>();
-    };
-    if (llvm::any_of(trueOperands.getTypes(), doesntSupportsScalarI1))
-      return failure();
-
     // Generate a select for any operands that 
diff er between the two.
     SmallVector<Value, 8> mergedOperands;
     mergedOperands.reserve(trueOperands.size());
@@ -1925,6 +1916,59 @@ OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) {
   return nullptr;
 }
 
+static void print(OpAsmPrinter &p, SelectOp op) {
+  p << "select " << op.getOperands();
+  p.printOptionalAttrDict(op.getAttrs());
+  p << " : ";
+  if (ShapedType condType = op.getCondition().getType().dyn_cast<ShapedType>())
+    p << condType << ", ";
+  p << op.getType();
+}
+
+static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &result) {
+  Type conditionType, resultType;
+  SmallVector<OpAsmParser::OperandType, 3> operands;
+  if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) ||
+      parser.parseOptionalAttrDict(result.attributes) ||
+      parser.parseColonType(resultType))
+    return failure();
+
+  // Check for the explicit condition type if this is a masked tensor or vector.
+  if (succeeded(parser.parseOptionalComma())) {
+    conditionType = resultType;
+    if (parser.parseType(resultType))
+      return failure();
+  } else {
+    conditionType = parser.getBuilder().getI1Type();
+  }
+
+  result.addTypes(resultType);
+  return parser.resolveOperands(operands,
+                                {conditionType, resultType, resultType},
+                                parser.getNameLoc(), result.operands);
+}
+
+static LogicalResult verify(SelectOp op) {
+  Type conditionType = op.getCondition().getType();
+  if (conditionType.isSignlessInteger(1))
+    return success();
+
+  // If the result type is a vector or tensor, the type can be a mask with the
+  // same elements.
+  Type resultType = op.getType();
+  if (!resultType.isa<TensorType>() && !resultType.isa<VectorType>())
+    return op.emitOpError()
+           << "expected condition to be a signless i1, but got "
+           << conditionType;
+  Type shapedConditionType = getI1SameShape(resultType);
+  if (conditionType != shapedConditionType)
+    return op.emitOpError()
+           << "expected condition type to have the same shape "
+              "as the result type, expected "
+           << shapedConditionType << ", but got " << conditionType;
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // SignExtendIOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Standard/canonicalize-cf.mlir b/mlir/test/Dialect/Standard/canonicalize-cf.mlir
index 71ee7f1fcfe0..b0fd84448b1a 100644
--- a/mlir/test/Dialect/Standard/canonicalize-cf.mlir
+++ b/mlir/test/Dialect/Standard/canonicalize-cf.mlir
@@ -69,39 +69,18 @@ func @cond_br_same_successor(%cond : i1, %a : i32) {
 
 // CHECK-LABEL: func @cond_br_same_successor_insert_select(
 // CHECK-SAME: %[[COND:.*]]: i1, %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32
-func @cond_br_same_successor_insert_select(%cond : i1, %a : i32, %b : i32) -> i32 {
+// CHECK-SAME: %[[ARG2:.*]]: tensor<2xi32>, %[[ARG3:.*]]: tensor<2xi32>
+func @cond_br_same_successor_insert_select(
+      %cond : i1, %a : i32, %b : i32, %c : tensor<2xi32>, %d : tensor<2xi32>
+    ) -> (i32, tensor<2xi32>)  {
   // CHECK: %[[RES:.*]] = select %[[COND]], %[[ARG0]], %[[ARG1]]
-  // CHECK: return %[[RES]]
-
-  cond_br %cond, ^bb1(%a : i32), ^bb1(%b : i32)
-
-^bb1(%result : i32):
-  return %result : i32
-}
-
-/// Check that we don't generate a select if the type requires a splat.
-/// TODO: SelectOp should allow for matching a vector/tensor with i1.
-
-// CHECK-LABEL: func @cond_br_same_successor_no_select_tensor(
-func @cond_br_same_successor_no_select_tensor(%cond : i1, %a : tensor<2xi32>,
-                                              %b : tensor<2xi32>) -> tensor<2xi32>{
-  // CHECK: cond_br
-
-  cond_br %cond, ^bb1(%a : tensor<2xi32>), ^bb1(%b : tensor<2xi32>)
-
-^bb1(%result : tensor<2xi32>):
-  return %result : tensor<2xi32>
-}
-
-// CHECK-LABEL: func @cond_br_same_successor_no_select_vector(
-func @cond_br_same_successor_no_select_vector(%cond : i1, %a : vector<2xi32>,
-                                              %b : vector<2xi32>) -> vector<2xi32> {
-  // CHECK: cond_br
+  // CHECK: %[[RES2:.*]] = select %[[COND]], %[[ARG2]], %[[ARG3]]
+  // CHECK: return %[[RES]], %[[RES2]]
 
-  cond_br %cond, ^bb1(%a : vector<2xi32>), ^bb1(%b : vector<2xi32>)
+  cond_br %cond, ^bb1(%a, %c : i32, tensor<2xi32>), ^bb1(%b, %d : i32, tensor<2xi32>)
 
-^bb1(%result : vector<2xi32>):
-  return %result : vector<2xi32>
+^bb1(%result : i32, %result2 : tensor<2xi32>):
+  return %result, %result2 : i32, tensor<2xi32>
 }
 
 /// Test the compound folding of BranchOp and CondBranchOp.

diff  --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir
index d19f3445655a..d0a27ec68468 100644
--- a/mlir/test/IR/core-ops.mlir
+++ b/mlir/test/IR/core-ops.mlir
@@ -141,17 +141,17 @@ func @standard_instrs(tensor<4x4x?xf32>, f32, i32, index, i64, f16) {
   // CHECK: %{{[0-9]+}} = select %{{[0-9]+}}, %arg3, %arg3 : index
   %21 = select %18, %idx, %idx : index
 
-  // CHECK: %{{[0-9]+}} = select %{{[0-9]+}}, %cst_4, %cst_4 : tensor<42xi32>
-  %22 = select %19, %tci32, %tci32 : tensor<42 x i32>
+  // CHECK: %{{[0-9]+}} = select %{{[0-9]+}}, %cst_4, %cst_4 : tensor<42xi1>, tensor<42xi32>
+  %22 = select %19, %tci32, %tci32 : tensor<42 x i1>, tensor<42 x i32>
 
-  // CHECK: %{{[0-9]+}} = select %{{[0-9]+}}, %cst_5, %cst_5 : vector<42xi32>
-  %23 = select %20, %vci32, %vci32 : vector<42 x i32>
+  // CHECK: %{{[0-9]+}} = select %{{[0-9]+}}, %cst_5, %cst_5 : vector<42xi1>, vector<42xi32>
+  %23 = select %20, %vci32, %vci32 : vector<42 x i1>, vector<42 x i32>
 
   // CHECK: %{{[0-9]+}} = select %{{[0-9]+}}, %arg3, %arg3 : index
   %24 = "std.select"(%18, %idx, %idx) : (i1, index, index) -> index
 
   // CHECK: %{{[0-9]+}} = select %{{[0-9]+}}, %cst_4, %cst_4 : tensor<42xi32>
-  %25 = "std.select"(%19, %tci32, %tci32) : (tensor<42 x i1>, tensor<42 x i32>, tensor<42 x i32>) -> tensor<42 x i32>
+  %25 = std.select %18, %tci32, %tci32 : tensor<42 x i32>
 
   // CHECK: %{{[0-9]+}} = divi_signed %arg2, %arg2 : i32
   %26 = divi_signed %i, %i : i32

diff  --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir
index c7b290517f02..80fdf3342995 100644
--- a/mlir/test/IR/invalid-ops.mlir
+++ b/mlir/test/IR/invalid-ops.mlir
@@ -281,18 +281,18 @@ func @func_with_ops(i1, i32, i64) {
 
 // -----
 
-func @func_with_ops(i1, vector<42xi32>, vector<42xi32>) {
-^bb0(%cond : i1, %t : vector<42xi32>, %f : vector<42xi32>):
-  // expected-error at +1 {{requires the same shape for all operands and results}}
-  %r = "std.select"(%cond, %t, %f) : (i1, vector<42xi32>, vector<42xi32>) -> vector<42xi32>
+func @func_with_ops(vector<12xi1>, vector<42xi32>, vector<42xi32>) {
+^bb0(%cond : vector<12xi1>, %t : vector<42xi32>, %f : vector<42xi32>):
+  // expected-error at +1 {{expected condition type to have the same shape as the result type, expected 'vector<42xi1>', but got 'vector<12xi1>'}}
+  %r = "std.select"(%cond, %t, %f) : (vector<12xi1>, vector<42xi32>, vector<42xi32>) -> vector<42xi32>
 }
 
 // -----
 
-func @func_with_ops(i1, tensor<42xi32>, tensor<?xi32>) {
-^bb0(%cond : i1, %t : tensor<42xi32>, %f : tensor<?xi32>):
-  // expected-error at +1 {{ op requires the same shape for all operands and results}}
-  %r = "std.select"(%cond, %t, %f) : (i1, tensor<42xi32>, tensor<?xi32>) -> tensor<42xi32>
+func @func_with_ops(tensor<12xi1>, tensor<42xi32>, tensor<42xi32>) {
+^bb0(%cond : tensor<12xi1>, %t : tensor<42xi32>, %f : tensor<42xi32>):
+  // expected-error at +1 {{expected condition type to have the same shape as the result type, expected 'tensor<42xi1>', but got 'tensor<12xi1>'}}
+  %r = "std.select"(%cond, %t, %f) : (tensor<12xi1>, tensor<42xi32>, tensor<42xi32>) -> tensor<42xi32>
 }
 
 // -----


        


More information about the llvm-commits mailing list