[Mlir-commits] [mlir] 8ce6f7d - [mlir][Linalg] NFC - Fail gracefully instead of crashing in SplitReduction

Nicolas Vasilache llvmlistbot at llvm.org
Thu Feb 9 13:05:35 PST 2023


Author: Nicolas Vasilache
Date: 2023-02-09T12:59:57-08:00
New Revision: 8ce6f7dd9340abedb7065b9593b986e4c2ff4c02

URL: https://github.com/llvm/llvm-project/commit/8ce6f7dd9340abedb7065b9593b986e4c2ff4c02
DIFF: https://github.com/llvm/llvm-project/commit/8ce6f7dd9340abedb7065b9593b986e4c2ff4c02.diff

LOG: [mlir][Linalg] NFC - Fail gracefully instead of crashing in SplitReduction

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
index f597273204f24..26dee47b8926a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
@@ -11,8 +11,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include <utility>
 #include <optional>
+#include <utility>
 
 #include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
@@ -42,15 +42,16 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
 
   SmallVector<unsigned> dims;
   op.getReductionDims(dims);
-  assert(dims.size() == 1);
+
+  if (dims.size() != 1)
+    return b.notifyMatchFailure(op, "needs a single reduction dimension");
   unsigned reductionDim = dims[0];
   if (control.innerParallel) {
     insertSplitDimension = reductionDim + 1;
   }
   SmallVector<int64_t, 4> loopRanges = op.getStaticLoopRanges();
   int64_t reductionDimSize = loopRanges[reductionDim];
-  if (reductionDimSize == ShapedType::kDynamic ||
-      reductionDimSize % ratio != 0)
+  if (reductionDimSize == ShapedType::kDynamic || reductionDimSize % ratio != 0)
     return b.notifyMatchFailure(
         op, "Reduction dimension not divisible by split ratio");
   if (op.getNumDpsInits() != 1)
@@ -85,19 +86,22 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
         if (control.innerParallel) {
           newShape.push_back(op.getShape(operand)[idx] / ratio); // reduce
           newShape.push_back(ratio); // parallel (insert)
-          exprs.push_back(b.getAffineDimExpr(dim < insertSplitDimension? dim : dim + 1));
+          exprs.push_back(
+              b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
           exprs.push_back(b.getAffineDimExpr(insertSplitDimension));
         } else {
           newShape.push_back(ratio); // parallel (insert)
           newShape.push_back(op.getShape(operand)[idx] / ratio); // reduce
           exprs.push_back(b.getAffineDimExpr(insertSplitDimension));
-          exprs.push_back(b.getAffineDimExpr(dim < insertSplitDimension? dim : dim + 1));
+          exprs.push_back(
+              b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
         }
         reassociation.push_back({index++, index++});
         continue;
       }
       newShape.push_back(op.getShape(operand)[idx]);
-      exprs.push_back(b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
+      exprs.push_back(
+          b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
       reassociation.push_back({index++});
     }
     newMaps.push_back(


        


More information about the Mlir-commits mailing list