[Mlir-commits] [mlir] [Mlir] decompose generic by unfolding projected permutation crash fix (PR #122449)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jan 13 11:27:36 PST 2025


https://github.com/GrumpyPigSkin updated https://github.com/llvm/llvm-project/pull/122449

>From fe3886ed346f45ee2d94672e550cce47477bbace Mon Sep 17 00:00:00 2001
From: GrumpyPigSkin <oliver61 at live.co.uk>
Date: Fri, 10 Jan 2025 13:11:15 +0000
Subject: [PATCH 1/9] Added check to prevent processing of scalars

---
 ...DecomposeGenericByUnfoldingPermutation.cpp |  8 ++++++
 ...olding-projected-permutation-validate.mlir | 28 +++++++++++++++++++
 2 files changed, 36 insertions(+)
 create mode 100644 mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation-validate.mlir

diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
index 83c4b5bdf10976..ce1c21504f1dc7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
@@ -159,6 +159,14 @@ LogicalResult DecomposeProjectedPermutation::matchAndRewrite(
     auto map = op.getMatchingIndexingMap(&opOperand);
     if (!map.isProjectedPermutation(false))
       return failure();
+
+    // If we have any inputs that aren't memref or ranked tensor types, reject the pattern.
+    if (!dyn_cast<ShapedType>(opOperand.get().getType()))
+      return op->emitError("Expected operand #")
+             << opOperand.getOperandNumber()
+             << " to be memref of any type values or ranked tensor of any type "
+                "values, but got "
+             << opOperand.get().getType();
   }
 
   // Decomposing linalg.generic involves creating `tensor.empty`
diff --git a/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation-validate.mlir b/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation-validate.mlir
new file mode 100644
index 00000000000000..43fdd17e10078c
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation-validate.mlir
@@ -0,0 +1,28 @@
+// RUN: mlir-opt %s -linalg-specialize-generic-ops -verify-diagnostics
+
+// Fixes issue: 122094. Verify that the following code causes an error to be produced.
+
+func.func @test_broadcast_scalar_across_single_tensor() -> tensor<2x2xi32> {
+
+  %a = arith.constant dense<2> : tensor<2x2xi32>
+  %b = arith.constant 42 : i32
+  %c = tensor.empty() : tensor<2x2xi32>
+  // expected-error @+1 {{Expected operand #1 to be memref of any type values or ranked tensor of any type values, but got 'i32'}}
+  %res = linalg.generic
+    {
+      indexing_maps = [
+        affine_map<(i, j) -> (i, j)>, 
+        affine_map<(i, j) -> ()>,     
+        affine_map<(i, j) -> (i, j)>  
+      ],
+      iterator_types = ["parallel", "parallel"]
+    }
+    ins(%a, %b : tensor<2x2xi32>, i32)
+    outs(%c : tensor<2x2xi32>) {
+  ^bb0(%x: i32, %scalar: i32, %out: i32):
+    %sum = arith.addi %x, %scalar : i32
+    linalg.yield %sum : i32
+  } -> tensor<2x2xi32>
+
+  return %res : tensor<2x2xi32>
+}

>From 3d446a76e0d331c0548edeb44e07ebd9938c54e5 Mon Sep 17 00:00:00 2001
From: GrumpyPigSkin <oliver61 at live.co.uk>
Date: Fri, 10 Jan 2025 13:18:00 +0000
Subject: [PATCH 2/9] Applied formatting

---
 .../Transforms/DecomposeGenericByUnfoldingPermutation.cpp      | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
index ce1c21504f1dc7..19ad156d2f6859 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
@@ -160,7 +160,8 @@ LogicalResult DecomposeProjectedPermutation::matchAndRewrite(
     if (!map.isProjectedPermutation(false))
       return failure();
 
-    // If we have any inputs that aren't memref or ranked tensor types, reject the pattern.
+    // If we have any inputs that aren't memref or ranked tensor types, reject
+    // the pattern.
     if (!dyn_cast<ShapedType>(opOperand.get().getType()))
       return op->emitError("Expected operand #")
              << opOperand.getOperandNumber()

>From 86a144e85774950761c75d42121ca1a9db92d104 Mon Sep 17 00:00:00 2001
From: GrumpyPigSkin <oliver61 at live.co.uk>
Date: Fri, 10 Jan 2025 15:23:51 +0000
Subject: [PATCH 3/9] Used rewriter.notifyMatchFailure instead

---
 .../DecomposeGenericByUnfoldingPermutation.cpp       | 12 +++++++-----
 ...-by-unfolding-projected-permutation-validate.mlir |  3 +--
 2 files changed, 8 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
index 19ad156d2f6859..6f7f2a0fdf6280 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "llvm/Support/FormatVariadic.h"
 #include <map>
 #include <optional>
 #include <utility>
@@ -163,11 +164,12 @@ LogicalResult DecomposeProjectedPermutation::matchAndRewrite(
     // If we have any inputs that aren't memref or ranked tensor types, reject
     // the pattern.
     if (!dyn_cast<ShapedType>(opOperand.get().getType()))
-      return op->emitError("Expected operand #")
-             << opOperand.getOperandNumber()
-             << " to be memref of any type values or ranked tensor of any type "
-                "values, but got "
-             << opOperand.get().getType();
+      return rewriter.notifyMatchFailure(
+          opOperand.get().getLoc(),
+          llvm::formatv("Expected operand #{0} to be memref of any type values "
+                        "or ranked tensor of any type values, but got {1}",
+                        opOperand.getOperandNumber(),
+                        opOperand.get().getType()));
   }
 
   // Decomposing linalg.generic involves creating `tensor.empty`
diff --git a/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation-validate.mlir b/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation-validate.mlir
index 43fdd17e10078c..f2542eee3149d8 100644
--- a/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation-validate.mlir
+++ b/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation-validate.mlir
@@ -1,13 +1,12 @@
 // RUN: mlir-opt %s -linalg-specialize-generic-ops -verify-diagnostics
 
-// Fixes issue: 122094. Verify that the following code causes an error to be produced.
+// Fixes issue: 122094. Verify that the following code compiles without issue.
 
 func.func @test_broadcast_scalar_across_single_tensor() -> tensor<2x2xi32> {
 
   %a = arith.constant dense<2> : tensor<2x2xi32>
   %b = arith.constant 42 : i32
   %c = tensor.empty() : tensor<2x2xi32>
-  // expected-error @+1 {{Expected operand #1 to be memref of any type values or ranked tensor of any type values, but got 'i32'}}
   %res = linalg.generic
     {
       indexing_maps = [

>From 02be6352969a132b6058ed12811e087d93267872 Mon Sep 17 00:00:00 2001
From: GrumpyPigSkin <oliver61 at live.co.uk>
Date: Fri, 10 Jan 2025 16:27:10 +0000
Subject: [PATCH 4/9] Added test into other test script

---
 ...olding-projected-permutation-validate.mlir | 27 --------------
 ...ic-by-unfolding-projected-permutation.mlir | 36 +++++++++++++++++++
 2 files changed, 36 insertions(+), 27 deletions(-)
 delete mode 100644 mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation-validate.mlir

diff --git a/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation-validate.mlir b/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation-validate.mlir
deleted file mode 100644
index f2542eee3149d8..00000000000000
--- a/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation-validate.mlir
+++ /dev/null
@@ -1,27 +0,0 @@
-// RUN: mlir-opt %s -linalg-specialize-generic-ops -verify-diagnostics
-
-// Fixes issue: 122094. Verify that the following code compiles without issue.
-
-func.func @test_broadcast_scalar_across_single_tensor() -> tensor<2x2xi32> {
-
-  %a = arith.constant dense<2> : tensor<2x2xi32>
-  %b = arith.constant 42 : i32
-  %c = tensor.empty() : tensor<2x2xi32>
-  %res = linalg.generic
-    {
-      indexing_maps = [
-        affine_map<(i, j) -> (i, j)>, 
-        affine_map<(i, j) -> ()>,     
-        affine_map<(i, j) -> (i, j)>  
-      ],
-      iterator_types = ["parallel", "parallel"]
-    }
-    ins(%a, %b : tensor<2x2xi32>, i32)
-    outs(%c : tensor<2x2xi32>) {
-  ^bb0(%x: i32, %scalar: i32, %out: i32):
-    %sum = arith.addi %x, %scalar : i32
-    linalg.yield %sum : i32
-  } -> tensor<2x2xi32>
-
-  return %res : tensor<2x2xi32>
-}
diff --git a/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation.mlir b/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation.mlir
index 38e406a13ec087..3df77a9430c8d2 100644
--- a/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation.mlir
+++ b/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation.mlir
@@ -69,3 +69,39 @@ func.func @broadcast_only(%x : tensor<2x16x32xf32>, %y:  tensor<2x32xf32>, %z :
 // CHECK: %[[X_bc:.+]] = linalg.broadcast ins(%[[Y]] : tensor<2x32xf32>) outs(%[[E0]] : tensor<2x16x32xf32>) dimensions = [1]
 // CHECK: {{.*}} = linalg.div ins(%[[X]], %[[X_bc]] : tensor<2x16x32xf32>, tensor<2x16x32xf32>) outs(%arg2 : tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
 // CHECK-NOT: linalg.generic
+
+// -----
+
+func.func @test_broadcast_scalar_across_single_tensor() -> tensor<2x2xi32> {
+  %a = arith.constant dense<2> : tensor<2x2xi32>
+  %b = arith.constant 42 : i32
+  %c = tensor.empty() : tensor<2x2xi32>
+  %res = linalg.generic
+    {
+      indexing_maps = [
+        affine_map<(i, j) -> (i, j)>, 
+        affine_map<(i, j) -> ()>,     
+        affine_map<(i, j) -> (i, j)>  
+      ],
+      iterator_types = ["parallel", "parallel"]
+    }
+    ins(%a, %b : tensor<2x2xi32>, i32)
+    outs(%c : tensor<2x2xi32>) {
+  ^bb0(%x: i32, %scalar: i32, %out: i32):
+    %sum = arith.addi %x, %scalar : i32
+    linalg.yield %sum : i32
+  } -> tensor<2x2xi32>
+  return %res : tensor<2x2xi32>
+}
+
+// CHECK-LABEL: test_broadcast_scalar_across_single_tensor
+// CHECK-SAME: () -> tensor<2x2xi32> {
+// CHECK:   %[[E0:.+]] = linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel"]} ins(%cst, %c42_i32 : tensor<2x2xi32>, i32) outs(%0 : tensor<2x2xi32>) {
+// CHECK:   ^bb0(%in: i32, %in_0: i32, %out: i32):
+// CHECK:     %[[E0:.+]] = arith.addi %in, %in_0 : i32
+// CHECK:     linalg.yield %2 : i32
+// CHECK:   } -> tensor<2x2xi32>
+
+
+
+

>From c620a730df35b69ac43762cbd5fcaa17e6831f52 Mon Sep 17 00:00:00 2001
From: GrumpyPigSkin <oliver61 at live.co.uk>
Date: Fri, 10 Jan 2025 16:28:57 +0000
Subject: [PATCH 5/9] Removed additional blank lines

---
 .../decompose-generic-by-unfolding-projected-permutation.mlir | 4 ----
 1 file changed, 4 deletions(-)

diff --git a/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation.mlir b/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation.mlir
index 3df77a9430c8d2..9e3a064757da10 100644
--- a/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation.mlir
+++ b/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation.mlir
@@ -101,7 +101,3 @@ func.func @test_broadcast_scalar_across_single_tensor() -> tensor<2x2xi32> {
 // CHECK:     %[[E0:.+]] = arith.addi %in, %in_0 : i32
 // CHECK:     linalg.yield %2 : i32
 // CHECK:   } -> tensor<2x2xi32>
-
-
-
-

>From 39f8388d574cbc010d8bdd77cdac3f6dad2dece6 Mon Sep 17 00:00:00 2001
From: GrumpyPigSkin <oliver61 at live.co.uk>
Date: Sat, 11 Jan 2025 12:01:32 +0000
Subject: [PATCH 6/9] Updated test

---
 ...ose-generic-by-unfolding-projected-permutation.mlir | 10 ++++------
 1 file changed, 4 insertions(+), 6 deletions(-)

diff --git a/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation.mlir b/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation.mlir
index 9e3a064757da10..6e3b5dec665d45 100644
--- a/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation.mlir
+++ b/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation.mlir
@@ -72,7 +72,7 @@ func.func @broadcast_only(%x : tensor<2x16x32xf32>, %y:  tensor<2x32xf32>, %z :
 
 // -----
 
-func.func @test_broadcast_scalar_across_single_tensor() -> tensor<2x2xi32> {
+func.func @no_decompose_on_scalar() -> tensor<2x2xi32> {
   %a = arith.constant dense<2> : tensor<2x2xi32>
   %b = arith.constant 42 : i32
   %c = tensor.empty() : tensor<2x2xi32>
@@ -96,8 +96,6 @@ func.func @test_broadcast_scalar_across_single_tensor() -> tensor<2x2xi32> {
 
 // CHECK-LABEL: test_broadcast_scalar_across_single_tensor
 // CHECK-SAME: () -> tensor<2x2xi32> {
-// CHECK:   %[[E0:.+]] = linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel"]} ins(%cst, %c42_i32 : tensor<2x2xi32>, i32) outs(%0 : tensor<2x2xi32>) {
-// CHECK:   ^bb0(%in: i32, %in_0: i32, %out: i32):
-// CHECK:     %[[E0:.+]] = arith.addi %in, %in_0 : i32
-// CHECK:     linalg.yield %2 : i32
-// CHECK:   } -> tensor<2x2xi32>
+// CHECK-DAG: %[[CST:.+]] = arith.constant dense<2> : tensor<2x2xi32>
+// CHECK-DAG: %[[C42:.+]] = arith.constant 42 : i32
+// CHECK:   %[[E0:.+]] = linalg.generic {{.*}} ins(%[[CST]], %[[C42]] : tensor<2x2xi32>, i32) outs(%0 : tensor<2x2xi32>) {

>From 21bb28ce0de5cbbfd0b5f27d3409f39a4d2b116f Mon Sep 17 00:00:00 2001
From: GrumpyPigSkin <oliver61 at live.co.uk>
Date: Sat, 11 Jan 2025 12:05:04 +0000
Subject: [PATCH 7/9] Updated test

---
 ...ompose-generic-by-unfolding-projected-permutation.mlir | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation.mlir b/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation.mlir
index 6e3b5dec665d45..d502ee744bbb4b 100644
--- a/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation.mlir
+++ b/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation.mlir
@@ -94,8 +94,8 @@ func.func @no_decompose_on_scalar() -> tensor<2x2xi32> {
   return %res : tensor<2x2xi32>
 }
 
-// CHECK-LABEL: test_broadcast_scalar_across_single_tensor
+// CHECK-LABEL: no_decompose_on_scalar
 // CHECK-SAME: () -> tensor<2x2xi32> {
-// CHECK-DAG: %[[CST:.+]] = arith.constant dense<2> : tensor<2x2xi32>
-// CHECK-DAG: %[[C42:.+]] = arith.constant 42 : i32
-// CHECK:   %[[E0:.+]] = linalg.generic {{.*}} ins(%[[CST]], %[[C42]] : tensor<2x2xi32>, i32) outs(%0 : tensor<2x2xi32>) {
+// CHECK: %[[CST:.+]] = arith.constant dense<2> : tensor<2x2xi32>
+// CHECK: %[[C42:.+]] = arith.constant 42 : i32
+// CHECK: linalg.generic {{.*}} ins(%[[CST]], %[[C42]] : tensor<2x2xi32>, i32) outs(%0 : tensor<2x2xi32>) {

>From 714bb8fa93963573ed5e9a7949a214b653a3d161 Mon Sep 17 00:00:00 2001
From: GrumpyPigSkin <oliver61 at live.co.uk>
Date: Sat, 11 Jan 2025 12:07:43 +0000
Subject: [PATCH 8/9] Added check-dag

---
 .../decompose-generic-by-unfolding-projected-permutation.mlir | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation.mlir b/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation.mlir
index d502ee744bbb4b..2baac12c7e7979 100644
--- a/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation.mlir
+++ b/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation.mlir
@@ -96,6 +96,6 @@ func.func @no_decompose_on_scalar() -> tensor<2x2xi32> {
 
 // CHECK-LABEL: no_decompose_on_scalar
 // CHECK-SAME: () -> tensor<2x2xi32> {
-// CHECK: %[[CST:.+]] = arith.constant dense<2> : tensor<2x2xi32>
-// CHECK: %[[C42:.+]] = arith.constant 42 : i32
+// CHECK-DAG: %[[CST:.+]] = arith.constant dense<2> : tensor<2x2xi32>
+// CHECK-DAG: %[[C42:.+]] = arith.constant 42 : i32
 // CHECK: linalg.generic {{.*}} ins(%[[CST]], %[[C42]] : tensor<2x2xi32>, i32) outs(%0 : tensor<2x2xi32>) {

>From 981d5a2b81cc9fe0db90410b5071cb8362ebf4c9 Mon Sep 17 00:00:00 2001
From: GrumpyPigSkin <oliver61 at live.co.uk>
Date: Mon, 13 Jan 2025 19:27:17 +0000
Subject: [PATCH 9/9] Changed check to tensor only, and added comment to test

---
 .../Transforms/DecomposeGenericByUnfoldingPermutation.cpp | 8 ++++----
 ...ompose-generic-by-unfolding-projected-permutation.mlir | 2 ++
 2 files changed, 6 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
index 6f7f2a0fdf6280..6df2710bf3ce2b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
@@ -161,13 +161,13 @@ LogicalResult DecomposeProjectedPermutation::matchAndRewrite(
     if (!map.isProjectedPermutation(false))
       return failure();
 
-    // If we have any inputs that aren't memref or ranked tensor types, reject
+    // If we have any inputs that aren't ranked tensor types, reject
     // the pattern.
-    if (!dyn_cast<ShapedType>(opOperand.get().getType()))
+    if (!dyn_cast<RankedTensorType>(opOperand.get().getType()))
       return rewriter.notifyMatchFailure(
           opOperand.get().getLoc(),
-          llvm::formatv("Expected operand #{0} to be memref of any type values "
-                        "or ranked tensor of any type values, but got {1}",
+          llvm::formatv("Expected operand #{0} to be "
+                        "ranked tensor of any type values, but got {1}",
                         opOperand.getOperandNumber(),
                         opOperand.get().getType()));
   }
diff --git a/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation.mlir b/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation.mlir
index 2baac12c7e7979..b4de3a7a65a4f2 100644
--- a/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation.mlir
+++ b/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation.mlir
@@ -72,6 +72,8 @@ func.func @broadcast_only(%x : tensor<2x16x32xf32>, %y:  tensor<2x32xf32>, %z :
 
 // -----
 
+// unsupported currently.
+
 func.func @no_decompose_on_scalar() -> tensor<2x2xi32> {
   %a = arith.constant dense<2> : tensor<2x2xi32>
   %b = arith.constant 42 : i32



More information about the Mlir-commits mailing list