[Mlir-commits] [mlir] [mlir][ArmSME] Lower extract from 2D scalable create_mask to psel (PR #96066)

Benjamin Maxwell llvmlistbot at llvm.org
Wed Jun 19 08:56:48 PDT 2024


================
@@ -549,6 +550,77 @@ struct VectorExtractToArmSMELowering
   }
 };
 
+/// Lower a `vector.extract` from a 2-D scalable `vector.create_mask` to
+/// `arm_sve.psel`. Note: While psel is under ArmSVE it requires SME (or
+/// SVE 2.1), so this is currently the most logical place for this lowering.
+///
+/// Example:
+/// ```mlir
+/// %mask = vector.create_mask %a, %b : vector<[4]x[8]xi1>
+/// %slice = vector.extract %mask[%index]
+///                                   : vector<[8]xi1> from vector<[4]x[8]xi1>
+/// ```
+/// Becomes:
+/// ```
+/// %mask_rows = vector.create_mask %a : vector<[4]xi1>
+/// %mask_cols = vector.create_mask %b : vector<[8]xi1>
+/// %slice = arm_sve.psel %mask_cols, %mask_rows[%index]
+///                                    : vector<[8]xi1>, vector<[4]xi1>
+/// ```
+struct VectorExtractFromMaskToPselLowering
+    : public OpRewritePattern<vector::ExtractOp> {
+  using OpRewritePattern<vector::ExtractOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
+                                PatternRewriter &rewriter) const override {
+    if (extractOp.getNumIndices() != 1)
+      return rewriter.notifyMatchFailure(extractOp, "not single extract index");
+
+    auto resultType = extractOp.getResult().getType();
+    auto resultVectorType = dyn_cast<VectorType>(resultType);
----------------
MacDue wrote:

The line count is the same as this wraps to:
```c++
    auto resultVectorType =
        dyn_cast<VectorType>(extractOp.getResult().getType());
```
I think the extra variable looks a little nicer :)

https://github.com/llvm/llvm-project/pull/96066


More information about the Mlir-commits mailing list