[Mlir-commits] [mlir] 9fb57c8 - [mlir] Add min/max operations to Standard.

Alexander Belyaev llvmlistbot at llvm.org
Tue Sep 28 00:40:36 PDT 2021


Author: Alexander Belyaev
Date: 2021-09-28T09:40:22+02:00
New Revision: 9fb57c8c1dd87df36daf7b6f7dee3b7423475afc

URL: https://github.com/llvm/llvm-project/commit/9fb57c8c1dd87df36daf7b6f7dee3b7423475afc
DIFF: https://github.com/llvm/llvm-project/commit/9fb57c8c1dd87df36daf7b6f7dee3b7423475afc.diff

LOG: [mlir] Add min/max operations to Standard.

[RFC: Add min/max ops](https://llvm.discourse.group/t/rfc-add-min-max-operations/4353)

I was following the naming style for Arith dialect in
https://reviews.llvm.org/D110200,
i.e. similar to DivSIOp and DivUIOp I defined MaxSIOp, MaxUIOp.

When Arith PR is landed, I will migrate these ops as well.

Differential Revision: https://reviews.llvm.org/D110540

Added: 
    

Modified: 
    mlir/docs/Rationale/Rationale.md
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp
    mlir/test/Dialect/Standard/expand-ops.mlir
    mlir/test/Dialect/Standard/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Rationale/Rationale.md b/mlir/docs/Rationale/Rationale.md
index 08d3c3d0ff8c6..1386df8deea98 100644
--- a/mlir/docs/Rationale/Rationale.md
+++ b/mlir/docs/Rationale/Rationale.md
@@ -344,33 +344,6 @@ possible to store the predicate as string attribute, it would have rendered
 impossible to implement switching logic based on the comparison kind and made
 attribute validity checks (one out of ten possible kinds) more complex.
 
-### 'select' operation to implement min/max
-
-Although `min` and `max` operations are likely to occur as a result of
-transforming affine loops in ML functions, we did not make them first-class
-operations. Instead, we provide the `select` operation that can be combined with
-`cmpi` to implement the minimum and maximum computation. Although they now
-require two operations, they are likely to be emitted automatically during the
-transformation inside MLIR. On the other hand, there are multiple benefits of
-introducing `select`: standalone min/max would concern themselves with the
-signedness of the comparison, already taken into account by `cmpi`; `select` can
-support floats transparently if used after a float-comparison operation; the
-lower-level targets provide `select`-like instructions making the translation
-trivial.
-
-This operation could have been implemented with additional control flow: `%r =
-select %cond, %t, %f` is equivalent to
-
-```mlir
-^bb0:
-  cond_br %cond, ^bb1(%t), ^bb1(%f)
-^bb1(%r):
-```
-
-However, this control flow granularity is not available in the ML functions
-where min/max, and thus `select`, are likely to appear. In addition, simpler
-control flow may be beneficial for optimization in general.
-
 ### Regions
 
 #### Attributes of type 'Block'

diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 8436ac796a97e..14b4b55d1a832 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -1247,6 +1247,152 @@ def IndexCastOp : ArithmeticCastOp<"index_cast"> {
   let hasCanonicalizer = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// MaxFOp
+//===----------------------------------------------------------------------===//
+
+def MaxFOp : FloatBinaryOp<"maxf"> {
+  let summary = "floating-point maximum operation";
+  let description = [{
+    Syntax:
+
+    ```
+    operation ::= ssa-id `=` `maxf` ssa-use `,` ssa-use `:` type
+    ```
+
+    Returns the maximum of the two arguments, treating -0.0 as less than +0.0.
+    If one of the arguments is NaN, then the result is also NaN.
+
+    Example:
+
+    ```mlir
+    // Scalar floating-point maximum.
+    %a = maxf %b, %c : f64
+    ```
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// MaxSIOp
+//===----------------------------------------------------------------------===//
+
+def MaxSIOp : IntBinaryOp<"maxsi"> {
+  let summary = "signed integer maximum operation";
+  let description = [{
+    Syntax:
+
+    ```
+    operation ::= ssa-id `=` `maxsi` ssa-use `,` ssa-use `:` type
+    ```
+
+    Returns the larger of %a and %b comparing the values as signed integers.
+
+    Example:
+
+    ```mlir
+    // Scalar signed integer maximum.
+    %a = maxsi %b, %c : i64
+    ```
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// MaxUIOp
+//===----------------------------------------------------------------------===//
+
+def MaxUIOp : IntBinaryOp<"maxui"> {
+  let summary = "unsigned integer maximum operation";
+  let description = [{
+    Syntax:
+
+    ```
+    operation ::= ssa-id `=` `maxui` ssa-use `,` ssa-use `:` type
+    ```
+
+    Returns the larger of %a and %b comparing the values as unsigned integers.
+
+    Example:
+
+    ```mlir
+    // Scalar unsigned integer maximum.
+    %a = maxui %b, %c : i64
+    ```
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// MinFOp
+//===----------------------------------------------------------------------===//
+
+def MinFOp : FloatBinaryOp<"minf"> {
+  let summary = "floating-point minimum operation";
+  let description = [{
+    Syntax:
+
+    ```
+    operation ::= ssa-id `=` `minf` ssa-use `,` ssa-use `:` type
+    ```
+
+    Returns the minimum of the two arguments, treating -0.0 as less than +0.0.
+    If one of the arguments is NaN, then the result is also NaN.
+
+    Example:
+
+    ```mlir
+    // Scalar floating-point minimum.
+    %a = minf %b, %c : f64
+    ```
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// MinSIOp
+//===----------------------------------------------------------------------===//
+
+def MinSIOp : IntBinaryOp<"minsi"> {
+  let summary = "signed integer minimum operation";
+  let description = [{
+    Syntax:
+
+    ```
+    operation ::= ssa-id `=` `minsi` ssa-use `,` ssa-use `:` type
+    ```
+
+    Returns the smaller of %a and %b comparing the values as signed integers.
+
+    Example:
+
+    ```mlir
+    // Scalar signed integer minimum.
+    %a = minsi %b, %c : i64
+    ```
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// MinUIOp
+//===----------------------------------------------------------------------===//
+
+def MinUIOp : IntBinaryOp<"minui"> {
+  let summary = "unsigned integer minimum operation";
+  let description = [{
+    Syntax:
+
+    ```
+    operation ::= ssa-id `=` `minui` ssa-use `,` ssa-use `:` type
+    ```
+
+    Returns the smaller of %a and %b comparing the values as unsigned integers.
+
+    Example:
+
+    ```mlir
+    // Scalar unsigned integer minimum.
+    %a = minui %b, %c : i64
+    ```
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // MulFOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp
index d4810ac40f31b..dc342d2033596 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp
@@ -215,6 +215,55 @@ struct SignedFloorDivIOpConverter : public OpRewritePattern<SignedFloorDivIOp> {
   }
 };
 
+static Type getElementTypeOrSelf(Type type) {
+  if (auto st = type.dyn_cast<ShapedType>())
+    return st.getElementType();
+  return type;
+}
+
+template <typename OpTy, CmpFPredicate pred>
+struct MaxMinFOpConverter : public OpRewritePattern<OpTy> {
+public:
+  using OpRewritePattern<OpTy>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(OpTy op,
+                                PatternRewriter &rewriter) const final {
+    Value lhs = op.lhs();
+    Value rhs = op.rhs();
+
+    Location loc = op.getLoc();
+    Value cmp = rewriter.create<CmpFOp>(loc, pred, lhs, rhs);
+    Value select = rewriter.create<SelectOp>(loc, cmp, lhs, rhs);
+
+    auto floatType = getElementTypeOrSelf(lhs.getType()).cast<FloatType>();
+    Value isNaN = rewriter.create<CmpFOp>(loc, CmpFPredicate::UNO, lhs, rhs);
+
+    Value nan = rewriter.create<ConstantFloatOp>(
+        loc, APFloat::getQNaN(floatType.getFloatSemantics()), floatType);
+    if (VectorType vectorType = lhs.getType().dyn_cast<VectorType>())
+      nan = rewriter.create<SplatOp>(loc, vectorType, nan);
+
+    rewriter.replaceOpWithNewOp<SelectOp>(op, isNaN, nan, select);
+    return success();
+  }
+};
+
+template <typename OpTy, CmpIPredicate pred>
+struct MaxMinIOpConverter : public OpRewritePattern<OpTy> {
+public:
+  using OpRewritePattern<OpTy>::OpRewritePattern;
+  LogicalResult matchAndRewrite(OpTy op,
+                                PatternRewriter &rewriter) const final {
+    Value lhs = op.lhs();
+    Value rhs = op.rhs();
+
+    Location loc = op.getLoc();
+    Value cmp = rewriter.create<CmpIOp>(loc, pred, lhs, rhs);
+    rewriter.replaceOpWithNewOp<SelectOp>(op, cmp, lhs, rhs);
+    return success();
+  }
+};
+
 struct StdExpandOpsPass : public StdExpandOpsBase<StdExpandOpsPass> {
   void runOnFunction() override {
     MLIRContext &ctx = getContext();
@@ -232,8 +281,18 @@ struct StdExpandOpsPass : public StdExpandOpsBase<StdExpandOpsPass> {
     target.addDynamicallyLegalOp<memref::ReshapeOp>([](memref::ReshapeOp op) {
       return !op.shape().getType().cast<MemRefType>().hasStaticShape();
     });
-    target.addIllegalOp<SignedCeilDivIOp>();
-    target.addIllegalOp<SignedFloorDivIOp>();
+    // clang-format off
+    target.addIllegalOp<
+      MaxFOp,
+      MaxSIOp,
+      MaxUIOp,
+      MinFOp,
+      MinSIOp,
+      MinUIOp,
+      SignedCeilDivIOp,
+      SignedFloorDivIOp
+    >();
+    // clang-format on
     if (failed(
             applyPartialConversion(getFunction(), target, std::move(patterns))))
       signalPassFailure();
@@ -243,9 +302,20 @@ struct StdExpandOpsPass : public StdExpandOpsBase<StdExpandOpsPass> {
 } // namespace
 
 void mlir::populateStdExpandOpsPatterns(RewritePatternSet &patterns) {
-  patterns.add<AtomicRMWOpConverter, MemRefReshapeOpConverter,
-               SignedCeilDivIOpConverter, SignedFloorDivIOpConverter>(
-      patterns.getContext());
+  // clang-format off
+  patterns.add<
+    AtomicRMWOpConverter,
+    MaxMinFOpConverter<MaxFOp, CmpFPredicate::OGT>,
+    MaxMinFOpConverter<MinFOp, CmpFPredicate::OLT>,
+    MaxMinIOpConverter<MaxSIOp, CmpIPredicate::sgt>,
+    MaxMinIOpConverter<MaxUIOp, CmpIPredicate::ugt>,
+    MaxMinIOpConverter<MinSIOp, CmpIPredicate::slt>,
+    MaxMinIOpConverter<MinUIOp, CmpIPredicate::ult>,
+    MemRefReshapeOpConverter,
+    SignedCeilDivIOpConverter,
+    SignedFloorDivIOpConverter
+  >(patterns.getContext());
+  // clang-format on
 }
 
 std::unique_ptr<Pass> mlir::createStdExpandOpsPass() {

diff  --git a/mlir/test/Dialect/Standard/expand-ops.mlir b/mlir/test/Dialect/Standard/expand-ops.mlir
index 44637a8f136be..04587725f8bff 100644
--- a/mlir/test/Dialect/Standard/expand-ops.mlir
+++ b/mlir/test/Dialect/Standard/expand-ops.mlir
@@ -109,3 +109,92 @@ func @memref_reshape(%input: memref<*xf32>,
 // CHECK-SAME: to offset: [0], sizes: {{\[}}[[SIZE_0]], [[SIZE_1]], 8],
 // CHECK-SAME: strides: {{\[}}[[STRIDE_0]], [[STRIDE_1]], [[C1]]]
 // CHECK-SAME: : memref<*xf32> to memref<?x?x8xf32>
+
+// -----
+
+// CHECK-LABEL: func @maxf
+func @maxf(%a: f32, %b: f32) -> f32 {
+  %result = maxf(%a, %b): (f32, f32) -> f32
+  return %result : f32
+}
+// CHECK-SAME: %[[LHS:.*]]: f32, %[[RHS:.*]]: f32)
+// CHECK-NEXT: %[[CMP:.*]] = cmpf ogt, %[[LHS]], %[[RHS]] : f32
+// CHECK-NEXT: %[[SELECT:.*]] = select %[[CMP]], %[[LHS]], %[[RHS]] : f32
+// CHECK-NEXT: %[[IS_NAN:.*]] = cmpf uno, %[[LHS]], %[[RHS]] : f32
+// CHECK-NEXT: %[[NAN:.*]] = constant 0x7FC00000 : f32
+// CHECK-NEXT: %[[RESULT:.*]] = select %[[IS_NAN]], %[[NAN]], %[[SELECT]] : f32
+// CHECK-NEXT: return %[[RESULT]] : f32
+
+// -----
+
+// CHECK-LABEL: func @maxf_vector
+func @maxf_vector(%a: vector<4xf16>, %b: vector<4xf16>) -> vector<4xf16> {
+  %result = maxf(%a, %b): (vector<4xf16>, vector<4xf16>) -> vector<4xf16>
+  return %result : vector<4xf16>
+}
+// CHECK-SAME: %[[LHS:.*]]: vector<4xf16>, %[[RHS:.*]]: vector<4xf16>)
+// CHECK-NEXT: %[[CMP:.*]] = cmpf ogt, %[[LHS]], %[[RHS]] : vector<4xf16>
+// CHECK-NEXT: %[[SELECT:.*]] = select %[[CMP]], %[[LHS]], %[[RHS]]
+// CHECK-NEXT: %[[IS_NAN:.*]] = cmpf uno, %[[LHS]], %[[RHS]] : vector<4xf16>
+// CHECK-NEXT: %[[NAN:.*]] = constant 0x7E00 : f16
+// CHECK-NEXT: %[[SPLAT_NAN:.*]] = splat %[[NAN]] : vector<4xf16>
+// CHECK-NEXT: %[[RESULT:.*]] = select %[[IS_NAN]], %[[SPLAT_NAN]], %[[SELECT]]
+// CHECK-NEXT: return %[[RESULT]] : vector<4xf16>
+
+// -----
+
+// CHECK-LABEL: func @minf
+func @minf(%a: f32, %b: f32) -> f32 {
+  %result = minf(%a, %b): (f32, f32) -> f32
+  return %result : f32
+}
+// CHECK-SAME: %[[LHS:.*]]: f32, %[[RHS:.*]]: f32)
+// CHECK-NEXT: %[[CMP:.*]] = cmpf olt, %[[LHS]], %[[RHS]] : f32
+// CHECK-NEXT: %[[SELECT:.*]] = select %[[CMP]], %[[LHS]], %[[RHS]] : f32
+// CHECK-NEXT: %[[IS_NAN:.*]] = cmpf uno, %[[LHS]], %[[RHS]] : f32
+// CHECK-NEXT: %[[NAN:.*]] = constant 0x7FC00000 : f32
+// CHECK-NEXT: %[[RESULT:.*]] = select %[[IS_NAN]], %[[NAN]], %[[SELECT]] : f32
+// CHECK-NEXT: return %[[RESULT]] : f32
+
+
+// -----
+
+// CHECK-LABEL: func @maxsi
+func @maxsi(%a: i32, %b: i32) -> i32 {
+  %result = maxsi(%a, %b): (i32, i32) -> i32
+  return %result : i32
+}
+// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32)
+// CHECK-NEXT: %[[CMP:.*]] = cmpi sgt, %[[LHS]], %[[RHS]] : i32
+
+// -----
+
+// CHECK-LABEL: func @minsi
+func @minsi(%a: i32, %b: i32) -> i32 {
+  %result = minsi(%a, %b): (i32, i32) -> i32
+  return %result : i32
+}
+// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32)
+// CHECK-NEXT: %[[CMP:.*]] = cmpi slt, %[[LHS]], %[[RHS]] : i32
+
+
+// -----
+
+// CHECK-LABEL: func @maxui
+func @maxui(%a: i32, %b: i32) -> i32 {
+  %result = maxui(%a, %b): (i32, i32) -> i32
+  return %result : i32
+}
+// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32)
+// CHECK-NEXT: %[[CMP:.*]] = cmpi ugt, %[[LHS]], %[[RHS]] : i32
+
+
+// -----
+
+// CHECK-LABEL: func @minui
+func @minui(%a: i32, %b: i32) -> i32 {
+  %result = minui(%a, %b): (i32, i32) -> i32
+  return %result : i32
+}
+// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32)
+// CHECK-NEXT: %[[CMP:.*]] = cmpi ult, %[[LHS]], %[[RHS]] : i32

diff  --git a/mlir/test/Dialect/Standard/ops.mlir b/mlir/test/Dialect/Standard/ops.mlir
index 0f73a8d9358a4..69253e0d7761b 100644
--- a/mlir/test/Dialect/Standard/ops.mlir
+++ b/mlir/test/Dialect/Standard/ops.mlir
@@ -86,3 +86,27 @@ func @bitcast(%arg : f32) -> i32 {
   %res = bitcast %arg : f32 to i32
   return %res : i32
 }
+
+// CHECK-LABEL: func @maximum
+func @maximum(%v1: vector<4xf32>, %v2: vector<4xf32>,
+               %f1: f32, %f2: f32,
+               %i1: i32, %i2: i32) {
+  %max_vector = maxf(%v1, %v2)
+    : (vector<4xf32>, vector<4xf32>) -> vector<4xf32>
+  %max_float = maxf(%f1, %f2) : (f32, f32) -> f32
+  %max_signed = maxsi(%i1, %i2) : (i32, i32) -> i32
+  %max_unsigned = maxui(%i1, %i2) : (i32, i32) -> i32
+  return
+}
+
+// CHECK-LABEL: func @minimum
+func @minimum(%v1: vector<4xf32>, %v2: vector<4xf32>,
+               %f1: f32, %f2: f32,
+               %i1: i32, %i2: i32) {
+  %min_vector = minf(%v1, %v2)
+    : (vector<4xf32>, vector<4xf32>) -> vector<4xf32>
+  %min_float = minf(%f1, %f2) : (f32, f32) -> f32
+  %min_signed = minsi(%i1, %i2) : (i32, i32) -> i32
+  %min_unsigned = minui(%i1, %i2) : (i32, i32) -> i32
+  return
+}


        


More information about the Mlir-commits mailing list