[Mlir-commits] [mlir] [mlir][amdgpu] Add support for multi-dim arith.truncf/extf fp8 lowering (PR #98074)
Rob Suderman
llvmlistbot at llvm.org
Mon Jul 8 16:26:23 PDT 2024
================
@@ -214,18 +224,27 @@ void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
int64_t numElements = outType.getNumElements();
Value zero = rewriter.create<arith::ConstantOp>(
loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
- Value result = rewriter.createOrFold<vector::SplatOp>(loc, outType, zero);
if (outType.getShape().empty()) {
Value scalarIn =
rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
// Recurse to send the 0-D vector case to the 1-D vector case
Value scalarTrunc =
rewriter.create<arith::TruncFOp>(loc, outElemType, scalarIn);
- result = rewriter.create<vector::InsertOp>(loc, scalarTrunc, zero,
- ArrayRef<int64_t>{});
+ Value result = rewriter.create<vector::InsertOp>(loc, scalarTrunc, zero,
+ ArrayRef<int64_t>{});
return rewriter.replaceOp(op, result);
}
+ VectorType flatTy =
+ VectorType::get(SmallVector<int64_t>{numElements}, outElemType);
+ Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);
+
+ if (inVectorTy.getShape().size() > 1) {
+ inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
----------------
rsuderman wrote:
Fixed.
https://github.com/llvm/llvm-project/pull/98074
More information about the Mlir-commits
mailing list