[Mlir-commits] [mlir] 7f08503 - Introduce `arith.scaling_extf` and `arith.scaling_truncf` (#141965)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jun 9 11:13:34 PDT 2025


Author: Umang Yadav
Date: 2025-06-09T13:13:31-05:00
New Revision: 7f08503a3bf3acdd2a58ac712d5e95682ce583dd

URL: https://github.com/llvm/llvm-project/commit/7f08503a3bf3acdd2a58ac712d5e95682ce583dd
DIFF: https://github.com/llvm/llvm-project/commit/7f08503a3bf3acdd2a58ac712d5e95682ce583dd.diff

LOG: Introduce `arith.scaling_extf` and `arith.scaling_truncf` (#141965)

This PR adds `arith.scaling_truncf` and `arith.scaling_extf` operations
which supports the block quantization following OCP MXFP specs listed
here
https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf

OCP MXFP Spec comes with reference implementation here
https://github.com/microsoft/microxcaling/tree/main

Interesting piece of reference code is this method `_quantize_mx`
https://github.com/microsoft/microxcaling/blob/7bc41952de394f5cc5e782baf132e7c7542eb4e4/mx/mx_ops.py#L173.

Both `arith.scaling_truncf` and `arith.scaling_extf` are designed to be
an elementwise operation. Please see description about them in
`ArithOps.td` file for more details.
 
Internally, 

`arith.scaling_truncf` does the
`arith.truncf(arith.divf(input/(2^scale)))`. `scale` should have
necessary broadcast, clamping, normalization and NaN propagation done
before callling into `arith.scaling_truncf`.

`arith.scaling_extf` does the `arith.mulf(2^scale, input)` after taking
care of necessary data type conversions.


CC: @krzysz00 @dhernandez0 @bjacob @pashu123 @MaheshRavishankar
@tgymnich

---------

Co-authored-by: Prashant Kumar <pk5561 at gmail.com>
Co-authored-by: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
    mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
    mlir/include/mlir/IR/Builders.h
    mlir/lib/Dialect/Arith/IR/ArithOps.cpp
    mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
    mlir/lib/IR/Builders.cpp
    mlir/test/Dialect/Arith/expand-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 599b3b982ec7f..adc27ae6bdafb 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1215,6 +1215,58 @@ def Arith_ExtFOp : Arith_FToFCastOp<"extf", [DeclareOpInterfaceMethods<ArithFast
                           attr-dict `:` type($in) `to` type($out) }];
 }
 
+//===----------------------------------------------------------------------===//
+// Scaling ExtFOp
+//===----------------------------------------------------------------------===//
+def Arith_ScalingExtFOp
+    : Arith_Op<
+          "scaling_extf", [Pure, SameInputOutputTensorDims,
+                           DeclareOpInterfaceMethods<ArithFastMathInterface>,
+                           DeclareOpInterfaceMethods<CastOpInterface>]>,
+      Arguments<(ins FloatLike:$in, FloatLike:$scale,
+          OptionalAttr<Arith_FastMathAttr>:$fastmath)>,
+      Results<(outs FloatLike:$out)> {
+  let summary = "Upcasts input floats using provided scales values following "
+                "OCP MXFP Spec";
+  let description = [{
+  This operation upcasts input floating-point values using provided scale 
+  values. It expects both scales and the input operand to be of the same shape, 
+  making the operation elementwise. Scales are usually calculated per block 
+  following the OCP MXFP spec as described in https://arxiv.org/abs/2310.10537.
+
+  If scales are calculated per block where blockSize != 1, then scales may 
+  require broadcasting to make this operation elementwise. For example, let's 
+  say the input is of shape `<dim1 x dim2 x ... dimN>`. Given blockSize != 1 and 
+  assuming quantization happens on the last axis, the input can be reshaped to 
+  `<dim1 x dim2 x ... (dimN/blockSize) x blockSize>`. Scales will be calculated 
+  per block on the last axis. Therefore, scales will be of shape 
+  `<dim1 x dim2 x ... (dimN/blockSize) x 1>`. Scales could also be of some other 
+  shape as long as it is broadcast compatible with the input, e.g., 
+  `<1 x 1 x ... (dimN/blockSize) x 1>`.
+
+  In this example, before calling into `arith.scaling_extf`, scales must be 
+  broadcasted to `<dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>`. Note 
+  that there could be multiple quantization axes. Internally, 
+  `arith.scaling_extf` would perform the following:
+ 
+    ```
+    resultTy = get_type(result) 
+    scaleTy  = get_type(scale)
+    inputTy = get_type(input)
+    scale.exponent = arith.truncf(scale) : scaleTy to f8E8M0
+    scale.extf = arith.extf(scale.exponent) : f8E8M0 to resultTy
+    input.extf = arith.extf(input) : inputTy to resultTy
+    result = arith.mulf(scale.extf, input.extf)
+    ```
+    It propagates NaN values. Therefore, if either scale or the input element 
+    contains NaN, then the output element value will also be a NaN.
+  }];
+  let hasVerifier = 1;
+  let assemblyFormat =
+      [{ $in `,` $scale (`fastmath` `` $fastmath^)? attr-dict `:` 
+      type($in) `,` type($scale) `to` type($out)}];
+}
+
 //===----------------------------------------------------------------------===//
 // TruncIOp
 //===----------------------------------------------------------------------===//
@@ -1280,6 +1332,63 @@ def Arith_TruncFOp :
                           attr-dict `:` type($in) `to` type($out) }];
 }
 
+//===----------------------------------------------------------------------===//
+// Scaling TruncFOp
+//===----------------------------------------------------------------------===//
+
+def Arith_ScalingTruncFOp
+    : Arith_Op<"scaling_truncf",
+               [Pure, SameInputOutputTensorDims,
+                DeclareOpInterfaceMethods<ArithRoundingModeInterface>,
+                DeclareOpInterfaceMethods<ArithFastMathInterface>,
+                DeclareOpInterfaceMethods<CastOpInterface>]>,
+      Arguments<(ins FloatLike:$in, FloatLike:$scale,
+          OptionalAttr<Arith_RoundingModeAttr>:$roundingmode,
+          OptionalAttr<Arith_FastMathAttr>:$fastmath)>,
+      Results<(outs FloatLike:$out)> {
+  let summary = "Downcasts input floating point values using provided scales "
+                "values following OCP MXFP Spec";
+  let description = [{
+    This operation downcasts input using the provided scale values. It expects 
+    both scales and the input operand to be of the same shape and, therefore, 
+    makes the operation elementwise. Scales are usually calculated per block 
+    following the OCP MXFP spec as described in https://arxiv.org/abs/2310.10537.
+    Users are required to normalize and clamp the scales as necessary before calling
+    passing them to this operation.  OCP MXFP spec also does the flushing of denorms
+    on the input operand, which should be handled during lowering by passing appropriate 
+    fastMath flag to this operation. 
+
+    If scales are calculated per block where blockSize != 1, scales may require 
+    broadcasting to make this operation elementwise. For example, let's say the 
+    input is of shape `<dim1 x dim2 x ... dimN>`. Given blockSize != 1 and 
+    assuming quantization happens on the last axis, the input can be reshaped to 
+    `<dim1 x dim2 x ... (dimN/blockSize) x blockSize>`. Scales will be calculated 
+    per block on the last axis. Therefore, scales will be of shape 
+    `<dim1 x dim2 x ... (dimN/blockSize) x 1>`. Scales could also be of some other 
+    shape as long as it is broadcast compatible with the input, e.g., 
+    `<1 x 1 x ... (dimN/blockSize) x 1>`.
+
+    In this example, before calling into `arith.scaling_truncf`, scales must be 
+    broadcasted to `<dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>`. Note 
+    that there could be multiple quantization axes. Internally, 
+    `arith.scaling_truncf` would perform the following:
+
+    ```
+    scaleTy = get_type(scale)
+    inputTy = get_type(input)
+    resultTy = get_type(result)
+    scale.exponent = arith.truncf(scale) : scaleTy to f8E8M0
+    scale.extf = arith.extf(scale.exponent) : f8E8M0 to inputTy
+    result = arith.divf(input, scale.extf)
+    result.cast = arith.truncf(result, resultTy)
+    ```
+  }];
+  let hasVerifier = 1;
+  let assemblyFormat =
+      [{ $in `,` $scale ($roundingmode^)? (`fastmath` `` $fastmath^)? attr-dict `:` 
+      type($in) `,` type($scale) `to` type($out)}];
+}
+
 //===----------------------------------------------------------------------===//
 // UIToFPOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
index 5aaac8d8e3dc5..e0a4567d6f406 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
@@ -62,6 +62,9 @@ void populateExpandBFloat16Patterns(RewritePatternSet &patterns);
 /// Add patterns to expand Arith f8e8m0 patterns to lower level bitcasts/shifts.
 void populateExpandF8E8M0Patterns(RewritePatternSet &patterns);
 
+/// Add patterns to expand scaling ExtF/TruncF ops to equivalent arith ops
+void populateExpandScalingExtTruncPatterns(RewritePatternSet &patterns);
+
 /// Add patterns to expand Arith ops.
 void populateArithExpandOpsPatterns(RewritePatternSet &patterns);
 

diff  --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 3f7b3268dd085..d68dbdb1efeef 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -60,6 +60,7 @@ class Builder {
                        Attribute metadata = Attribute());
 
   // Types.
+  FloatType getF8E8M0Type();
   FloatType getBF16Type();
   FloatType getF16Type();
   FloatType getTF32Type();

diff  --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 41f2d0f3425e2..9e53e195274aa 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1451,6 +1451,19 @@ bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
 
 LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }
 
+//===----------------------------------------------------------------------===//
+// ScalingExtFOp
+//===----------------------------------------------------------------------===//
+
+bool arith::ScalingExtFOp::areCastCompatible(TypeRange inputs,
+                                             TypeRange outputs) {
+  return checkWidthChangeCast<std::greater, FloatType>(inputs.front(), outputs);
+}
+
+LogicalResult arith::ScalingExtFOp::verify() {
+  return verifyExtOp<FloatType>(*this);
+}
+
 //===----------------------------------------------------------------------===//
 // TruncIOp
 //===----------------------------------------------------------------------===//
@@ -1565,6 +1578,19 @@ LogicalResult arith::TruncFOp::verify() {
   return verifyTruncateOp<FloatType>(*this);
 }
 
+//===----------------------------------------------------------------------===//
+// ScalingTruncFOp
+//===----------------------------------------------------------------------===//
+
+bool arith::ScalingTruncFOp::areCastCompatible(TypeRange inputs,
+                                               TypeRange outputs) {
+  return checkWidthChangeCast<std::less, FloatType>(inputs.front(), outputs);
+}
+
+LogicalResult arith::ScalingTruncFOp::verify() {
+  return verifyTruncateOp<FloatType>(*this);
+}
+
 //===----------------------------------------------------------------------===//
 // AndIOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 95546bb09e765..534aff9562b7a 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -6,10 +6,10 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Arith/Transforms/Passes.h"
-
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Transforms/Passes.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Transforms/DialectConversion.h"
@@ -31,7 +31,6 @@ static Value createConst(Location loc, Type type, int value,
     return rewriter.create<arith::ConstantOp>(
         loc, DenseElementsAttr::get(shapedTy, attr));
   }
-
   return rewriter.create<arith::ConstantOp>(loc, attr);
 }
 
@@ -357,9 +356,10 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
     f32Bits = b.create<arith::SelectOp>(isNan, cF32NaN, f32Bits);
     Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
     if (resultETy.getIntOrFloatBitWidth() < 32) {
-      result = b.create<arith::TruncFOp>(resultTy, result);
+      result = b.create<arith::TruncFOp>(resultTy, result, nullptr,
+                                         op.getFastmathAttr());
     } else if (resultETy.getIntOrFloatBitWidth() > 32) {
-      result = b.create<arith::ExtFOp>(resultTy, result);
+      result = b.create<arith::ExtFOp>(resultTy, result, op.getFastmathAttr());
     }
     rewriter.replaceOp(op, result);
     return success();
@@ -395,9 +395,10 @@ struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
     Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
 
     if (operandETy.getIntOrFloatBitWidth() < 32) {
-      operand = b.create<arith::ExtFOp>(f32Ty, operand);
+      operand = b.create<arith::ExtFOp>(f32Ty, operand, op.getFastmathAttr());
     } else if (operandETy.getIntOrFloatBitWidth() > 32) {
-      operand = b.create<arith::TruncFOp>(f32Ty, operand);
+      operand = b.create<arith::TruncFOp>(
+          f32Ty, operand, op.getRoundingmodeAttr(), op.getFastmathAttr());
     }
     Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operand);
     Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
@@ -409,6 +410,83 @@ struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
   }
 };
 
+struct ScalingExtFOpConverter : public OpRewritePattern<arith::ScalingExtFOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(arith::ScalingExtFOp op,
+                                PatternRewriter &rewriter) const final {
+    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    Value inputOperand = op.getIn();
+    Value scaleOperand = op.getScale();
+    Type scaleTy = scaleOperand.getType();
+    Type scaleETy = getElementTypeOrSelf(scaleOperand);
+    // allow implicit exponent extraction from 16/32 bits floats
+    if (scaleETy.getIntOrFloatBitWidth() >= 16) {
+      scaleETy = b.getF8E8M0Type();
+      scaleTy = cloneToShapedType(scaleTy, scaleETy);
+      scaleOperand = b.create<arith::TruncFOp>(scaleTy, scaleOperand, nullptr,
+                                               op.getFastmathAttr());
+    }
+    if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) {
+      return rewriter.notifyMatchFailure(
+          op, "scaling_extf is using scales of type which can not be converted "
+              "to f8E8M0FNU");
+    }
+    Type resultTy = op.getType();
+    // extf on scale will essentially create floating point number
+    // of type resulTy that is 2^scale and will also propagate NaNs
+    Value scaleExt =
+        b.create<arith::ExtFOp>(resultTy, scaleOperand, op.getFastmathAttr());
+    Value inputExt =
+        b.create<arith::ExtFOp>(resultTy, inputOperand, op.getFastmathAttr());
+    Value result =
+        b.create<arith::MulFOp>(inputExt, scaleExt, op.getFastmathAttr());
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
+/*
+Expands arith.ScalingTruncFOp(in, scale) into
+  scale = arith.truncf(scale) : scaleTy -> f8E8M0FNU
+  result = arith.truncf(in / (2^scale))
+ */
+struct ScalingTruncFOpConverter
+    : public OpRewritePattern<arith::ScalingTruncFOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(arith::ScalingTruncFOp op,
+                                PatternRewriter &rewriter) const final {
+    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    Value inputOperand = op.getIn();
+    Value scaleOperand = op.getScale();
+    Type scaleTy = scaleOperand.getType();
+    Type scaleETy = getElementTypeOrSelf(scaleOperand);
+    // allow implicit exponent extraction from 16/32 bits floats
+    if (scaleETy.getIntOrFloatBitWidth() >= 16) {
+      scaleETy = b.getF8E8M0Type();
+      scaleTy = cloneToShapedType(scaleTy, scaleETy);
+      scaleOperand = b.create<arith::TruncFOp>(scaleTy, scaleOperand, nullptr,
+                                               op.getFastmathAttr());
+    }
+    if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) {
+      return rewriter.notifyMatchFailure(
+          op, "scaling_truncf is using scales type which can not be converted "
+              "to f8E8M0FNU");
+    }
+    Type resultTy = op.getType();
+    Type inputTy = inputOperand.getType();
+    // this will create a floating point number of type
+    // inputTy that is 2^scale and will also propagate NaNs
+    scaleOperand =
+        b.create<arith::ExtFOp>(inputTy, scaleOperand, op.getFastmathAttr());
+    Value result = b.create<arith::DivFOp>(inputOperand, scaleOperand,
+                                           op.getFastmathAttr());
+    Value resultCast = b.create<arith::TruncFOp>(
+        resultTy, result, op.getRoundingmodeAttr(), op.getFastmathAttr());
+    rewriter.replaceOp(op, resultCast);
+    return success();
+  }
+};
+
 struct ArithExpandOpsPass
     : public arith::impl::ArithExpandOpsPassBase<ArithExpandOpsPass> {
   using ArithExpandOpsPassBase::ArithExpandOpsPassBase;
@@ -432,7 +510,9 @@ struct ArithExpandOpsPass
       arith::MaximumFOp,
       arith::MinimumFOp,
       arith::MaxNumFOp,
-      arith::MinNumFOp
+      arith::MinNumFOp,
+      arith::ScalingExtFOp,
+      arith::ScalingTruncFOp
     >();
 
     if (includeBf16) {
@@ -492,8 +572,15 @@ void mlir::arith::populateExpandF8E8M0Patterns(RewritePatternSet &patterns) {
       patterns.getContext());
 }
 
+void mlir::arith::populateExpandScalingExtTruncPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<ScalingExtFOpConverter, ScalingTruncFOpConverter>(
+      patterns.getContext());
+}
+
 void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) {
   populateCeilFloorDivExpandOpsPatterns(patterns);
+  populateExpandScalingExtTruncPatterns(patterns);
   // clang-format off
   patterns.add<
     MaxMinIOpConverter<MaxSIOp, arith::CmpIPredicate::sgt>,
@@ -503,7 +590,7 @@ void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) {
     MaximumMinimumFOpConverter<MaximumFOp, arith::CmpFPredicate::UGT>,
     MaximumMinimumFOpConverter<MinimumFOp, arith::CmpFPredicate::ULT>,
     MaxNumMinNumFOpConverter<MaxNumFOp, arith::CmpFPredicate::UGT>,
-    MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT>
+    MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT> 
    >(patterns.getContext());
   // clang-format on
 }

diff  --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 89102115cdc40..5f7bc50afc418 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -34,6 +34,8 @@ Location Builder::getFusedLoc(ArrayRef<Location> locs, Attribute metadata) {
 // Types.
 //===----------------------------------------------------------------------===//
 
+FloatType Builder::getF8E8M0Type() { return Float8E8M0FNUType::get(context); }
+
 FloatType Builder::getBF16Type() { return BFloat16Type::get(context); }
 
 FloatType Builder::getF16Type() { return Float16Type::get(context); }

diff  --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir
index 5b6badf13d763..db1349feaff3a 100644
--- a/mlir/test/Dialect/Arith/expand-ops.mlir
+++ b/mlir/test/Dialect/Arith/expand-ops.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt %s -arith-expand="include-bf16=true include-f8e8m0=true" -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -arith-expand="include-bf16=true include-f8e8m0=true" -verify-diagnostics -split-input-file | FileCheck %s 
+// RUN: mlir-opt %s -arith-expand -split-input-file -verify-diagnostics | FileCheck %s --check-prefix=SCHECK
 
 // Test ceil divide with signed integer
 // CHECK-LABEL:       func @ceildivi
@@ -253,7 +254,7 @@ func.func @truncf_f32_to_f8E8M0FNU(%arg0 : f32) -> f8E8M0FNU {
     %0 = arith.truncf %arg0 : f32 to f8E8M0FNU
     return %0 : f8E8M0FNU
 }
-// CHECK-LABLE: @truncf_f32_to_f8E8M0FNU
+// CHECK-LABEL: @truncf_f32_to_f8E8M0FNU
 // CHECK: %[[BITCAST:.+]] = arith.bitcast %arg0 : f32 to i32
 // CHECK: %[[C23_i32:.+]] = arith.constant 23 : i32
 // CHECK: %[[SHRUI:.+]] = arith.shrui %[[BITCAST]], %[[C23_i32]] : i32
@@ -267,7 +268,7 @@ func.func @truncf_f16_to_f8E8M0FNU(%arg0 : f16) -> f8E8M0FNU {
     %0 = arith.truncf %arg0 : f16 to f8E8M0FNU
     return %0 : f8E8M0FNU
 }
-// CHECK-LABLE: @truncf_f16_to_f8E8M0FNU
+// CHECK-LABEL: @truncf_f16_to_f8E8M0FNU
 // CHECK: %[[EXTF:.+]] = arith.extf %arg0 : f16 to f32
 // CHECK: %[[BITCAST:.+]] = arith.bitcast %[[EXTF]] : f32 to i32
 // CHECK: %[[C23_i32:.+]] = arith.constant 23 : i32
@@ -305,9 +306,76 @@ func.func @truncf_vector_bf16_to_f8E8M0FNU(%arg0 : vector<4xbf16>) -> vector<4xf
 
 // CHECK-LABEL: @truncf_vector_bf16_to_f8E8M0FNU
 // CHECK-NOT: arith.truncf
+// CHECK: return
 
+// -----
+
+func.func @scaling_truncf_f32_to_f4E2M1FN(%arg0 : f32, %arg1: f8E8M0FNU) -> f4E2M1FN {
+    %0 = arith.scaling_truncf %arg0, %arg1 : f32, f8E8M0FNU to f4E2M1FN
+    return %0 : f4E2M1FN
+}
+
+// SCHECK-LABEL: @scaling_truncf_f32_to_f4E2M1FN
+// SCHECK: %[[SCALEF32:.+]] = arith.extf %arg1 : f8E8M0FNU to f32
+// SCHECK: %[[DIVF:.+]] = arith.divf %arg0, %[[SCALEF32]] : f32
+// SCHECK: %[[RESULT:.+]] = arith.truncf %[[DIVF]] : f32 to f4E2M1FN
+// SCHECK: return %[[RESULT]]
+
+// -----
+
+func.func @scaling_truncf_vector_f16_to_f6E3M2FN(%arg0 : vector<4xf16>, %arg1: vector<4xf8E8M0FNU>) -> vector<4xf6E3M2FN> {
+    %0 = arith.scaling_truncf %arg0, %arg1 : vector<4xf16>, vector<4xf8E8M0FNU> to vector<4xf6E3M2FN>
+    return %0 : vector<4xf6E3M2FN>
+}
+
+// SCHECK-LABEL: @scaling_truncf_vector_f16_to_f6E3M2FN
+// SCHECK: %[[SCALEF16:.+]] = arith.extf %arg1 : vector<4xf8E8M0FNU> to vector<4xf16>
+// SCHECK: %[[DIVF:.+]] = arith.divf %arg0, %[[SCALEF16]] : vector<4xf16>
+// SCHECK: %[[RESULT:.+]] = arith.truncf %[[DIVF]] : vector<4xf16> to vector<4xf6E3M2FN>
+// SCHECK: return %[[RESULT]] : vector<4xf6E3M2FN>
 
 // -----
+
+func.func @scaling_truncf_propagate_rounding_mode_fast_math(%arg0 : vector<4xf16>, %arg1: vector<4xf16>) -> vector<4xf6E3M2FN> {
+    %0 = arith.scaling_truncf %arg0, %arg1 to_nearest_even fastmath<fast> : vector<4xf16>, vector<4xf16> to vector<4xf6E3M2FN>
+    return %0 : vector<4xf6E3M2FN>
+}
+// SCHECK-LABEL: @scaling_truncf_propagate_rounding_mode_fast_math
+// SCHECK: %[[SCALEF8:.+]] = arith.truncf %arg1 fastmath<fast> : vector<4xf16> to vector<4xf8E8M0FNU>
+// SCHECK: %[[SCALEINTY:.+]] = arith.extf %[[SCALEF8]] fastmath<fast> : vector<4xf8E8M0FNU> to vector<4xf16>
+// SCHECK: %[[DIVF:.+]] = arith.divf %arg0, %[[SCALEINTY]] fastmath<fast> : vector<4xf16>
+// SCHECK: %[[TRUNCF:.+]] = arith.truncf [[_:%[a-zA-Z0-9_]+]] to_nearest_even fastmath<fast> : vector<4xf16> to vector<4xf6E3M2FN>
+// SCHECK: return %[[TRUNCF]] : vector<4xf6E3M2FN>
+
+// -----
+
+func.func @scaling_truncf_f16_to_f4E2M1FN_using_f16_scales(%arg0: f16, %arg1 : f16) -> f4E2M1FN {
+    %0 = arith.scaling_truncf %arg0, %arg1 : f16, f16 to f4E2M1FN
+    return %0 : f4E2M1FN
+}
+// SCHECK-LABEL: @scaling_truncf_f16_to_f4E2M1FN_using_f16_scales
+// SCHECK: %[[SCALETRUNCF:.+]] = arith.truncf %arg1 : f16 to f8E8M0FN
+// SCHECK: return
+
+// -----
+func.func @scaling_truncf_vector_f16_to_f4E2M1FN_using_f16_scales(%arg0: vector<4xf16>, %arg1 : vector<4xf16>) -> vector<4xf4E2M1FN> {
+    %0 = arith.scaling_truncf %arg0, %arg1 : vector<4xf16>, vector<4xf16> to vector<4xf4E2M1FN>
+    return %0 : vector<4xf4E2M1FN>
+}
+// SCHECK-LABEL: @scaling_truncf_vector_f16_to_f4E2M1FN_using_f16_scales
+// SCHECK: %[[SCALETRUNCF:.+]] = arith.truncf %arg1 : vector<4xf16> to vector<4xf8E8M0FNU>
+// SCHECK: return
+
+// -----
+
+func.func @invalid_scaling_truncf_to_f4E2M1FN(%arg0: f16, %arg1 : f8E5M2FNUZ) -> f4E2M1FN {
+    // expected-error at +1 {{failed to legalize operation 'arith.scaling_truncf' that was explicitly marked illegal}}
+    %0 = arith.scaling_truncf %arg0, %arg1 : f16, f8E5M2FNUZ to f4E2M1FN
+    return %0 : f4E2M1FN
+}
+
+// -----
+
 func.func @extf_f8E8M0FNU_to_f32(%arg0 : f8E8M0FNU) -> f32 {
     %0 = arith.extf %arg0 : f8E8M0FNU to f32
     return %0 : f32
@@ -332,7 +400,7 @@ func.func @extf_f8E8M0FNU_to_f16(%arg0 : f8E8M0FNU) -> f16 {
     return %0 : f16
 }
 
-// CHECK-LABLE: @extf_f8E8M0FNU_to_f16
+// CHECK-LABEL: @extf_f8E8M0FNU_to_f16
 // CHECK: %[[BITCAST:.+]] = arith.bitcast %arg0 : f8E8M0FNU to i8
 // CHECK-DAG: %[[CF8NAN:.+]] = arith.constant -1 : i8
 // CHECK-DAG: %[[CF32NAN:.+]] = arith.constant -1 : i32
@@ -374,7 +442,109 @@ func.func @extf_vector_f8E8M0FNU_to_bf16(%arg0 : vector<4xf8E8M0FNU>) -> vector<
 
 // CHECK-LABEL: @extf_vector_f8E8M0FNU_to_bf16
 // CHECK-NOT: arith.extf
+// CHECK: return
+
+// -----
+
+func.func @scaling_extf_to_f32(%arg0: f4E2M1FN, %arg1 : f8E8M0FNU) -> f32 {
+    %0 = arith.scaling_extf %arg0, %arg1 : f4E2M1FN, f8E8M0FNU to f32
+    return %0 : f32 
+}
+
+// SCHECK-LABEL: @scaling_extf_to_f32
+// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %arg1 : f8E8M0FNU to f32
+// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : f4E2M1FN to f32
+// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : f32
+// SCHECK: return %[[RESULT]]
+
+// -----
+
+func.func @scaling_extf_to_f32_using_f16_scales(%arg0: f4E2M1FN, %arg1 : f16) -> f32 {
+    %0 = arith.scaling_extf %arg0, %arg1 : f4E2M1FN, f16 to f32
+    return %0 : f32 
+}
+
+// SCHECK-LABEL: @scaling_extf_to_f32_using_f16_scales
+// SCHECK: %[[TRUNCF_SCALE:.+]] = arith.truncf %arg1 : f16 to f8E8M0FNU
+// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %[[TRUNCF_SCALE]] : f8E8M0FNU to f32
+// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : f4E2M1FN to f32
+// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : f32
+// SCHECK: return %[[RESULT]]
+
+// -----
+
+func.func @invalid_scaling_extf_to_f32(%arg0: f4E2M1FN, %arg1 : f8E5M2FNUZ) -> f32 {
+    // expected-error at +1 {{failed to legalize operation 'arith.scaling_extf' that was explicitly marked illegal}}
+    %0 = arith.scaling_extf %arg0, %arg1 : f4E2M1FN, f8E5M2FNUZ to f32
+    return %0 : f32
+}
+
+// -----
+
+func.func @scaling_extf_vector_to_f32(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf8E8M0FNU>) -> vector<4xf32> {
+    %0 = arith.scaling_extf %arg0, %arg1 : vector<4xf4E2M1FN>, vector<4xf8E8M0FNU> to vector<4xf32>
+    return %0 : vector<4xf32>
+}
+
+// SCHECK-LABEL: @scaling_extf_vector_to_f32
+// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %arg1 : vector<4xf8E8M0FNU> to vector<4xf32>
+// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : vector<4xf4E2M1FN> to vector<4xf32>
+// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : vector<4xf32> 
+// SCHECK: return %[[RESULT]]
+
+// -----
+
+func.func @scaling_extf_vector_to_f16(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf8E8M0FNU>) -> vector<4xf16> {
+    %0 = arith.scaling_extf %arg0, %arg1 : vector<4xf4E2M1FN>, vector<4xf8E8M0FNU> to vector<4xf16>
+    return %0 : vector<4xf16>
+}
+
+// SCHECK-LABEL: @scaling_extf_vector_to_f16
+// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %arg1 : vector<4xf8E8M0FNU> to vector<4xf16>
+// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : vector<4xf4E2M1FN> to vector<4xf16>
+// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : vector<4xf16> 
+// SCHECK: return %[[RESULT]]
+
+// -----
+
+func.func @scaling_extf_vector_to_bf16(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf8E8M0FNU>) -> vector<4xbf16> {
+    %0 = arith.scaling_extf %arg0, %arg1 : vector<4xf4E2M1FN>, vector<4xf8E8M0FNU> to vector<4xbf16>
+    return %0 : vector<4xbf16>
+}
+
+// SCHECK-LABEL: @scaling_extf_vector_to_bf16
+// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %arg1 : vector<4xf8E8M0FNU> to vector<4xbf16>
+// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : vector<4xf4E2M1FN> to vector<4xbf16>
+// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : vector<4xbf16> 
+// SCHECK: return %[[RESULT]]
+
+// -----
+
+func.func @scaling_extf_vector_to_f32_using_f16_scales(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf16>) -> vector<4xf32> {
+    %0 = arith.scaling_extf %arg0, %arg1 : vector<4xf4E2M1FN>, vector<4xf16> to vector<4xf32>
+    return %0 : vector<4xf32>
+}
+
+// SCHECK-LABEL: @scaling_extf_vector_to_f32_using_f16_scales
+// SCHECK: %[[TRUNCF_SCALE:.+]] = arith.truncf %arg1 : vector<4xf16> to vector<4xf8E8M0FNU>
+// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %[[TRUNCF_SCALE]] : vector<4xf8E8M0FNU> to vector<4xf32>
+// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : vector<4xf4E2M1FN> to vector<4xf32>
+// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : vector<4xf32>
+// SCHECK: return %[[RESULT]]
+
+// -----
+
+func.func @scaling_extf_vector_to_f32_using_f16_scales_fastmath(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf16>) -> vector<4xf32> {
+    %0 = arith.scaling_extf %arg0, %arg1 fastmath<fast> : vector<4xf4E2M1FN>, vector<4xf16> to vector<4xf32>
+    return %0 : vector<4xf32>
+}
 
+// SCHECK-LABEL: @scaling_extf_vector_to_f32_using_f16_scales_fastmath
+// SCHECK: %[[TRUNCF_SCALE:.+]] = arith.truncf %arg1 fastmath<fast> : vector<4xf16> to vector<4xf8E8M0FNU>
+// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %[[TRUNCF_SCALE]] fastmath<fast> : vector<4xf8E8M0FNU> to vector<4xf32>
+// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 fastmath<fast> : vector<4xf4E2M1FN> to vector<4xf32>
+// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] fastmath<fast> : vector<4xf32>
+// SCHECK: return %[[RESULT]]
 
 // -----
 


        


More information about the Mlir-commits mailing list