[Mlir-commits] [mlir] [mlir][ArmSME] Add arith-to-arm-sme conversion pass (PR #78197)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jan 15 10:11:29 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-sme

Author: Cullen Rhodes (c-rhodes)

<details>
<summary>Changes</summary>

Existing 'arith::ConstantOp' conversion and tests are moved from VectorToArmSME. There's currently only a single op that's converted at the moment, but this will grow in the future as things like in-tile add are implemented. Also, 'createLoopOverTileSlices' is moved to ArmSME utils since it's relevant for both conversions.

---

Patch is 27.35 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/78197.diff


18 Files Affected:

- (added) mlir/include/mlir/Conversion/ArithToArmSME/ArithToArmSME.h (+27) 
- (modified) mlir/include/mlir/Conversion/Passes.h (+1) 
- (modified) mlir/include/mlir/Conversion/Passes.td (+9) 
- (modified) mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h (+17) 
- (added) mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp (+127) 
- (added) mlir/lib/Conversion/ArithToArmSME/CMakeLists.txt (+18) 
- (modified) mlir/lib/Conversion/CMakeLists.txt (+1) 
- (modified) mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp (+28-106) 
- (modified) mlir/lib/Dialect/ArmSME/IR/Utils.cpp (+20) 
- (renamed) mlir/test/Conversion/ArithToArmSME/arith-to-arm-sme.mlir (+1-1) 
- (modified) mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir (+1-1) 
- (modified) mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir (+2-1) 
- (modified) mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir (+2-2) 
- (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir (+2-1) 
- (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir (+2-1) 
- (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir (+2-1) 
- (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir (+2-1) 
- (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir (+2-1) 


``````````diff
diff --git a/mlir/include/mlir/Conversion/ArithToArmSME/ArithToArmSME.h b/mlir/include/mlir/Conversion/ArithToArmSME/ArithToArmSME.h
new file mode 100644
index 00000000000000..012e7fb5b0af2f
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ArithToArmSME/ArithToArmSME.h
@@ -0,0 +1,27 @@
+//===- ArithToArmSME.h - Arith to ArmSME dialect conversion -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_ARITHTOARMSME_ARITHTOARMSME_H
+#define MLIR_CONVERSION_ARITHTOARMSME_ARITHTOARMSME_H
+
+#include <memory>
+
+namespace mlir {
+
+class RewritePatternSet;
+class Pass;
+
+#define GEN_PASS_DECL_ARITHTOARMSMECONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+
+namespace arith {
+void populateArithToArmSMEConversionPatterns(RewritePatternSet &patterns);
+} // namespace arith
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_ARITHTOARMSME_ARITHTOARMSME_H
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index a25fd17ea923fb..0bfc5064c5dd72 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -12,6 +12,7 @@
 #include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"
 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
 #include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h"
+#include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h"
 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
 #include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
 #include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 71be8841ca7c03..3467e042c493e9 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -164,6 +164,15 @@ def ConvertArithToSPIRV : Pass<"convert-arith-to-spirv"> {
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// ArithToArmSME
+//===----------------------------------------------------------------------===//
+
+def ArithToArmSMEConversionPass : Pass<"convert-arith-to-arm-sme"> {
+  let summary = "Convert Arith dialect to ArmSME dialect";
+  let dependentDialects = ["arm_sme::ArmSMEDialect"];
+}
+
 //===----------------------------------------------------------------------===//
 // ArmNeon2dToIntr
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
index b7d90195d49d76..a15eac7302077b 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
@@ -16,9 +16,16 @@
 #define MLIR_DIALECT_ARMSME_UTILS_UTILS_H_
 
 #include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include <optional>
 
+namespace mlir {
+class Location;
+class PatternRewriter;
+class Value;
+} // namespace mlir
+
 namespace mlir::arm_sme {
 
 constexpr unsigned MinStreamingVectorLengthInBits = 128;
@@ -42,6 +49,16 @@ std::optional<ArmSMETileType> getSMETileType(VectorType);
 /// Verifies the tile ID (if set) on this tile operation is valid.
 LogicalResult verifyOperationHasValidTileId(Operation *);
 
+using LoopBodyBuilder =
+    std::function<void(OpBuilder &, Location, Value, Value)>;
+
+/// Generates a for loop over ZA tile slices where the induction variable is
+/// the tile slice index and each iteration yields a new tile. Loop body is
+/// built via LoopBodyBuilder, which returns the next tile value.
+scf::ForOp createLoopOverTileSlices(PatternRewriter &rewriter, Location loc,
+                                    Value initTile,
+                                    LoopBodyBuilder bodyBuilder);
+
 } // namespace mlir::arm_sme
 
 #endif // MLIR_DIALECT_ARMSME_UTILS_UTILS_H_
diff --git a/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp b/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp
new file mode 100644
index 00000000000000..9aab969881f75e
--- /dev/null
+++ b/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp
@@ -0,0 +1,127 @@
+//===- ArithToArmSME.cpp - Arith to ArmSME dialect conversion -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+#include "mlir/Dialect/ArmSME/Utils/Utils.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_ARITHTOARMSMECONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+#define DEBUG_TYPE "arith-to-arm-sme"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Conversion helpers
+//===----------------------------------------------------------------------===//
+
+/// Returns true if 'val' is a splat of zero, false otherwise.
+static bool isSplatZero(Type elemType, DenseElementsAttr val) {
+  if (llvm::isa<FloatType>(elemType))
+    return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
+  if (llvm::isa<IntegerType>(elemType))
+    return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
+  return false;
+}
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// ConstantOp
+//===----------------------------------------------------------------------===//
+
+/// Conversion pattern for dense arith.constant.
+struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> {
+  using OpRewritePattern<arith::ConstantOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(arith::ConstantOp constantOp,
+                                PatternRewriter &rewriter) const final {
+    auto tileType = dyn_cast<VectorType>(constantOp.getType());
+    if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
+      return failure();
+
+    auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
+    if (!denseAttr || !denseAttr.isSplat())
+      return failure();
+
+    auto tileElementType = tileType.getElementType();
+
+    // Lower 'arith.constant dense<0>' to 'arm_sme.zero' op.
+    if (isSplatZero(tileElementType, denseAttr)) {
+      rewriter.replaceOpWithNewOp<arm_sme::ZeroOp>(constantOp, tileType);
+      return success();
+    }
+
+    // Lower non-zero constants to a loop of 'arm_sme.move_vector_to_tile_slice'
+    // ops that broadcast the constant to each tile slice.
+    auto loc = constantOp.getLoc();
+
+    // To fill a tile with a constant, we create a 1-D splat of the constant,
+    // then move that into each tile slice (the largest unit we can set at once,
+    // outside of operations like the outerproduct).
+    VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
+    auto denseAttr1D = DenseElementsAttr::get(
+        tileSliceType, denseAttr.getSplatValue<Attribute>());
+    auto constantOp1D = rewriter.create<arith::ConstantOp>(loc, denseAttr1D);
+
+    auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
+    arm_sme::LoopBodyBuilder loopBody = [&](OpBuilder &b, Location loc,
+                                            Value tileSliceIndex,
+                                            Value currentTile) {
+      // Create 'arm_sme.move_vector_to_tile_slice' to write vector to tile
+      // slice.
+      auto nextTile = b.create<arm_sme::MoveVectorToTileSliceOp>(
+          loc, tileType, constantOp1D, currentTile, tileSliceIndex);
+      b.create<scf::YieldOp>(loc, nextTile.getResult());
+      return;
+    };
+    auto forOp = mlir::arm_sme::createLoopOverTileSlices(rewriter, loc,
+                                                         initTile, loopBody);
+    rewriter.replaceOp(constantOp, forOp.getResult(0));
+
+    return success();
+  }
+};
+
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Pattern population
+//===----------------------------------------------------------------------===//
+
+void mlir::arith::populateArithToArmSMEConversionPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<ConstantOpToArmSMELowering>(patterns.getContext());
+}
+
+//===----------------------------------------------------------------------===//
+// Pass definition
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct ArithToArmSMEConversionPass final
+    : impl::ArithToArmSMEConversionPassBase<ArithToArmSMEConversionPass> {
+  using impl::ArithToArmSMEConversionPassBase<
+      ArithToArmSMEConversionPass>::ArithToArmSMEConversionPassBase;
+
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    arith::populateArithToArmSMEConversionPatterns(patterns);
+    if (failed(
+            applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
+      return signalPassFailure();
+  }
+};
+} // namespace
diff --git a/mlir/lib/Conversion/ArithToArmSME/CMakeLists.txt b/mlir/lib/Conversion/ArithToArmSME/CMakeLists.txt
new file mode 100644
index 00000000000000..c2a6fe5398e7c8
--- /dev/null
+++ b/mlir/lib/Conversion/ArithToArmSME/CMakeLists.txt
@@ -0,0 +1,18 @@
+add_mlir_conversion_library(MLIRArithToArmSME
+  ArithToArmSME.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToArmSME
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRArmSMEDialect
+  MLIRArithDialect
+  MLIRPass
+  MLIRTransforms
+  )
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index c3a2481975040c..3a5dbc12c23f5c 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -2,6 +2,7 @@ add_subdirectory(AffineToStandard)
 add_subdirectory(AMDGPUToROCDL)
 add_subdirectory(ArithCommon)
 add_subdirectory(ArithToAMDGPU)
+add_subdirectory(ArithToArmSME)
 add_subdirectory(ArithToLLVM)
 add_subdirectory(ArithToSPIRV)
 add_subdirectory(ArmNeon2dToIntr)
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 87d1bf9bed5a31..88252725bcff26 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -16,39 +16,6 @@
 
 using namespace mlir;
 
-/// Returns true if 'val' is a splat of zero, false otherwise.
-static bool isSplatZero(Type elemType, DenseElementsAttr val) {
-  if (llvm::isa<FloatType>(elemType))
-    return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
-  if (llvm::isa<IntegerType>(elemType))
-    return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
-  return false;
-}
-
-/// Generates a for loop over ZA tile slices where the induction variable is
-/// the tile slice index and each iteration yields a new tile. Loop body is
-/// built via the callback, which returns the next tile value.
-template <typename LoopBodyCallback>
-static scf::ForOp createLoopOverTileSlices(PatternRewriter &rewriter,
-                                           Location loc, Value initTile,
-                                           LoopBodyCallback callback) {
-  OpBuilder::InsertionGuard g(rewriter);
-  auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
-  auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
-      loc, llvm::cast<VectorType>(initTile.getType()).getDimSize(0));
-  auto vscale =
-      rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
-  auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
-  auto numTileSlices =
-      rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
-  auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step,
-                                           ValueRange{initTile});
-  rewriter.setInsertionPointToStart(forOp.getBody());
-  auto nextTile = callback(forOp);
-  rewriter.create<scf::YieldOp>(loc, nextTile.getResult());
-  return forOp;
-}
-
 namespace {
 
 /// Conversion pattern for vector.transfer_read.
@@ -223,56 +190,6 @@ struct VectorStoreToArmSMELowering : public OpRewritePattern<vector::StoreOp> {
   }
 };
 
-/// Conversion pattern for dense arith.constant.
-struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> {
-  using OpRewritePattern<arith::ConstantOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(arith::ConstantOp constantOp,
-                                PatternRewriter &rewriter) const final {
-    auto tileType = dyn_cast<VectorType>(constantOp.getType());
-    if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
-      return failure();
-
-    auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
-    if (!denseAttr || !denseAttr.isSplat())
-      return failure();
-
-    auto tileElementType = tileType.getElementType();
-
-    // Lower 'arith.constant dense<0>' to 'arm_sme.zero' op.
-    if (isSplatZero(tileElementType, denseAttr)) {
-      rewriter.replaceOpWithNewOp<arm_sme::ZeroOp>(constantOp, tileType);
-      return success();
-    }
-
-    // Lower non-zero constants to a loop of 'arm_sme.move_vector_to_tile_slice'
-    // ops that broadcast the constant to each tile slice.
-    auto loc = constantOp.getLoc();
-
-    // To fill a tile with a constant, we create a 1-D splat of the constant,
-    // then move that into each tile slice (the largest unit we can set at once,
-    // outside of operations like the outerproduct).
-    VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
-    auto denseAttr1D = DenseElementsAttr::get(
-        tileSliceType, denseAttr.getSplatValue<Attribute>());
-    auto constantOp1D = rewriter.create<arith::ConstantOp>(loc, denseAttr1D);
-
-    auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
-    auto forOp =
-        createLoopOverTileSlices(rewriter, loc, initTile, [&](auto forOp) {
-          auto tileSliceIndex = forOp.getInductionVar();
-          auto currentTile = forOp.getRegionIterArg(0);
-          // Create 'arm_sme.move_vector_to_tile_slice' to write vector to tile
-          // slice.
-          return rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
-              loc, tileType, constantOp1D, currentTile, tileSliceIndex);
-        });
-    rewriter.replaceOp(constantOp, forOp.getResult(0));
-
-    return success();
-  }
-};
-
 /// Conversion pattern for vector.broadcast.
 ///
 /// Example:
@@ -322,16 +239,19 @@ struct BroadcastOpToArmSMELowering
 
     auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
 
+    arm_sme::LoopBodyBuilder loopBody = [&](OpBuilder &b, Location loc,
+                                            Value tileSliceIndex,
+                                            Value currentTile) {
+      // Create 'arm_sme.move_vector_to_tile_slice' to broadcast the value
+      // to each tile slice.
+      auto nextTile = b.create<arm_sme::MoveVectorToTileSliceOp>(
+          loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
+      b.create<scf::YieldOp>(loc, nextTile.getResult());
+      return;
+    };
+
     // Create a loop over ZA tile slices.
-    auto forOp =
-        createLoopOverTileSlices(rewriter, loc, initTile, [&](auto forOp) {
-          auto tileSliceIndex = forOp.getInductionVar();
-          auto currentTile = forOp.getRegionIterArg(0);
-          // Create 'arm_sme.move_vector_to_tile_slice' to broadcast the value
-          // to each tile slice.
-          return rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
-              loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
-        });
+    auto forOp = createLoopOverTileSlices(rewriter, loc, initTile, loopBody);
 
     rewriter.replaceOp(broadcastOp, forOp.getResult(0));
 
@@ -381,15 +301,18 @@ struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> {
 
     auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
 
+    arm_sme::LoopBodyBuilder loopBody = [&](OpBuilder &b, Location loc,
+                                            Value tileSliceIndex,
+                                            Value currentTile) {
+      auto nextTile = b.create<arm_sme::MoveVectorToTileSliceOp>(
+          loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
+      b.create<scf::YieldOp>(loc, nextTile.getResult());
+      return;
+    };
+
     // Next, create a loop over ZA tile slices and "move" the generated 1-d
     // vector to each slice.
-    auto forOp =
-        createLoopOverTileSlices(rewriter, loc, initTile, [&](auto forOp) {
-          auto tileSliceIndex = forOp.getInductionVar();
-          auto currentTile = forOp.getRegionIterArg(0);
-          return rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
-              loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
-        });
+    auto forOp = createLoopOverTileSlices(rewriter, loc, initTile, loopBody);
 
     rewriter.replaceOp(splatOp, forOp.getResult(0));
 
@@ -741,11 +664,10 @@ struct VectorPrintToArmSMELowering : public OpRewritePattern<vector::PrintOp> {
 
 void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
                                           MLIRContext &ctx) {
-  patterns
-      .add<BroadcastOpToArmSMELowering, ConstantOpToArmSMELowering,
-           SplatOpToArmSMELowering, TransferReadToArmSMELowering,
-           TransferWriteToArmSMELowering, TransposeOpToArmSMELowering,
-           VectorLoadToArmSMELowering, VectorStoreToArmSMELowering,
-           VectorOuterProductToArmSMELowering, VectorExtractToArmSMELowering,
-           VectorInsertToArmSMELowering, VectorPrintToArmSMELowering>(&ctx);
+  patterns.add<BroadcastOpToArmSMELowering, SplatOpToArmSMELowering,
+               TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
+               TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
+               VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,
+               VectorExtractToArmSMELowering, VectorInsertToArmSMELowering,
+               VectorPrintToArmSMELowering>(&ctx);
 }
diff --git a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
index 1fa060cafc0bc6..2e159abb1e89eb 100644
--- a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
@@ -72,4 +72,24 @@ LogicalResult verifyOperationHasValidTileId(Operation *op) {
   return success();
 }
 
+scf::ForOp createLoopOverTileSlices(PatternRewriter &rewriter, Location loc,
+                                    Value initTile,
+                                    LoopBodyBuilder bodyBuilder) {
+  OpBuilder::InsertionGuard g(rewriter);
+  auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+  auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
+      loc, llvm::cast<VectorType>(initTile.getType()).getDimSize(0));
+  auto vscale =
+      rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
+  auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+  auto numTileSlices =
+      rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
+  auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step,
+                                           ValueRange{initTile});
+  rewriter.setInsertionPointToStart(forOp.getBody());
+  bodyBuilder(rewriter, loc, /*tileSliceIndex=*/forOp.getInductionVar(),
+              /*currentTile=*/forOp.getRegionIterArg(0));
+  return forOp;
+}
+
 } // namespace mlir::arm_sme
diff --git a/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir b/mlir/test/Conversion/ArithToArmSME/arith-to-arm-sme.mlir
similarity index 97%
rename from mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir
rename to mlir/test/Conversion/ArithToArmSME/arith-to-arm-sme.mlir
index e51f2485dadbcc..49d2e2f3c182b9 100644
--- a...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/78197


More information about the Mlir-commits mailing list