[Mlir-commits] [mlir] [mlir][spirv] Split codegen for float min/max reductions and others v2. [NFC] (PR #73363)

Jakub Kuderski llvmlistbot at llvm.org
Fri Nov 24 12:09:41 PST 2023


https://github.com/kuhar updated https://github.com/llvm/llvm-project/pull/73363

>From e8976950b559805233851f9aec8ae65c70969a9b Mon Sep 17 00:00:00 2001
From: Daniil Dudkin <unterumarmung at yandex.ru>
Date: Thu, 21 Sep 2023 22:47:21 +0300
Subject: [PATCH 1/3] [mlir][spirv] Split codegen for float min/max reductions
 and others (NFC)

This patch is part of a larger initiative aimed at fixing floating-point `max` and `min` operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671.

There are two types of min/max operations for floating-point numbers: `minf`/`maxf` and `minimumf`/`maximumf`. The code generation for these operations should differ from that of other vector reduction kinds. This difference arises because CL and GL operations for floating-point min and max do not have the same semantics when handling NaNs. Therefore, we must enforce the desired semantics with additional ops.

However, since the code generation for floating-point min/max operations shares the same functionality as extracting values for the vector, we have decided to refactor the existing code using the CRTP pattern. This change does not alter the actual behavior of the code and is necessary for future fixes to the codegen for floating-point min/max operations.
---
 .../VectorToSPIRV/VectorToSPIRV.cpp           | 116 ++++++++++++++----
 1 file changed, 94 insertions(+), 22 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index dcc6449d3fe8927..f1d6c849fef6842 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -20,6 +20,7 @@
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Location.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
@@ -351,15 +352,13 @@ struct VectorInsertStridedSliceOpConvert final
   }
 };
 
-template <class SPIRVFMaxOp, class SPIRVFMinOp, class SPIRVUMaxOp,
-          class SPIRVUMinOp, class SPIRVSMaxOp, class SPIRVSMinOp>
-struct VectorReductionPattern final
-    : public OpConversionPattern<vector::ReductionOp> {
+template <typename Derived>
+struct VectorReductionPatternBase : OpConversionPattern<vector::ReductionOp> {
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
+                  ConversionPatternRewriter &rewriter) const final {
     Type resultType = typeConverter->convertType(reduceOp.getType());
     if (!resultType)
       return failure();
@@ -368,9 +367,22 @@ struct VectorReductionPattern final
     if (!srcVectorType || srcVectorType.getRank() != 1)
       return rewriter.notifyMatchFailure(reduceOp, "not 1-D vector source");
 
-    // Extract all elements.
+    SmallVector<Value> extractedElements =
+        extractAllElements(reduceOp, adaptor, srcVectorType, rewriter);
+
+    const auto &self = static_cast<const Derived &>(*this);
+
+    return self.reduceExtracted(reduceOp, extractedElements, resultType,
+                                rewriter);
+  }
+
+private:
+  SmallVector<Value>
+  extractAllElements(vector::ReductionOp reduceOp, OpAdaptor adaptor,
+                     VectorType srcVectorType,
+                     ConversionPatternRewriter &rewriter) const {
     int numElements = srcVectorType.getDimSize(0);
-    SmallVector<Value, 4> values;
+    SmallVector<Value> values;
     values.reserve(numElements + (adaptor.getAcc() != nullptr));
     Location loc = reduceOp.getLoc();
     for (int i = 0; i < numElements; ++i) {
@@ -381,9 +393,26 @@ struct VectorReductionPattern final
     if (Value acc = adaptor.getAcc())
       values.push_back(acc);
 
-    // Reduce them.
-    Value result = values.front();
-    for (Value next : llvm::ArrayRef(values).drop_front()) {
+    return values;
+  }
+};
+
+#define VECTOR_REDUCTION_BASE                                                  \
+  VectorReductionPatternBase<VectorReductionPattern<SPIRVUMaxOp, SPIRVUMinOp,  \
+                                                    SPIRVSMaxOp, SPIRVSMinOp>>
+template <typename SPIRVUMaxOp, typename SPIRVUMinOp, typename SPIRVSMaxOp,
+          typename SPIRVSMinOp>
+struct VectorReductionPattern final : VECTOR_REDUCTION_BASE {
+  using Base = VECTOR_REDUCTION_BASE;
+  using Base::Base;
+
+  LogicalResult reduceExtracted(vector::ReductionOp reduceOp,
+                                ArrayRef<Value> extractedElements,
+                                Type resultType,
+                                ConversionPatternRewriter &rewriter) const {
+    mlir::Location loc = reduceOp->getLoc();
+    Value result = extractedElements.front();
+    for (Value next : llvm::ArrayRef(extractedElements).drop_front()) {
       switch (reduceOp.getKind()) {
 
 #define INT_AND_FLOAT_CASE(kind, iop, fop)                                     \
@@ -403,10 +432,6 @@ struct VectorReductionPattern final
 
         INT_AND_FLOAT_CASE(ADD, IAddOp, FAddOp);
         INT_AND_FLOAT_CASE(MUL, IMulOp, FMulOp);
-        INT_OR_FLOAT_CASE(MAXIMUMF, SPIRVFMaxOp);
-        INT_OR_FLOAT_CASE(MINIMUMF, SPIRVFMinOp);
-        INT_OR_FLOAT_CASE(MAXF, SPIRVFMaxOp);
-        INT_OR_FLOAT_CASE(MINF, SPIRVFMinOp);
         INT_OR_FLOAT_CASE(MINUI, SPIRVUMinOp);
         INT_OR_FLOAT_CASE(MINSI, SPIRVSMinOp);
         INT_OR_FLOAT_CASE(MAXUI, SPIRVUMaxOp);
@@ -416,6 +441,8 @@ struct VectorReductionPattern final
       case vector::CombiningKind::OR:
       case vector::CombiningKind::XOR:
         return rewriter.notifyMatchFailure(reduceOp, "unimplemented");
+      default:
+        return rewriter.notifyMatchFailure(reduceOp, "not handled here");
       }
     }
 
@@ -423,6 +450,48 @@ struct VectorReductionPattern final
     return success();
   }
 };
+#undef VECTOR_REDUCTION_BASE
+#undef INT_AND_FLOAT_CASE
+#undef INT_OR_FLOAT_CASE
+
+#define MIN_MAX_PATTERN_BASE                                                   \
+  VectorReductionPatternBase<                                                  \
+      VectorReductionFloatMinMax<SPIRVFMaxOp, SPIRVFMinOp>>
+template <class SPIRVFMaxOp, class SPIRVFMinOp>
+struct VectorReductionFloatMinMax final : MIN_MAX_PATTERN_BASE {
+  using Base = MIN_MAX_PATTERN_BASE;
+  using Base::Base;
+
+  LogicalResult reduceExtracted(vector::ReductionOp reduceOp,
+                                ArrayRef<Value> extractedElements,
+                                Type resultType,
+                                ConversionPatternRewriter &rewriter) const {
+    mlir::Location loc = reduceOp->getLoc();
+    Value result = extractedElements.front();
+    for (Value next : llvm::ArrayRef(extractedElements).drop_front()) {
+      switch (reduceOp.getKind()) {
+
+#define INT_OR_FLOAT_CASE(kind, fop)                                           \
+  case vector::CombiningKind::kind:                                            \
+    result = rewriter.create<fop>(loc, resultType, result, next);              \
+    break
+
+        INT_OR_FLOAT_CASE(MAXIMUMF, SPIRVFMaxOp);
+        INT_OR_FLOAT_CASE(MINIMUMF, SPIRVFMinOp);
+        INT_OR_FLOAT_CASE(MAXF, SPIRVFMaxOp);
+        INT_OR_FLOAT_CASE(MINF, SPIRVFMinOp);
+
+      default:
+        return rewriter.notifyMatchFailure(reduceOp, "not handled here");
+      }
+    }
+
+    rewriter.replaceOp(reduceOp, result);
+    return success();
+  }
+};
+#undef MIN_MAX_PATTERN_BASE
+#undef INT_OR_FLOAT_CASE
 
 class VectorSplatPattern final : public OpConversionPattern<vector::SplatOp> {
 public:
@@ -674,13 +743,14 @@ struct VectorReductionToDotProd final : OpRewritePattern<vector::ReductionOp> {
 };
 
 } // namespace
-#define CL_MAX_MIN_OPS                                                         \
-  spirv::CLFMaxOp, spirv::CLFMinOp, spirv::CLUMaxOp, spirv::CLUMinOp,          \
-      spirv::CLSMaxOp, spirv::CLSMinOp
+#define CL_INT_MAX_MIN_OPS                                                     \
+  spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
+
+#define GL_INT_MAX_MIN_OPS                                                     \
+  spirv::GLUMaxOp, spirv::GLUMinOp, spirv::GLSMaxOp, spirv::GLSMinOp
 
-#define GL_MAX_MIN_OPS                                                         \
-  spirv::GLFMaxOp, spirv::GLFMinOp, spirv::GLUMaxOp, spirv::GLUMinOp,          \
-      spirv::GLSMaxOp, spirv::GLSMinOp
+#define CL_FLOAT_MAX_MIN_OPS spirv::CLFMaxOp, spirv::CLFMinOp
+#define GL_FLOAT_MAX_MIN_OPS spirv::GLFMaxOp, spirv::GLFMinOp
 
 void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                          RewritePatternSet &patterns) {
@@ -689,8 +759,10 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
       VectorExtractElementOpConvert, VectorExtractOpConvert,
       VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
       VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert,
-      VectorInsertOpConvert, VectorReductionPattern<GL_MAX_MIN_OPS>,
-      VectorReductionPattern<CL_MAX_MIN_OPS>, VectorShapeCast,
+      VectorInsertOpConvert, VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
+      VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
+      VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
+      VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
       VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
       VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
       typeConverter, patterns.getContext());

>From c283ffc7c54efecd2520c057b9871f65780f1caf Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Fri, 24 Nov 2023 15:01:30 -0500
Subject: [PATCH 2/3] [mlir][spirv] Simplify vector reduction to spirv
 conversion patterns

Use free functions instead of CRTP.
---
 .../VectorToSPIRV/VectorToSPIRV.cpp           | 131 +++++++++---------
 1 file changed, 64 insertions(+), 67 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index f1d6c849fef6842..29bc5f1dd73787f 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -28,6 +28,7 @@
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/SmallVectorExtras.h"
 #include "llvm/Support/FormatVariadic.h"
 #include <cassert>
@@ -352,67 +353,64 @@ struct VectorInsertStridedSliceOpConvert final
   }
 };
 
-template <typename Derived>
-struct VectorReductionPatternBase : OpConversionPattern<vector::ReductionOp> {
-  using OpConversionPattern::OpConversionPattern;
+static SmallVector<Value> extractAllElements(vector::ReductionOp reduceOp,
+                                      vector::ReductionOp::Adaptor adaptor,
+                                      VectorType srcVectorType,
+                                      ConversionPatternRewriter &rewriter) {
+  int numElements = srcVectorType.getDimSize(0);
+  SmallVector<Value> values;
+  values.reserve(numElements + (adaptor.getAcc() != nullptr));
+  Location loc = reduceOp.getLoc();
+  for (int i = 0; i < numElements; ++i) {
+    values.push_back(rewriter.create<spirv::CompositeExtractOp>(
+        loc, srcVectorType.getElementType(), adaptor.getVector(),
+        rewriter.getI32ArrayAttr({i})));
+  }
+  if (Value acc = adaptor.getAcc())
+    values.push_back(acc);
 
-  LogicalResult
-  matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const final {
-    Type resultType = typeConverter->convertType(reduceOp.getType());
-    if (!resultType)
-      return failure();
+  return values;
+}
 
-    auto srcVectorType = dyn_cast<VectorType>(adaptor.getVector().getType());
-    if (!srcVectorType || srcVectorType.getRank() != 1)
-      return rewriter.notifyMatchFailure(reduceOp, "not 1-D vector source");
+struct ReductionRewriteInfo {
+  Type resultType;
+  SmallVector<Value> extractedElements;
+};
 
-    SmallVector<Value> extractedElements =
-        extractAllElements(reduceOp, adaptor, srcVectorType, rewriter);
+FailureOr<ReductionRewriteInfo> static getReductionInfo(
+    vector::ReductionOp op, vector::ReductionOp::Adaptor adaptor,
+    ConversionPatternRewriter &rewriter, const TypeConverter &typeConverter) {
+  Type resultType = typeConverter.convertType(op.getType());
+  if (!resultType)
+    return failure();
 
-    const auto &self = static_cast<const Derived &>(*this);
+  auto srcVectorType = dyn_cast<VectorType>(adaptor.getVector().getType());
+  if (!srcVectorType || srcVectorType.getRank() != 1)
+    return rewriter.notifyMatchFailure(op, "not a 1-D vector source");
 
-    return self.reduceExtracted(reduceOp, extractedElements, resultType,
-                                rewriter);
-  }
+  SmallVector<Value> extractedElements =
+      extractAllElements(op, adaptor, srcVectorType, rewriter);
 
-private:
-  SmallVector<Value>
-  extractAllElements(vector::ReductionOp reduceOp, OpAdaptor adaptor,
-                     VectorType srcVectorType,
-                     ConversionPatternRewriter &rewriter) const {
-    int numElements = srcVectorType.getDimSize(0);
-    SmallVector<Value> values;
-    values.reserve(numElements + (adaptor.getAcc() != nullptr));
-    Location loc = reduceOp.getLoc();
-    for (int i = 0; i < numElements; ++i) {
-      values.push_back(rewriter.create<spirv::CompositeExtractOp>(
-          loc, srcVectorType.getElementType(), adaptor.getVector(),
-          rewriter.getI32ArrayAttr({i})));
-    }
-    if (Value acc = adaptor.getAcc())
-      values.push_back(acc);
-
-    return values;
-  }
-};
+  return ReductionRewriteInfo{resultType, std::move(extractedElements)};
+}
 
-#define VECTOR_REDUCTION_BASE                                                  \
-  VectorReductionPatternBase<VectorReductionPattern<SPIRVUMaxOp, SPIRVUMinOp,  \
-                                                    SPIRVSMaxOp, SPIRVSMinOp>>
 template <typename SPIRVUMaxOp, typename SPIRVUMinOp, typename SPIRVSMaxOp,
           typename SPIRVSMinOp>
-struct VectorReductionPattern final : VECTOR_REDUCTION_BASE {
-  using Base = VECTOR_REDUCTION_BASE;
-  using Base::Base;
-
-  LogicalResult reduceExtracted(vector::ReductionOp reduceOp,
-                                ArrayRef<Value> extractedElements,
-                                Type resultType,
-                                ConversionPatternRewriter &rewriter) const {
+struct VectorReductionPattern final : OpConversionPattern<vector::ReductionOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto reductionInfo =
+        getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter());
+    if (failed(reductionInfo))
+      return failure();
+
+    auto [resultType, extractedElements] = *reductionInfo;
     mlir::Location loc = reduceOp->getLoc();
     Value result = extractedElements.front();
-    for (Value next : llvm::ArrayRef(extractedElements).drop_front()) {
+    for (Value next : llvm::drop_begin(extractedElements)) {
       switch (reduceOp.getKind()) {
 
 #define INT_AND_FLOAT_CASE(kind, iop, fop)                                     \
@@ -445,27 +443,27 @@ struct VectorReductionPattern final : VECTOR_REDUCTION_BASE {
         return rewriter.notifyMatchFailure(reduceOp, "not handled here");
       }
     }
+#undef INT_AND_FLOAT_CASE
+#undef INT_OR_FLOAT_CASE
 
     rewriter.replaceOp(reduceOp, result);
     return success();
   }
 };
-#undef VECTOR_REDUCTION_BASE
-#undef INT_AND_FLOAT_CASE
-#undef INT_OR_FLOAT_CASE
 
-#define MIN_MAX_PATTERN_BASE                                                   \
-  VectorReductionPatternBase<                                                  \
-      VectorReductionFloatMinMax<SPIRVFMaxOp, SPIRVFMinOp>>
-template <class SPIRVFMaxOp, class SPIRVFMinOp>
-struct VectorReductionFloatMinMax final : MIN_MAX_PATTERN_BASE {
-  using Base = MIN_MAX_PATTERN_BASE;
-  using Base::Base;
-
-  LogicalResult reduceExtracted(vector::ReductionOp reduceOp,
-                                ArrayRef<Value> extractedElements,
-                                Type resultType,
-                                ConversionPatternRewriter &rewriter) const {
+template <typename SPIRVFMaxOp, typename SPIRVFMinOp>
+struct VectorReductionFloatMinMax final : OpConversionPattern<vector::ReductionOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto reductionInfo =
+        getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter());
+    if (failed(reductionInfo))
+      return failure();
+
+    auto [resultType, extractedElements] = *reductionInfo;
     mlir::Location loc = reduceOp->getLoc();
     Value result = extractedElements.front();
     for (Value next : llvm::ArrayRef(extractedElements).drop_front()) {
@@ -485,13 +483,12 @@ struct VectorReductionFloatMinMax final : MIN_MAX_PATTERN_BASE {
         return rewriter.notifyMatchFailure(reduceOp, "not handled here");
       }
     }
+#undef INT_OR_FLOAT_CASE
 
     rewriter.replaceOp(reduceOp, result);
     return success();
   }
 };
-#undef MIN_MAX_PATTERN_BASE
-#undef INT_OR_FLOAT_CASE
 
 class VectorSplatPattern final : public OpConversionPattern<vector::SplatOp> {
 public:

>From bac12cc97ad062759580da4ed90df5ecfb7f3fa1 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Fri, 24 Nov 2023 15:09:27 -0500
Subject: [PATCH 3/3] Use llvm::drop_begin instead of ArrayRef + drop_front

---
 mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 29bc5f1dd73787f..57531b8f05e923b 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -353,10 +353,9 @@ struct VectorInsertStridedSliceOpConvert final
   }
 };
 
-static SmallVector<Value> extractAllElements(vector::ReductionOp reduceOp,
-                                      vector::ReductionOp::Adaptor adaptor,
-                                      VectorType srcVectorType,
-                                      ConversionPatternRewriter &rewriter) {
+static SmallVector<Value> extractAllElements(
+    vector::ReductionOp reduceOp, vector::ReductionOp::Adaptor adaptor,
+    VectorType srcVectorType, ConversionPatternRewriter &rewriter) {
   int numElements = srcVectorType.getDimSize(0);
   SmallVector<Value> values;
   values.reserve(numElements + (adaptor.getAcc() != nullptr));
@@ -452,7 +451,8 @@ struct VectorReductionPattern final : OpConversionPattern<vector::ReductionOp> {
 };
 
 template <typename SPIRVFMaxOp, typename SPIRVFMinOp>
-struct VectorReductionFloatMinMax final : OpConversionPattern<vector::ReductionOp> {
+struct VectorReductionFloatMinMax final
+    : OpConversionPattern<vector::ReductionOp> {
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
@@ -466,7 +466,7 @@ struct VectorReductionFloatMinMax final : OpConversionPattern<vector::ReductionO
     auto [resultType, extractedElements] = *reductionInfo;
     mlir::Location loc = reduceOp->getLoc();
     Value result = extractedElements.front();
-    for (Value next : llvm::ArrayRef(extractedElements).drop_front()) {
+    for (Value next : llvm::drop_begin(extractedElements)) {
       switch (reduceOp.getKind()) {
 
 #define INT_OR_FLOAT_CASE(kind, fop)                                           \



More information about the Mlir-commits mailing list