[Mlir-commits] [mlir] [mlir][amdgpu] Support for 8bit extf for 0d vector type (PR #126102)

Prashant Kumar llvmlistbot at llvm.org
Thu Feb 6 10:00:34 PST 2025


https://github.com/pashu123 created https://github.com/llvm/llvm-project/pull/126102

For 0d vector type the rewrite crashes.

>From 6da4acde95f614d79e396881e78fb418f0aa4f90 Mon Sep 17 00:00:00 2001
From: Prashant Kumar <pk5561 at gmail.com>
Date: Thu, 6 Feb 2025 23:22:34 +0530
Subject: [PATCH] [mlir][amdgpu] Support for 8bit extf for 0d vector type

For 0d vector type the rewrite crashes.
---
 .../lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp |  9 ++++++---
 .../Conversion/ArithToAMDGPU/8-bit-floats.mlir     | 14 +++++++++++++-
 2 files changed, 19 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 33370566996eee5..60a002c41bfb2f3 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -102,20 +102,23 @@ void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
     return rewriter.replaceOp(op, result);
   }
   int64_t numElements = inType.getNumElements();
+
   Value zero = rewriter.create<arith::ConstantOp>(
       loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
+  VectorType outType = cast<VectorType>(op.getOut().getType());
+
   if (inType.getShape().empty()) {
+    Value zerodSplat =
+        rewriter.createOrFold<vector::SplatOp>(loc, outType, zero);
     Value scalarIn =
         rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
-    // Recurse to send the 0-D vector case to the 1-D vector case
     Value scalarExt =
         rewriter.create<arith::ExtFOp>(loc, outElemType, scalarIn);
-    Value result = rewriter.create<vector::InsertOp>(loc, scalarExt, zero,
+    Value result = rewriter.create<vector::InsertOp>(loc, scalarExt, zerodSplat,
                                                      ArrayRef<int64_t>{});
     return rewriter.replaceOp(op, result);
   }
 
-  VectorType outType = cast<VectorType>(op.getOut().getType());
   VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements},
                                       outType.getElementType());
   Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);
diff --git a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
index bd90facb6154408..985fb532ea74ad3 100644
--- a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
+++ b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
@@ -10,7 +10,19 @@ func.func @scalar_ext(%v: f8E5M2FNUZ) -> f16 {
   return %w : f16
 }
 
-// No 0-D test because arith.extf hasn't been extended to support it.
+// -----
+
+// CHECK-LABEL: func.func @vector_zero_d(
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: vector<f8E5M2FNUZ>) -> vector<f32>
+// CHECK: %[[CONST:.+]] = arith.constant dense<0.000000e+00> : vector<f32>
+// CHECK: %[[EXTRACT:.+]] = vector.extract %[[ARG0]][] : f8E5M2FNUZ from vector<f8E5M2FNUZ>
+// CHECK: %[[CONVERT:.+]] = amdgpu.ext_packed_fp8 %[[EXTRACT]][0] : f8E5M2FNUZ to f32
+// CHECK: %[[RESULT:.+]] = vector.insert %[[CONVERT]], %[[CONST]] [] : f32 into vector<f32>
+// CHECK: return %[[RESULT]] : vector<f32>
+func.func @vector_zero_d(%v: vector<f8E5M2FNUZ>) -> vector<f32> {
+  %w = arith.extf %v : vector<f8E5M2FNUZ> to vector<f32>
+  return %w : vector<f32>
+}
 
 // -----
 



More information about the Mlir-commits mailing list