[Mlir-commits] [mlir] d625ea1 - [mlir][spirv] Split codegen for float min/max reductions and others v2. [NFC] (#73363)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Nov 24 12:24:49 PST 2023


Author: Jakub Kuderski
Date: 2023-11-24T15:24:45-05:00
New Revision: d625ea12c71813db0da4c2e5867e907da22e22f2

URL: https://github.com/llvm/llvm-project/commit/d625ea12c71813db0da4c2e5867e907da22e22f2
DIFF: https://github.com/llvm/llvm-project/commit/d625ea12c71813db0da4c2e5867e907da22e22f2.diff

LOG: [mlir][spirv] Split codegen for float min/max reductions and others v2. [NFC] (#73363)

This is https://github.com/llvm/llvm-project/pull/69023 but with
cleanups.
Reduced complexity by avoiding CRTP and preprocessor defines in favor of
free functions

Original description by @unterumarmung:

---

This patch is part of a larger initiative aimed at fixing floating-point
`max` and `min` operations in MLIR:
https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671.

There are two types of min/max operations for floating-point numbers:
`minf`/`maxf` and `minimumf`/`maximumf`. The code generation for these
operations should differ from that of other vector reduction kinds. This
difference arises because CL and GL operations for floating-point min
and max do not have the same semantics when handling NaNs. Therefore, we
must enforce the desired semantics with additional ops.

~~However, since the code generation for floating-point min/max
operations shares the same functionality as extracting values for the
vector, we have decided to refactor the existing code using the CRTP
pattern.~~ This change does not alter the actual behavior of the code
and is necessary for future fixes to the codegen for floating-point
min/max operations.

---------

Co-authored-by: Daniil Dudkin <unterumarmung at yandex.ru>

Added: 
    

Modified: 
    mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index dcc6449d3fe8927..05ef535dde4b5c7 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -20,6 +20,7 @@
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Location.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
@@ -27,6 +28,7 @@
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/SmallVectorExtras.h"
 #include "llvm/Support/FormatVariadic.h"
 #include <cassert>
@@ -351,39 +353,64 @@ struct VectorInsertStridedSliceOpConvert final
   }
 };
 
-template <class SPIRVFMaxOp, class SPIRVFMinOp, class SPIRVUMaxOp,
-          class SPIRVUMinOp, class SPIRVSMaxOp, class SPIRVSMinOp>
-struct VectorReductionPattern final
-    : public OpConversionPattern<vector::ReductionOp> {
+static SmallVector<Value> extractAllElements(
+    vector::ReductionOp reduceOp, vector::ReductionOp::Adaptor adaptor,
+    VectorType srcVectorType, ConversionPatternRewriter &rewriter) {
+  int numElements = static_cast<int>(srcVectorType.getDimSize(0));
+  SmallVector<Value> values;
+  values.reserve(numElements + (adaptor.getAcc() ? 1 : 0));
+  Location loc = reduceOp.getLoc();
+
+  for (int i = 0; i < numElements; ++i) {
+    values.push_back(rewriter.create<spirv::CompositeExtractOp>(
+        loc, srcVectorType.getElementType(), adaptor.getVector(),
+        rewriter.getI32ArrayAttr({i})));
+  }
+  if (Value acc = adaptor.getAcc())
+    values.push_back(acc);
+
+  return values;
+}
+
+struct ReductionRewriteInfo {
+  Type resultType;
+  SmallVector<Value> extractedElements;
+};
+
+FailureOr<ReductionRewriteInfo> static getReductionInfo(
+    vector::ReductionOp op, vector::ReductionOp::Adaptor adaptor,
+    ConversionPatternRewriter &rewriter, const TypeConverter &typeConverter) {
+  Type resultType = typeConverter.convertType(op.getType());
+  if (!resultType)
+    return failure();
+
+  auto srcVectorType = dyn_cast<VectorType>(adaptor.getVector().getType());
+  if (!srcVectorType || srcVectorType.getRank() != 1)
+    return rewriter.notifyMatchFailure(op, "not a 1-D vector source");
+
+  SmallVector<Value> extractedElements =
+      extractAllElements(op, adaptor, srcVectorType, rewriter);
+
+  return ReductionRewriteInfo{resultType, std::move(extractedElements)};
+}
+
+template <typename SPIRVUMaxOp, typename SPIRVUMinOp, typename SPIRVSMaxOp,
+          typename SPIRVSMinOp>
+struct VectorReductionPattern final : OpConversionPattern<vector::ReductionOp> {
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    Type resultType = typeConverter->convertType(reduceOp.getType());
-    if (!resultType)
+    auto reductionInfo =
+        getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter());
+    if (failed(reductionInfo))
       return failure();
 
-    auto srcVectorType = dyn_cast<VectorType>(adaptor.getVector().getType());
-    if (!srcVectorType || srcVectorType.getRank() != 1)
-      return rewriter.notifyMatchFailure(reduceOp, "not 1-D vector source");
-
-    // Extract all elements.
-    int numElements = srcVectorType.getDimSize(0);
-    SmallVector<Value, 4> values;
-    values.reserve(numElements + (adaptor.getAcc() != nullptr));
-    Location loc = reduceOp.getLoc();
-    for (int i = 0; i < numElements; ++i) {
-      values.push_back(rewriter.create<spirv::CompositeExtractOp>(
-          loc, srcVectorType.getElementType(), adaptor.getVector(),
-          rewriter.getI32ArrayAttr({i})));
-    }
-    if (Value acc = adaptor.getAcc())
-      values.push_back(acc);
-
-    // Reduce them.
-    Value result = values.front();
-    for (Value next : llvm::ArrayRef(values).drop_front()) {
+    auto [resultType, extractedElements] = *reductionInfo;
+    Location loc = reduceOp->getLoc();
+    Value result = extractedElements.front();
+    for (Value next : llvm::drop_begin(extractedElements)) {
       switch (reduceOp.getKind()) {
 
 #define INT_AND_FLOAT_CASE(kind, iop, fop)                                     \
@@ -403,10 +430,6 @@ struct VectorReductionPattern final
 
         INT_AND_FLOAT_CASE(ADD, IAddOp, FAddOp);
         INT_AND_FLOAT_CASE(MUL, IMulOp, FMulOp);
-        INT_OR_FLOAT_CASE(MAXIMUMF, SPIRVFMaxOp);
-        INT_OR_FLOAT_CASE(MINIMUMF, SPIRVFMinOp);
-        INT_OR_FLOAT_CASE(MAXF, SPIRVFMaxOp);
-        INT_OR_FLOAT_CASE(MINF, SPIRVFMinOp);
         INT_OR_FLOAT_CASE(MINUI, SPIRVUMinOp);
         INT_OR_FLOAT_CASE(MINSI, SPIRVSMinOp);
         INT_OR_FLOAT_CASE(MAXUI, SPIRVUMaxOp);
@@ -416,7 +439,51 @@ struct VectorReductionPattern final
       case vector::CombiningKind::OR:
       case vector::CombiningKind::XOR:
         return rewriter.notifyMatchFailure(reduceOp, "unimplemented");
+      default:
+        return rewriter.notifyMatchFailure(reduceOp, "not handled here");
       }
+#undef INT_AND_FLOAT_CASE
+#undef INT_OR_FLOAT_CASE
+    }
+
+    rewriter.replaceOp(reduceOp, result);
+    return success();
+  }
+};
+
+template <typename SPIRVFMaxOp, typename SPIRVFMinOp>
+struct VectorReductionFloatMinMax final
+    : OpConversionPattern<vector::ReductionOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto reductionInfo =
+        getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter());
+    if (failed(reductionInfo))
+      return failure();
+
+    auto [resultType, extractedElements] = *reductionInfo;
+    Location loc = reduceOp->getLoc();
+    Value result = extractedElements.front();
+    for (Value next : llvm::drop_begin(extractedElements)) {
+      switch (reduceOp.getKind()) {
+
+#define INT_OR_FLOAT_CASE(kind, fop)                                           \
+  case vector::CombiningKind::kind:                                            \
+    result = rewriter.create<fop>(loc, resultType, result, next);              \
+    break
+
+        INT_OR_FLOAT_CASE(MAXIMUMF, SPIRVFMaxOp);
+        INT_OR_FLOAT_CASE(MINIMUMF, SPIRVFMinOp);
+        INT_OR_FLOAT_CASE(MAXF, SPIRVFMaxOp);
+        INT_OR_FLOAT_CASE(MINF, SPIRVFMinOp);
+
+      default:
+        return rewriter.notifyMatchFailure(reduceOp, "not handled here");
+      }
+#undef INT_OR_FLOAT_CASE
     }
 
     rewriter.replaceOp(reduceOp, result);
@@ -674,13 +741,14 @@ struct VectorReductionToDotProd final : OpRewritePattern<vector::ReductionOp> {
 };
 
 } // namespace
-#define CL_MAX_MIN_OPS                                                         \
-  spirv::CLFMaxOp, spirv::CLFMinOp, spirv::CLUMaxOp, spirv::CLUMinOp,          \
-      spirv::CLSMaxOp, spirv::CLSMinOp
+#define CL_INT_MAX_MIN_OPS                                                     \
+  spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
+
+#define GL_INT_MAX_MIN_OPS                                                     \
+  spirv::GLUMaxOp, spirv::GLUMinOp, spirv::GLSMaxOp, spirv::GLSMinOp
 
-#define GL_MAX_MIN_OPS                                                         \
-  spirv::GLFMaxOp, spirv::GLFMinOp, spirv::GLUMaxOp, spirv::GLUMinOp,          \
-      spirv::GLSMaxOp, spirv::GLSMinOp
+#define CL_FLOAT_MAX_MIN_OPS spirv::CLFMaxOp, spirv::CLFMinOp
+#define GL_FLOAT_MAX_MIN_OPS spirv::GLFMaxOp, spirv::GLFMinOp
 
 void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                          RewritePatternSet &patterns) {
@@ -689,8 +757,10 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
       VectorExtractElementOpConvert, VectorExtractOpConvert,
       VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
       VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert,
-      VectorInsertOpConvert, VectorReductionPattern<GL_MAX_MIN_OPS>,
-      VectorReductionPattern<CL_MAX_MIN_OPS>, VectorShapeCast,
+      VectorInsertOpConvert, VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
+      VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
+      VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
+      VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
       VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
       VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
       typeConverter, patterns.getContext());


        


More information about the Mlir-commits mailing list