[Mlir-commits] [mlir] [MLIR] AMDGPUToROCDL: Use a bitcast op to reintepret a vector of i8 as single integer. (PR #111400)
Matt Arsenault
llvmlistbot at llvm.org
Mon Oct 7 10:45:42 PDT 2024
================
@@ -351,39 +351,23 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
} // namespace
-/// If `input` is a vector of bytes, concatentate those bytes in little-endian
-/// order to form a single integer of size 8 * [vector length]. This works
-/// around a wart in the AMDGPU intrinsics where operations that logically take
-/// vectors of bytes instead integers. Since we do not want to expose this
-/// implementation detail to MLIR, we correct for it here.
+/// Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL
+/// and LLVM AMDGPU intrinsics convention.
///
-/// In addition, convert vectors of LLVM bfloats to vectors of i16, since AMDGPU
-/// MFMA intrinsics pre-date the bfloat type.
-static Value mfmaConcatIfNeeded(ConversionPatternRewriter &rewriter,
- Location loc, Value input) {
+/// Specifically:
+/// 1. If `input` is a vector of N bytes, bitcast it to a (N * 8)-bit integer.
+/// 2. If the element type is bfloat16, bitcast it to i16.
+static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
+ Location loc, Value input) {
Type inputType = input.getType();
if (auto vectorType = dyn_cast<VectorType>(inputType)) {
if (vectorType.getElementType().isBF16())
return rewriter.create<LLVM::BitcastOp>(
loc, vectorType.clone(rewriter.getI16Type()), input);
-
- if (!vectorType.getElementType().isInteger(8))
- return input;
- int64_t numBytes = vectorType.getNumElements();
- Type destType = rewriter.getIntegerType(numBytes * 8);
- Value result = rewriter.create<LLVM::ConstantOp>(
- loc, destType, rewriter.getIntegerAttr(destType, 0));
- for (int64_t i = 0; i < numBytes; ++i) {
- Value idxConst = createI32Constant(rewriter, loc, i);
- Value element =
- rewriter.create<LLVM::ExtractElementOp>(loc, input, idxConst);
- Value extended = rewriter.create<LLVM::ZExtOp>(loc, destType, element);
- Value shiftConst = rewriter.create<LLVM::ConstantOp>(
- loc, destType, rewriter.getIntegerAttr(destType, i * 8));
- Value shifted = rewriter.create<LLVM::ShlOp>(loc, extended, shiftConst);
- result = rewriter.create<LLVM::OrOp>(loc, result, shifted);
+ if (vectorType.getElementType().isInteger(8)) {
+ return rewriter.create<LLVM::BitcastOp>(
+ loc, rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
}
- return result;
}
return input;
----------------
arsenm wrote:
This whole function can just create one bitcast? I don't see why you need to consider the element types. Especially since bf16 should be natively consumed now
https://github.com/llvm/llvm-project/pull/111400
More information about the Mlir-commits
mailing list