[Mlir-commits] [mlir] [MLIR] AMDGPUToROCDL: Use a bitcast op to reintepret a vector of i8 as single integer. (PR #111400)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Oct 7 09:43:54 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-gpu
Author: Benoit Jacob (bjacob)
<details>
<summary>Changes</summary>
Found by inspecting AMDGPU assembly - so the arithmetic ops created there were definitely making their way into the target ISA. A `LLVM::BitcastOp` seems equivalent, and evaporates as expected in the target asm.
Along the way, I thought that this helper function `mfmaConcatIfNeeded` could be renamed to `convertMFMAVectorOperand` to better convey its contract; so I don't need to think about whether a bitcast is a legitimate "concat" :-)
---
Full diff: https://github.com/llvm/llvm-project/pull/111400.diff
1 Files Affected:
- (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+12-28)
``````````diff
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 2b33f3773dc7d1..0ccd4133d3761d 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -351,39 +351,23 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
} // namespace
-/// If `input` is a vector of bytes, concatentate those bytes in little-endian
-/// order to form a single integer of size 8 * [vector length]. This works
-/// around a wart in the AMDGPU intrinsics where operations that logically take
-/// vectors of bytes instead integers. Since we do not want to expose this
-/// implementation detail to MLIR, we correct for it here.
+/// Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL
+/// and LLVM AMDGPU intrinsics convention.
///
-/// In addition, convert vectors of LLVM bfloats to vectors of i16, since AMDGPU
-/// MFMA intrinsics pre-date the bfloat type.
-static Value mfmaConcatIfNeeded(ConversionPatternRewriter &rewriter,
- Location loc, Value input) {
+/// Specifically:
+/// 1. If `input` is a vector of N bytes, bitcast it to a (N * 8)-bit integer.
+/// 2. If the element type is bfloat16, bitcast it to i16.
+static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
+ Location loc, Value input) {
Type inputType = input.getType();
if (auto vectorType = dyn_cast<VectorType>(inputType)) {
if (vectorType.getElementType().isBF16())
return rewriter.create<LLVM::BitcastOp>(
loc, vectorType.clone(rewriter.getI16Type()), input);
-
- if (!vectorType.getElementType().isInteger(8))
- return input;
- int64_t numBytes = vectorType.getNumElements();
- Type destType = rewriter.getIntegerType(numBytes * 8);
- Value result = rewriter.create<LLVM::ConstantOp>(
- loc, destType, rewriter.getIntegerAttr(destType, 0));
- for (int64_t i = 0; i < numBytes; ++i) {
- Value idxConst = createI32Constant(rewriter, loc, i);
- Value element =
- rewriter.create<LLVM::ExtractElementOp>(loc, input, idxConst);
- Value extended = rewriter.create<LLVM::ZExtOp>(loc, destType, element);
- Value shiftConst = rewriter.create<LLVM::ConstantOp>(
- loc, destType, rewriter.getIntegerAttr(destType, i * 8));
- Value shifted = rewriter.create<LLVM::ShlOp>(loc, extended, shiftConst);
- result = rewriter.create<LLVM::OrOp>(loc, result, shifted);
+ if (vectorType.getElementType().isInteger(8)) {
+ return rewriter.create<LLVM::BitcastOp>(
+ loc, rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
}
- return result;
}
return input;
}
@@ -656,8 +640,8 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
OperationState loweredOp(loc, *maybeIntrinsic);
loweredOp.addTypes(intrinsicOutType);
loweredOp.addOperands(
- {mfmaConcatIfNeeded(rewriter, loc, adaptor.getSourceA()),
- mfmaConcatIfNeeded(rewriter, loc, adaptor.getSourceB()),
+ {convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
+ convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
adaptor.getDestC(), createI32Constant(rewriter, loc, op.getCbsz()),
createI32Constant(rewriter, loc, op.getAbid()),
createI32Constant(rewriter, loc, getBlgpField)});
``````````
</details>
https://github.com/llvm/llvm-project/pull/111400
More information about the Mlir-commits
mailing list