[Mlir-commits] [mlir] [mlir][linalg] Add Check for Reduction Operation in Contraction Body (PR #123134)

Ayokunle Amodu llvmlistbot at llvm.org
Thu Jan 16 07:08:41 PST 2025


https://github.com/ayokunle321 updated https://github.com/llvm/llvm-project/pull/123134

>From 80200160d4c4288f604046440e6dd36a924f0fdc Mon Sep 17 00:00:00 2001
From: Ayokunle Amodu <121697771+ayokunle321 at users.noreply.github.com>
Date: Wed, 15 Jan 2025 15:32:35 -0700
Subject: [PATCH 1/6] added check for reduction op in contraction body

---
 mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp | 6 ++++++
 1 file changed, 6 insertions(+)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index caf9cdb3a3eb4f..14f8f9e8fdd3b4 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -281,6 +281,12 @@ bool mlir::linalg::detail::isContractionBody(
 
   Value yielded = getSourceSkipUnary(terminator->getOperand(0));
   Operation *reductionOp = yielded.getDefiningOp();
+
+  if (!reductionOp){
+    errs << "expected reduction op in body";
+    return false;
+  }
+
   if (reductionOp->getNumResults() != 1 || reductionOp->getNumOperands() != 2) {
     errs << "expected reduction op to be binary";
     return false;

>From 4ed59274eebea0d0601b37d5b138d4ed57371617 Mon Sep 17 00:00:00 2001
From: Ayokunle Amodu <121697771+ayokunle321 at users.noreply.github.com>
Date: Wed, 15 Jan 2025 16:01:19 -0700
Subject: [PATCH 2/6] fix code style

---
 mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 14f8f9e8fdd3b4..91165ddeb88870 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -282,7 +282,7 @@ bool mlir::linalg::detail::isContractionBody(
   Value yielded = getSourceSkipUnary(terminator->getOperand(0));
   Operation *reductionOp = yielded.getDefiningOp();
 
-  if (!reductionOp){
+  if (!reductionOp) {
     errs << "expected reduction op in body";
     return false;
   }

>From d72f50420e5118c1f26dbfff9c4eac4a2e872f0a Mon Sep 17 00:00:00 2001
From: Ayokunle Amodu <121697771+ayokunle321 at users.noreply.github.com>
Date: Thu, 16 Jan 2025 07:39:53 -0700
Subject: [PATCH 3/6] add test case to specialize-generic-ops-fail.mlir

---
 .../Linalg/specialize-generic-ops-fail.mlir   | 36 +++++++++++++++++++
 1 file changed, 36 insertions(+)

diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir
index 542a7ed4a198b8..9993046f1325f1 100644
--- a/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir
@@ -14,3 +14,39 @@ func.func @transpose_and_broadcast(%arg0: tensor<7x8xf32>, %arg1: tensor<8x7x9xf
   } -> tensor<8x7x9xf32>
   return %0 : tensor<8x7x9xf32>
 }
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+// Verifies that the pass crashes when trying to specialize a linalg.generic op with 
+// a reduction iterator when there is no valid reduction operation in the contraction body.
+// CHECK-LABEL: @specialize_reduction
+func.func private @specialize_reduction(%arg0: tensor<1x31x8xi32>) -> tensor<31x31xi32> {
+  %c-2351_i32 = arith.constant -2351 : i32
+  %c0_i32 = arith.constant 0 : i32
+  %0 = tensor.empty() : tensor<31x8xi32>
+  %1 = linalg.generic 
+        {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]}
+        outs(%0 : tensor<31x8xi32>) {
+    ^bb0(%out: i32):
+      linalg.yield %c-2351_i32 : i32
+  } -> tensor<31x8xi32>
+  %2 = tensor.empty() : tensor<31x31xi32>
+  %3 = linalg.generic 
+        {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]}
+        outs(%2 : tensor<31x31xi32>) {
+    ^bb0(%out: i32):
+      linalg.yield %c0_i32 : i32
+  } -> tensor<31x31xi32>
+  %4 = linalg.generic 
+        {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]}
+        ins(%1, %1 : tensor<31x8xi32>, tensor<31x8xi32>) 
+        outs(%3 : tensor<31x31xi32>) {
+    ^bb0(%in: i32, %in_0: i32, %out: i32):
+      linalg.yield %out : i32
+  } -> tensor<31x31xi32>
+  return %4 : tensor<31x31xi32>
+}
\ No newline at end of file

>From 9472c46c18ae32a905c6f53c10cf2d2e73bfd996 Mon Sep 17 00:00:00 2001
From: Ayokunle Amodu <121697771+ayokunle321 at users.noreply.github.com>
Date: Thu, 16 Jan 2025 07:51:38 -0700
Subject: [PATCH 4/6] edit test description, need to minimize test

---
 mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir
index 9993046f1325f1..abe7a33c3a9e86 100644
--- a/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir
@@ -21,8 +21,8 @@ func.func @transpose_and_broadcast(%arg0: tensor<7x8xf32>, %arg1: tensor<8x7x9xf
 #map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
 #map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
 #map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
-// Verifies that the pass crashes when trying to specialize a linalg.generic op with 
-// a reduction iterator when there is no valid reduction operation in the contraction body.
+// This tests checks that the pass does not crash when trying to specialize a 
+// contraction-like generic op with no reduction operation in its body.
 // CHECK-LABEL: @specialize_reduction
 func.func private @specialize_reduction(%arg0: tensor<1x31x8xi32>) -> tensor<31x31xi32> {
   %c-2351_i32 = arith.constant -2351 : i32

>From 57ea829fd813ae1747fb8143180baf7efb65e220 Mon Sep 17 00:00:00 2001
From: Ayokunle Amodu <121697771+ayokunle321 at users.noreply.github.com>
Date: Thu, 16 Jan 2025 08:02:57 -0700
Subject: [PATCH 5/6] minimize test case

---
 .../Linalg/specialize-generic-ops-fail.mlir   | 42 ++++++-------------
 1 file changed, 13 insertions(+), 29 deletions(-)

diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir
index abe7a33c3a9e86..683b6965293bc4 100644
--- a/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir
@@ -17,36 +17,20 @@ func.func @transpose_and_broadcast(%arg0: tensor<7x8xf32>, %arg1: tensor<8x7x9xf
 
 // -----
 
-#map = affine_map<(d0, d1) -> (d0, d1)>
-#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
-#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
-#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
 // This tests checks that the pass does not crash when trying to specialize a 
 // contraction-like generic op with no reduction operation in its body.
-// CHECK-LABEL: @specialize_reduction
-func.func private @specialize_reduction(%arg0: tensor<1x31x8xi32>) -> tensor<31x31xi32> {
-  %c-2351_i32 = arith.constant -2351 : i32
-  %c0_i32 = arith.constant 0 : i32
-  %0 = tensor.empty() : tensor<31x8xi32>
+// CHECK-LABEL: @test_fake_contraction
+// CHECK: linalg.generic
+func.func @test_fake_contraction(%arg0: tensor<4x4xi32>) -> tensor<4x4xi32> {
+  %0 = tensor.empty() : tensor<4x4xi32>
   %1 = linalg.generic 
-        {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]}
-        outs(%0 : tensor<31x8xi32>) {
-    ^bb0(%out: i32):
-      linalg.yield %c-2351_i32 : i32
-  } -> tensor<31x8xi32>
-  %2 = tensor.empty() : tensor<31x31xi32>
-  %3 = linalg.generic 
-        {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]}
-        outs(%2 : tensor<31x31xi32>) {
-    ^bb0(%out: i32):
-      linalg.yield %c0_i32 : i32
-  } -> tensor<31x31xi32>
-  %4 = linalg.generic 
-        {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]}
-        ins(%1, %1 : tensor<31x8xi32>, tensor<31x8xi32>) 
-        outs(%3 : tensor<31x31xi32>) {
-    ^bb0(%in: i32, %in_0: i32, %out: i32):
-      linalg.yield %out : i32
-  } -> tensor<31x31xi32>
-  return %4 : tensor<31x31xi32>
+        {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} 
+        ins(%arg0, %arg0 : tensor<4x4xi32>, tensor<4x4xi32>) outs(%0 : tensor<4x4xi32>) {
+        ^bb0(%in0: i32, %in1: i32, %out: i32):
+           linalg.yield %out : i32
+  } -> tensor<4x4xi32>
+  return %1 : tensor<4x4xi32>
 }
\ No newline at end of file

>From e3686d952147900c3ca5c77773dca6bac217020f Mon Sep 17 00:00:00 2001
From: Ayokunle Amodu <121697771+ayokunle321 at users.noreply.github.com>
Date: Thu, 16 Jan 2025 08:08:22 -0700
Subject: [PATCH 6/6] added new line to test case

---
 mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir
index 683b6965293bc4..3c5649fb63f627 100644
--- a/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir
@@ -33,4 +33,4 @@ func.func @test_fake_contraction(%arg0: tensor<4x4xi32>) -> tensor<4x4xi32> {
            linalg.yield %out : i32
   } -> tensor<4x4xi32>
   return %1 : tensor<4x4xi32>
-}
\ No newline at end of file
+}



More information about the Mlir-commits mailing list