[Mlir-commits] [mlir] 5ef7cea - [mlir][Vector] Significantly improve VectorToGPU.cpp
Nicolas Vasilache
llvmlistbot at llvm.org
Tue Feb 14 16:49:42 PST 2023
Author: Nicolas Vasilache
Date: 2023-02-14T16:49:36-08:00
New Revision: 5ef7ceae5769e30aab78a9f1e7bf1e4e68e8b859
URL: https://github.com/llvm/llvm-project/commit/5ef7ceae5769e30aab78a9f1e7bf1e4e68e8b859
DIFF: https://github.com/llvm/llvm-project/commit/5ef7ceae5769e30aab78a9f1e7bf1e4e68e8b859.diff
LOG: [mlir][Vector] Significantly improve VectorToGPU.cpp
This revision performs a bunch of cleanups and tracks free-flowing IR mutations.
APIs are systematized around RewriterBase and relevant debug messages are added.
Deliberate use of OpBuilder::InsertionGuard is added where needed.
Differential Revision: https://reviews.llvm.org/D143738
Added:
Modified:
mlir/include/mlir/Conversion/VectorToGPU/VectorToGPU.h
mlir/include/mlir/Dialect/NVGPU/Utils/MMAUtils.h
mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/VectorToGPU/VectorToGPU.h b/mlir/include/mlir/Conversion/VectorToGPU/VectorToGPU.h
index 6899134ec6b10..d8231fc5a10b5 100644
--- a/mlir/include/mlir/Conversion/VectorToGPU/VectorToGPU.h
+++ b/mlir/include/mlir/Conversion/VectorToGPU/VectorToGPU.h
@@ -12,6 +12,7 @@
#include "mlir/IR/PatternMatch.h"
namespace mlir {
+class LogicalResult;
class MLIRContext;
class Pass;
class RewritePatternSet;
@@ -29,13 +30,14 @@ void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns,
/// Convert vector ops to MMA matrix operations nested under `rootOp`. This will
/// convert slice of operations that can be legally converted to MMA operations.
/// The rest of the vector operations are left untouched.
-void convertVectorToMMAOps(Operation *rootOp);
+LogicalResult convertVectorToMMAOps(RewriterBase &rewriter, Operation *rootOp);
/// Convert vector ops ops nested under `rootOp` to vector and GPU operaitons
/// compatible with the `nvvm.mma.sync` lowering path. This will convert a slice
/// of operations that can be legally lowered on this path while the rest of
/// the vector operations are left untouched.
-LogicalResult convertVectorToNVVMCompatibleMMASync(Operation *rootOp);
+LogicalResult convertVectorToNVVMCompatibleMMASync(RewriterBase &rewriter,
+ Operation *rootOp);
/// Convert from vector to GPU ops.
std::unique_ptr<Pass> createConvertVectorToGPUPass(bool useNvGpu = false);
diff --git a/mlir/include/mlir/Dialect/NVGPU/Utils/MMAUtils.h b/mlir/include/mlir/Dialect/NVGPU/Utils/MMAUtils.h
index fac99dc048dba..5880b09161c7c 100644
--- a/mlir/include/mlir/Dialect/NVGPU/Utils/MMAUtils.h
+++ b/mlir/include/mlir/Dialect/NVGPU/Utils/MMAUtils.h
@@ -69,7 +69,7 @@ getMmaSyncRegisterType(const WarpMatrixInfo &type);
/// please see NVIDIA's PTX documentation:
/// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-mma
FailureOr<AffineMap>
-getLaneIdAndValueIdToOperandCoord(Location loc, OpBuilder &builder,
+getLaneIdAndValueIdToOperandCoord(OpBuilder &builder, Location loc,
const WarpMatrixInfo &fragmentType);
/// Encapsulates the parameters needed to lower a `nvgpu.ldmatrix` operation to
@@ -90,7 +90,7 @@ FailureOr<LdMatrixParams> getLdMatrixParams(const WarpMatrixInfo &type,
/// to two results representing offsets within the matrix operand that should
/// be the pointer locations a thread should pass to the ldmatrix instruction.
FailureOr<AffineMap>
-getLaneIdToLdMatrixMatrixCoord(Location loc, OpBuilder &builder,
+getLaneIdToLdMatrixMatrixCoord(OpBuilder &builder, Location loc,
const LdMatrixParams ¶ms);
/// Transform `vector.contract` into (m,k)x(n,k)x(m,n) form so that it can be
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index b0fa50d799160..0266ba139ec7d 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -26,11 +26,19 @@
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Region.h"
#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
+#define DEBUG_TYPE "vector-to-gpu"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+#define DBGSNL() (llvm::dbgs() << "\n")
+
namespace mlir {
#define GEN_PASS_DEF_CONVERTVECTORTOGPU
#include "mlir/Conversion/Passes.h.inc"
@@ -45,7 +53,7 @@ using namespace mlir;
/// the `offsetMap` has dimension placeholders, those should be provided in
/// `dimValues`.
template <typename TransferOpType>
-static void getXferIndices(OpBuilder &b, TransferOpType xferOp,
+static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp,
AffineMap offsetMap, ArrayRef<Value> dimValues,
SmallVector<Value, 4> &indices) {
indices.append(xferOp.getIndices().begin(), xferOp.getIndices().end());
@@ -56,9 +64,9 @@ static void getXferIndices(OpBuilder &b, TransferOpType xferOp,
Value prevIdx = indices[dim.getPosition()];
SmallVector<Value, 3> dims(dimValues.begin(), dimValues.end());
dims.push_back(prevIdx);
- AffineExpr d0 = b.getAffineDimExpr(offsetMap.getNumDims());
+ AffineExpr d0 = rewriter.getAffineDimExpr(offsetMap.getNumDims());
indices[dim.getPosition()] = makeComposedAffineApply(
- b, loc, d0 + offsetMap.getResult(offsetsIdx++), dims);
+ rewriter, loc, d0 + offsetMap.getResult(offsetsIdx++), dims);
continue;
}
}
@@ -94,8 +102,10 @@ static bool contractSupportsMMAMatrixType(vector::ContractionOp contract,
// Return true if the given map represents a transposed matrix load,
// i.e. (d0, d1, ...) -> (dn-1, dn-2).
-static bool isTransposeMatrixLoadMap(OpBuilder &b, AffineMap permutationMap) {
- MLIRContext *ctx = b.getContext();
+static bool isTransposeMatrixLoadMap(AffineMap permutationMap) {
+ MLIRContext *ctx = permutationMap.getContext();
+ // Local OpBuilder is fine here, we just build attributes.
+ OpBuilder b(ctx);
auto nDim = permutationMap.getNumDims();
AffineExpr zero = b.getAffineConstantExpr(0);
if (nDim < 2) {
@@ -148,15 +158,16 @@ static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp,
return false;
AffineMap map = readOp.getPermutationMap();
- OpBuilder b(readOp.getContext());
- AffineExpr innerDim = b.getAffineDimExpr(map.getNumDims() - 1);
- AffineExpr zero = b.getAffineConstantExpr(0);
- auto broadcastInnerDim = AffineMap::get(map.getNumDims(), 0, {zero, innerDim},
- readOp.getContext());
+
+ MLIRContext *ctx = readOp.getContext();
+ AffineExpr innerDim = getAffineDimExpr(map.getNumDims() - 1, ctx);
+ AffineExpr zero = getAffineConstantExpr(0, ctx);
+ auto broadcastInnerDim =
+ AffineMap::get(map.getNumDims(), 0, {zero, innerDim}, ctx);
if (!useNvGpu) {
bool result = map.isMinorIdentity() || map == broadcastInnerDim ||
- isTransposeMatrixLoadMap(b, map);
+ isTransposeMatrixLoadMap(map);
return result;
}
@@ -383,14 +394,13 @@ struct PrepareContractToGPUMMA
if (!(vector::isParallelIterator(iteratorTypes[0]) &&
vector::isParallelIterator(iteratorTypes[1]) &&
vector::isReductionIterator(iteratorTypes[2])))
- return failure();
+ return rewriter.notifyMatchFailure(op, "not a gemm contraction");
//
// Two outer parallel, one inner reduction (matmat flavor).
//
- if (maps == infer({{m, k}, {k, n}, {m, n}})) {
- // This is the classical row-major matmul, nothing to do.
- return failure();
- }
+ // This is the classical row-major matmul, nothing to do.
+ if (maps == infer({{m, k}, {k, n}, {m, n}}))
+ return rewriter.notifyMatchFailure(op, "contraction already prepared");
if (maps == infer({{m, k}, {n, k}, {m, n}})) {
rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
} else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
@@ -411,7 +421,8 @@ struct PrepareContractToGPUMMA
} else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
std::swap(lhs, rhs);
} else {
- return failure();
+ // TODO: llvm_unreachable ?
+ return rewriter.notifyMatchFailure(op, "unexpected contraction case");
}
rewriter.replaceOpWithNewOp<vector::ContractionOp>(
op, lhs, rhs, res,
@@ -445,14 +456,15 @@ struct CombineTransferReadOpTranspose final
auto transferReadOp = source.getDefiningOp<vector::TransferReadOp>();
if (!transferReadOp)
- return failure();
+ return rewriter.notifyMatchFailure(op, "no transfer read");
// TODO: support 0-d corner case.
if (transferReadOp.getTransferRank() == 0)
- return failure();
+ return rewriter.notifyMatchFailure(op, "0-D transfer read");
if (transferReadOp.getMask() || transferReadOp.hasOutOfBoundsDim())
- return failure();
+ return rewriter.notifyMatchFailure(op, "not inbounds transfer read");
+
SmallVector<int64_t, 2> perm;
op.getTransp(perm);
SmallVector<unsigned, 2> permU;
@@ -508,17 +520,24 @@ static const char *inferFragType(Operation *op) {
return "COp";
}
-static void convertTransferReadOp(vector::TransferReadOp op,
- llvm::DenseMap<Value, Value> &valueMapping) {
+static LogicalResult
+convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op,
+ llvm::DenseMap<Value, Value> &valueMapping) {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(op);
+
assert(op.getTransferRank() > 0 && "unexpected 0-d transfer");
assert(transferReadSupportsMMAMatrixType(op, /*useNvGpu=*/false));
std::optional<int64_t> stride =
getMemrefConstantHorizontalStride(op.getShapedType());
+ if (!stride.has_value()) {
+ LLVM_DEBUG(DBGS() << "no stride\n");
+ return rewriter.notifyMatchFailure(op, "no stride");
+ }
AffineMap map = op.getPermutationMap();
- OpBuilder b(op);
- bool isTranspose = isTransposeMatrixLoadMap(b, map);
+ bool isTranspose = isTransposeMatrixLoadMap(map);
// Handle broadcast by setting the stride to 0.
if (auto cstExpr =
@@ -526,7 +545,7 @@ static void convertTransferReadOp(vector::TransferReadOp op,
assert(cstExpr.getValue() == 0);
stride = 0;
}
- assert(stride);
+
Value mappingResult = op.getResult();
auto elType = op.getVectorType().getElementType();
const char *fragType = inferFragType(op);
@@ -544,24 +563,47 @@ static void convertTransferReadOp(vector::TransferReadOp op,
}
gpu::MMAMatrixType type =
gpu::MMAMatrixType::get(op.getVectorType().getShape(), elType, fragType);
- Value load = b.create<gpu::SubgroupMmaLoadMatrixOp>(
+ Value load = rewriter.create<gpu::SubgroupMmaLoadMatrixOp>(
op.getLoc(), type, op.getSource(), op.getIndices(),
- b.getIndexAttr(*stride), isTranspose ? b.getUnitAttr() : UnitAttr());
+ rewriter.getIndexAttr(*stride),
+ isTranspose ? rewriter.getUnitAttr() : UnitAttr());
valueMapping[mappingResult] = load;
+
+ LLVM_DEBUG(DBGS() << "transfer read to: " << load << "\n");
+ return success();
}
-static void convertTransferWriteOp(vector::TransferWriteOp op,
- llvm::DenseMap<Value, Value> &valueMapping) {
+static LogicalResult
+convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op,
+ llvm::DenseMap<Value, Value> &valueMapping) {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(op);
+
assert(transferWriteSupportsMMAMatrixType(op));
std::optional<int64_t> stride =
getMemrefConstantHorizontalStride(op.getShapedType());
- assert(stride);
- OpBuilder b(op);
- Value matrix = valueMapping.find(op.getVector())->second;
- b.create<gpu::SubgroupMmaStoreMatrixOp>(
+ if (!stride.has_value()) {
+ LLVM_DEBUG(DBGS() << "no stride\n");
+ return rewriter.notifyMatchFailure(op, "no stride");
+ }
+
+ auto it = valueMapping.find(op.getVector());
+ if (it == valueMapping.end()) {
+ LLVM_DEBUG(DBGS() << "no mapping\n");
+ return rewriter.notifyMatchFailure(op, "no mapping");
+ }
+
+ Value matrix = it->second;
+ auto store = rewriter.create<gpu::SubgroupMmaStoreMatrixOp>(
op.getLoc(), matrix, op.getSource(), op.getIndices(),
- b.getIndexAttr(*stride), /*transpose=*/UnitAttr());
- op.erase();
+ rewriter.getIndexAttr(*stride), /*transpose=*/UnitAttr());
+ (void)store;
+
+ LLVM_DEBUG(DBGS() << "transfer write to: " << store << "\n");
+
+ LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
+ rewriter.eraseOp(op);
+ return success();
}
/// Returns the vector type which represents a matrix fragment.
@@ -577,24 +619,33 @@ getMmaSyncVectorOperandType(const nvgpu::FragmentElementInfo ®Info) {
/// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
static LogicalResult
-convertConstantOpMmaSync(arith::ConstantOp op,
+convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op,
llvm::DenseMap<Value, Value> &valueMapping) {
- OpBuilder b(op);
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(op);
+
FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
nvgpu::getWarpMatrixInfo(op);
- if (failed(warpMatrixInfo))
- return failure();
+ if (failed(warpMatrixInfo)) {
+ LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n");
+ return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
+ }
FailureOr<nvgpu::FragmentElementInfo> regInfo =
nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
- if (failed(regInfo))
- return failure();
+ if (failed(regInfo)) {
+ LLVM_DEBUG(DBGS() << "not mma sync reg info\n");
+ return rewriter.notifyMatchFailure(op, "not mma sync reg info");
+ }
VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
auto dense = op.getValue().dyn_cast<SplatElementsAttr>();
- if (!dense)
- return failure();
- Value result = b.create<arith::ConstantOp>(
+ if (!dense) {
+ LLVM_DEBUG(DBGS() << "not a splat\n");
+ return rewriter.notifyMatchFailure(op, "not a splat");
+ }
+
+ Value result = rewriter.create<arith::ConstantOp>(
op.getLoc(), vectorType,
DenseElementsAttr::get(vectorType, dense.getSplatValue<Attribute>()));
valueMapping[op.getResult()] = result;
@@ -602,43 +653,54 @@ convertConstantOpMmaSync(arith::ConstantOp op,
}
static LogicalResult
-creatLdMatrixCompatibleLoads(vector::TransferReadOp op, OpBuilder &builder,
+creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op,
llvm::DenseMap<Value, Value> &valueMapping) {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(op);
Location loc = op->getLoc();
FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
nvgpu::getWarpMatrixInfo(op);
- if (failed(warpMatrixInfo))
- return failure();
+ if (failed(warpMatrixInfo)) {
+ LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n");
+ return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
+ }
FailureOr<nvgpu::FragmentElementInfo> regInfo =
nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
- if (failed(regInfo))
- return failure();
+ if (failed(regInfo)) {
+ LLVM_DEBUG(DBGS() << "not mma sync reg info\n");
+ return rewriter.notifyMatchFailure(op, "not mma sync reg info");
+ }
FailureOr<nvgpu::LdMatrixParams> params = nvgpu::getLdMatrixParams(
*warpMatrixInfo,
/*transpose=*/!op.getPermutationMap().isMinorIdentity());
if (failed(params)) {
- return op->emitError()
- << "failed to convert vector.transfer_read to ldmatrix; this op "
- "likely "
- "should not be converted to a nvgpu.ldmatrix call.";
+ LLVM_DEBUG(
+ DBGS()
+ << "failed to convert vector.transfer_read to ldmatrix. "
+ << "Op should likely not be converted to a nvgpu.ldmatrix call.\n");
+ return rewriter.notifyMatchFailure(
+ op, "failed to convert vector.transfer_read to ldmatrix; this op "
+ "likely should not be converted to a nvgpu.ldmatrix call.");
}
// Adjust the load offset.
- auto laneId = builder.create<gpu::LaneIdOp>(loc);
+ auto laneId = rewriter.create<gpu::LaneIdOp>(loc);
FailureOr<AffineMap> offsets =
- nvgpu::getLaneIdToLdMatrixMatrixCoord(loc, builder, *params);
- if (failed(offsets))
- return failure();
+ nvgpu::getLaneIdToLdMatrixMatrixCoord(rewriter, loc, *params);
+ if (failed(offsets)) {
+ LLVM_DEBUG(DBGS() << "no offsets\n");
+ return rewriter.notifyMatchFailure(op, "no offsets");
+ }
VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
SmallVector<Value, 4> indices;
- getXferIndices<vector::TransferReadOp>(builder, op, *offsets, {laneId},
+ getXferIndices<vector::TransferReadOp>(rewriter, op, *offsets, {laneId},
indices);
- nvgpu::LdMatrixOp newOp = builder.create<nvgpu::LdMatrixOp>(
+ nvgpu::LdMatrixOp newOp = rewriter.create<nvgpu::LdMatrixOp>(
loc, vectorType, op.getSource(), indices,
!op.getPermutationMap().isMinorIdentity(), params->numTiles);
valueMapping[op] = newOp->getResult(0);
@@ -646,32 +708,36 @@ creatLdMatrixCompatibleLoads(vector::TransferReadOp op, OpBuilder &builder,
}
static LogicalResult
-createNonLdMatrixLoads(vector::TransferReadOp op, OpBuilder &builder,
+createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
llvm::DenseMap<Value, Value> &valueMapping) {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(op);
+
Location loc = op.getLoc();
FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
nvgpu::getWarpMatrixInfo(op);
if (failed(warpMatrixInfo))
- return failure();
+ return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
FailureOr<nvgpu::FragmentElementInfo> regInfo =
nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
if (failed(regInfo)) {
- op->emitError() << "Failed to deduce register fragment type during "
- "conversion to distributed non-ldmatrix compatible load";
- return failure();
+ rewriter.notifyMatchFailure(
+ op, "Failed to deduce register fragment type during "
+ "conversion to distributed non-ldmatrix compatible load");
}
- Value laneId = builder.create<gpu::LaneIdOp>(loc);
+ Value laneId = rewriter.create<gpu::LaneIdOp>(loc);
SmallVector<Value, 4> elements;
// This is the individual element type.
Type loadedElType = regInfo->registerLLVMType;
VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
- Value fill = builder.create<arith::ConstantOp>(
+ Value fill = rewriter.create<arith::ConstantOp>(
op.getLoc(), vectorType.getElementType(),
- builder.getZeroAttr(vectorType.getElementType()));
- Value result = builder.create<vector::SplatOp>(op.getLoc(), fill, vectorType);
+ rewriter.getZeroAttr(vectorType.getElementType()));
+ Value result =
+ rewriter.create<vector::SplatOp>(op.getLoc(), fill, vectorType);
bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity();
@@ -684,20 +750,21 @@ createNonLdMatrixLoads(vector::TransferReadOp op, OpBuilder &builder,
for (int i = 0; i < vectorType.getShape()[0]; i++) {
FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
- op.getLoc(), builder, *warpMatrixInfo);
+ rewriter, op.getLoc(), *warpMatrixInfo);
if (failed(coords))
- return failure();
- Value logicalValueId = builder.create<arith::ConstantOp>(
- loc, builder.getIndexType(),
- builder.getIndexAttr(i * regInfo->elementsPerRegister));
+ return rewriter.notifyMatchFailure(op, "no coords");
+
+ Value logicalValueId = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getIndexType(),
+ rewriter.getIndexAttr(i * regInfo->elementsPerRegister));
SmallVector<Value, 4> newIndices;
getXferIndices<vector::TransferReadOp>(
- builder, op, *coords, {laneId, logicalValueId}, newIndices);
+ rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
- Value el = builder.create<vector::LoadOp>(loc, loadedElType,
- op.getSource(), newIndices);
- result = builder.create<vector::InsertOp>(loc, el, result,
- builder.getI64ArrayAttr(i));
+ Value el = rewriter.create<vector::LoadOp>(loc, loadedElType,
+ op.getSource(), newIndices);
+ result = rewriter.create<vector::InsertOp>(loc, el, result,
+ rewriter.getI64ArrayAttr(i));
}
} else {
if (auto vecType = loadedElType.dyn_cast<VectorType>()) {
@@ -707,21 +774,21 @@ createNonLdMatrixLoads(vector::TransferReadOp op, OpBuilder &builder,
for (unsigned innerIdx = 0; innerIdx < vectorType.getShape()[1];
innerIdx++) {
- Value logicalValueId = builder.create<arith::ConstantOp>(
- loc, builder.getIndexType(),
- builder.getIndexAttr(i * regInfo->elementsPerRegister + innerIdx));
+ Value logicalValueId = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getIndexType(),
+ rewriter.getIndexAttr(i * regInfo->elementsPerRegister + innerIdx));
FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
- op.getLoc(), builder, *warpMatrixInfo);
+ rewriter, op.getLoc(), *warpMatrixInfo);
if (failed(coords))
- return failure();
+ return rewriter.notifyMatchFailure(op, "no coords");
SmallVector<Value, 4> newIndices;
getXferIndices<vector::TransferReadOp>(
- builder, op, *coords, {laneId, logicalValueId}, newIndices);
- Value el = builder.create<memref::LoadOp>(op.getLoc(), loadedElType,
- op.getSource(), newIndices);
- result = builder.create<vector::InsertOp>(
- op.getLoc(), el, result, builder.getI64ArrayAttr({i, innerIdx}));
+ rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
+ Value el = rewriter.create<memref::LoadOp>(op.getLoc(), loadedElType,
+ op.getSource(), newIndices);
+ result = rewriter.create<vector::InsertOp>(
+ op.getLoc(), el, result, rewriter.getI64ArrayAttr({i, innerIdx}));
}
}
}
@@ -744,14 +811,15 @@ static bool isSharedMemory(MemRefType type) {
/// `vector.load` or a `nvgpu.ldmatrix` operation. This function should only be
/// used when converting to `nvgpu.mma.sync` operations.
static LogicalResult
-convertTransferReadToLoads(vector::TransferReadOp op,
+convertTransferReadToLoads(RewriterBase &rewriter, vector::TransferReadOp op,
llvm::DenseMap<Value, Value> &valueMapping) {
- OpBuilder b(op);
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(op);
FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
nvgpu::getWarpMatrixInfo(op);
if (failed(warpMatrixInfo))
- return failure();
+ return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
bool isLdMatrixCompatible =
isSharedMemory(op.getSource().getType().cast<MemRefType>()) &&
@@ -769,46 +837,54 @@ convertTransferReadToLoads(vector::TransferReadOp op,
isLdMatrixCompatible = false;
if (!isLdMatrixCompatible)
- return createNonLdMatrixLoads(op, b, valueMapping);
+ return createNonLdMatrixLoads(rewriter, op, valueMapping);
- return creatLdMatrixCompatibleLoads(op, b, valueMapping);
+ return creatLdMatrixCompatibleLoads(rewriter, op, valueMapping);
}
static LogicalResult
-convertTransferWriteToStores(vector::TransferWriteOp op,
+convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op,
llvm::DenseMap<Value, Value> &valueMapping) {
- OpBuilder b(op);
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(op);
+
Location loc = op->getLoc();
- Value matrix = valueMapping.find(op.getVector())->second;
+ auto it = valueMapping.find(op.getVector());
+ if (it == valueMapping.end())
+ return rewriter.notifyMatchFailure(op, "no mapping");
+ Value matrix = it->second;
FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
nvgpu::getWarpMatrixInfo(op);
if (failed(warpMatrixInfo))
- return failure();
+ return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
FailureOr<nvgpu::FragmentElementInfo> regInfo =
nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
if (failed(regInfo))
- return failure();
+ return rewriter.notifyMatchFailure(op, "not mma sync reg info");
VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
- Value laneId = b.create<gpu::LaneIdOp>(loc);
+ Value laneId = rewriter.create<gpu::LaneIdOp>(loc);
for (unsigned i = 0; i < vectorType.getShape()[0]; i++) {
- Value logicalValueId = b.create<arith::ConstantOp>(
- loc, b.getIndexType(),
- b.getIndexAttr(i * regInfo->elementsPerRegister));
+ Value logicalValueId = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getIndexType(),
+ rewriter.getIndexAttr(i * regInfo->elementsPerRegister));
FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
- op.getLoc(), b, *warpMatrixInfo);
+ rewriter, op.getLoc(), *warpMatrixInfo);
if (failed(coords))
- return failure();
+ return rewriter.notifyMatchFailure(op, "no coords");
- Value el = b.create<vector::ExtractOp>(loc, matrix, ArrayRef<int64_t>{i});
+ Value el =
+ rewriter.create<vector::ExtractOp>(loc, matrix, ArrayRef<int64_t>{i});
SmallVector<Value, 4> newIndices;
getXferIndices<vector::TransferWriteOp>(
- b, op, *coords, {laneId, logicalValueId}, newIndices);
- b.create<vector::StoreOp>(loc, el, op.getSource(), newIndices);
+ rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
+ rewriter.create<vector::StoreOp>(loc, el, op.getSource(), newIndices);
}
- op->erase();
+
+ LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
+ rewriter.eraseOp(op);
return success();
}
@@ -819,35 +895,37 @@ static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
}
static LogicalResult
-convertExtractStridedSlice(vector::ExtractStridedSliceOp op,
+convertExtractStridedSlice(RewriterBase &rewriter,
+ vector::ExtractStridedSliceOp op,
llvm::DenseMap<Value, Value> &valueMapping) {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(op);
- OpBuilder b(op);
Location loc = op->getLoc();
FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
nvgpu::getWarpMatrixInfo(op);
if (failed(warpMatrixInfo))
- return failure();
+ return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
FailureOr<nvgpu::FragmentElementInfo> mmaSyncFragmentInfo =
nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
if (failed(mmaSyncFragmentInfo))
- return failure();
+ return rewriter.notifyMatchFailure(op, "no mmaSyncFragmentInfo");
// Find the vector.transer_read whose result vector is being sliced.
auto transferReadOp = op.getVector().getDefiningOp<vector::TransferReadOp>();
if (!transferReadOp)
- return failure();
+ return rewriter.notifyMatchFailure(op, "no transfer read");
warpMatrixInfo = nvgpu::getWarpMatrixInfo(transferReadOp);
if (failed(warpMatrixInfo))
- return failure();
+ return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
FailureOr<nvgpu::FragmentElementInfo> ldFragmentInfo =
nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
if (failed(ldFragmentInfo))
- return failure();
+ return rewriter.notifyMatchFailure(op, "no ldFragmentInfo");
assert(
(mmaSyncFragmentInfo->elementsPerRegister ==
@@ -860,7 +938,10 @@ convertExtractStridedSlice(vector::ExtractStridedSliceOp op,
std::array<int64_t, 2> sliceShape = {
mmaSyncFragmentInfo->numRegistersPerFragment,
mmaSyncFragmentInfo->elementsPerRegister};
- auto sourceVector = valueMapping.find(transferReadOp)->second;
+ auto it = valueMapping.find(transferReadOp);
+ if (it == valueMapping.end())
+ return rewriter.notifyMatchFailure(op, "no mapping");
+ auto sourceVector = it->second;
// offset and sizes at warp-level of onwership.
SmallVector<int64_t> offsets;
@@ -882,86 +963,114 @@ convertExtractStridedSlice(vector::ExtractStridedSliceOp op,
else if (offsets[1])
sliceOffset[0] = (warpVectorShape[1] / offsets[1]);
- Value newOp = b.create<vector::ExtractStridedSliceOp>(
+ Value newOp = rewriter.create<vector::ExtractStridedSliceOp>(
loc, sourceVector, sliceOffset, sliceShape, strides);
valueMapping[op] = newOp;
return success();
}
-static void convertContractOp(vector::ContractionOp op,
- llvm::DenseMap<Value, Value> &valueMapping) {
- OpBuilder b(op);
- Value opA = valueMapping.find(op.getLhs())->second;
- Value opB = valueMapping.find(op.getRhs())->second;
- Value opC = valueMapping.find(op.getAcc())->second;
- Value matmul = b.create<gpu::SubgroupMmaComputeOp>(
+static LogicalResult
+convertContractOp(RewriterBase &rewriter, vector::ContractionOp op,
+ llvm::DenseMap<Value, Value> &valueMapping) {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(op);
+
+ auto itA = valueMapping.find(op.getLhs());
+ auto itB = valueMapping.find(op.getRhs());
+ auto itC = valueMapping.find(op.getAcc());
+ if (itA == valueMapping.end() || itB == valueMapping.end() ||
+ itC == valueMapping.end())
+ return rewriter.notifyMatchFailure(op, "no mapping");
+ Value opA = itA->second, opB = itB->second, opC = itC->second;
+ Value matmul = rewriter.create<gpu::SubgroupMmaComputeOp>(
op.getLoc(), opC.getType(), opA, opB, opC, /*a_transpose=*/UnitAttr(),
/*b_transpose=*/UnitAttr());
valueMapping[op.getResult()] = matmul;
+ return success();
}
static LogicalResult
-convertContractOpToMmaSync(vector::ContractionOp op,
+convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op,
llvm::DenseMap<Value, Value> &valueMapping) {
- OpBuilder b(op);
- Value opA = valueMapping.find(op.getLhs())->second;
- Value opB = valueMapping.find(op.getRhs())->second;
- Value opC = valueMapping.find(op.getAcc())->second;
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(op);
+
+ auto itA = valueMapping.find(op.getLhs());
+ auto itB = valueMapping.find(op.getRhs());
+ auto itC = valueMapping.find(op.getAcc());
+ if (itA == valueMapping.end() || itB == valueMapping.end() ||
+ itC == valueMapping.end())
+ return rewriter.notifyMatchFailure(op, "no mapping");
+ Value opA = itA->second, opB = itB->second, opC = itC->second;
int64_t m = op.getLhs().getType().cast<VectorType>().getShape()[0];
int64_t n = op.getRhs().getType().cast<VectorType>().getShape()[0];
int64_t k = op.getLhs().getType().cast<VectorType>().getShape()[1];
- Value matmul = b.create<nvgpu::MmaSyncOp>(op.getLoc(), opA, opB, opC,
- b.getI64ArrayAttr({m, n, k}));
+ Value matmul = rewriter.create<nvgpu::MmaSyncOp>(
+ op.getLoc(), opA, opB, opC, rewriter.getI64ArrayAttr({m, n, k}));
valueMapping[op.getResult()] = matmul;
return success();
}
/// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
-static void convertConstantOp(arith::ConstantOp op,
- llvm::DenseMap<Value, Value> &valueMapping) {
+static LogicalResult
+convertConstantOp(RewriterBase &rewriter, arith::ConstantOp op,
+ llvm::DenseMap<Value, Value> &valueMapping) {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(op);
+
assert(constantSupportsMMAMatrixType(op));
- OpBuilder b(op);
+
auto splat =
op.getValue().cast<SplatElementsAttr>().getSplatValue<TypedAttr>();
auto scalarConstant =
- b.create<arith::ConstantOp>(op.getLoc(), splat.getType(), splat);
+ rewriter.create<arith::ConstantOp>(op.getLoc(), splat.getType(), splat);
const char *fragType = inferFragType(op);
auto vecType = op.getType().cast<VectorType>();
gpu::MMAMatrixType type = gpu::MMAMatrixType::get(
vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
- auto matrix = b.create<gpu::SubgroupMmaConstantMatrixOp>(op.getLoc(), type,
- scalarConstant);
+ auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>(
+ op.getLoc(), type, scalarConstant);
valueMapping[op.getResult()] = matrix;
+ return success();
}
/// Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op.
-static void convertBroadcastOp(vector::BroadcastOp op,
- llvm::DenseMap<Value, Value> &valueMapping) {
+static LogicalResult
+convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op,
+ llvm::DenseMap<Value, Value> &valueMapping) {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(op);
+
assert(broadcastSupportsMMAMatrixType(op));
- OpBuilder b(op);
+
const char *fragType = inferFragType(op);
auto vecType = op.getVectorType();
gpu::MMAMatrixType type = gpu::MMAMatrixType::get(
vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
- auto matrix = b.create<gpu::SubgroupMmaConstantMatrixOp>(op.getLoc(), type,
- op.getSource());
+ auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>(
+ op.getLoc(), type, op.getSource());
valueMapping[op.getResult()] = matrix;
+ return success();
}
// Replace ForOp with a new ForOp with extra operands. The YieldOp is not
// updated and needs to be updated separatly for the loop to be correct.
-static scf::ForOp replaceForOpWithNewSignature(OpBuilder &b, scf::ForOp loop,
+static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter,
+ scf::ForOp loop,
ValueRange newIterOperands) {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(loop);
+
// Create a new loop before the existing one, with the extra operands.
- OpBuilder::InsertionGuard g(b);
- b.setInsertionPoint(loop);
+ rewriter.setInsertionPoint(loop);
auto operands = llvm::to_vector<4>(loop.getIterOperands());
operands.append(newIterOperands.begin(), newIterOperands.end());
- scf::ForOp newLoop =
- b.create<scf::ForOp>(loop.getLoc(), loop.getLowerBound(),
- loop.getUpperBound(), loop.getStep(), operands);
+ scf::ForOp newLoop = rewriter.create<scf::ForOp>(
+ loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
+ operands);
newLoop.getBody()->erase();
+
newLoop.getLoopBody().getBlocks().splice(
newLoop.getLoopBody().getBlocks().begin(),
loop.getLoopBody().getBlocks());
@@ -970,25 +1079,35 @@ static scf::ForOp replaceForOpWithNewSignature(OpBuilder &b, scf::ForOp loop,
for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front(
loop.getNumResults())))
- std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
- loop.erase();
+ rewriter.replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
+
+ LLVM_DEBUG(DBGS() << "newLoop now: " << newLoop << "\n");
+ LLVM_DEBUG(DBGS() << "stripped scf.for: " << loop << "\n");
+ LLVM_DEBUG(DBGS() << "erase: " << loop);
+
+ rewriter.eraseOp(loop);
return newLoop;
}
-static void convertForOp(scf::ForOp op,
- llvm::DenseMap<Value, Value> &valueMapping) {
+static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op,
+ llvm::DenseMap<Value, Value> &valueMapping) {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(op);
+
SmallVector<Value> newOperands;
SmallVector<std::pair<size_t, size_t>> argMapping;
for (const auto &operand : llvm::enumerate(op.getIterOperands())) {
auto it = valueMapping.find(operand.value());
- if (it == valueMapping.end())
+ if (it == valueMapping.end()) {
+ LLVM_DEBUG(DBGS() << "no value mapping for: " << operand.value() << "\n");
continue;
+ }
argMapping.push_back(std::make_pair(
operand.index(), op.getNumIterOperands() + newOperands.size()));
newOperands.push_back(it->second);
}
- OpBuilder b(op);
- scf::ForOp newForOp = replaceForOpWithNewSignature(b, op, newOperands);
+
+ scf::ForOp newForOp = replaceForOpWithNewSignature(rewriter, op, newOperands);
Block &loopBody = *newForOp.getBody();
for (auto mapping : argMapping) {
valueMapping[newForOp.getResult(mapping.first)] =
@@ -997,11 +1116,17 @@ static void convertForOp(scf::ForOp op,
newForOp.getNumInductionVars())] =
loopBody.getArgument(mapping.second + newForOp.getNumInductionVars());
}
+
+ LLVM_DEBUG(DBGS() << "scf.for to: " << newForOp << "\n");
+ return success();
}
-static void convertYieldOp(scf::YieldOp op,
- llvm::DenseMap<Value, Value> &valueMapping) {
- OpBuilder b(op);
+static LogicalResult
+convertYieldOp(RewriterBase &rewriter, scf::YieldOp op,
+ llvm::DenseMap<Value, Value> &valueMapping) {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(op);
+
auto loop = cast<scf::ForOp>(op->getParentOp());
auto yieldOperands = llvm::to_vector<4>(op.getOperands());
for (const auto &operand : llvm::enumerate(op.getOperands())) {
@@ -1013,20 +1138,32 @@ static void convertYieldOp(scf::YieldOp op,
yieldOperands[operand.index()] = loop.getIterOperands()[operand.index()];
yieldOperands.push_back(it->second);
}
- b.create<scf::YieldOp>(op.getLoc(), yieldOperands);
- op.erase();
+ rewriter.create<scf::YieldOp>(op.getLoc(), yieldOperands);
+
+ LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
+ rewriter.eraseOp(op);
+ return success();
}
/// Convert an elementwise op to the equivalent elementwise op on MMA matrix.
-static void convertElementwiseOp(Operation *op, gpu::MMAElementwiseOp opType,
- llvm::DenseMap<Value, Value> &valueMapping) {
- OpBuilder b(op);
+static LogicalResult
+convertElementwiseOp(RewriterBase &rewriter, Operation *op,
+ gpu::MMAElementwiseOp opType,
+ llvm::DenseMap<Value, Value> &valueMapping) {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(op);
+
SmallVector<Value> matrixOperands;
- for (Value operand : op->getOperands())
- matrixOperands.push_back(valueMapping.find(operand)->second);
- Value newOp = b.create<gpu::SubgroupMmaElementwiseOp>(
+ for (Value operand : op->getOperands()) {
+ auto it = valueMapping.find(operand);
+ if (it == valueMapping.end())
+ return rewriter.notifyMatchFailure(op, "no mapping");
+ matrixOperands.push_back(it->second);
+ }
+ Value newOp = rewriter.create<gpu::SubgroupMmaElementwiseOp>(
op->getLoc(), matrixOperands[0].getType(), matrixOperands, opType);
valueMapping[op->getResult(0)] = newOp;
+ return success();
}
void mlir::populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns,
@@ -1041,67 +1178,75 @@ void mlir::populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns,
patterns.getContext());
}
-void mlir::convertVectorToMMAOps(Operation *rootOp) {
+LogicalResult mlir::convertVectorToMMAOps(RewriterBase &rewriter,
+ Operation *rootOp) {
SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/false);
llvm::DenseMap<Value, Value> valueMapping;
+
+ auto globalRes = LogicalResult::success();
for (Operation *op : ops) {
+ LLVM_DEBUG(DBGS() << "Process op: " << *op << "\n");
+ // Apparently callers do not want to early exit on failure here.
+ auto res = LogicalResult::success();
if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) {
- convertTransferReadOp(transferRead, valueMapping);
+ res = convertTransferReadOp(rewriter, transferRead, valueMapping);
} else if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) {
- convertTransferWriteOp(transferWrite, valueMapping);
+ res = convertTransferWriteOp(rewriter, transferWrite, valueMapping);
} else if (auto contractOp = dyn_cast<vector::ContractionOp>(op)) {
- convertContractOp(contractOp, valueMapping);
+ res = convertContractOp(rewriter, contractOp, valueMapping);
} else if (auto constantOp = dyn_cast<arith::ConstantOp>(op)) {
- convertConstantOp(constantOp, valueMapping);
+ res = convertConstantOp(rewriter, constantOp, valueMapping);
} else if (auto broadcastOp = dyn_cast<vector::BroadcastOp>(op)) {
- convertBroadcastOp(broadcastOp, valueMapping);
+ res = convertBroadcastOp(rewriter, broadcastOp, valueMapping);
} else if (auto forOp = dyn_cast<scf::ForOp>(op)) {
- convertForOp(forOp, valueMapping);
- } else if (auto yiledOp = dyn_cast<scf::YieldOp>(op)) {
- convertYieldOp(yiledOp, valueMapping);
+ res = convertForOp(rewriter, forOp, valueMapping);
+ } else if (auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
+ res = convertYieldOp(rewriter, yieldOp, valueMapping);
} else if (auto elementwiseType = convertElementwiseOpToMMA(op)) {
- convertElementwiseOp(op, *elementwiseType, valueMapping);
+ res = convertElementwiseOp(rewriter, op, *elementwiseType, valueMapping);
}
+ if (failed(res))
+ globalRes = failure();
}
+ return globalRes;
}
-LogicalResult mlir::convertVectorToNVVMCompatibleMMASync(Operation *rootOp) {
+LogicalResult mlir::convertVectorToNVVMCompatibleMMASync(RewriterBase &rewriter,
+ Operation *rootOp) {
SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/true);
llvm::DenseMap<Value, Value> valueMapping;
for (Operation *op : ops) {
if (llvm::TypeSwitch<Operation *, LogicalResult>(op)
.Case([&](vector::TransferReadOp transferReadOp) {
- return convertTransferReadToLoads(transferReadOp, valueMapping);
+ return convertTransferReadToLoads(rewriter, transferReadOp,
+ valueMapping);
})
.Case([&](vector::TransferWriteOp transferWriteOp) {
- return convertTransferWriteToStores(transferWriteOp,
+ return convertTransferWriteToStores(rewriter, transferWriteOp,
valueMapping);
})
.Case([&](vector::ExtractStridedSliceOp extractStridedSliceOp) {
- return convertExtractStridedSlice(extractStridedSliceOp,
+ return convertExtractStridedSlice(rewriter, extractStridedSliceOp,
valueMapping);
})
.Case([&](vector::ContractionOp contractionOp) {
- return convertContractOpToMmaSync(contractionOp, valueMapping);
+ return convertContractOpToMmaSync(rewriter, contractionOp,
+ valueMapping);
})
.Case([&](scf::ForOp forOp) {
- convertForOp(forOp, valueMapping);
- return success();
+ return convertForOp(rewriter, forOp, valueMapping);
})
.Case([&](scf::YieldOp yieldOp) {
- convertYieldOp(yieldOp, valueMapping);
- return success();
+ return convertYieldOp(rewriter, yieldOp, valueMapping);
})
.Case([&](arith::ConstantOp constOp) {
- return convertConstantOpMmaSync(constOp, valueMapping);
+ return convertConstantOpMmaSync(rewriter, constOp, valueMapping);
})
.Default([&](Operation *op) {
- op->emitError() << "unhandled vector to mma type: " << *op;
- return failure();
+ return op->emitError() << "unhandled vector to mma type: " << *op;
})
.failed()) {
- op->emitError() << "Failed to convert op " << *op;
- return failure();
+ return op->emitError() << "Failed to convert op " << *op;
}
}
return success();
@@ -1123,12 +1268,13 @@ struct ConvertVectorToGPUPass
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
+ IRRewriter rewriter(&getContext());
if (useNvGpu.getValue()) {
- if (failed(convertVectorToNVVMCompatibleMMASync(getOperation())))
+ if (failed(
+ convertVectorToNVVMCompatibleMMASync(rewriter, getOperation())))
return signalPassFailure();
}
-
- (void)convertVectorToMMAOps(getOperation());
+ (void)convertVectorToMMAOps(rewriter, getOperation());
}
};
diff --git a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
index 6fdaaad746ec1..44f9b6d4ea012 100644
--- a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
+++ b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
@@ -170,7 +170,7 @@ static AffineMap getRegisterIndexToTileOffsetMap(int64_t lineSize,
}
FailureOr<AffineMap>
-nvgpu::getLaneIdAndValueIdToOperandCoord(Location loc, OpBuilder &builder,
+nvgpu::getLaneIdAndValueIdToOperandCoord(OpBuilder &builder, Location loc,
const WarpMatrixInfo &fragmentType) {
Type elementType = fragmentType.vectorType.getElementType();
ArrayRef<int64_t> operandShape = fragmentType.vectorType.getShape();
@@ -235,7 +235,7 @@ nvgpu::getLdMatrixParams(const WarpMatrixInfo &type, bool transpose) {
}
FailureOr<AffineMap>
-nvgpu::getLaneIdToLdMatrixMatrixCoord(Location loc, OpBuilder &builder,
+nvgpu::getLaneIdToLdMatrixMatrixCoord(OpBuilder &builder, Location loc,
const LdMatrixParams ¶ms) {
// One thread per 128b row.
const int bitsPerElement = static_cast<int>(
diff --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
index c742150401d8e..0ba9eb40483bd 100644
--- a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
+++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(convert-vector-to-gpu),canonicalize)" | FileCheck %s
+// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(convert-vector-to-gpu),canonicalize)" --split-input-file | FileCheck %s
#map0 = affine_map<(d0, d1) -> (d1, d0)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
@@ -25,6 +25,15 @@ func.func @matmul(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>, %arg2: mem
return
}
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#map4 = affine_map<(d0) -> (d0, 0)>
+#map5 = affine_map<(d0, d1) -> (d0, d1)>
+
// CHECK-LABEL: func @matmul_cst
// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f16
// CHECK-DAG: %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
@@ -43,6 +52,15 @@ func.func @matmul_cst(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>, %arg2:
return
}
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#map4 = affine_map<(d0) -> (d0, 0)>
+#map5 = affine_map<(d0, d1) -> (d0, d1)>
+
// CHECK-LABEL: func @matmul_broadcast
// CHECK-SAME: (%{{.*}}: memref<16x16xf16>, %{{.*}}: memref<16x16xf16>, %{{.*}}: memref<16x16xf16>, %[[F:.*]]: f16)
// CHECK-DAG: %[[C:.+]] = gpu.subgroup_mma_constant_matrix %[[F]] : !gpu.mma_matrix<16x16xf16, "COp">
@@ -61,6 +79,15 @@ func.func @matmul_broadcast(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>,
return
}
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#map4 = affine_map<(d0) -> (d0, 0)>
+#map5 = affine_map<(d0, d1) -> (d0, d1)>
+
// CHECK-LABEL: func @matmul_loop
// CHECK: %[[C:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 128 : index} : memref<128x128xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
// CHECK: %[[ACC:.+]] = scf.for {{.*}} iter_args(%[[ACC1:.+]] = %[[C]]) -> (!gpu.mma_matrix<16x16xf16, "COp">) {
@@ -86,6 +113,15 @@ func.func @matmul_loop(%arg0: memref<128x128xf16>, %arg1: memref<128x128xf16>, %
return
}
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#map4 = affine_map<(d0) -> (d0, 0)>
+#map5 = affine_map<(d0, d1) -> (d0, d1)>
+
// CHECK-LABEL: func @matmul_fused_elementwise
// CHECK-DAG: %[[CST_0:.+]] = arith.constant 0.000000e+00 : f16
// CHECK-DAG: %[[CST_1:.+]] = arith.constant 1.000000e+00 : f16
@@ -109,6 +145,15 @@ func.func @matmul_fused_elementwise(%arg0: memref<16x16xf16>, %arg1: memref<16x1
return
}
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#map4 = affine_map<(d0) -> (d0, 0)>
+#map5 = affine_map<(d0, d1) -> (d0, d1)>
+
// CHECK-LABEL: func @matmul_fused_broadcast
// CHECK-DAG: %[[CST_0:.+]] = arith.constant 0.000000e+00 : f16
// CHECK-DAG: %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
@@ -134,6 +179,15 @@ func.func @matmul_fused_broadcast(%arg0: memref<16x16xf16>, %arg1: memref<16x16x
return
}
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#map4 = affine_map<(d0) -> (d0, 0)>
+#map5 = affine_map<(d0, d1) -> (d0, d1)>
+
// CHECK-LABEL: func @matmul_3Dmemref
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {leadDimension = 16 : index} : memref<2x16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
@@ -153,6 +207,15 @@ func.func @matmul_3Dmemref(%arg0: memref<2x16x16xf16>, %arg1: memref<16xf16>, %a
return
}
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#map4 = affine_map<(d0) -> (d0, 0)>
+#map5 = affine_map<(d0, d1) -> (d0, d1)>
+
// CHECK-LABEL: func @matmul_memref_strided
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {leadDimension = 32 : index} : memref<2x16x16xf16, #{{.*}}> -> !gpu.mma_matrix<16x16xf16, "AOp">
@@ -172,6 +235,15 @@ func.func @matmul_memref_strided(%arg0: memref<2x16x16xf16, affine_map<(d0, d1,
return
}
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#map4 = affine_map<(d0) -> (d0, 0)>
+#map5 = affine_map<(d0, d1) -> (d0, d1)>
+
// CHECK-LABEL: func @matmul_transposed
// CHECK-DAG: %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
// CHECK-DAG: %[[B:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index, transpose} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp">
@@ -190,6 +262,15 @@ func.func @matmul_transposed(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>,
return
}
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#map4 = affine_map<(d0) -> (d0, 0)>
+#map5 = affine_map<(d0, d1) -> (d0, d1)>
+
// CHECK-LABEL: func @matmul_transposed_broadcasted_1d
// CHECK-DAG: %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}] {leadDimension = 0 : index, transpose} : memref<16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
// CHECK-DAG: %[[B:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}] {leadDimension = 0 : index} : memref<16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp">
@@ -208,6 +289,15 @@ func.func @matmul_transposed_broadcasted_1d(%arg0: memref<16xf16>, %arg1: memref
return
}
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#map4 = affine_map<(d0) -> (d0, 0)>
+#map5 = affine_map<(d0, d1) -> (d0, d1)>
+
// CHECK-LABEL: func @matmul_transposed_broadcasted_2d
// CHECK-DAG: %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}] {leadDimension = 0 : index, transpose} : memref<32x32xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
// CHECK-DAG: %[[B:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}] {leadDimension = 0 : index} : memref<32x32xf16> -> !gpu.mma_matrix<16x16xf16, "BOp">
@@ -226,12 +316,25 @@ func.func @matmul_transposed_broadcasted_2d(%arg0: memref<32x32xf16>, %arg1: mem
return
}
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#map4 = affine_map<(d0) -> (d0, 0)>
+#map5 = affine_map<(d0, d1) -> (d0, d1)>
+
+// CHECK-DAG: #[[$map:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
// Do not convert to subgroup_mma ops with integer types if signedness cannot be inferred.
// CHECK-LABEL: func @matmul_no_extend_int8
// CHECK-DAG: %[[A:.+]] = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} {in_bounds = [true, true]} : memref<16x16xi8>, vector<16x16xi8>
// CHECK-DAG: %[[B:.+]] = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} {in_bounds = [true, true]} : memref<16x16xi8>, vector<16x16xi8>
// CHECK-DAG: %[[C:.+]] = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} {in_bounds = [true, true]} : memref<16x16xi32>, vector<16x16xi32>
-// CHECK: %[[D:.+]] = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[A]], %[[B]], %[[C]] : vector<16x16xi8>, vector<16x16xi8> into vector<16x16xi32>
+// CHECK: %[[D:.+]] = vector.contract {indexing_maps = [#[[$map]], #[[$map1]], #[[$map2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[A]], %[[B]], %[[C]] : vector<16x16xi8>, vector<16x16xi8> into vector<16x16xi32>
// CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] {in_bounds = [true, true]} : vector<16x16xi32>, memref<16x16xi32>
func.func @matmul_no_extend_int8(%arg0: memref<16x16xi8>, %arg1: memref<16x16xi8>, %arg2: memref<16x16xi32>) {
%cst_0 = arith.constant dense<0> : vector<16x16xi8>
@@ -246,6 +349,15 @@ func.func @matmul_no_extend_int8(%arg0: memref<16x16xi8>, %arg1: memref<16x16xi8
return
}
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#map4 = affine_map<(d0) -> (d0, 0)>
+#map5 = affine_map<(d0, d1) -> (d0, d1)>
+
// CHECK-LABEL: func @matmul_int8
// CHECK-DAG: %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xi8> -> !gpu.mma_matrix<16x16xsi8, "AOp">
// CHECK-DAG: %[[B:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xi8> -> !gpu.mma_matrix<16x16xsi8, "BOp">
@@ -267,6 +379,15 @@ func.func @matmul_int8(%arg0: memref<16x16xi8>, %arg1: memref<16x16xi8>, %arg2:
return
}
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#map4 = affine_map<(d0) -> (d0, 0)>
+#map5 = affine_map<(d0, d1) -> (d0, d1)>
+
// CHECK-LABEL: func @matmul_mixed_signedness_int8
// CHECK-DAG: %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xi8> -> !gpu.mma_matrix<16x16xui8, "AOp">
// CHECK-DAG: %[[B:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xi8> -> !gpu.mma_matrix<16x16xsi8, "BOp">
More information about the Mlir-commits
mailing list