[Mlir-commits] [mlir] [mlir][ArmSME] Lower extract from 2D scalable create_mask to psel (PR #96066)
Cullen Rhodes
llvmlistbot at llvm.org
Wed Jun 19 07:41:33 PDT 2024
================
@@ -549,6 +550,77 @@ struct VectorExtractToArmSMELowering
}
};
+/// Lower a `vector.extract` from a 2-D scalable `vector.create_mask` to
+/// `arm_sve.psel`. Note: While psel is under ArmSVE it requires SME (or
+/// SVE 2.1), so this is currently the most logical place for this lowering.
+///
+/// Example:
+/// ```mlir
+/// %mask = vector.create_mask %a, %b : vector<[4]x[8]xi1>
+/// %slice = vector.extract %mask[%index]
+/// : vector<[8]xi1> from vector<[4]x[8]xi1>
+/// ```
+/// Becomes:
+/// ```
+/// %mask_rows = vector.create_mask %a : vector<[4]xi1>
+/// %mask_cols = vector.create_mask %b : vector<[8]xi1>
+/// %slice = arm_sve.psel %mask_cols, %mask_rows[%index]
+/// : vector<[8]xi1>, vector<[4]xi1>
+/// ```
+struct VectorExtractFromMaskToPselLowering
+ : public OpRewritePattern<vector::ExtractOp> {
+ using OpRewritePattern<vector::ExtractOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
+ PatternRewriter &rewriter) const override {
+ if (extractOp.getNumIndices() != 1)
+ return rewriter.notifyMatchFailure(extractOp, "not single extract index");
+
+ auto resultType = extractOp.getResult().getType();
+ auto resultVectorType = dyn_cast<VectorType>(resultType);
----------------
c-rhodes wrote:
nit: could be simplified
```suggestion
auto resultVectorType = dyn_cast<VectorType>(extractOp.getType());
```
https://github.com/llvm/llvm-project/pull/96066
More information about the Mlir-commits
mailing list