[Mlir-commits] [mlir] [mlir][amdgpu] Support for 8bit extf for 0d vector type (PR #126102)
Prashant Kumar
llvmlistbot at llvm.org
Thu Feb 6 10:00:34 PST 2025
https://github.com/pashu123 created https://github.com/llvm/llvm-project/pull/126102
For 0d vector type the rewrite crashes.
>From 6da4acde95f614d79e396881e78fb418f0aa4f90 Mon Sep 17 00:00:00 2001
From: Prashant Kumar <pk5561 at gmail.com>
Date: Thu, 6 Feb 2025 23:22:34 +0530
Subject: [PATCH] [mlir][amdgpu] Support for 8bit extf for 0d vector type
For 0d vector type the rewrite crashes.
---
.../lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp | 9 ++++++---
.../Conversion/ArithToAMDGPU/8-bit-floats.mlir | 14 +++++++++++++-
2 files changed, 19 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 33370566996eee5..60a002c41bfb2f3 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -102,20 +102,23 @@ void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
return rewriter.replaceOp(op, result);
}
int64_t numElements = inType.getNumElements();
+
Value zero = rewriter.create<arith::ConstantOp>(
loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
+ VectorType outType = cast<VectorType>(op.getOut().getType());
+
if (inType.getShape().empty()) {
+ Value zerodSplat =
+ rewriter.createOrFold<vector::SplatOp>(loc, outType, zero);
Value scalarIn =
rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
- // Recurse to send the 0-D vector case to the 1-D vector case
Value scalarExt =
rewriter.create<arith::ExtFOp>(loc, outElemType, scalarIn);
- Value result = rewriter.create<vector::InsertOp>(loc, scalarExt, zero,
+ Value result = rewriter.create<vector::InsertOp>(loc, scalarExt, zerodSplat,
ArrayRef<int64_t>{});
return rewriter.replaceOp(op, result);
}
- VectorType outType = cast<VectorType>(op.getOut().getType());
VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements},
outType.getElementType());
Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);
diff --git a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
index bd90facb6154408..985fb532ea74ad3 100644
--- a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
+++ b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
@@ -10,7 +10,19 @@ func.func @scalar_ext(%v: f8E5M2FNUZ) -> f16 {
return %w : f16
}
-// No 0-D test because arith.extf hasn't been extended to support it.
+// -----
+
+// CHECK-LABEL: func.func @vector_zero_d(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: vector<f8E5M2FNUZ>) -> vector<f32>
+// CHECK: %[[CONST:.+]] = arith.constant dense<0.000000e+00> : vector<f32>
+// CHECK: %[[EXTRACT:.+]] = vector.extract %[[ARG0]][] : f8E5M2FNUZ from vector<f8E5M2FNUZ>
+// CHECK: %[[CONVERT:.+]] = amdgpu.ext_packed_fp8 %[[EXTRACT]][0] : f8E5M2FNUZ to f32
+// CHECK: %[[RESULT:.+]] = vector.insert %[[CONVERT]], %[[CONST]] [] : f32 into vector<f32>
+// CHECK: return %[[RESULT]] : vector<f32>
+func.func @vector_zero_d(%v: vector<f8E5M2FNUZ>) -> vector<f32> {
+ %w = arith.extf %v : vector<f8E5M2FNUZ> to vector<f32>
+ return %w : vector<f32>
+}
// -----
More information about the Mlir-commits
mailing list