[Mlir-commits] [mlir] [mlir][VectorOp] Move VectorMaskOpConversionBase template to header (NFC) (PR #69341)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Oct 17 08:07:04 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Benjamin Maxwell (MacDue)

<details>
<summary>Changes</summary>

This will be needed within the ArmSME conversions to lower masked outer products. The name has been updated to `ConvertVectorMaskOpToLLVMPattern` to fit in more with the other pattern base classes.

---
Full diff: https://github.com/llvm/llvm-project/pull/69341.diff


2 Files Affected:

- (modified) mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h (+32) 
- (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (+4-36) 


``````````diff
diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
index ecd33779236cc34..20f654a4ac245fa 100644
--- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
+++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
@@ -8,6 +8,8 @@
 #ifndef MLIR_CONVERSION_VECTORTOLLVM_CONVERTVECTORTOLLVM_H_
 #define MLIR_CONVERSION_VECTORTOLLVM_CONVERTVECTORTOLLVM_H_
 
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Transforms/DialectConversion.h"
 
 namespace mlir {
@@ -24,6 +26,36 @@ void populateVectorToLLVMConversionPatterns(
     LLVMTypeConverter &converter, RewritePatternSet &patterns,
     bool reassociateFPReductions = false, bool force32BitVectorIndices = false);
 
+/// Base class to convert a `vector.mask` operation while matching traits
+/// of the maskable operation nested inside. A
+/// `ConvertVectorMaskOpToLLVMPattern` instance matches against a `vector.mask`
+/// operation. The `matchAndRewrite` method performs a second match against the
+/// maskable operation `MaskedOp`. Finally, it invokes the virtual method
+/// `matchAndRewriteMaskableOp` to be implemented by the concrete conversion
+/// classes. This method can match against specific traits of the `vector.mask`
+/// and the maskable operation. It must replace the `vector.mask` operation.
+template <class MaskedOp>
+class ConvertVectorMaskOpToLLVMPattern
+    : public ConvertOpToLLVMPattern<vector::MaskOp> {
+public:
+  using ConvertOpToLLVMPattern<vector::MaskOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const final {
+    // Match against the maskable operation kind.
+    auto maskedOp = llvm::dyn_cast_or_null<MaskedOp>(maskOp.getMaskableOp());
+    if (!maskedOp)
+      return failure();
+    return matchAndRewriteMaskableOp(maskOp, maskedOp, rewriter);
+  }
+
+protected:
+  virtual LogicalResult
+  matchAndRewriteMaskableOp(vector::MaskOp maskOp, MaskedOp maskableOp,
+                            ConversionPatternRewriter &rewriter) const = 0;
+};
+
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_VECTORTOLLVM_CONVERTVECTORTOLLVM_H_
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 8427d60f14c0bcc..d2864fd3ea67bb2 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -849,48 +849,16 @@ class VectorReductionOpConversion
   const bool reassociateFPReductions;
 };
 
-/// Base class to convert a `vector.mask` operation while matching traits
-/// of the maskable operation nested inside. A `VectorMaskOpConversionBase`
-/// instance matches against a `vector.mask` operation. The `matchAndRewrite`
-/// method performs a second match against the maskable operation `MaskedOp`.
-/// Finally, it invokes the virtual method `matchAndRewriteMaskableOp` to be
-/// implemented by the concrete conversion classes. This method can match
-/// against specific traits of the `vector.mask` and the maskable operation. It
-/// must replace the `vector.mask` operation.
-template <class MaskedOp>
-class VectorMaskOpConversionBase
-    : public ConvertOpToLLVMPattern<vector::MaskOp> {
-public:
-  using ConvertOpToLLVMPattern<vector::MaskOp>::ConvertOpToLLVMPattern;
-
-  LogicalResult
-  matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const final {
-    // Match against the maskable operation kind.
-    auto maskedOp = llvm::dyn_cast_or_null<MaskedOp>(maskOp.getMaskableOp());
-    if (!maskedOp)
-      return failure();
-    return matchAndRewriteMaskableOp(maskOp, maskedOp, rewriter);
-  }
-
-protected:
-  virtual LogicalResult
-  matchAndRewriteMaskableOp(vector::MaskOp maskOp,
-                            vector::MaskableOpInterface maskableOp,
-                            ConversionPatternRewriter &rewriter) const = 0;
-};
-
 class MaskedReductionOpConversion
-    : public VectorMaskOpConversionBase<vector::ReductionOp> {
+    : public ConvertVectorMaskOpToLLVMPattern<vector::ReductionOp> {
 
 public:
-  using VectorMaskOpConversionBase<
-      vector::ReductionOp>::VectorMaskOpConversionBase;
+  using ConvertVectorMaskOpToLLVMPattern<
+      vector::ReductionOp>::ConvertVectorMaskOpToLLVMPattern;
 
   LogicalResult matchAndRewriteMaskableOp(
-      vector::MaskOp maskOp, MaskableOpInterface maskableOp,
+      vector::MaskOp maskOp, vector::ReductionOp reductionOp,
       ConversionPatternRewriter &rewriter) const override {
-    auto reductionOp = cast<ReductionOp>(maskableOp.getOperation());
     auto kind = reductionOp.getKind();
     Type eltType = reductionOp.getDest().getType();
     Type llvmType = typeConverter->convertType(eltType);

``````````

</details>


https://github.com/llvm/llvm-project/pull/69341


More information about the Mlir-commits mailing list