[Mlir-commits] [mlir] a703d15 - [mlir][Index][NFC] Migrate index dialect to the new fold API

Markus Böck llvmlistbot at llvm.org
Wed Jan 11 12:47:31 PST 2023


Author: Markus Böck
Date: 2023-01-11T21:47:25+01:00
New Revision: a703d15519efac35518fb38e756b689bc3766781

URL: https://github.com/llvm/llvm-project/commit/a703d15519efac35518fb38e756b689bc3766781
DIFF: https://github.com/llvm/llvm-project/commit/a703d15519efac35518fb38e756b689bc3766781.diff

LOG: [mlir][Index][NFC] Migrate index dialect to the new fold API

See https://discourse.llvm.org/t/psa-new-improved-fold-method-signature-has-landed-please-update-your-downstream-projects/67618 for context

Similar to the patch for the arith dialect, the index dialects fold implementations make heavy use of generic fold functions, hence the change being comparatively mechanical and mostly changing the function signature.

Differential Revision: https://reviews.llvm.org/D141502

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Index/IR/IndexDialect.td
    mlir/lib/Dialect/Index/IR/IndexOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Index/IR/IndexDialect.td b/mlir/include/mlir/Dialect/Index/IR/IndexDialect.td
index be0fea79ee392..7e1130c86add4 100644
--- a/mlir/include/mlir/Dialect/Index/IR/IndexDialect.td
+++ b/mlir/include/mlir/Dialect/Index/IR/IndexDialect.td
@@ -83,6 +83,7 @@ def IndexDialect : Dialect {
 
   let hasConstantMaterializer = 1;
   let useDefaultAttributePrinterParser = 1;
+  let useFoldAPI = kEmitFoldAdaptorFolder;
 }
 
 #endif // INDEX_DIALECT

diff  --git a/mlir/lib/Dialect/Index/IR/IndexOps.cpp b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
index dee6025ffd978..598bb9a78ebf1 100644
--- a/mlir/lib/Dialect/Index/IR/IndexOps.cpp
+++ b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
@@ -115,36 +115,40 @@ static OpFoldResult foldBinaryOpChecked(
 // AddOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult AddOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
   return foldBinaryOpUnchecked(
-      operands, [](const APInt &lhs, const APInt &rhs) { return lhs + rhs; });
+      adaptor.getOperands(),
+      [](const APInt &lhs, const APInt &rhs) { return lhs + rhs; });
 }
 
 //===----------------------------------------------------------------------===//
 // SubOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult SubOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
   return foldBinaryOpUnchecked(
-      operands, [](const APInt &lhs, const APInt &rhs) { return lhs - rhs; });
+      adaptor.getOperands(),
+      [](const APInt &lhs, const APInt &rhs) { return lhs - rhs; });
 }
 
 //===----------------------------------------------------------------------===//
 // MulOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
   return foldBinaryOpUnchecked(
-      operands, [](const APInt &lhs, const APInt &rhs) { return lhs * rhs; });
+      adaptor.getOperands(),
+      [](const APInt &lhs, const APInt &rhs) { return lhs * rhs; });
 }
 
 //===----------------------------------------------------------------------===//
 // DivSOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult DivSOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult DivSOp::fold(FoldAdaptor adaptor) {
   return foldBinaryOpChecked(
-      operands, [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
+      adaptor.getOperands(),
+      [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
         // Don't fold division by zero.
         if (rhs.isZero())
           return std::nullopt;
@@ -156,9 +160,10 @@ OpFoldResult DivSOp::fold(ArrayRef<Attribute> operands) {
 // DivUOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult DivUOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult DivUOp::fold(FoldAdaptor adaptor) {
   return foldBinaryOpChecked(
-      operands, [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
+      adaptor.getOperands(),
+      [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
         // Don't fold division by zero.
         if (rhs.isZero())
           return std::nullopt;
@@ -193,18 +198,19 @@ static Optional<APInt> calculateCeilDivS(const APInt &n, const APInt &m) {
   return (n + x).sdiv(m) + 1;
 }
 
-OpFoldResult CeilDivSOp::fold(ArrayRef<Attribute> operands) {
-  return foldBinaryOpChecked(operands, calculateCeilDivS);
+OpFoldResult CeilDivSOp::fold(FoldAdaptor adaptor) {
+  return foldBinaryOpChecked(adaptor.getOperands(), calculateCeilDivS);
 }
 
 //===----------------------------------------------------------------------===//
 // CeilDivUOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult CeilDivUOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult CeilDivUOp::fold(FoldAdaptor adaptor) {
   // Compute `ceildivu(n, m)` as `n == 0 ? 0 : (n-1)/m + 1`.
   return foldBinaryOpChecked(
-      operands, [](const APInt &n, const APInt &m) -> Optional<APInt> {
+      adaptor.getOperands(),
+      [](const APInt &n, const APInt &m) -> Optional<APInt> {
         // Don't fold division by zero.
         if (m.isZero())
           return std::nullopt;
@@ -242,56 +248,58 @@ static Optional<APInt> calculateFloorDivS(const APInt &n, const APInt &m) {
   return -1 - (x - n).sdiv(m);
 }
 
-OpFoldResult FloorDivSOp::fold(ArrayRef<Attribute> operands) {
-  return foldBinaryOpChecked(operands, calculateFloorDivS);
+OpFoldResult FloorDivSOp::fold(FoldAdaptor adaptor) {
+  return foldBinaryOpChecked(adaptor.getOperands(), calculateFloorDivS);
 }
 
 //===----------------------------------------------------------------------===//
 // RemSOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult RemSOp::fold(ArrayRef<Attribute> operands) {
-  return foldBinaryOpChecked(operands, [](const APInt &lhs, const APInt &rhs) {
-    return lhs.srem(rhs);
-  });
+OpFoldResult RemSOp::fold(FoldAdaptor adaptor) {
+  return foldBinaryOpChecked(
+      adaptor.getOperands(),
+      [](const APInt &lhs, const APInt &rhs) { return lhs.srem(rhs); });
 }
 
 //===----------------------------------------------------------------------===//
 // RemUOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult RemUOp::fold(ArrayRef<Attribute> operands) {
-  return foldBinaryOpChecked(operands, [](const APInt &lhs, const APInt &rhs) {
-    return lhs.urem(rhs);
-  });
+OpFoldResult RemUOp::fold(FoldAdaptor adaptor) {
+  return foldBinaryOpChecked(
+      adaptor.getOperands(),
+      [](const APInt &lhs, const APInt &rhs) { return lhs.urem(rhs); });
 }
 
 //===----------------------------------------------------------------------===//
 // MaxSOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult MaxSOp::fold(ArrayRef<Attribute> operands) {
-  return foldBinaryOpChecked(operands, [](const APInt &lhs, const APInt &rhs) {
-    return lhs.sgt(rhs) ? lhs : rhs;
-  });
+OpFoldResult MaxSOp::fold(FoldAdaptor adaptor) {
+  return foldBinaryOpChecked(adaptor.getOperands(),
+                             [](const APInt &lhs, const APInt &rhs) {
+                               return lhs.sgt(rhs) ? lhs : rhs;
+                             });
 }
 
 //===----------------------------------------------------------------------===//
 // MaxUOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult MaxUOp::fold(ArrayRef<Attribute> operands) {
-  return foldBinaryOpChecked(operands, [](const APInt &lhs, const APInt &rhs) {
-    return lhs.ugt(rhs) ? lhs : rhs;
-  });
+OpFoldResult MaxUOp::fold(FoldAdaptor adaptor) {
+  return foldBinaryOpChecked(adaptor.getOperands(),
+                             [](const APInt &lhs, const APInt &rhs) {
+                               return lhs.ugt(rhs) ? lhs : rhs;
+                             });
 }
 
 //===----------------------------------------------------------------------===//
 // MinSOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult MinSOp::fold(ArrayRef<Attribute> operands) {
-  return foldBinaryOpChecked(operands, [](const APInt &lhs, const APInt &rhs) {
+OpFoldResult MinSOp::fold(FoldAdaptor adaptor) {
+  return foldBinaryOpChecked(adaptor.getOperands(), [](const APInt &lhs, const APInt &rhs) {
     return lhs.slt(rhs) ? lhs : rhs;
   });
 }
@@ -300,8 +308,8 @@ OpFoldResult MinSOp::fold(ArrayRef<Attribute> operands) {
 // MinUOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult MinUOp::fold(ArrayRef<Attribute> operands) {
-  return foldBinaryOpChecked(operands, [](const APInt &lhs, const APInt &rhs) {
+OpFoldResult MinUOp::fold(FoldAdaptor adaptor) {
+  return foldBinaryOpChecked(adaptor.getOperands(), [](const APInt &lhs, const APInt &rhs) {
     return lhs.ult(rhs) ? lhs : rhs;
   });
 }
@@ -310,9 +318,10 @@ OpFoldResult MinUOp::fold(ArrayRef<Attribute> operands) {
 // ShlOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult ShlOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult ShlOp::fold(FoldAdaptor adaptor) {
   return foldBinaryOpUnchecked(
-      operands, [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
+      adaptor.getOperands(),
+      [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
         // We cannot fold if the RHS is greater than or equal to 32 because
         // this would be UB in 32-bit systems but not on 64-bit systems. RHS is
         // already treated as unsigned.
@@ -326,9 +335,10 @@ OpFoldResult ShlOp::fold(ArrayRef<Attribute> operands) {
 // ShrSOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult ShrSOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult ShrSOp::fold(FoldAdaptor adaptor) {
   return foldBinaryOpChecked(
-      operands, [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
+      adaptor.getOperands(),
+      [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
         // Don't fold if RHS is greater than or equal to 32.
         if (rhs.uge(32))
           return {};
@@ -340,9 +350,10 @@ OpFoldResult ShrSOp::fold(ArrayRef<Attribute> operands) {
 // ShrUOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult ShrUOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult ShrUOp::fold(FoldAdaptor adaptor) {
   return foldBinaryOpChecked(
-      operands, [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
+      adaptor.getOperands(),
+      [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
         // Don't fold if RHS is greater than or equal to 32.
         if (rhs.uge(32))
           return {};
@@ -354,27 +365,30 @@ OpFoldResult ShrUOp::fold(ArrayRef<Attribute> operands) {
 // AndOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult AndOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult AndOp::fold(FoldAdaptor adaptor) {
   return foldBinaryOpUnchecked(
-      operands, [](const APInt &lhs, const APInt &rhs) { return lhs & rhs; });
+      adaptor.getOperands(),
+      [](const APInt &lhs, const APInt &rhs) { return lhs & rhs; });
 }
 
 //===----------------------------------------------------------------------===//
 // OrOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult OrOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult OrOp::fold(FoldAdaptor adaptor) {
   return foldBinaryOpUnchecked(
-      operands, [](const APInt &lhs, const APInt &rhs) { return lhs | rhs; });
+      adaptor.getOperands(),
+      [](const APInt &lhs, const APInt &rhs) { return lhs | rhs; });
 }
 
 //===----------------------------------------------------------------------===//
 // XOrOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult XOrOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult XOrOp::fold(FoldAdaptor adaptor) {
   return foldBinaryOpUnchecked(
-      operands, [](const APInt &lhs, const APInt &rhs) { return lhs ^ rhs; });
+      adaptor.getOperands(),
+      [](const APInt &lhs, const APInt &rhs) { return lhs ^ rhs; });
 }
 
 //===----------------------------------------------------------------------===//
@@ -425,10 +439,9 @@ bool compareIndices(const APInt &lhs, const APInt &rhs,
   llvm_unreachable("unhandled IndexCmpPredicate predicate");
 }
 
-OpFoldResult CmpOp::fold(ArrayRef<Attribute> operands) {
-  assert(operands.size() == 2 && "compare expected 2 operands");
-  auto lhs = dyn_cast_if_present<IntegerAttr>(operands[0]);
-  auto rhs = dyn_cast_if_present<IntegerAttr>(operands[1]);
+OpFoldResult CmpOp::fold(FoldAdaptor adaptor) {
+  auto lhs = dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
+  auto rhs = dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
   if (!lhs || !rhs)
     return {};
 
@@ -453,9 +466,7 @@ void ConstantOp::getAsmResultNames(
   setNameFn(getResult(), specialName.str());
 }
 
-OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
-  return getValueAttr();
-}
+OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
 
 void ConstantOp::build(OpBuilder &b, OperationState &state, int64_t value) {
   build(b, state, b.getIndexType(), b.getIndexAttr(value));
@@ -465,7 +476,7 @@ void ConstantOp::build(OpBuilder &b, OperationState &state, int64_t value) {
 // BoolConstantOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult BoolConstantOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult BoolConstantOp::fold(FoldAdaptor adaptor) {
   return getValueAttr();
 }
 


        


More information about the Mlir-commits mailing list