[Mlir-commits] [mlir] [mlir][tosa] Optimize block scaled cast sequences (PR #188018)

Ian Tayler Lessa llvmlistbot at llvm.org
Tue Mar 24 06:59:51 PDT 2026


================
@@ -935,6 +935,59 @@ void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<NonNarrowingCastsOptimization>(context);
 }
 
+struct CancellingBlockScaledCastsOptimization
+    : public OpRewritePattern<tosa::CastToBlockScaledOp> {
+  using OpRewritePattern<tosa::CastToBlockScaledOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::CastToBlockScaledOp castToBlockScaledOp,
+                                PatternRewriter &rewriter) const override {
+    const Value castToBlockScaledInput = castToBlockScaledOp.getInputData();
+    auto castFromBlockScaledOp =
+        castToBlockScaledInput.getDefiningOp<tosa::CastFromBlockScaledOp>();
+    if (!castFromBlockScaledOp)
+      return rewriter.notifyMatchFailure(
+          castToBlockScaledOp,
+          "input must be cast_from_block_scaled operation");
+
+    const Value innerData = castFromBlockScaledOp.getInputData();
+    const Value innerScale = castFromBlockScaledOp.getInputScale();
+    const auto innerDataTy =
+        dyn_cast<ShapedType>(innerData.getType()).getElementType();
+    const auto innerScaleTy =
+        dyn_cast<ShapedType>(innerScale.getType()).getElementType();
+
+    const Value outerData = castToBlockScaledOp.getOutputData();
+    const Value outerScale = castToBlockScaledOp.getOutputScale();
+    const auto outerDataTy =
+        dyn_cast<ShapedType>(outerData.getType()).getElementType();
+    const auto outerScaleTy =
+        dyn_cast<ShapedType>(outerScale.getType()).getElementType();
+
+    if (innerDataTy != outerDataTy || innerScaleTy != outerScaleTy) {
+      return rewriter.notifyMatchFailure(
+          castToBlockScaledOp,
+          "inputs types to cast_from_block_scaled operation must match output "
+          "types to cast_to_block_scaled");
+    }
+
+    if (castFromBlockScaledOp.getBlockSize() !=
+        castToBlockScaledOp.getBlockSize()) {
+      return rewriter.notifyMatchFailure(
+          castToBlockScaledOp, "block sizes for cast_from_block_scaled and "
----------------
IanTaylerLessa-arm wrote:

Currently `BlockSize` only has one value, so I don't think there's a way to exercise this test. This is more for future-proofing since if anyone adds another value to it this optimisation will suddenly and unexpectedly become illegal.

https://github.com/llvm/llvm-project/pull/188018


More information about the Mlir-commits mailing list