[Mlir-commits] [mlir] [mlir][spirv] Implement vector type legalization in function signatures (PR #98337)
Angel Zhang
llvmlistbot at llvm.org
Wed Jul 10 12:54:23 PDT 2024
https://github.com/angelz913 updated https://github.com/llvm/llvm-project/pull/98337
>From b33d3726d916dc03927d01616db195f292ccf410 Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Thu, 27 Jun 2024 15:44:27 +0000
Subject: [PATCH 01/14] [mlir][spirv] Implement vector type legalization in
function signatures
---
mlir/include/mlir/Conversion/Passes.td | 5 +-
.../Vector/Transforms/VectorRewritePatterns.h | 4 +
.../ConvertToSPIRV/ConvertToSPIRVPass.cpp | 145 ++++++++++++-
.../Dialect/Vector/Transforms/CMakeLists.txt | 1 +
.../Vector/Transforms/VectorUnroll.cpp | 201 ++++++++++++++++++
5 files changed, 353 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 560b088dbe5cd..8d83343f5b736 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -40,7 +40,10 @@ def ConvertToSPIRVPass : Pass<"convert-to-spirv"> {
let description = [{
This is a generic pass to convert to SPIR-V.
}];
- let dependentDialects = ["spirv::SPIRVDialect"];
+ let dependentDialects = [
+ "spirv::SPIRVDialect",
+ "vector::VectorDialect",
+ ];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 8e6d36f0b5f09..5c06d6d4d6ad3 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -293,6 +293,10 @@ void populateVectorUnrollPatterns(RewritePatternSet &patterns,
const UnrollVectorOptions &options,
PatternBenefit benefit = 1);
+void populateVectorUnrollFuncSignaturePatterns(RewritePatternSet &patterns,
+ const UnrollVectorOptions &options,
+ PatternBenefit benefit = 1);
+
/// Collect a set of vector.shape_cast folding patterns.
void populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
diff --git a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
index b5be4654bcb25..54152c5be26fa 100644
--- a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
@@ -34,6 +34,66 @@ namespace mlir {
using namespace mlir;
+//===----------------------------------------------------------------------===//
+// Vector Lowering
+//===----------------------------------------------------------------------===//
+
+int getComputeVectorSize(int64_t size) {
+ for (int i : {4, 3, 2}) {
+ if (size % i == 0)
+ return i;
+ }
+ return 1;
+}
+
+SmallVector<int64_t> getNativeVectorShapeImpl(vector::MultiDimReductionOp op) {
+ // Unroll all reduction dimensions by size 1 for vector.multi_reduction.
+ VectorType srcVectorType = op.getSourceVectorType();
+ auto nativeSize = llvm::to_vector(srcVectorType.getShape());
+ auto dims = op.getReductionDims().getAsValueRange<IntegerAttr>();
+ for (const auto &dimAttr : dims) {
+ nativeSize[dimAttr.getZExtValue()] = 1;
+ }
+ return nativeSize;
+}
+
+SmallVector<int64_t> getNativeVectorShapeImpl(vector::ReductionOp op) {
+ VectorType srcVectorType = op.getSourceVectorType();
+ assert(srcVectorType.getRank() == 1); // Guaranteed by semantics
+ int64_t vectorSize = getComputeVectorSize(srcVectorType.getDimSize(0));
+ return {vectorSize};
+}
+
+SmallVector<int64_t> getNativeVectorShapeImpl(vector::TransposeOp op) {
+ VectorType vectorType = op.getResultVectorType();
+ SmallVector<int64_t> nativeSize(vectorType.getRank(), 1);
+ nativeSize.back() = getComputeVectorSize(vectorType.getShape().back());
+ return nativeSize;
+}
+
+SmallVector<int64_t> getNativeVectorShapeImpl(vector::GatherOp op) {
+ VectorType vectorType = op.getVectorType();
+ SmallVector<int64_t> nativeSize(vectorType.getRank(), 1);
+ nativeSize.back() = getComputeVectorSize(vectorType.getShape().back());
+ return nativeSize;
+}
+
+std::optional<SmallVector<int64_t>> getNativeVectorShape(Operation *op) {
+ if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1) {
+ if (auto vecType = llvm::dyn_cast<VectorType>(op->getResultTypes()[0])) {
+ SmallVector<int64_t> nativeSize(vecType.getRank(), 1);
+ nativeSize.back() = getComputeVectorSize(vecType.getShape().back());
+ return nativeSize;
+ }
+ }
+
+ return TypeSwitch<Operation *, std::optional<SmallVector<int64_t>>>(op)
+ .Case<vector::MultiDimReductionOp, vector::ReductionOp,
+ vector::TransposeOp, vector::GatherOp>(
+ [](auto typedOp) { return getNativeVectorShapeImpl(typedOp); })
+ .Default([](Operation *) { return std::nullopt; });
+}
+
namespace {
/// A pass to perform the SPIR-V conversion.
@@ -47,13 +107,94 @@ struct ConvertToSPIRVPass final
spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
SPIRVTypeConverter typeConverter(targetAttr);
+ // Unroll vectors in function signature to native vector size.
+ {
+ llvm::errs() << "Start unrolling function signature\n";
+ RewritePatternSet patterns(context);
+ // TODO: This is hardcoded to unroll with size 1. Change this later
+ SmallVector<int64_t> nativeShape(1, 1);
+ auto options = vector::UnrollVectorOptions().setNativeShape(nativeShape);
+ populateVectorUnrollFuncSignaturePatterns(patterns, options);
+ GreedyRewriteConfig config;
+ config.strictMode = GreedyRewriteStrictness::ExistingOps;
+ if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
+ return signalPassFailure();
+ llvm::errs() << "Finish unrolling function signature\n";
+ }
+
+ // Unroll vectors to native vector size.
+ {
+ RewritePatternSet patterns(context);
+ auto options = vector::UnrollVectorOptions().setNativeShapeFn(
+ [=](auto op) { return getNativeVectorShape(op); });
+ populateVectorUnrollPatterns(patterns, options);
+ if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
+ return signalPassFailure();
+ }
+
+ // Next run canonicalization to cast away leading size-1 dimensions.
+ {
+ RewritePatternSet patterns(context);
+
+ // We need to pull in casting way leading one dims to allow cancelling
+ // some read/write ops.
+ vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
+
+ // We may have vector.insert_strided_slice inserting 1-D native vectors
+ // into n-D larger vectors with the above. Break that down too. This is a
+ // companion transformation of unrolling.
+ vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
+ patterns);
+ vector::ExtractOp::getCanonicalizationPatterns(patterns, context);
+
+ // Trimming leading unit dims may generate broadcast/shape_cast ops. Clean
+ // them up.
+ vector::BroadcastOp::getCanonicalizationPatterns(patterns, context);
+ vector::ShapeCastOp::getCanonicalizationPatterns(patterns, context);
+
+ vector::TransferReadOp::getCanonicalizationPatterns(patterns, context);
+ vector::TransferWriteOp::getCanonicalizationPatterns(patterns, context);
+
+ if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
+ return signalPassFailure();
+ }
+
+ // Convert vector.extract_strided_slice into a chain of vector.extract and
+ // then a chain of vector.insert ops. This helps to cancel with previous
+ // vector.insert/extract ops, especially for fP16 cases where we have
+ // mismatched vector size for transfer and compute.
+ {
+ RewritePatternSet patterns(context);
+ vector::populateVectorExtractStridedSliceToExtractInsertChainPatterns(
+ patterns, [](vector::ExtractStridedSliceOp op) {
+ return op.getSourceVectorType().getNumElements() > 4;
+ });
+ vector::InsertOp::getCanonicalizationPatterns(patterns, context);
+ vector::ExtractOp::getCanonicalizationPatterns(patterns, context);
+ if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
+ return signalPassFailure();
+ }
+
+ // Run all sorts of canonicalization patterns to clean up again.
+ {
+ RewritePatternSet patterns(context);
+ vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
+ vector::InsertOp::getCanonicalizationPatterns(patterns, context);
+ vector::ExtractOp::getCanonicalizationPatterns(patterns, context);
+ vector::TransferReadOp::getCanonicalizationPatterns(patterns, context);
+ vector::TransferWriteOp::getCanonicalizationPatterns(patterns, context);
+ vector::ReductionOp::getCanonicalizationPatterns(patterns, context);
+ if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
+ return signalPassFailure();
+ }
+
RewritePatternSet patterns(context);
ScfToSPIRVContext scfToSPIRVContext;
- // Populate patterns.
+ // Populate patterns for each dialect.
arith::populateCeilFloorDivExpandOpsPatterns(patterns);
arith::populateArithToSPIRVPatterns(typeConverter, patterns);
- populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
+ // populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
populateFuncToSPIRVPatterns(typeConverter, patterns);
index::populateIndexToSPIRVPatterns(typeConverter, patterns);
populateVectorToSPIRVPatterns(typeConverter, patterns);
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index 723b2f62d65d4..1538c7eed6e76 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -43,6 +43,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
MLIRMemRefDialect
MLIRMemRefUtils
MLIRSCFDialect
+ MLIRSPIRVDialect
MLIRSideEffectInterfaces
MLIRSubsetOpInterface
MLIRTensorDialect
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index b3f558c3bac12..b63cb502b76e8 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -11,12 +11,26 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
+#include "mlir/IR/Block.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/TypeRange.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/VectorInterfaces.h"
+#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/iterator_range.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include <numeric>
#include <optional>
@@ -65,6 +79,32 @@ static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc,
resultTypes, op->getAttrs());
}
+static std::optional<SmallVector<int64_t>>
+getTargetShape(const vector::UnrollVectorOptions &options, func::FuncOp funcOp,
+ VectorType vecType) {
+ assert(options.nativeShape &&
+ "vector unrolling expects the native shape or native"
+ "shape call back function to be set");
+ llvm::errs() << "Get target shape\n";
+ SmallVector<int64_t> unrollShape = llvm::to_vector<4>(vecType.getShape());
+ std::optional<SmallVector<int64_t>> targetShape = options.nativeShape(funcOp);
+ if (!targetShape) {
+ llvm::errs() << "--no unrolling target shape defined\n";
+ return std::nullopt;
+ }
+ auto maybeShapeRatio = computeShapeRatio(unrollShape, *targetShape);
+ if (!maybeShapeRatio) {
+ llvm::errs() << "--could not compute integral shape ratio -> BAIL\n";
+ return std::nullopt;
+ }
+ if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) {
+ llvm::errs() << "--no unrolling needed -> SKIP\n";
+ return std::nullopt;
+ }
+ llvm::errs() << "--found an integral shape ratio to unroll to -> SUCCESS\n";
+ return targetShape;
+}
+
/// Return the target shape for unrolling for the given `op`. Return
/// std::nullopt if the op shouldn't be or cannot be unrolled.
static std::optional<SmallVector<int64_t>>
@@ -617,6 +657,160 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
vector::UnrollVectorOptions options;
};
+struct UnrollFuncSignaturePattern : OpRewritePattern<func::FuncOp> {
+ UnrollFuncSignaturePattern(MLIRContext *context,
+ const vector::UnrollVectorOptions &options,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<func::FuncOp>(context, benefit), options(options) {}
+
+ LogicalResult matchAndRewrite(func::FuncOp funcOp,
+ PatternRewriter &rewriter) const override {
+ llvm::errs() << "Run unroll function signature pattern\n";
+
+ auto fnType = funcOp.getFunctionType();
+
+ // Check function inputs.
+ Location loc = funcOp.getFunctionBody()
+ .getBlocks()
+ .begin()
+ ->getOperations()
+ .begin()
+ ->getLoc();
+ size_t newArgIndex = 0;
+ std::vector<Type> newSignature;
+ std::vector<std::vector<size_t>> newArgMap(fnType.getNumInputs());
+
+ for (const auto &argType : enumerate(fnType.getInputs())) {
+ size_t index = argType.index();
+ Type type = argType.value();
+ auto vecType = llvm::dyn_cast<VectorType>(type);
+ if (!vecType) {
+ newSignature.push_back(type);
+ newArgMap[index].push_back(newArgIndex);
+ newArgIndex++;
+ continue;
+ }
+ // Try vector unrolling
+ llvm::errs() << "Try vector unrolling\n";
+ SmallVector<int64_t> originalShape =
+ llvm::to_vector<4>(vecType.getShape());
+ auto targetShape = getTargetShape(options, funcOp, vecType);
+ if (!targetShape) {
+ llvm::errs() << "No target shape\n";
+ newSignature.push_back(type);
+ newArgMap[index].push_back(newArgIndex);
+ newArgIndex++;
+ continue;
+ }
+ llvm::errs() << "Got target shape\n";
+ VectorType unrolledType =
+ VectorType::get(*targetShape, vecType.getElementType());
+ llvm::errs() << "Unrolled type is ";
+ unrolledType.dump();
+
+ for (SmallVector<int64_t> offsets :
+ StaticTileOffsetRange(originalShape, *targetShape)) {
+ newSignature.push_back(unrolledType);
+ newArgMap[index].push_back(newArgIndex);
+ newArgIndex++;
+ }
+ }
+
+ // Assume there is a single result for now.
+ Type originalResultType = fnType.getResult(0);
+
+ // TODO: Handle illegal vector types in results as well.
+ // SmallVector<Type> resultTypes;
+ // auto vecType = llvm::dyn_cast<VectorType>(originalResultType);
+
+ // if (vecType) {
+ // // Try vector unrolling
+ // SmallVector<int64_t> originalShape =
+ // llvm::to_vector<4>(vecType.getShape()); auto targetShape =
+ // getTargetShape(options, funcOp, vecType); VectorType unrolledType =
+ // VectorType::get(*targetShape, vecType.getElementType());
+ // if (targetShape)
+ // for (SmallVector<int64_t> offsets :
+ // StaticTileOffsetRange(originalShape, *targetShape))
+ // resultTypes.push_back(unrolledType);
+ // }
+
+ // Create the converted func op
+ auto newFuncOp = rewriter.create<func::FuncOp>(
+ funcOp.getLoc(), funcOp.getName(),
+ FunctionType::get(rewriter.getContext(), TypeRange(newSignature),
+ TypeRange(originalResultType)));
+
+ newFuncOp.addEntryBlock();
+
+ llvm::errs() << "Created new func op\n";
+ newFuncOp.dump();
+ llvm::errs() << newFuncOp.getArguments().size() << "\n";
+
+ // TODO: Copy over all attributes other than the function name and type
+
+ // Clone operations (assuming one block for now)
+ // TODO: The uses for operands that are SSA values are not cloned properly.
+ loc = newFuncOp.getBody().getLoc();
+ rewriter.setInsertionPointToStart(&newFuncOp.getBody().getBlocks().front());
+
+ for (auto &op : funcOp.getBlocks().front().getOperations()) {
+ op.dump();
+ SmallVector<Value> newOperands(op.getNumOperands());
+ for (size_t i = 0; i < op.getOperands().size(); ++i) {
+ Value operand = op.getOperand(i);
+ auto blockArg = llvm::dyn_cast<BlockArgument>(operand);
+ if (!blockArg) {
+ newOperands[i] = operand;
+ continue;
+ }
+ // Not unrolled
+ unsigned int argNum = blockArg.getArgNumber();
+ if (newArgMap[argNum].size() == 1) {
+ newOperands[i] = newFuncOp.getArgument(newArgMap[argNum][0]);
+ continue;
+ }
+ // Unrolled
+ // TODO: Store previously created vector.insert_strided_slice ops.
+ auto vecType = dyn_cast<VectorType>(blockArg.getType());
+ SmallVector<int64_t> originalShape =
+ llvm::to_vector<4>(vecType.getShape());
+ auto targetShape = getTargetShape(options, funcOp, vecType);
+ VectorType unrolledType =
+ VectorType::get(*targetShape, vecType.getElementType());
+ llvm::errs() << "Unrolled type is ";
+ unrolledType.dump();
+ // Prepare the result vector.
+ Value result = rewriter.create<arith::ConstantOp>(
+ loc, vecType, rewriter.getZeroAttr(vecType));
+ result.dump();
+ SmallVector<int64_t> strides(targetShape->size(), 1);
+ // Create the vector.insert_strided_slice ops.
+ unsigned int j = 0;
+ for (SmallVector<int64_t> offsets :
+ StaticTileOffsetRange(originalShape, *targetShape)) {
+ result = rewriter.create<vector::InsertStridedSliceOp>(
+ loc, newFuncOp.getArgument(newArgMap[argNum][j]), result, offsets,
+ strides);
+ result.dump();
+ j++;
+ }
+ newOperands[i] = result;
+ }
+ Operation *newOp =
+ rewriter.create(loc, op.getName().getIdentifier(), newOperands,
+ op.getResultTypes(), op.getAttrs());
+ llvm::errs() << "newOp is ";
+ newOp->dump();
+ }
+ rewriter.eraseOp(funcOp);
+ return success();
+ }
+
+private:
+ vector::UnrollVectorOptions options;
+};
+
} // namespace
void mlir::vector::populateVectorUnrollPatterns(
@@ -628,3 +822,10 @@ void mlir::vector::populateVectorUnrollPatterns(
UnrollTransposePattern, UnrollGatherPattern>(
patterns.getContext(), options, benefit);
}
+
+void mlir::vector::populateVectorUnrollFuncSignaturePatterns(
+ RewritePatternSet &patterns, const UnrollVectorOptions &options,
+ PatternBenefit benefit) {
+ patterns.add<UnrollFuncSignaturePattern>(patterns.getContext(), options,
+ benefit);
+}
\ No newline at end of file
>From a5913ba5744fdd56c29284d6b1c2daecd6d7b73d Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Fri, 5 Jul 2024 15:45:34 +0000
Subject: [PATCH 02/14] Function input vector unrolling working and moved
pattern to SPIRV
---
.../SPIRV/Transforms/SPIRVConversion.h | 4 +
.../ConvertToSPIRV/ConvertToSPIRVPass.cpp | 34 ++--
.../Dialect/SPIRV/Transforms/CMakeLists.txt | 1 +
.../SPIRV/Transforms/SPIRVConversion.cpp | 161 ++++++++++++++++++
4 files changed, 188 insertions(+), 12 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
index 09eecafc0c8a5..1206603edcb6d 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
@@ -17,8 +17,10 @@
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallSet.h"
+#include "mlir/Transforms/OneToNTypeConversion.h"
namespace mlir {
@@ -134,6 +136,8 @@ class SPIRVConversionTarget : public ConversionTarget {
void populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns);
+void populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns);
+
namespace spirv {
class AccessChainOp;
diff --git a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
index 54152c5be26fa..adb903d3f448c 100644
--- a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
@@ -23,6 +23,7 @@
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/OneToNTypeConversion.h"
#include <memory>
#define DEBUG_TYPE "convert-to-spirv"
@@ -105,23 +106,23 @@ struct ConvertToSPIRVPass final
Operation *op = getOperation();
spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
- SPIRVTypeConverter typeConverter(targetAttr);
+ std::unique_ptr<ConversionTarget> target =
+ SPIRVConversionTarget::get(targetAttr);
- // Unroll vectors in function signature to native vector size.
+ // Unroll vectors in function inputs to native vector size.
{
- llvm::errs() << "Start unrolling function signature\n";
+ llvm::errs() << "Start unrolling function inputs\n";
RewritePatternSet patterns(context);
- // TODO: This is hardcoded to unroll with size 1. Change this later
- SmallVector<int64_t> nativeShape(1, 1);
- auto options = vector::UnrollVectorOptions().setNativeShape(nativeShape);
- populateVectorUnrollFuncSignaturePatterns(patterns, options);
+ populateFuncOpVectorRewritePatterns(patterns);
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingOps;
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
return signalPassFailure();
- llvm::errs() << "Finish unrolling function signature\n";
+ llvm::errs() << "Finish unrolling function inputs\n";
}
+ SPIRVTypeConverter typeConverter(targetAttr);
+
// Unroll vectors to native vector size.
{
RewritePatternSet patterns(context);
@@ -132,6 +133,9 @@ struct ConvertToSPIRVPass final
return signalPassFailure();
}
+ llvm::errs() << "After unrolling vectors to native vector size\n";
+ op->dump();
+
// Next run canonicalization to cast away leading size-1 dimensions.
{
RewritePatternSet patterns(context);
@@ -159,6 +163,9 @@ struct ConvertToSPIRVPass final
return signalPassFailure();
}
+ llvm::errs() << "After running canonicalization to cast away leading size-1 dimensions\n";
+ op->dump();
+
// Convert vector.extract_strided_slice into a chain of vector.extract and
// then a chain of vector.insert ops. This helps to cancel with previous
// vector.insert/extract ops, especially for fP16 cases where we have
@@ -175,6 +182,9 @@ struct ConvertToSPIRVPass final
return signalPassFailure();
}
+ llvm::errs() << "After converting vector.extract_strided_slice into a chain of vector.extract and then a chain of vector.insert ops\n";
+ op->dump();
+
// Run all sorts of canonicalization patterns to clean up again.
{
RewritePatternSet patterns(context);
@@ -188,22 +198,22 @@ struct ConvertToSPIRVPass final
return signalPassFailure();
}
+ llvm::errs() << "After running canonicalization patterns to clean up again\n";
+ op->dump();
+
RewritePatternSet patterns(context);
ScfToSPIRVContext scfToSPIRVContext;
// Populate patterns for each dialect.
arith::populateCeilFloorDivExpandOpsPatterns(patterns);
arith::populateArithToSPIRVPatterns(typeConverter, patterns);
- // populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
+ populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
populateFuncToSPIRVPatterns(typeConverter, patterns);
index::populateIndexToSPIRVPatterns(typeConverter, patterns);
populateVectorToSPIRVPatterns(typeConverter, patterns);
populateSCFToSPIRVPatterns(typeConverter, scfToSPIRVContext, patterns);
ub::populateUBToSPIRVConversionPatterns(typeConverter, patterns);
- std::unique_ptr<ConversionTarget> target =
- SPIRVConversionTarget::get(targetAttr);
-
if (failed(applyPartialConversion(op, *target, std::move(patterns))))
return signalPassFailure();
}
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
index 821f82ebc0796..11af020b6c188 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
@@ -19,6 +19,7 @@ add_mlir_dialect_library(MLIRSPIRVConversion
MLIRFuncDialect
MLIRSPIRVDialect
MLIRTransformUtils
+ MLIRVectorTransforms
)
add_mlir_dialect_library(MLIRSPIRVTransforms
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 4072608dc8f87..616eb6104b705 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -17,8 +17,15 @@
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/OneToNTypeConversion.h"
+#include "llvm/ADT/Sequence.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/MathExtras.h"
@@ -813,6 +820,160 @@ void mlir::populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
patterns.add<FuncOpConversion>(typeConverter, patterns.getContext());
}
+//===----------------------------------------------------------------------===//
+// func::FuncOp Conversion Patterns
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// A pattern for rewriting function signature to convert vector arguments of
+/// functions to be of valid types
+class FuncOpVectorTypesConversion : public OpRewritePattern<func::FuncOp> {
+public:
+ using OpRewritePattern<func::FuncOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(func::FuncOp funcOp,
+ PatternRewriter &rewriter) const override;
+};
+} // namespace
+
+static std::optional<SmallVector<int64_t>> getTargetShape(VectorType vecType) {
+ llvm::errs() << "Get target shape\n";
+ SmallVector<int64_t> unrollShape = llvm::to_vector<4>(vecType.getShape());
+ // TODO: This is hardcoded to unroll with size 1. Change this later
+ std::optional<SmallVector<int64_t>> targetShape = SmallVector<int64_t>(1, 1);
+ if (!targetShape) {
+ llvm::errs() << "--no unrolling target shape defined\n";
+ return std::nullopt;
+ }
+ auto maybeShapeRatio = computeShapeRatio(unrollShape, *targetShape);
+ if (!maybeShapeRatio) {
+ llvm::errs() << "--could not compute integral shape ratio -> BAIL\n";
+ return std::nullopt;
+ }
+ if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) {
+ llvm::errs() << "--no unrolling needed -> SKIP\n";
+ return std::nullopt;
+ }
+ llvm::errs() << "--found an integral shape ratio to unroll to -> SUCCESS\n";
+ return targetShape;
+}
+
+LogicalResult
+FuncOpVectorTypesConversion::matchAndRewrite(func::FuncOp funcOp,
+ PatternRewriter &rewriter) const {
+ auto fnType = funcOp.getFunctionType();
+
+ auto newFuncOp =
+ rewriter.create<func::FuncOp>(funcOp.getLoc(), funcOp.getName(), fnType);
+
+ rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
+ newFuncOp.end());
+
+ newFuncOp.dump();
+
+ OneToNTypeMapping oneToNTypeMapping(fnType.getInputs());
+ Location loc = newFuncOp.getBody().getLoc();
+ rewriter.setInsertionPointToStart(&newFuncOp.getBody().getBlocks().front());
+ SmallVector<size_t> unrolledInputNums;
+ size_t newInputNo = 0;
+
+ // Enumerate through the arguments.
+ for (const auto &argType : enumerate(fnType.getInputs())) {
+ size_t origInputNo = argType.index();
+ Type origType = argType.value();
+ auto origVecType = llvm::dyn_cast<VectorType>(origType);
+ if (!origVecType) {
+ oneToNTypeMapping.addInputs(origInputNo, origType);
+ newInputNo++;
+ continue;
+ }
+ llvm::errs() << "Try vector unrolling\n";
+ SmallVector<int64_t> nativeShape(1, 1);
+ auto options = vector::UnrollVectorOptions().setNativeShape(nativeShape);
+ auto targetShape = getTargetShape(origVecType);
+ if (!targetShape) {
+ llvm::errs() << "No target shape\n";
+ oneToNTypeMapping.addInputs(origInputNo, origType);
+ newInputNo++;
+ continue;
+ }
+ llvm::errs() << "Got target shape\n";
+ VectorType unrolledType =
+ VectorType::get(*targetShape, origVecType.getElementType());
+ llvm::errs() << "Unrolled type is ";
+ unrolledType.dump();
+ SmallVector<int64_t> originalShape =
+ llvm::to_vector<4>(origVecType.getShape());
+ SmallVector<Type> newTypes;
+ // Prepare the result vector
+ Value result = rewriter.create<arith::ConstantOp>(
+ loc, origVecType, rewriter.getZeroAttr(origVecType));
+ result.dump();
+ // Prepare the placeholder
+ Value dummy = rewriter.create<arith::ConstantOp>(
+ loc, unrolledType, rewriter.getZeroAttr(unrolledType));
+ dummy.dump();
+ SmallVector<int64_t> strides(targetShape->size(), 1);
+ for (SmallVector<int64_t> offsets :
+ StaticTileOffsetRange(originalShape, *targetShape)) {
+ result = rewriter.create<vector::InsertStridedSliceOp>(loc, dummy, result,
+ offsets, strides);
+ result.dump();
+ newTypes.push_back(unrolledType);
+ unrolledInputNums.push_back(newInputNo);
+ newInputNo++;
+ }
+ rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
+ oneToNTypeMapping.addInputs(origInputNo, newTypes);
+ }
+
+ llvm::errs() << "After enumerating through the arguments\n";
+ newFuncOp->dump();
+
+ // Assume there is a single result for now.
+ Type originalResultType = fnType.getResult(0);
+
+ // Change function signature
+ auto newFnType = FunctionType::get(
+ rewriter.getContext(), TypeRange(oneToNTypeMapping.getConvertedTypes()),
+ TypeRange(originalResultType));
+ rewriter.modifyOpInPlace(newFuncOp,
+ [&] { newFuncOp.setFunctionType(newFnType); });
+ llvm::errs() << "After changing function signature\n";
+ newFuncOp->dump();
+
+ Block &entryBlock = newFuncOp.getBlocks().front();
+
+ // Update the arguments in the entry block.
+ entryBlock.eraseArguments(0, fnType.getNumInputs());
+ SmallVector<Location> locs(oneToNTypeMapping.getConvertedTypes().size(),
+ newFuncOp.getLoc());
+ entryBlock.addArguments(oneToNTypeMapping.getConvertedTypes(), locs);
+
+ llvm::errs() << "After modifying the entry block\n";
+ newFuncOp->dump();
+
+ size_t i = 0;
+ // Relace the dummy values with actual arguments.
+ for (auto &op : entryBlock.getOperations()) {
+ op.dump();
+ auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op);
+ if (vecOp) {
+ size_t unrolledInputNo = unrolledInputNums[i];
+ rewriter.modifyOpInPlace(
+ &op, [&] { op.setOperand(0, newFuncOp.getArgument(unrolledInputNo)); });
+ i++;
+ }
+ }
+
+ rewriter.eraseOp(funcOp);
+ return success();
+}
+
+void mlir::populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns) {
+ patterns.add<FuncOpVectorTypesConversion>(patterns.getContext());
+}
+
//===----------------------------------------------------------------------===//
// Builtin Variables
//===----------------------------------------------------------------------===//
>From 63744d1555a745d21a3691d9fa6e0319d4b2374b Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Mon, 8 Jul 2024 15:29:52 +0000
Subject: [PATCH 03/14] Implement function result and ReturnOp vector unrolling
---
.../SPIRV/Transforms/SPIRVConversion.h | 2 +
.../ConvertToSPIRV/ConvertToSPIRVPass.cpp | 12 ++
.../SPIRV/Transforms/SPIRVConversion.cpp | 132 +++++++++++++++---
3 files changed, 128 insertions(+), 18 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
index 1206603edcb6d..112c404527927 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
@@ -138,6 +138,8 @@ void populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
void populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns);
+void populateReturnOpVectorRewritePatterns(RewritePatternSet &patterns);
+
namespace spirv {
class AccessChainOp;
diff --git a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
index adb903d3f448c..cd1344569e503 100644
--- a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
@@ -121,6 +121,18 @@ struct ConvertToSPIRVPass final
llvm::errs() << "Finish unrolling function inputs\n";
}
+ // Unroll vectors in function outputs to native vector size.
+ {
+ llvm::errs() << "Start unrolling function outputs\n";
+ RewritePatternSet patterns(context);
+ populateReturnOpVectorRewritePatterns(patterns);
+ GreedyRewriteConfig config;
+ config.strictMode = GreedyRewriteStrictness::ExistingOps;
+ if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
+ return signalPassFailure();
+ llvm::errs() << "Finish unrolling function inputs\n";
+ }
+
SPIRVTypeConverter typeConverter(targetAttr);
// Unroll vectors to native vector size.
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 616eb6104b705..37a8071cbf9b6 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -22,8 +22,10 @@
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/OneToNTypeConversion.h"
+#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
@@ -863,17 +865,18 @@ FuncOpVectorTypesConversion::matchAndRewrite(func::FuncOp funcOp,
PatternRewriter &rewriter) const {
auto fnType = funcOp.getFunctionType();
+ // First create a new func op and copy the function body.
auto newFuncOp =
rewriter.create<func::FuncOp>(funcOp.getLoc(), funcOp.getName(), fnType);
-
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
newFuncOp.end());
-
+ llvm::errs() << "After creating new func op and copying the function body\n";
newFuncOp.dump();
- OneToNTypeMapping oneToNTypeMapping(fnType.getInputs());
Location loc = newFuncOp.getBody().getLoc();
- rewriter.setInsertionPointToStart(&newFuncOp.getBody().getBlocks().front());
+ Block &entryBlock = newFuncOp.getBlocks().front();
+ rewriter.setInsertionPointToStart(&entryBlock);
+ OneToNTypeMapping oneToNTypeMapping(fnType.getInputs());
SmallVector<size_t> unrolledInputNums;
size_t newInputNo = 0;
@@ -928,21 +931,17 @@ FuncOpVectorTypesConversion::matchAndRewrite(func::FuncOp funcOp,
}
llvm::errs() << "After enumerating through the arguments\n";
- newFuncOp->dump();
-
- // Assume there is a single result for now.
- Type originalResultType = fnType.getResult(0);
+ newFuncOp.dump();
- // Change function signature
+ // Change function signature.
auto newFnType = FunctionType::get(
rewriter.getContext(), TypeRange(oneToNTypeMapping.getConvertedTypes()),
- TypeRange(originalResultType));
+ TypeRange(fnType.getResults()));
rewriter.modifyOpInPlace(newFuncOp,
[&] { newFuncOp.setFunctionType(newFnType); });
- llvm::errs() << "After changing function signature\n";
- newFuncOp->dump();
- Block &entryBlock = newFuncOp.getBlocks().front();
+ llvm::errs() << "After changing function signature\n";
+ newFuncOp.dump();
// Update the arguments in the entry block.
entryBlock.eraseArguments(0, fnType.getNumInputs());
@@ -950,18 +949,19 @@ FuncOpVectorTypesConversion::matchAndRewrite(func::FuncOp funcOp,
newFuncOp.getLoc());
entryBlock.addArguments(oneToNTypeMapping.getConvertedTypes(), locs);
- llvm::errs() << "After modifying the entry block\n";
- newFuncOp->dump();
+ llvm::errs() << "After updating the arguments in the entry block\n";
+ newFuncOp.dump();
- size_t i = 0;
// Relace the dummy values with actual arguments.
+ size_t i = 0;
for (auto &op : entryBlock.getOperations()) {
op.dump();
auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op);
if (vecOp) {
size_t unrolledInputNo = unrolledInputNums[i];
- rewriter.modifyOpInPlace(
- &op, [&] { op.setOperand(0, newFuncOp.getArgument(unrolledInputNo)); });
+ rewriter.modifyOpInPlace(&op, [&] {
+ op.setOperand(0, newFuncOp.getArgument(unrolledInputNo));
+ });
i++;
}
}
@@ -974,6 +974,102 @@ void mlir::populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns) {
patterns.add<FuncOpVectorTypesConversion>(patterns.getContext());
}
+//===----------------------------------------------------------------------===//
+// func::ReturnOp Conversion Patterns
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// A pattern for rewriting function signature and the return op to convert
+/// vectors to be of valid types.
+class ReturnOpVectorTypesConversion : public OpRewritePattern<func::ReturnOp> {
+public:
+ using OpRewritePattern<func::ReturnOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(func::ReturnOp returnOp,
+ PatternRewriter &rewriter) const override;
+};
+} // namespace
+
+LogicalResult ReturnOpVectorTypesConversion::matchAndRewrite(
+ func::ReturnOp returnOp, PatternRewriter &rewriter) const {
+
+ func::FuncOp funcOp = dyn_cast<func::FuncOp>(returnOp->getParentOp());
+ if (!funcOp)
+ return failure();
+
+ auto fnType = funcOp.getFunctionType();
+ OneToNTypeMapping oneToNTypeMapping(fnType.getResults());
+ Location loc = returnOp.getLoc();
+ SmallVector<Value> newOperands;
+
+ // Enumerate through the results.
+ for (const auto &argType : enumerate(fnType.getResults())) {
+ size_t origResultNo = argType.index();
+ Type origType = argType.value();
+ auto origVecType = llvm::dyn_cast<VectorType>(origType);
+ if (!origVecType) {
+ oneToNTypeMapping.addInputs(origResultNo, origType);
+ newOperands.push_back(returnOp.getOperand(origResultNo));
+ continue;
+ }
+ llvm::errs() << "Try vector unrolling\n";
+ SmallVector<int64_t> nativeShape(1, 1);
+ auto options = vector::UnrollVectorOptions().setNativeShape(nativeShape);
+ auto targetShape = getTargetShape(origVecType);
+ if (!targetShape) {
+ llvm::errs() << "No target shape\n";
+ oneToNTypeMapping.addInputs(origResultNo, origType);
+ newOperands.push_back(returnOp.getOperand(origResultNo));
+ continue;
+ }
+ llvm::errs() << "Got target shape\n";
+ VectorType unrolledType =
+ VectorType::get(*targetShape, origVecType.getElementType());
+ llvm::errs() << "Unrolled type is ";
+ unrolledType.dump();
+ SmallVector<int64_t> originalShape =
+ llvm::to_vector<4>(origVecType.getShape());
+ SmallVector<Type> newTypes;
+ SmallVector<int64_t> strides(targetShape->size(), 1);
+ Value returnValue = returnOp.getOperand(0);
+ for (SmallVector<int64_t> offsets :
+ StaticTileOffsetRange(originalShape, *targetShape)) {
+ auto result = rewriter.create<vector::ExtractStridedSliceOp>(
+ loc, returnValue, offsets, *targetShape, strides);
+ result.dump();
+ newOperands.push_back(result);
+ newTypes.push_back(unrolledType);
+ }
+ oneToNTypeMapping.addInputs(origResultNo, newTypes);
+ }
+
+ llvm::errs() << "After enumerating through the arguments\n";
+ funcOp.dump();
+
+ for (auto operand : newOperands)
+ operand.dump();
+
+ // Change function signature.
+ auto newFnType =
+ FunctionType::get(rewriter.getContext(), TypeRange(fnType.getInputs()),
+ TypeRange(oneToNTypeMapping.getConvertedTypes()));
+ rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setFunctionType(newFnType); });
+ llvm::errs() << "After changing function signature\n";
+ funcOp.dump();
+
+ // Replace the return op using the new operands.
+ rewriter.replaceOp(returnOp,
+ rewriter.create<func::ReturnOp>(loc, newOperands));
+ llvm::errs() << "After replacing return op\n";
+ funcOp.dump();
+
+ return success();
+}
+
+void mlir::populateReturnOpVectorRewritePatterns(RewritePatternSet &patterns) {
+ patterns.add<ReturnOpVectorTypesConversion>(patterns.getContext());
+}
+
//===----------------------------------------------------------------------===//
// Builtin Variables
//===----------------------------------------------------------------------===//
>From ba04f4f8b60cbbe9cfb8769aa877f598467de5a0 Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Mon, 8 Jul 2024 18:44:29 +0000
Subject: [PATCH 04/14] Compute the target shape based on original vector shape
---
.../SPIRV/Transforms/SPIRVConversion.cpp | 18 ++++++++++++------
1 file changed, 12 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 37a8071cbf9b6..2af947e94a475 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -838,11 +838,20 @@ class FuncOpVectorTypesConversion : public OpRewritePattern<func::FuncOp> {
};
} // namespace
+static int getComputeVectorSize(int64_t size) {
+ for (int i : {4, 3, 2}) {
+ if (size % i == 0)
+ return i;
+ }
+ return 1;
+}
+
static std::optional<SmallVector<int64_t>> getTargetShape(VectorType vecType) {
llvm::errs() << "Get target shape\n";
SmallVector<int64_t> unrollShape = llvm::to_vector<4>(vecType.getShape());
// TODO: This is hardcoded to unroll with size 1. Change this later
- std::optional<SmallVector<int64_t>> targetShape = SmallVector<int64_t>(1, 1);
+ std::optional<SmallVector<int64_t>> targetShape =
+ SmallVector<int64_t>(1, getComputeVectorSize(vecType.getShape().back()));
if (!targetShape) {
llvm::errs() << "--no unrolling target shape defined\n";
return std::nullopt;
@@ -870,12 +879,14 @@ FuncOpVectorTypesConversion::matchAndRewrite(func::FuncOp funcOp,
rewriter.create<func::FuncOp>(funcOp.getLoc(), funcOp.getName(), fnType);
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
newFuncOp.end());
+ rewriter.eraseOp(funcOp);
llvm::errs() << "After creating new func op and copying the function body\n";
newFuncOp.dump();
Location loc = newFuncOp.getBody().getLoc();
Block &entryBlock = newFuncOp.getBlocks().front();
rewriter.setInsertionPointToStart(&entryBlock);
+
OneToNTypeMapping oneToNTypeMapping(fnType.getInputs());
SmallVector<size_t> unrolledInputNums;
size_t newInputNo = 0;
@@ -891,8 +902,6 @@ FuncOpVectorTypesConversion::matchAndRewrite(func::FuncOp funcOp,
continue;
}
llvm::errs() << "Try vector unrolling\n";
- SmallVector<int64_t> nativeShape(1, 1);
- auto options = vector::UnrollVectorOptions().setNativeShape(nativeShape);
auto targetShape = getTargetShape(origVecType);
if (!targetShape) {
llvm::errs() << "No target shape\n";
@@ -966,7 +975,6 @@ FuncOpVectorTypesConversion::matchAndRewrite(func::FuncOp funcOp,
}
}
- rewriter.eraseOp(funcOp);
return success();
}
@@ -1013,8 +1021,6 @@ LogicalResult ReturnOpVectorTypesConversion::matchAndRewrite(
continue;
}
llvm::errs() << "Try vector unrolling\n";
- SmallVector<int64_t> nativeShape(1, 1);
- auto options = vector::UnrollVectorOptions().setNativeShape(nativeShape);
auto targetShape = getTargetShape(origVecType);
if (!targetShape) {
llvm::errs() << "No target shape\n";
>From 6e99d24e626aa8282340df52fc9e54605e08ca4d Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Mon, 8 Jul 2024 20:58:14 +0000
Subject: [PATCH 05/14] Fix bug in function output unrolling
---
mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 2af947e94a475..99de9b8195d35 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -1037,7 +1037,7 @@ LogicalResult ReturnOpVectorTypesConversion::matchAndRewrite(
llvm::to_vector<4>(origVecType.getShape());
SmallVector<Type> newTypes;
SmallVector<int64_t> strides(targetShape->size(), 1);
- Value returnValue = returnOp.getOperand(0);
+ Value returnValue = returnOp.getOperand(origResultNo);
for (SmallVector<int64_t> offsets :
StaticTileOffsetRange(originalShape, *targetShape)) {
auto result = rewriter.create<vector::ExtractStridedSliceOp>(
>From 49b5a4b399f8e60eaafc631f4fda52f1b0469182 Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Tue, 9 Jul 2024 15:19:07 +0000
Subject: [PATCH 06/14] Working for signatures with legal and illegal types
---
.../ConvertToSPIRV/ConvertToSPIRVPass.cpp | 21 ++++++----
.../SPIRV/Transforms/SPIRVConversion.cpp | 42 ++++++++++++++++++-
2 files changed, 53 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
index cd1344569e503..1c11076c4b5b9 100644
--- a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
@@ -39,7 +39,7 @@ using namespace mlir;
// Vector Lowering
//===----------------------------------------------------------------------===//
-int getComputeVectorSize(int64_t size) {
+static int getComputeVectorSize(int64_t size) {
for (int i : {4, 3, 2}) {
if (size % i == 0)
return i;
@@ -110,28 +110,29 @@ struct ConvertToSPIRVPass final
SPIRVConversionTarget::get(targetAttr);
// Unroll vectors in function inputs to native vector size.
+ llvm::errs() << "Start unrolling function inputs\n";
{
- llvm::errs() << "Start unrolling function inputs\n";
RewritePatternSet patterns(context);
populateFuncOpVectorRewritePatterns(patterns);
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingOps;
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
return signalPassFailure();
- llvm::errs() << "Finish unrolling function inputs\n";
}
+ llvm::errs() << "Finish unrolling function inputs\n";
+ op->dump();
// Unroll vectors in function outputs to native vector size.
+ llvm::errs() << "Start unrolling function outputs\n";
{
- llvm::errs() << "Start unrolling function outputs\n";
RewritePatternSet patterns(context);
populateReturnOpVectorRewritePatterns(patterns);
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingOps;
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
return signalPassFailure();
- llvm::errs() << "Finish unrolling function inputs\n";
}
+ llvm::errs() << "Finish unrolling function outputs\n";
SPIRVTypeConverter typeConverter(targetAttr);
@@ -175,7 +176,8 @@ struct ConvertToSPIRVPass final
return signalPassFailure();
}
- llvm::errs() << "After running canonicalization to cast away leading size-1 dimensions\n";
+ llvm::errs() << "After running canonicalization to cast away leading "
+ "size-1 dimensions\n";
op->dump();
// Convert vector.extract_strided_slice into a chain of vector.extract and
@@ -194,7 +196,9 @@ struct ConvertToSPIRVPass final
return signalPassFailure();
}
- llvm::errs() << "After converting vector.extract_strided_slice into a chain of vector.extract and then a chain of vector.insert ops\n";
+ llvm::errs()
+ << "After converting vector.extract_strided_slice into a chain of "
+ "vector.extract and then a chain of vector.insert ops\n";
op->dump();
// Run all sorts of canonicalization patterns to clean up again.
@@ -210,7 +214,8 @@ struct ConvertToSPIRVPass final
return signalPassFailure();
}
- llvm::errs() << "After running canonicalization patterns to clean up again\n";
+ llvm::errs()
+ << "After running canonicalization patterns to clean up again\n";
op->dump();
RewritePatternSet patterns(context);
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 99de9b8195d35..304e1f7756580 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
@@ -21,19 +22,24 @@
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/OneToNTypeConversion.h"
#include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Sequence.h"
+#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/MathExtras.h"
+#include <cctype>
#include <functional>
#include <optional>
+#include <unordered_set>
#define DEBUG_TYPE "mlir-spirv-conversion"
@@ -891,22 +897,35 @@ FuncOpVectorTypesConversion::matchAndRewrite(func::FuncOp funcOp,
SmallVector<size_t> unrolledInputNums;
size_t newInputNo = 0;
+ std::unordered_map<Operation *, size_t> tmpOps;
+ size_t newOpCount = 0;
+
// Enumerate through the arguments.
for (const auto &argType : enumerate(fnType.getInputs())) {
size_t origInputNo = argType.index();
Type origType = argType.value();
auto origVecType = llvm::dyn_cast<VectorType>(origType);
if (!origVecType) {
+ Value result = rewriter.create<arith::ConstantOp>(
+ loc, origType, rewriter.getZeroAttr(origType));
+ rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
+ tmpOps.insert({result.getDefiningOp(), newInputNo});
oneToNTypeMapping.addInputs(origInputNo, origType);
newInputNo++;
+ newOpCount++;
continue;
}
llvm::errs() << "Try vector unrolling\n";
auto targetShape = getTargetShape(origVecType);
if (!targetShape) {
llvm::errs() << "No target shape\n";
+ Value result = rewriter.create<arith::ConstantOp>(
+ loc, origType, rewriter.getZeroAttr(origType));
+ rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
+ tmpOps.insert({result.getDefiningOp(), newInputNo});
oneToNTypeMapping.addInputs(origInputNo, origType);
newInputNo++;
+ newOpCount++;
continue;
}
llvm::errs() << "Got target shape\n";
@@ -921,10 +940,12 @@ FuncOpVectorTypesConversion::matchAndRewrite(func::FuncOp funcOp,
Value result = rewriter.create<arith::ConstantOp>(
loc, origVecType, rewriter.getZeroAttr(origVecType));
result.dump();
+ newOpCount++;
// Prepare the placeholder
Value dummy = rewriter.create<arith::ConstantOp>(
loc, unrolledType, rewriter.getZeroAttr(unrolledType));
dummy.dump();
+ newOpCount++;
SmallVector<int64_t> strides(targetShape->size(), 1);
for (SmallVector<int64_t> offsets :
StaticTileOffsetRange(originalShape, *targetShape)) {
@@ -934,6 +955,7 @@ FuncOpVectorTypesConversion::matchAndRewrite(func::FuncOp funcOp,
newTypes.push_back(unrolledType);
unrolledInputNums.push_back(newInputNo);
newInputNo++;
+ newOpCount++;
}
rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
oneToNTypeMapping.addInputs(origInputNo, newTypes);
@@ -961,10 +983,25 @@ FuncOpVectorTypesConversion::matchAndRewrite(func::FuncOp funcOp,
llvm::errs() << "After updating the arguments in the entry block\n";
newFuncOp.dump();
- // Relace the dummy values with actual arguments.
+ // Replace the dummy values with actual arguments.
size_t i = 0;
- for (auto &op : entryBlock.getOperations()) {
+ for (auto pair : llvm::enumerate(entryBlock.getOperations())) {
+ size_t count = pair.index();
+ Operation &op = pair.value();
op.dump();
+ for (auto pair : llvm::enumerate(op.getOperands())) {
+ Operation *operandOp = pair.value().getDefiningOp();
+ if (tmpOps.find(operandOp) != tmpOps.end()) {
+ rewriter.modifyOpInPlace(&op, [&] {
+ op.setOperand(pair.index(), newFuncOp.getArgument(tmpOps[operandOp]));
+ });
+ rewriter.eraseOp(operandOp);
+ count++;
+ continue;
+ }
+ }
+ if (count == newOpCount)
+ continue;
auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op);
if (vecOp) {
size_t unrolledInputNo = unrolledInputNums[i];
@@ -973,6 +1010,7 @@ FuncOpVectorTypesConversion::matchAndRewrite(func::FuncOp funcOp,
});
i++;
}
+ count++;
}
return success();
>From b0fc3ab9439bc5b6092b1470d4642801c626dfaa Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Tue, 9 Jul 2024 20:13:56 +0000
Subject: [PATCH 07/14] Only keep the signature conversion, and refactor code
---
.../Vector/Transforms/VectorRewritePatterns.h | 4 -
.../ConvertToSPIRV/ConvertToSPIRVPass.cpp | 152 +------------
.../SPIRV/Transforms/SPIRVConversion.cpp | 196 +++++++++--------
.../Dialect/Vector/Transforms/CMakeLists.txt | 1 -
.../Vector/Transforms/VectorUnroll.cpp | 201 ------------------
5 files changed, 108 insertions(+), 446 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 5c06d6d4d6ad3..8e6d36f0b5f09 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -293,10 +293,6 @@ void populateVectorUnrollPatterns(RewritePatternSet &patterns,
const UnrollVectorOptions &options,
PatternBenefit benefit = 1);
-void populateVectorUnrollFuncSignaturePatterns(RewritePatternSet &patterns,
- const UnrollVectorOptions &options,
- PatternBenefit benefit = 1);
-
/// Collect a set of vector.shape_cast folding patterns.
void populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
diff --git a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
index 1c11076c4b5b9..ddfbb0a76ad11 100644
--- a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
@@ -23,7 +23,6 @@
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "mlir/Transforms/OneToNTypeConversion.h"
#include <memory>
#define DEBUG_TYPE "convert-to-spirv"
@@ -35,66 +34,6 @@ namespace mlir {
using namespace mlir;
-//===----------------------------------------------------------------------===//
-// Vector Lowering
-//===----------------------------------------------------------------------===//
-
-static int getComputeVectorSize(int64_t size) {
- for (int i : {4, 3, 2}) {
- if (size % i == 0)
- return i;
- }
- return 1;
-}
-
-SmallVector<int64_t> getNativeVectorShapeImpl(vector::MultiDimReductionOp op) {
- // Unroll all reduction dimensions by size 1 for vector.multi_reduction.
- VectorType srcVectorType = op.getSourceVectorType();
- auto nativeSize = llvm::to_vector(srcVectorType.getShape());
- auto dims = op.getReductionDims().getAsValueRange<IntegerAttr>();
- for (const auto &dimAttr : dims) {
- nativeSize[dimAttr.getZExtValue()] = 1;
- }
- return nativeSize;
-}
-
-SmallVector<int64_t> getNativeVectorShapeImpl(vector::ReductionOp op) {
- VectorType srcVectorType = op.getSourceVectorType();
- assert(srcVectorType.getRank() == 1); // Guaranteed by semantics
- int64_t vectorSize = getComputeVectorSize(srcVectorType.getDimSize(0));
- return {vectorSize};
-}
-
-SmallVector<int64_t> getNativeVectorShapeImpl(vector::TransposeOp op) {
- VectorType vectorType = op.getResultVectorType();
- SmallVector<int64_t> nativeSize(vectorType.getRank(), 1);
- nativeSize.back() = getComputeVectorSize(vectorType.getShape().back());
- return nativeSize;
-}
-
-SmallVector<int64_t> getNativeVectorShapeImpl(vector::GatherOp op) {
- VectorType vectorType = op.getVectorType();
- SmallVector<int64_t> nativeSize(vectorType.getRank(), 1);
- nativeSize.back() = getComputeVectorSize(vectorType.getShape().back());
- return nativeSize;
-}
-
-std::optional<SmallVector<int64_t>> getNativeVectorShape(Operation *op) {
- if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1) {
- if (auto vecType = llvm::dyn_cast<VectorType>(op->getResultTypes()[0])) {
- SmallVector<int64_t> nativeSize(vecType.getRank(), 1);
- nativeSize.back() = getComputeVectorSize(vecType.getShape().back());
- return nativeSize;
- }
- }
-
- return TypeSwitch<Operation *, std::optional<SmallVector<int64_t>>>(op)
- .Case<vector::MultiDimReductionOp, vector::ReductionOp,
- vector::TransposeOp, vector::GatherOp>(
- [](auto typedOp) { return getNativeVectorShapeImpl(typedOp); })
- .Default([](Operation *) { return std::nullopt; });
-}
-
namespace {
/// A pass to perform the SPIR-V conversion.
@@ -105,10 +44,6 @@ struct ConvertToSPIRVPass final
MLIRContext *context = &getContext();
Operation *op = getOperation();
- spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
- std::unique_ptr<ConversionTarget> target =
- SPIRVConversionTarget::get(targetAttr);
-
// Unroll vectors in function inputs to native vector size.
llvm::errs() << "Start unrolling function inputs\n";
{
@@ -120,7 +55,6 @@ struct ConvertToSPIRVPass final
return signalPassFailure();
}
llvm::errs() << "Finish unrolling function inputs\n";
- op->dump();
// Unroll vectors in function outputs to native vector size.
llvm::errs() << "Start unrolling function outputs\n";
@@ -134,90 +68,10 @@ struct ConvertToSPIRVPass final
}
llvm::errs() << "Finish unrolling function outputs\n";
+ spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
+ std::unique_ptr<ConversionTarget> target =
+ SPIRVConversionTarget::get(targetAttr);
SPIRVTypeConverter typeConverter(targetAttr);
-
- // Unroll vectors to native vector size.
- {
- RewritePatternSet patterns(context);
- auto options = vector::UnrollVectorOptions().setNativeShapeFn(
- [=](auto op) { return getNativeVectorShape(op); });
- populateVectorUnrollPatterns(patterns, options);
- if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
- return signalPassFailure();
- }
-
- llvm::errs() << "After unrolling vectors to native vector size\n";
- op->dump();
-
- // Next run canonicalization to cast away leading size-1 dimensions.
- {
- RewritePatternSet patterns(context);
-
- // We need to pull in casting way leading one dims to allow cancelling
- // some read/write ops.
- vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
-
- // We may have vector.insert_strided_slice inserting 1-D native vectors
- // into n-D larger vectors with the above. Break that down too. This is a
- // companion transformation of unrolling.
- vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
- patterns);
- vector::ExtractOp::getCanonicalizationPatterns(patterns, context);
-
- // Trimming leading unit dims may generate broadcast/shape_cast ops. Clean
- // them up.
- vector::BroadcastOp::getCanonicalizationPatterns(patterns, context);
- vector::ShapeCastOp::getCanonicalizationPatterns(patterns, context);
-
- vector::TransferReadOp::getCanonicalizationPatterns(patterns, context);
- vector::TransferWriteOp::getCanonicalizationPatterns(patterns, context);
-
- if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
- return signalPassFailure();
- }
-
- llvm::errs() << "After running canonicalization to cast away leading "
- "size-1 dimensions\n";
- op->dump();
-
- // Convert vector.extract_strided_slice into a chain of vector.extract and
- // then a chain of vector.insert ops. This helps to cancel with previous
- // vector.insert/extract ops, especially for fP16 cases where we have
- // mismatched vector size for transfer and compute.
- {
- RewritePatternSet patterns(context);
- vector::populateVectorExtractStridedSliceToExtractInsertChainPatterns(
- patterns, [](vector::ExtractStridedSliceOp op) {
- return op.getSourceVectorType().getNumElements() > 4;
- });
- vector::InsertOp::getCanonicalizationPatterns(patterns, context);
- vector::ExtractOp::getCanonicalizationPatterns(patterns, context);
- if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
- return signalPassFailure();
- }
-
- llvm::errs()
- << "After converting vector.extract_strided_slice into a chain of "
- "vector.extract and then a chain of vector.insert ops\n";
- op->dump();
-
- // Run all sorts of canonicalization patterns to clean up again.
- {
- RewritePatternSet patterns(context);
- vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
- vector::InsertOp::getCanonicalizationPatterns(patterns, context);
- vector::ExtractOp::getCanonicalizationPatterns(patterns, context);
- vector::TransferReadOp::getCanonicalizationPatterns(patterns, context);
- vector::TransferWriteOp::getCanonicalizationPatterns(patterns, context);
- vector::ReductionOp::getCanonicalizationPatterns(patterns, context);
- if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
- return signalPassFailure();
- }
-
- llvm::errs()
- << "After running canonicalization patterns to clean up again\n";
- op->dump();
-
RewritePatternSet patterns(context);
ScfToSPIRVContext scfToSPIRVContext;
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 304e1f7756580..6e793573f0262 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -27,19 +27,14 @@
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/OneToNTypeConversion.h"
-#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/Sequence.h"
-#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/MathExtras.h"
#include <cctype>
-#include <functional>
#include <optional>
-#include <unordered_set>
#define DEBUG_TYPE "mlir-spirv-conversion"
@@ -49,6 +44,36 @@ using namespace mlir;
// Utility functions
//===----------------------------------------------------------------------===//
+static int getComputeVectorSize(int64_t size) {
+ for (int i : {4, 3, 2}) {
+ if (size % i == 0)
+ return i;
+ }
+ return 1;
+}
+
+static std::optional<SmallVector<int64_t>> getTargetShape(VectorType vecType) {
+ llvm::errs() << "Get target shape\n";
+ SmallVector<int64_t> unrollShape = llvm::to_vector<4>(vecType.getShape());
+ std::optional<SmallVector<int64_t>> targetShape =
+ SmallVector<int64_t>(1, getComputeVectorSize(vecType.getShape().back()));
+ if (!targetShape) {
+ llvm::errs() << "--no unrolling target shape defined\n";
+ return std::nullopt;
+ }
+ auto maybeShapeRatio = computeShapeRatio(unrollShape, *targetShape);
+ if (!maybeShapeRatio) {
+ llvm::errs() << "--could not compute integral shape ratio -> BAIL\n";
+ return std::nullopt;
+ }
+ if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) {
+ llvm::errs() << "--no unrolling needed -> SKIP\n";
+ return std::nullopt;
+ }
+ llvm::errs() << "--found an integral shape ratio to unroll to -> SUCCESS\n";
+ return targetShape;
+}
+
/// Checks that `candidates` extension requirements are possible to be satisfied
/// with the given `targetEnv`.
///
@@ -835,7 +860,7 @@ void mlir::populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
namespace {
/// A pattern for rewriting function signature to convert vector arguments of
/// functions to be of valid types
-class FuncOpVectorTypesConversion : public OpRewritePattern<func::FuncOp> {
+class FuncOpVectorUnroll : public OpRewritePattern<func::FuncOp> {
public:
using OpRewritePattern<func::FuncOp>::OpRewritePattern;
@@ -844,48 +869,17 @@ class FuncOpVectorTypesConversion : public OpRewritePattern<func::FuncOp> {
};
} // namespace
-static int getComputeVectorSize(int64_t size) {
- for (int i : {4, 3, 2}) {
- if (size % i == 0)
- return i;
- }
- return 1;
-}
-
-static std::optional<SmallVector<int64_t>> getTargetShape(VectorType vecType) {
- llvm::errs() << "Get target shape\n";
- SmallVector<int64_t> unrollShape = llvm::to_vector<4>(vecType.getShape());
- // TODO: This is hardcoded to unroll with size 1. Change this later
- std::optional<SmallVector<int64_t>> targetShape =
- SmallVector<int64_t>(1, getComputeVectorSize(vecType.getShape().back()));
- if (!targetShape) {
- llvm::errs() << "--no unrolling target shape defined\n";
- return std::nullopt;
- }
- auto maybeShapeRatio = computeShapeRatio(unrollShape, *targetShape);
- if (!maybeShapeRatio) {
- llvm::errs() << "--could not compute integral shape ratio -> BAIL\n";
- return std::nullopt;
- }
- if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) {
- llvm::errs() << "--no unrolling needed -> SKIP\n";
- return std::nullopt;
- }
- llvm::errs() << "--found an integral shape ratio to unroll to -> SUCCESS\n";
- return targetShape;
-}
-
LogicalResult
-FuncOpVectorTypesConversion::matchAndRewrite(func::FuncOp funcOp,
- PatternRewriter &rewriter) const {
+FuncOpVectorUnroll::matchAndRewrite(func::FuncOp funcOp,
+ PatternRewriter &rewriter) const {
auto fnType = funcOp.getFunctionType();
- // First create a new func op and copy the function body.
+ // Create a new func op with the original type and copy the function body.
auto newFuncOp =
rewriter.create<func::FuncOp>(funcOp.getLoc(), funcOp.getName(), fnType);
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
newFuncOp.end());
- rewriter.eraseOp(funcOp);
+
llvm::errs() << "After creating new func op and copying the function body\n";
newFuncOp.dump();
@@ -894,18 +888,30 @@ FuncOpVectorTypesConversion::matchAndRewrite(func::FuncOp funcOp,
rewriter.setInsertionPointToStart(&entryBlock);
OneToNTypeMapping oneToNTypeMapping(fnType.getInputs());
+
+ // For arguments that are of illegal types and require unrolling.
+ // `unrolledInputNums` stores the indices of arguments that result from
+ // unrolling in the new function signature. `newInputNo` is a counter.
SmallVector<size_t> unrolledInputNums;
size_t newInputNo = 0;
- std::unordered_map<Operation *, size_t> tmpOps;
+ // For arguments that are of legal types and do not require unrolling.
+ // `tmpOps` stores a mapping from temporary operations that serve as
+ // placeholders for new arguments that will be added later. These operations
+ // will be erased once the entry block's argument list is updated.
+ DenseMap<Operation *, size_t> tmpOps;
+
+ // This counts the number of new operations created.
size_t newOpCount = 0;
// Enumerate through the arguments.
for (const auto &argType : enumerate(fnType.getInputs())) {
size_t origInputNo = argType.index();
Type origType = argType.value();
+ // Check whether the argument is of vector type.
auto origVecType = llvm::dyn_cast<VectorType>(origType);
if (!origVecType) {
+ // We need a placeholder for the old argument that will be erased later.
Value result = rewriter.create<arith::ConstantOp>(
loc, origType, rewriter.getZeroAttr(origType));
rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
@@ -915,10 +921,10 @@ FuncOpVectorTypesConversion::matchAndRewrite(func::FuncOp funcOp,
newOpCount++;
continue;
}
- llvm::errs() << "Try vector unrolling\n";
+ // Check whether the vector needs unrolling.
auto targetShape = getTargetShape(origVecType);
if (!targetShape) {
- llvm::errs() << "No target shape\n";
+ // We need a placeholder for the old argument that will be erased later.
Value result = rewriter.create<arith::ConstantOp>(
loc, origType, rewriter.getZeroAttr(origType));
rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
@@ -935,23 +941,23 @@ FuncOpVectorTypesConversion::matchAndRewrite(func::FuncOp funcOp,
unrolledType.dump();
SmallVector<int64_t> originalShape =
llvm::to_vector<4>(origVecType.getShape());
- SmallVector<Type> newTypes;
- // Prepare the result vector
+
+ // Prepare the result vector.
Value result = rewriter.create<arith::ConstantOp>(
loc, origVecType, rewriter.getZeroAttr(origVecType));
- result.dump();
newOpCount++;
- // Prepare the placeholder
+ // Prepare the placeholder for the new arguments that will be added later.
Value dummy = rewriter.create<arith::ConstantOp>(
loc, unrolledType, rewriter.getZeroAttr(unrolledType));
- dummy.dump();
newOpCount++;
+
+ // Create the `vector.insert_strided_slice` ops.
SmallVector<int64_t> strides(targetShape->size(), 1);
+ SmallVector<Type> newTypes;
for (SmallVector<int64_t> offsets :
StaticTileOffsetRange(originalShape, *targetShape)) {
result = rewriter.create<vector::InsertStridedSliceOp>(loc, dummy, result,
offsets, strides);
- result.dump();
newTypes.push_back(unrolledType);
unrolledInputNums.push_back(newInputNo);
newInputNo++;
@@ -964,10 +970,11 @@ FuncOpVectorTypesConversion::matchAndRewrite(func::FuncOp funcOp,
llvm::errs() << "After enumerating through the arguments\n";
newFuncOp.dump();
- // Change function signature.
- auto newFnType = FunctionType::get(
- rewriter.getContext(), TypeRange(oneToNTypeMapping.getConvertedTypes()),
- TypeRange(fnType.getResults()));
+ // Change the function signature.
+ auto convertedTypes = oneToNTypeMapping.getConvertedTypes();
+ auto newFnType =
+ FunctionType::get(rewriter.getContext(), TypeRange(convertedTypes),
+ TypeRange(fnType.getResults()));
rewriter.modifyOpInPlace(newFuncOp,
[&] { newFuncOp.setFunctionType(newFnType); });
@@ -976,48 +983,52 @@ FuncOpVectorTypesConversion::matchAndRewrite(func::FuncOp funcOp,
// Update the arguments in the entry block.
entryBlock.eraseArguments(0, fnType.getNumInputs());
- SmallVector<Location> locs(oneToNTypeMapping.getConvertedTypes().size(),
- newFuncOp.getLoc());
- entryBlock.addArguments(oneToNTypeMapping.getConvertedTypes(), locs);
+ SmallVector<Location> locs(convertedTypes.size(), newFuncOp.getLoc());
+ entryBlock.addArguments(convertedTypes, locs);
llvm::errs() << "After updating the arguments in the entry block\n";
newFuncOp.dump();
- // Replace the dummy values with actual arguments.
- size_t i = 0;
- for (auto pair : llvm::enumerate(entryBlock.getOperations())) {
- size_t count = pair.index();
- Operation &op = pair.value();
- op.dump();
- for (auto pair : llvm::enumerate(op.getOperands())) {
- Operation *operandOp = pair.value().getDefiningOp();
- if (tmpOps.find(operandOp) != tmpOps.end()) {
+ // Replace the placeholder values with the new arguments. We assume there is
+ // only one block for now.
+ size_t idx = 0;
+ for (auto opPair : llvm::enumerate(entryBlock.getOperations())) {
+ size_t count = opPair.index();
+ Operation &op = opPair.value();
+ // We first look for operands that are placeholders for initially legal
+ // arguments.
+ for (auto operandPair : llvm::enumerate(op.getOperands())) {
+ Operation *operandOp = operandPair.value().getDefiningOp();
+ if (tmpOps.find(operandOp) != tmpOps.end())
rewriter.modifyOpInPlace(&op, [&] {
- op.setOperand(pair.index(), newFuncOp.getArgument(tmpOps[operandOp]));
+ op.setOperand(operandPair.index(),
+ newFuncOp.getArgument(tmpOps[operandOp]));
});
- rewriter.eraseOp(operandOp);
- count++;
- continue;
- }
}
- if (count == newOpCount)
+ // Since all newly created operations are in the beginning, reaching the end
+ // of them means that any later `vector.insert_strided_slice` should not be
+ // touched.
+ if (count >= newOpCount)
continue;
auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op);
if (vecOp) {
- size_t unrolledInputNo = unrolledInputNums[i];
+ size_t unrolledInputNo = unrolledInputNums[idx];
rewriter.modifyOpInPlace(&op, [&] {
op.setOperand(0, newFuncOp.getArgument(unrolledInputNo));
});
- i++;
+ idx++;
}
count++;
}
+ // Erase the original funcOp. The `tmpOps` do not need to be erased since
+ // they have no uses and will be handled by dead-code elimination.
+ rewriter.eraseOp(funcOp);
return success();
}
void mlir::populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns) {
- patterns.add<FuncOpVectorTypesConversion>(patterns.getContext());
+ patterns.add<FuncOpVectorUnroll>(patterns.getContext());
}
//===----------------------------------------------------------------------===//
@@ -1027,7 +1038,7 @@ void mlir::populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns) {
namespace {
/// A pattern for rewriting function signature and the return op to convert
/// vectors to be of valid types.
-class ReturnOpVectorTypesConversion : public OpRewritePattern<func::ReturnOp> {
+class ReturnOpVectorUnroll : public OpRewritePattern<func::ReturnOp> {
public:
using OpRewritePattern<func::ReturnOp>::OpRewritePattern;
@@ -1036,9 +1047,11 @@ class ReturnOpVectorTypesConversion : public OpRewritePattern<func::ReturnOp> {
};
} // namespace
-LogicalResult ReturnOpVectorTypesConversion::matchAndRewrite(
- func::ReturnOp returnOp, PatternRewriter &rewriter) const {
+LogicalResult
+ReturnOpVectorUnroll::matchAndRewrite(func::ReturnOp returnOp,
+ PatternRewriter &rewriter) const {
+ // Check whether the parent funcOp is valid.
func::FuncOp funcOp = dyn_cast<func::FuncOp>(returnOp->getParentOp());
if (!funcOp)
return failure();
@@ -1046,6 +1059,8 @@ LogicalResult ReturnOpVectorTypesConversion::matchAndRewrite(
auto fnType = funcOp.getFunctionType();
OneToNTypeMapping oneToNTypeMapping(fnType.getResults());
Location loc = returnOp.getLoc();
+
+ // For the new return op.
SmallVector<Value> newOperands;
// Enumerate through the results.
@@ -1053,15 +1068,16 @@ LogicalResult ReturnOpVectorTypesConversion::matchAndRewrite(
size_t origResultNo = argType.index();
Type origType = argType.value();
auto origVecType = llvm::dyn_cast<VectorType>(origType);
+ // Check whether the argument is of vector type.
if (!origVecType) {
oneToNTypeMapping.addInputs(origResultNo, origType);
newOperands.push_back(returnOp.getOperand(origResultNo));
continue;
}
- llvm::errs() << "Try vector unrolling\n";
+ // Check whether the vector needs unrolling.
auto targetShape = getTargetShape(origVecType);
if (!targetShape) {
- llvm::errs() << "No target shape\n";
+ // The original argument can be used.
oneToNTypeMapping.addInputs(origResultNo, origType);
newOperands.push_back(returnOp.getOperand(origResultNo));
continue;
@@ -1071,16 +1087,18 @@ LogicalResult ReturnOpVectorTypesConversion::matchAndRewrite(
VectorType::get(*targetShape, origVecType.getElementType());
llvm::errs() << "Unrolled type is ";
unrolledType.dump();
+
+ // Create `vector.extract_strided_slice` ops to form legal vectors from the
+ // original operand of illegal type.
SmallVector<int64_t> originalShape =
llvm::to_vector<4>(origVecType.getShape());
- SmallVector<Type> newTypes;
SmallVector<int64_t> strides(targetShape->size(), 1);
+ SmallVector<Type> newTypes;
Value returnValue = returnOp.getOperand(origResultNo);
for (SmallVector<int64_t> offsets :
StaticTileOffsetRange(originalShape, *targetShape)) {
- auto result = rewriter.create<vector::ExtractStridedSliceOp>(
+ Value result = rewriter.create<vector::ExtractStridedSliceOp>(
loc, returnValue, offsets, *targetShape, strides);
- result.dump();
newOperands.push_back(result);
newTypes.push_back(unrolledType);
}
@@ -1090,10 +1108,7 @@ LogicalResult ReturnOpVectorTypesConversion::matchAndRewrite(
llvm::errs() << "After enumerating through the arguments\n";
funcOp.dump();
- for (auto operand : newOperands)
- operand.dump();
-
- // Change function signature.
+ // Change the function signature.
auto newFnType =
FunctionType::get(rewriter.getContext(), TypeRange(fnType.getInputs()),
TypeRange(oneToNTypeMapping.getConvertedTypes()));
@@ -1101,17 +1116,16 @@ LogicalResult ReturnOpVectorTypesConversion::matchAndRewrite(
llvm::errs() << "After changing function signature\n";
funcOp.dump();
- // Replace the return op using the new operands.
+ // Replace the return op using the new operands. This will automatically
+ // update the entry block as well.
rewriter.replaceOp(returnOp,
rewriter.create<func::ReturnOp>(loc, newOperands));
- llvm::errs() << "After replacing return op\n";
- funcOp.dump();
return success();
}
void mlir::populateReturnOpVectorRewritePatterns(RewritePatternSet &patterns) {
- patterns.add<ReturnOpVectorTypesConversion>(patterns.getContext());
+ patterns.add<ReturnOpVectorUnroll>(patterns.getContext());
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index 1538c7eed6e76..723b2f62d65d4 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -43,7 +43,6 @@ add_mlir_dialect_library(MLIRVectorTransforms
MLIRMemRefDialect
MLIRMemRefUtils
MLIRSCFDialect
- MLIRSPIRVDialect
MLIRSideEffectInterfaces
MLIRSubsetOpInterface
MLIRTensorDialect
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index b63cb502b76e8..b3f558c3bac12 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -11,26 +11,12 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/IR/AffineOps.h"
-#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
-#include "mlir/IR/Block.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
-#include "mlir/IR/Location.h"
-#include "mlir/IR/TypeRange.h"
-#include "mlir/IR/Value.h"
-#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/VectorInterfaces.h"
-#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/SmallVector.h"
-#include "llvm/ADT/iterator_range.h"
-#include "llvm/IR/DerivedTypes.h"
-#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include <numeric>
#include <optional>
@@ -79,32 +65,6 @@ static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc,
resultTypes, op->getAttrs());
}
-static std::optional<SmallVector<int64_t>>
-getTargetShape(const vector::UnrollVectorOptions &options, func::FuncOp funcOp,
- VectorType vecType) {
- assert(options.nativeShape &&
- "vector unrolling expects the native shape or native"
- "shape call back function to be set");
- llvm::errs() << "Get target shape\n";
- SmallVector<int64_t> unrollShape = llvm::to_vector<4>(vecType.getShape());
- std::optional<SmallVector<int64_t>> targetShape = options.nativeShape(funcOp);
- if (!targetShape) {
- llvm::errs() << "--no unrolling target shape defined\n";
- return std::nullopt;
- }
- auto maybeShapeRatio = computeShapeRatio(unrollShape, *targetShape);
- if (!maybeShapeRatio) {
- llvm::errs() << "--could not compute integral shape ratio -> BAIL\n";
- return std::nullopt;
- }
- if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) {
- llvm::errs() << "--no unrolling needed -> SKIP\n";
- return std::nullopt;
- }
- llvm::errs() << "--found an integral shape ratio to unroll to -> SUCCESS\n";
- return targetShape;
-}
-
/// Return the target shape for unrolling for the given `op`. Return
/// std::nullopt if the op shouldn't be or cannot be unrolled.
static std::optional<SmallVector<int64_t>>
@@ -657,160 +617,6 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
vector::UnrollVectorOptions options;
};
-struct UnrollFuncSignaturePattern : OpRewritePattern<func::FuncOp> {
- UnrollFuncSignaturePattern(MLIRContext *context,
- const vector::UnrollVectorOptions &options,
- PatternBenefit benefit = 1)
- : OpRewritePattern<func::FuncOp>(context, benefit), options(options) {}
-
- LogicalResult matchAndRewrite(func::FuncOp funcOp,
- PatternRewriter &rewriter) const override {
- llvm::errs() << "Run unroll function signature pattern\n";
-
- auto fnType = funcOp.getFunctionType();
-
- // Check function inputs.
- Location loc = funcOp.getFunctionBody()
- .getBlocks()
- .begin()
- ->getOperations()
- .begin()
- ->getLoc();
- size_t newArgIndex = 0;
- std::vector<Type> newSignature;
- std::vector<std::vector<size_t>> newArgMap(fnType.getNumInputs());
-
- for (const auto &argType : enumerate(fnType.getInputs())) {
- size_t index = argType.index();
- Type type = argType.value();
- auto vecType = llvm::dyn_cast<VectorType>(type);
- if (!vecType) {
- newSignature.push_back(type);
- newArgMap[index].push_back(newArgIndex);
- newArgIndex++;
- continue;
- }
- // Try vector unrolling
- llvm::errs() << "Try vector unrolling\n";
- SmallVector<int64_t> originalShape =
- llvm::to_vector<4>(vecType.getShape());
- auto targetShape = getTargetShape(options, funcOp, vecType);
- if (!targetShape) {
- llvm::errs() << "No target shape\n";
- newSignature.push_back(type);
- newArgMap[index].push_back(newArgIndex);
- newArgIndex++;
- continue;
- }
- llvm::errs() << "Got target shape\n";
- VectorType unrolledType =
- VectorType::get(*targetShape, vecType.getElementType());
- llvm::errs() << "Unrolled type is ";
- unrolledType.dump();
-
- for (SmallVector<int64_t> offsets :
- StaticTileOffsetRange(originalShape, *targetShape)) {
- newSignature.push_back(unrolledType);
- newArgMap[index].push_back(newArgIndex);
- newArgIndex++;
- }
- }
-
- // Assume there is a single result for now.
- Type originalResultType = fnType.getResult(0);
-
- // TODO: Handle illegal vector types in results as well.
- // SmallVector<Type> resultTypes;
- // auto vecType = llvm::dyn_cast<VectorType>(originalResultType);
-
- // if (vecType) {
- // // Try vector unrolling
- // SmallVector<int64_t> originalShape =
- // llvm::to_vector<4>(vecType.getShape()); auto targetShape =
- // getTargetShape(options, funcOp, vecType); VectorType unrolledType =
- // VectorType::get(*targetShape, vecType.getElementType());
- // if (targetShape)
- // for (SmallVector<int64_t> offsets :
- // StaticTileOffsetRange(originalShape, *targetShape))
- // resultTypes.push_back(unrolledType);
- // }
-
- // Create the converted func op
- auto newFuncOp = rewriter.create<func::FuncOp>(
- funcOp.getLoc(), funcOp.getName(),
- FunctionType::get(rewriter.getContext(), TypeRange(newSignature),
- TypeRange(originalResultType)));
-
- newFuncOp.addEntryBlock();
-
- llvm::errs() << "Created new func op\n";
- newFuncOp.dump();
- llvm::errs() << newFuncOp.getArguments().size() << "\n";
-
- // TODO: Copy over all attributes other than the function name and type
-
- // Clone operations (assuming one block for now)
- // TODO: The uses for operands that are SSA values are not cloned properly.
- loc = newFuncOp.getBody().getLoc();
- rewriter.setInsertionPointToStart(&newFuncOp.getBody().getBlocks().front());
-
- for (auto &op : funcOp.getBlocks().front().getOperations()) {
- op.dump();
- SmallVector<Value> newOperands(op.getNumOperands());
- for (size_t i = 0; i < op.getOperands().size(); ++i) {
- Value operand = op.getOperand(i);
- auto blockArg = llvm::dyn_cast<BlockArgument>(operand);
- if (!blockArg) {
- newOperands[i] = operand;
- continue;
- }
- // Not unrolled
- unsigned int argNum = blockArg.getArgNumber();
- if (newArgMap[argNum].size() == 1) {
- newOperands[i] = newFuncOp.getArgument(newArgMap[argNum][0]);
- continue;
- }
- // Unrolled
- // TODO: Store previously created vector.insert_strided_slice ops.
- auto vecType = dyn_cast<VectorType>(blockArg.getType());
- SmallVector<int64_t> originalShape =
- llvm::to_vector<4>(vecType.getShape());
- auto targetShape = getTargetShape(options, funcOp, vecType);
- VectorType unrolledType =
- VectorType::get(*targetShape, vecType.getElementType());
- llvm::errs() << "Unrolled type is ";
- unrolledType.dump();
- // Prepare the result vector.
- Value result = rewriter.create<arith::ConstantOp>(
- loc, vecType, rewriter.getZeroAttr(vecType));
- result.dump();
- SmallVector<int64_t> strides(targetShape->size(), 1);
- // Create the vector.insert_strided_slice ops.
- unsigned int j = 0;
- for (SmallVector<int64_t> offsets :
- StaticTileOffsetRange(originalShape, *targetShape)) {
- result = rewriter.create<vector::InsertStridedSliceOp>(
- loc, newFuncOp.getArgument(newArgMap[argNum][j]), result, offsets,
- strides);
- result.dump();
- j++;
- }
- newOperands[i] = result;
- }
- Operation *newOp =
- rewriter.create(loc, op.getName().getIdentifier(), newOperands,
- op.getResultTypes(), op.getAttrs());
- llvm::errs() << "newOp is ";
- newOp->dump();
- }
- rewriter.eraseOp(funcOp);
- return success();
- }
-
-private:
- vector::UnrollVectorOptions options;
-};
-
} // namespace
void mlir::vector::populateVectorUnrollPatterns(
@@ -822,10 +628,3 @@ void mlir::vector::populateVectorUnrollPatterns(
UnrollTransposePattern, UnrollGatherPattern>(
patterns.getContext(), options, benefit);
}
-
-void mlir::vector::populateVectorUnrollFuncSignaturePatterns(
- RewritePatternSet &patterns, const UnrollVectorOptions &options,
- PatternBenefit benefit) {
- patterns.add<UnrollFuncSignaturePattern>(patterns.getContext(), options,
- benefit);
-}
\ No newline at end of file
>From a055178e070d458158808d17a1c50d5c59952030 Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Wed, 10 Jul 2024 13:07:28 +0000
Subject: [PATCH 08/14] Add an option for testing signature conversion
---
mlir/include/mlir/Conversion/Passes.td | 5 ++
.../ConvertToSPIRV/ConvertToSPIRVPass.cpp | 49 ++++++++++---------
2 files changed, 31 insertions(+), 23 deletions(-)
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 8d83343f5b736..598bba63a2a82 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -44,6 +44,11 @@ def ConvertToSPIRVPass : Pass<"convert-to-spirv"> {
"spirv::SPIRVDialect",
"vector::VectorDialect",
];
+ let options = [
+ Option<"runSignatureConversion", "run-signature-conversion", "bool",
+ /*default=*/"false",
+ "Run function signature conversion to convert vector types">
+ ];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
index ddfbb0a76ad11..21a5a44ece92a 100644
--- a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
@@ -37,36 +37,39 @@ using namespace mlir;
namespace {
/// A pass to perform the SPIR-V conversion.
-struct ConvertToSPIRVPass final
- : impl::ConvertToSPIRVPassBase<ConvertToSPIRVPass> {
+class ConvertToSPIRVPass
+ : public impl::ConvertToSPIRVPassBase<ConvertToSPIRVPass> {
+ using impl::ConvertToSPIRVPassBase<ConvertToSPIRVPass>::ConvertToSPIRVPassBase;
void runOnOperation() override {
MLIRContext *context = &getContext();
Operation *op = getOperation();
- // Unroll vectors in function inputs to native vector size.
- llvm::errs() << "Start unrolling function inputs\n";
- {
- RewritePatternSet patterns(context);
- populateFuncOpVectorRewritePatterns(patterns);
- GreedyRewriteConfig config;
- config.strictMode = GreedyRewriteStrictness::ExistingOps;
- if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
- return signalPassFailure();
- }
- llvm::errs() << "Finish unrolling function inputs\n";
+ if (runSignatureConversion) {
+ // Unroll vectors in function inputs to native vector size.
+ llvm::errs() << "Start unrolling function inputs\n";
+ {
+ RewritePatternSet patterns(context);
+ populateFuncOpVectorRewritePatterns(patterns);
+ GreedyRewriteConfig config;
+ config.strictMode = GreedyRewriteStrictness::ExistingOps;
+ if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
+ return signalPassFailure();
+ }
+ llvm::errs() << "Finish unrolling function inputs\n";
- // Unroll vectors in function outputs to native vector size.
- llvm::errs() << "Start unrolling function outputs\n";
- {
- RewritePatternSet patterns(context);
- populateReturnOpVectorRewritePatterns(patterns);
- GreedyRewriteConfig config;
- config.strictMode = GreedyRewriteStrictness::ExistingOps;
- if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
- return signalPassFailure();
+ // Unroll vectors in function outputs to native vector size.
+ llvm::errs() << "Start unrolling function outputs\n";
+ {
+ RewritePatternSet patterns(context);
+ populateReturnOpVectorRewritePatterns(patterns);
+ GreedyRewriteConfig config;
+ config.strictMode = GreedyRewriteStrictness::ExistingOps;
+ if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
+ return signalPassFailure();
+ }
+ llvm::errs() << "Finish unrolling function outputs\n";
}
- llvm::errs() << "Finish unrolling function outputs\n";
spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
std::unique_ptr<ConversionTarget> target =
>From 61cf2559a0f7f71b5bf39956b13deb57cf6aadca Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Wed, 10 Jul 2024 14:55:25 +0000
Subject: [PATCH 09/14] Add unit tests
---
.../ConvertToSPIRV/ConvertToSPIRVPass.cpp | 11 +-
.../func-signature-vector-unroll.mlir | 132 ++++++++++++++++++
2 files changed, 140 insertions(+), 3 deletions(-)
create mode 100644 mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir
diff --git a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
index 21a5a44ece92a..88d7590c1daae 100644
--- a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
@@ -39,7 +39,8 @@ namespace {
/// A pass to perform the SPIR-V conversion.
class ConvertToSPIRVPass
: public impl::ConvertToSPIRVPassBase<ConvertToSPIRVPass> {
- using impl::ConvertToSPIRVPassBase<ConvertToSPIRVPass>::ConvertToSPIRVPassBase;
+ using impl::ConvertToSPIRVPassBase<
+ ConvertToSPIRVPass>::ConvertToSPIRVPassBase;
void runOnOperation() override {
MLIRContext *context = &getContext();
@@ -53,7 +54,8 @@ class ConvertToSPIRVPass
populateFuncOpVectorRewritePatterns(patterns);
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingOps;
- if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
+ if (failed(
+ applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
return signalPassFailure();
}
llvm::errs() << "Finish unrolling function inputs\n";
@@ -65,10 +67,13 @@ class ConvertToSPIRVPass
populateReturnOpVectorRewritePatterns(patterns);
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingOps;
- if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
+ if (failed(
+ applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
return signalPassFailure();
}
llvm::errs() << "Finish unrolling function outputs\n";
+
+ return;
}
spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
diff --git a/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir b/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir
new file mode 100644
index 0000000000000..d5c777908d7e2
--- /dev/null
+++ b/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir
@@ -0,0 +1,132 @@
+// RUN: mlir-opt -convert-to-spirv="run-signature-conversion" -split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: @simple_scalar
+// CHECK-SAME: (%[[ARG0:.+]]: i32)
+func.func @simple_scalar(%arg0 : i32) -> i32 {
+ // CHECK: return %[[ARG0]] : i32
+ return %arg0 : i32
+}
+
+// -----
+
+// CHECK-LABEL: @simple_vector_4
+// CHECK-SAME: (%[[ARG0:.+]]: vector<4xi32>)
+func.func @simple_vector_4(%arg0 : vector<4xi32>) -> vector<4xi32> {
+ // CHECK: return %[[ARG0]] : vector<4xi32>
+ return %arg0 : vector<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @simple_vector_5
+// CHECK-SAME: (%[[ARG0:.+]]: vector<1xi32>, %[[ARG1:.+]]: vector<1xi32>, %[[ARG2:.+]]: vector<1xi32>, %[[ARG3:.+]]: vector<1xi32>, %[[ARG4:.+]]: vector<1xi32>)
+func.func @simple_vector_5(%arg0 : vector<5xi32>) -> vector<5xi32> {
+ // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<5xi32>
+ // CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG0]], %[[CST]] {offsets = [0], strides = [1]} : vector<1xi32> into vector<5xi32>
+ // CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG1]], %[[INSERT0]] {offsets = [1], strides = [1]} : vector<1xi32> into vector<5xi32>
+ // CHECK: %[[INSERT2:.*]] = vector.insert_strided_slice %[[ARG2]], %[[INSERT1]] {offsets = [2], strides = [1]} : vector<1xi32> into vector<5xi32>
+ // CHECK: %[[INSERT3:.*]] = vector.insert_strided_slice %[[ARG3]], %[[INSERT2]] {offsets = [3], strides = [1]} : vector<1xi32> into vector<5xi32>
+ // CHECK: %[[INSERT4:.*]] = vector.insert_strided_slice %[[ARG4]], %[[INSERT3]] {offsets = [4], strides = [1]} : vector<1xi32> into vector<5xi32>
+ // CHECK: %[[EXTRACT0:.*]] = vector.extract_strided_slice %[[INSERT4]] {offsets = [0], sizes = [1], strides = [1]} : vector<5xi32> to vector<1xi32>
+ // CHECK: %[[EXTRACT1:.*]] = vector.extract_strided_slice %[[INSERT4]] {offsets = [1], sizes = [1], strides = [1]} : vector<5xi32> to vector<1xi32>
+ // CHECK: %[[EXTRACT2:.*]] = vector.extract_strided_slice %[[INSERT4]] {offsets = [2], sizes = [1], strides = [1]} : vector<5xi32> to vector<1xi32>
+ // CHECK: %[[EXTRACT3:.*]] = vector.extract_strided_slice %[[INSERT4]] {offsets = [3], sizes = [1], strides = [1]} : vector<5xi32> to vector<1xi32>
+ // CHECK: %[[EXTRACT4:.*]] = vector.extract_strided_slice %[[INSERT4]] {offsets = [4], sizes = [1], strides = [1]} : vector<5xi32> to vector<1xi32>
+ // CHECK: return %[[EXTRACT0]], %[[EXTRACT1]], %[[EXTRACT2]], %[[EXTRACT3]], %[[EXTRACT4]] : vector<1xi32>, vector<1xi32>, vector<1xi32>, vector<1xi32>, vector<1xi32>
+ return %arg0 : vector<5xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @simple_vector_6
+// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi32>, %[[ARG1:.+]]: vector<3xi32>)
+func.func @simple_vector_6(%arg0 : vector<6xi32>) -> vector<6xi32> {
+ // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<6xi32>
+ // CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG0]], %[[CST]] {offsets = [0], strides = [1]} : vector<3xi32> into vector<6xi32>
+ // CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG1]], %[[INSERT0]] {offsets = [3], strides = [1]} : vector<3xi32> into vector<6xi32>
+ // CHECK: %[[EXTRACT0:.*]] = vector.extract_strided_slice %[[INSERT1]] {offsets = [0], sizes = [3], strides = [1]} : vector<6xi32> to vector<3xi32>
+ // CHECK: %[[EXTRACT1:.*]] = vector.extract_strided_slice %[[INSERT1]] {offsets = [3], sizes = [3], strides = [1]} : vector<6xi32> to vector<3xi32>
+ // CHECK: return %[[EXTRACT0]], %[[EXTRACT1]] : vector<3xi32>, vector<3xi32>
+ return %arg0 : vector<6xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @simple_vector_8
+// CHECK-SAME: (%[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>)
+func.func @simple_vector_8(%arg0 : vector<8xi32>) -> vector<8xi32> {
+ // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
+ // CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG0]], %[[CST]] {offsets = [0], strides = [1]} : vector<4xi32> into vector<8xi32>
+ // CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG1]], %[[INSERT0]] {offsets = [4], strides = [1]} : vector<4xi32> into vector<8xi32>
+ // CHECK: %[[EXTRACT0:.*]] = vector.extract_strided_slice %[[INSERT1]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xi32> to vector<4xi32>
+ // CHECK: %[[EXTRACT1:.*]] = vector.extract_strided_slice %[[INSERT1]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xi32> to vector<4xi32>
+ // CHECK: return %[[EXTRACT0]], %[[EXTRACT1]] : vector<4xi32>, vector<4xi32>
+ return %arg0 : vector<8xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_6and8
+// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi32>, %[[ARG1:.+]]: vector<3xi32>, %[[ARG2:.+]]: vector<4xi32>, %[[ARG3:.+]]: vector<4xi32>)
+func.func @vector_6and8(%arg0 : vector<6xi32>, %arg1 : vector<8xi32>) -> (vector<6xi32>, vector<8xi32>) {
+ // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
+ // CHECK: %[[CST0:.*]] = arith.constant dense<0> : vector<6xi32>
+ // CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG0]], %[[CST0]] {offsets = [0], strides = [1]} : vector<3xi32> into vector<6xi32>
+ // CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG1]], %[[INSERT0]] {offsets = [3], strides = [1]} : vector<3xi32> into vector<6xi32>
+ // CHECK: %[[INSERT2:.*]] = vector.insert_strided_slice %[[ARG2]], %[[CST]] {offsets = [0], strides = [1]} : vector<4xi32> into vector<8xi32>
+ // CHECK: %[[INSERT3:.*]] = vector.insert_strided_slice %[[ARG3]], %[[INSERT2]] {offsets = [4], strides = [1]} : vector<4xi32> into vector<8xi32>
+ // CHECK: %[[EXTRACT0:.*]] = vector.extract_strided_slice %[[INSERT1]] {offsets = [0], sizes = [3], strides = [1]} : vector<6xi32> to vector<3xi32>
+ // CHECK: %[[EXTRACT1:.*]] = vector.extract_strided_slice %[[INSERT1]] {offsets = [3], sizes = [3], strides = [1]} : vector<6xi32> to vector<3xi32>
+ // CHECK: %[[EXTRACT2:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xi32> to vector<4xi32>
+ // CHECK: %[[EXTRACT3:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xi32> to vector<4xi32>
+ // CHECK: return %[[EXTRACT0]], %[[EXTRACT1]], %[[EXTRACT2]], %[[EXTRACT3]] : vector<3xi32>, vector<3xi32>, vector<4xi32>, vector<4xi32>
+ return %arg0, %arg1 : vector<6xi32>, vector<8xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_3and8
+// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi32>, %[[ARG1:.+]]: vector<4xi32>, %[[ARG2:.+]]: vector<4xi32>)
+func.func @vector_3and8(%arg0 : vector<3xi32>, %arg1 : vector<8xi32>) -> (vector<3xi32>, vector<8xi32>) {
+ // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
+ // CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG1]], %[[CST]] {offsets = [0], strides = [1]} : vector<4xi32> into vector<8xi32>
+ // CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG2]], %[[INSERT0]] {offsets = [4], strides = [1]} : vector<4xi32> into vector<8xi32>
+ // CHECK: %[[EXTRACT0:.*]] = vector.extract_strided_slice %[[INSERT1]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xi32> to vector<4xi32>
+ // CHECK: %[[EXTRACT1:.*]] = vector.extract_strided_slice %[[INSERT1]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xi32> to vector<4xi32>
+ // CHECK: return %[[ARG0]], %[[EXTRACT0]], %[[EXTRACT1]] : vector<3xi32>, vector<4xi32>, vector<4xi32>
+ return %arg0, %arg1 : vector<3xi32>, vector<8xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @scalar_vector
+// CHECK-SAME: (%[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>, %[[ARG2:.+]]: vector<3xi32>, %[[ARG3:.+]]: i32)
+func.func @scalar_vector(%arg0 : vector<8xi32>, %arg1 : vector<3xi32>, %arg2 : i32) -> (vector<8xi32>, vector<3xi32>, i32) {
+ // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
+ // CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG0]], %[[CST]] {offsets = [0], strides = [1]} : vector<4xi32> into vector<8xi32>
+ // CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG1]], %[[INSERT0]] {offsets = [4], strides = [1]} : vector<4xi32> into vector<8xi32>
+ // CHECK: %[[EXTRACT0:.*]] = vector.extract_strided_slice %[[INSERT1]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xi32> to vector<4xi32>
+ // CHECK: %[[EXTRACT1:.*]] = vector.extract_strided_slice %[[INSERT1]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xi32> to vector<4xi32>
+ // CHECK: return %[[EXTRACT0]], %[[EXTRACT1]], %[[ARG2]], %[[ARG3]] : vector<4xi32>, vector<4xi32>, vector<3xi32>, i32
+ return %arg0, %arg1, %arg2 : vector<8xi32>, vector<3xi32>, i32
+}
+
+// -----
+
+// CHECK-LABEL: @reduction
+// CHECK-SAME: (%[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>, %[[ARG2:.+]]: vector<4xi32>, %[[ARG3:.+]]: vector<4xi32>, %[[ARG4:.+]]: i32)
+func.func @reduction(%arg0 : vector<8xi32>, %arg1 : vector<8xi32>, %arg2 : i32) -> (i32) {
+ // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
+ // CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG0]], %[[CST]] {offsets = [0], strides = [1]} : vector<4xi32> into vector<8xi32>
+ // CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG1]], %[[INSERT0]] {offsets = [4], strides = [1]} : vector<4xi32> into vector<8xi32>
+ // CHECK: %[[INSERT2:.*]] = vector.insert_strided_slice %[[ARG2]], %[[CST]] {offsets = [0], strides = [1]} : vector<4xi32> into vector<8xi32>
+ // CHECK: %[[INSERT3:.*]] = vector.insert_strided_slice %[[ARG3]], %[[INSERT2]] {offsets = [4], strides = [1]} : vector<4xi32> into vector<8xi32>
+ // CHECK: %[[ADDI:.*]] = arith.addi %[[INSERT1]], %[[INSERT3]] : vector<8xi32>
+ // CHECK: %[[REDUCTION:.*]] = vector.reduction <add>, %[[ADDI]] : vector<8xi32> into i32
+ // CHECK: %[[RET:.*]] = arith.addi %[[REDUCTION]], %[[ARG4]] : i32
+ // CHECK: return %[[RET]] : i32
+ %0 = arith.addi %arg0, %arg1 : vector<8xi32>
+ %1 = vector.reduction <add>, %0 : vector<8xi32> into i32
+ %2 = arith.addi %1, %arg2 : i32
+ return %2 : i32
+}
>From a37422c99a5ef6e072a44723a546d155b0462eb5 Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Wed, 10 Jul 2024 15:50:34 +0000
Subject: [PATCH 10/14] Code formatting
---
mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
index 112c404527927..9ad3d5fc85dd3 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
@@ -19,8 +19,8 @@
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Transforms/DialectConversion.h"
-#include "llvm/ADT/SmallSet.h"
#include "mlir/Transforms/OneToNTypeConversion.h"
+#include "llvm/ADT/SmallSet.h"
namespace mlir {
>From fc237908b8f6e0dfd3feef2cfcc0a3353cb3ac28 Mon Sep 17 00:00:00 2001
From: Angel Zhang <anzhouzhang913 at gmail.com>
Date: Wed, 10 Jul 2024 13:21:33 -0400
Subject: [PATCH 11/14] Update
mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>
---
mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
index 88d7590c1daae..57dc11c434176 100644
--- a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
@@ -39,8 +39,7 @@ namespace {
/// A pass to perform the SPIR-V conversion.
class ConvertToSPIRVPass
: public impl::ConvertToSPIRVPassBase<ConvertToSPIRVPass> {
- using impl::ConvertToSPIRVPassBase<
- ConvertToSPIRVPass>::ConvertToSPIRVPassBase;
+ using impl::ConvertToSPIRVPassBase::ConvertToSPIRVPassBase;
void runOnOperation() override {
MLIRContext *context = &getContext();
>From a194ff067bf5ef4d0f44b6b7052490a4dca96b88 Mon Sep 17 00:00:00 2001
From: Angel Zhang <anzhouzhang913 at gmail.com>
Date: Wed, 10 Jul 2024 13:21:42 -0400
Subject: [PATCH 12/14] Update
mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>
---
mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 6e793573f0262..1e2eb336372a1 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -909,7 +909,7 @@ FuncOpVectorUnroll::matchAndRewrite(func::FuncOp funcOp,
size_t origInputNo = argType.index();
Type origType = argType.value();
// Check whether the argument is of vector type.
- auto origVecType = llvm::dyn_cast<VectorType>(origType);
+ auto origVecType = dyn_cast<VectorType>(origType);
if (!origVecType) {
// We need a placeholder for the old argument that will be erased later.
Value result = rewriter.create<arith::ConstantOp>(
>From 21077fc29a2861626540d9b2a58c5e2694fa3d71 Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Wed, 10 Jul 2024 18:45:42 +0000
Subject: [PATCH 13/14] Run both patterns at the same time
---
.../ConvertToSPIRV/ConvertToSPIRVPass.cpp | 20 +++----------------
1 file changed, 3 insertions(+), 17 deletions(-)
diff --git a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
index 57dc11c434176..3ffaff76e566d 100644
--- a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
@@ -39,30 +39,18 @@ namespace {
/// A pass to perform the SPIR-V conversion.
class ConvertToSPIRVPass
: public impl::ConvertToSPIRVPassBase<ConvertToSPIRVPass> {
- using impl::ConvertToSPIRVPassBase::ConvertToSPIRVPassBase;
+ using impl::ConvertToSPIRVPassBase<
+ ConvertToSPIRVPass>::ConvertToSPIRVPassBase;
void runOnOperation() override {
MLIRContext *context = &getContext();
Operation *op = getOperation();
if (runSignatureConversion) {
- // Unroll vectors in function inputs to native vector size.
- llvm::errs() << "Start unrolling function inputs\n";
+ // Unroll vectors in function signatures to native vector size.
{
RewritePatternSet patterns(context);
populateFuncOpVectorRewritePatterns(patterns);
- GreedyRewriteConfig config;
- config.strictMode = GreedyRewriteStrictness::ExistingOps;
- if (failed(
- applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
- return signalPassFailure();
- }
- llvm::errs() << "Finish unrolling function inputs\n";
-
- // Unroll vectors in function outputs to native vector size.
- llvm::errs() << "Start unrolling function outputs\n";
- {
- RewritePatternSet patterns(context);
populateReturnOpVectorRewritePatterns(patterns);
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingOps;
@@ -70,8 +58,6 @@ class ConvertToSPIRVPass
applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
return signalPassFailure();
}
- llvm::errs() << "Finish unrolling function outputs\n";
-
return;
}
>From 139ea575815c30d7478db1fa40f7c3448984f321 Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Wed, 10 Jul 2024 19:50:36 +0000
Subject: [PATCH 14/14] Code refactoring and formatting
---
.../ConvertToSPIRV/ConvertToSPIRVPass.cpp | 2 +-
.../SPIRV/Transforms/SPIRVConversion.cpp | 428 ++++++++----------
2 files changed, 197 insertions(+), 233 deletions(-)
diff --git a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
index 3ffaff76e566d..9d1f9d0f85ea1 100644
--- a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
@@ -37,7 +37,7 @@ using namespace mlir;
namespace {
/// A pass to perform the SPIR-V conversion.
-class ConvertToSPIRVPass
+struct ConvertToSPIRVPass
: public impl::ConvertToSPIRVPassBase<ConvertToSPIRVPass> {
using impl::ConvertToSPIRVPassBase<
ConvertToSPIRVPass>::ConvertToSPIRVPassBase;
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 1e2eb336372a1..148df27c47df8 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -53,24 +53,26 @@ static int getComputeVectorSize(int64_t size) {
}
static std::optional<SmallVector<int64_t>> getTargetShape(VectorType vecType) {
- llvm::errs() << "Get target shape\n";
+ LLVM_DEBUG(llvm::dbgs() << "Get target shape\n");
SmallVector<int64_t> unrollShape = llvm::to_vector<4>(vecType.getShape());
std::optional<SmallVector<int64_t>> targetShape =
SmallVector<int64_t>(1, getComputeVectorSize(vecType.getShape().back()));
if (!targetShape) {
- llvm::errs() << "--no unrolling target shape defined\n";
+ LLVM_DEBUG(llvm::dbgs() << "--no unrolling target shape defined\n");
return std::nullopt;
}
auto maybeShapeRatio = computeShapeRatio(unrollShape, *targetShape);
if (!maybeShapeRatio) {
- llvm::errs() << "--could not compute integral shape ratio -> BAIL\n";
+ LLVM_DEBUG(llvm::dbgs()
+ << "--could not compute integral shape ratio -> BAIL\n");
return std::nullopt;
}
if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) {
- llvm::errs() << "--no unrolling needed -> SKIP\n";
+ LLVM_DEBUG(llvm::dbgs() << "--no unrolling needed -> SKIP\n");
return std::nullopt;
}
- llvm::errs() << "--found an integral shape ratio to unroll to -> SUCCESS\n";
+ LLVM_DEBUG(llvm::dbgs()
+ << "--found an integral shape ratio to unroll to -> SUCCESS\n");
return targetShape;
}
@@ -865,167 +867,143 @@ class FuncOpVectorUnroll : public OpRewritePattern<func::FuncOp> {
using OpRewritePattern<func::FuncOp>::OpRewritePattern;
LogicalResult matchAndRewrite(func::FuncOp funcOp,
- PatternRewriter &rewriter) const override;
-};
-} // namespace
-
-LogicalResult
-FuncOpVectorUnroll::matchAndRewrite(func::FuncOp funcOp,
- PatternRewriter &rewriter) const {
- auto fnType = funcOp.getFunctionType();
-
- // Create a new func op with the original type and copy the function body.
- auto newFuncOp =
- rewriter.create<func::FuncOp>(funcOp.getLoc(), funcOp.getName(), fnType);
- rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
- newFuncOp.end());
-
- llvm::errs() << "After creating new func op and copying the function body\n";
- newFuncOp.dump();
-
- Location loc = newFuncOp.getBody().getLoc();
- Block &entryBlock = newFuncOp.getBlocks().front();
- rewriter.setInsertionPointToStart(&entryBlock);
-
- OneToNTypeMapping oneToNTypeMapping(fnType.getInputs());
-
- // For arguments that are of illegal types and require unrolling.
- // `unrolledInputNums` stores the indices of arguments that result from
- // unrolling in the new function signature. `newInputNo` is a counter.
- SmallVector<size_t> unrolledInputNums;
- size_t newInputNo = 0;
-
- // For arguments that are of legal types and do not require unrolling.
- // `tmpOps` stores a mapping from temporary operations that serve as
- // placeholders for new arguments that will be added later. These operations
- // will be erased once the entry block's argument list is updated.
- DenseMap<Operation *, size_t> tmpOps;
-
- // This counts the number of new operations created.
- size_t newOpCount = 0;
+ PatternRewriter &rewriter) const override {
+ FunctionType fnType = funcOp.getFunctionType();
+
+ // Create a new func op with the original type and copy the function body.
+ auto newFuncOp = rewriter.create<func::FuncOp>(funcOp.getLoc(),
+ funcOp.getName(), fnType);
+ rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
+ newFuncOp.end());
+
+ Location loc = newFuncOp.getBody().getLoc();
+ Block &entryBlock = newFuncOp.getBlocks().front();
+ rewriter.setInsertionPointToStart(&entryBlock);
+
+ OneToNTypeMapping oneToNTypeMapping(fnType.getInputs());
+
+ // For arguments that are of illegal types and require unrolling.
+ // `unrolledInputNums` stores the indices of arguments that result from
+ // unrolling in the new function signature. `newInputNo` is a counter.
+ SmallVector<size_t> unrolledInputNums;
+ size_t newInputNo = 0;
+
+ // For arguments that are of legal types and do not require unrolling.
+ // `tmpOps` stores a mapping from temporary operations that serve as
+ // placeholders for new arguments that will be added later. These operations
+ // will be erased once the entry block's argument list is updated.
+ DenseMap<Operation *, size_t> tmpOps;
+
+ // This counts the number of new operations created.
+ size_t newOpCount = 0;
+
+ // Enumerate through the arguments.
+ for (auto [origInputNo, origType] : enumerate(fnType.getInputs())) {
+ // Check whether the argument is of vector type.
+ auto origVecType = dyn_cast<VectorType>(origType);
+ if (!origVecType) {
+ // We need a placeholder for the old argument that will be erased later.
+ Value result = rewriter.create<arith::ConstantOp>(
+ loc, origType, rewriter.getZeroAttr(origType));
+ rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
+ tmpOps.insert({result.getDefiningOp(), newInputNo});
+ oneToNTypeMapping.addInputs(origInputNo, origType);
+ newInputNo++;
+ newOpCount++;
+ continue;
+ }
+ // Check whether the vector needs unrolling.
+ auto targetShape = getTargetShape(origVecType);
+ if (!targetShape) {
+ // We need a placeholder for the old argument that will be erased later.
+ Value result = rewriter.create<arith::ConstantOp>(
+ loc, origType, rewriter.getZeroAttr(origType));
+ rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
+ tmpOps.insert({result.getDefiningOp(), newInputNo});
+ oneToNTypeMapping.addInputs(origInputNo, origType);
+ newInputNo++;
+ newOpCount++;
+ continue;
+ }
+ VectorType unrolledType =
+ VectorType::get(*targetShape, origVecType.getElementType());
+ SmallVector<int64_t> originalShape =
+ llvm::to_vector<4>(origVecType.getShape());
- // Enumerate through the arguments.
- for (const auto &argType : enumerate(fnType.getInputs())) {
- size_t origInputNo = argType.index();
- Type origType = argType.value();
- // Check whether the argument is of vector type.
- auto origVecType = dyn_cast<VectorType>(origType);
- if (!origVecType) {
- // We need a placeholder for the old argument that will be erased later.
- Value result = rewriter.create<arith::ConstantOp>(
- loc, origType, rewriter.getZeroAttr(origType));
- rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
- tmpOps.insert({result.getDefiningOp(), newInputNo});
- oneToNTypeMapping.addInputs(origInputNo, origType);
- newInputNo++;
- newOpCount++;
- continue;
- }
- // Check whether the vector needs unrolling.
- auto targetShape = getTargetShape(origVecType);
- if (!targetShape) {
- // We need a placeholder for the old argument that will be erased later.
+ // Prepare the result vector.
Value result = rewriter.create<arith::ConstantOp>(
- loc, origType, rewriter.getZeroAttr(origType));
- rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
- tmpOps.insert({result.getDefiningOp(), newInputNo});
- oneToNTypeMapping.addInputs(origInputNo, origType);
- newInputNo++;
+ loc, origVecType, rewriter.getZeroAttr(origVecType));
newOpCount++;
- continue;
- }
- llvm::errs() << "Got target shape\n";
- VectorType unrolledType =
- VectorType::get(*targetShape, origVecType.getElementType());
- llvm::errs() << "Unrolled type is ";
- unrolledType.dump();
- SmallVector<int64_t> originalShape =
- llvm::to_vector<4>(origVecType.getShape());
-
- // Prepare the result vector.
- Value result = rewriter.create<arith::ConstantOp>(
- loc, origVecType, rewriter.getZeroAttr(origVecType));
- newOpCount++;
- // Prepare the placeholder for the new arguments that will be added later.
- Value dummy = rewriter.create<arith::ConstantOp>(
- loc, unrolledType, rewriter.getZeroAttr(unrolledType));
- newOpCount++;
-
- // Create the `vector.insert_strided_slice` ops.
- SmallVector<int64_t> strides(targetShape->size(), 1);
- SmallVector<Type> newTypes;
- for (SmallVector<int64_t> offsets :
- StaticTileOffsetRange(originalShape, *targetShape)) {
- result = rewriter.create<vector::InsertStridedSliceOp>(loc, dummy, result,
- offsets, strides);
- newTypes.push_back(unrolledType);
- unrolledInputNums.push_back(newInputNo);
- newInputNo++;
+ // Prepare the placeholder for the new arguments that will be added later.
+ Value dummy = rewriter.create<arith::ConstantOp>(
+ loc, unrolledType, rewriter.getZeroAttr(unrolledType));
newOpCount++;
+
+ // Create the `vector.insert_strided_slice` ops.
+ SmallVector<int64_t> strides(targetShape->size(), 1);
+ SmallVector<Type> newTypes;
+ for (SmallVector<int64_t> offsets :
+ StaticTileOffsetRange(originalShape, *targetShape)) {
+ result = rewriter.create<vector::InsertStridedSliceOp>(
+ loc, dummy, result, offsets, strides);
+ newTypes.push_back(unrolledType);
+ unrolledInputNums.push_back(newInputNo);
+ newInputNo++;
+ newOpCount++;
+ }
+ rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
+ oneToNTypeMapping.addInputs(origInputNo, newTypes);
}
- rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
- oneToNTypeMapping.addInputs(origInputNo, newTypes);
- }
-
- llvm::errs() << "After enumerating through the arguments\n";
- newFuncOp.dump();
-
- // Change the function signature.
- auto convertedTypes = oneToNTypeMapping.getConvertedTypes();
- auto newFnType =
- FunctionType::get(rewriter.getContext(), TypeRange(convertedTypes),
- TypeRange(fnType.getResults()));
- rewriter.modifyOpInPlace(newFuncOp,
- [&] { newFuncOp.setFunctionType(newFnType); });
-
- llvm::errs() << "After changing function signature\n";
- newFuncOp.dump();
-
- // Update the arguments in the entry block.
- entryBlock.eraseArguments(0, fnType.getNumInputs());
- SmallVector<Location> locs(convertedTypes.size(), newFuncOp.getLoc());
- entryBlock.addArguments(convertedTypes, locs);
-
- llvm::errs() << "After updating the arguments in the entry block\n";
- newFuncOp.dump();
-
- // Replace the placeholder values with the new arguments. We assume there is
- // only one block for now.
- size_t idx = 0;
- for (auto opPair : llvm::enumerate(entryBlock.getOperations())) {
- size_t count = opPair.index();
- Operation &op = opPair.value();
- // We first look for operands that are placeholders for initially legal
- // arguments.
- for (auto operandPair : llvm::enumerate(op.getOperands())) {
- Operation *operandOp = operandPair.value().getDefiningOp();
- if (tmpOps.find(operandOp) != tmpOps.end())
+
+ // Change the function signature.
+ auto convertedTypes = oneToNTypeMapping.getConvertedTypes();
+ auto newFnType =
+ FunctionType::get(rewriter.getContext(), TypeRange(convertedTypes),
+ TypeRange(fnType.getResults()));
+ rewriter.modifyOpInPlace(newFuncOp,
+ [&] { newFuncOp.setFunctionType(newFnType); });
+
+ // Update the arguments in the entry block.
+ entryBlock.eraseArguments(0, fnType.getNumInputs());
+ SmallVector<Location> locs(convertedTypes.size(), newFuncOp.getLoc());
+ entryBlock.addArguments(convertedTypes, locs);
+
+ // Replace the placeholder values with the new arguments. We assume there is
+ // only one block for now.
+ size_t idx = 0;
+ for (auto [count, op] : enumerate(entryBlock.getOperations())) {
+ // We first look for operands that are placeholders for initially legal
+ // arguments.
+ for (auto [operandIdx, operandVal] : llvm::enumerate(op.getOperands())) {
+ Operation *operandOp = operandVal.getDefiningOp();
+ if (tmpOps.find(operandOp) != tmpOps.end())
+ rewriter.modifyOpInPlace(&op, [&] {
+ op.setOperand(operandIdx, newFuncOp.getArgument(tmpOps[operandOp]));
+ });
+ }
+ // Since all newly created operations are in the beginning, reaching the
+ // end of them means that any later `vector.insert_strided_slice` should
+ // not be touched.
+ if (count >= newOpCount)
+ continue;
+ auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op);
+ if (vecOp) {
+ size_t unrolledInputNo = unrolledInputNums[idx];
rewriter.modifyOpInPlace(&op, [&] {
- op.setOperand(operandPair.index(),
- newFuncOp.getArgument(tmpOps[operandOp]));
+ op.setOperand(0, newFuncOp.getArgument(unrolledInputNo));
});
+ idx++;
+ }
+ count++;
}
- // Since all newly created operations are in the beginning, reaching the end
- // of them means that any later `vector.insert_strided_slice` should not be
- // touched.
- if (count >= newOpCount)
- continue;
- auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op);
- if (vecOp) {
- size_t unrolledInputNo = unrolledInputNums[idx];
- rewriter.modifyOpInPlace(&op, [&] {
- op.setOperand(0, newFuncOp.getArgument(unrolledInputNo));
- });
- idx++;
- }
- count++;
- }
- // Erase the original funcOp. The `tmpOps` do not need to be erased since
- // they have no uses and will be handled by dead-code elimination.
- rewriter.eraseOp(funcOp);
- return success();
-}
+ // Erase the original funcOp. The `tmpOps` do not need to be erased since
+ // they have no uses and will be handled by dead-code elimination.
+ rewriter.eraseOp(funcOp);
+ return success();
+ }
+};
+} // namespace
void mlir::populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns) {
patterns.add<FuncOpVectorUnroll>(patterns.getContext());
@@ -1043,86 +1021,72 @@ class ReturnOpVectorUnroll : public OpRewritePattern<func::ReturnOp> {
using OpRewritePattern<func::ReturnOp>::OpRewritePattern;
LogicalResult matchAndRewrite(func::ReturnOp returnOp,
- PatternRewriter &rewriter) const override;
-};
-} // namespace
-
-LogicalResult
-ReturnOpVectorUnroll::matchAndRewrite(func::ReturnOp returnOp,
- PatternRewriter &rewriter) const {
-
- // Check whether the parent funcOp is valid.
- func::FuncOp funcOp = dyn_cast<func::FuncOp>(returnOp->getParentOp());
- if (!funcOp)
- return failure();
+ PatternRewriter &rewriter) const override {
+ // Check whether the parent funcOp is valid.
+ auto funcOp = dyn_cast<func::FuncOp>(returnOp->getParentOp());
+ if (!funcOp)
+ return failure();
- auto fnType = funcOp.getFunctionType();
- OneToNTypeMapping oneToNTypeMapping(fnType.getResults());
- Location loc = returnOp.getLoc();
-
- // For the new return op.
- SmallVector<Value> newOperands;
-
- // Enumerate through the results.
- for (const auto &argType : enumerate(fnType.getResults())) {
- size_t origResultNo = argType.index();
- Type origType = argType.value();
- auto origVecType = llvm::dyn_cast<VectorType>(origType);
- // Check whether the argument is of vector type.
- if (!origVecType) {
- oneToNTypeMapping.addInputs(origResultNo, origType);
- newOperands.push_back(returnOp.getOperand(origResultNo));
- continue;
- }
- // Check whether the vector needs unrolling.
- auto targetShape = getTargetShape(origVecType);
- if (!targetShape) {
- // The original argument can be used.
- oneToNTypeMapping.addInputs(origResultNo, origType);
- newOperands.push_back(returnOp.getOperand(origResultNo));
- continue;
- }
- llvm::errs() << "Got target shape\n";
- VectorType unrolledType =
- VectorType::get(*targetShape, origVecType.getElementType());
- llvm::errs() << "Unrolled type is ";
- unrolledType.dump();
-
- // Create `vector.extract_strided_slice` ops to form legal vectors from the
- // original operand of illegal type.
- SmallVector<int64_t> originalShape =
- llvm::to_vector<4>(origVecType.getShape());
- SmallVector<int64_t> strides(targetShape->size(), 1);
- SmallVector<Type> newTypes;
- Value returnValue = returnOp.getOperand(origResultNo);
- for (SmallVector<int64_t> offsets :
- StaticTileOffsetRange(originalShape, *targetShape)) {
- Value result = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, returnValue, offsets, *targetShape, strides);
- newOperands.push_back(result);
- newTypes.push_back(unrolledType);
+ FunctionType fnType = funcOp.getFunctionType();
+ OneToNTypeMapping oneToNTypeMapping(fnType.getResults());
+ Location loc = returnOp.getLoc();
+
+ // For the new return op.
+ SmallVector<Value> newOperands;
+
+ // Enumerate through the results.
+ for (auto [origResultNo, origType] : enumerate(fnType.getResults())) {
+ // Check whether the argument is of vector type.
+ auto origVecType = llvm::dyn_cast<VectorType>(origType);
+ if (!origVecType) {
+ oneToNTypeMapping.addInputs(origResultNo, origType);
+ newOperands.push_back(returnOp.getOperand(origResultNo));
+ continue;
+ }
+ // Check whether the vector needs unrolling.
+ auto targetShape = getTargetShape(origVecType);
+ if (!targetShape) {
+ // The original argument can be used.
+ oneToNTypeMapping.addInputs(origResultNo, origType);
+ newOperands.push_back(returnOp.getOperand(origResultNo));
+ continue;
+ }
+ VectorType unrolledType =
+ VectorType::get(*targetShape, origVecType.getElementType());
+
+ // Create `vector.extract_strided_slice` ops to form legal vectors from
+ // the original operand of illegal type.
+ SmallVector<int64_t> originalShape =
+ llvm::to_vector<4>(origVecType.getShape());
+ SmallVector<int64_t> strides(targetShape->size(), 1);
+ SmallVector<Type> newTypes;
+ Value returnValue = returnOp.getOperand(origResultNo);
+ for (SmallVector<int64_t> offsets :
+ StaticTileOffsetRange(originalShape, *targetShape)) {
+ Value result = rewriter.create<vector::ExtractStridedSliceOp>(
+ loc, returnValue, offsets, *targetShape, strides);
+ newOperands.push_back(result);
+ newTypes.push_back(unrolledType);
+ }
+ oneToNTypeMapping.addInputs(origResultNo, newTypes);
}
- oneToNTypeMapping.addInputs(origResultNo, newTypes);
- }
- llvm::errs() << "After enumerating through the arguments\n";
- funcOp.dump();
+ // Change the function signature.
+ auto newFnType =
+ FunctionType::get(rewriter.getContext(), TypeRange(fnType.getInputs()),
+ TypeRange(oneToNTypeMapping.getConvertedTypes()));
+ rewriter.modifyOpInPlace(funcOp,
+ [&] { funcOp.setFunctionType(newFnType); });
- // Change the function signature.
- auto newFnType =
- FunctionType::get(rewriter.getContext(), TypeRange(fnType.getInputs()),
- TypeRange(oneToNTypeMapping.getConvertedTypes()));
- rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setFunctionType(newFnType); });
- llvm::errs() << "After changing function signature\n";
- funcOp.dump();
+ // Replace the return op using the new operands. This will automatically
+ // update the entry block as well.
+ rewriter.replaceOp(returnOp,
+ rewriter.create<func::ReturnOp>(loc, newOperands));
- // Replace the return op using the new operands. This will automatically
- // update the entry block as well.
- rewriter.replaceOp(returnOp,
- rewriter.create<func::ReturnOp>(loc, newOperands));
-
- return success();
-}
+ return success();
+ }
+};
+} // namespace
void mlir::populateReturnOpVectorRewritePatterns(RewritePatternSet &patterns) {
patterns.add<ReturnOpVectorUnroll>(patterns.getContext());
More information about the Mlir-commits
mailing list