[Mlir-commits] [mlir] [mlir][amdgpu] Add support for multi-dim arith.truncf/extf fp8 lowering (PR #98074)
Rob Suderman
llvmlistbot at llvm.org
Mon Jul 8 16:25:59 PDT 2024
================
@@ -81,28 +78,37 @@ void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
Location loc = op.getLoc();
Value in = op.getIn();
Type outElemType = getElementTypeOrSelf(op.getOut().getType());
- if (!isa<VectorType>(in.getType())) {
+ auto inType = dyn_cast<VectorType>(in.getType());
+ if (!inType) {
Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
loc, rewriter.getF32Type(), in, 0);
Value result = castF32To(outElemType, asFloat, loc, rewriter);
return rewriter.replaceOp(op, result);
}
- VectorType inType = cast<VectorType>(in.getType());
int64_t numElements = inType.getNumElements();
Value zero = rewriter.create<arith::ConstantOp>(
loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
- Value result =
- rewriter.createOrFold<vector::SplatOp>(loc, op.getOut().getType(), zero);
if (inType.getShape().empty()) {
Value scalarIn =
rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
// Recurse to send the 0-D vector case to the 1-D vector case
Value scalarExt =
rewriter.create<arith::ExtFOp>(loc, outElemType, scalarIn);
- result = rewriter.create<vector::InsertOp>(loc, scalarExt, zero,
- ArrayRef<int64_t>{});
+ Value result = rewriter.create<vector::InsertOp>(loc, scalarExt, zero,
+ ArrayRef<int64_t>{});
return rewriter.replaceOp(op, result);
}
+
+ VectorType flatTy =
+ VectorType::get(SmallVector<int64_t>{numElements}, outElemType);
+ Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);
+
+ if (inType.getShape().size() > 1) {
----------------
rsuderman wrote:
Fixed
https://github.com/llvm/llvm-project/pull/98074
More information about the Mlir-commits
mailing list