[Mlir-commits] [mlir] [mlir][AMDGPU] Add int4 intrinsics, mixed-type fp8 to handle gfx12 (PR #128963)

Daniel Hernandez-Juarez llvmlistbot at llvm.org
Thu Feb 27 09:05:14 PST 2025


================
@@ -427,25 +434,33 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
   // for int8. This is because, in LLVM, fp8 type is converted to int8, so the
   // fp8/int8 information is lost during the conversion process.
   auto mlirInputType = cast<VectorType>(mlirInput.getType());
-  bool isInputInt8 = mlirInputType.getElementType().isInteger(8);
-  if (isInputInt8) {
+  bool isInputInteger = mlirInputType.getElementType().isInteger();
+  if (isInputInteger) {
     // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
     bool localIsUnsigned = isUnsigned;
-    if (elemType.isUnsignedInteger(8)) {
+    if (elemType.isUnsignedInteger()) {
       localIsUnsigned = true;
-    } else if (elemType.isSignedInteger(8)) {
+    } else if (elemType.isSignedInteger()) {
       localIsUnsigned = false;
     }
     Value sign = createI1Constant(rewriter, loc, !localIsUnsigned);
     operands.push_back(sign);
   }
 
-  int64_t numBytes = vectorType.getNumElements();
+  int64_t numBits =
+      vectorType.getNumElements() * elemType.getIntOrFloatBitWidth();
   Type i32 = rewriter.getI32Type();
-  VectorType vectorType32bits = VectorType::get(numBytes * 8 / 32, i32);
-  auto llvmVectorType32bits = typeConverter->convertType(vectorType32bits);
+  Type intrinsicInType = numBits <= 32
+                             ? (Type)rewriter.getIntegerType(numBits)
+                             : (Type)VectorType::get(numBits / 32, i32);
+  auto llvmIntrinsicInType = typeConverter->convertType(intrinsicInType);
   Value result = rewriter.createOrFold<LLVM::BitcastOp>(
----------------
dhernandez0 wrote:

nit: this is an operand, it seems confusing to call it "result"

https://github.com/llvm/llvm-project/pull/128963


More information about the Mlir-commits mailing list