[Mlir-commits] [mlir] aa5558d - [mlir][ArithToAMDGPU] limit scaling truncf/extf support to gfx950 (#155431)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Sep 17 12:26:57 PDT 2025


Author: Muzammil
Date: 2025-09-17T14:26:52-05:00
New Revision: aa5558d12cbbb6a039955fcb174129b57d182642

URL: https://github.com/llvm/llvm-project/commit/aa5558d12cbbb6a039955fcb174129b57d182642
DIFF: https://github.com/llvm/llvm-project/commit/aa5558d12cbbb6a039955fcb174129b57d182642.diff

LOG: [mlir][ArithToAMDGPU] limit scaling truncf/extf support to gfx950 (#155431)

The current chip guard fails to prevent scaling_extf/truncf patterns
from being applied on gfx1100 which does not have scaling support.

---------

Signed-off-by: Muzammiluddin Syed <muzasyed at amd.com>

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h
    mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
    mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir
    mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h b/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h
index f4a9518839224..fd144edf77452 100644
--- a/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h
+++ b/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h
@@ -28,12 +28,10 @@ namespace arith {
 /// is set, values outside the range of the destination type are clamped
 /// to the largest value of that type instead of being rewritten to Inf (aka
 /// NaN).
-void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns,
-                                             bool convertFP8Arithmetic,
-                                             bool saturateFP8Truncf,
-                                             bool allowPackedF16Rtz,
-                                             amdgpu::Chipset chipset,
-                                             PatternBenefit benefit = 1);
+void populateArithToAMDGPUConversionPatterns(
+    RewritePatternSet &patterns, bool convertFP8Arithmetic,
+    bool saturateFP8Truncf, bool allowPackedF16Rtz, bool supportsScaledExtTrunc,
+    amdgpu::Chipset chipset, PatternBenefit benefit = 1);
 } // namespace arith
 } // namespace mlir
 

diff  --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 8230591123661..3d6f6cab42244 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -690,8 +690,8 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
 
 void mlir::arith::populateArithToAMDGPUConversionPatterns(
     RewritePatternSet &patterns, bool convertFP8Arithmetic,
-    bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset,
-    PatternBenefit benefit) {
+    bool saturateFP8Truncf, bool allowPackedF16Rtz, bool supportsScaledExtTrunc,
+    Chipset chipset, PatternBenefit benefit) {
 
   if (convertFP8Arithmetic) {
     patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext(), chipset,
@@ -702,7 +702,7 @@ void mlir::arith::populateArithToAMDGPUConversionPatterns(
   if (allowPackedF16Rtz)
     patterns.add<TruncfToFloat16RewritePattern>(patterns.getContext(), benefit);
 
-  if (chipset >= kGfx950) {
+  if (supportsScaledExtTrunc) {
     patterns.add<ScalingExtFRewritePattern>(patterns.getContext(), benefit);
     patterns.add<ScalingTruncFRewritePattern>(patterns.getContext(), benefit);
   }
@@ -720,9 +720,10 @@ void ArithToAMDGPUConversionPass::runOnOperation() {
 
   bool convertFP8Arithmetic =
       *maybeChipset == kGfx942 || hasOcpFp8(*maybeChipset);
+  bool supportsScaledExtTrunc = *maybeChipset == kGfx950;
   arith::populateArithToAMDGPUConversionPatterns(
       patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz,
-      *maybeChipset);
+      supportsScaledExtTrunc, *maybeChipset);
   if (failed(applyPatternsGreedily(op, std::move(patterns))))
     return signalPassFailure();
 }

diff  --git a/mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir b/mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir
index 1d36be1108d26..a2b0aef594e61 100644
--- a/mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir
+++ b/mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir
@@ -1,4 +1,5 @@
 // RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu="chipset=gfx950" | FileCheck %s
+// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu="chipset=gfx1100" | FileCheck %s --check-prefix=CHECK-GFX1100
 
 // CHECK-LABEL: @conversion_f8_f32_fallback
 // CHECK:         %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<2x2xf32>
@@ -241,6 +242,9 @@ func.func @conversion_broadcast(%in: vector<4xf8E5M2>, %scale: f8E8M0FNU) -> vec
 
 // -----
 
+// CHECK-GFX1100-LABEL: @conversion_scalar
+// CHECK-GFX1100: arith.scaling_extf
+
 // CHECK-LABEL: @conversion_scalar
 // CHECK:         %[[SCALE_F32:.+]] = arith.extf %arg1 : f8E8M0FNU to f32
 // CHECK-NEXT:    %[[SPLAT_IN:.+]] = vector.broadcast %arg0 : f8E5M2 to vector<1xf8E5M2>

diff  --git a/mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir b/mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir
index 90a86084ac93f..bc2c6a5aa0275 100644
--- a/mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir
+++ b/mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir
@@ -1,4 +1,5 @@
 // RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu="chipset=gfx950" | FileCheck %s
+// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu="chipset=gfx1100" | FileCheck %s --check-prefix=CHECK-GFX1100
 
 // CHECK-LABEL: @conversion_f8_fallback
 // CHECK-DAG:     %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<2x2xf8E5M2>
@@ -163,6 +164,9 @@ func.func @conversion_broadcast(%in: vector<4xf32>, %scale: f8E8M0FNU) -> vector
 
 // -----
 
+// CHECK-GFX1100-LABEL: @conversion_scalar
+// CHECK-GFX1100: arith.scaling_truncf
+
 // CHECK-LABEL: @conversion_scalar
 // CHECK:         %[[SCALE_F32:.+]] = arith.extf %arg1 : f8E8M0FNU to f32
 // CHECK-NEXT:    %[[SPLAT_IN:.+]] = vector.broadcast %arg0 : f32 to vector<1xf32>


        


More information about the Mlir-commits mailing list