[Mlir-commits] [mlir] d625ea1 - [mlir][spirv] Split codegen for float min/max reductions and others v2. [NFC] (#73363)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Nov 24 12:24:49 PST 2023
Author: Jakub Kuderski
Date: 2023-11-24T15:24:45-05:00
New Revision: d625ea12c71813db0da4c2e5867e907da22e22f2
URL: https://github.com/llvm/llvm-project/commit/d625ea12c71813db0da4c2e5867e907da22e22f2
DIFF: https://github.com/llvm/llvm-project/commit/d625ea12c71813db0da4c2e5867e907da22e22f2.diff
LOG: [mlir][spirv] Split codegen for float min/max reductions and others v2. [NFC] (#73363)
This is https://github.com/llvm/llvm-project/pull/69023 but with
cleanups.
Reduced complexity by avoiding CRTP and preprocessor defines in favor of
free functions
Original description by @unterumarmung:
---
This patch is part of a larger initiative aimed at fixing floating-point
`max` and `min` operations in MLIR:
https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671.
There are two types of min/max operations for floating-point numbers:
`minf`/`maxf` and `minimumf`/`maximumf`. The code generation for these
operations should differ from that of other vector reduction kinds. This
difference arises because CL and GL operations for floating-point min
and max do not have the same semantics when handling NaNs. Therefore, we
must enforce the desired semantics with additional ops.
~~However, since the code generation for floating-point min/max
operations shares the same functionality as extracting values for the
vector, we have decided to refactor the existing code using the CRTP
pattern.~~ This change does not alter the actual behavior of the code
and is necessary for future fixes to the codegen for floating-point
min/max operations.
---------
Co-authored-by: Daniil Dudkin <unterumarmung at yandex.ru>
Added:
Modified:
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index dcc6449d3fe8927..05ef535dde4b5c7 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -20,6 +20,7 @@
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Location.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
@@ -27,6 +28,7 @@
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include <cassert>
@@ -351,39 +353,64 @@ struct VectorInsertStridedSliceOpConvert final
}
};
-template <class SPIRVFMaxOp, class SPIRVFMinOp, class SPIRVUMaxOp,
- class SPIRVUMinOp, class SPIRVSMaxOp, class SPIRVSMinOp>
-struct VectorReductionPattern final
- : public OpConversionPattern<vector::ReductionOp> {
+static SmallVector<Value> extractAllElements(
+ vector::ReductionOp reduceOp, vector::ReductionOp::Adaptor adaptor,
+ VectorType srcVectorType, ConversionPatternRewriter &rewriter) {
+ int numElements = static_cast<int>(srcVectorType.getDimSize(0));
+ SmallVector<Value> values;
+ values.reserve(numElements + (adaptor.getAcc() ? 1 : 0));
+ Location loc = reduceOp.getLoc();
+
+ for (int i = 0; i < numElements; ++i) {
+ values.push_back(rewriter.create<spirv::CompositeExtractOp>(
+ loc, srcVectorType.getElementType(), adaptor.getVector(),
+ rewriter.getI32ArrayAttr({i})));
+ }
+ if (Value acc = adaptor.getAcc())
+ values.push_back(acc);
+
+ return values;
+}
+
+struct ReductionRewriteInfo {
+ Type resultType;
+ SmallVector<Value> extractedElements;
+};
+
+FailureOr<ReductionRewriteInfo> static getReductionInfo(
+ vector::ReductionOp op, vector::ReductionOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter, const TypeConverter &typeConverter) {
+ Type resultType = typeConverter.convertType(op.getType());
+ if (!resultType)
+ return failure();
+
+ auto srcVectorType = dyn_cast<VectorType>(adaptor.getVector().getType());
+ if (!srcVectorType || srcVectorType.getRank() != 1)
+ return rewriter.notifyMatchFailure(op, "not a 1-D vector source");
+
+ SmallVector<Value> extractedElements =
+ extractAllElements(op, adaptor, srcVectorType, rewriter);
+
+ return ReductionRewriteInfo{resultType, std::move(extractedElements)};
+}
+
+template <typename SPIRVUMaxOp, typename SPIRVUMinOp, typename SPIRVSMaxOp,
+ typename SPIRVSMinOp>
+struct VectorReductionPattern final : OpConversionPattern<vector::ReductionOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- Type resultType = typeConverter->convertType(reduceOp.getType());
- if (!resultType)
+ auto reductionInfo =
+ getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter());
+ if (failed(reductionInfo))
return failure();
- auto srcVectorType = dyn_cast<VectorType>(adaptor.getVector().getType());
- if (!srcVectorType || srcVectorType.getRank() != 1)
- return rewriter.notifyMatchFailure(reduceOp, "not 1-D vector source");
-
- // Extract all elements.
- int numElements = srcVectorType.getDimSize(0);
- SmallVector<Value, 4> values;
- values.reserve(numElements + (adaptor.getAcc() != nullptr));
- Location loc = reduceOp.getLoc();
- for (int i = 0; i < numElements; ++i) {
- values.push_back(rewriter.create<spirv::CompositeExtractOp>(
- loc, srcVectorType.getElementType(), adaptor.getVector(),
- rewriter.getI32ArrayAttr({i})));
- }
- if (Value acc = adaptor.getAcc())
- values.push_back(acc);
-
- // Reduce them.
- Value result = values.front();
- for (Value next : llvm::ArrayRef(values).drop_front()) {
+ auto [resultType, extractedElements] = *reductionInfo;
+ Location loc = reduceOp->getLoc();
+ Value result = extractedElements.front();
+ for (Value next : llvm::drop_begin(extractedElements)) {
switch (reduceOp.getKind()) {
#define INT_AND_FLOAT_CASE(kind, iop, fop) \
@@ -403,10 +430,6 @@ struct VectorReductionPattern final
INT_AND_FLOAT_CASE(ADD, IAddOp, FAddOp);
INT_AND_FLOAT_CASE(MUL, IMulOp, FMulOp);
- INT_OR_FLOAT_CASE(MAXIMUMF, SPIRVFMaxOp);
- INT_OR_FLOAT_CASE(MINIMUMF, SPIRVFMinOp);
- INT_OR_FLOAT_CASE(MAXF, SPIRVFMaxOp);
- INT_OR_FLOAT_CASE(MINF, SPIRVFMinOp);
INT_OR_FLOAT_CASE(MINUI, SPIRVUMinOp);
INT_OR_FLOAT_CASE(MINSI, SPIRVSMinOp);
INT_OR_FLOAT_CASE(MAXUI, SPIRVUMaxOp);
@@ -416,7 +439,51 @@ struct VectorReductionPattern final
case vector::CombiningKind::OR:
case vector::CombiningKind::XOR:
return rewriter.notifyMatchFailure(reduceOp, "unimplemented");
+ default:
+ return rewriter.notifyMatchFailure(reduceOp, "not handled here");
}
+#undef INT_AND_FLOAT_CASE
+#undef INT_OR_FLOAT_CASE
+ }
+
+ rewriter.replaceOp(reduceOp, result);
+ return success();
+ }
+};
+
+template <typename SPIRVFMaxOp, typename SPIRVFMinOp>
+struct VectorReductionFloatMinMax final
+ : OpConversionPattern<vector::ReductionOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto reductionInfo =
+ getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter());
+ if (failed(reductionInfo))
+ return failure();
+
+ auto [resultType, extractedElements] = *reductionInfo;
+ Location loc = reduceOp->getLoc();
+ Value result = extractedElements.front();
+ for (Value next : llvm::drop_begin(extractedElements)) {
+ switch (reduceOp.getKind()) {
+
+#define INT_OR_FLOAT_CASE(kind, fop) \
+ case vector::CombiningKind::kind: \
+ result = rewriter.create<fop>(loc, resultType, result, next); \
+ break
+
+ INT_OR_FLOAT_CASE(MAXIMUMF, SPIRVFMaxOp);
+ INT_OR_FLOAT_CASE(MINIMUMF, SPIRVFMinOp);
+ INT_OR_FLOAT_CASE(MAXF, SPIRVFMaxOp);
+ INT_OR_FLOAT_CASE(MINF, SPIRVFMinOp);
+
+ default:
+ return rewriter.notifyMatchFailure(reduceOp, "not handled here");
+ }
+#undef INT_OR_FLOAT_CASE
}
rewriter.replaceOp(reduceOp, result);
@@ -674,13 +741,14 @@ struct VectorReductionToDotProd final : OpRewritePattern<vector::ReductionOp> {
};
} // namespace
-#define CL_MAX_MIN_OPS \
- spirv::CLFMaxOp, spirv::CLFMinOp, spirv::CLUMaxOp, spirv::CLUMinOp, \
- spirv::CLSMaxOp, spirv::CLSMinOp
+#define CL_INT_MAX_MIN_OPS \
+ spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
+
+#define GL_INT_MAX_MIN_OPS \
+ spirv::GLUMaxOp, spirv::GLUMinOp, spirv::GLSMaxOp, spirv::GLSMinOp
-#define GL_MAX_MIN_OPS \
- spirv::GLFMaxOp, spirv::GLFMinOp, spirv::GLUMaxOp, spirv::GLUMinOp, \
- spirv::GLSMaxOp, spirv::GLSMinOp
+#define CL_FLOAT_MAX_MIN_OPS spirv::CLFMaxOp, spirv::CLFMinOp
+#define GL_FLOAT_MAX_MIN_OPS spirv::GLFMaxOp, spirv::GLFMinOp
void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns) {
@@ -689,8 +757,10 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
VectorExtractElementOpConvert, VectorExtractOpConvert,
VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert,
- VectorInsertOpConvert, VectorReductionPattern<GL_MAX_MIN_OPS>,
- VectorReductionPattern<CL_MAX_MIN_OPS>, VectorShapeCast,
+ VectorInsertOpConvert, VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
+ VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
+ VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
+ VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
typeConverter, patterns.getContext());
More information about the Mlir-commits
mailing list