[Mlir-commits] [mlir] [mlir][ArmSME] Add arith-to-arm-sme conversion pass (PR #78197)

Cullen Rhodes llvmlistbot at llvm.org
Wed Jan 17 09:02:25 PST 2024


https://github.com/c-rhodes updated https://github.com/llvm/llvm-project/pull/78197

>From c07b16830925bd86d2440423a2fd409c26799854 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Mon, 15 Jan 2024 09:31:43 +0000
Subject: [PATCH 1/2] [mlir][ArmSME] Add arith-to-arm-sme conversion pass

Existing 'arith::ConstantOp' conversion and tests are moved from
VectorToArmSME. There's currently only a single op that's converted at
the moment, but this will grow in the future as things like in-tile add
are implemented. Also, 'createLoopOverTileSlices' is moved to ArmSME
utils since it's relevant for both conversions.
---
 .../Conversion/ArithToArmSME/ArithToArmSME.h  |  27 ++++
 mlir/include/mlir/Conversion/Passes.h         |   1 +
 mlir/include/mlir/Conversion/Passes.td        |   9 ++
 .../include/mlir/Dialect/ArmSME/Utils/Utils.h |  17 +++
 .../ArithToArmSME/ArithToArmSME.cpp           | 127 +++++++++++++++++
 .../Conversion/ArithToArmSME/CMakeLists.txt   |  18 +++
 mlir/lib/Conversion/CMakeLists.txt            |   1 +
 .../VectorToArmSME/VectorToArmSME.cpp         | 134 ++++--------------
 mlir/lib/Dialect/ArmSME/IR/Utils.cpp          |  20 +++
 .../ArithToArmSME/arith-to-arm-sme.mlir}      |   2 +-
 .../Dialect/ArmSME/vector-ops-to-llvm.mlir    |   2 +-
 .../Dialect/Linalg/CPU/ArmSME/fill-2d.mlir    |   3 +-
 .../Linalg/CPU/ArmSME/use-too-many-tiles.mlir |   4 +-
 .../CPU/ArmSME/test-outerproduct-f32.mlir     |   3 +-
 .../CPU/ArmSME/test-outerproduct-f64.mlir     |   3 +-
 .../CPU/ArmSME/test-transfer-write-2d.mlir    |   3 +-
 .../Dialect/Vector/CPU/ArmSME/tile_fill.mlir  |   3 +-
 .../Dialect/Vector/CPU/ArmSME/vector-ops.mlir |   3 +-
 18 files changed, 264 insertions(+), 116 deletions(-)
 create mode 100644 mlir/include/mlir/Conversion/ArithToArmSME/ArithToArmSME.h
 create mode 100644 mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp
 create mode 100644 mlir/lib/Conversion/ArithToArmSME/CMakeLists.txt
 rename mlir/test/{Dialect/ArmSME/arith-ops-to-sme.mlir => Conversion/ArithToArmSME/arith-to-arm-sme.mlir} (97%)

diff --git a/mlir/include/mlir/Conversion/ArithToArmSME/ArithToArmSME.h b/mlir/include/mlir/Conversion/ArithToArmSME/ArithToArmSME.h
new file mode 100644
index 00000000000000..012e7fb5b0af2f
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ArithToArmSME/ArithToArmSME.h
@@ -0,0 +1,27 @@
+//===- ArithToArmSME.h - Arith to ArmSME dialect conversion -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_ARITHTOARMSME_ARITHTOARMSME_H
+#define MLIR_CONVERSION_ARITHTOARMSME_ARITHTOARMSME_H
+
+#include <memory>
+
+namespace mlir {
+
+class RewritePatternSet;
+class Pass;
+
+#define GEN_PASS_DECL_ARITHTOARMSMECONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+
+namespace arith {
+void populateArithToArmSMEConversionPatterns(RewritePatternSet &patterns);
+} // namespace arith
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_ARITHTOARMSME_ARITHTOARMSME_H
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index a25fd17ea923fb..0bfc5064c5dd72 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -12,6 +12,7 @@
 #include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"
 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
 #include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h"
+#include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h"
 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
 #include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
 #include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 71be8841ca7c03..3467e042c493e9 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -164,6 +164,15 @@ def ConvertArithToSPIRV : Pass<"convert-arith-to-spirv"> {
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// ArithToArmSME
+//===----------------------------------------------------------------------===//
+
+def ArithToArmSMEConversionPass : Pass<"convert-arith-to-arm-sme"> {
+  let summary = "Convert Arith dialect to ArmSME dialect";
+  let dependentDialects = ["arm_sme::ArmSMEDialect"];
+}
+
 //===----------------------------------------------------------------------===//
 // ArmNeon2dToIntr
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
index b7d90195d49d76..a15eac7302077b 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
@@ -16,9 +16,16 @@
 #define MLIR_DIALECT_ARMSME_UTILS_UTILS_H_
 
 #include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include <optional>
 
+namespace mlir {
+class Location;
+class PatternRewriter;
+class Value;
+} // namespace mlir
+
 namespace mlir::arm_sme {
 
 constexpr unsigned MinStreamingVectorLengthInBits = 128;
@@ -42,6 +49,16 @@ std::optional<ArmSMETileType> getSMETileType(VectorType);
 /// Verifies the tile ID (if set) on this tile operation is valid.
 LogicalResult verifyOperationHasValidTileId(Operation *);
 
+using LoopBodyBuilder =
+    std::function<void(OpBuilder &, Location, Value, Value)>;
+
+/// Generates a for loop over ZA tile slices where the induction variable is
+/// the tile slice index and each iteration yields a new tile. Loop body is
+/// built via LoopBodyBuilder, which returns the next tile value.
+scf::ForOp createLoopOverTileSlices(PatternRewriter &rewriter, Location loc,
+                                    Value initTile,
+                                    LoopBodyBuilder bodyBuilder);
+
 } // namespace mlir::arm_sme
 
 #endif // MLIR_DIALECT_ARMSME_UTILS_UTILS_H_
diff --git a/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp b/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp
new file mode 100644
index 00000000000000..9aab969881f75e
--- /dev/null
+++ b/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp
@@ -0,0 +1,127 @@
+//===- ArithToArmSME.cpp - Arith to ArmSME dialect conversion -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+#include "mlir/Dialect/ArmSME/Utils/Utils.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_ARITHTOARMSMECONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+#define DEBUG_TYPE "arith-to-arm-sme"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Conversion helpers
+//===----------------------------------------------------------------------===//
+
+/// Returns true if 'val' is a splat of zero, false otherwise.
+static bool isSplatZero(Type elemType, DenseElementsAttr val) {
+  if (llvm::isa<FloatType>(elemType))
+    return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
+  if (llvm::isa<IntegerType>(elemType))
+    return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
+  return false;
+}
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// ConstantOp
+//===----------------------------------------------------------------------===//
+
+/// Conversion pattern for dense arith.constant.
+struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> {
+  using OpRewritePattern<arith::ConstantOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(arith::ConstantOp constantOp,
+                                PatternRewriter &rewriter) const final {
+    auto tileType = dyn_cast<VectorType>(constantOp.getType());
+    if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
+      return failure();
+
+    auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
+    if (!denseAttr || !denseAttr.isSplat())
+      return failure();
+
+    auto tileElementType = tileType.getElementType();
+
+    // Lower 'arith.constant dense<0>' to 'arm_sme.zero' op.
+    if (isSplatZero(tileElementType, denseAttr)) {
+      rewriter.replaceOpWithNewOp<arm_sme::ZeroOp>(constantOp, tileType);
+      return success();
+    }
+
+    // Lower non-zero constants to a loop of 'arm_sme.move_vector_to_tile_slice'
+    // ops that broadcast the constant to each tile slice.
+    auto loc = constantOp.getLoc();
+
+    // To fill a tile with a constant, we create a 1-D splat of the constant,
+    // then move that into each tile slice (the largest unit we can set at once,
+    // outside of operations like the outerproduct).
+    VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
+    auto denseAttr1D = DenseElementsAttr::get(
+        tileSliceType, denseAttr.getSplatValue<Attribute>());
+    auto constantOp1D = rewriter.create<arith::ConstantOp>(loc, denseAttr1D);
+
+    auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
+    arm_sme::LoopBodyBuilder loopBody = [&](OpBuilder &b, Location loc,
+                                            Value tileSliceIndex,
+                                            Value currentTile) {
+      // Create 'arm_sme.move_vector_to_tile_slice' to write vector to tile
+      // slice.
+      auto nextTile = b.create<arm_sme::MoveVectorToTileSliceOp>(
+          loc, tileType, constantOp1D, currentTile, tileSliceIndex);
+      b.create<scf::YieldOp>(loc, nextTile.getResult());
+      return;
+    };
+    auto forOp = mlir::arm_sme::createLoopOverTileSlices(rewriter, loc,
+                                                         initTile, loopBody);
+    rewriter.replaceOp(constantOp, forOp.getResult(0));
+
+    return success();
+  }
+};
+
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Pattern population
+//===----------------------------------------------------------------------===//
+
+void mlir::arith::populateArithToArmSMEConversionPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<ConstantOpToArmSMELowering>(patterns.getContext());
+}
+
+//===----------------------------------------------------------------------===//
+// Pass definition
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct ArithToArmSMEConversionPass final
+    : impl::ArithToArmSMEConversionPassBase<ArithToArmSMEConversionPass> {
+  using impl::ArithToArmSMEConversionPassBase<
+      ArithToArmSMEConversionPass>::ArithToArmSMEConversionPassBase;
+
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    arith::populateArithToArmSMEConversionPatterns(patterns);
+    if (failed(
+            applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
+      return signalPassFailure();
+  }
+};
+} // namespace
diff --git a/mlir/lib/Conversion/ArithToArmSME/CMakeLists.txt b/mlir/lib/Conversion/ArithToArmSME/CMakeLists.txt
new file mode 100644
index 00000000000000..c2a6fe5398e7c8
--- /dev/null
+++ b/mlir/lib/Conversion/ArithToArmSME/CMakeLists.txt
@@ -0,0 +1,18 @@
+add_mlir_conversion_library(MLIRArithToArmSME
+  ArithToArmSME.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToArmSME
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRArmSMEDialect
+  MLIRArithDialect
+  MLIRPass
+  MLIRTransforms
+  )
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index c3a2481975040c..3a5dbc12c23f5c 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -2,6 +2,7 @@ add_subdirectory(AffineToStandard)
 add_subdirectory(AMDGPUToROCDL)
 add_subdirectory(ArithCommon)
 add_subdirectory(ArithToAMDGPU)
+add_subdirectory(ArithToArmSME)
 add_subdirectory(ArithToLLVM)
 add_subdirectory(ArithToSPIRV)
 add_subdirectory(ArmNeon2dToIntr)
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 87d1bf9bed5a31..88252725bcff26 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -16,39 +16,6 @@
 
 using namespace mlir;
 
-/// Returns true if 'val' is a splat of zero, false otherwise.
-static bool isSplatZero(Type elemType, DenseElementsAttr val) {
-  if (llvm::isa<FloatType>(elemType))
-    return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
-  if (llvm::isa<IntegerType>(elemType))
-    return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
-  return false;
-}
-
-/// Generates a for loop over ZA tile slices where the induction variable is
-/// the tile slice index and each iteration yields a new tile. Loop body is
-/// built via the callback, which returns the next tile value.
-template <typename LoopBodyCallback>
-static scf::ForOp createLoopOverTileSlices(PatternRewriter &rewriter,
-                                           Location loc, Value initTile,
-                                           LoopBodyCallback callback) {
-  OpBuilder::InsertionGuard g(rewriter);
-  auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
-  auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
-      loc, llvm::cast<VectorType>(initTile.getType()).getDimSize(0));
-  auto vscale =
-      rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
-  auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
-  auto numTileSlices =
-      rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
-  auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step,
-                                           ValueRange{initTile});
-  rewriter.setInsertionPointToStart(forOp.getBody());
-  auto nextTile = callback(forOp);
-  rewriter.create<scf::YieldOp>(loc, nextTile.getResult());
-  return forOp;
-}
-
 namespace {
 
 /// Conversion pattern for vector.transfer_read.
@@ -223,56 +190,6 @@ struct VectorStoreToArmSMELowering : public OpRewritePattern<vector::StoreOp> {
   }
 };
 
-/// Conversion pattern for dense arith.constant.
-struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> {
-  using OpRewritePattern<arith::ConstantOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(arith::ConstantOp constantOp,
-                                PatternRewriter &rewriter) const final {
-    auto tileType = dyn_cast<VectorType>(constantOp.getType());
-    if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
-      return failure();
-
-    auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
-    if (!denseAttr || !denseAttr.isSplat())
-      return failure();
-
-    auto tileElementType = tileType.getElementType();
-
-    // Lower 'arith.constant dense<0>' to 'arm_sme.zero' op.
-    if (isSplatZero(tileElementType, denseAttr)) {
-      rewriter.replaceOpWithNewOp<arm_sme::ZeroOp>(constantOp, tileType);
-      return success();
-    }
-
-    // Lower non-zero constants to a loop of 'arm_sme.move_vector_to_tile_slice'
-    // ops that broadcast the constant to each tile slice.
-    auto loc = constantOp.getLoc();
-
-    // To fill a tile with a constant, we create a 1-D splat of the constant,
-    // then move that into each tile slice (the largest unit we can set at once,
-    // outside of operations like the outerproduct).
-    VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
-    auto denseAttr1D = DenseElementsAttr::get(
-        tileSliceType, denseAttr.getSplatValue<Attribute>());
-    auto constantOp1D = rewriter.create<arith::ConstantOp>(loc, denseAttr1D);
-
-    auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
-    auto forOp =
-        createLoopOverTileSlices(rewriter, loc, initTile, [&](auto forOp) {
-          auto tileSliceIndex = forOp.getInductionVar();
-          auto currentTile = forOp.getRegionIterArg(0);
-          // Create 'arm_sme.move_vector_to_tile_slice' to write vector to tile
-          // slice.
-          return rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
-              loc, tileType, constantOp1D, currentTile, tileSliceIndex);
-        });
-    rewriter.replaceOp(constantOp, forOp.getResult(0));
-
-    return success();
-  }
-};
-
 /// Conversion pattern for vector.broadcast.
 ///
 /// Example:
@@ -322,16 +239,19 @@ struct BroadcastOpToArmSMELowering
 
     auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
 
+    arm_sme::LoopBodyBuilder loopBody = [&](OpBuilder &b, Location loc,
+                                            Value tileSliceIndex,
+                                            Value currentTile) {
+      // Create 'arm_sme.move_vector_to_tile_slice' to broadcast the value
+      // to each tile slice.
+      auto nextTile = b.create<arm_sme::MoveVectorToTileSliceOp>(
+          loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
+      b.create<scf::YieldOp>(loc, nextTile.getResult());
+      return;
+    };
+
     // Create a loop over ZA tile slices.
-    auto forOp =
-        createLoopOverTileSlices(rewriter, loc, initTile, [&](auto forOp) {
-          auto tileSliceIndex = forOp.getInductionVar();
-          auto currentTile = forOp.getRegionIterArg(0);
-          // Create 'arm_sme.move_vector_to_tile_slice' to broadcast the value
-          // to each tile slice.
-          return rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
-              loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
-        });
+    auto forOp = createLoopOverTileSlices(rewriter, loc, initTile, loopBody);
 
     rewriter.replaceOp(broadcastOp, forOp.getResult(0));
 
@@ -381,15 +301,18 @@ struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> {
 
     auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
 
+    arm_sme::LoopBodyBuilder loopBody = [&](OpBuilder &b, Location loc,
+                                            Value tileSliceIndex,
+                                            Value currentTile) {
+      auto nextTile = b.create<arm_sme::MoveVectorToTileSliceOp>(
+          loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
+      b.create<scf::YieldOp>(loc, nextTile.getResult());
+      return;
+    };
+
     // Next, create a loop over ZA tile slices and "move" the generated 1-d
     // vector to each slice.
-    auto forOp =
-        createLoopOverTileSlices(rewriter, loc, initTile, [&](auto forOp) {
-          auto tileSliceIndex = forOp.getInductionVar();
-          auto currentTile = forOp.getRegionIterArg(0);
-          return rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
-              loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
-        });
+    auto forOp = createLoopOverTileSlices(rewriter, loc, initTile, loopBody);
 
     rewriter.replaceOp(splatOp, forOp.getResult(0));
 
@@ -741,11 +664,10 @@ struct VectorPrintToArmSMELowering : public OpRewritePattern<vector::PrintOp> {
 
 void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
                                           MLIRContext &ctx) {
-  patterns
-      .add<BroadcastOpToArmSMELowering, ConstantOpToArmSMELowering,
-           SplatOpToArmSMELowering, TransferReadToArmSMELowering,
-           TransferWriteToArmSMELowering, TransposeOpToArmSMELowering,
-           VectorLoadToArmSMELowering, VectorStoreToArmSMELowering,
-           VectorOuterProductToArmSMELowering, VectorExtractToArmSMELowering,
-           VectorInsertToArmSMELowering, VectorPrintToArmSMELowering>(&ctx);
+  patterns.add<BroadcastOpToArmSMELowering, SplatOpToArmSMELowering,
+               TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
+               TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
+               VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,
+               VectorExtractToArmSMELowering, VectorInsertToArmSMELowering,
+               VectorPrintToArmSMELowering>(&ctx);
 }
diff --git a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
index 1fa060cafc0bc6..2e159abb1e89eb 100644
--- a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
@@ -72,4 +72,24 @@ LogicalResult verifyOperationHasValidTileId(Operation *op) {
   return success();
 }
 
+scf::ForOp createLoopOverTileSlices(PatternRewriter &rewriter, Location loc,
+                                    Value initTile,
+                                    LoopBodyBuilder bodyBuilder) {
+  OpBuilder::InsertionGuard g(rewriter);
+  auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+  auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
+      loc, llvm::cast<VectorType>(initTile.getType()).getDimSize(0));
+  auto vscale =
+      rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
+  auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+  auto numTileSlices =
+      rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
+  auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step,
+                                           ValueRange{initTile});
+  rewriter.setInsertionPointToStart(forOp.getBody());
+  bodyBuilder(rewriter, loc, /*tileSliceIndex=*/forOp.getInductionVar(),
+              /*currentTile=*/forOp.getRegionIterArg(0));
+  return forOp;
+}
+
 } // namespace mlir::arm_sme
diff --git a/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir b/mlir/test/Conversion/ArithToArmSME/arith-to-arm-sme.mlir
similarity index 97%
rename from mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir
rename to mlir/test/Conversion/ArithToArmSME/arith-to-arm-sme.mlir
index e51f2485dadbcc..49d2e2f3c182b9 100644
--- a/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir
+++ b/mlir/test/Conversion/ArithToArmSME/arith-to-arm-sme.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-vector-to-arm-sme -split-input-file -allow-unregistered-dialect | FileCheck %s
+// RUN: mlir-opt %s -convert-arith-to-arm-sme -split-input-file -allow-unregistered-dialect | FileCheck %s
 
 // =============================================================================
 // arith.constant dense<0> to arm_sme.zero
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
index ce5bfd25cbdbcc..17a070999c20a0 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-vector-to-arm-sme -allocate-arm-sme-tiles -convert-arm-sme-to-scf -convert-arm-sme-to-llvm -cse -canonicalize -split-input-file -allow-unregistered-dialect -verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s -convert-vector-to-arm-sme -convert-arith-to-arm-sme -allocate-arm-sme-tiles -convert-arm-sme-to-scf -convert-arm-sme-to-llvm -cse -canonicalize -split-input-file -allow-unregistered-dialect -verify-diagnostics | FileCheck %s
 
 //===----------------------------------------------------------------------===//
 // vector.transfer_write
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir
index 6314e6f279952b..44ff1afe76d383 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir
@@ -4,7 +4,8 @@
 // RUN:   -lower-vector-mask \
 // RUN:   -one-shot-bufferize="bufferize-function-boundaries" \
 // RUN:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// RUN:   -convert-vector-to-arm-sme -allocate-arm-sme-tiles -convert-arm-sme-to-scf \
+// RUN:   -convert-vector-to-arm-sme -convert-arith-to-arm-sme \
+// RUN:   -allocate-arm-sme-tiles -convert-arm-sme-to-scf \
 // RUN:   -convert-arm-sme-to-llvm -cse -canonicalize \
 // RUN:   -test-lower-to-llvm | \
 // RUN: %mcr_aarch64_cmd \
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir
index dd9f280cb75099..42fe21cccd48a7 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir
@@ -1,6 +1,6 @@
 // RUN: mlir-opt %s \
-// RUN:   -convert-vector-to-arm-sme -allocate-arm-sme-tiles  \
-// RUN:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// RUN:   -convert-vector-to-arm-sme -convert-arith-to-arm-sme \
+// RUN:   -allocate-arm-sme-tiles -convert-arm-sme-to-scf \
 // RUN:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za only-if-required-by-ops"  \
 // RUN:   -convert-vector-to-scf -cse -arm-sve-legalize-vector-storage \
 // RUN:   -convert-arm-sme-to-llvm -convert-vector-to-llvm=enable-arm-sve -cse \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
index 8c73c24d695cfb..5f41b37560e760 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
@@ -1,7 +1,8 @@
 // DEFINE: %{entry_point} = test_outerproduct_no_accumulator_4x4xf32
 // DEFINE: %{compile} = mlir-opt %s \
 // DEFINE:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
+// DEFINE:   -convert-vector-to-arm-sme -convert-arith-to-arm-sme \
+// DEFINE:   -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
 // DEFINE:   -convert-arm-sme-to-llvm -cse -canonicalize \
 // DEFINE:   -test-lower-to-llvm -o %t
 // DEFINE: %{run} = %mcr_aarch64_cmd %t \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
index 965337c60b9ffd..a1bb9b7d6f80ec 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
@@ -1,7 +1,8 @@
 // DEFINE: %{entry_point} = test_outerproduct_no_accumulator_2x2xf64
 // DEFINE: %{compile} = mlir-opt %s \
 // DEFINE:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles  \
+// DEFINE:   -convert-vector-to-arm-sme -convert-arith-to-arm-sme \
+// DEFINE:   -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
 // DEFINE:   -convert-arm-sme-to-llvm -cse -canonicalize \
 // DEFINE:   -test-lower-to-llvm -o %t
 // DEFINE: %{run} = %mcr_aarch64_cmd %t \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
index cb30fee4e12d72..c0c1f55d7ddd1a 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
@@ -1,6 +1,7 @@
 // DEFINE: %{entry_point} = entry
 // DEFINE: %{compile} = mlir-opt %s \
-// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
+// DEFINE:   -convert-vector-to-arm-sme -convert-arith-to-arm-sme \
+// DEFINE:   -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
 // DEFINE:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za only-if-required-by-ops" \
 // DEFINE:   -convert-arm-sme-to-llvm -cse -canonicalize \
 // DEFINE:   -test-lower-to-llvm
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir
index b45f24f6c8fdda..223bc8ce74343b 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir
@@ -1,5 +1,6 @@
 // RUN: mlir-opt %s -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// RUN:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
+// RUN:   -convert-vector-to-arm-sme -convert-arith-to-arm-sme \
+// RUN:   -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
 // RUN:   -convert-arm-sme-to-llvm -cse -canonicalize \
 // RUN:   -test-lower-to-llvm | \
 // RUN: %mcr_aarch64_cmd \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
index 073c08bff1c415..f28bf19b299934 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
@@ -1,6 +1,7 @@
 // DEFINE: %{entry_point} = entry
 // DEFINE: %{compile} = mlir-opt %s -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
+// DEFINE:   -convert-vector-to-arm-sme -convert-arith-to-arm-sme \
+// DEFINE:   -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
 // DEFINE:   -convert-arm-sme-to-llvm -test-lower-to-llvm
 // DEFINE: %{run} = %mcr_aarch64_cmd \
 // DEFINE:  -march=aarch64 -mattr=+sve,+sme \

>From 716fcc1cea6b02e66c0561de7afc8e4e358c86f8 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Wed, 17 Jan 2024 16:57:10 +0000
Subject: [PATCH 2/2] address comments

---
 .../include/mlir/Dialect/ArmSME/Utils/Utils.h | 11 ++++------
 .../ArithToArmSME/ArithToArmSME.cpp           | 10 ++++------
 .../VectorToArmSME/VectorToArmSME.cpp         | 20 ++++++++-----------
 mlir/lib/Dialect/ArmSME/IR/Utils.cpp          | 12 ++++++-----
 4 files changed, 23 insertions(+), 30 deletions(-)

diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
index a15eac7302077b..e37581ce00f03c 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
@@ -49,15 +49,12 @@ std::optional<ArmSMETileType> getSMETileType(VectorType);
 /// Verifies the tile ID (if set) on this tile operation is valid.
 LogicalResult verifyOperationHasValidTileId(Operation *);
 
-using LoopBodyBuilder =
-    std::function<void(OpBuilder &, Location, Value, Value)>;
-
 /// Generates a for loop over ZA tile slices where the induction variable is
 /// the tile slice index and each iteration yields a new tile. Loop body is
-/// built via LoopBodyBuilder, which returns the next tile value.
-scf::ForOp createLoopOverTileSlices(PatternRewriter &rewriter, Location loc,
-                                    Value initTile,
-                                    LoopBodyBuilder bodyBuilder);
+/// built via the callback, which returns the next tile value.
+scf::ForOp createLoopOverTileSlices(
+    PatternRewriter &rewriter, Location loc, Value initTile,
+    std::function<Value(OpBuilder &, Location, Value, Value)> callback);
 
 } // namespace mlir::arm_sme
 
diff --git a/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp b/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp
index 9aab969881f75e..2f562ba3e1ce00 100644
--- a/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp
+++ b/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp
@@ -77,18 +77,16 @@ struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> {
     auto constantOp1D = rewriter.create<arith::ConstantOp>(loc, denseAttr1D);
 
     auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
-    arm_sme::LoopBodyBuilder loopBody = [&](OpBuilder &b, Location loc,
-                                            Value tileSliceIndex,
-                                            Value currentTile) {
+    auto callback = [&](OpBuilder &b, Location loc, Value tileSliceIndex,
+                        Value currentTile) {
       // Create 'arm_sme.move_vector_to_tile_slice' to write vector to tile
       // slice.
       auto nextTile = b.create<arm_sme::MoveVectorToTileSliceOp>(
           loc, tileType, constantOp1D, currentTile, tileSliceIndex);
-      b.create<scf::YieldOp>(loc, nextTile.getResult());
-      return;
+      return nextTile.getResult();
     };
     auto forOp = mlir::arm_sme::createLoopOverTileSlices(rewriter, loc,
-                                                         initTile, loopBody);
+                                                         initTile, callback);
     rewriter.replaceOp(constantOp, forOp.getResult(0));
 
     return success();
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 88252725bcff26..0d1c092b2079e3 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -239,19 +239,17 @@ struct BroadcastOpToArmSMELowering
 
     auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
 
-    arm_sme::LoopBodyBuilder loopBody = [&](OpBuilder &b, Location loc,
-                                            Value tileSliceIndex,
-                                            Value currentTile) {
+    auto callback = [&](OpBuilder &b, Location loc, Value tileSliceIndex,
+                        Value currentTile) {
       // Create 'arm_sme.move_vector_to_tile_slice' to broadcast the value
       // to each tile slice.
       auto nextTile = b.create<arm_sme::MoveVectorToTileSliceOp>(
           loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
-      b.create<scf::YieldOp>(loc, nextTile.getResult());
-      return;
+      return nextTile.getResult();
     };
 
     // Create a loop over ZA tile slices.
-    auto forOp = createLoopOverTileSlices(rewriter, loc, initTile, loopBody);
+    auto forOp = createLoopOverTileSlices(rewriter, loc, initTile, callback);
 
     rewriter.replaceOp(broadcastOp, forOp.getResult(0));
 
@@ -301,18 +299,16 @@ struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> {
 
     auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
 
-    arm_sme::LoopBodyBuilder loopBody = [&](OpBuilder &b, Location loc,
-                                            Value tileSliceIndex,
-                                            Value currentTile) {
+    auto callback = [&](OpBuilder &b, Location loc, Value tileSliceIndex,
+                        Value currentTile) {
       auto nextTile = b.create<arm_sme::MoveVectorToTileSliceOp>(
           loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
-      b.create<scf::YieldOp>(loc, nextTile.getResult());
-      return;
+      return nextTile.getResult();
     };
 
     // Next, create a loop over ZA tile slices and "move" the generated 1-d
     // vector to each slice.
-    auto forOp = createLoopOverTileSlices(rewriter, loc, initTile, loopBody);
+    auto forOp = createLoopOverTileSlices(rewriter, loc, initTile, callback);
 
     rewriter.replaceOp(splatOp, forOp.getResult(0));
 
diff --git a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
index 2e159abb1e89eb..916691a1c7b9bc 100644
--- a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
@@ -72,9 +72,9 @@ LogicalResult verifyOperationHasValidTileId(Operation *op) {
   return success();
 }
 
-scf::ForOp createLoopOverTileSlices(PatternRewriter &rewriter, Location loc,
-                                    Value initTile,
-                                    LoopBodyBuilder bodyBuilder) {
+scf::ForOp createLoopOverTileSlices(
+    PatternRewriter &rewriter, Location loc, Value initTile,
+    std::function<Value(OpBuilder &, Location, Value, Value)> callback) {
   OpBuilder::InsertionGuard g(rewriter);
   auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
   auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
@@ -87,8 +87,10 @@ scf::ForOp createLoopOverTileSlices(PatternRewriter &rewriter, Location loc,
   auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step,
                                            ValueRange{initTile});
   rewriter.setInsertionPointToStart(forOp.getBody());
-  bodyBuilder(rewriter, loc, /*tileSliceIndex=*/forOp.getInductionVar(),
-              /*currentTile=*/forOp.getRegionIterArg(0));
+  Value nextTile =
+      callback(rewriter, loc, /*tileSliceIndex=*/forOp.getInductionVar(),
+               /*currentTile=*/forOp.getRegionIterArg(0));
+  rewriter.create<scf::YieldOp>(loc, nextTile);
   return forOp;
 }
 



More information about the Mlir-commits mailing list