[Mlir-commits] [mlir] [mlir][linalg] Add Check for Reduction Operation in Contraction Body (PR #123134)

Ayokunle Amodu llvmlistbot at llvm.org
Wed Jan 15 15:01:55 PST 2025


https://github.com/ayokunle321 updated https://github.com/llvm/llvm-project/pull/123134

>From 80200160d4c4288f604046440e6dd36a924f0fdc Mon Sep 17 00:00:00 2001
From: Ayokunle Amodu <121697771+ayokunle321 at users.noreply.github.com>
Date: Wed, 15 Jan 2025 15:32:35 -0700
Subject: [PATCH 1/2] added check for reduction op in contraction body

---
 mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp | 6 ++++++
 1 file changed, 6 insertions(+)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index caf9cdb3a3eb4f..14f8f9e8fdd3b4 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -281,6 +281,12 @@ bool mlir::linalg::detail::isContractionBody(
 
   Value yielded = getSourceSkipUnary(terminator->getOperand(0));
   Operation *reductionOp = yielded.getDefiningOp();
+
+  if (!reductionOp){
+    errs << "expected reduction op in body";
+    return false;
+  }
+
   if (reductionOp->getNumResults() != 1 || reductionOp->getNumOperands() != 2) {
     errs << "expected reduction op to be binary";
     return false;

>From 4ed59274eebea0d0601b37d5b138d4ed57371617 Mon Sep 17 00:00:00 2001
From: Ayokunle Amodu <121697771+ayokunle321 at users.noreply.github.com>
Date: Wed, 15 Jan 2025 16:01:19 -0700
Subject: [PATCH 2/2] fix code style

---
 mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 14f8f9e8fdd3b4..91165ddeb88870 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -282,7 +282,7 @@ bool mlir::linalg::detail::isContractionBody(
   Value yielded = getSourceSkipUnary(terminator->getOperand(0));
   Operation *reductionOp = yielded.getDefiningOp();
 
-  if (!reductionOp){
+  if (!reductionOp) {
     errs << "expected reduction op in body";
     return false;
   }



More information about the Mlir-commits mailing list