[Mlir-commits] [mlir] 89837a0 - Adding min(f/s/u) and max(f/s/u) cases for vector reduction

Alexander Slepko llvmlistbot at llvm.org
Thu Sep 9 12:25:09 PDT 2021


Author: Alexander Slepko
Date: 2021-09-09T12:00:43-07:00
New Revision: 89837a0e1b536c88651b9e68b73eda9c18659db2

URL: https://github.com/llvm/llvm-project/commit/89837a0e1b536c88651b9e68b73eda9c18659db2
DIFF: https://github.com/llvm/llvm-project/commit/89837a0e1b536c88651b9e68b73eda9c18659db2.diff

LOG: Adding min(f/s/u) and max(f/s/u) cases for vector reduction

This PR adds missing AtomicRMWKind::min/max cases which we would like to use for min/max reduction loop vectorizations.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D104881

Added: 
    

Modified: 
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/lib/Dialect/Vector/VectorOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 2c99e85f7f071..ed0417634e309 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -350,9 +350,32 @@ static LogicalResult verify(AtomicRMWOp op) {
 Attribute mlir::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
                                      OpBuilder &builder, Location loc) {
   switch (kind) {
+  case AtomicRMWKind::maxf:
+    return builder.getFloatAttr(
+        resultType,
+        APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
+                        /*Negative=*/true));
   case AtomicRMWKind::addf:
   case AtomicRMWKind::addi:
+  case AtomicRMWKind::maxu:
     return builder.getZeroAttr(resultType);
+  case AtomicRMWKind::maxs:
+    return builder.getIntegerAttr(
+        resultType,
+        APInt::getSignedMinValue(resultType.cast<IntegerType>().getWidth()));
+  case AtomicRMWKind::minf:
+    return builder.getFloatAttr(
+        resultType,
+        APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
+                        /*Negative=*/false));
+  case AtomicRMWKind::mins:
+    return builder.getIntegerAttr(
+        resultType,
+        APInt::getSignedMaxValue(resultType.cast<IntegerType>().getWidth()));
+  case AtomicRMWKind::minu:
+    return builder.getIntegerAttr(
+        resultType,
+        APInt::getMaxValue(resultType.cast<IntegerType>().getWidth()));
   case AtomicRMWKind::muli:
     return builder.getIntegerAttr(resultType, 1);
   case AtomicRMWKind::mulf:
@@ -385,6 +408,30 @@ Value mlir::getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc,
     return builder.create<MulFOp>(loc, lhs, rhs);
   case AtomicRMWKind::muli:
     return builder.create<MulIOp>(loc, lhs, rhs);
+  case AtomicRMWKind::maxf:
+    return builder.create<SelectOp>(
+        loc, builder.create<CmpFOp>(loc, CmpFPredicate::OGT, lhs, rhs), lhs,
+        rhs);
+  case AtomicRMWKind::minf:
+    return builder.create<SelectOp>(
+        loc, builder.create<CmpFOp>(loc, CmpFPredicate::OLT, lhs, rhs), lhs,
+        rhs);
+  case AtomicRMWKind::maxs:
+    return builder.create<SelectOp>(
+        loc, builder.create<CmpIOp>(loc, CmpIPredicate::sgt, lhs, rhs), lhs,
+        rhs);
+  case AtomicRMWKind::mins:
+    return builder.create<SelectOp>(
+        loc, builder.create<CmpIOp>(loc, CmpIPredicate::slt, lhs, rhs), lhs,
+        rhs);
+  case AtomicRMWKind::maxu:
+    return builder.create<SelectOp>(
+        loc, builder.create<CmpIOp>(loc, CmpIPredicate::ugt, lhs, rhs), lhs,
+        rhs);
+  case AtomicRMWKind::minu:
+    return builder.create<SelectOp>(
+        loc, builder.create<CmpIOp>(loc, CmpIPredicate::ult, lhs, rhs), lhs,
+        rhs);
   // TODO: Add remaining reduction operations.
   default:
     (void)emitOptionalError(loc, "Reduction operation type not supported");

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 279cc1846769f..40741fa4d9d0e 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -357,6 +357,18 @@ Value mlir::vector::getVectorReductionOp(AtomicRMWKind op, OpBuilder &builder,
     return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
                                                builder.getStringAttr("mul"),
                                                vector, ValueRange{});
+  case AtomicRMWKind::minf:
+  case AtomicRMWKind::mins:
+  case AtomicRMWKind::minu:
+    return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
+                                               builder.getStringAttr("min"),
+                                               vector, ValueRange{});
+  case AtomicRMWKind::maxf:
+  case AtomicRMWKind::maxs:
+  case AtomicRMWKind::maxu:
+    return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
+                                               builder.getStringAttr("max"),
+                                               vector, ValueRange{});
   // TODO: Add remaining reduction operations.
   default:
     (void)emitOptionalError(loc, "Reduction operation type not supported");


        


More information about the Mlir-commits mailing list