[Mlir-commits] [mlir] 7f08503 - Introduce `arith.scaling_extf` and `arith.scaling_truncf` (#141965)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jun 9 11:13:34 PDT 2025
Author: Umang Yadav
Date: 2025-06-09T13:13:31-05:00
New Revision: 7f08503a3bf3acdd2a58ac712d5e95682ce583dd
URL: https://github.com/llvm/llvm-project/commit/7f08503a3bf3acdd2a58ac712d5e95682ce583dd
DIFF: https://github.com/llvm/llvm-project/commit/7f08503a3bf3acdd2a58ac712d5e95682ce583dd.diff
LOG: Introduce `arith.scaling_extf` and `arith.scaling_truncf` (#141965)
This PR adds `arith.scaling_truncf` and `arith.scaling_extf` operations
which supports the block quantization following OCP MXFP specs listed
here
https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
OCP MXFP Spec comes with reference implementation here
https://github.com/microsoft/microxcaling/tree/main
Interesting piece of reference code is this method `_quantize_mx`
https://github.com/microsoft/microxcaling/blob/7bc41952de394f5cc5e782baf132e7c7542eb4e4/mx/mx_ops.py#L173.
Both `arith.scaling_truncf` and `arith.scaling_extf` are designed to be
an elementwise operation. Please see description about them in
`ArithOps.td` file for more details.
Internally,
`arith.scaling_truncf` does the
`arith.truncf(arith.divf(input/(2^scale)))`. `scale` should have
necessary broadcast, clamping, normalization and NaN propagation done
before callling into `arith.scaling_truncf`.
`arith.scaling_extf` does the `arith.mulf(2^scale, input)` after taking
care of necessary data type conversions.
CC: @krzysz00 @dhernandez0 @bjacob @pashu123 @MaheshRavishankar
@tgymnich
---------
Co-authored-by: Prashant Kumar <pk5561 at gmail.com>
Co-authored-by: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Added:
Modified:
mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
mlir/include/mlir/IR/Builders.h
mlir/lib/Dialect/Arith/IR/ArithOps.cpp
mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
mlir/lib/IR/Builders.cpp
mlir/test/Dialect/Arith/expand-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 599b3b982ec7f..adc27ae6bdafb 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1215,6 +1215,58 @@ def Arith_ExtFOp : Arith_FToFCastOp<"extf", [DeclareOpInterfaceMethods<ArithFast
attr-dict `:` type($in) `to` type($out) }];
}
+//===----------------------------------------------------------------------===//
+// Scaling ExtFOp
+//===----------------------------------------------------------------------===//
+def Arith_ScalingExtFOp
+ : Arith_Op<
+ "scaling_extf", [Pure, SameInputOutputTensorDims,
+ DeclareOpInterfaceMethods<ArithFastMathInterface>,
+ DeclareOpInterfaceMethods<CastOpInterface>]>,
+ Arguments<(ins FloatLike:$in, FloatLike:$scale,
+ OptionalAttr<Arith_FastMathAttr>:$fastmath)>,
+ Results<(outs FloatLike:$out)> {
+ let summary = "Upcasts input floats using provided scales values following "
+ "OCP MXFP Spec";
+ let description = [{
+ This operation upcasts input floating-point values using provided scale
+ values. It expects both scales and the input operand to be of the same shape,
+ making the operation elementwise. Scales are usually calculated per block
+ following the OCP MXFP spec as described in https://arxiv.org/abs/2310.10537.
+
+ If scales are calculated per block where blockSize != 1, then scales may
+ require broadcasting to make this operation elementwise. For example, let's
+ say the input is of shape `<dim1 x dim2 x ... dimN>`. Given blockSize != 1 and
+ assuming quantization happens on the last axis, the input can be reshaped to
+ `<dim1 x dim2 x ... (dimN/blockSize) x blockSize>`. Scales will be calculated
+ per block on the last axis. Therefore, scales will be of shape
+ `<dim1 x dim2 x ... (dimN/blockSize) x 1>`. Scales could also be of some other
+ shape as long as it is broadcast compatible with the input, e.g.,
+ `<1 x 1 x ... (dimN/blockSize) x 1>`.
+
+ In this example, before calling into `arith.scaling_extf`, scales must be
+ broadcasted to `<dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>`. Note
+ that there could be multiple quantization axes. Internally,
+ `arith.scaling_extf` would perform the following:
+
+ ```
+ resultTy = get_type(result)
+ scaleTy = get_type(scale)
+ inputTy = get_type(input)
+ scale.exponent = arith.truncf(scale) : scaleTy to f8E8M0
+ scale.extf = arith.extf(scale.exponent) : f8E8M0 to resultTy
+ input.extf = arith.extf(input) : inputTy to resultTy
+ result = arith.mulf(scale.extf, input.extf)
+ ```
+ It propagates NaN values. Therefore, if either scale or the input element
+ contains NaN, then the output element value will also be a NaN.
+ }];
+ let hasVerifier = 1;
+ let assemblyFormat =
+ [{ $in `,` $scale (`fastmath` `` $fastmath^)? attr-dict `:`
+ type($in) `,` type($scale) `to` type($out)}];
+}
+
//===----------------------------------------------------------------------===//
// TruncIOp
//===----------------------------------------------------------------------===//
@@ -1280,6 +1332,63 @@ def Arith_TruncFOp :
attr-dict `:` type($in) `to` type($out) }];
}
+//===----------------------------------------------------------------------===//
+// Scaling TruncFOp
+//===----------------------------------------------------------------------===//
+
+def Arith_ScalingTruncFOp
+ : Arith_Op<"scaling_truncf",
+ [Pure, SameInputOutputTensorDims,
+ DeclareOpInterfaceMethods<ArithRoundingModeInterface>,
+ DeclareOpInterfaceMethods<ArithFastMathInterface>,
+ DeclareOpInterfaceMethods<CastOpInterface>]>,
+ Arguments<(ins FloatLike:$in, FloatLike:$scale,
+ OptionalAttr<Arith_RoundingModeAttr>:$roundingmode,
+ OptionalAttr<Arith_FastMathAttr>:$fastmath)>,
+ Results<(outs FloatLike:$out)> {
+ let summary = "Downcasts input floating point values using provided scales "
+ "values following OCP MXFP Spec";
+ let description = [{
+ This operation downcasts input using the provided scale values. It expects
+ both scales and the input operand to be of the same shape and, therefore,
+ makes the operation elementwise. Scales are usually calculated per block
+ following the OCP MXFP spec as described in https://arxiv.org/abs/2310.10537.
+ Users are required to normalize and clamp the scales as necessary before calling
+ passing them to this operation. OCP MXFP spec also does the flushing of denorms
+ on the input operand, which should be handled during lowering by passing appropriate
+ fastMath flag to this operation.
+
+ If scales are calculated per block where blockSize != 1, scales may require
+ broadcasting to make this operation elementwise. For example, let's say the
+ input is of shape `<dim1 x dim2 x ... dimN>`. Given blockSize != 1 and
+ assuming quantization happens on the last axis, the input can be reshaped to
+ `<dim1 x dim2 x ... (dimN/blockSize) x blockSize>`. Scales will be calculated
+ per block on the last axis. Therefore, scales will be of shape
+ `<dim1 x dim2 x ... (dimN/blockSize) x 1>`. Scales could also be of some other
+ shape as long as it is broadcast compatible with the input, e.g.,
+ `<1 x 1 x ... (dimN/blockSize) x 1>`.
+
+ In this example, before calling into `arith.scaling_truncf`, scales must be
+ broadcasted to `<dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>`. Note
+ that there could be multiple quantization axes. Internally,
+ `arith.scaling_truncf` would perform the following:
+
+ ```
+ scaleTy = get_type(scale)
+ inputTy = get_type(input)
+ resultTy = get_type(result)
+ scale.exponent = arith.truncf(scale) : scaleTy to f8E8M0
+ scale.extf = arith.extf(scale.exponent) : f8E8M0 to inputTy
+ result = arith.divf(input, scale.extf)
+ result.cast = arith.truncf(result, resultTy)
+ ```
+ }];
+ let hasVerifier = 1;
+ let assemblyFormat =
+ [{ $in `,` $scale ($roundingmode^)? (`fastmath` `` $fastmath^)? attr-dict `:`
+ type($in) `,` type($scale) `to` type($out)}];
+}
+
//===----------------------------------------------------------------------===//
// UIToFPOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
index 5aaac8d8e3dc5..e0a4567d6f406 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
@@ -62,6 +62,9 @@ void populateExpandBFloat16Patterns(RewritePatternSet &patterns);
/// Add patterns to expand Arith f8e8m0 patterns to lower level bitcasts/shifts.
void populateExpandF8E8M0Patterns(RewritePatternSet &patterns);
+/// Add patterns to expand scaling ExtF/TruncF ops to equivalent arith ops
+void populateExpandScalingExtTruncPatterns(RewritePatternSet &patterns);
+
/// Add patterns to expand Arith ops.
void populateArithExpandOpsPatterns(RewritePatternSet &patterns);
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 3f7b3268dd085..d68dbdb1efeef 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -60,6 +60,7 @@ class Builder {
Attribute metadata = Attribute());
// Types.
+ FloatType getF8E8M0Type();
FloatType getBF16Type();
FloatType getF16Type();
FloatType getTF32Type();
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 41f2d0f3425e2..9e53e195274aa 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1451,6 +1451,19 @@ bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }
+//===----------------------------------------------------------------------===//
+// ScalingExtFOp
+//===----------------------------------------------------------------------===//
+
+bool arith::ScalingExtFOp::areCastCompatible(TypeRange inputs,
+ TypeRange outputs) {
+ return checkWidthChangeCast<std::greater, FloatType>(inputs.front(), outputs);
+}
+
+LogicalResult arith::ScalingExtFOp::verify() {
+ return verifyExtOp<FloatType>(*this);
+}
+
//===----------------------------------------------------------------------===//
// TruncIOp
//===----------------------------------------------------------------------===//
@@ -1565,6 +1578,19 @@ LogicalResult arith::TruncFOp::verify() {
return verifyTruncateOp<FloatType>(*this);
}
+//===----------------------------------------------------------------------===//
+// ScalingTruncFOp
+//===----------------------------------------------------------------------===//
+
+bool arith::ScalingTruncFOp::areCastCompatible(TypeRange inputs,
+ TypeRange outputs) {
+ return checkWidthChangeCast<std::less, FloatType>(inputs.front(), outputs);
+}
+
+LogicalResult arith::ScalingTruncFOp::verify() {
+ return verifyTruncateOp<FloatType>(*this);
+}
+
//===----------------------------------------------------------------------===//
// AndIOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 95546bb09e765..534aff9562b7a 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -6,10 +6,10 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Arith/Transforms/Passes.h"
-
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -31,7 +31,6 @@ static Value createConst(Location loc, Type type, int value,
return rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(shapedTy, attr));
}
-
return rewriter.create<arith::ConstantOp>(loc, attr);
}
@@ -357,9 +356,10 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
f32Bits = b.create<arith::SelectOp>(isNan, cF32NaN, f32Bits);
Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
if (resultETy.getIntOrFloatBitWidth() < 32) {
- result = b.create<arith::TruncFOp>(resultTy, result);
+ result = b.create<arith::TruncFOp>(resultTy, result, nullptr,
+ op.getFastmathAttr());
} else if (resultETy.getIntOrFloatBitWidth() > 32) {
- result = b.create<arith::ExtFOp>(resultTy, result);
+ result = b.create<arith::ExtFOp>(resultTy, result, op.getFastmathAttr());
}
rewriter.replaceOp(op, result);
return success();
@@ -395,9 +395,10 @@ struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
if (operandETy.getIntOrFloatBitWidth() < 32) {
- operand = b.create<arith::ExtFOp>(f32Ty, operand);
+ operand = b.create<arith::ExtFOp>(f32Ty, operand, op.getFastmathAttr());
} else if (operandETy.getIntOrFloatBitWidth() > 32) {
- operand = b.create<arith::TruncFOp>(f32Ty, operand);
+ operand = b.create<arith::TruncFOp>(
+ f32Ty, operand, op.getRoundingmodeAttr(), op.getFastmathAttr());
}
Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operand);
Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
@@ -409,6 +410,83 @@ struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
}
};
+struct ScalingExtFOpConverter : public OpRewritePattern<arith::ScalingExtFOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(arith::ScalingExtFOp op,
+ PatternRewriter &rewriter) const final {
+ ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ Value inputOperand = op.getIn();
+ Value scaleOperand = op.getScale();
+ Type scaleTy = scaleOperand.getType();
+ Type scaleETy = getElementTypeOrSelf(scaleOperand);
+ // allow implicit exponent extraction from 16/32 bits floats
+ if (scaleETy.getIntOrFloatBitWidth() >= 16) {
+ scaleETy = b.getF8E8M0Type();
+ scaleTy = cloneToShapedType(scaleTy, scaleETy);
+ scaleOperand = b.create<arith::TruncFOp>(scaleTy, scaleOperand, nullptr,
+ op.getFastmathAttr());
+ }
+ if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) {
+ return rewriter.notifyMatchFailure(
+ op, "scaling_extf is using scales of type which can not be converted "
+ "to f8E8M0FNU");
+ }
+ Type resultTy = op.getType();
+ // extf on scale will essentially create floating point number
+ // of type resulTy that is 2^scale and will also propagate NaNs
+ Value scaleExt =
+ b.create<arith::ExtFOp>(resultTy, scaleOperand, op.getFastmathAttr());
+ Value inputExt =
+ b.create<arith::ExtFOp>(resultTy, inputOperand, op.getFastmathAttr());
+ Value result =
+ b.create<arith::MulFOp>(inputExt, scaleExt, op.getFastmathAttr());
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
+/*
+Expands arith.ScalingTruncFOp(in, scale) into
+ scale = arith.truncf(scale) : scaleTy -> f8E8M0FNU
+ result = arith.truncf(in / (2^scale))
+ */
+struct ScalingTruncFOpConverter
+ : public OpRewritePattern<arith::ScalingTruncFOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(arith::ScalingTruncFOp op,
+ PatternRewriter &rewriter) const final {
+ ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ Value inputOperand = op.getIn();
+ Value scaleOperand = op.getScale();
+ Type scaleTy = scaleOperand.getType();
+ Type scaleETy = getElementTypeOrSelf(scaleOperand);
+ // allow implicit exponent extraction from 16/32 bits floats
+ if (scaleETy.getIntOrFloatBitWidth() >= 16) {
+ scaleETy = b.getF8E8M0Type();
+ scaleTy = cloneToShapedType(scaleTy, scaleETy);
+ scaleOperand = b.create<arith::TruncFOp>(scaleTy, scaleOperand, nullptr,
+ op.getFastmathAttr());
+ }
+ if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) {
+ return rewriter.notifyMatchFailure(
+ op, "scaling_truncf is using scales type which can not be converted "
+ "to f8E8M0FNU");
+ }
+ Type resultTy = op.getType();
+ Type inputTy = inputOperand.getType();
+ // this will create a floating point number of type
+ // inputTy that is 2^scale and will also propagate NaNs
+ scaleOperand =
+ b.create<arith::ExtFOp>(inputTy, scaleOperand, op.getFastmathAttr());
+ Value result = b.create<arith::DivFOp>(inputOperand, scaleOperand,
+ op.getFastmathAttr());
+ Value resultCast = b.create<arith::TruncFOp>(
+ resultTy, result, op.getRoundingmodeAttr(), op.getFastmathAttr());
+ rewriter.replaceOp(op, resultCast);
+ return success();
+ }
+};
+
struct ArithExpandOpsPass
: public arith::impl::ArithExpandOpsPassBase<ArithExpandOpsPass> {
using ArithExpandOpsPassBase::ArithExpandOpsPassBase;
@@ -432,7 +510,9 @@ struct ArithExpandOpsPass
arith::MaximumFOp,
arith::MinimumFOp,
arith::MaxNumFOp,
- arith::MinNumFOp
+ arith::MinNumFOp,
+ arith::ScalingExtFOp,
+ arith::ScalingTruncFOp
>();
if (includeBf16) {
@@ -492,8 +572,15 @@ void mlir::arith::populateExpandF8E8M0Patterns(RewritePatternSet &patterns) {
patterns.getContext());
}
+void mlir::arith::populateExpandScalingExtTruncPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<ScalingExtFOpConverter, ScalingTruncFOpConverter>(
+ patterns.getContext());
+}
+
void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) {
populateCeilFloorDivExpandOpsPatterns(patterns);
+ populateExpandScalingExtTruncPatterns(patterns);
// clang-format off
patterns.add<
MaxMinIOpConverter<MaxSIOp, arith::CmpIPredicate::sgt>,
@@ -503,7 +590,7 @@ void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) {
MaximumMinimumFOpConverter<MaximumFOp, arith::CmpFPredicate::UGT>,
MaximumMinimumFOpConverter<MinimumFOp, arith::CmpFPredicate::ULT>,
MaxNumMinNumFOpConverter<MaxNumFOp, arith::CmpFPredicate::UGT>,
- MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT>
+ MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT>
>(patterns.getContext());
// clang-format on
}
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 89102115cdc40..5f7bc50afc418 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -34,6 +34,8 @@ Location Builder::getFusedLoc(ArrayRef<Location> locs, Attribute metadata) {
// Types.
//===----------------------------------------------------------------------===//
+FloatType Builder::getF8E8M0Type() { return Float8E8M0FNUType::get(context); }
+
FloatType Builder::getBF16Type() { return BFloat16Type::get(context); }
FloatType Builder::getF16Type() { return Float16Type::get(context); }
diff --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir
index 5b6badf13d763..db1349feaff3a 100644
--- a/mlir/test/Dialect/Arith/expand-ops.mlir
+++ b/mlir/test/Dialect/Arith/expand-ops.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt %s -arith-expand="include-bf16=true include-f8e8m0=true" -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -arith-expand="include-bf16=true include-f8e8m0=true" -verify-diagnostics -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -arith-expand -split-input-file -verify-diagnostics | FileCheck %s --check-prefix=SCHECK
// Test ceil divide with signed integer
// CHECK-LABEL: func @ceildivi
@@ -253,7 +254,7 @@ func.func @truncf_f32_to_f8E8M0FNU(%arg0 : f32) -> f8E8M0FNU {
%0 = arith.truncf %arg0 : f32 to f8E8M0FNU
return %0 : f8E8M0FNU
}
-// CHECK-LABLE: @truncf_f32_to_f8E8M0FNU
+// CHECK-LABEL: @truncf_f32_to_f8E8M0FNU
// CHECK: %[[BITCAST:.+]] = arith.bitcast %arg0 : f32 to i32
// CHECK: %[[C23_i32:.+]] = arith.constant 23 : i32
// CHECK: %[[SHRUI:.+]] = arith.shrui %[[BITCAST]], %[[C23_i32]] : i32
@@ -267,7 +268,7 @@ func.func @truncf_f16_to_f8E8M0FNU(%arg0 : f16) -> f8E8M0FNU {
%0 = arith.truncf %arg0 : f16 to f8E8M0FNU
return %0 : f8E8M0FNU
}
-// CHECK-LABLE: @truncf_f16_to_f8E8M0FNU
+// CHECK-LABEL: @truncf_f16_to_f8E8M0FNU
// CHECK: %[[EXTF:.+]] = arith.extf %arg0 : f16 to f32
// CHECK: %[[BITCAST:.+]] = arith.bitcast %[[EXTF]] : f32 to i32
// CHECK: %[[C23_i32:.+]] = arith.constant 23 : i32
@@ -305,9 +306,76 @@ func.func @truncf_vector_bf16_to_f8E8M0FNU(%arg0 : vector<4xbf16>) -> vector<4xf
// CHECK-LABEL: @truncf_vector_bf16_to_f8E8M0FNU
// CHECK-NOT: arith.truncf
+// CHECK: return
+// -----
+
+func.func @scaling_truncf_f32_to_f4E2M1FN(%arg0 : f32, %arg1: f8E8M0FNU) -> f4E2M1FN {
+ %0 = arith.scaling_truncf %arg0, %arg1 : f32, f8E8M0FNU to f4E2M1FN
+ return %0 : f4E2M1FN
+}
+
+// SCHECK-LABEL: @scaling_truncf_f32_to_f4E2M1FN
+// SCHECK: %[[SCALEF32:.+]] = arith.extf %arg1 : f8E8M0FNU to f32
+// SCHECK: %[[DIVF:.+]] = arith.divf %arg0, %[[SCALEF32]] : f32
+// SCHECK: %[[RESULT:.+]] = arith.truncf %[[DIVF]] : f32 to f4E2M1FN
+// SCHECK: return %[[RESULT]]
+
+// -----
+
+func.func @scaling_truncf_vector_f16_to_f6E3M2FN(%arg0 : vector<4xf16>, %arg1: vector<4xf8E8M0FNU>) -> vector<4xf6E3M2FN> {
+ %0 = arith.scaling_truncf %arg0, %arg1 : vector<4xf16>, vector<4xf8E8M0FNU> to vector<4xf6E3M2FN>
+ return %0 : vector<4xf6E3M2FN>
+}
+
+// SCHECK-LABEL: @scaling_truncf_vector_f16_to_f6E3M2FN
+// SCHECK: %[[SCALEF16:.+]] = arith.extf %arg1 : vector<4xf8E8M0FNU> to vector<4xf16>
+// SCHECK: %[[DIVF:.+]] = arith.divf %arg0, %[[SCALEF16]] : vector<4xf16>
+// SCHECK: %[[RESULT:.+]] = arith.truncf %[[DIVF]] : vector<4xf16> to vector<4xf6E3M2FN>
+// SCHECK: return %[[RESULT]] : vector<4xf6E3M2FN>
// -----
+
+func.func @scaling_truncf_propagate_rounding_mode_fast_math(%arg0 : vector<4xf16>, %arg1: vector<4xf16>) -> vector<4xf6E3M2FN> {
+ %0 = arith.scaling_truncf %arg0, %arg1 to_nearest_even fastmath<fast> : vector<4xf16>, vector<4xf16> to vector<4xf6E3M2FN>
+ return %0 : vector<4xf6E3M2FN>
+}
+// SCHECK-LABEL: @scaling_truncf_propagate_rounding_mode_fast_math
+// SCHECK: %[[SCALEF8:.+]] = arith.truncf %arg1 fastmath<fast> : vector<4xf16> to vector<4xf8E8M0FNU>
+// SCHECK: %[[SCALEINTY:.+]] = arith.extf %[[SCALEF8]] fastmath<fast> : vector<4xf8E8M0FNU> to vector<4xf16>
+// SCHECK: %[[DIVF:.+]] = arith.divf %arg0, %[[SCALEINTY]] fastmath<fast> : vector<4xf16>
+// SCHECK: %[[TRUNCF:.+]] = arith.truncf [[_:%[a-zA-Z0-9_]+]] to_nearest_even fastmath<fast> : vector<4xf16> to vector<4xf6E3M2FN>
+// SCHECK: return %[[TRUNCF]] : vector<4xf6E3M2FN>
+
+// -----
+
+func.func @scaling_truncf_f16_to_f4E2M1FN_using_f16_scales(%arg0: f16, %arg1 : f16) -> f4E2M1FN {
+ %0 = arith.scaling_truncf %arg0, %arg1 : f16, f16 to f4E2M1FN
+ return %0 : f4E2M1FN
+}
+// SCHECK-LABEL: @scaling_truncf_f16_to_f4E2M1FN_using_f16_scales
+// SCHECK: %[[SCALETRUNCF:.+]] = arith.truncf %arg1 : f16 to f8E8M0FN
+// SCHECK: return
+
+// -----
+func.func @scaling_truncf_vector_f16_to_f4E2M1FN_using_f16_scales(%arg0: vector<4xf16>, %arg1 : vector<4xf16>) -> vector<4xf4E2M1FN> {
+ %0 = arith.scaling_truncf %arg0, %arg1 : vector<4xf16>, vector<4xf16> to vector<4xf4E2M1FN>
+ return %0 : vector<4xf4E2M1FN>
+}
+// SCHECK-LABEL: @scaling_truncf_vector_f16_to_f4E2M1FN_using_f16_scales
+// SCHECK: %[[SCALETRUNCF:.+]] = arith.truncf %arg1 : vector<4xf16> to vector<4xf8E8M0FNU>
+// SCHECK: return
+
+// -----
+
+func.func @invalid_scaling_truncf_to_f4E2M1FN(%arg0: f16, %arg1 : f8E5M2FNUZ) -> f4E2M1FN {
+ // expected-error at +1 {{failed to legalize operation 'arith.scaling_truncf' that was explicitly marked illegal}}
+ %0 = arith.scaling_truncf %arg0, %arg1 : f16, f8E5M2FNUZ to f4E2M1FN
+ return %0 : f4E2M1FN
+}
+
+// -----
+
func.func @extf_f8E8M0FNU_to_f32(%arg0 : f8E8M0FNU) -> f32 {
%0 = arith.extf %arg0 : f8E8M0FNU to f32
return %0 : f32
@@ -332,7 +400,7 @@ func.func @extf_f8E8M0FNU_to_f16(%arg0 : f8E8M0FNU) -> f16 {
return %0 : f16
}
-// CHECK-LABLE: @extf_f8E8M0FNU_to_f16
+// CHECK-LABEL: @extf_f8E8M0FNU_to_f16
// CHECK: %[[BITCAST:.+]] = arith.bitcast %arg0 : f8E8M0FNU to i8
// CHECK-DAG: %[[CF8NAN:.+]] = arith.constant -1 : i8
// CHECK-DAG: %[[CF32NAN:.+]] = arith.constant -1 : i32
@@ -374,7 +442,109 @@ func.func @extf_vector_f8E8M0FNU_to_bf16(%arg0 : vector<4xf8E8M0FNU>) -> vector<
// CHECK-LABEL: @extf_vector_f8E8M0FNU_to_bf16
// CHECK-NOT: arith.extf
+// CHECK: return
+
+// -----
+
+func.func @scaling_extf_to_f32(%arg0: f4E2M1FN, %arg1 : f8E8M0FNU) -> f32 {
+ %0 = arith.scaling_extf %arg0, %arg1 : f4E2M1FN, f8E8M0FNU to f32
+ return %0 : f32
+}
+
+// SCHECK-LABEL: @scaling_extf_to_f32
+// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %arg1 : f8E8M0FNU to f32
+// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : f4E2M1FN to f32
+// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : f32
+// SCHECK: return %[[RESULT]]
+
+// -----
+
+func.func @scaling_extf_to_f32_using_f16_scales(%arg0: f4E2M1FN, %arg1 : f16) -> f32 {
+ %0 = arith.scaling_extf %arg0, %arg1 : f4E2M1FN, f16 to f32
+ return %0 : f32
+}
+
+// SCHECK-LABEL: @scaling_extf_to_f32_using_f16_scales
+// SCHECK: %[[TRUNCF_SCALE:.+]] = arith.truncf %arg1 : f16 to f8E8M0FNU
+// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %[[TRUNCF_SCALE]] : f8E8M0FNU to f32
+// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : f4E2M1FN to f32
+// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : f32
+// SCHECK: return %[[RESULT]]
+
+// -----
+
+func.func @invalid_scaling_extf_to_f32(%arg0: f4E2M1FN, %arg1 : f8E5M2FNUZ) -> f32 {
+ // expected-error at +1 {{failed to legalize operation 'arith.scaling_extf' that was explicitly marked illegal}}
+ %0 = arith.scaling_extf %arg0, %arg1 : f4E2M1FN, f8E5M2FNUZ to f32
+ return %0 : f32
+}
+
+// -----
+
+func.func @scaling_extf_vector_to_f32(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf8E8M0FNU>) -> vector<4xf32> {
+ %0 = arith.scaling_extf %arg0, %arg1 : vector<4xf4E2M1FN>, vector<4xf8E8M0FNU> to vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// SCHECK-LABEL: @scaling_extf_vector_to_f32
+// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %arg1 : vector<4xf8E8M0FNU> to vector<4xf32>
+// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : vector<4xf4E2M1FN> to vector<4xf32>
+// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : vector<4xf32>
+// SCHECK: return %[[RESULT]]
+
+// -----
+
+func.func @scaling_extf_vector_to_f16(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf8E8M0FNU>) -> vector<4xf16> {
+ %0 = arith.scaling_extf %arg0, %arg1 : vector<4xf4E2M1FN>, vector<4xf8E8M0FNU> to vector<4xf16>
+ return %0 : vector<4xf16>
+}
+
+// SCHECK-LABEL: @scaling_extf_vector_to_f16
+// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %arg1 : vector<4xf8E8M0FNU> to vector<4xf16>
+// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : vector<4xf4E2M1FN> to vector<4xf16>
+// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : vector<4xf16>
+// SCHECK: return %[[RESULT]]
+
+// -----
+
+func.func @scaling_extf_vector_to_bf16(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf8E8M0FNU>) -> vector<4xbf16> {
+ %0 = arith.scaling_extf %arg0, %arg1 : vector<4xf4E2M1FN>, vector<4xf8E8M0FNU> to vector<4xbf16>
+ return %0 : vector<4xbf16>
+}
+
+// SCHECK-LABEL: @scaling_extf_vector_to_bf16
+// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %arg1 : vector<4xf8E8M0FNU> to vector<4xbf16>
+// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : vector<4xf4E2M1FN> to vector<4xbf16>
+// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : vector<4xbf16>
+// SCHECK: return %[[RESULT]]
+
+// -----
+
+func.func @scaling_extf_vector_to_f32_using_f16_scales(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf16>) -> vector<4xf32> {
+ %0 = arith.scaling_extf %arg0, %arg1 : vector<4xf4E2M1FN>, vector<4xf16> to vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// SCHECK-LABEL: @scaling_extf_vector_to_f32_using_f16_scales
+// SCHECK: %[[TRUNCF_SCALE:.+]] = arith.truncf %arg1 : vector<4xf16> to vector<4xf8E8M0FNU>
+// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %[[TRUNCF_SCALE]] : vector<4xf8E8M0FNU> to vector<4xf32>
+// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : vector<4xf4E2M1FN> to vector<4xf32>
+// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : vector<4xf32>
+// SCHECK: return %[[RESULT]]
+
+// -----
+
+func.func @scaling_extf_vector_to_f32_using_f16_scales_fastmath(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf16>) -> vector<4xf32> {
+ %0 = arith.scaling_extf %arg0, %arg1 fastmath<fast> : vector<4xf4E2M1FN>, vector<4xf16> to vector<4xf32>
+ return %0 : vector<4xf32>
+}
+// SCHECK-LABEL: @scaling_extf_vector_to_f32_using_f16_scales_fastmath
+// SCHECK: %[[TRUNCF_SCALE:.+]] = arith.truncf %arg1 fastmath<fast> : vector<4xf16> to vector<4xf8E8M0FNU>
+// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %[[TRUNCF_SCALE]] fastmath<fast> : vector<4xf8E8M0FNU> to vector<4xf32>
+// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 fastmath<fast> : vector<4xf4E2M1FN> to vector<4xf32>
+// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] fastmath<fast> : vector<4xf32>
+// SCHECK: return %[[RESULT]]
// -----
More information about the Mlir-commits
mailing list