[Mlir-commits] [mlir] f35318e - [mlir][amdgpu] Add support for multi-dim arith.truncf/extf fp8 lowering (#98074)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jul 9 15:00:01 PDT 2024
Author: Rob Suderman
Date: 2024-07-09T14:59:58-07:00
New Revision: f35318e828d166bee04a117c1a87c1c0e32a57da
URL: https://github.com/llvm/llvm-project/commit/f35318e828d166bee04a117c1a87c1c0e32a57da
DIFF: https://github.com/llvm/llvm-project/commit/f35318e828d166bee04a117c1a87c1c0e32a57da.diff
LOG: [mlir][amdgpu] Add support for multi-dim arith.truncf/extf fp8 lowering (#98074)
The existing `fp8` lowering from `arith` to `amdgpu` bails out on the
multidimensional case. We can handle this by `vector.shape_cast`
collapsing to the 1-D case on extraction and re-casting back to the
desired output shape.
Added:
Modified:
mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 58764ad38e34f..b3798a3f7624b 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -67,9 +67,6 @@ LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const {
if (auto inVecType = dyn_cast<VectorType>(inType)) {
if (inVecType.isScalable())
return failure();
- if (inVecType.getShape().size() > 1)
- // Multi-dimensional vectors are currently unsupported.
- return failure();
inType = inVecType.getElementType();
}
return success(inType.isFloat8E5M2FNUZ() || inType.isFloat8E4M3FNUZ());
@@ -80,28 +77,38 @@ void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
Location loc = op.getLoc();
Value in = op.getIn();
Type outElemType = getElementTypeOrSelf(op.getOut().getType());
- if (!isa<VectorType>(in.getType())) {
+ auto inType = dyn_cast<VectorType>(in.getType());
+ if (!inType) {
Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
loc, rewriter.getF32Type(), in, 0);
Value result = castF32To(outElemType, asFloat, loc, rewriter);
return rewriter.replaceOp(op, result);
}
- VectorType inType = cast<VectorType>(in.getType());
int64_t numElements = inType.getNumElements();
Value zero = rewriter.create<arith::ConstantOp>(
loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
- Value result =
- rewriter.createOrFold<vector::SplatOp>(loc, op.getOut().getType(), zero);
if (inType.getShape().empty()) {
Value scalarIn =
rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
// Recurse to send the 0-D vector case to the 1-D vector case
Value scalarExt =
rewriter.create<arith::ExtFOp>(loc, outElemType, scalarIn);
- result = rewriter.create<vector::InsertOp>(loc, scalarExt, zero,
- ArrayRef<int64_t>{});
+ Value result = rewriter.create<vector::InsertOp>(loc, scalarExt, zero,
+ ArrayRef<int64_t>{});
return rewriter.replaceOp(op, result);
}
+
+ VectorType outType = cast<VectorType>(op.getOut().getType());
+ VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements},
+ outType.getElementType());
+ Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);
+
+ if (inType.getRank() > 1) {
+ inType = VectorType::get(SmallVector<int64_t>{numElements},
+ inType.getElementType());
+ in = rewriter.create<vector::ShapeCastOp>(loc, inType, in);
+ }
+
for (int64_t i = 0; i < numElements; i += 4) {
int64_t elemsThisOp = std::min(numElements, i + 4) - i;
Value inSlice = rewriter.create<vector::ExtractStridedSliceOp>(
@@ -113,6 +120,11 @@ void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
result = rewriter.create<vector::InsertOp>(loc, asType, result, i + j);
}
}
+
+ if (inType.getRank() != outType.getRank()) {
+ result = rewriter.create<vector::ShapeCastOp>(loc, outType, result);
+ }
+
rewriter.replaceOp(op, result);
}
@@ -181,9 +193,6 @@ LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
if (auto outVecType = dyn_cast<VectorType>(outType)) {
if (outVecType.isScalable())
return failure();
- if (outVecType.getShape().size() > 1)
- // Multi-dimensional vectors are currently unsupported.
- return failure();
outType = outVecType.getElementType();
}
auto inType = dyn_cast<FloatType>(getElementTypeOrSelf(op.getIn().getType()));
@@ -200,8 +209,9 @@ void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
Type outElemType = getElementTypeOrSelf(op.getOut().getType());
if (saturateFP8)
in = clampInput(rewriter, loc, outElemType, in);
+ auto inVectorTy = dyn_cast<VectorType>(in.getType());
VectorType truncResType = VectorType::get(4, outElemType);
- if (!isa<VectorType>(in.getType())) {
+ if (!inVectorTy) {
Value asFloat = castToF32(in, loc, rewriter);
Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
loc, truncResType, asFloat, /*sourceB=*/nullptr, 0,
@@ -213,18 +223,27 @@ void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
int64_t numElements = outType.getNumElements();
Value zero = rewriter.create<arith::ConstantOp>(
loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
- Value result = rewriter.createOrFold<vector::SplatOp>(loc, outType, zero);
if (outType.getShape().empty()) {
Value scalarIn =
rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
// Recurse to send the 0-D vector case to the 1-D vector case
Value scalarTrunc =
rewriter.create<arith::TruncFOp>(loc, outElemType, scalarIn);
- result = rewriter.create<vector::InsertOp>(loc, scalarTrunc, zero,
- ArrayRef<int64_t>{});
+ Value result = rewriter.create<vector::InsertOp>(loc, scalarTrunc, zero,
+ ArrayRef<int64_t>{});
return rewriter.replaceOp(op, result);
}
+ VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements},
+ outType.getElementType());
+ Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);
+
+ if (inVectorTy.getRank() > 1) {
+ inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
+ inVectorTy.getElementType());
+ in = rewriter.create<vector::ShapeCastOp>(loc, inVectorTy, in);
+ }
+
for (int64_t i = 0; i < numElements; i += 4) {
int64_t elemsThisOp = std::min(numElements, i + 4) - i;
Value thisResult = nullptr;
@@ -245,6 +264,11 @@ void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult,
result, i, 1);
}
+
+ if (inVectorTy.getRank() != outType.getRank()) {
+ result = rewriter.create<vector::ShapeCastOp>(loc, outType, result);
+ }
+
rewriter.replaceOp(op, result);
}
diff --git a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
index 159a2f02f0560..26a222a4a788e 100644
--- a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
+++ b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
@@ -115,3 +115,61 @@ func.func @vector_trunc_long(%v: vector<9xf32>) -> vector<9xf8E4M3FNUZ> {
%w = arith.truncf %v : vector<9xf32> to vector<9xf8E4M3FNUZ>
return %w : vector<9xf8E4M3FNUZ>
}
+
+// -----
+
+// CHECK-LABEL: func.func @vector_trunc_long_2d
+// CHECK-SAME: ([[V:%.+]]: vector<1x9xf32>)
+// CHECK: [[ZEROES:%.+]] = arith.constant dense<0.000000e+00> : vector<9xf8E4M3FNUZ>
+// CHECK: [[T0:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into undef[word 0]
+// CHECK: [[T1:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into [[T0]][word 1]
+// CHECK: [[W0:%.+]] = vector.insert_strided_slice [[T1]], [[ZEROES]] {offsets = [0], strides = [1]}
+
+// CHECK: [[T2:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into undef[word 0]
+// CHECK: [[T3:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into [[T2]][word 1]
+// CHECK: [[W1:%.+]] = vector.insert_strided_slice [[T3]], [[W0]] {offsets = [4], strides = [1]}
+
+// CHECK: [[T4:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, undef into undef[word 0]
+// CHECK: [[T4_SHORT:%.+]] = vector.extract_strided_slice [[T4]] {offsets = [0], sizes = [1], strides = [1]}
+// CHECK: [[W:%.+]] = vector.insert_strided_slice [[T4_SHORT]], [[W1]] {offsets = [8], strides = [1]}
+// CHECK: [[RE:%.+]] = vector.shape_cast [[W]] : vector<9xf8E4M3FNUZ> to vector<1x9xf8E4M3FNUZ>
+// CHECK: return [[RE]]
+func.func @vector_trunc_long_2d(%v: vector<1x9xf32>) -> vector<1x9xf8E4M3FNUZ> {
+ %w = arith.truncf %v : vector<1x9xf32> to vector<1x9xf8E4M3FNUZ>
+ return %w : vector<1x9xf8E4M3FNUZ>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @vector_ext_long_2d
+// CHECK-SAME: ([[V:%.+]]: vector<1x9xf8E4M3FNUZ>)
+// CHECK: [[CAST:%.+]] = vector.shape_cast [[V]] : vector<1x9xf8E4M3FNUZ> to vector<9xf8E4M3FNUZ>
+// CHECK: [[V0:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [0], sizes = [4], strides = [1]}
+// CHECK: [[F0:%.+]] = amdgpu.ext_packed_fp8 [[V0]][0]
+// CHECK: [[W0:%.+]] = vector.insert [[F0]]
+// CHECK: [[F1:%.+]] = amdgpu.ext_packed_fp8 [[V0]][1]
+// CHECK: [[W1:%.+]] = vector.insert [[F1]], [[W0]]
+// CHECK: [[F2:%.+]] = amdgpu.ext_packed_fp8 [[V0]][2]
+// CHECK: [[W2:%.+]] = vector.insert [[F2]], [[W1]]
+// CHECK: [[F3:%.+]] = amdgpu.ext_packed_fp8 [[V0]][3]
+// CHECK: [[W3:%.+]] = vector.insert [[F3]], [[W2]]
+
+// CHECK: [[V1:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [4], sizes = [4], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<4xf8E4M3FNUZ>
+// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V1]][0]
+// CHECK: [[W4:%.+]] = vector.insert [[F4]], [[W3]]
+// CHECK: [[F5:%.+]] = amdgpu.ext_packed_fp8 [[V1]][1]
+// CHECK: [[W5:%.+]] = vector.insert [[F5]], [[W4]]
+// CHECK: [[F6:%.+]] = amdgpu.ext_packed_fp8 [[V1]][2]
+// CHECK: [[W6:%.+]] = vector.insert [[F6]], [[W5]]
+// CHECK: [[F7:%.+]] = amdgpu.ext_packed_fp8 [[V1]][3]
+// CHECK: [[W7:%.+]] = vector.insert [[F7]], [[W6]]
+
+// CHECK: [[V2:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<1xf8E4M3FNUZ>
+// CHECK: [[F8:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0]
+// CHECK: [[W8:%.+]] = vector.insert [[F8]], [[W7]]
+// CHECK: [[CAST:%.+]] = vector.shape_cast [[W8]] : vector<9xf32> to vector<1x9xf32>
+// CHECK: return [[CAST]]
+func.func @vector_ext_long_2d(%v: vector<1x9xf8E4M3FNUZ>) -> vector<1x9xf32> {
+ %w = arith.extf %v : vector<1x9xf8E4M3FNUZ> to vector<1x9xf32>
+ return %w : vector<1x9xf32>
+}
More information about the Mlir-commits
mailing list