[Mlir-commits] [mlir] [mlir][tosa] Optimize block scaled cast sequences (PR #188018)
Ian Tayler Lessa
llvmlistbot at llvm.org
Mon Mar 23 04:30:07 PDT 2026
https://github.com/IanTaylerLessa-arm created https://github.com/llvm/llvm-project/pull/188018
Add a canonicalization pattern that will delete cast_from_block_scaled -> cast_to_block_scaled sequences when the input and output types and block sizes match.
Change-Id: I769e4f756df303d7906c0ef38aa489abdd166da1
>From b73363d6110d08eb525cd6d63833e1b1f9e68d1a Mon Sep 17 00:00:00 2001
From: Ian Tayler Lessa <ian.taylerlessa at arm.com>
Date: Fri, 20 Mar 2026 16:03:55 +0000
Subject: [PATCH] [mlir][tosa] Optimize block scaled cast sequences
Add a canonicalization pattern that will delete cast_from_block_scaled
-> cast_to_block_scaled sequences when the input and output types and
block sizes match.
Signed-off-by: Ian Tayler Lessa <ian.taylerlessa at arm.com>
Change-Id: I769e4f756df303d7906c0ef38aa489abdd166da1
---
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 1 +
.../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 53 +++++++++++++++++++
mlir/test/Dialect/Tosa/canonicalize.mlir | 44 +++++++++++++++
3 files changed, 98 insertions(+)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index cab2bccfc27b3..5fc462d6e8a36 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2653,6 +2653,7 @@ def Tosa_CastToBlockScaledOp : Tosa_InferShapedTypeOp<"cast_to_block_scaled", [P
Extension<[Tosa_EXT_BF16, Tosa_EXT_MXFP]>
];
+ let hasCanonicalizer = 1;
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index b622cbedec1dc..9fb928e1ffefb 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -935,6 +935,59 @@ void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<NonNarrowingCastsOptimization>(context);
}
+struct CancellingBlockScaledCastsOptimization
+ : public OpRewritePattern<tosa::CastToBlockScaledOp> {
+ using OpRewritePattern<tosa::CastToBlockScaledOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tosa::CastToBlockScaledOp castToBlockScaledOp,
+ PatternRewriter &rewriter) const override {
+ const Value castToBlockScaledInput = castToBlockScaledOp.getInputData();
+ auto castFromBlockScaledOp =
+ castToBlockScaledInput.getDefiningOp<tosa::CastFromBlockScaledOp>();
+ if (!castFromBlockScaledOp)
+ return rewriter.notifyMatchFailure(
+ castToBlockScaledOp,
+ "input must be cast_from_block_scaled operation");
+
+ const Value innerData = castFromBlockScaledOp.getInputData();
+ const Value innerScale = castFromBlockScaledOp.getInputScale();
+ const auto innerDataTy =
+ dyn_cast<ShapedType>(innerData.getType()).getElementType();
+ const auto innerScaleTy =
+ dyn_cast<ShapedType>(innerScale.getType()).getElementType();
+
+ const Value outerData = castToBlockScaledOp.getOutputData();
+ const Value outerScale = castToBlockScaledOp.getOutputScale();
+ const auto outerDataTy =
+ dyn_cast<ShapedType>(outerData.getType()).getElementType();
+ const auto outerScaleTy =
+ dyn_cast<ShapedType>(outerScale.getType()).getElementType();
+
+ if (innerDataTy != outerDataTy || innerScaleTy != outerScaleTy) {
+ return rewriter.notifyMatchFailure(
+ castToBlockScaledOp,
+ "inputs types to cast_from_block_scaled operation must match output "
+ "types to cast_to_block_scaled");
+ }
+
+ if (castFromBlockScaledOp.getBlockSize() !=
+ castToBlockScaledOp.getBlockSize()) {
+ return rewriter.notifyMatchFailure(
+ castToBlockScaledOp, "block sizes for cast_from_block_scaled and "
+ "cast_to_block_scaled must match");
+ }
+
+ rewriter.replaceOp(castToBlockScaledOp, {innerData, innerScale});
+
+ return success();
+ }
+};
+
+void CastToBlockScaledOp::getCanonicalizationPatterns(
+ RewritePatternSet &results, MLIRContext *context) {
+ results.add<CancellingBlockScaledCastsOptimization>(context);
+}
+
//===----------------------------------------------------------------------===//
// Operator Folders.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 52098413f18d9..c2137f08ee2a4 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -1388,3 +1388,47 @@ func.func @test_canonicalize_narrowing_cast_i32_to_i8_to_i16(%arg0: tensor<13x21
%1 = tosa.cast %0 : (tensor<13x21x3xi8>) -> tensor<13x21x3xi16>
return %1 : tensor<13x21x3xi16>
}
+
+// -----
+
+// CHECK-LABEL: @test_canonicalize_cast_from_cast_to_block_scaled_f4E2M1
+// CHECK: return %arg0, %arg1 : tensor<15x3x2x256xf4E2M1FN>, tensor<15x3x2x8xf8E8M0FNU>
+func.func @test_canonicalize_cast_from_cast_to_block_scaled_f4E2M1(%arg0: tensor<15x3x2x256xf4E2M1FN>, %arg1: tensor<15x3x2x8xf8E8M0FNU>) -> (tensor<15x3x2x256xf4E2M1FN>, tensor<15x3x2x8xf8E8M0FNU>) {
+ %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = BLOCK_SIZE_32} : (tensor<15x3x2x256xf4E2M1FN>, tensor<15x3x2x8xf8E8M0FNU>) -> tensor<15x3x2x256xf32>
+ %1, %2 = tosa.cast_to_block_scaled %0 {block_size = BLOCK_SIZE_32} : (tensor<15x3x2x256xf32>) -> (tensor<15x3x2x256xf4E2M1FN>, tensor<15x3x2x8xf8E8M0FNU>)
+ return %1, %2 : tensor<15x3x2x256xf4E2M1FN>, tensor<15x3x2x8xf8E8M0FNU>
+}
+
+// -----
+
+// CHECK-LABEL: @test_canonicalize_cast_from_cast_to_block_scaled_f8E5M2
+// CHECK: return %arg0, %arg1 : tensor<160xf8E5M2>, tensor<5xf8E8M0FNU>
+func.func @test_canonicalize_cast_from_cast_to_block_scaled_f8E5M2(%arg0: tensor<160xf8E5M2>, %arg1: tensor<5xf8E8M0FNU>) -> (tensor<160xf8E5M2>, tensor<5xf8E8M0FNU>) {
+ %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = BLOCK_SIZE_32} : (tensor<160xf8E5M2>, tensor<5xf8E8M0FNU>) -> tensor<160xf32>
+ %1, %2 = tosa.cast_to_block_scaled %0 {block_size = BLOCK_SIZE_32} : (tensor<160xf32>) -> (tensor<160xf8E5M2>, tensor<5xf8E8M0FNU>)
+ return %1, %2 : tensor<160xf8E5M2>, tensor<5xf8E8M0FNU>
+}
+
+// -----
+
+// CHECK-LABEL: @test_do_not_canonicalize_cast_from_cast_to_block_scaled_different_types_f8E5M2_f6E2M3
+// CHECK: %[[values:.+]] = tosa.cast_from_block_scaled %arg0, %arg1
+// CHECK: %[[data:.+]], %[[scales:.+]] = tosa.cast_to_block_scaled %[[values]]
+// CHECK: return %[[data]], %[[scales]] : tensor<160xf6E2M3FN>, tensor<5xf8E8M0FNU>
+func.func @test_do_not_canonicalize_cast_from_cast_to_block_scaled_different_types_f8E5M2_f6E2M3(%arg0: tensor<160xf8E5M2>, %arg1: tensor<5xf8E8M0FNU>) -> (tensor<160xf6E2M3FN>, tensor<5xf8E8M0FNU>) {
+ %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = BLOCK_SIZE_32} : (tensor<160xf8E5M2>, tensor<5xf8E8M0FNU>) -> tensor<160xf32>
+ %1, %2 = tosa.cast_to_block_scaled %0 {block_size = BLOCK_SIZE_32} : (tensor<160xf32>) -> (tensor<160xf6E2M3FN>, tensor<5xf8E8M0FNU>)
+ return %1, %2 : tensor<160xf6E2M3FN>, tensor<5xf8E8M0FNU>
+}
+
+// -----
+
+// CHECK-LABEL: @test_do_not_canonicalize_cast_from_cast_to_block_scaled_different_types_f6E2M3_f6E3M2
+// CHECK: %[[values:.+]] = tosa.cast_from_block_scaled %arg0, %arg1
+// CHECK: %[[data:.+]], %[[scales:.+]] = tosa.cast_to_block_scaled %[[values]]
+// CHECK: return %[[data]], %[[scales]] : tensor<32xf6E3M2FN>, tensor<1xf8E8M0FNU>
+func.func @test_do_not_canonicalize_cast_from_cast_to_block_scaled_different_types_f6E2M3_f6E3M2(%arg0: tensor<32xf6E2M3FN>, %arg1: tensor<1xf8E8M0FNU>) -> (tensor<32xf6E3M2FN>, tensor<1xf8E8M0FNU>) {
+ %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = BLOCK_SIZE_32} : (tensor<32xf6E2M3FN>, tensor<1xf8E8M0FNU>) -> tensor<32xf32>
+ %1, %2 = tosa.cast_to_block_scaled %0 {block_size = BLOCK_SIZE_32} : (tensor<32xf32>) -> (tensor<32xf6E3M2FN>, tensor<1xf8E8M0FNU>)
+ return %1, %2 : tensor<32xf6E3M2FN>, tensor<1xf8E8M0FNU>
+}
More information about the Mlir-commits
mailing list