[Mlir-commits] [mlir] [mlir][ArmSME] Support decomposing constant splats into ArmSME tiles (PR #88762)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Apr 15 10:35:30 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Benjamin Maxwell (MacDue)
<details>
<summary>Changes</summary>
This adds a simple rewrite/legalization to decompose constant splats larger than a single ArmSME tile into multiple SME virtual tile sized splats. E.g. a constant splat to `vector<[8]x[8]xi32>` would decompose into four `vector<[4]x[4]xi32>` splats.
---
Full diff: https://github.com/llvm/llvm-project/pull/88762.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp (+31-1)
- (modified) mlir/test/Dialect/ArmSME/vector-legalization.mlir (+11)
``````````diff
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 31500c62c0d600..b595c6dd8a6848 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -165,6 +165,35 @@ int getNumberOfSMETilesForVectorType(VectorType type) {
return (vectorRows * vectorCols) / (minNumElts * minNumElts);
}
+/// Legalize `arith.constant dense<value>` splat operations to fit within SME
+/// tiles by decomposing them into tile-sized operations.
+struct LegalizeArithConstantOpsByDecomposition
+ : public OneToNOpConversionPattern<arith::ConstantOp> {
+ using OneToNOpConversionPattern::OneToNOpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(arith::ConstantOp constantOp, OpAdaptor adaptor,
+ OneToNPatternRewriter &rewriter) const override {
+ auto vectorType = dyn_cast<VectorType>(constantOp.getType());
+ auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
+ if (!vectorType || !denseAttr || !denseAttr.isSplat())
+ return failure();
+
+ if (!isMultipleOfSMETileVectorType(vectorType))
+ return rewriter.notifyMatchFailure(constantOp,
+ kMatchFailureNotSMETileTypeMultiple);
+
+ auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
+ auto tileCount = getNumberOfSMETilesForVectorType(vectorType);
+ auto tileSplat = rewriter.create<arith::ConstantOp>(
+ constantOp.getLoc(), denseAttr.resizeSplat(smeTileType));
+ rewriter.replaceOp(constantOp, SmallVector<Value>(tileCount, tileSplat),
+ adaptor.getResultMapping());
+
+ return success();
+ }
+};
+
/// Legalize `vector.outerproduct` operations to fit within SME tiles by
/// decomposing them into tile-sized operations.
struct LegalizeVectorOuterProductOpsByDecomposition
@@ -637,7 +666,8 @@ struct VectorLegalizationPass
// Note: High benefit to ensure masked outer products are lowered first.
patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition>(
converter, context, 1024);
- patterns.add<LegalizeVectorOuterProductOpsByDecomposition,
+ patterns.add<LegalizeArithConstantOpsByDecomposition,
+ LegalizeVectorOuterProductOpsByDecomposition,
LegalizeTransferReadOpsByDecomposition,
LegalizeTransferWriteOpsByDecomposition>(converter, context);
populateFuncTypeConversionPatterns(converter, patterns);
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index f8be697548c197..f43ef1cce787c5 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -433,3 +433,14 @@ func.func @lift_illegal_1d_shape_cast_to_memory(%a: index, %b: index, %memref: m
%cast = vector.shape_cast %illegalRead : vector<[4]x1xf32> to vector<[4]xf32>
return %cast : vector<[4]xf32>
}
+
+// -----
+
+// CHECK-LABEL: @multi_tile_splat
+func.func @multi_tile_splat() -> vector<[8]x[8]xi32>
+{
+ // CHECK: %[[SPLAT:.*]] = arith.constant dense<42> : vector<[4]x[4]xi32>
+ // CHECK-NEXT: return %[[SPLAT]], %[[SPLAT]], %[[SPLAT]], %[[SPLAT]] : vector<[4]x[4]xi32>, vector<[4]x[4]xi32>, vector<[4]x[4]xi32>, vector<[4]x[4]xi32>
+ %0 = arith.constant dense<42> : vector<[8]x[8]xi32>
+ return %0 : vector<[8]x[8]xi32>
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/88762
More information about the Mlir-commits
mailing list