[Mlir-commits] [mlir] [MLIR] AMDGPUToROCDL: Use a bitcast op to reintepret a vector of i8 as single integer. (PR #111400)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Oct 7 09:43:54 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-gpu

Author: Benoit Jacob (bjacob)

<details>
<summary>Changes</summary>

Found by inspecting AMDGPU assembly - so the arithmetic ops created there were definitely making their way into the target ISA. A `LLVM::BitcastOp` seems equivalent, and evaporates as expected in the target asm.

Along the way, I thought that this helper function `mfmaConcatIfNeeded` could be renamed to `convertMFMAVectorOperand` to better convey its contract;  so I don't need to think about whether a bitcast is a legitimate "concat" :-)

---
Full diff: https://github.com/llvm/llvm-project/pull/111400.diff


1 Files Affected:

- (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+12-28) 


``````````diff
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 2b33f3773dc7d1..0ccd4133d3761d 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -351,39 +351,23 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
 
 } // namespace
 
-/// If `input` is a vector of bytes, concatentate those bytes in little-endian
-/// order to form a single integer of size 8 * [vector length]. This works
-/// around a wart in the AMDGPU intrinsics where operations that logically take
-/// vectors of bytes instead integers. Since we do not want to expose this
-/// implementation detail to MLIR, we correct for it here.
+/// Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL
+/// and LLVM AMDGPU intrinsics convention.
 ///
-/// In addition, convert vectors of LLVM bfloats to vectors of i16, since AMDGPU
-/// MFMA intrinsics pre-date the bfloat type.
-static Value mfmaConcatIfNeeded(ConversionPatternRewriter &rewriter,
-                                Location loc, Value input) {
+/// Specifically:
+/// 1. If `input` is a vector of N bytes, bitcast it to a (N * 8)-bit integer.
+/// 2. If the element type is bfloat16, bitcast it to i16.
+static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
+                                      Location loc, Value input) {
   Type inputType = input.getType();
   if (auto vectorType = dyn_cast<VectorType>(inputType)) {
     if (vectorType.getElementType().isBF16())
       return rewriter.create<LLVM::BitcastOp>(
           loc, vectorType.clone(rewriter.getI16Type()), input);
-
-    if (!vectorType.getElementType().isInteger(8))
-      return input;
-    int64_t numBytes = vectorType.getNumElements();
-    Type destType = rewriter.getIntegerType(numBytes * 8);
-    Value result = rewriter.create<LLVM::ConstantOp>(
-        loc, destType, rewriter.getIntegerAttr(destType, 0));
-    for (int64_t i = 0; i < numBytes; ++i) {
-      Value idxConst = createI32Constant(rewriter, loc, i);
-      Value element =
-          rewriter.create<LLVM::ExtractElementOp>(loc, input, idxConst);
-      Value extended = rewriter.create<LLVM::ZExtOp>(loc, destType, element);
-      Value shiftConst = rewriter.create<LLVM::ConstantOp>(
-          loc, destType, rewriter.getIntegerAttr(destType, i * 8));
-      Value shifted = rewriter.create<LLVM::ShlOp>(loc, extended, shiftConst);
-      result = rewriter.create<LLVM::OrOp>(loc, result, shifted);
+    if (vectorType.getElementType().isInteger(8)) {
+      return rewriter.create<LLVM::BitcastOp>(
+          loc, rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
     }
-    return result;
   }
   return input;
 }
@@ -656,8 +640,8 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
     OperationState loweredOp(loc, *maybeIntrinsic);
     loweredOp.addTypes(intrinsicOutType);
     loweredOp.addOperands(
-        {mfmaConcatIfNeeded(rewriter, loc, adaptor.getSourceA()),
-         mfmaConcatIfNeeded(rewriter, loc, adaptor.getSourceB()),
+        {convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
+         convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
          adaptor.getDestC(), createI32Constant(rewriter, loc, op.getCbsz()),
          createI32Constant(rewriter, loc, op.getAbid()),
          createI32Constant(rewriter, loc, getBlgpField)});

``````````

</details>


https://github.com/llvm/llvm-project/pull/111400


More information about the Mlir-commits mailing list