[Mlir-commits] [mlir] dadcaf8 - [mlir][ArmSME] Support decomposing constant splats into ArmSME tiles (#88762)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Apr 16 04:54:05 PDT 2024


Author: Benjamin Maxwell
Date: 2024-04-16T12:54:01+01:00
New Revision: dadcaf82274805456b7d85131cf94f921b5398b7

URL: https://github.com/llvm/llvm-project/commit/dadcaf82274805456b7d85131cf94f921b5398b7
DIFF: https://github.com/llvm/llvm-project/commit/dadcaf82274805456b7d85131cf94f921b5398b7.diff

LOG: [mlir][ArmSME] Support decomposing constant splats into ArmSME tiles (#88762)

This adds a simple rewrite/legalization to decompose constant splats
larger than a single ArmSME tile into multiple SME virtual tile sized
splats. E.g. a constant splat to `vector<[8]x[8]xi32>` would decompose
into four `vector<[4]x[4]xi32>` splats.

Added: 
    

Modified: 
    mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
    mlir/test/Dialect/ArmSME/vector-legalization.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 31500c62c0d600..b595c6dd8a6848 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -165,6 +165,35 @@ int getNumberOfSMETilesForVectorType(VectorType type) {
   return (vectorRows * vectorCols) / (minNumElts * minNumElts);
 }
 
+/// Legalize `arith.constant dense<value>` splat operations to fit within SME
+/// tiles by decomposing them into tile-sized operations.
+struct LegalizeArithConstantOpsByDecomposition
+    : public OneToNOpConversionPattern<arith::ConstantOp> {
+  using OneToNOpConversionPattern::OneToNOpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(arith::ConstantOp constantOp, OpAdaptor adaptor,
+                  OneToNPatternRewriter &rewriter) const override {
+    auto vectorType = dyn_cast<VectorType>(constantOp.getType());
+    auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
+    if (!vectorType || !denseAttr || !denseAttr.isSplat())
+      return failure();
+
+    if (!isMultipleOfSMETileVectorType(vectorType))
+      return rewriter.notifyMatchFailure(constantOp,
+                                         kMatchFailureNotSMETileTypeMultiple);
+
+    auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
+    auto tileCount = getNumberOfSMETilesForVectorType(vectorType);
+    auto tileSplat = rewriter.create<arith::ConstantOp>(
+        constantOp.getLoc(), denseAttr.resizeSplat(smeTileType));
+    rewriter.replaceOp(constantOp, SmallVector<Value>(tileCount, tileSplat),
+                       adaptor.getResultMapping());
+
+    return success();
+  }
+};
+
 /// Legalize `vector.outerproduct` operations to fit within SME tiles by
 /// decomposing them into tile-sized operations.
 struct LegalizeVectorOuterProductOpsByDecomposition
@@ -637,7 +666,8 @@ struct VectorLegalizationPass
     // Note: High benefit to ensure masked outer products are lowered first.
     patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition>(
         converter, context, 1024);
-    patterns.add<LegalizeVectorOuterProductOpsByDecomposition,
+    patterns.add<LegalizeArithConstantOpsByDecomposition,
+                 LegalizeVectorOuterProductOpsByDecomposition,
                  LegalizeTransferReadOpsByDecomposition,
                  LegalizeTransferWriteOpsByDecomposition>(converter, context);
     populateFuncTypeConversionPatterns(converter, patterns);

diff  --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index f8be697548c197..f43ef1cce787c5 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -433,3 +433,14 @@ func.func @lift_illegal_1d_shape_cast_to_memory(%a: index, %b: index, %memref: m
   %cast = vector.shape_cast %illegalRead : vector<[4]x1xf32> to vector<[4]xf32>
   return %cast : vector<[4]xf32>
 }
+
+// -----
+
+// CHECK-LABEL: @multi_tile_splat
+func.func @multi_tile_splat() -> vector<[8]x[8]xi32>
+{
+  // CHECK: %[[SPLAT:.*]] = arith.constant dense<42> : vector<[4]x[4]xi32>
+  // CHECK-NEXT: return %[[SPLAT]], %[[SPLAT]], %[[SPLAT]], %[[SPLAT]] : vector<[4]x[4]xi32>, vector<[4]x[4]xi32>, vector<[4]x[4]xi32>, vector<[4]x[4]xi32>
+  %0 = arith.constant dense<42> : vector<[8]x[8]xi32>
+  return %0 : vector<[8]x[8]xi32>
+}


        


More information about the Mlir-commits mailing list