[Mlir-commits] [mlir] [MLIR] AMDGPUToROCDL: Use a bitcast op to reintepret a vector of i8 as single integer. (PR #111400)

Benoit Jacob llvmlistbot at llvm.org
Mon Oct 7 09:39:51 PDT 2024


https://github.com/bjacob updated https://github.com/llvm/llvm-project/pull/111400

>From f0909e4992a9a979b06abfaea63fecbfc134b90b Mon Sep 17 00:00:00 2001
From: Benoit Jacob <jacob.benoit.1 at gmail.com>
Date: Mon, 7 Oct 2024 11:11:13 -0500
Subject: [PATCH] bitcast

Signed-off-by: Benoit Jacob <jacob.benoit.1 at gmail.com>
---
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 40 ++++++-------------
 1 file changed, 12 insertions(+), 28 deletions(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 2b33f3773dc7d1..0ccd4133d3761d 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -351,39 +351,23 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
 
 } // namespace
 
-/// If `input` is a vector of bytes, concatentate those bytes in little-endian
-/// order to form a single integer of size 8 * [vector length]. This works
-/// around a wart in the AMDGPU intrinsics where operations that logically take
-/// vectors of bytes instead integers. Since we do not want to expose this
-/// implementation detail to MLIR, we correct for it here.
+/// Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL
+/// and LLVM AMDGPU intrinsics convention.
 ///
-/// In addition, convert vectors of LLVM bfloats to vectors of i16, since AMDGPU
-/// MFMA intrinsics pre-date the bfloat type.
-static Value mfmaConcatIfNeeded(ConversionPatternRewriter &rewriter,
-                                Location loc, Value input) {
+/// Specifically:
+/// 1. If `input` is a vector of N bytes, bitcast it to a (N * 8)-bit integer.
+/// 2. If the element type is bfloat16, bitcast it to i16.
+static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
+                                      Location loc, Value input) {
   Type inputType = input.getType();
   if (auto vectorType = dyn_cast<VectorType>(inputType)) {
     if (vectorType.getElementType().isBF16())
       return rewriter.create<LLVM::BitcastOp>(
           loc, vectorType.clone(rewriter.getI16Type()), input);
-
-    if (!vectorType.getElementType().isInteger(8))
-      return input;
-    int64_t numBytes = vectorType.getNumElements();
-    Type destType = rewriter.getIntegerType(numBytes * 8);
-    Value result = rewriter.create<LLVM::ConstantOp>(
-        loc, destType, rewriter.getIntegerAttr(destType, 0));
-    for (int64_t i = 0; i < numBytes; ++i) {
-      Value idxConst = createI32Constant(rewriter, loc, i);
-      Value element =
-          rewriter.create<LLVM::ExtractElementOp>(loc, input, idxConst);
-      Value extended = rewriter.create<LLVM::ZExtOp>(loc, destType, element);
-      Value shiftConst = rewriter.create<LLVM::ConstantOp>(
-          loc, destType, rewriter.getIntegerAttr(destType, i * 8));
-      Value shifted = rewriter.create<LLVM::ShlOp>(loc, extended, shiftConst);
-      result = rewriter.create<LLVM::OrOp>(loc, result, shifted);
+    if (vectorType.getElementType().isInteger(8)) {
+      return rewriter.create<LLVM::BitcastOp>(
+          loc, rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
     }
-    return result;
   }
   return input;
 }
@@ -656,8 +640,8 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
     OperationState loweredOp(loc, *maybeIntrinsic);
     loweredOp.addTypes(intrinsicOutType);
     loweredOp.addOperands(
-        {mfmaConcatIfNeeded(rewriter, loc, adaptor.getSourceA()),
-         mfmaConcatIfNeeded(rewriter, loc, adaptor.getSourceB()),
+        {convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
+         convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
          adaptor.getDestC(), createI32Constant(rewriter, loc, op.getCbsz()),
          createI32Constant(rewriter, loc, op.getAbid()),
          createI32Constant(rewriter, loc, getBlgpField)});



More information about the Mlir-commits mailing list