[Mlir-commits] [mlir] Add getters for multi dim loop variables in LoopLikeOpInterface (PR #94516)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jun 5 11:30:59 PDT 2024


https://github.com/srcarroll updated https://github.com/llvm/llvm-project/pull/94516

>From 5020e498440b0016adef7e99806aa55c4837b441 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Wed, 5 Jun 2024 13:08:22 -0500
Subject: [PATCH] Add getters for multi dim loop variables in
 LoopLikeOpInterface

---
 .../mlir/Dialect/Affine/IR/AffineOps.td       |  4 +-
 mlir/include/mlir/Dialect/SCF/IR/SCFOps.td    | 37 ++--------
 .../mlir/Interfaces/LoopLikeInterface.td      | 65 +++++++++++------
 mlir/lib/Dialect/Affine/IR/AffineOps.cpp      | 20 +++---
 mlir/lib/Dialect/SCF/IR/SCF.cpp               | 70 +++++++------------
 .../Dialect/SCF/LoopLikeSCFOpsTest.cpp        |  8 +++
 6 files changed, 97 insertions(+), 107 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index 3640055ea8da8..bb2c29b5733b8 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -118,8 +118,8 @@ def AffineForOp : Affine_Op<"for",
     [AttrSizedOperandSegments, AutomaticAllocationScope,
      ImplicitAffineTerminator, ConditionallySpeculatable,
      RecursiveMemoryEffects, DeclareOpInterfaceMethods<LoopLikeOpInterface,
-     ["getSingleInductionVar", "getSingleLowerBound", "getSingleStep",
-      "getSingleUpperBound", "getYieldedValuesMutable",
+     ["getInductionVars", "getMixedLowerBound", "getMixedStep",
+      "getMixedUpperBound", "getYieldedValuesMutable",
       "replaceWithAdditionalYields"]>,
      DeclareOpInterfaceMethods<RegionBranchOpInterface,
      ["getEntrySuccessorOperands"]>]> {
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 0b063aa772bab..3b28ca8b21d0f 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -136,8 +136,8 @@ def ExecuteRegionOp : SCF_Op<"execute_region", [
 def ForOp : SCF_Op<"for",
       [AutomaticAllocationScope, DeclareOpInterfaceMethods<LoopLikeOpInterface,
        ["getInitsMutable", "getLoopResults", "getRegionIterArgs",
-        "getSingleInductionVar", "getSingleLowerBound", "getSingleStep",
-        "getSingleUpperBound", "getYieldedValuesMutable",
+        "getInductionVars", "getMixedLowerBound", "getMixedStep",
+        "getMixedUpperBound", "getYieldedValuesMutable",
         "promoteIfSingleIteration", "replaceWithAdditionalYields",
         "yieldTiledValuesAndReplace"]>,
        AllTypesMatch<["lowerBound", "upperBound", "step"]>,
@@ -301,8 +301,8 @@ def ForallOp : SCF_Op<"forall", [
        AttrSizedOperandSegments,
        AutomaticAllocationScope,
        DeclareOpInterfaceMethods<LoopLikeOpInterface,
-          ["getInitsMutable", "getRegionIterArgs", "getSingleInductionVar", 
-           "getSingleLowerBound", "getSingleUpperBound", "getSingleStep",
+          ["getInitsMutable", "getRegionIterArgs", "getInductionVars", 
+           "getMixedLowerBound", "getMixedUpperBound", "getMixedStep",
            "promoteIfSingleIteration", "yieldTiledValuesAndReplace"]>,
        RecursiveMemoryEffects,
        SingleBlockImplicitTerminator<"scf::InParallelOp">,
@@ -510,24 +510,6 @@ def ForallOp : SCF_Op<"forall", [
   ];
 
   let extraClassDeclaration = [{
-    // Get lower bounds as OpFoldResult.
-    SmallVector<OpFoldResult> getMixedLowerBound() {
-      Builder b(getOperation()->getContext());
-      return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(), b);
-    }
-
-    // Get upper bounds as OpFoldResult.
-    SmallVector<OpFoldResult> getMixedUpperBound() {
-      Builder b(getOperation()->getContext());
-      return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(), b);
-    }
-
-    // Get steps as OpFoldResult.
-    SmallVector<OpFoldResult> getMixedStep() {
-      Builder b(getOperation()->getContext());
-      return getMixedValues(getStaticStep(), getDynamicStep(), b);
-    }
-
     /// Get lower bounds as values.
     SmallVector<Value> getLowerBound(OpBuilder &b) {
       return getValueOrCreateConstantIndexOp(b, getLoc(), getMixedLowerBound());
@@ -584,10 +566,6 @@ def ForallOp : SCF_Op<"forall", [
                                     getNumDynamicControlOperands() + getRank());
     }
 
-    ::mlir::ValueRange getInductionVars() {
-      return getBody()->getArguments().take_front(getRank());
-    }
-
     ::mlir::Value getInductionVar(int64_t idx) {
       return getInductionVars()[idx];
     }
@@ -765,8 +743,8 @@ def IfOp : SCF_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterface, [
 def ParallelOp : SCF_Op<"parallel",
     [AutomaticAllocationScope,
      AttrSizedOperandSegments,
-     DeclareOpInterfaceMethods<LoopLikeOpInterface, ["getSingleInductionVar",
-          "getSingleLowerBound", "getSingleUpperBound", "getSingleStep"]>,
+     DeclareOpInterfaceMethods<LoopLikeOpInterface, ["getInductionVars",
+          "getMixedLowerBound", "getMixedUpperBound", "getMixedStep"]>,
      RecursiveMemoryEffects,
      DeclareOpInterfaceMethods<RegionBranchOpInterface>,
      SingleBlockImplicitTerminator<"scf::ReduceOp">,
@@ -846,9 +824,6 @@ def ParallelOp : SCF_Op<"parallel",
   ];
 
   let extraClassDeclaration = [{
-    ValueRange getInductionVars() {
-      return getBody()->getArguments();
-    }
     unsigned getNumLoops() { return getStep().size(); }
     unsigned getNumReductions() { return getInitVals().size(); }
   }];
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index f0dc6e60eba58..813779c852027 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -93,51 +93,47 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
       }]
     >,
     InterfaceMethod<[{
-        If there is a single induction variable return it, otherwise return
-        std::nullopt.
+        Return all induction variables.
       }],
-      /*retTy=*/"::std::optional<::mlir::Value>",
-      /*methodName=*/"getSingleInductionVar",
+      /*retTy=*/"::mlir::ValueRange",
+      /*methodName=*/"getInductionVars",
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        return std::nullopt;
+        return {};
       }]
     >,
     InterfaceMethod<[{
-        Return the single lower bound value or attribute if it exists, otherwise
-        return std::nullopt.
+        Return all lower bounds.
       }],
-      /*retTy=*/"::std::optional<::mlir::OpFoldResult>",
-      /*methodName=*/"getSingleLowerBound",
+      /*retTy=*/"::llvm::SmallVector<::mlir::OpFoldResult>",
+      /*methodName=*/"getMixedLowerBound",
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        return std::nullopt;
+        return {};
       }]
     >,
     InterfaceMethod<[{
-        Return the single step value or attribute if it exists, otherwise
-        return std::nullopt.
+        Return all steps.
       }],
-      /*retTy=*/"::std::optional<::mlir::OpFoldResult>",
-      /*methodName=*/"getSingleStep",
+      /*retTy=*/"::llvm::SmallVector<::mlir::OpFoldResult>",
+      /*methodName=*/"getMixedStep",
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        return std::nullopt;
+        return {};
       }]
     >,
     InterfaceMethod<[{
-        Return the single upper bound value or attribute if it exists, otherwise
-        return std::nullopt.
+        Return all upper bounds.
       }],
-      /*retTy=*/"::std::optional<::mlir::OpFoldResult>",
-      /*methodName=*/"getSingleUpperBound",
+      /*retTy=*/"::llvm::SmallVector<::mlir::OpFoldResult>",
+      /*methodName=*/"getMixedUpperBound",
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        return std::nullopt;
+        return {};
       }]
     >,
     InterfaceMethod<[{
@@ -235,6 +231,35 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
   }];
 
   let extraSharedClassDeclaration = [{
+    /// If there is a single induction variable return it, otherwise return
+    /// std::nullopt.
+    ::std::optional<::mlir::Value> getSingleInductionVar() {
+      if (this->getInductionVars().size() == 1)
+          return this->getInductionVars()[0];
+        return std::nullopt;
+    }
+    /// Return the single lower bound value or attribute if it exists, otherwise
+    /// return std::nullopt.
+    ::std::optional<::mlir::OpFoldResult> getSingleLowerBound() {
+      if (this->getMixedLowerBound().size() == 1)
+          return this->getMixedLowerBound()[0];
+        return std::nullopt;
+    }
+    /// Return the single step value or attribute if it exists, otherwise
+    /// return std::nullopt.
+    ::std::optional<::mlir::OpFoldResult> getSingleStep() {
+      if (this->getMixedStep().size() == 1)
+          return this->getMixedStep()[0];
+        return std::nullopt;
+    }
+    /// Return the single upper bound value or attribute if it exists, otherwise
+    /// return std::nullopt.
+    ::std::optional<::mlir::OpFoldResult> getSingleUpperBound() {
+      if (this->getMixedUpperBound().size() == 1)
+          return this->getMixedUpperBound()[0];
+        return std::nullopt;
+    }
+
     /// Append the specified additional "init" operands: replace this loop with
     /// a new loop that has the additional init operands. The loop body of this
     /// loop is moved over to the new loop.
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 2e31487bd55a0..746a9c919560c 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -2454,27 +2454,25 @@ bool AffineForOp::matchingBoundOperandList() {
 
 SmallVector<Region *> AffineForOp::getLoopRegions() { return {&getRegion()}; }
 
-std::optional<Value> AffineForOp::getSingleInductionVar() {
-  return getInductionVar();
-}
+ValueRange AffineForOp::getInductionVars() { return {getInductionVar()}; }
 
-std::optional<OpFoldResult> AffineForOp::getSingleLowerBound() {
+SmallVector<OpFoldResult> AffineForOp::getMixedLowerBound() {
   if (!hasConstantLowerBound())
-    return std::nullopt;
+    return {};
   OpBuilder b(getContext());
-  return OpFoldResult(b.getI64IntegerAttr(getConstantLowerBound()));
+  return {OpFoldResult(b.getI64IntegerAttr(getConstantLowerBound()))};
 }
 
-std::optional<OpFoldResult> AffineForOp::getSingleStep() {
+SmallVector<OpFoldResult> AffineForOp::getMixedStep() {
   OpBuilder b(getContext());
-  return OpFoldResult(b.getI64IntegerAttr(getStepAsInt()));
+  return {OpFoldResult(b.getI64IntegerAttr(getStepAsInt()))};
 }
 
-std::optional<OpFoldResult> AffineForOp::getSingleUpperBound() {
+SmallVector<OpFoldResult> AffineForOp::getMixedUpperBound() {
   if (!hasConstantUpperBound())
-    return std::nullopt;
+    return {};
   OpBuilder b(getContext());
-  return OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()));
+  return {OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()))};
 }
 
 FailureOr<LoopLikeOpInterface> AffineForOp::replaceWithAdditionalYields(
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 107fd0690f193..e275ff1849c10 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -378,20 +378,18 @@ LogicalResult ForOp::verifyRegions() {
   return success();
 }
 
-std::optional<Value> ForOp::getSingleInductionVar() {
-  return getInductionVar();
-}
+ValueRange ForOp::getInductionVars() { return {getInductionVar()}; }
 
-std::optional<OpFoldResult> ForOp::getSingleLowerBound() {
-  return OpFoldResult(getLowerBound());
+SmallVector<OpFoldResult> ForOp::getMixedLowerBound() {
+  return {OpFoldResult(getLowerBound())};
 }
 
-std::optional<OpFoldResult> ForOp::getSingleStep() {
-  return OpFoldResult(getStep());
+SmallVector<OpFoldResult> ForOp::getMixedStep() {
+  return {OpFoldResult(getStep())};
 }
 
-std::optional<OpFoldResult> ForOp::getSingleUpperBound() {
-  return OpFoldResult(getUpperBound());
+SmallVector<OpFoldResult> ForOp::getMixedUpperBound() {
+  return {OpFoldResult(getUpperBound())};
 }
 
 std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
@@ -1428,28 +1426,26 @@ SmallVector<Operation *> ForallOp::getCombiningOps(BlockArgument bbArg) {
   return storeOps;
 }
 
-std::optional<Value> ForallOp::getSingleInductionVar() {
-  if (getRank() != 1)
-    return std::nullopt;
-  return getInductionVar(0);
+ValueRange ForallOp::getInductionVars() {
+  return getBody()->getArguments().take_front(getRank());
 }
 
-std::optional<OpFoldResult> ForallOp::getSingleLowerBound() {
-  if (getRank() != 1)
-    return std::nullopt;
-  return getMixedLowerBound()[0];
+// Get lower bounds as OpFoldResult.
+SmallVector<OpFoldResult> ForallOp::getMixedLowerBound() {
+  Builder b(getOperation()->getContext());
+  return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(), b);
 }
 
-std::optional<OpFoldResult> ForallOp::getSingleUpperBound() {
-  if (getRank() != 1)
-    return std::nullopt;
-  return getMixedUpperBound()[0];
+// Get upper bounds as OpFoldResult.
+SmallVector<OpFoldResult> ForallOp::getMixedUpperBound() {
+  Builder b(getOperation()->getContext());
+  return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(), b);
 }
 
-std::optional<OpFoldResult> ForallOp::getSingleStep() {
-  if (getRank() != 1)
-    return std::nullopt;
-  return getMixedStep()[0];
+// Get steps as OpFoldResult.
+SmallVector<OpFoldResult> ForallOp::getMixedStep() {
+  Builder b(getOperation()->getContext());
+  return getMixedValues(getStaticStep(), getDynamicStep(), b);
 }
 
 ForallOp mlir::scf::getForallOpThreadIndexOwner(Value val) {
@@ -3008,29 +3004,17 @@ void ParallelOp::print(OpAsmPrinter &p) {
 
 SmallVector<Region *> ParallelOp::getLoopRegions() { return {&getRegion()}; }
 
-std::optional<Value> ParallelOp::getSingleInductionVar() {
-  if (getNumLoops() != 1)
-    return std::nullopt;
-  return getBody()->getArgument(0);
-}
+ValueRange ParallelOp::getInductionVars() { return getBody()->getArguments(); }
 
-std::optional<OpFoldResult> ParallelOp::getSingleLowerBound() {
-  if (getNumLoops() != 1)
-    return std::nullopt;
-  return getLowerBound()[0];
+SmallVector<OpFoldResult> ParallelOp::getMixedLowerBound() {
+  return getLowerBound();
 }
 
-std::optional<OpFoldResult> ParallelOp::getSingleUpperBound() {
-  if (getNumLoops() != 1)
-    return std::nullopt;
-  return getUpperBound()[0];
+SmallVector<OpFoldResult> ParallelOp::getMixedUpperBound() {
+  return getUpperBound();
 }
 
-std::optional<OpFoldResult> ParallelOp::getSingleStep() {
-  if (getNumLoops() != 1)
-    return std::nullopt;
-  return getStep()[0];
-}
+SmallVector<OpFoldResult> ParallelOp::getMixedStep() { return getStep(); }
 
 ParallelOp mlir::scf::getParallelForInductionVarOwner(Value val) {
   auto ivArg = llvm::dyn_cast<BlockArgument>(val);
diff --git a/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp b/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp
index 6bc0fd6113b9b..d8cdb213070da 100644
--- a/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp
+++ b/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp
@@ -36,6 +36,10 @@ class SCFLoopLikeTest : public ::testing::Test {
     std::optional<OpFoldResult> maybeIndVar =
         loopLikeOp.getSingleInductionVar();
     EXPECT_TRUE(maybeIndVar.has_value());
+    EXPECT_EQ(loopLikeOp.getInductionVars().size(), 1u);
+    EXPECT_EQ(loopLikeOp.getMixedLowerBound().size(), 1u);
+    EXPECT_EQ(loopLikeOp.getMixedStep().size(), 1u);
+    EXPECT_EQ(loopLikeOp.getMixedLowerBound().size(), 1u);
   }
 
   void checkMultidimensional(LoopLikeOpInterface loopLikeOp) {
@@ -48,6 +52,10 @@ class SCFLoopLikeTest : public ::testing::Test {
     std::optional<OpFoldResult> maybeIndVar =
         loopLikeOp.getSingleInductionVar();
     EXPECT_FALSE(maybeIndVar.has_value());
+    EXPECT_EQ(loopLikeOp.getInductionVars().size(), 2u);
+    EXPECT_EQ(loopLikeOp.getMixedLowerBound().size(), 2u);
+    EXPECT_EQ(loopLikeOp.getMixedStep().size(), 2u);
+    EXPECT_EQ(loopLikeOp.getMixedLowerBound().size(), 2u);
   }
 
   MLIRContext context;



More information about the Mlir-commits mailing list