[Mlir-commits] [mlir] [mlir][ArmSME] Support decomposing constant splats into ArmSME tiles (PR #88762)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Apr 15 10:35:30 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Benjamin Maxwell (MacDue)

<details>
<summary>Changes</summary>

This adds a simple rewrite/legalization to decompose constant splats larger than a single ArmSME tile into multiple SME virtual tile sized splats. E.g. a constant splat to `vector<[8]x[8]xi32>` would decompose into four `vector<[4]x[4]xi32>` splats.

---
Full diff: https://github.com/llvm/llvm-project/pull/88762.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp (+31-1) 
- (modified) mlir/test/Dialect/ArmSME/vector-legalization.mlir (+11) 


``````````diff
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 31500c62c0d600..b595c6dd8a6848 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -165,6 +165,35 @@ int getNumberOfSMETilesForVectorType(VectorType type) {
   return (vectorRows * vectorCols) / (minNumElts * minNumElts);
 }
 
+/// Legalize `arith.constant dense<value>` splat operations to fit within SME
+/// tiles by decomposing them into tile-sized operations.
+struct LegalizeArithConstantOpsByDecomposition
+    : public OneToNOpConversionPattern<arith::ConstantOp> {
+  using OneToNOpConversionPattern::OneToNOpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(arith::ConstantOp constantOp, OpAdaptor adaptor,
+                  OneToNPatternRewriter &rewriter) const override {
+    auto vectorType = dyn_cast<VectorType>(constantOp.getType());
+    auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
+    if (!vectorType || !denseAttr || !denseAttr.isSplat())
+      return failure();
+
+    if (!isMultipleOfSMETileVectorType(vectorType))
+      return rewriter.notifyMatchFailure(constantOp,
+                                         kMatchFailureNotSMETileTypeMultiple);
+
+    auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
+    auto tileCount = getNumberOfSMETilesForVectorType(vectorType);
+    auto tileSplat = rewriter.create<arith::ConstantOp>(
+        constantOp.getLoc(), denseAttr.resizeSplat(smeTileType));
+    rewriter.replaceOp(constantOp, SmallVector<Value>(tileCount, tileSplat),
+                       adaptor.getResultMapping());
+
+    return success();
+  }
+};
+
 /// Legalize `vector.outerproduct` operations to fit within SME tiles by
 /// decomposing them into tile-sized operations.
 struct LegalizeVectorOuterProductOpsByDecomposition
@@ -637,7 +666,8 @@ struct VectorLegalizationPass
     // Note: High benefit to ensure masked outer products are lowered first.
     patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition>(
         converter, context, 1024);
-    patterns.add<LegalizeVectorOuterProductOpsByDecomposition,
+    patterns.add<LegalizeArithConstantOpsByDecomposition,
+                 LegalizeVectorOuterProductOpsByDecomposition,
                  LegalizeTransferReadOpsByDecomposition,
                  LegalizeTransferWriteOpsByDecomposition>(converter, context);
     populateFuncTypeConversionPatterns(converter, patterns);
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index f8be697548c197..f43ef1cce787c5 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -433,3 +433,14 @@ func.func @lift_illegal_1d_shape_cast_to_memory(%a: index, %b: index, %memref: m
   %cast = vector.shape_cast %illegalRead : vector<[4]x1xf32> to vector<[4]xf32>
   return %cast : vector<[4]xf32>
 }
+
+// -----
+
+// CHECK-LABEL: @multi_tile_splat
+func.func @multi_tile_splat() -> vector<[8]x[8]xi32>
+{
+  // CHECK: %[[SPLAT:.*]] = arith.constant dense<42> : vector<[4]x[4]xi32>
+  // CHECK-NEXT: return %[[SPLAT]], %[[SPLAT]], %[[SPLAT]], %[[SPLAT]] : vector<[4]x[4]xi32>, vector<[4]x[4]xi32>, vector<[4]x[4]xi32>, vector<[4]x[4]xi32>
+  %0 = arith.constant dense<42> : vector<[8]x[8]xi32>
+  return %0 : vector<[8]x[8]xi32>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/88762


More information about the Mlir-commits mailing list