[Mlir-commits] [mlir] [Mlir] decompose generic by unfolding projected permutation crash fix (PR #122449)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 10 05:17:19 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: None (GrumpyPigSkin)

<details>
<summary>Changes</summary>

Fixes #<!-- -->122094.

@<!-- -->CoTinker could you please review.

I added the check in DecomposeGenericByUnfoldingPermutation.cpp as adding the check anywhere else was too general and would cause other valid test cases to fail.

---
Full diff: https://github.com/llvm/llvm-project/pull/122449.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp (+8) 
- (added) mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation-validate.mlir (+28) 


``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
index 83c4b5bdf10976..ce1c21504f1dc7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
@@ -159,6 +159,14 @@ LogicalResult DecomposeProjectedPermutation::matchAndRewrite(
     auto map = op.getMatchingIndexingMap(&opOperand);
     if (!map.isProjectedPermutation(false))
       return failure();
+
+    // If we have any inputs that aren't memref or ranked tensor types, reject the pattern.
+    if (!dyn_cast<ShapedType>(opOperand.get().getType()))
+      return op->emitError("Expected operand #")
+             << opOperand.getOperandNumber()
+             << " to be memref of any type values or ranked tensor of any type "
+                "values, but got "
+             << opOperand.get().getType();
   }
 
   // Decomposing linalg.generic involves creating `tensor.empty`
diff --git a/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation-validate.mlir b/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation-validate.mlir
new file mode 100644
index 00000000000000..43fdd17e10078c
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation-validate.mlir
@@ -0,0 +1,28 @@
+// RUN: mlir-opt %s -linalg-specialize-generic-ops -verify-diagnostics
+
+// Fixes issue: 122094. Verify that the following code causes an error to be produced.
+
+func.func @test_broadcast_scalar_across_single_tensor() -> tensor<2x2xi32> {
+
+  %a = arith.constant dense<2> : tensor<2x2xi32>
+  %b = arith.constant 42 : i32
+  %c = tensor.empty() : tensor<2x2xi32>
+  // expected-error @+1 {{Expected operand #1 to be memref of any type values or ranked tensor of any type values, but got 'i32'}}
+  %res = linalg.generic
+    {
+      indexing_maps = [
+        affine_map<(i, j) -> (i, j)>, 
+        affine_map<(i, j) -> ()>,     
+        affine_map<(i, j) -> (i, j)>  
+      ],
+      iterator_types = ["parallel", "parallel"]
+    }
+    ins(%a, %b : tensor<2x2xi32>, i32)
+    outs(%c : tensor<2x2xi32>) {
+  ^bb0(%x: i32, %scalar: i32, %out: i32):
+    %sum = arith.addi %x, %scalar : i32
+    linalg.yield %sum : i32
+  } -> tensor<2x2xi32>
+
+  return %res : tensor<2x2xi32>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/122449


More information about the Mlir-commits mailing list