[Mlir-commits] [mlir] 894641e - Revert "[mlir][Arithmetic] Add `arith.delinearize_index` operation"

Christopher Bate llvmlistbot at llvm.org
Mon Jul 25 10:53:28 PDT 2022


Author: Christopher Bate
Date: 2022-07-25T11:52:05-06:00
New Revision: 894641e974e55164481fe3e7c90ef3cf5af89cf9

URL: https://github.com/llvm/llvm-project/commit/894641e974e55164481fe3e7c90ef3cf5af89cf9
DIFF: https://github.com/llvm/llvm-project/commit/894641e974e55164481fe3e7c90ef3cf5af89cf9.diff

LOG: Revert "[mlir][Arithmetic] Add `arith.delinearize_index` operation"

This reverts commit 535b507ba58e8b5f604d53ffc961be1456d229a7.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
    mlir/include/mlir/Dialect/Arithmetic/Utils/Utils.h
    mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
    mlir/lib/Dialect/Arithmetic/IR/CMakeLists.txt
    mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt
    mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp
    mlir/lib/Dialect/Arithmetic/Utils/Utils.cpp
    mlir/test/Dialect/Arithmetic/expand-ops.mlir
    mlir/test/Dialect/Arithmetic/invalid.mlir
    mlir/test/Dialect/Arithmetic/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
index 2bdff675bf9a1..75710d60c6d45 100644
--- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
@@ -1219,47 +1219,4 @@ def SelectOp : Arith_Op<"select", [
   let hasCustomAssemblyFormat = 1;
 }
 
-//===----------------------------------------------------------------------===//
-// DelinearizeIndexOp
-//===----------------------------------------------------------------------===//
-
-def DelinearizeIndexOp : Op<Arithmetic_Dialect, "delinearize_index",
-    [NoSideEffect]> {
-  let summary = "delinearize an index";
-  let description = [{
-    The `arith.delinearize_index` operation takes a single index value and
-    calculates the multi-index according to the given basis.
-
-    Example:
-
-    ```
-    %indices:3 = arith.delinearize_index %linear_index (%1, %2, %3) : index, index, index
-    ```
-
-    In the above example, `%indices:3` conceptually holds the following:
-
-    ```
-    %v1 = arith.muli %1, %2 : index
-    %indices#0 = floorDiv(%linear_index , %v1)
-    %indices#1 = floorDiv(remander(%linear_index , %v1), %3)
-    %indices#2 = remainder(remainder(%linear_idnex, %v1), %3)
-    ```
-  }];
-
-  let arguments = (ins Index:$linear_index, Variadic<Index>:$basis);
-  let results = (outs Variadic<Index>:$multi_index);
-
-  let assemblyFormat = [{
-    $linear_index `(` $basis `)` attr-dict `:` type($multi_index)
-  }];
-
-  let builders = [
-    OpBuilder<(ins "Value":$linear_index, "ArrayRef<OpFoldResult>":$basis)>
-  ];
-
-  let hasVerifier = 1;
-}
-
-
-
 #endif // ARITHMETIC_OPS

diff  --git a/mlir/include/mlir/Dialect/Arithmetic/Utils/Utils.h b/mlir/include/mlir/Dialect/Arithmetic/Utils/Utils.h
index a59e108bd5263..924de08e81af2 100644
--- a/mlir/include/mlir/Dialect/Arithmetic/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Arithmetic/Utils/Utils.h
@@ -108,22 +108,6 @@ struct ArithBuilder {
   OpBuilder &b;
   Location loc;
 };
-
-/// Holds the result of (div a, b)  and (mod a, b)
-struct DivModValue {
-  Value quotient;
-  Value remainder;
-};
-
-/// Create IR to calculate (div a, b)  and (mod a, b)
-DivModValue getDivMod(OpBuilder &b, Location loc, Value lhs, Value rhs);
-
-/// Generate the IR to delinearize `linearIndex` given the `basis` and return
-/// the multi-index.
-FailureOr<SmallVector<Value>> delinearizeIndex(OpBuilder &b, Location loc,
-                                               Value linearIndex,
-                                               ArrayRef<Value> dimSizes);
-
 } // namespace mlir
 
 #endif // MLIR_DIALECT_ARITHMETIC_UTILS_UTILS_H

diff  --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
index f2a88e2e2416d..4d0d50e4d6350 100644
--- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
+++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
@@ -10,7 +10,6 @@
 
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
 #include "mlir/Dialect/CommonFolders.h"
-#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/OpImplementation.h"
@@ -2041,35 +2040,6 @@ Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
   return nullptr;
 }
 
-//===----------------------------------------------------------------------===//
-// DelinearizeIndexOp
-//===----------------------------------------------------------------------===//
-
-void arith::DelinearizeIndexOp::build(OpBuilder &builder,
-                                      OperationState &result,
-                                      Value linear_index,
-                                      ArrayRef<OpFoldResult> basis) {
-  result.addTypes(SmallVector<Type>(basis.size(), builder.getIndexType()));
-  result.addOperands(linear_index);
-  SmallVector<Value> basisValues =
-      llvm::to_vector(llvm::map_range(basis, [&](OpFoldResult ofr) -> Value {
-        Optional<int64_t> staticDim = getConstantIntValue(ofr);
-        if (staticDim.has_value())
-          return builder.create<arith::ConstantIndexOp>(result.location,
-                                                        *staticDim);
-        return ofr.dyn_cast<Value>();
-      }));
-  result.addOperands(basisValues);
-}
-
-LogicalResult arith::DelinearizeIndexOp::verify() {
-  if (getBasis().empty())
-    return emitOpError("basis should not be empty");
-  if (getNumResults() != getBasis().size())
-    return emitOpError("should return an index for each basis element");
-  return success();
-}
-
 //===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Arithmetic/IR/CMakeLists.txt b/mlir/lib/Dialect/Arithmetic/IR/CMakeLists.txt
index 5ca47d98cc5b3..e23504bee3302 100644
--- a/mlir/lib/Dialect/Arithmetic/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Arithmetic/IR/CMakeLists.txt
@@ -15,7 +15,6 @@ add_mlir_dialect_library(MLIRArithmeticDialect
 
   LINK_LIBS PUBLIC
   MLIRDialect
-  MLIRDialectUtils
   MLIRInferIntRangeInterface
   MLIRInferTypeOpInterface
   MLIRIR

diff  --git a/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt
index eead11d11557d..f140715e603ee 100644
--- a/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt
@@ -13,7 +13,6 @@ add_mlir_dialect_library(MLIRArithmeticTransforms
   LINK_LIBS PUBLIC
   MLIRAnalysis
   MLIRArithmeticDialect
-  MLIRArithmeticUtils
   MLIRBufferizationDialect
   MLIRBufferizationTransforms
   MLIRInferIntRangeInterface

diff  --git a/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp
index 7c48b6f639f81..afe7aab99af30 100644
--- a/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp
@@ -9,7 +9,6 @@
 #include "PassDetail.h"
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
 #include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
-#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Transforms/DialectConversion.h"
 
@@ -189,23 +188,6 @@ struct MaxMinIOpConverter : public OpRewritePattern<OpTy> {
   }
 };
 
-/// Lowers `arith.delinearize_index` into a sequence of division and remainder
-/// operations.
-struct LowerDelinearizeIndexOps
-    : public OpRewritePattern<arith::DelinearizeIndexOp> {
-  using OpRewritePattern<arith::DelinearizeIndexOp>::OpRewritePattern;
-  LogicalResult matchAndRewrite(arith::DelinearizeIndexOp op,
-                                PatternRewriter &rewriter) const override {
-    FailureOr<SmallVector<Value>> multiIndex =
-        delinearizeIndex(rewriter, op->getLoc(), op.getLinearIndex(),
-                         llvm::to_vector(op.getBasis()));
-    if (failed(multiIndex))
-      return failure();
-    rewriter.replaceOp(op, *multiIndex);
-    return success();
-  }
-};
-
 struct ArithmeticExpandOpsPass
     : public ArithmeticExpandOpsBase<ArithmeticExpandOpsPass> {
   void runOnOperation() override {
@@ -225,8 +207,7 @@ struct ArithmeticExpandOpsPass
       arith::MaxUIOp,
       arith::MinFOp,
       arith::MinSIOp,
-      arith::MinUIOp,
-      arith::DelinearizeIndexOp
+      arith::MinUIOp
     >();
     // clang-format on
     if (failed(applyPartialConversion(getOperation(), target,
@@ -249,8 +230,7 @@ void mlir::arith::populateArithmeticExpandOpsPatterns(
     MaxMinIOpConverter<MaxSIOp, arith::CmpIPredicate::sgt>,
     MaxMinIOpConverter<MaxUIOp, arith::CmpIPredicate::ugt>,
     MaxMinIOpConverter<MinSIOp, arith::CmpIPredicate::slt>,
-    MaxMinIOpConverter<MinUIOp, arith::CmpIPredicate::ult>,
-    LowerDelinearizeIndexOps
+    MaxMinIOpConverter<MinUIOp, arith::CmpIPredicate::ult>
    >(patterns.getContext());
   // clang-format on
 }

diff  --git a/mlir/lib/Dialect/Arithmetic/Utils/Utils.cpp b/mlir/lib/Dialect/Arithmetic/Utils/Utils.cpp
index 40588869a5880..b568891df66a1 100644
--- a/mlir/lib/Dialect/Arithmetic/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Arithmetic/Utils/Utils.cpp
@@ -12,7 +12,6 @@
 
 #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
-#include "mlir/IR/OpDefinition.h"
 #include "llvm/ADT/SmallBitVector.h"
 
 using namespace mlir;
@@ -116,47 +115,3 @@ Value ArithBuilder::slt(Value lhs, Value rhs) {
 Value ArithBuilder::select(Value cmp, Value lhs, Value rhs) {
   return b.create<arith::SelectOp>(loc, cmp, lhs, rhs);
 }
-
-DivModValue mlir::getDivMod(OpBuilder &b, Location loc, Value lhs, Value rhs) {
-  DivModValue result;
-  result.quotient = b.create<arith::DivUIOp>(loc, lhs, rhs);
-  result.remainder = b.create<arith::RemUIOp>(loc, lhs, rhs);
-  return result;
-}
-
-/// Create IR that computes the product of all elements in the set.
-static FailureOr<OpFoldResult> getIndexProduct(OpBuilder &b, Location loc,
-                                               ArrayRef<Value> set) {
-  if (set.empty())
-    return failure();
-  OpFoldResult result = set[0];
-  for (unsigned i = 1; i < set.size(); i++)
-    result = b.createOrFold<arith::MulIOp>(
-        loc, getValueOrCreateConstantIndexOp(b, loc, result), set[i]);
-  return result;
-}
-
-FailureOr<SmallVector<Value>> mlir::delinearizeIndex(OpBuilder &b, Location loc,
-                                                     Value linearIndex,
-                                                     ArrayRef<Value> dimSizes) {
-  unsigned numDims = dimSizes.size();
-
-  SmallVector<Value> divisors;
-  for (unsigned i = 1; i < numDims; i++) {
-    ArrayRef<Value> slice(dimSizes.begin() + i, dimSizes.end());
-    FailureOr<OpFoldResult> prod = getIndexProduct(b, loc, slice);
-    if (failed(prod))
-      return failure();
-    divisors.push_back(getValueOrCreateConstantIndexOp(b, loc, *prod));
-  }
-
-  SmallVector<Value> results;
-  Value residual = linearIndex;
-  for (Value divisor : divisors) {
-    DivModValue divMod = getDivMod(b, loc, residual, divisor);
-    results.push_back(divMod.quotient);
-    residual = divMod.remainder;
-  }
-  results.push_back(residual);
-  return results;
-}

diff  --git a/mlir/test/Dialect/Arithmetic/expand-ops.mlir b/mlir/test/Dialect/Arithmetic/expand-ops.mlir
index 1da12d369669e..e0990fdec5656 100644
--- a/mlir/test/Dialect/Arithmetic/expand-ops.mlir
+++ b/mlir/test/Dialect/Arithmetic/expand-ops.mlir
@@ -230,24 +230,3 @@ func.func @minui(%a: i32, %b: i32) -> i32 {
 }
 // CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32)
 // CHECK-NEXT: %[[CMP:.*]] = arith.cmpi ult, %[[LHS]], %[[RHS]] : i32
-
-// -----
-
-// CHECK-LABEL: @static_basis
-//  CHECK-SAME:    (%[[IDX:.+]]: index)
-//       CHECK:   arith.constant
-//       CHECK:   arith.constant
-//   CHECK-DAG:   %[[c224:.+]] = arith.constant 224 : index
-//   CHECK-DAG:   %[[c50176:.+]] = arith.constant 50176 : index
-//       CHECK:   %[[N:.+]] = arith.divui %[[IDX]], %[[c50176]] : index
-//       CHECK:   %[[RES:.+]] = arith.remui %[[IDX]], %[[c50176]] : index
-//       CHECK:   %[[P:.+]] = arith.divui %[[RES]], %[[c224]] : index
-//       CHECK:   %[[Q:.+]] = arith.remui %[[RES]], %[[c224]] : index
-//       CHECK:   return %[[N]], %[[P]], %[[Q]]
-func.func @static_basis(%linear_index: index) -> (index, index, index) {
-  %b0 = arith.constant 16 : index
-  %b1 = arith.constant 224 : index
-  %b2 = arith.constant 224 : index
-  %1:3 = arith.delinearize_index %linear_index (%b0, %b1, %b2) : index, index, index
-  return %1#0, %1#1, %1#2 : index, index, index
-}

diff  --git a/mlir/test/Dialect/Arithmetic/invalid.mlir b/mlir/test/Dialect/Arithmetic/invalid.mlir
index 201b75784cb4b..19c427b5e744f 100644
--- a/mlir/test/Dialect/Arithmetic/invalid.mlir
+++ b/mlir/test/Dialect/Arithmetic/invalid.mlir
@@ -721,19 +721,3 @@ func.func @func() {
 
   %x = arith.constant 1 : i32
 }
-
-// -----
-
-func.func @delinearize(%idx: index, %basis0: index, %basis1 :index) {
-  // expected-error at +1 {{'arith.delinearize_index' op should return an index for each basis element}}
-  %1 = arith.delinearize_index %idx (%basis0, %basis1) : index
-  return
-}
-
-// -----
-
-func.func @delinearize(%idx: index, %basis0: index, %basis1 :index) {
-  // expected-error at +1 {{'arith.delinearize_index' op basis should not be empty}}
-  arith.delinearize_index %idx () : index
-  return
-}

diff  --git a/mlir/test/Dialect/Arithmetic/ops.mlir b/mlir/test/Dialect/Arithmetic/ops.mlir
index fe241ba1d1736..61f9a2d8e5125 100644
--- a/mlir/test/Dialect/Arithmetic/ops.mlir
+++ b/mlir/test/Dialect/Arithmetic/ops.mlir
@@ -958,9 +958,3 @@ func.func @minimum(%v1: vector<4xf32>, %v2: vector<4xf32>,
   %min_unsigned = arith.minui %i1, %i2 : i32
   return
 }
-
-// CHECK-LABEL: func @delinearize
-func.func @delinearize(%idx: index, %basis0: index, %basis1 :index) -> (index, index) {
-  %1:2 = arith.delinearize_index %idx (%basis0, %basis1) : index, index
-  return %1#0, %1#1 : index, index
-}


        


More information about the Mlir-commits mailing list