[Mlir-commits] [mlir] 10063c5 - [mlir][ArmSME] Move vector.print -> ArmSME lowering to VectorToArmSME (#74063)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Dec 4 01:42:16 PST 2023


Author: Benjamin Maxwell
Date: 2023-12-04T09:42:11Z
New Revision: 10063c5a29b9dd5041ea7e0ab9569ed74ef54d32

URL: https://github.com/llvm/llvm-project/commit/10063c5a29b9dd5041ea7e0ab9569ed74ef54d32
DIFF: https://github.com/llvm/llvm-project/commit/10063c5a29b9dd5041ea7e0ab9569ed74ef54d32.diff

LOG: [mlir][ArmSME] Move vector.print -> ArmSME lowering to VectorToArmSME (#74063)

This moves the SME tile vector.print lowering from
`-convert-arm-sme-to-scf` to `-convert-vector-to-arm-sme`. This seems
like a more logical place, as this is lowering a vector op to ArmSME,
and it also prevents vector.print from blocking tile allocation.

Added: 
    

Modified: 
    mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
    mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
    mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
    mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 69c68663070b6..fece03040dbb8 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -447,75 +447,13 @@ struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
   }
 };
 
-/// Lowers `vector.print` of a tile into a loop over the rows of the tile,
-/// extracting them via `arm_sme.move_tile_slice_to_vector`, then printing with
-/// a 1D `vector.print`.
-///
-///  BEFORE:
-///  ```mlir
-///  vector.print %tile : vector<[4]x[4]xf32>
-///  ```
-///  AFTER:
-///  ```mlir
-///  %c0 = arith.constant 0 : index
-///  %c1 = arith.constant 1 : index
-///  %c4 = arith.constant 4 : index
-///  %vscale = vector.vscale
-///  %svl_s = arith.muli %c4, %vscale : index
-///  scf.for %i = %c0 to %svl_s step %c1 {
-///    %tile_slice = arm_sme.move_tile_slice_to_vector %tile[%i]
-///                     : vector<[4]xf32> from vector<[4]x[4]xf32>
-///    vector.print %tile_slice : vector<[4]xf32>
-///  }
-///  ```
-struct TileVectorPrintOpConversion : public OpRewritePattern<vector::PrintOp> {
-  using OpRewritePattern<vector::PrintOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(vector::PrintOp printOp,
-                                PatternRewriter &rewriter) const override {
-    if (!printOp.getSource())
-      return failure();
-
-    VectorType vectorType = dyn_cast<VectorType>(printOp.getPrintType());
-    if (!vectorType || !arm_sme::isValidSMETileVectorType(vectorType))
-      return failure();
-
-    auto loc = printOp.getLoc();
-
-    // Create a loop over the rows of the tile.
-    auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
-    auto minTileRows =
-        rewriter.create<arith::ConstantIndexOp>(loc, vectorType.getDimSize(0));
-    auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
-    auto upperBound = rewriter.create<arith::MulIOp>(loc, minTileRows, vscale);
-    auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
-    auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
-    {
-      // Loop body.
-      rewriter.setInsertionPointToStart(forOp.getBody());
-      // Extract the current row from the tile.
-      Value rowIndex = forOp.getInductionVar();
-      // FIXME: Forward tile IDs.
-      // For now, if you vector.print a SME tile you need to do
-      // -allocate-arm-sme-tiles after -convert-arm-sme-to-scf.
-      auto tileSlice = rewriter.create<arm_sme::MoveTileSliceToVectorOp>(
-          loc, printOp.getSource(), rowIndex);
-      // Print the row with a 1D vector.print.
-      rewriter.create<vector::PrintOp>(loc, tileSlice,
-                                       printOp.getPunctuation());
-    }
-
-    rewriter.eraseOp(printOp);
-    return success();
-  }
-};
-
 } // namespace
 
 void mlir::populateArmSMEToSCFConversionPatterns(RewritePatternSet &patterns) {
-  patterns.add<TileLoadOpConversion, TileLoadOpWithMaskAndPadZeroConversion,
-               TileLoadOpWithMaskAndPadNonZeroConversion, TileStoreOpConversion,
-               TileVectorPrintOpConversion>(patterns.getContext());
+  patterns
+      .add<TileLoadOpConversion, TileLoadOpWithMaskAndPadZeroConversion,
+           TileLoadOpWithMaskAndPadNonZeroConversion, TileStoreOpConversion>(
+          patterns.getContext());
 }
 
 namespace {

diff  --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 3016c7b0a8477..4b3fd26c6d59e 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -666,15 +666,75 @@ struct VectorInsertToArmSMELowering
   }
 };
 
+/// Lowers `vector.print` of a tile into a loop over the rows of the tile,
+/// extracting them via `arm_sme.move_tile_slice_to_vector`, then printing with
+/// a 1D `vector.print`.
+///
+///  BEFORE:
+///  ```mlir
+///  vector.print %tile : vector<[4]x[4]xf32>
+///  ```
+///  AFTER:
+///  ```mlir
+///  %c0 = arith.constant 0 : index
+///  %c1 = arith.constant 1 : index
+///  %c4 = arith.constant 4 : index
+///  %vscale = vector.vscale
+///  %svl_s = arith.muli %c4, %vscale : index
+///  scf.for %i = %c0 to %svl_s step %c1 {
+///    %tile_slice = arm_sme.move_tile_slice_to_vector %tile[%i]
+///                     : vector<[4]xf32> from vector<[4]x[4]xf32>
+///    vector.print %tile_slice : vector<[4]xf32>
+///  }
+///  ```
+struct VectorPrintToArmSMELowering : public OpRewritePattern<vector::PrintOp> {
+  using OpRewritePattern<vector::PrintOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::PrintOp printOp,
+                                PatternRewriter &rewriter) const override {
+    if (!printOp.getSource())
+      return failure();
+
+    VectorType vectorType = dyn_cast<VectorType>(printOp.getPrintType());
+    if (!vectorType || !arm_sme::isValidSMETileVectorType(vectorType))
+      return failure();
+
+    auto loc = printOp.getLoc();
+
+    // Create a loop over the rows of the tile.
+    auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
+    auto minTileRows =
+        rewriter.create<arith::ConstantIndexOp>(loc, vectorType.getDimSize(0));
+    auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    auto upperBound = rewriter.create<arith::MulIOp>(loc, minTileRows, vscale);
+    auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+    auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
+    {
+      // Loop body.
+      rewriter.setInsertionPointToStart(forOp.getBody());
+      // Extract the current row from the tile.
+      Value rowIndex = forOp.getInductionVar();
+      auto tileSlice = rewriter.create<arm_sme::MoveTileSliceToVectorOp>(
+          loc, printOp.getSource(), rowIndex);
+      // Print the row with a 1D vector.print.
+      rewriter.create<vector::PrintOp>(loc, tileSlice,
+                                       printOp.getPunctuation());
+    }
+
+    rewriter.eraseOp(printOp);
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
                                           MLIRContext &ctx) {
-  patterns.add<BroadcastOpToArmSMELowering, ConstantOpToArmSMELowering,
-               SplatOpToArmSMELowering, TransferReadToArmSMELowering,
-               TransferWriteToArmSMELowering, TransposeOpToArmSMELowering,
-               VectorLoadToArmSMELowering, VectorStoreToArmSMELowering,
-               VectorOuterProductToArmSMELowering,
-               VectorExtractToArmSMELowering, VectorInsertToArmSMELowering>(
-      &ctx);
+  patterns
+      .add<BroadcastOpToArmSMELowering, ConstantOpToArmSMELowering,
+           SplatOpToArmSMELowering, TransferReadToArmSMELowering,
+           TransferWriteToArmSMELowering, TransposeOpToArmSMELowering,
+           VectorLoadToArmSMELowering, VectorStoreToArmSMELowering,
+           VectorOuterProductToArmSMELowering, VectorExtractToArmSMELowering,
+           VectorInsertToArmSMELowering, VectorPrintToArmSMELowering>(&ctx);
 }

diff  --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
index fc28645a7acf7..efefc6c49e08f 100644
--- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -160,25 +160,3 @@ func.func @arm_sme_tile_store_hor_with_mask(%tile : vector<[4]x[4]xi32>, %dest :
   arm_sme.tile_store %tile, %dest[%c0, %c0], %mask : memref<?x?xi32>, vector<[4]x[4]xi32>
   return
 }
-
-//===----------------------------------------------------------------------===//
-// vector.print
-//===----------------------------------------------------------------------===//
-
-// -----
-
-func.func @arm_sme_tile_print(%tile: vector<[4]x[4]xf32>)
-{
-  vector.print %tile : vector<[4]x[4]xf32>
-  return
-}
-// CHECK-LABEL:   func.func @arm_sme_tile_print(
-// CHECK-SAME:                                  %[[TILE:.*]]: vector<[4]x[4]xf32>) {
-// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG:     %[[C4:.*]] = arith.constant 4 : index
-// CHECK-DAG:     %[[VSCALE:.*]] = vector.vscale
-// CHECK-DAG:     %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
-// CHECK-NEXT:      scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
-// CHECK-NEXT:        %[[TILE_SLICE:.*]] = arm_sme.move_tile_slice_to_vector %[[TILE]][%[[TILE_SLICE_INDEX]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
-// CHECK-NEXT:        vector.print %[[TILE_SLICE]] : vector<[4]xf32>

diff  --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
index 2491b2e2468cd..5bc147c60f3a6 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
@@ -736,3 +736,25 @@ func.func @vector_outerproduct_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64
   %result = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64>
   "prevent.dce"(%result) : (vector<[2]x[2]xf64>) -> ()
 }
+
+//===----------------------------------------------------------------------===//
+// vector.print
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @vector_print_tile(%tile: vector<[4]x[4]xf32>)
+{
+  vector.print %tile : vector<[4]x[4]xf32>
+  return
+}
+// CHECK-LABEL:   func.func @vector_print_tile(
+// CHECK-SAME:                                  %[[TILE:.*]]: vector<[4]x[4]xf32>) {
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:     %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG:     %[[VSCALE:.*]] = vector.vscale
+// CHECK-DAG:     %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
+// CHECK-NEXT:      scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
+// CHECK-NEXT:        %[[TILE_SLICE:.*]] = arm_sme.move_tile_slice_to_vector %[[TILE]][%[[TILE_SLICE_INDEX]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
+// CHECK-NEXT:        vector.print %[[TILE_SLICE]] : vector<[4]xf32>


        


More information about the Mlir-commits mailing list