[Mlir-commits] [mlir] 25055a4 - [mlir] add unsigned comparison builders to Affine EDSC

Alex Zinenko llvmlistbot at llvm.org
Mon Jun 29 14:30:56 PDT 2020


Author: Adam D Straw
Date: 2020-06-29T23:30:49+02:00
New Revision: 25055a4fb90292e49f44a0a708390a730cd1116e

URL: https://github.com/llvm/llvm-project/commit/25055a4fb90292e49f44a0a708390a730cd1116e
DIFF: https://github.com/llvm/llvm-project/commit/25055a4fb90292e49f44a0a708390a730cd1116e.diff

LOG: [mlir] add unsigned comparison builders to Affine EDSC

Current Affine comparison builders, which use operator overload, default to signed comparison.  This creates the possibility of misuse of these builders and potential correctness issues when dealing with unsigned integers.  This change makes the distinction between signed and unsigned comparison builders and forces the caller to make a choice between the two.

Differential Revision: https://reviews.llvm.org/D82323

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Affine/EDSC/Builders.h
    mlir/include/mlir/EDSC/Builders.h
    mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
    mlir/lib/Dialect/Affine/EDSC/Builders.cpp
    mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
    mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
    mlir/test/EDSC/builder-api-test.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Affine/EDSC/Builders.h b/mlir/include/mlir/Dialect/Affine/EDSC/Builders.h
index 5c99a430c862..96191e01296a 100644
--- a/mlir/include/mlir/Dialect/Affine/EDSC/Builders.h
+++ b/mlir/include/mlir/Dialect/Affine/EDSC/Builders.h
@@ -66,10 +66,14 @@ Value operator^(Value lhs, Value rhs);
 /// Comparison operator overloadings.
 Value eq(Value lhs, Value rhs);
 Value ne(Value lhs, Value rhs);
-Value operator<(Value lhs, Value rhs);
-Value operator<=(Value lhs, Value rhs);
-Value operator>(Value lhs, Value rhs);
-Value operator>=(Value lhs, Value rhs);
+Value slt(Value lhs, Value rhs);
+Value sle(Value lhs, Value rhs);
+Value sgt(Value lhs, Value rhs);
+Value sge(Value lhs, Value rhs);
+Value ult(Value lhs, Value rhs);
+Value ule(Value lhs, Value rhs);
+Value ugt(Value lhs, Value rhs);
+Value uge(Value lhs, Value rhs);
 
 } // namespace op
 
@@ -159,24 +163,44 @@ Value TemplatedIndexedValue<Load, Store>::ne(Value e) {
   return ne(value, e);
 }
 template <typename Load, typename Store>
-Value TemplatedIndexedValue<Load, Store>::operator<(Value e) {
-  using op::operator<;
-  return static_cast<Value>(*this) < e;
+Value TemplatedIndexedValue<Load, Store>::slt(Value e) {
+  using op::slt;
+  return slt(static_cast<Value>(*this), e);
 }
 template <typename Load, typename Store>
-Value TemplatedIndexedValue<Load, Store>::operator<=(Value e) {
-  using op::operator<=;
-  return static_cast<Value>(*this) <= e;
+Value TemplatedIndexedValue<Load, Store>::sle(Value e) {
+  using op::sle;
+  return sle(static_cast<Value>(*this), e);
 }
 template <typename Load, typename Store>
-Value TemplatedIndexedValue<Load, Store>::operator>(Value e) {
-  using op::operator>;
-  return static_cast<Value>(*this) > e;
+Value TemplatedIndexedValue<Load, Store>::sgt(Value e) {
+  using op::sgt;
+  return sgt(static_cast<Value>(*this), e);
 }
 template <typename Load, typename Store>
-Value TemplatedIndexedValue<Load, Store>::operator>=(Value e) {
-  using op::operator>=;
-  return static_cast<Value>(*this) >= e;
+Value TemplatedIndexedValue<Load, Store>::sge(Value e) {
+  using op::sge;
+  return sge(static_cast<Value>(*this), e);
+}
+template <typename Load, typename Store>
+Value TemplatedIndexedValue<Load, Store>::ult(Value e) {
+  using op::ult;
+  return ult(static_cast<Value>(*this), e);
+}
+template <typename Load, typename Store>
+Value TemplatedIndexedValue<Load, Store>::ule(Value e) {
+  using op::ule;
+  return ule(static_cast<Value>(*this), e);
+}
+template <typename Load, typename Store>
+Value TemplatedIndexedValue<Load, Store>::ugt(Value e) {
+  using op::ugt;
+  return ugt(static_cast<Value>(*this), e);
+}
+template <typename Load, typename Store>
+Value TemplatedIndexedValue<Load, Store>::uge(Value e) {
+  using op::uge;
+  return uge(static_cast<Value>(*this), e);
 }
 
 } // namespace edsc

diff  --git a/mlir/include/mlir/EDSC/Builders.h b/mlir/include/mlir/EDSC/Builders.h
index a7c5506f7ab0..64df2c9fe367 100644
--- a/mlir/include/mlir/EDSC/Builders.h
+++ b/mlir/include/mlir/EDSC/Builders.h
@@ -288,21 +288,37 @@ class TemplatedIndexedValue {
   /// Comparison operator overloadings.
   Value eq(Value e);
   Value ne(Value e);
-  Value operator<(Value e);
-  Value operator<=(Value e);
-  Value operator>(Value e);
-  Value operator>=(Value e);
-  Value operator<(TemplatedIndexedValue e) {
-    return *this < static_cast<Value>(e);
+  Value slt(Value e);
+  Value sle(Value e);
+  Value sgt(Value e);
+  Value sge(Value e);
+  Value ult(Value e);
+  Value ule(Value e);
+  Value ugt(Value e);
+  Value uge(Value e);
+  Value slt(TemplatedIndexedValue e) {
+    return slt(*this, static_cast<Value>(e));
   }
-  Value operator<=(TemplatedIndexedValue e) {
-    return *this <= static_cast<Value>(e);
+  Value sle(TemplatedIndexedValue e) {
+    return sle(*this, static_cast<Value>(e));
   }
-  Value operator>(TemplatedIndexedValue e) {
-    return *this > static_cast<Value>(e);
+  Value sgt(TemplatedIndexedValue e) {
+    return sgt(*this, static_cast<Value>(e));
   }
-  Value operator>=(TemplatedIndexedValue e) {
-    return *this >= static_cast<Value>(e);
+  Value sge(TemplatedIndexedValue e) {
+    return sge(*this, static_cast<Value>(e));
+  }
+  Value ult(TemplatedIndexedValue e) {
+    return ult(*this, static_cast<Value>(e));
+  }
+  Value ule(TemplatedIndexedValue e) {
+    return ule(*this, static_cast<Value>(e));
+  }
+  Value ugt(TemplatedIndexedValue e) {
+    return ugt(*this, static_cast<Value>(e));
+  }
+  Value uge(TemplatedIndexedValue e) {
+    return uge(*this, static_cast<Value>(e));
   }
 
 private:

diff  --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 99ded0686a54..cf3d9653d7df 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -187,7 +187,7 @@ Value NDTransferOpHelper<ConcreteOp>::emitInBoundsCondition(
     using namespace mlir::edsc::op;
     majorIvsPlusOffsets.push_back(iv + off);
     if (xferOp.isMaskedDim(leadingRank + idx)) {
-      Value inBounds = majorIvsPlusOffsets.back() < ub;
+      Value inBounds = slt(majorIvsPlusOffsets.back(), ub);
       inBoundsCondition =
           (inBoundsCondition) ? (inBoundsCondition && inBounds) : inBounds;
     }
@@ -433,16 +433,16 @@ clip(TransferOpTy transfer, MemRefBoundsCapture &bounds, ArrayRef<Value> ivs) {
     auto i = memRefAccess[memRefDim];
     if (loopIndex < 0) {
       auto N_minus_1 = N - one;
-      auto select_1 = std_select(i < N, i, N_minus_1);
+      auto select_1 = std_select(slt(i, N), i, N_minus_1);
       clippedScalarAccessExprs[memRefDim] =
-          std_select(i < zero, zero, select_1);
+          std_select(slt(i, zero), zero, select_1);
     } else {
       auto ii = ivs[loopIndex];
       auto i_plus_ii = i + ii;
       auto N_minus_1 = N - one;
-      auto select_1 = std_select(i_plus_ii < N, i_plus_ii, N_minus_1);
+      auto select_1 = std_select(slt(i_plus_ii, N), i_plus_ii, N_minus_1);
       clippedScalarAccessExprs[memRefDim] =
-          std_select(i_plus_ii < zero, zero, select_1);
+          std_select(slt(i_plus_ii, zero), zero, select_1);
     }
   }
 

diff  --git a/mlir/lib/Dialect/Affine/EDSC/Builders.cpp b/mlir/lib/Dialect/Affine/EDSC/Builders.cpp
index 559f375c5dff..e5bf1c015e02 100644
--- a/mlir/lib/Dialect/Affine/EDSC/Builders.cpp
+++ b/mlir/lib/Dialect/Affine/EDSC/Builders.cpp
@@ -221,29 +221,51 @@ Value mlir::edsc::op::ne(Value lhs, Value rhs) {
              ? createFComparisonExpr(CmpFPredicate::ONE, lhs, rhs)
              : createIComparisonExpr(CmpIPredicate::ne, lhs, rhs);
 }
-Value mlir::edsc::op::operator<(Value lhs, Value rhs) {
+Value mlir::edsc::op::slt(Value lhs, Value rhs) {
   auto type = lhs.getType();
   return type.isa<FloatType>()
              ? createFComparisonExpr(CmpFPredicate::OLT, lhs, rhs)
-             :
-             // TODO(ntv,zinenko): signed by default, how about unsigned?
-             createIComparisonExpr(CmpIPredicate::slt, lhs, rhs);
+             : createIComparisonExpr(CmpIPredicate::slt, lhs, rhs);
 }
-Value mlir::edsc::op::operator<=(Value lhs, Value rhs) {
+Value mlir::edsc::op::sle(Value lhs, Value rhs) {
   auto type = lhs.getType();
   return type.isa<FloatType>()
              ? createFComparisonExpr(CmpFPredicate::OLE, lhs, rhs)
              : createIComparisonExpr(CmpIPredicate::sle, lhs, rhs);
 }
-Value mlir::edsc::op::operator>(Value lhs, Value rhs) {
+Value mlir::edsc::op::sgt(Value lhs, Value rhs) {
   auto type = lhs.getType();
   return type.isa<FloatType>()
              ? createFComparisonExpr(CmpFPredicate::OGT, lhs, rhs)
              : createIComparisonExpr(CmpIPredicate::sgt, lhs, rhs);
 }
-Value mlir::edsc::op::operator>=(Value lhs, Value rhs) {
+Value mlir::edsc::op::sge(Value lhs, Value rhs) {
   auto type = lhs.getType();
   return type.isa<FloatType>()
              ? createFComparisonExpr(CmpFPredicate::OGE, lhs, rhs)
              : createIComparisonExpr(CmpIPredicate::sge, lhs, rhs);
 }
+Value mlir::edsc::op::ult(Value lhs, Value rhs) {
+  auto type = lhs.getType();
+  return type.isa<FloatType>()
+             ? createFComparisonExpr(CmpFPredicate::OLT, lhs, rhs)
+             : createIComparisonExpr(CmpIPredicate::ult, lhs, rhs);
+}
+Value mlir::edsc::op::ule(Value lhs, Value rhs) {
+  auto type = lhs.getType();
+  return type.isa<FloatType>()
+             ? createFComparisonExpr(CmpFPredicate::OLE, lhs, rhs)
+             : createIComparisonExpr(CmpIPredicate::ule, lhs, rhs);
+}
+Value mlir::edsc::op::ugt(Value lhs, Value rhs) {
+  auto type = lhs.getType();
+  return type.isa<FloatType>()
+             ? createFComparisonExpr(CmpFPredicate::OGT, lhs, rhs)
+             : createIComparisonExpr(CmpIPredicate::ugt, lhs, rhs);
+}
+Value mlir::edsc::op::uge(Value lhs, Value rhs) {
+  auto type = lhs.getType();
+  return type.isa<FloatType>()
+             ? createFComparisonExpr(CmpFPredicate::OGE, lhs, rhs)
+             : createIComparisonExpr(CmpIPredicate::uge, lhs, rhs);
+}

diff  --git a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
index 5016ca9b3055..8cfc25d2ff8e 100644
--- a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
+++ b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
@@ -169,8 +169,8 @@ Operation *mlir::edsc::ops::linalg_generic_pointwise_max(StructuredIndexed I1,
                                                          StructuredIndexed I2,
                                                          StructuredIndexed O) {
   BinaryPointwiseOpBuilder binOp([](Value a, Value b) -> Value {
-    using edsc::op::operator>;
-    return std_select(a > b, a, b);
+    using edsc::op::sgt;
+    return std_select(sgt(a, b), a, b);
   });
   return linalg_generic_pointwise(binOp, I1, I2, O);
 }

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index d031712ce5a9..ec57717eaca9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -263,16 +263,16 @@ Value getConvOpInput(ConvOp convOp, StdIndexedValue im,
       continue;
     }
 
-    using edsc::op::operator<;
-    using edsc::op::operator>=;
+    using edsc::op::sge;
+    using edsc::op::slt;
     using edsc::op::operator||;
-    Value leftOutOfBound = dim < zeroIndex;
+    Value leftOutOfBound = slt(dim, zeroIndex);
     if (conds.empty())
       conds.push_back(leftOutOfBound);
     else
       conds.push_back(conds.back() || leftOutOfBound);
     Value rightBound = std_dim(convOp.input(), idx);
-    conds.push_back(conds.back() || (dim >= rightBound));
+    conds.push_back(conds.back() || (sge(dim, rightBound)));
 
     // When padding is involved, the indices will only be shifted to negative,
     // so having a max op is enough.
@@ -337,8 +337,8 @@ void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingMaxOp op) {
   // Emit scalar form.
   Value lhs = std_load(op.output(), indices.outputs);
   Value rhs = std_load(op.input(), indices.inputs);
-  using edsc::op::operator>;
-  Value maxValue = std_select(lhs > rhs, lhs, rhs);
+  using edsc::op::sgt;
+  Value maxValue = std_select(sgt(lhs, rhs), lhs, rhs);
   std_store(maxValue, op.output(), indices.outputs);
 }
 template <typename IndexedValueType>
@@ -347,8 +347,8 @@ void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingMinOp op) {
   // Emit scalar form.
   Value lhs = std_load(op.output(), indices.outputs);
   Value rhs = std_load(op.input(), indices.inputs);
-  using edsc::op::operator<;
-  Value minValue = std_select(lhs < rhs, lhs, rhs);
+  using edsc::op::slt;
+  Value minValue = std_select(slt(lhs, rhs), lhs, rhs);
   std_store(minValue, op.output(), indices.outputs);
 }
 template <typename IndexedValueType>

diff  --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp
index 31748810f899..73f7adeeaf71 100644
--- a/mlir/test/EDSC/builder-api-test.cpp
+++ b/mlir/test/EDSC/builder-api-test.cpp
@@ -459,9 +459,9 @@ TEST_FUNC(diviu_op_i32) {
 
 TEST_FUNC(select_op_i32) {
   using namespace edsc::op;
-  auto f32Type = FloatType::getF32(&globalContext());
+  auto i32Type = IntegerType::get(32, &globalContext());
   auto memrefType = MemRefType::get(
-      {ShapedType::kDynamicSize, ShapedType::kDynamicSize}, f32Type, {}, 0);
+      {ShapedType::kDynamicSize, ShapedType::kDynamicSize}, i32Type, {}, 0);
   auto f = makeFunction("select_op", {}, {memrefType});
 
   OpBuilder builder(f.getBody());
@@ -470,7 +470,18 @@ TEST_FUNC(select_op_i32) {
   MemRefBoundsCapture vA(f.getArgument(0));
   AffineIndexedValue A(f.getArgument(0));
   affineLoopNestBuilder({zero, zero}, {one, one}, {1, 1}, [&](ValueRange ivs) {
-    std_select(eq(ivs[0], zero), A(zero, zero), A(ivs[0], ivs[1]));
+    using namespace edsc::op;
+    Value i = ivs[0], j = ivs[1];
+    std_select(eq(i, zero), A(zero, zero), A(i, j));
+    std_select(ne(i, zero), A(zero, zero), A(i, j));
+    std_select(slt(i, zero), A(zero, zero), A(i, j));
+    std_select(sle(i, zero), A(zero, zero), A(i, j));
+    std_select(sgt(i, zero), A(zero, zero), A(i, j));
+    std_select(sge(i, zero), A(zero, zero), A(i, j));
+    std_select(ult(i, zero), A(zero, zero), A(i, j));
+    std_select(ule(i, zero), A(zero, zero), A(i, j));
+    std_select(ugt(i, zero), A(zero, zero), A(i, j));
+    std_select(uge(i, zero), A(zero, zero), A(i, j));
   });
 
   // clang-format off
@@ -481,6 +492,42 @@ TEST_FUNC(select_op_i32) {
   //  CHECK-DAG:     {{.*}} = affine.load
   //  CHECK-DAG:     {{.*}} = affine.load
   // CHECK-NEXT:     {{.*}} = select
+  //  CHECK-DAG:     {{.*}} = cmpi "ne"
+  //  CHECK-DAG:     {{.*}} = affine.load
+  //  CHECK-DAG:     {{.*}} = affine.load
+  // CHECK-NEXT:     {{.*}} = select
+  //  CHECK-DAG:     {{.*}} = cmpi "slt"
+  //  CHECK-DAG:     {{.*}} = affine.load
+  //  CHECK-DAG:     {{.*}} = affine.load
+  // CHECK-NEXT:     {{.*}} = select
+  //  CHECK-DAG:     {{.*}} = cmpi "sle"
+  //  CHECK-DAG:     {{.*}} = affine.load
+  //  CHECK-DAG:     {{.*}} = affine.load
+  // CHECK-NEXT:     {{.*}} = select
+  //  CHECK-DAG:     {{.*}} = cmpi "sgt"
+  //  CHECK-DAG:     {{.*}} = affine.load
+  //  CHECK-DAG:     {{.*}} = affine.load
+  // CHECK-NEXT:     {{.*}} = select
+  //  CHECK-DAG:     {{.*}} = cmpi "sge"
+  //  CHECK-DAG:     {{.*}} = affine.load
+  //  CHECK-DAG:     {{.*}} = affine.load
+  // CHECK-NEXT:     {{.*}} = select
+  //  CHECK-DAG:     {{.*}} = cmpi "ult"
+  //  CHECK-DAG:     {{.*}} = affine.load
+  //  CHECK-DAG:     {{.*}} = affine.load
+  // CHECK-NEXT:     {{.*}} = select
+  //  CHECK-DAG:     {{.*}} = cmpi "ule"
+  //  CHECK-DAG:     {{.*}} = affine.load
+  //  CHECK-DAG:     {{.*}} = affine.load
+  // CHECK-NEXT:     {{.*}} = select
+  //  CHECK-DAG:     {{.*}} = cmpi "ugt"
+  //  CHECK-DAG:     {{.*}} = affine.load
+  //  CHECK-DAG:     {{.*}} = affine.load
+  // CHECK-NEXT:     {{.*}} = select
+  //  CHECK-DAG:     {{.*}} = cmpi "uge"
+  //  CHECK-DAG:     {{.*}} = affine.load
+  //  CHECK-DAG:     {{.*}} = affine.load
+  // CHECK-NEXT:     {{.*}} = select
   // clang-format on
   f.print(llvm::outs());
   f.erase();
@@ -503,10 +550,14 @@ TEST_FUNC(select_op_f32) {
     Value i = ivs[0], j = ivs[1];
     std_select(eq(B(i, j), B(i + one, j)), A(zero, zero), A(i, j));
     std_select(ne(B(i, j), B(i + one, j)), A(zero, zero), A(i, j));
-    std_select(B(i, j) >= B(i + one, j), A(zero, zero), A(i, j));
-    std_select(B(i, j) <= B(i + one, j), A(zero, zero), A(i, j));
-    std_select(B(i, j) < B(i + one, j), A(zero, zero), A(i, j));
-    std_select(B(i, j) > B(i + one, j), A(zero, zero), A(i, j));
+    std_select(sge(B(i, j), B(i + one, j)), A(zero, zero), A(i, j));
+    std_select(sle(B(i, j), B(i + one, j)), A(zero, zero), A(i, j));
+    std_select(slt(B(i, j), B(i + one, j)), A(zero, zero), A(i, j));
+    std_select(sgt(B(i, j), B(i + one, j)), A(zero, zero), A(i, j));
+    std_select(uge(B(i, j), B(i + one, j)), A(zero, zero), A(i, j));
+    std_select(ule(B(i, j), B(i + one, j)), A(zero, zero), A(i, j));
+    std_select(ult(B(i, j), B(i + one, j)), A(zero, zero), A(i, j));
+    std_select(ugt(B(i, j), B(i + one, j)), A(zero, zero), A(i, j));
   });
 
   // CHECK-LABEL: @select_op
@@ -554,6 +605,34 @@ TEST_FUNC(select_op_f32) {
   //  CHECK-DAG:     affine.load
   //  CHECK-DAG:     affine.apply
   // CHECK-NEXT:     select
+  //  CHECK-DAG:     cmpf "oge"
+  //  CHECK-DAG:     affine.load
+  //  CHECK-DAG:     affine.load
+  //  CHECK-DAG:     affine.load
+  //  CHECK-DAG:     affine.load
+  //  CHECK-DAG:     affine.apply
+  // CHECK-NEXT:     select
+  //  CHECK-DAG:     cmpf "ole"
+  //  CHECK-DAG:     affine.load
+  //  CHECK-DAG:     affine.load
+  //  CHECK-DAG:     affine.load
+  //  CHECK-DAG:     affine.load
+  //  CHECK-DAG:     affine.apply
+  // CHECK-NEXT:     select
+  //  CHECK-DAG:     cmpf "olt"
+  //  CHECK-DAG:     affine.load
+  //  CHECK-DAG:     affine.load
+  //  CHECK-DAG:     affine.load
+  //  CHECK-DAG:     affine.load
+  //  CHECK-DAG:     affine.apply
+  // CHECK-NEXT:     select
+  //  CHECK-DAG:     cmpf "ogt"
+  //  CHECK-DAG:     affine.load
+  //  CHECK-DAG:     affine.load
+  //  CHECK-DAG:     affine.load
+  //  CHECK-DAG:     affine.load
+  //  CHECK-DAG:     affine.apply
+  // CHECK-NEXT:     select
   // clang-format on
   f.print(llvm::outs());
   f.erase();


        


More information about the Mlir-commits mailing list