[Mlir-commits] [mlir] [mlir][tosa] Optimize block scaled cast sequences (PR #188018)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Mar 23 04:30:44 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-tosa
Author: Ian Tayler Lessa (IanTaylerLessa-arm)
<details>
<summary>Changes</summary>
Add a canonicalization pattern that will delete cast_from_block_scaled -> cast_to_block_scaled sequences when the input and output types and block sizes match.
Change-Id: I769e4f756df303d7906c0ef38aa489abdd166da1
---
Full diff: https://github.com/llvm/llvm-project/pull/188018.diff
3 Files Affected:
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+1)
- (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+53)
- (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+44)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index cab2bccfc27b3..5fc462d6e8a36 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2653,6 +2653,7 @@ def Tosa_CastToBlockScaledOp : Tosa_InferShapedTypeOp<"cast_to_block_scaled", [P
Extension<[Tosa_EXT_BF16, Tosa_EXT_MXFP]>
];
+ let hasCanonicalizer = 1;
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index b622cbedec1dc..9fb928e1ffefb 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -935,6 +935,59 @@ void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<NonNarrowingCastsOptimization>(context);
}
+struct CancellingBlockScaledCastsOptimization
+ : public OpRewritePattern<tosa::CastToBlockScaledOp> {
+ using OpRewritePattern<tosa::CastToBlockScaledOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tosa::CastToBlockScaledOp castToBlockScaledOp,
+ PatternRewriter &rewriter) const override {
+ const Value castToBlockScaledInput = castToBlockScaledOp.getInputData();
+ auto castFromBlockScaledOp =
+ castToBlockScaledInput.getDefiningOp<tosa::CastFromBlockScaledOp>();
+ if (!castFromBlockScaledOp)
+ return rewriter.notifyMatchFailure(
+ castToBlockScaledOp,
+ "input must be cast_from_block_scaled operation");
+
+ const Value innerData = castFromBlockScaledOp.getInputData();
+ const Value innerScale = castFromBlockScaledOp.getInputScale();
+ const auto innerDataTy =
+ dyn_cast<ShapedType>(innerData.getType()).getElementType();
+ const auto innerScaleTy =
+ dyn_cast<ShapedType>(innerScale.getType()).getElementType();
+
+ const Value outerData = castToBlockScaledOp.getOutputData();
+ const Value outerScale = castToBlockScaledOp.getOutputScale();
+ const auto outerDataTy =
+ dyn_cast<ShapedType>(outerData.getType()).getElementType();
+ const auto outerScaleTy =
+ dyn_cast<ShapedType>(outerScale.getType()).getElementType();
+
+ if (innerDataTy != outerDataTy || innerScaleTy != outerScaleTy) {
+ return rewriter.notifyMatchFailure(
+ castToBlockScaledOp,
+ "inputs types to cast_from_block_scaled operation must match output "
+ "types to cast_to_block_scaled");
+ }
+
+ if (castFromBlockScaledOp.getBlockSize() !=
+ castToBlockScaledOp.getBlockSize()) {
+ return rewriter.notifyMatchFailure(
+ castToBlockScaledOp, "block sizes for cast_from_block_scaled and "
+ "cast_to_block_scaled must match");
+ }
+
+ rewriter.replaceOp(castToBlockScaledOp, {innerData, innerScale});
+
+ return success();
+ }
+};
+
+void CastToBlockScaledOp::getCanonicalizationPatterns(
+ RewritePatternSet &results, MLIRContext *context) {
+ results.add<CancellingBlockScaledCastsOptimization>(context);
+}
+
//===----------------------------------------------------------------------===//
// Operator Folders.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 52098413f18d9..c2137f08ee2a4 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -1388,3 +1388,47 @@ func.func @test_canonicalize_narrowing_cast_i32_to_i8_to_i16(%arg0: tensor<13x21
%1 = tosa.cast %0 : (tensor<13x21x3xi8>) -> tensor<13x21x3xi16>
return %1 : tensor<13x21x3xi16>
}
+
+// -----
+
+// CHECK-LABEL: @test_canonicalize_cast_from_cast_to_block_scaled_f4E2M1
+// CHECK: return %arg0, %arg1 : tensor<15x3x2x256xf4E2M1FN>, tensor<15x3x2x8xf8E8M0FNU>
+func.func @test_canonicalize_cast_from_cast_to_block_scaled_f4E2M1(%arg0: tensor<15x3x2x256xf4E2M1FN>, %arg1: tensor<15x3x2x8xf8E8M0FNU>) -> (tensor<15x3x2x256xf4E2M1FN>, tensor<15x3x2x8xf8E8M0FNU>) {
+ %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = BLOCK_SIZE_32} : (tensor<15x3x2x256xf4E2M1FN>, tensor<15x3x2x8xf8E8M0FNU>) -> tensor<15x3x2x256xf32>
+ %1, %2 = tosa.cast_to_block_scaled %0 {block_size = BLOCK_SIZE_32} : (tensor<15x3x2x256xf32>) -> (tensor<15x3x2x256xf4E2M1FN>, tensor<15x3x2x8xf8E8M0FNU>)
+ return %1, %2 : tensor<15x3x2x256xf4E2M1FN>, tensor<15x3x2x8xf8E8M0FNU>
+}
+
+// -----
+
+// CHECK-LABEL: @test_canonicalize_cast_from_cast_to_block_scaled_f8E5M2
+// CHECK: return %arg0, %arg1 : tensor<160xf8E5M2>, tensor<5xf8E8M0FNU>
+func.func @test_canonicalize_cast_from_cast_to_block_scaled_f8E5M2(%arg0: tensor<160xf8E5M2>, %arg1: tensor<5xf8E8M0FNU>) -> (tensor<160xf8E5M2>, tensor<5xf8E8M0FNU>) {
+ %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = BLOCK_SIZE_32} : (tensor<160xf8E5M2>, tensor<5xf8E8M0FNU>) -> tensor<160xf32>
+ %1, %2 = tosa.cast_to_block_scaled %0 {block_size = BLOCK_SIZE_32} : (tensor<160xf32>) -> (tensor<160xf8E5M2>, tensor<5xf8E8M0FNU>)
+ return %1, %2 : tensor<160xf8E5M2>, tensor<5xf8E8M0FNU>
+}
+
+// -----
+
+// CHECK-LABEL: @test_do_not_canonicalize_cast_from_cast_to_block_scaled_different_types_f8E5M2_f6E2M3
+// CHECK: %[[values:.+]] = tosa.cast_from_block_scaled %arg0, %arg1
+// CHECK: %[[data:.+]], %[[scales:.+]] = tosa.cast_to_block_scaled %[[values]]
+// CHECK: return %[[data]], %[[scales]] : tensor<160xf6E2M3FN>, tensor<5xf8E8M0FNU>
+func.func @test_do_not_canonicalize_cast_from_cast_to_block_scaled_different_types_f8E5M2_f6E2M3(%arg0: tensor<160xf8E5M2>, %arg1: tensor<5xf8E8M0FNU>) -> (tensor<160xf6E2M3FN>, tensor<5xf8E8M0FNU>) {
+ %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = BLOCK_SIZE_32} : (tensor<160xf8E5M2>, tensor<5xf8E8M0FNU>) -> tensor<160xf32>
+ %1, %2 = tosa.cast_to_block_scaled %0 {block_size = BLOCK_SIZE_32} : (tensor<160xf32>) -> (tensor<160xf6E2M3FN>, tensor<5xf8E8M0FNU>)
+ return %1, %2 : tensor<160xf6E2M3FN>, tensor<5xf8E8M0FNU>
+}
+
+// -----
+
+// CHECK-LABEL: @test_do_not_canonicalize_cast_from_cast_to_block_scaled_different_types_f6E2M3_f6E3M2
+// CHECK: %[[values:.+]] = tosa.cast_from_block_scaled %arg0, %arg1
+// CHECK: %[[data:.+]], %[[scales:.+]] = tosa.cast_to_block_scaled %[[values]]
+// CHECK: return %[[data]], %[[scales]] : tensor<32xf6E3M2FN>, tensor<1xf8E8M0FNU>
+func.func @test_do_not_canonicalize_cast_from_cast_to_block_scaled_different_types_f6E2M3_f6E3M2(%arg0: tensor<32xf6E2M3FN>, %arg1: tensor<1xf8E8M0FNU>) -> (tensor<32xf6E3M2FN>, tensor<1xf8E8M0FNU>) {
+ %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = BLOCK_SIZE_32} : (tensor<32xf6E2M3FN>, tensor<1xf8E8M0FNU>) -> tensor<32xf32>
+ %1, %2 = tosa.cast_to_block_scaled %0 {block_size = BLOCK_SIZE_32} : (tensor<32xf32>) -> (tensor<32xf6E3M2FN>, tensor<1xf8E8M0FNU>)
+ return %1, %2 : tensor<32xf6E3M2FN>, tensor<1xf8E8M0FNU>
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/188018
More information about the Mlir-commits
mailing list