[Mlir-commits] [mlir] [mlir][spirv] Split codegen for float min/max reductions and others v2. [NFC] (PR #73363)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Nov 24 12:06:45 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-spirv

Author: Jakub Kuderski (kuhar)

<details>
<summary>Changes</summary>

This is https://github.com/llvm/llvm-project/pull/69023 but with cleanups.
Reduced complexity by avoiding CRTP and preprocessor defines in favor of free functions

Original description by @<!-- -->unterumarmung:

---

This patch is part of a larger initiative aimed at fixing floating-point `max` and `min` operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671.

There are two types of min/max operations for floating-point numbers: `minf`/`maxf` and `minimumf`/`maximumf`. The code generation for these operations should differ from that of other vector reduction kinds. This difference arises because CL and GL operations for floating-point min and max do not have the same semantics when handling NaNs. Therefore, we must enforce the desired semantics with additional ops.

~~However, since the code generation for floating-point min/max operations shares the same functionality as extracting values for the vector, we have decided to refactor the existing code using the CRTP pattern.~~ This change does not alter the actual behavior of the code and is necessary for future fixes to the codegen for floating-point min/max operations.

---
Full diff: https://github.com/llvm/llvm-project/pull/73363.diff


1 Files Affected:

- (modified) mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp (+107-38) 


``````````diff
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index dcc6449d3fe8927..29bc5f1dd73787f 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -20,6 +20,7 @@
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Location.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
@@ -27,6 +28,7 @@
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/SmallVectorExtras.h"
 #include "llvm/Support/FormatVariadic.h"
 #include <cassert>
@@ -351,39 +353,64 @@ struct VectorInsertStridedSliceOpConvert final
   }
 };
 
-template <class SPIRVFMaxOp, class SPIRVFMinOp, class SPIRVUMaxOp,
-          class SPIRVUMinOp, class SPIRVSMaxOp, class SPIRVSMinOp>
-struct VectorReductionPattern final
-    : public OpConversionPattern<vector::ReductionOp> {
+static SmallVector<Value> extractAllElements(vector::ReductionOp reduceOp,
+                                      vector::ReductionOp::Adaptor adaptor,
+                                      VectorType srcVectorType,
+                                      ConversionPatternRewriter &rewriter) {
+  int numElements = srcVectorType.getDimSize(0);
+  SmallVector<Value> values;
+  values.reserve(numElements + (adaptor.getAcc() != nullptr));
+  Location loc = reduceOp.getLoc();
+  for (int i = 0; i < numElements; ++i) {
+    values.push_back(rewriter.create<spirv::CompositeExtractOp>(
+        loc, srcVectorType.getElementType(), adaptor.getVector(),
+        rewriter.getI32ArrayAttr({i})));
+  }
+  if (Value acc = adaptor.getAcc())
+    values.push_back(acc);
+
+  return values;
+}
+
+struct ReductionRewriteInfo {
+  Type resultType;
+  SmallVector<Value> extractedElements;
+};
+
+FailureOr<ReductionRewriteInfo> static getReductionInfo(
+    vector::ReductionOp op, vector::ReductionOp::Adaptor adaptor,
+    ConversionPatternRewriter &rewriter, const TypeConverter &typeConverter) {
+  Type resultType = typeConverter.convertType(op.getType());
+  if (!resultType)
+    return failure();
+
+  auto srcVectorType = dyn_cast<VectorType>(adaptor.getVector().getType());
+  if (!srcVectorType || srcVectorType.getRank() != 1)
+    return rewriter.notifyMatchFailure(op, "not a 1-D vector source");
+
+  SmallVector<Value> extractedElements =
+      extractAllElements(op, adaptor, srcVectorType, rewriter);
+
+  return ReductionRewriteInfo{resultType, std::move(extractedElements)};
+}
+
+template <typename SPIRVUMaxOp, typename SPIRVUMinOp, typename SPIRVSMaxOp,
+          typename SPIRVSMinOp>
+struct VectorReductionPattern final : OpConversionPattern<vector::ReductionOp> {
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    Type resultType = typeConverter->convertType(reduceOp.getType());
-    if (!resultType)
+    auto reductionInfo =
+        getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter());
+    if (failed(reductionInfo))
       return failure();
 
-    auto srcVectorType = dyn_cast<VectorType>(adaptor.getVector().getType());
-    if (!srcVectorType || srcVectorType.getRank() != 1)
-      return rewriter.notifyMatchFailure(reduceOp, "not 1-D vector source");
-
-    // Extract all elements.
-    int numElements = srcVectorType.getDimSize(0);
-    SmallVector<Value, 4> values;
-    values.reserve(numElements + (adaptor.getAcc() != nullptr));
-    Location loc = reduceOp.getLoc();
-    for (int i = 0; i < numElements; ++i) {
-      values.push_back(rewriter.create<spirv::CompositeExtractOp>(
-          loc, srcVectorType.getElementType(), adaptor.getVector(),
-          rewriter.getI32ArrayAttr({i})));
-    }
-    if (Value acc = adaptor.getAcc())
-      values.push_back(acc);
-
-    // Reduce them.
-    Value result = values.front();
-    for (Value next : llvm::ArrayRef(values).drop_front()) {
+    auto [resultType, extractedElements] = *reductionInfo;
+    mlir::Location loc = reduceOp->getLoc();
+    Value result = extractedElements.front();
+    for (Value next : llvm::drop_begin(extractedElements)) {
       switch (reduceOp.getKind()) {
 
 #define INT_AND_FLOAT_CASE(kind, iop, fop)                                     \
@@ -403,10 +430,6 @@ struct VectorReductionPattern final
 
         INT_AND_FLOAT_CASE(ADD, IAddOp, FAddOp);
         INT_AND_FLOAT_CASE(MUL, IMulOp, FMulOp);
-        INT_OR_FLOAT_CASE(MAXIMUMF, SPIRVFMaxOp);
-        INT_OR_FLOAT_CASE(MINIMUMF, SPIRVFMinOp);
-        INT_OR_FLOAT_CASE(MAXF, SPIRVFMaxOp);
-        INT_OR_FLOAT_CASE(MINF, SPIRVFMinOp);
         INT_OR_FLOAT_CASE(MINUI, SPIRVUMinOp);
         INT_OR_FLOAT_CASE(MINSI, SPIRVSMinOp);
         INT_OR_FLOAT_CASE(MAXUI, SPIRVUMaxOp);
@@ -416,8 +439,51 @@ struct VectorReductionPattern final
       case vector::CombiningKind::OR:
       case vector::CombiningKind::XOR:
         return rewriter.notifyMatchFailure(reduceOp, "unimplemented");
+      default:
+        return rewriter.notifyMatchFailure(reduceOp, "not handled here");
       }
     }
+#undef INT_AND_FLOAT_CASE
+#undef INT_OR_FLOAT_CASE
+
+    rewriter.replaceOp(reduceOp, result);
+    return success();
+  }
+};
+
+template <typename SPIRVFMaxOp, typename SPIRVFMinOp>
+struct VectorReductionFloatMinMax final : OpConversionPattern<vector::ReductionOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto reductionInfo =
+        getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter());
+    if (failed(reductionInfo))
+      return failure();
+
+    auto [resultType, extractedElements] = *reductionInfo;
+    mlir::Location loc = reduceOp->getLoc();
+    Value result = extractedElements.front();
+    for (Value next : llvm::ArrayRef(extractedElements).drop_front()) {
+      switch (reduceOp.getKind()) {
+
+#define INT_OR_FLOAT_CASE(kind, fop)                                           \
+  case vector::CombiningKind::kind:                                            \
+    result = rewriter.create<fop>(loc, resultType, result, next);              \
+    break
+
+        INT_OR_FLOAT_CASE(MAXIMUMF, SPIRVFMaxOp);
+        INT_OR_FLOAT_CASE(MINIMUMF, SPIRVFMinOp);
+        INT_OR_FLOAT_CASE(MAXF, SPIRVFMaxOp);
+        INT_OR_FLOAT_CASE(MINF, SPIRVFMinOp);
+
+      default:
+        return rewriter.notifyMatchFailure(reduceOp, "not handled here");
+      }
+    }
+#undef INT_OR_FLOAT_CASE
 
     rewriter.replaceOp(reduceOp, result);
     return success();
@@ -674,13 +740,14 @@ struct VectorReductionToDotProd final : OpRewritePattern<vector::ReductionOp> {
 };
 
 } // namespace
-#define CL_MAX_MIN_OPS                                                         \
-  spirv::CLFMaxOp, spirv::CLFMinOp, spirv::CLUMaxOp, spirv::CLUMinOp,          \
-      spirv::CLSMaxOp, spirv::CLSMinOp
+#define CL_INT_MAX_MIN_OPS                                                     \
+  spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
+
+#define GL_INT_MAX_MIN_OPS                                                     \
+  spirv::GLUMaxOp, spirv::GLUMinOp, spirv::GLSMaxOp, spirv::GLSMinOp
 
-#define GL_MAX_MIN_OPS                                                         \
-  spirv::GLFMaxOp, spirv::GLFMinOp, spirv::GLUMaxOp, spirv::GLUMinOp,          \
-      spirv::GLSMaxOp, spirv::GLSMinOp
+#define CL_FLOAT_MAX_MIN_OPS spirv::CLFMaxOp, spirv::CLFMinOp
+#define GL_FLOAT_MAX_MIN_OPS spirv::GLFMaxOp, spirv::GLFMinOp
 
 void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                          RewritePatternSet &patterns) {
@@ -689,8 +756,10 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
       VectorExtractElementOpConvert, VectorExtractOpConvert,
       VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
       VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert,
-      VectorInsertOpConvert, VectorReductionPattern<GL_MAX_MIN_OPS>,
-      VectorReductionPattern<CL_MAX_MIN_OPS>, VectorShapeCast,
+      VectorInsertOpConvert, VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
+      VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
+      VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
+      VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
       VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
       VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
       typeConverter, patterns.getContext());

``````````

</details>


https://github.com/llvm/llvm-project/pull/73363


More information about the Mlir-commits mailing list