[Mlir-commits] [mlir] [MLIR] AMDGPUToROCDL: Use a bitcast op to reintepret a vector of i8 as single integer. (PR #111400)

Benoit Jacob llvmlistbot at llvm.org
Mon Oct 7 10:47:39 PDT 2024


================
@@ -351,39 +351,23 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
 
 } // namespace
 
-/// If `input` is a vector of bytes, concatentate those bytes in little-endian
-/// order to form a single integer of size 8 * [vector length]. This works
-/// around a wart in the AMDGPU intrinsics where operations that logically take
-/// vectors of bytes instead integers. Since we do not want to expose this
-/// implementation detail to MLIR, we correct for it here.
+/// Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL
+/// and LLVM AMDGPU intrinsics convention.
 ///
-/// In addition, convert vectors of LLVM bfloats to vectors of i16, since AMDGPU
-/// MFMA intrinsics pre-date the bfloat type.
-static Value mfmaConcatIfNeeded(ConversionPatternRewriter &rewriter,
-                                Location loc, Value input) {
+/// Specifically:
+/// 1. If `input` is a vector of N bytes, bitcast it to a (N * 8)-bit integer.
+/// 2. If the element type is bfloat16, bitcast it to i16.
+static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
+                                      Location loc, Value input) {
   Type inputType = input.getType();
   if (auto vectorType = dyn_cast<VectorType>(inputType)) {
     if (vectorType.getElementType().isBF16())
       return rewriter.create<LLVM::BitcastOp>(
           loc, vectorType.clone(rewriter.getI16Type()), input);
-
-    if (!vectorType.getElementType().isInteger(8))
-      return input;
-    int64_t numBytes = vectorType.getNumElements();
-    Type destType = rewriter.getIntegerType(numBytes * 8);
-    Value result = rewriter.create<LLVM::ConstantOp>(
-        loc, destType, rewriter.getIntegerAttr(destType, 0));
-    for (int64_t i = 0; i < numBytes; ++i) {
-      Value idxConst = createI32Constant(rewriter, loc, i);
-      Value element =
-          rewriter.create<LLVM::ExtractElementOp>(loc, input, idxConst);
-      Value extended = rewriter.create<LLVM::ZExtOp>(loc, destType, element);
-      Value shiftConst = rewriter.create<LLVM::ConstantOp>(
-          loc, destType, rewriter.getIntegerAttr(destType, i * 8));
-      Value shifted = rewriter.create<LLVM::ShlOp>(loc, extended, shiftConst);
-      result = rewriter.create<LLVM::OrOp>(loc, result, shifted);
+    if (vectorType.getElementType().isInteger(8)) {
+      return rewriter.create<LLVM::BitcastOp>(
+          loc, rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
     }
-    return result;
   }
   return input;
----------------
bjacob wrote:

I thought about that, but thought that maybe we want to avoid creating identity bitcasts unconditionally. And then if we're doing it conditionally when it would be needed, the code starts looking like its current form.
Feel free to improve this code further in a follow-up!

https://github.com/llvm/llvm-project/pull/111400


More information about the Mlir-commits mailing list