[Mlir-commits] [mlir] 919e459 - [Linalg] Remove Optional from getStaticLoopRanges interface method.

Hanhan Wang llvmlistbot at llvm.org
Tue May 3 05:13:11 PDT 2022


Author: Hanhan Wang
Date: 2022-05-03T05:12:54-07:00
New Revision: 919e459f1ba3bd3f93b50cac1077d685547250e5

URL: https://github.com/llvm/llvm-project/commit/919e459f1ba3bd3f93b50cac1077d685547250e5
DIFF: https://github.com/llvm/llvm-project/commit/919e459f1ba3bd3f93b50cac1077d685547250e5.diff

LOG: [Linalg] Remove Optional from getStaticLoopRanges interface method.

It is very wrong if the ranges can't be infered. It's also checked in
verifyStructuredOpInterface, so we don't need the Optional return type.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D124596

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
    mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
    mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
    mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 439538378dd71..18e65c849e28f 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -999,18 +999,17 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
       /*desc=*/[{
         Returns the statically-known loop ranges. Composes
         `getShapesToLoopsMap()` with the result of `getStaticShape`.
-        Returns None if `getShapesToLoopsMap()` fails. Returns
-        ShapeType::kDynamicSize for non-statically-known loop ranges.
+        Returns ShapeType::kDynamicSize for non-statically-known loop ranges.
+        This is expected to be called by a valid Linalg op
       }],
-      /*retTy=*/"Optional<SmallVector<int64_t, 4>>",
+      /*retTy=*/"SmallVector<int64_t, 4>",
       /*methodName=*/"getStaticLoopRanges",
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
         SmallVector<int64_t> viewSizes = getStaticShape();
         AffineMap invertedMap = getShapesToLoopsMap();
-        if (!invertedMap)
-          return {};
+        assert(invertedMap && "expected a valid Linalg op to call the method");
         return invertedMap.compose(viewSizes);
       }]
     >,

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 4c796723c25a7..c916d15c9d86a 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -732,23 +732,20 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
   }
 
   // Check if given shapes match to inferred shapes.
-  Optional<SmallVector<int64_t, 4>> endLoopRangeValues =
-      linalgOp.getStaticLoopRanges();
-  if (!endLoopRangeValues)
-    return op->emitOpError("unable to find loop range for operation");
-  SmallVector<int64_t, 4> startLoopRangeValues((*endLoopRangeValues).size(), 0);
+  SmallVector<int64_t, 4> endLoopRangeValues = linalgOp.getStaticLoopRanges();
+  SmallVector<int64_t, 4> startLoopRangeValues(endLoopRangeValues.size(), 0);
 
   // Verify only static cases since we can't get exact dimension sizes and loop
   // ranges for dynamic cases in this stage.
-  if (llvm::none_of(*endLoopRangeValues, ShapedType::isDynamic)) {
-    for (int64_t &range : *endLoopRangeValues)
+  if (llvm::none_of(endLoopRangeValues, ShapedType::isDynamic)) {
+    for (int64_t &range : endLoopRangeValues)
       range -= 1;
     for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
       AffineMap indexingMap = linalgOp.getTiedIndexingMap(opOperand);
       SmallVector<int64_t, 4> startIndices =
           indexingMap.compose(startLoopRangeValues);
       SmallVector<int64_t, 4> endIndices =
-          indexingMap.compose(*endLoopRangeValues);
+          indexingMap.compose(endLoopRangeValues);
       ArrayRef<int64_t> shape = linalgOp.getShape(opOperand);
       for (auto dim : llvm::seq<int64_t>(0, shape.size())) {
         // Ignore dynamic dimension or the case that the dimension size is 0

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index cc0ec0866f842..a24fdbe3f2512 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -518,12 +518,8 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
     return failure();
   AffineMap fusedIndexMap = linalgOp.getTiedIndexingMap(fusableOpOperand);
 
-  Optional<SmallVector<int64_t, 4>> originalLoopRange =
-      linalgOp.getStaticLoopRanges();
-  if (!originalLoopRange)
-    return rewriter.notifyMatchFailure(linalgOp, "unable to find loop range");
-  originalLoopExtent.assign(originalLoopRange->begin(),
-                            originalLoopRange->end());
+  SmallVector<int64_t, 4> originalLoopRange = linalgOp.getStaticLoopRanges();
+  originalLoopExtent.assign(originalLoopRange.begin(), originalLoopRange.end());
 
   reassociation.clear();
   expandedShapeMap.clear();

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
index 8f94de4abd1ff..da3c52039d848 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
@@ -73,12 +73,10 @@ FailureOr<LinalgOp> mlir::linalg::splitReduction(
   op.getReductionDims(dims);
   assert(dims.size() == 1);
   unsigned reductionDim = dims[0];
-  Optional<SmallVector<int64_t, 4>> loopRanges = op.getStaticLoopRanges();
-  if (!loopRanges)
-    return b.notifyMatchFailure(op, "Cannot analyze loops");
-  int64_t reductionDimSize = (*loopRanges)[reductionDim];
+  SmallVector<int64_t, 4> loopRanges = op.getStaticLoopRanges();
+  int64_t reductionDimSize = loopRanges[reductionDim];
   if (reductionDimSize == ShapedType::kDynamicSize ||
-      reductionDimSize % ratio != 0 || insertDimIndex >= loopRanges->size())
+      reductionDimSize % ratio != 0 || insertDimIndex >= loopRanges.size())
     return b.notifyMatchFailure(
         op, "Reduction dimension not divisible by split ratio");
   SmallVector<Operation *, 4> combinerOps;


        


More information about the Mlir-commits mailing list