[Mlir-commits] [mlir] [MLIR] AMDGPUToROCDL: Use a bitcast op to reintepret a vector of i8 as single integer. (PR #111400)

Benoit Jacob llvmlistbot at llvm.org
Mon Oct 7 11:14:14 PDT 2024


================
@@ -351,39 +351,23 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
 
 } // namespace
 
-/// If `input` is a vector of bytes, concatentate those bytes in little-endian
-/// order to form a single integer of size 8 * [vector length]. This works
-/// around a wart in the AMDGPU intrinsics where operations that logically take
-/// vectors of bytes instead integers. Since we do not want to expose this
-/// implementation detail to MLIR, we correct for it here.
+/// Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL
+/// and LLVM AMDGPU intrinsics convention.
 ///
-/// In addition, convert vectors of LLVM bfloats to vectors of i16, since AMDGPU
-/// MFMA intrinsics pre-date the bfloat type.
-static Value mfmaConcatIfNeeded(ConversionPatternRewriter &rewriter,
-                                Location loc, Value input) {
+/// Specifically:
+/// 1. If `input` is a vector of N bytes, bitcast it to a (N * 8)-bit integer.
+/// 2. If the element type is bfloat16, bitcast it to i16.
+static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
+                                      Location loc, Value input) {
   Type inputType = input.getType();
   if (auto vectorType = dyn_cast<VectorType>(inputType)) {
     if (vectorType.getElementType().isBF16())
       return rewriter.create<LLVM::BitcastOp>(
           loc, vectorType.clone(rewriter.getI16Type()), input);
-
-    if (!vectorType.getElementType().isInteger(8))
-      return input;
-    int64_t numBytes = vectorType.getNumElements();
-    Type destType = rewriter.getIntegerType(numBytes * 8);
-    Value result = rewriter.create<LLVM::ConstantOp>(
-        loc, destType, rewriter.getIntegerAttr(destType, 0));
-    for (int64_t i = 0; i < numBytes; ++i) {
-      Value idxConst = createI32Constant(rewriter, loc, i);
-      Value element =
-          rewriter.create<LLVM::ExtractElementOp>(loc, input, idxConst);
-      Value extended = rewriter.create<LLVM::ZExtOp>(loc, destType, element);
-      Value shiftConst = rewriter.create<LLVM::ConstantOp>(
-          loc, destType, rewriter.getIntegerAttr(destType, i * 8));
-      Value shifted = rewriter.create<LLVM::ShlOp>(loc, extended, shiftConst);
-      result = rewriter.create<LLVM::OrOp>(loc, result, shifted);
+    if (vectorType.getElementType().isInteger(8)) {
+      return rewriter.create<LLVM::BitcastOp>(
+          loc, rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
     }
-    return result;
   }
   return input;
----------------
bjacob wrote:

Good point! But I mean: we sometimes want to bitcast, and sometimes not. We don't bitcast f32 and f16's. So (even after the bf16 simplification you mentioned) we still need to have some logic based on element types.  I'd rather defer any further simplification to you as a follow-up, since you were aware of things such as this bf16 simplification, which I wasn't.

https://github.com/llvm/llvm-project/pull/111400


More information about the Mlir-commits mailing list