[Mlir-commits] [mlir] [mlir][linalg] Fix crash in linalg-specialize-generic-ops with scalar inputs (PR #189212)

Mehdi Amini llvmlistbot at llvm.org
Sun Mar 29 03:10:47 PDT 2026


https://github.com/joker-eph updated https://github.com/llvm/llvm-project/pull/189212

>From c74b8ff6024d178cd90ec2b68bb74ad491096d5f Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Sun, 29 Mar 2026 00:00:13 -0700
Subject: [PATCH] [mlir][linalg] Fix crash in linalg-specialize-generic-ops
 with scalar inputs
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

`DecomposeProjectedPermutation` used `cast<RankedTensorType>` on every
operand of a `linalg.generic`, but `linalg.generic` allows scalar (non-tensor)
inputs with 0-D affine maps. When such an operand was present the hard cast
caused an assertion failure.

Fix by using `dyn_cast<RankedTensorType>` and returning failure (skip
decomposition) when any operand is not a ranked tensor type. The existing
`hasPureTensorSemantics()` guard does not exclude scalar inputs — it only
requires that no operand is a memref.

Fixes https://github.com/llvm/llvm-project/issues/122094

Assisted-by: Claude Code
---
 ...DecomposeGenericByUnfoldingPermutation.cpp | 12 ++++---
 ...ic-by-unfolding-projected-permutation.mlir | 33 +++++++++++++++++++
 2 files changed, 40 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
index 9015cbb096f88..a8fac0a1ba6f0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
@@ -159,12 +159,14 @@ LogicalResult DecomposeProjectedPermutation::matchAndRewrite(
       return failure();
   }
 
-  // Decomposing linalg.generic involves creating `tensor.empty`
-  // which can have dynamic shapes but then we would have to work
-  // out which operand can supply that runtime-value (tensor.dim).
-  // Leaving it as a future TODO.
+  // Bail out for operands that are not ranked tensors (e.g. scalar inputs)
+  // or that have dynamic shapes. Decomposing requires creating `tensor.empty`
+  // with static shapes; dynamic shapes would require finding which operand
+  // can supply the runtime value (tensor.dim) — a future TODO.
   if (llvm::any_of(op->getOpOperands(), [](OpOperand &oper) {
-        auto opType = cast<RankedTensorType>(oper.get().getType());
+        auto opType = dyn_cast<RankedTensorType>(oper.get().getType());
+        if (!opType)
+          return true; // scalar or unranked tensor: bail out
         return ShapedType::isDynamicShape(opType.getShape());
       }))
     return failure();
diff --git a/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation.mlir b/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation.mlir
index 38e406a13ec08..fc5dd4305b76a 100644
--- a/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation.mlir
+++ b/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation.mlir
@@ -69,3 +69,36 @@ func.func @broadcast_only(%x : tensor<2x16x32xf32>, %y:  tensor<2x32xf32>, %z :
 // CHECK: %[[X_bc:.+]] = linalg.broadcast ins(%[[Y]] : tensor<2x32xf32>) outs(%[[E0]] : tensor<2x16x32xf32>) dimensions = [1]
 // CHECK: {{.*}} = linalg.div ins(%[[X]], %[[X_bc]] : tensor<2x16x32xf32>, tensor<2x16x32xf32>) outs(%arg2 : tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
 // CHECK-NOT: linalg.generic
+
+// -----
+
+// Verify that linalg.generic with scalar (non-tensor) inputs is not decomposed
+// and does not crash. Scalar inputs have 0-D affine maps and are not
+// RankedTensorType; the pass must handle them gracefully by bailing out.
+// (GitHub issue #122094)
+
+// CHECK-LABEL: func @scalar_input
+// The op must survive unchanged: linalg.generic is preserved (not decomposed).
+// CHECK: linalg.generic
+// CHECK: } -> tensor<4x4xi32>
+// CHECK-NOT: linalg.broadcast
+// CHECK-NOT: linalg.transpose
+
+#map = affine_map<(d0, d1) -> (d0)>
+#map1 = affine_map<(d0, d1) -> (d1)>
+#map2 = affine_map<(d0, d1) -> ()>
+#map3 = affine_map<(d0, d1) -> (d0, d1)>
+
+func.func @scalar_input(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>, %arg2: i32) -> tensor<4x4xi32> {
+  %0 = tensor.empty() : tensor<4x4xi32>
+  %1 = linalg.generic {indexing_maps = [#map, #map1, #map2, #map3],
+                        iterator_types = ["parallel", "parallel"]}
+    ins(%arg0, %arg1, %arg2 : tensor<4xi32>, tensor<4xi32>, i32)
+    outs(%0 : tensor<4x4xi32>) {
+  ^bb0(%in: i32, %in2: i32, %in3: i32, %out: i32):
+    %2 = arith.muli %in, %in2 : i32
+    %3 = arith.addi %in3, %2 : i32
+    linalg.yield %3 : i32
+  } -> tensor<4x4xi32>
+  return %1 : tensor<4x4xi32>
+}



More information about the Mlir-commits mailing list