[Mlir-commits] [mlir] [mlir][ArmSME] Add arith-to-arm-sme conversion pass (PR #78197)
Cullen Rhodes
llvmlistbot at llvm.org
Wed Jan 17 09:02:25 PST 2024
https://github.com/c-rhodes updated https://github.com/llvm/llvm-project/pull/78197
>From c07b16830925bd86d2440423a2fd409c26799854 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Mon, 15 Jan 2024 09:31:43 +0000
Subject: [PATCH 1/2] [mlir][ArmSME] Add arith-to-arm-sme conversion pass
Existing 'arith::ConstantOp' conversion and tests are moved from
VectorToArmSME. There's currently only a single op that's converted at
the moment, but this will grow in the future as things like in-tile add
are implemented. Also, 'createLoopOverTileSlices' is moved to ArmSME
utils since it's relevant for both conversions.
---
.../Conversion/ArithToArmSME/ArithToArmSME.h | 27 ++++
mlir/include/mlir/Conversion/Passes.h | 1 +
mlir/include/mlir/Conversion/Passes.td | 9 ++
.../include/mlir/Dialect/ArmSME/Utils/Utils.h | 17 +++
.../ArithToArmSME/ArithToArmSME.cpp | 127 +++++++++++++++++
.../Conversion/ArithToArmSME/CMakeLists.txt | 18 +++
mlir/lib/Conversion/CMakeLists.txt | 1 +
.../VectorToArmSME/VectorToArmSME.cpp | 134 ++++--------------
mlir/lib/Dialect/ArmSME/IR/Utils.cpp | 20 +++
.../ArithToArmSME/arith-to-arm-sme.mlir} | 2 +-
.../Dialect/ArmSME/vector-ops-to-llvm.mlir | 2 +-
.../Dialect/Linalg/CPU/ArmSME/fill-2d.mlir | 3 +-
.../Linalg/CPU/ArmSME/use-too-many-tiles.mlir | 4 +-
.../CPU/ArmSME/test-outerproduct-f32.mlir | 3 +-
.../CPU/ArmSME/test-outerproduct-f64.mlir | 3 +-
.../CPU/ArmSME/test-transfer-write-2d.mlir | 3 +-
.../Dialect/Vector/CPU/ArmSME/tile_fill.mlir | 3 +-
.../Dialect/Vector/CPU/ArmSME/vector-ops.mlir | 3 +-
18 files changed, 264 insertions(+), 116 deletions(-)
create mode 100644 mlir/include/mlir/Conversion/ArithToArmSME/ArithToArmSME.h
create mode 100644 mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp
create mode 100644 mlir/lib/Conversion/ArithToArmSME/CMakeLists.txt
rename mlir/test/{Dialect/ArmSME/arith-ops-to-sme.mlir => Conversion/ArithToArmSME/arith-to-arm-sme.mlir} (97%)
diff --git a/mlir/include/mlir/Conversion/ArithToArmSME/ArithToArmSME.h b/mlir/include/mlir/Conversion/ArithToArmSME/ArithToArmSME.h
new file mode 100644
index 00000000000000..012e7fb5b0af2f
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ArithToArmSME/ArithToArmSME.h
@@ -0,0 +1,27 @@
+//===- ArithToArmSME.h - Arith to ArmSME dialect conversion -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_ARITHTOARMSME_ARITHTOARMSME_H
+#define MLIR_CONVERSION_ARITHTOARMSME_ARITHTOARMSME_H
+
+#include <memory>
+
+namespace mlir {
+
+class RewritePatternSet;
+class Pass;
+
+#define GEN_PASS_DECL_ARITHTOARMSMECONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+
+namespace arith {
+void populateArithToArmSMEConversionPatterns(RewritePatternSet &patterns);
+} // namespace arith
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_ARITHTOARMSME_ARITHTOARMSME_H
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index a25fd17ea923fb..0bfc5064c5dd72 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -12,6 +12,7 @@
#include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h"
+#include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h"
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
#include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 71be8841ca7c03..3467e042c493e9 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -164,6 +164,15 @@ def ConvertArithToSPIRV : Pass<"convert-arith-to-spirv"> {
];
}
+//===----------------------------------------------------------------------===//
+// ArithToArmSME
+//===----------------------------------------------------------------------===//
+
+def ArithToArmSMEConversionPass : Pass<"convert-arith-to-arm-sme"> {
+ let summary = "Convert Arith dialect to ArmSME dialect";
+ let dependentDialects = ["arm_sme::ArmSMEDialect"];
+}
+
//===----------------------------------------------------------------------===//
// ArmNeon2dToIntr
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
index b7d90195d49d76..a15eac7302077b 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
@@ -16,9 +16,16 @@
#define MLIR_DIALECT_ARMSME_UTILS_UTILS_H_
#include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BuiltinTypes.h"
#include <optional>
+namespace mlir {
+class Location;
+class PatternRewriter;
+class Value;
+} // namespace mlir
+
namespace mlir::arm_sme {
constexpr unsigned MinStreamingVectorLengthInBits = 128;
@@ -42,6 +49,16 @@ std::optional<ArmSMETileType> getSMETileType(VectorType);
/// Verifies the tile ID (if set) on this tile operation is valid.
LogicalResult verifyOperationHasValidTileId(Operation *);
+using LoopBodyBuilder =
+ std::function<void(OpBuilder &, Location, Value, Value)>;
+
+/// Generates a for loop over ZA tile slices where the induction variable is
+/// the tile slice index and each iteration yields a new tile. Loop body is
+/// built via LoopBodyBuilder, which returns the next tile value.
+scf::ForOp createLoopOverTileSlices(PatternRewriter &rewriter, Location loc,
+ Value initTile,
+ LoopBodyBuilder bodyBuilder);
+
} // namespace mlir::arm_sme
#endif // MLIR_DIALECT_ARMSME_UTILS_UTILS_H_
diff --git a/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp b/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp
new file mode 100644
index 00000000000000..9aab969881f75e
--- /dev/null
+++ b/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp
@@ -0,0 +1,127 @@
+//===- ArithToArmSME.cpp - Arith to ArmSME dialect conversion -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+#include "mlir/Dialect/ArmSME/Utils/Utils.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_ARITHTOARMSMECONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+#define DEBUG_TYPE "arith-to-arm-sme"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Conversion helpers
+//===----------------------------------------------------------------------===//
+
+/// Returns true if 'val' is a splat of zero, false otherwise.
+static bool isSplatZero(Type elemType, DenseElementsAttr val) {
+ if (llvm::isa<FloatType>(elemType))
+ return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
+ if (llvm::isa<IntegerType>(elemType))
+ return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
+ return false;
+}
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// ConstantOp
+//===----------------------------------------------------------------------===//
+
+/// Conversion pattern for dense arith.constant.
+struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> {
+ using OpRewritePattern<arith::ConstantOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(arith::ConstantOp constantOp,
+ PatternRewriter &rewriter) const final {
+ auto tileType = dyn_cast<VectorType>(constantOp.getType());
+ if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
+ return failure();
+
+ auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
+ if (!denseAttr || !denseAttr.isSplat())
+ return failure();
+
+ auto tileElementType = tileType.getElementType();
+
+ // Lower 'arith.constant dense<0>' to 'arm_sme.zero' op.
+ if (isSplatZero(tileElementType, denseAttr)) {
+ rewriter.replaceOpWithNewOp<arm_sme::ZeroOp>(constantOp, tileType);
+ return success();
+ }
+
+ // Lower non-zero constants to a loop of 'arm_sme.move_vector_to_tile_slice'
+ // ops that broadcast the constant to each tile slice.
+ auto loc = constantOp.getLoc();
+
+ // To fill a tile with a constant, we create a 1-D splat of the constant,
+ // then move that into each tile slice (the largest unit we can set at once,
+ // outside of operations like the outerproduct).
+ VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
+ auto denseAttr1D = DenseElementsAttr::get(
+ tileSliceType, denseAttr.getSplatValue<Attribute>());
+ auto constantOp1D = rewriter.create<arith::ConstantOp>(loc, denseAttr1D);
+
+ auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
+ arm_sme::LoopBodyBuilder loopBody = [&](OpBuilder &b, Location loc,
+ Value tileSliceIndex,
+ Value currentTile) {
+ // Create 'arm_sme.move_vector_to_tile_slice' to write vector to tile
+ // slice.
+ auto nextTile = b.create<arm_sme::MoveVectorToTileSliceOp>(
+ loc, tileType, constantOp1D, currentTile, tileSliceIndex);
+ b.create<scf::YieldOp>(loc, nextTile.getResult());
+ return;
+ };
+ auto forOp = mlir::arm_sme::createLoopOverTileSlices(rewriter, loc,
+ initTile, loopBody);
+ rewriter.replaceOp(constantOp, forOp.getResult(0));
+
+ return success();
+ }
+};
+
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Pattern population
+//===----------------------------------------------------------------------===//
+
+void mlir::arith::populateArithToArmSMEConversionPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<ConstantOpToArmSMELowering>(patterns.getContext());
+}
+
+//===----------------------------------------------------------------------===//
+// Pass definition
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct ArithToArmSMEConversionPass final
+ : impl::ArithToArmSMEConversionPassBase<ArithToArmSMEConversionPass> {
+ using impl::ArithToArmSMEConversionPassBase<
+ ArithToArmSMEConversionPass>::ArithToArmSMEConversionPassBase;
+
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ arith::populateArithToArmSMEConversionPatterns(patterns);
+ if (failed(
+ applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
+ return signalPassFailure();
+ }
+};
+} // namespace
diff --git a/mlir/lib/Conversion/ArithToArmSME/CMakeLists.txt b/mlir/lib/Conversion/ArithToArmSME/CMakeLists.txt
new file mode 100644
index 00000000000000..c2a6fe5398e7c8
--- /dev/null
+++ b/mlir/lib/Conversion/ArithToArmSME/CMakeLists.txt
@@ -0,0 +1,18 @@
+add_mlir_conversion_library(MLIRArithToArmSME
+ ArithToArmSME.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToArmSME
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRArmSMEDialect
+ MLIRArithDialect
+ MLIRPass
+ MLIRTransforms
+ )
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index c3a2481975040c..3a5dbc12c23f5c 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -2,6 +2,7 @@ add_subdirectory(AffineToStandard)
add_subdirectory(AMDGPUToROCDL)
add_subdirectory(ArithCommon)
add_subdirectory(ArithToAMDGPU)
+add_subdirectory(ArithToArmSME)
add_subdirectory(ArithToLLVM)
add_subdirectory(ArithToSPIRV)
add_subdirectory(ArmNeon2dToIntr)
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 87d1bf9bed5a31..88252725bcff26 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -16,39 +16,6 @@
using namespace mlir;
-/// Returns true if 'val' is a splat of zero, false otherwise.
-static bool isSplatZero(Type elemType, DenseElementsAttr val) {
- if (llvm::isa<FloatType>(elemType))
- return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
- if (llvm::isa<IntegerType>(elemType))
- return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
- return false;
-}
-
-/// Generates a for loop over ZA tile slices where the induction variable is
-/// the tile slice index and each iteration yields a new tile. Loop body is
-/// built via the callback, which returns the next tile value.
-template <typename LoopBodyCallback>
-static scf::ForOp createLoopOverTileSlices(PatternRewriter &rewriter,
- Location loc, Value initTile,
- LoopBodyCallback callback) {
- OpBuilder::InsertionGuard g(rewriter);
- auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
- auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
- loc, llvm::cast<VectorType>(initTile.getType()).getDimSize(0));
- auto vscale =
- rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
- auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- auto numTileSlices =
- rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
- auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step,
- ValueRange{initTile});
- rewriter.setInsertionPointToStart(forOp.getBody());
- auto nextTile = callback(forOp);
- rewriter.create<scf::YieldOp>(loc, nextTile.getResult());
- return forOp;
-}
-
namespace {
/// Conversion pattern for vector.transfer_read.
@@ -223,56 +190,6 @@ struct VectorStoreToArmSMELowering : public OpRewritePattern<vector::StoreOp> {
}
};
-/// Conversion pattern for dense arith.constant.
-struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> {
- using OpRewritePattern<arith::ConstantOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(arith::ConstantOp constantOp,
- PatternRewriter &rewriter) const final {
- auto tileType = dyn_cast<VectorType>(constantOp.getType());
- if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
- return failure();
-
- auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
- if (!denseAttr || !denseAttr.isSplat())
- return failure();
-
- auto tileElementType = tileType.getElementType();
-
- // Lower 'arith.constant dense<0>' to 'arm_sme.zero' op.
- if (isSplatZero(tileElementType, denseAttr)) {
- rewriter.replaceOpWithNewOp<arm_sme::ZeroOp>(constantOp, tileType);
- return success();
- }
-
- // Lower non-zero constants to a loop of 'arm_sme.move_vector_to_tile_slice'
- // ops that broadcast the constant to each tile slice.
- auto loc = constantOp.getLoc();
-
- // To fill a tile with a constant, we create a 1-D splat of the constant,
- // then move that into each tile slice (the largest unit we can set at once,
- // outside of operations like the outerproduct).
- VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
- auto denseAttr1D = DenseElementsAttr::get(
- tileSliceType, denseAttr.getSplatValue<Attribute>());
- auto constantOp1D = rewriter.create<arith::ConstantOp>(loc, denseAttr1D);
-
- auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
- auto forOp =
- createLoopOverTileSlices(rewriter, loc, initTile, [&](auto forOp) {
- auto tileSliceIndex = forOp.getInductionVar();
- auto currentTile = forOp.getRegionIterArg(0);
- // Create 'arm_sme.move_vector_to_tile_slice' to write vector to tile
- // slice.
- return rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
- loc, tileType, constantOp1D, currentTile, tileSliceIndex);
- });
- rewriter.replaceOp(constantOp, forOp.getResult(0));
-
- return success();
- }
-};
-
/// Conversion pattern for vector.broadcast.
///
/// Example:
@@ -322,16 +239,19 @@ struct BroadcastOpToArmSMELowering
auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
+ arm_sme::LoopBodyBuilder loopBody = [&](OpBuilder &b, Location loc,
+ Value tileSliceIndex,
+ Value currentTile) {
+ // Create 'arm_sme.move_vector_to_tile_slice' to broadcast the value
+ // to each tile slice.
+ auto nextTile = b.create<arm_sme::MoveVectorToTileSliceOp>(
+ loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
+ b.create<scf::YieldOp>(loc, nextTile.getResult());
+ return;
+ };
+
// Create a loop over ZA tile slices.
- auto forOp =
- createLoopOverTileSlices(rewriter, loc, initTile, [&](auto forOp) {
- auto tileSliceIndex = forOp.getInductionVar();
- auto currentTile = forOp.getRegionIterArg(0);
- // Create 'arm_sme.move_vector_to_tile_slice' to broadcast the value
- // to each tile slice.
- return rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
- loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
- });
+ auto forOp = createLoopOverTileSlices(rewriter, loc, initTile, loopBody);
rewriter.replaceOp(broadcastOp, forOp.getResult(0));
@@ -381,15 +301,18 @@ struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> {
auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
+ arm_sme::LoopBodyBuilder loopBody = [&](OpBuilder &b, Location loc,
+ Value tileSliceIndex,
+ Value currentTile) {
+ auto nextTile = b.create<arm_sme::MoveVectorToTileSliceOp>(
+ loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
+ b.create<scf::YieldOp>(loc, nextTile.getResult());
+ return;
+ };
+
// Next, create a loop over ZA tile slices and "move" the generated 1-d
// vector to each slice.
- auto forOp =
- createLoopOverTileSlices(rewriter, loc, initTile, [&](auto forOp) {
- auto tileSliceIndex = forOp.getInductionVar();
- auto currentTile = forOp.getRegionIterArg(0);
- return rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
- loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
- });
+ auto forOp = createLoopOverTileSlices(rewriter, loc, initTile, loopBody);
rewriter.replaceOp(splatOp, forOp.getResult(0));
@@ -741,11 +664,10 @@ struct VectorPrintToArmSMELowering : public OpRewritePattern<vector::PrintOp> {
void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
MLIRContext &ctx) {
- patterns
- .add<BroadcastOpToArmSMELowering, ConstantOpToArmSMELowering,
- SplatOpToArmSMELowering, TransferReadToArmSMELowering,
- TransferWriteToArmSMELowering, TransposeOpToArmSMELowering,
- VectorLoadToArmSMELowering, VectorStoreToArmSMELowering,
- VectorOuterProductToArmSMELowering, VectorExtractToArmSMELowering,
- VectorInsertToArmSMELowering, VectorPrintToArmSMELowering>(&ctx);
+ patterns.add<BroadcastOpToArmSMELowering, SplatOpToArmSMELowering,
+ TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
+ TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
+ VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,
+ VectorExtractToArmSMELowering, VectorInsertToArmSMELowering,
+ VectorPrintToArmSMELowering>(&ctx);
}
diff --git a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
index 1fa060cafc0bc6..2e159abb1e89eb 100644
--- a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
@@ -72,4 +72,24 @@ LogicalResult verifyOperationHasValidTileId(Operation *op) {
return success();
}
+scf::ForOp createLoopOverTileSlices(PatternRewriter &rewriter, Location loc,
+ Value initTile,
+ LoopBodyBuilder bodyBuilder) {
+ OpBuilder::InsertionGuard g(rewriter);
+ auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
+ loc, llvm::cast<VectorType>(initTile.getType()).getDimSize(0));
+ auto vscale =
+ rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
+ auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ auto numTileSlices =
+ rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
+ auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step,
+ ValueRange{initTile});
+ rewriter.setInsertionPointToStart(forOp.getBody());
+ bodyBuilder(rewriter, loc, /*tileSliceIndex=*/forOp.getInductionVar(),
+ /*currentTile=*/forOp.getRegionIterArg(0));
+ return forOp;
+}
+
} // namespace mlir::arm_sme
diff --git a/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir b/mlir/test/Conversion/ArithToArmSME/arith-to-arm-sme.mlir
similarity index 97%
rename from mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir
rename to mlir/test/Conversion/ArithToArmSME/arith-to-arm-sme.mlir
index e51f2485dadbcc..49d2e2f3c182b9 100644
--- a/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir
+++ b/mlir/test/Conversion/ArithToArmSME/arith-to-arm-sme.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-vector-to-arm-sme -split-input-file -allow-unregistered-dialect | FileCheck %s
+// RUN: mlir-opt %s -convert-arith-to-arm-sme -split-input-file -allow-unregistered-dialect | FileCheck %s
// =============================================================================
// arith.constant dense<0> to arm_sme.zero
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
index ce5bfd25cbdbcc..17a070999c20a0 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-vector-to-arm-sme -allocate-arm-sme-tiles -convert-arm-sme-to-scf -convert-arm-sme-to-llvm -cse -canonicalize -split-input-file -allow-unregistered-dialect -verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s -convert-vector-to-arm-sme -convert-arith-to-arm-sme -allocate-arm-sme-tiles -convert-arm-sme-to-scf -convert-arm-sme-to-llvm -cse -canonicalize -split-input-file -allow-unregistered-dialect -verify-diagnostics | FileCheck %s
//===----------------------------------------------------------------------===//
// vector.transfer_write
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir
index 6314e6f279952b..44ff1afe76d383 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir
@@ -4,7 +4,8 @@
// RUN: -lower-vector-mask \
// RUN: -one-shot-bufferize="bufferize-function-boundaries" \
// RUN: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// RUN: -convert-vector-to-arm-sme -allocate-arm-sme-tiles -convert-arm-sme-to-scf \
+// RUN: -convert-vector-to-arm-sme -convert-arith-to-arm-sme \
+// RUN: -allocate-arm-sme-tiles -convert-arm-sme-to-scf \
// RUN: -convert-arm-sme-to-llvm -cse -canonicalize \
// RUN: -test-lower-to-llvm | \
// RUN: %mcr_aarch64_cmd \
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir
index dd9f280cb75099..42fe21cccd48a7 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir
@@ -1,6 +1,6 @@
// RUN: mlir-opt %s \
-// RUN: -convert-vector-to-arm-sme -allocate-arm-sme-tiles \
-// RUN: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// RUN: -convert-vector-to-arm-sme -convert-arith-to-arm-sme \
+// RUN: -allocate-arm-sme-tiles -convert-arm-sme-to-scf \
// RUN: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za only-if-required-by-ops" \
// RUN: -convert-vector-to-scf -cse -arm-sve-legalize-vector-storage \
// RUN: -convert-arm-sme-to-llvm -convert-vector-to-llvm=enable-arm-sve -cse \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
index 8c73c24d695cfb..5f41b37560e760 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
@@ -1,7 +1,8 @@
// DEFINE: %{entry_point} = test_outerproduct_no_accumulator_4x4xf32
// DEFINE: %{compile} = mlir-opt %s \
// DEFINE: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
+// DEFINE: -convert-vector-to-arm-sme -convert-arith-to-arm-sme \
+// DEFINE: -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
// DEFINE: -convert-arm-sme-to-llvm -cse -canonicalize \
// DEFINE: -test-lower-to-llvm -o %t
// DEFINE: %{run} = %mcr_aarch64_cmd %t \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
index 965337c60b9ffd..a1bb9b7d6f80ec 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
@@ -1,7 +1,8 @@
// DEFINE: %{entry_point} = test_outerproduct_no_accumulator_2x2xf64
// DEFINE: %{compile} = mlir-opt %s \
// DEFINE: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
+// DEFINE: -convert-vector-to-arm-sme -convert-arith-to-arm-sme \
+// DEFINE: -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
// DEFINE: -convert-arm-sme-to-llvm -cse -canonicalize \
// DEFINE: -test-lower-to-llvm -o %t
// DEFINE: %{run} = %mcr_aarch64_cmd %t \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
index cb30fee4e12d72..c0c1f55d7ddd1a 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
@@ -1,6 +1,7 @@
// DEFINE: %{entry_point} = entry
// DEFINE: %{compile} = mlir-opt %s \
-// DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
+// DEFINE: -convert-vector-to-arm-sme -convert-arith-to-arm-sme \
+// DEFINE: -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
// DEFINE: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za only-if-required-by-ops" \
// DEFINE: -convert-arm-sme-to-llvm -cse -canonicalize \
// DEFINE: -test-lower-to-llvm
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir
index b45f24f6c8fdda..223bc8ce74343b 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir
@@ -1,5 +1,6 @@
// RUN: mlir-opt %s -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// RUN: -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
+// RUN: -convert-vector-to-arm-sme -convert-arith-to-arm-sme \
+// RUN: -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
// RUN: -convert-arm-sme-to-llvm -cse -canonicalize \
// RUN: -test-lower-to-llvm | \
// RUN: %mcr_aarch64_cmd \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
index 073c08bff1c415..f28bf19b299934 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
@@ -1,6 +1,7 @@
// DEFINE: %{entry_point} = entry
// DEFINE: %{compile} = mlir-opt %s -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
+// DEFINE: -convert-vector-to-arm-sme -convert-arith-to-arm-sme \
+// DEFINE: -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
// DEFINE: -convert-arm-sme-to-llvm -test-lower-to-llvm
// DEFINE: %{run} = %mcr_aarch64_cmd \
// DEFINE: -march=aarch64 -mattr=+sve,+sme \
>From 716fcc1cea6b02e66c0561de7afc8e4e358c86f8 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Wed, 17 Jan 2024 16:57:10 +0000
Subject: [PATCH 2/2] address comments
---
.../include/mlir/Dialect/ArmSME/Utils/Utils.h | 11 ++++------
.../ArithToArmSME/ArithToArmSME.cpp | 10 ++++------
.../VectorToArmSME/VectorToArmSME.cpp | 20 ++++++++-----------
mlir/lib/Dialect/ArmSME/IR/Utils.cpp | 12 ++++++-----
4 files changed, 23 insertions(+), 30 deletions(-)
diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
index a15eac7302077b..e37581ce00f03c 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
@@ -49,15 +49,12 @@ std::optional<ArmSMETileType> getSMETileType(VectorType);
/// Verifies the tile ID (if set) on this tile operation is valid.
LogicalResult verifyOperationHasValidTileId(Operation *);
-using LoopBodyBuilder =
- std::function<void(OpBuilder &, Location, Value, Value)>;
-
/// Generates a for loop over ZA tile slices where the induction variable is
/// the tile slice index and each iteration yields a new tile. Loop body is
-/// built via LoopBodyBuilder, which returns the next tile value.
-scf::ForOp createLoopOverTileSlices(PatternRewriter &rewriter, Location loc,
- Value initTile,
- LoopBodyBuilder bodyBuilder);
+/// built via the callback, which returns the next tile value.
+scf::ForOp createLoopOverTileSlices(
+ PatternRewriter &rewriter, Location loc, Value initTile,
+ std::function<Value(OpBuilder &, Location, Value, Value)> callback);
} // namespace mlir::arm_sme
diff --git a/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp b/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp
index 9aab969881f75e..2f562ba3e1ce00 100644
--- a/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp
+++ b/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp
@@ -77,18 +77,16 @@ struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> {
auto constantOp1D = rewriter.create<arith::ConstantOp>(loc, denseAttr1D);
auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
- arm_sme::LoopBodyBuilder loopBody = [&](OpBuilder &b, Location loc,
- Value tileSliceIndex,
- Value currentTile) {
+ auto callback = [&](OpBuilder &b, Location loc, Value tileSliceIndex,
+ Value currentTile) {
// Create 'arm_sme.move_vector_to_tile_slice' to write vector to tile
// slice.
auto nextTile = b.create<arm_sme::MoveVectorToTileSliceOp>(
loc, tileType, constantOp1D, currentTile, tileSliceIndex);
- b.create<scf::YieldOp>(loc, nextTile.getResult());
- return;
+ return nextTile.getResult();
};
auto forOp = mlir::arm_sme::createLoopOverTileSlices(rewriter, loc,
- initTile, loopBody);
+ initTile, callback);
rewriter.replaceOp(constantOp, forOp.getResult(0));
return success();
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 88252725bcff26..0d1c092b2079e3 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -239,19 +239,17 @@ struct BroadcastOpToArmSMELowering
auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
- arm_sme::LoopBodyBuilder loopBody = [&](OpBuilder &b, Location loc,
- Value tileSliceIndex,
- Value currentTile) {
+ auto callback = [&](OpBuilder &b, Location loc, Value tileSliceIndex,
+ Value currentTile) {
// Create 'arm_sme.move_vector_to_tile_slice' to broadcast the value
// to each tile slice.
auto nextTile = b.create<arm_sme::MoveVectorToTileSliceOp>(
loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
- b.create<scf::YieldOp>(loc, nextTile.getResult());
- return;
+ return nextTile.getResult();
};
// Create a loop over ZA tile slices.
- auto forOp = createLoopOverTileSlices(rewriter, loc, initTile, loopBody);
+ auto forOp = createLoopOverTileSlices(rewriter, loc, initTile, callback);
rewriter.replaceOp(broadcastOp, forOp.getResult(0));
@@ -301,18 +299,16 @@ struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> {
auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
- arm_sme::LoopBodyBuilder loopBody = [&](OpBuilder &b, Location loc,
- Value tileSliceIndex,
- Value currentTile) {
+ auto callback = [&](OpBuilder &b, Location loc, Value tileSliceIndex,
+ Value currentTile) {
auto nextTile = b.create<arm_sme::MoveVectorToTileSliceOp>(
loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
- b.create<scf::YieldOp>(loc, nextTile.getResult());
- return;
+ return nextTile.getResult();
};
// Next, create a loop over ZA tile slices and "move" the generated 1-d
// vector to each slice.
- auto forOp = createLoopOverTileSlices(rewriter, loc, initTile, loopBody);
+ auto forOp = createLoopOverTileSlices(rewriter, loc, initTile, callback);
rewriter.replaceOp(splatOp, forOp.getResult(0));
diff --git a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
index 2e159abb1e89eb..916691a1c7b9bc 100644
--- a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
@@ -72,9 +72,9 @@ LogicalResult verifyOperationHasValidTileId(Operation *op) {
return success();
}
-scf::ForOp createLoopOverTileSlices(PatternRewriter &rewriter, Location loc,
- Value initTile,
- LoopBodyBuilder bodyBuilder) {
+scf::ForOp createLoopOverTileSlices(
+ PatternRewriter &rewriter, Location loc, Value initTile,
+ std::function<Value(OpBuilder &, Location, Value, Value)> callback) {
OpBuilder::InsertionGuard g(rewriter);
auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
@@ -87,8 +87,10 @@ scf::ForOp createLoopOverTileSlices(PatternRewriter &rewriter, Location loc,
auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step,
ValueRange{initTile});
rewriter.setInsertionPointToStart(forOp.getBody());
- bodyBuilder(rewriter, loc, /*tileSliceIndex=*/forOp.getInductionVar(),
- /*currentTile=*/forOp.getRegionIterArg(0));
+ Value nextTile =
+ callback(rewriter, loc, /*tileSliceIndex=*/forOp.getInductionVar(),
+ /*currentTile=*/forOp.getRegionIterArg(0));
+ rewriter.create<scf::YieldOp>(loc, nextTile);
return forOp;
}
More information about the Mlir-commits
mailing list