[Mlir-commits] [mlir] [mlir][ArmSME] Add rewrites to swap extract of extend (PR #80407)
Benjamin Maxwell
llvmlistbot at llvm.org
Mon Feb 5 03:51:38 PST 2024
================
@@ -338,6 +338,105 @@ struct LegalizeTransferWriteOpsByDecomposition
}
};
+// Shuffles arith extend ops after vector.extract op.
+//
+// This transforms IR like:
+// %0 = arith.extsi %src : vector<4x[8]xi8> to vector<4x[8]xi32>
+// %1 = vector.extract %0[0] : vector<[8]xi32> from vector<4x[8]xi32>
+// Into:
+// %0 = vector.extract %src[0] : vector<[8]xi8> from vector<4x[8]xi8>
+// %1 = arith.extsi %0 : vector<[8]xi8> to vector<[8]xi32>
+//
+// This enables outer product fusion in the `-arm-sme-outer-product-fusion`
+// pass when the result is the input to an outer product.
+struct SwapVectorExtractOfArithExtend
+ : public OpRewritePattern<vector::ExtractOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
+ PatternRewriter &rewriter) const override {
+ VectorType resultType = llvm::dyn_cast<VectorType>(extractOp.getType());
+ if (!resultType)
+ return rewriter.notifyMatchFailure(extractOp,
+ "extracted type is not a vector type");
+
+ auto numScalableDims = llvm::count(resultType.getScalableDims(), true);
+ if (numScalableDims != 1)
+ return rewriter.notifyMatchFailure(
+ extractOp, "extracted type is not a 1-D scalable vector type");
+
+ auto *extendOp = extractOp.getVector().getDefiningOp();
+ if (!isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(
+ extendOp))
+ return rewriter.notifyMatchFailure(extractOp,
+ "extract not from extend op");
+
+ auto loc = extractOp.getLoc();
+ StringAttr extendOpName = extendOp->getName().getIdentifier();
+ Value extendSource = extendOp->getOperand(0);
+
+ // Create new extract from source of extend.
+ Value newExtract = rewriter.create<vector::ExtractOp>(
+ loc, extendSource, extractOp.getMixedPosition());
+
+ // Extend new extract to original result type.
+ Operation *newExtend =
+ rewriter.create(loc, extendOpName, Value(newExtract), resultType);
+
+ rewriter.replaceOp(extractOp, newExtend->getResult(0));
+
+ return success();
+ }
+};
+
+// Shuffles arith extend ops after vector.scalable.extract op.
+//
+// This transforms IR like:
+// %0 = arith.extsi %src : vector<[8]xi8> to vector<[8]xi32>
+// %1 = vector.scalable.extract %0[0] : vector<[4]xi32> from vector<[8]xi32>
+// Into:
+// %0 = vector.scalable.extract %src[0] : vector<[4]xi8> from vector<[8]xi8>
+// %1 = arith.extsi %0 : vector<[4]xi8> to vector<[4]xi32>
+//
+// This enables outer product fusion in the `-arm-sme-outer-product-fusion`
+// pass when the result is the input to an outer product.
+struct SwapVectorScalableExtractOfArithExtend
+ : public OpRewritePattern<vector::ScalableExtractOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ScalableExtractOp extractOp,
+ PatternRewriter &rewriter) const override {
+ auto *extendOp = extractOp.getSource().getDefiningOp();
+ if (!isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(
+ extendOp))
+ return rewriter.notifyMatchFailure(extractOp,
+ "extract not from extend op");
+
+ auto loc = extractOp.getLoc();
+ VectorType resultType = extractOp.getResultVectorType();
+
+ Value extendSource = extendOp->getOperand(0);
+ StringAttr extendOpName = extendOp->getName().getIdentifier();
+ VectorType extendSourceVectorType =
+ cast<VectorType>(extendSource.getType());
+
+ // Create new extract from source of extend.
+ VectorType extractResultVectorType =
+ VectorType::Builder(resultType)
+ .setElementType(extendSourceVectorType.getElementType());
----------------
MacDue wrote:
```suggestion
VectorType extractResultVectorType = resultType.clone(extendSourceVectorType.getElementType());
```
https://github.com/llvm/llvm-project/pull/80407
More information about the Mlir-commits
mailing list