[Mlir-commits] [mlir] [mlir][VectorOp] Move VectorMaskOpConversionBase template to header (NFC) (PR #69341)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Oct 17 08:07:04 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Benjamin Maxwell (MacDue)
<details>
<summary>Changes</summary>
This will be needed within the ArmSME conversions to lower masked outer products. The name has been updated to `ConvertVectorMaskOpToLLVMPattern` to fit in more with the other pattern base classes.
---
Full diff: https://github.com/llvm/llvm-project/pull/69341.diff
2 Files Affected:
- (modified) mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h (+32)
- (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (+4-36)
``````````diff
diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
index ecd33779236cc34..20f654a4ac245fa 100644
--- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
+++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
@@ -8,6 +8,8 @@
#ifndef MLIR_CONVERSION_VECTORTOLLVM_CONVERTVECTORTOLLVM_H_
#define MLIR_CONVERSION_VECTORTOLLVM_CONVERTVECTORTOLLVM_H_
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
@@ -24,6 +26,36 @@ void populateVectorToLLVMConversionPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns,
bool reassociateFPReductions = false, bool force32BitVectorIndices = false);
+/// Base class to convert a `vector.mask` operation while matching traits
+/// of the maskable operation nested inside. A
+/// `ConvertVectorMaskOpToLLVMPattern` instance matches against a `vector.mask`
+/// operation. The `matchAndRewrite` method performs a second match against the
+/// maskable operation `MaskedOp`. Finally, it invokes the virtual method
+/// `matchAndRewriteMaskableOp` to be implemented by the concrete conversion
+/// classes. This method can match against specific traits of the `vector.mask`
+/// and the maskable operation. It must replace the `vector.mask` operation.
+template <class MaskedOp>
+class ConvertVectorMaskOpToLLVMPattern
+ : public ConvertOpToLLVMPattern<vector::MaskOp> {
+public:
+ using ConvertOpToLLVMPattern<vector::MaskOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
+ // Match against the maskable operation kind.
+ auto maskedOp = llvm::dyn_cast_or_null<MaskedOp>(maskOp.getMaskableOp());
+ if (!maskedOp)
+ return failure();
+ return matchAndRewriteMaskableOp(maskOp, maskedOp, rewriter);
+ }
+
+protected:
+ virtual LogicalResult
+ matchAndRewriteMaskableOp(vector::MaskOp maskOp, MaskedOp maskableOp,
+ ConversionPatternRewriter &rewriter) const = 0;
+};
+
} // namespace mlir
#endif // MLIR_CONVERSION_VECTORTOLLVM_CONVERTVECTORTOLLVM_H_
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 8427d60f14c0bcc..d2864fd3ea67bb2 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -849,48 +849,16 @@ class VectorReductionOpConversion
const bool reassociateFPReductions;
};
-/// Base class to convert a `vector.mask` operation while matching traits
-/// of the maskable operation nested inside. A `VectorMaskOpConversionBase`
-/// instance matches against a `vector.mask` operation. The `matchAndRewrite`
-/// method performs a second match against the maskable operation `MaskedOp`.
-/// Finally, it invokes the virtual method `matchAndRewriteMaskableOp` to be
-/// implemented by the concrete conversion classes. This method can match
-/// against specific traits of the `vector.mask` and the maskable operation. It
-/// must replace the `vector.mask` operation.
-template <class MaskedOp>
-class VectorMaskOpConversionBase
- : public ConvertOpToLLVMPattern<vector::MaskOp> {
-public:
- using ConvertOpToLLVMPattern<vector::MaskOp>::ConvertOpToLLVMPattern;
-
- LogicalResult
- matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const final {
- // Match against the maskable operation kind.
- auto maskedOp = llvm::dyn_cast_or_null<MaskedOp>(maskOp.getMaskableOp());
- if (!maskedOp)
- return failure();
- return matchAndRewriteMaskableOp(maskOp, maskedOp, rewriter);
- }
-
-protected:
- virtual LogicalResult
- matchAndRewriteMaskableOp(vector::MaskOp maskOp,
- vector::MaskableOpInterface maskableOp,
- ConversionPatternRewriter &rewriter) const = 0;
-};
-
class MaskedReductionOpConversion
- : public VectorMaskOpConversionBase<vector::ReductionOp> {
+ : public ConvertVectorMaskOpToLLVMPattern<vector::ReductionOp> {
public:
- using VectorMaskOpConversionBase<
- vector::ReductionOp>::VectorMaskOpConversionBase;
+ using ConvertVectorMaskOpToLLVMPattern<
+ vector::ReductionOp>::ConvertVectorMaskOpToLLVMPattern;
LogicalResult matchAndRewriteMaskableOp(
- vector::MaskOp maskOp, MaskableOpInterface maskableOp,
+ vector::MaskOp maskOp, vector::ReductionOp reductionOp,
ConversionPatternRewriter &rewriter) const override {
- auto reductionOp = cast<ReductionOp>(maskableOp.getOperation());
auto kind = reductionOp.getKind();
Type eltType = reductionOp.getDest().getType();
Type llvmType = typeConverter->convertType(eltType);
``````````
</details>
https://github.com/llvm/llvm-project/pull/69341
More information about the Mlir-commits
mailing list