[Mlir-commits] [mlir] [Mlir] decompose generic by unfolding projected permutation crash fix (PR #122449)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jan 10 07:24:11 PST 2025
https://github.com/GrumpyPigSkin updated https://github.com/llvm/llvm-project/pull/122449
>From fe3886ed346f45ee2d94672e550cce47477bbace Mon Sep 17 00:00:00 2001
From: GrumpyPigSkin <oliver61 at live.co.uk>
Date: Fri, 10 Jan 2025 13:11:15 +0000
Subject: [PATCH 1/3] Added check to prevent processing of scalars
---
...DecomposeGenericByUnfoldingPermutation.cpp | 8 ++++++
...olding-projected-permutation-validate.mlir | 28 +++++++++++++++++++
2 files changed, 36 insertions(+)
create mode 100644 mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation-validate.mlir
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
index 83c4b5bdf10976..ce1c21504f1dc7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
@@ -159,6 +159,14 @@ LogicalResult DecomposeProjectedPermutation::matchAndRewrite(
auto map = op.getMatchingIndexingMap(&opOperand);
if (!map.isProjectedPermutation(false))
return failure();
+
+ // If we have any inputs that aren't memref or ranked tensor types, reject the pattern.
+ if (!dyn_cast<ShapedType>(opOperand.get().getType()))
+ return op->emitError("Expected operand #")
+ << opOperand.getOperandNumber()
+ << " to be memref of any type values or ranked tensor of any type "
+ "values, but got "
+ << opOperand.get().getType();
}
// Decomposing linalg.generic involves creating `tensor.empty`
diff --git a/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation-validate.mlir b/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation-validate.mlir
new file mode 100644
index 00000000000000..43fdd17e10078c
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation-validate.mlir
@@ -0,0 +1,28 @@
+// RUN: mlir-opt %s -linalg-specialize-generic-ops -verify-diagnostics
+
+// Fixes issue: 122094. Verify that the following code causes an error to be produced.
+
+func.func @test_broadcast_scalar_across_single_tensor() -> tensor<2x2xi32> {
+
+ %a = arith.constant dense<2> : tensor<2x2xi32>
+ %b = arith.constant 42 : i32
+ %c = tensor.empty() : tensor<2x2xi32>
+ // expected-error @+1 {{Expected operand #1 to be memref of any type values or ranked tensor of any type values, but got 'i32'}}
+ %res = linalg.generic
+ {
+ indexing_maps = [
+ affine_map<(i, j) -> (i, j)>,
+ affine_map<(i, j) -> ()>,
+ affine_map<(i, j) -> (i, j)>
+ ],
+ iterator_types = ["parallel", "parallel"]
+ }
+ ins(%a, %b : tensor<2x2xi32>, i32)
+ outs(%c : tensor<2x2xi32>) {
+ ^bb0(%x: i32, %scalar: i32, %out: i32):
+ %sum = arith.addi %x, %scalar : i32
+ linalg.yield %sum : i32
+ } -> tensor<2x2xi32>
+
+ return %res : tensor<2x2xi32>
+}
>From 3d446a76e0d331c0548edeb44e07ebd9938c54e5 Mon Sep 17 00:00:00 2001
From: GrumpyPigSkin <oliver61 at live.co.uk>
Date: Fri, 10 Jan 2025 13:18:00 +0000
Subject: [PATCH 2/3] Applied formatting
---
.../Transforms/DecomposeGenericByUnfoldingPermutation.cpp | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
index ce1c21504f1dc7..19ad156d2f6859 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
@@ -160,7 +160,8 @@ LogicalResult DecomposeProjectedPermutation::matchAndRewrite(
if (!map.isProjectedPermutation(false))
return failure();
- // If we have any inputs that aren't memref or ranked tensor types, reject the pattern.
+ // If we have any inputs that aren't memref or ranked tensor types, reject
+ // the pattern.
if (!dyn_cast<ShapedType>(opOperand.get().getType()))
return op->emitError("Expected operand #")
<< opOperand.getOperandNumber()
>From 86a144e85774950761c75d42121ca1a9db92d104 Mon Sep 17 00:00:00 2001
From: GrumpyPigSkin <oliver61 at live.co.uk>
Date: Fri, 10 Jan 2025 15:23:51 +0000
Subject: [PATCH 3/3] Used rewriter.notifyMatchFailure instead
---
.../DecomposeGenericByUnfoldingPermutation.cpp | 12 +++++++-----
...-by-unfolding-projected-permutation-validate.mlir | 3 +--
2 files changed, 8 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
index 19ad156d2f6859..6f7f2a0fdf6280 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
@@ -9,6 +9,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "llvm/Support/FormatVariadic.h"
#include <map>
#include <optional>
#include <utility>
@@ -163,11 +164,12 @@ LogicalResult DecomposeProjectedPermutation::matchAndRewrite(
// If we have any inputs that aren't memref or ranked tensor types, reject
// the pattern.
if (!dyn_cast<ShapedType>(opOperand.get().getType()))
- return op->emitError("Expected operand #")
- << opOperand.getOperandNumber()
- << " to be memref of any type values or ranked tensor of any type "
- "values, but got "
- << opOperand.get().getType();
+ return rewriter.notifyMatchFailure(
+ opOperand.get().getLoc(),
+ llvm::formatv("Expected operand #{0} to be memref of any type values "
+ "or ranked tensor of any type values, but got {1}",
+ opOperand.getOperandNumber(),
+ opOperand.get().getType()));
}
// Decomposing linalg.generic involves creating `tensor.empty`
diff --git a/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation-validate.mlir b/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation-validate.mlir
index 43fdd17e10078c..f2542eee3149d8 100644
--- a/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation-validate.mlir
+++ b/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation-validate.mlir
@@ -1,13 +1,12 @@
// RUN: mlir-opt %s -linalg-specialize-generic-ops -verify-diagnostics
-// Fixes issue: 122094. Verify that the following code causes an error to be produced.
+// Fixes issue: 122094. Verify that the following code compiles without issue.
func.func @test_broadcast_scalar_across_single_tensor() -> tensor<2x2xi32> {
%a = arith.constant dense<2> : tensor<2x2xi32>
%b = arith.constant 42 : i32
%c = tensor.empty() : tensor<2x2xi32>
- // expected-error @+1 {{Expected operand #1 to be memref of any type values or ranked tensor of any type values, but got 'i32'}}
%res = linalg.generic
{
indexing_maps = [
More information about the Mlir-commits
mailing list