[Mlir-commits] [mlir] [mlir][ArmSME] Add arith-to-arm-sme conversion pass (PR #78197)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jan 15 10:11:28 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector
@llvm/pr-subscribers-mlir-linalg
Author: Cullen Rhodes (c-rhodes)
<details>
<summary>Changes</summary>
Existing 'arith::ConstantOp' conversion and tests are moved from VectorToArmSME. There's currently only a single op that's converted at the moment, but this will grow in the future as things like in-tile add are implemented. Also, 'createLoopOverTileSlices' is moved to ArmSME utils since it's relevant for both conversions.
---
Patch is 27.35 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/78197.diff
18 Files Affected:
- (added) mlir/include/mlir/Conversion/ArithToArmSME/ArithToArmSME.h (+27)
- (modified) mlir/include/mlir/Conversion/Passes.h (+1)
- (modified) mlir/include/mlir/Conversion/Passes.td (+9)
- (modified) mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h (+17)
- (added) mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp (+127)
- (added) mlir/lib/Conversion/ArithToArmSME/CMakeLists.txt (+18)
- (modified) mlir/lib/Conversion/CMakeLists.txt (+1)
- (modified) mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp (+28-106)
- (modified) mlir/lib/Dialect/ArmSME/IR/Utils.cpp (+20)
- (renamed) mlir/test/Conversion/ArithToArmSME/arith-to-arm-sme.mlir (+1-1)
- (modified) mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir (+1-1)
- (modified) mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir (+2-1)
- (modified) mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir (+2-2)
- (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir (+2-1)
- (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir (+2-1)
- (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir (+2-1)
- (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir (+2-1)
- (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir (+2-1)
``````````diff
diff --git a/mlir/include/mlir/Conversion/ArithToArmSME/ArithToArmSME.h b/mlir/include/mlir/Conversion/ArithToArmSME/ArithToArmSME.h
new file mode 100644
index 00000000000000..012e7fb5b0af2f
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ArithToArmSME/ArithToArmSME.h
@@ -0,0 +1,27 @@
+//===- ArithToArmSME.h - Arith to ArmSME dialect conversion -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_ARITHTOARMSME_ARITHTOARMSME_H
+#define MLIR_CONVERSION_ARITHTOARMSME_ARITHTOARMSME_H
+
+#include <memory>
+
+namespace mlir {
+
+class RewritePatternSet;
+class Pass;
+
+#define GEN_PASS_DECL_ARITHTOARMSMECONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+
+namespace arith {
+void populateArithToArmSMEConversionPatterns(RewritePatternSet &patterns);
+} // namespace arith
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_ARITHTOARMSME_ARITHTOARMSME_H
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index a25fd17ea923fb..0bfc5064c5dd72 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -12,6 +12,7 @@
#include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h"
+#include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h"
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
#include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 71be8841ca7c03..3467e042c493e9 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -164,6 +164,15 @@ def ConvertArithToSPIRV : Pass<"convert-arith-to-spirv"> {
];
}
+//===----------------------------------------------------------------------===//
+// ArithToArmSME
+//===----------------------------------------------------------------------===//
+
+def ArithToArmSMEConversionPass : Pass<"convert-arith-to-arm-sme"> {
+ let summary = "Convert Arith dialect to ArmSME dialect";
+ let dependentDialects = ["arm_sme::ArmSMEDialect"];
+}
+
//===----------------------------------------------------------------------===//
// ArmNeon2dToIntr
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
index b7d90195d49d76..a15eac7302077b 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
@@ -16,9 +16,16 @@
#define MLIR_DIALECT_ARMSME_UTILS_UTILS_H_
#include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BuiltinTypes.h"
#include <optional>
+namespace mlir {
+class Location;
+class PatternRewriter;
+class Value;
+} // namespace mlir
+
namespace mlir::arm_sme {
constexpr unsigned MinStreamingVectorLengthInBits = 128;
@@ -42,6 +49,16 @@ std::optional<ArmSMETileType> getSMETileType(VectorType);
/// Verifies the tile ID (if set) on this tile operation is valid.
LogicalResult verifyOperationHasValidTileId(Operation *);
+using LoopBodyBuilder =
+ std::function<void(OpBuilder &, Location, Value, Value)>;
+
+/// Generates a for loop over ZA tile slices where the induction variable is
+/// the tile slice index and each iteration yields a new tile. Loop body is
+/// built via LoopBodyBuilder, which returns the next tile value.
+scf::ForOp createLoopOverTileSlices(PatternRewriter &rewriter, Location loc,
+ Value initTile,
+ LoopBodyBuilder bodyBuilder);
+
} // namespace mlir::arm_sme
#endif // MLIR_DIALECT_ARMSME_UTILS_UTILS_H_
diff --git a/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp b/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp
new file mode 100644
index 00000000000000..9aab969881f75e
--- /dev/null
+++ b/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp
@@ -0,0 +1,127 @@
+//===- ArithToArmSME.cpp - Arith to ArmSME dialect conversion -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+#include "mlir/Dialect/ArmSME/Utils/Utils.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_ARITHTOARMSMECONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+#define DEBUG_TYPE "arith-to-arm-sme"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Conversion helpers
+//===----------------------------------------------------------------------===//
+
+/// Returns true if 'val' is a splat of zero, false otherwise.
+static bool isSplatZero(Type elemType, DenseElementsAttr val) {
+ if (llvm::isa<FloatType>(elemType))
+ return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
+ if (llvm::isa<IntegerType>(elemType))
+ return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
+ return false;
+}
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// ConstantOp
+//===----------------------------------------------------------------------===//
+
+/// Conversion pattern for dense arith.constant.
+struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> {
+ using OpRewritePattern<arith::ConstantOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(arith::ConstantOp constantOp,
+ PatternRewriter &rewriter) const final {
+ auto tileType = dyn_cast<VectorType>(constantOp.getType());
+ if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
+ return failure();
+
+ auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
+ if (!denseAttr || !denseAttr.isSplat())
+ return failure();
+
+ auto tileElementType = tileType.getElementType();
+
+ // Lower 'arith.constant dense<0>' to 'arm_sme.zero' op.
+ if (isSplatZero(tileElementType, denseAttr)) {
+ rewriter.replaceOpWithNewOp<arm_sme::ZeroOp>(constantOp, tileType);
+ return success();
+ }
+
+ // Lower non-zero constants to a loop of 'arm_sme.move_vector_to_tile_slice'
+ // ops that broadcast the constant to each tile slice.
+ auto loc = constantOp.getLoc();
+
+ // To fill a tile with a constant, we create a 1-D splat of the constant,
+ // then move that into each tile slice (the largest unit we can set at once,
+ // outside of operations like the outerproduct).
+ VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
+ auto denseAttr1D = DenseElementsAttr::get(
+ tileSliceType, denseAttr.getSplatValue<Attribute>());
+ auto constantOp1D = rewriter.create<arith::ConstantOp>(loc, denseAttr1D);
+
+ auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
+ arm_sme::LoopBodyBuilder loopBody = [&](OpBuilder &b, Location loc,
+ Value tileSliceIndex,
+ Value currentTile) {
+ // Create 'arm_sme.move_vector_to_tile_slice' to write vector to tile
+ // slice.
+ auto nextTile = b.create<arm_sme::MoveVectorToTileSliceOp>(
+ loc, tileType, constantOp1D, currentTile, tileSliceIndex);
+ b.create<scf::YieldOp>(loc, nextTile.getResult());
+ return;
+ };
+ auto forOp = mlir::arm_sme::createLoopOverTileSlices(rewriter, loc,
+ initTile, loopBody);
+ rewriter.replaceOp(constantOp, forOp.getResult(0));
+
+ return success();
+ }
+};
+
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Pattern population
+//===----------------------------------------------------------------------===//
+
+void mlir::arith::populateArithToArmSMEConversionPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<ConstantOpToArmSMELowering>(patterns.getContext());
+}
+
+//===----------------------------------------------------------------------===//
+// Pass definition
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct ArithToArmSMEConversionPass final
+ : impl::ArithToArmSMEConversionPassBase<ArithToArmSMEConversionPass> {
+ using impl::ArithToArmSMEConversionPassBase<
+ ArithToArmSMEConversionPass>::ArithToArmSMEConversionPassBase;
+
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ arith::populateArithToArmSMEConversionPatterns(patterns);
+ if (failed(
+ applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
+ return signalPassFailure();
+ }
+};
+} // namespace
diff --git a/mlir/lib/Conversion/ArithToArmSME/CMakeLists.txt b/mlir/lib/Conversion/ArithToArmSME/CMakeLists.txt
new file mode 100644
index 00000000000000..c2a6fe5398e7c8
--- /dev/null
+++ b/mlir/lib/Conversion/ArithToArmSME/CMakeLists.txt
@@ -0,0 +1,18 @@
+add_mlir_conversion_library(MLIRArithToArmSME
+ ArithToArmSME.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToArmSME
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRArmSMEDialect
+ MLIRArithDialect
+ MLIRPass
+ MLIRTransforms
+ )
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index c3a2481975040c..3a5dbc12c23f5c 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -2,6 +2,7 @@ add_subdirectory(AffineToStandard)
add_subdirectory(AMDGPUToROCDL)
add_subdirectory(ArithCommon)
add_subdirectory(ArithToAMDGPU)
+add_subdirectory(ArithToArmSME)
add_subdirectory(ArithToLLVM)
add_subdirectory(ArithToSPIRV)
add_subdirectory(ArmNeon2dToIntr)
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 87d1bf9bed5a31..88252725bcff26 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -16,39 +16,6 @@
using namespace mlir;
-/// Returns true if 'val' is a splat of zero, false otherwise.
-static bool isSplatZero(Type elemType, DenseElementsAttr val) {
- if (llvm::isa<FloatType>(elemType))
- return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
- if (llvm::isa<IntegerType>(elemType))
- return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
- return false;
-}
-
-/// Generates a for loop over ZA tile slices where the induction variable is
-/// the tile slice index and each iteration yields a new tile. Loop body is
-/// built via the callback, which returns the next tile value.
-template <typename LoopBodyCallback>
-static scf::ForOp createLoopOverTileSlices(PatternRewriter &rewriter,
- Location loc, Value initTile,
- LoopBodyCallback callback) {
- OpBuilder::InsertionGuard g(rewriter);
- auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
- auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
- loc, llvm::cast<VectorType>(initTile.getType()).getDimSize(0));
- auto vscale =
- rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
- auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- auto numTileSlices =
- rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
- auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step,
- ValueRange{initTile});
- rewriter.setInsertionPointToStart(forOp.getBody());
- auto nextTile = callback(forOp);
- rewriter.create<scf::YieldOp>(loc, nextTile.getResult());
- return forOp;
-}
-
namespace {
/// Conversion pattern for vector.transfer_read.
@@ -223,56 +190,6 @@ struct VectorStoreToArmSMELowering : public OpRewritePattern<vector::StoreOp> {
}
};
-/// Conversion pattern for dense arith.constant.
-struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> {
- using OpRewritePattern<arith::ConstantOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(arith::ConstantOp constantOp,
- PatternRewriter &rewriter) const final {
- auto tileType = dyn_cast<VectorType>(constantOp.getType());
- if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
- return failure();
-
- auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
- if (!denseAttr || !denseAttr.isSplat())
- return failure();
-
- auto tileElementType = tileType.getElementType();
-
- // Lower 'arith.constant dense<0>' to 'arm_sme.zero' op.
- if (isSplatZero(tileElementType, denseAttr)) {
- rewriter.replaceOpWithNewOp<arm_sme::ZeroOp>(constantOp, tileType);
- return success();
- }
-
- // Lower non-zero constants to a loop of 'arm_sme.move_vector_to_tile_slice'
- // ops that broadcast the constant to each tile slice.
- auto loc = constantOp.getLoc();
-
- // To fill a tile with a constant, we create a 1-D splat of the constant,
- // then move that into each tile slice (the largest unit we can set at once,
- // outside of operations like the outerproduct).
- VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
- auto denseAttr1D = DenseElementsAttr::get(
- tileSliceType, denseAttr.getSplatValue<Attribute>());
- auto constantOp1D = rewriter.create<arith::ConstantOp>(loc, denseAttr1D);
-
- auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
- auto forOp =
- createLoopOverTileSlices(rewriter, loc, initTile, [&](auto forOp) {
- auto tileSliceIndex = forOp.getInductionVar();
- auto currentTile = forOp.getRegionIterArg(0);
- // Create 'arm_sme.move_vector_to_tile_slice' to write vector to tile
- // slice.
- return rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
- loc, tileType, constantOp1D, currentTile, tileSliceIndex);
- });
- rewriter.replaceOp(constantOp, forOp.getResult(0));
-
- return success();
- }
-};
-
/// Conversion pattern for vector.broadcast.
///
/// Example:
@@ -322,16 +239,19 @@ struct BroadcastOpToArmSMELowering
auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
+ arm_sme::LoopBodyBuilder loopBody = [&](OpBuilder &b, Location loc,
+ Value tileSliceIndex,
+ Value currentTile) {
+ // Create 'arm_sme.move_vector_to_tile_slice' to broadcast the value
+ // to each tile slice.
+ auto nextTile = b.create<arm_sme::MoveVectorToTileSliceOp>(
+ loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
+ b.create<scf::YieldOp>(loc, nextTile.getResult());
+ return;
+ };
+
// Create a loop over ZA tile slices.
- auto forOp =
- createLoopOverTileSlices(rewriter, loc, initTile, [&](auto forOp) {
- auto tileSliceIndex = forOp.getInductionVar();
- auto currentTile = forOp.getRegionIterArg(0);
- // Create 'arm_sme.move_vector_to_tile_slice' to broadcast the value
- // to each tile slice.
- return rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
- loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
- });
+ auto forOp = createLoopOverTileSlices(rewriter, loc, initTile, loopBody);
rewriter.replaceOp(broadcastOp, forOp.getResult(0));
@@ -381,15 +301,18 @@ struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> {
auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
+ arm_sme::LoopBodyBuilder loopBody = [&](OpBuilder &b, Location loc,
+ Value tileSliceIndex,
+ Value currentTile) {
+ auto nextTile = b.create<arm_sme::MoveVectorToTileSliceOp>(
+ loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
+ b.create<scf::YieldOp>(loc, nextTile.getResult());
+ return;
+ };
+
// Next, create a loop over ZA tile slices and "move" the generated 1-d
// vector to each slice.
- auto forOp =
- createLoopOverTileSlices(rewriter, loc, initTile, [&](auto forOp) {
- auto tileSliceIndex = forOp.getInductionVar();
- auto currentTile = forOp.getRegionIterArg(0);
- return rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
- loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
- });
+ auto forOp = createLoopOverTileSlices(rewriter, loc, initTile, loopBody);
rewriter.replaceOp(splatOp, forOp.getResult(0));
@@ -741,11 +664,10 @@ struct VectorPrintToArmSMELowering : public OpRewritePattern<vector::PrintOp> {
void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
MLIRContext &ctx) {
- patterns
- .add<BroadcastOpToArmSMELowering, ConstantOpToArmSMELowering,
- SplatOpToArmSMELowering, TransferReadToArmSMELowering,
- TransferWriteToArmSMELowering, TransposeOpToArmSMELowering,
- VectorLoadToArmSMELowering, VectorStoreToArmSMELowering,
- VectorOuterProductToArmSMELowering, VectorExtractToArmSMELowering,
- VectorInsertToArmSMELowering, VectorPrintToArmSMELowering>(&ctx);
+ patterns.add<BroadcastOpToArmSMELowering, SplatOpToArmSMELowering,
+ TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
+ TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
+ VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,
+ VectorExtractToArmSMELowering, VectorInsertToArmSMELowering,
+ VectorPrintToArmSMELowering>(&ctx);
}
diff --git a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
index 1fa060cafc0bc6..2e159abb1e89eb 100644
--- a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
@@ -72,4 +72,24 @@ LogicalResult verifyOperationHasValidTileId(Operation *op) {
return success();
}
+scf::ForOp createLoopOverTileSlices(PatternRewriter &rewriter, Location loc,
+ Value initTile,
+ LoopBodyBuilder bodyBuilder) {
+ OpBuilder::InsertionGuard g(rewriter);
+ auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
+ loc, llvm::cast<VectorType>(initTile.getType()).getDimSize(0));
+ auto vscale =
+ rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
+ auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ auto numTileSlices =
+ rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
+ auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step,
+ ValueRange{initTile});
+ rewriter.setInsertionPointToStart(forOp.getBody());
+ bodyBuilder(rewriter, loc, /*tileSliceIndex=*/forOp.getInductionVar(),
+ /*currentTile=*/forOp.getRegionIterArg(0));
+ return forOp;
+}
+
} // namespace mlir::arm_sme
diff --git a/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir b/mlir/test/Conversion/ArithToArmSME/arith-to-arm-sme.mlir
similarity index 97%
rename from mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir
rename to mlir/test/Conversion/ArithToArmSME/arith-to-arm-sme.mlir
index e51f2485dadbcc..49d2e2f3c182b9 100644
--- a...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/78197
More information about the Mlir-commits
mailing list