[Mlir-commits] [mlir] 0b33f0d - [mlir][linalg] expose convolution dimension classifier

Alex Zinenko llvmlistbot at llvm.org
Tue Feb 14 02:12:08 PST 2023


Author: Alex Zinenko
Date: 2023-02-14T10:12:01Z
New Revision: 0b33f0d80b832614af68557011786a83a0eef18a

URL: https://github.com/llvm/llvm-project/commit/0b33f0d80b832614af68557011786a83a0eef18a
DIFF: https://github.com/llvm/llvm-project/commit/0b33f0d80b832614af68557011786a83a0eef18a.diff

LOG: [mlir][linalg] expose convolution dimension classifier

Make available through functions in the `linalg::detail` namespace the
classification of Linalg op dimensions as different kinds (batch, image,
channel, etc) of convolution dimensions. This is useful for identifying
which dimensions to target with transformations.

Reviewed By: chelini

Differential Revision: https://reviews.llvm.org/D143584

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
    mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index d8d59b5ce3a81..cb93e8a7bc104 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -42,6 +42,31 @@ bool isaContractionOpInterface(LinalgOp linalgOp);
 
 namespace detail {
 
+/// Result of matching a Linalg generic against the predicates of it being a
+/// convolution.
+enum class MatchConvolutionResult;
+
+/// Positions of a Linalg op loops that correspond to 
diff erent kinds of a
+/// convolution dimension.
+struct ConvolutionDimensions {
+  SmallVector<unsigned, 2> batch;
+  SmallVector<unsigned, 2> outputImage;
+  SmallVector<unsigned, 2> outputChannel;
+  SmallVector<unsigned, 2> filterLoop;
+  SmallVector<unsigned, 2> inputChannel;
+  SmallVector<unsigned, 2> depth;
+};
+
+/// Checks whether `op` conforms to ConvolutionOpInterface and populates
+/// `dimensions` with indexes of the 
diff erent kinds of dimensions when present.
+MatchConvolutionResult
+isConvolutionInterfaceImpl(Operation *op,
+                           ConvolutionDimensions *dimensions = nullptr);
+
+/// Returns the error message corresponding to the convolution checking return
+/// code.
+StringRef getMatchConvolutionMessage(MatchConvolutionResult res);
+
 /// Verify that `op` conforms to ContractionOpInterface.
 LogicalResult verifyContractionInterface(Operation *op);
 

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index e5e0bdd255dc5..a5c6dc627299e 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -268,6 +268,7 @@ static llvm::SmallDenseSet<unsigned> getPreservedDims(AffineMap map) {
   return preservedDims;
 }
 
+namespace mlir::linalg::detail {
 enum class MatchConvolutionResult {
   Success = 0,
   NotLinalgOp,
@@ -278,8 +279,11 @@ enum class MatchConvolutionResult {
   OutputDimsNotParallel,
   NonOutputDimNotReduction
 };
+} // namespace mlir::linalg::detail
 
-static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
+mlir::linalg::detail::MatchConvolutionResult
+mlir::linalg::detail::isConvolutionInterfaceImpl(
+    Operation *op, ConvolutionDimensions *dimensions) {
   auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
   if (!linalgOp)
     return MatchConvolutionResult::NotLinalgOp;
@@ -307,7 +311,7 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
   llvm::SmallDenseSet<unsigned> outputDims =
       getPreservedDims(indexingMaps.back());
   llvm::SmallDenseSet<unsigned> filterDims = getPreservedDims(indexingMaps[1]);
-  // Make sure all loops are charecterized as one of:
+  // Make sure all loops are characterized as one of:
   // - Batch loop : present in output, as non-convolved in input, not present in
   //   filter.
   // - Output image dimension : present in output, convolved dims in input, not
@@ -329,6 +333,8 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
       if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
         return MatchConvolutionResult::OutputDimsNotParallel;
       allLoopDims.insert(outputDim);
+      if (dimensions)
+        dimensions->batch.push_back(outputDim);
       continue;
     }
     if (inputExprWalker.convolvedDims.count(outputDim) &&
@@ -337,6 +343,8 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
       if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
         return MatchConvolutionResult::OutputDimsNotParallel;
       allLoopDims.insert(outputDim);
+      if (dimensions)
+        dimensions->outputImage.push_back(outputDim);
       continue;
     }
     if (!inputExprWalker.convolvedDims.count(outputDim) &&
@@ -346,6 +354,8 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
       if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
         return MatchConvolutionResult::OutputDimsNotParallel;
       allLoopDims.insert(outputDim);
+      if (dimensions)
+        dimensions->outputChannel.push_back(outputDim);
       continue;
     }
     if (inputExprWalker.unConvolvedDims.count(outputDim) &&
@@ -354,6 +364,8 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
       if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
         return MatchConvolutionResult::OutputDimsNotParallel;
       allLoopDims.insert(outputDim);
+      if (dimensions)
+        dimensions->depth.push_back(outputDim);
       continue;
     }
     return MatchConvolutionResult::NonConvolutionLoop;
@@ -363,7 +375,10 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
     if (outputDims.count(filterDim) &&
         !inputExprWalker.unConvolvedDims.count(filterDim) &&
         !inputExprWalker.convolvedDims.count(filterDim)) {
-      // Output channel dimension. THis is already seen, continue;
+      // Output channel dimension. This is already seen, continue;
+      assert((!dimensions ||
+              llvm::is_contained(dimensions->outputChannel, filterDim)) &&
+             "expected output channel to have been found from output dims");
       continue;
     }
     if (inputExprWalker.convolvedDims.count(filterDim) &&
@@ -374,6 +389,8 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
       if (allLoopDims.count(filterDim))
         return MatchConvolutionResult::NonConvolutionLoop;
       allLoopDims.insert(filterDim);
+      if (dimensions)
+        dimensions->filterLoop.push_back(filterDim);
       continue;
     }
     if (inputExprWalker.unConvolvedDims.count(filterDim) &&
@@ -384,11 +401,16 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
       if (allLoopDims.count(filterDim))
         return MatchConvolutionResult::NonConvolutionLoop;
       allLoopDims.insert(filterDim);
+      if (dimensions)
+        dimensions->inputChannel.push_back(filterDim);
       continue;
     }
     if (inputExprWalker.unConvolvedDims.count(filterDim) &&
         outputDims.count(filterDim)) {
       // Depthwise loop. Already seen.
+      assert(
+          (!dimensions || llvm::is_contained(dimensions->depth, filterDim)) &&
+          "expected depthwise dimension to have been found from output dims");
       continue;
     }
     return MatchConvolutionResult::NonConvolutionLoop;
@@ -397,32 +419,45 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
   if (allLoopDims.size() != linalgOp.getNumLoops())
     return MatchConvolutionResult::NonConvolutionLoop;
 
+  if (dimensions) {
+    assert(dimensions->batch.size() + dimensions->outputImage.size() +
+                   dimensions->outputChannel.size() +
+                   dimensions->filterLoop.size() +
+                   dimensions->inputChannel.size() + dimensions->depth.size() ==
+               linalgOp.getNumLoops() &&
+           "expected all loops to be classified");
+  }
+
   return MatchConvolutionResult::Success;
 }
 
-LogicalResult mlir::linalg::detail::verifyConvolutionInterface(Operation *op) {
-  auto res = isConvolutionInterfaceImpl(op);
-  if (res == MatchConvolutionResult::NotLinalgOp)
-    return op->emitError("expected a LinalgOp");
-  if (res == MatchConvolutionResult::WrongNumOperands)
-    return op->emitError("expected op with 2 inputs and 1 output");
-  if (res == MatchConvolutionResult::WrongInputIndexingMap)
-    return op->emitError("unexpected input index map for convolutions");
-  if (res == MatchConvolutionResult::NotProjectedPermutations) {
-    return op->emitError(
-        "expected output/filter indexing maps to be projected permutations");
-  }
-  if (res == MatchConvolutionResult::NonConvolutionLoop) {
-    return op->emitError("unexpected loop dimension for convolution op");
-  }
-  if (res == MatchConvolutionResult::OutputDimsNotParallel) {
-    return op->emitError(
-        "expected all iterators used to access outputs to be parallel");
-  }
-  if (res == MatchConvolutionResult::NonOutputDimNotReduction) {
-    return op->emitError(
-        "expected all iterators not used to access outputs to be reduction");
+StringRef
+mlir::linalg::detail::getMatchConvolutionMessage(MatchConvolutionResult res) {
+  switch (res) {
+  case MatchConvolutionResult::NotLinalgOp:
+    return "expected a LinalgOp";
+  case MatchConvolutionResult::WrongNumOperands:
+    return "expected op with 2 inputs and 1 output";
+  case MatchConvolutionResult::WrongInputIndexingMap:
+    return "unexpected input index map for convolutions";
+  case MatchConvolutionResult::NotProjectedPermutations:
+    return "expected output/filter indexing maps to be projected permutations";
+  case MatchConvolutionResult::NonConvolutionLoop:
+    return "unexpected loop dimension for convolution op";
+  case MatchConvolutionResult::OutputDimsNotParallel:
+    return "expected all iterators used to access outputs to be parallel";
+  case MatchConvolutionResult::NonOutputDimNotReduction:
+    return "expected all iterators not used to access outputs to be reduction";
+  case MatchConvolutionResult::Success:
+    return "";
   }
+  llvm_unreachable("unhandled MatchConvolutionResult case");
+}
+
+LogicalResult mlir::linalg::detail::verifyConvolutionInterface(Operation *op) {
+  MatchConvolutionResult res = isConvolutionInterfaceImpl(op);
+  if (res != MatchConvolutionResult::Success)
+    return op->emitError(getMatchConvolutionMessage(res));
   return success();
 }
 


        


More information about the Mlir-commits mailing list