[Mlir-commits] [mlir] [mlir][ArmSME] Add rewrites to swap extract of extend (PR #80407)

Benjamin Maxwell llvmlistbot at llvm.org
Mon Feb 5 03:51:38 PST 2024


================
@@ -338,6 +338,105 @@ struct LegalizeTransferWriteOpsByDecomposition
   }
 };
 
+// Shuffles arith extend ops after vector.extract op.
+//
+// This transforms IR like:
+//   %0 = arith.extsi %src : vector<4x[8]xi8> to vector<4x[8]xi32>
+//   %1 = vector.extract %0[0] : vector<[8]xi32> from vector<4x[8]xi32>
+// Into:
+//   %0 = vector.extract %src[0] : vector<[8]xi8> from vector<4x[8]xi8>
+//   %1 = arith.extsi %0 : vector<[8]xi8> to vector<[8]xi32>
+//
+// This enables outer product fusion in the `-arm-sme-outer-product-fusion`
+// pass when the result is the input to an outer product.
+struct SwapVectorExtractOfArithExtend
+    : public OpRewritePattern<vector::ExtractOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
+                                PatternRewriter &rewriter) const override {
+    VectorType resultType = llvm::dyn_cast<VectorType>(extractOp.getType());
+    if (!resultType)
+      return rewriter.notifyMatchFailure(extractOp,
+                                         "extracted type is not a vector type");
+
+    auto numScalableDims = llvm::count(resultType.getScalableDims(), true);
+    if (numScalableDims != 1)
+      return rewriter.notifyMatchFailure(
+          extractOp, "extracted type is not a 1-D scalable vector type");
+
+    auto *extendOp = extractOp.getVector().getDefiningOp();
+    if (!isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(
+            extendOp))
+      return rewriter.notifyMatchFailure(extractOp,
+                                         "extract not from extend op");
+
+    auto loc = extractOp.getLoc();
+    StringAttr extendOpName = extendOp->getName().getIdentifier();
+    Value extendSource = extendOp->getOperand(0);
+
+    // Create new extract from source of extend.
+    Value newExtract = rewriter.create<vector::ExtractOp>(
+        loc, extendSource, extractOp.getMixedPosition());
+
+    // Extend new extract to original result type.
+    Operation *newExtend =
+        rewriter.create(loc, extendOpName, Value(newExtract), resultType);
+
+    rewriter.replaceOp(extractOp, newExtend->getResult(0));
+
+    return success();
+  }
+};
+
+// Shuffles arith extend ops after vector.scalable.extract op.
+//
+// This transforms IR like:
+//   %0 = arith.extsi %src : vector<[8]xi8> to vector<[8]xi32>
+//   %1 = vector.scalable.extract %0[0] : vector<[4]xi32> from vector<[8]xi32>
+// Into:
+//   %0 = vector.scalable.extract %src[0] : vector<[4]xi8> from vector<[8]xi8>
+//   %1 = arith.extsi %0 : vector<[4]xi8> to vector<[4]xi32>
+//
+// This enables outer product fusion in the `-arm-sme-outer-product-fusion`
+// pass when the result is the input to an outer product.
+struct SwapVectorScalableExtractOfArithExtend
+    : public OpRewritePattern<vector::ScalableExtractOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ScalableExtractOp extractOp,
+                                PatternRewriter &rewriter) const override {
+    auto *extendOp = extractOp.getSource().getDefiningOp();
+    if (!isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(
+            extendOp))
+      return rewriter.notifyMatchFailure(extractOp,
+                                         "extract not from extend op");
+
+    auto loc = extractOp.getLoc();
+    VectorType resultType = extractOp.getResultVectorType();
+
+    Value extendSource = extendOp->getOperand(0);
+    StringAttr extendOpName = extendOp->getName().getIdentifier();
+    VectorType extendSourceVectorType =
+        cast<VectorType>(extendSource.getType());
+
+    // Create new extract from source of extend.
+    VectorType extractResultVectorType =
+        VectorType::Builder(resultType)
+            .setElementType(extendSourceVectorType.getElementType());
----------------
MacDue wrote:

```suggestion
    VectorType extractResultVectorType = resultType.clone(extendSourceVectorType.getElementType());
```

https://github.com/llvm/llvm-project/pull/80407


More information about the Mlir-commits mailing list