[Mlir-commits] [mlir] 535b507 - [mlir][Arithmetic] Add `arith.delinearize_index` operation

Christopher Bate llvmlistbot at llvm.org
Fri Jul 22 10:22:04 PDT 2022


Author: Christopher Bate
Date: 2022-07-22T11:20:41-06:00
New Revision: 535b507ba58e8b5f604d53ffc961be1456d229a7

URL: https://github.com/llvm/llvm-project/commit/535b507ba58e8b5f604d53ffc961be1456d229a7
DIFF: https://github.com/llvm/llvm-project/commit/535b507ba58e8b5f604d53ffc961be1456d229a7.diff

LOG: [mlir][Arithmetic] Add `arith.delinearize_index` operation

This change adds a new DelinearizeIndexOp to the `arith` dialect. The
operation accepts an `index` type as well as a basis (array of index
values) representing how the index should be decomposed into a
multi-index. The decomposition obeys a canonical semantic that treats
the final basis element as "fastest varying" and the first basis element
as "slowest varying". A naive lowering of the operation using a sequence
of `arith.divui` and `arith.remui` operations is also given.

Differential Revision: https://reviews.llvm.org/D129697

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
    mlir/include/mlir/Dialect/Arithmetic/Utils/Utils.h
    mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
    mlir/lib/Dialect/Arithmetic/IR/CMakeLists.txt
    mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt
    mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp
    mlir/lib/Dialect/Arithmetic/Utils/Utils.cpp
    mlir/test/Dialect/Arithmetic/expand-ops.mlir
    mlir/test/Dialect/Arithmetic/invalid.mlir
    mlir/test/Dialect/Arithmetic/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
index 75710d60c6d45..2bdff675bf9a1 100644
--- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
@@ -1219,4 +1219,47 @@ def SelectOp : Arith_Op<"select", [
   let hasCustomAssemblyFormat = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// DelinearizeIndexOp
+//===----------------------------------------------------------------------===//
+
+def DelinearizeIndexOp : Op<Arithmetic_Dialect, "delinearize_index",
+    [NoSideEffect]> {
+  let summary = "delinearize an index";
+  let description = [{
+    The `arith.delinearize_index` operation takes a single index value and
+    calculates the multi-index according to the given basis.
+
+    Example:
+
+    ```
+    %indices:3 = arith.delinearize_index %linear_index (%1, %2, %3) : index, index, index
+    ```
+
+    In the above example, `%indices:3` conceptually holds the following:
+
+    ```
+    %v1 = arith.muli %1, %2 : index
+    %indices#0 = floorDiv(%linear_index , %v1)
+    %indices#1 = floorDiv(remander(%linear_index , %v1), %3)
+    %indices#2 = remainder(remainder(%linear_idnex, %v1), %3)
+    ```
+  }];
+
+  let arguments = (ins Index:$linear_index, Variadic<Index>:$basis);
+  let results = (outs Variadic<Index>:$multi_index);
+
+  let assemblyFormat = [{
+    $linear_index `(` $basis `)` attr-dict `:` type($multi_index)
+  }];
+
+  let builders = [
+    OpBuilder<(ins "Value":$linear_index, "ArrayRef<OpFoldResult>":$basis)>
+  ];
+
+  let hasVerifier = 1;
+}
+
+
+
 #endif // ARITHMETIC_OPS

diff  --git a/mlir/include/mlir/Dialect/Arithmetic/Utils/Utils.h b/mlir/include/mlir/Dialect/Arithmetic/Utils/Utils.h
index 924de08e81af2..a59e108bd5263 100644
--- a/mlir/include/mlir/Dialect/Arithmetic/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Arithmetic/Utils/Utils.h
@@ -108,6 +108,22 @@ struct ArithBuilder {
   OpBuilder &b;
   Location loc;
 };
+
+/// Holds the result of (div a, b)  and (mod a, b)
+struct DivModValue {
+  Value quotient;
+  Value remainder;
+};
+
+/// Create IR to calculate (div a, b)  and (mod a, b)
+DivModValue getDivMod(OpBuilder &b, Location loc, Value lhs, Value rhs);
+
+/// Generate the IR to delinearize `linearIndex` given the `basis` and return
+/// the multi-index.
+FailureOr<SmallVector<Value>> delinearizeIndex(OpBuilder &b, Location loc,
+                                               Value linearIndex,
+                                               ArrayRef<Value> dimSizes);
+
 } // namespace mlir
 
 #endif // MLIR_DIALECT_ARITHMETIC_UTILS_UTILS_H

diff  --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
index 939c969d95cae..efec55c125578 100644
--- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
+++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
@@ -10,6 +10,7 @@
 
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
 #include "mlir/Dialect/CommonFolders.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/OpImplementation.h"
@@ -2040,6 +2041,35 @@ Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
   return nullptr;
 }
 
+//===----------------------------------------------------------------------===//
+// DelinearizeIndexOp
+//===----------------------------------------------------------------------===//
+
+void arith::DelinearizeIndexOp::build(OpBuilder &builder,
+                                      OperationState &result,
+                                      Value linear_index,
+                                      ArrayRef<OpFoldResult> basis) {
+  result.addTypes(SmallVector<Type>(basis.size(), builder.getIndexType()));
+  result.addOperands(linear_index);
+  SmallVector<Value> basisValues =
+      llvm::to_vector(llvm::map_range(basis, [&](OpFoldResult ofr) -> Value {
+        Optional<int64_t> staticDim = getConstantIntValue(ofr);
+        if (staticDim.hasValue())
+          return builder.create<arith::ConstantIndexOp>(result.location,
+                                                        *staticDim);
+        return ofr.dyn_cast<Value>();
+      }));
+  result.addOperands(basisValues);
+}
+
+LogicalResult arith::DelinearizeIndexOp::verify() {
+  if (getBasis().empty())
+    return emitOpError("basis should not be empty");
+  if (getNumResults() != getBasis().size())
+    return emitOpError("should return an index for each basis element");
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Arithmetic/IR/CMakeLists.txt b/mlir/lib/Dialect/Arithmetic/IR/CMakeLists.txt
index e23504bee3302..5ca47d98cc5b3 100644
--- a/mlir/lib/Dialect/Arithmetic/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Arithmetic/IR/CMakeLists.txt
@@ -15,6 +15,7 @@ add_mlir_dialect_library(MLIRArithmeticDialect
 
   LINK_LIBS PUBLIC
   MLIRDialect
+  MLIRDialectUtils
   MLIRInferIntRangeInterface
   MLIRInferTypeOpInterface
   MLIRIR

diff  --git a/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt
index f140715e603ee..eead11d11557d 100644
--- a/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRArithmeticTransforms
   LINK_LIBS PUBLIC
   MLIRAnalysis
   MLIRArithmeticDialect
+  MLIRArithmeticUtils
   MLIRBufferizationDialect
   MLIRBufferizationTransforms
   MLIRInferIntRangeInterface

diff  --git a/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp
index afe7aab99af30..7c48b6f639f81 100644
--- a/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp
@@ -9,6 +9,7 @@
 #include "PassDetail.h"
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
 #include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
+#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Transforms/DialectConversion.h"
 
@@ -188,6 +189,23 @@ struct MaxMinIOpConverter : public OpRewritePattern<OpTy> {
   }
 };
 
+/// Lowers `arith.delinearize_index` into a sequence of division and remainder
+/// operations.
+struct LowerDelinearizeIndexOps
+    : public OpRewritePattern<arith::DelinearizeIndexOp> {
+  using OpRewritePattern<arith::DelinearizeIndexOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(arith::DelinearizeIndexOp op,
+                                PatternRewriter &rewriter) const override {
+    FailureOr<SmallVector<Value>> multiIndex =
+        delinearizeIndex(rewriter, op->getLoc(), op.getLinearIndex(),
+                         llvm::to_vector(op.getBasis()));
+    if (failed(multiIndex))
+      return failure();
+    rewriter.replaceOp(op, *multiIndex);
+    return success();
+  }
+};
+
 struct ArithmeticExpandOpsPass
     : public ArithmeticExpandOpsBase<ArithmeticExpandOpsPass> {
   void runOnOperation() override {
@@ -207,7 +225,8 @@ struct ArithmeticExpandOpsPass
       arith::MaxUIOp,
       arith::MinFOp,
       arith::MinSIOp,
-      arith::MinUIOp
+      arith::MinUIOp,
+      arith::DelinearizeIndexOp
     >();
     // clang-format on
     if (failed(applyPartialConversion(getOperation(), target,
@@ -230,7 +249,8 @@ void mlir::arith::populateArithmeticExpandOpsPatterns(
     MaxMinIOpConverter<MaxSIOp, arith::CmpIPredicate::sgt>,
     MaxMinIOpConverter<MaxUIOp, arith::CmpIPredicate::ugt>,
     MaxMinIOpConverter<MinSIOp, arith::CmpIPredicate::slt>,
-    MaxMinIOpConverter<MinUIOp, arith::CmpIPredicate::ult>
+    MaxMinIOpConverter<MinUIOp, arith::CmpIPredicate::ult>,
+    LowerDelinearizeIndexOps
    >(patterns.getContext());
   // clang-format on
 }

diff  --git a/mlir/lib/Dialect/Arithmetic/Utils/Utils.cpp b/mlir/lib/Dialect/Arithmetic/Utils/Utils.cpp
index b568891df66a1..40588869a5880 100644
--- a/mlir/lib/Dialect/Arithmetic/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Arithmetic/Utils/Utils.cpp
@@ -12,6 +12,7 @@
 
 #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/IR/OpDefinition.h"
 #include "llvm/ADT/SmallBitVector.h"
 
 using namespace mlir;
@@ -115,3 +116,47 @@ Value ArithBuilder::slt(Value lhs, Value rhs) {
 Value ArithBuilder::select(Value cmp, Value lhs, Value rhs) {
   return b.create<arith::SelectOp>(loc, cmp, lhs, rhs);
 }
+
+DivModValue mlir::getDivMod(OpBuilder &b, Location loc, Value lhs, Value rhs) {
+  DivModValue result;
+  result.quotient = b.create<arith::DivUIOp>(loc, lhs, rhs);
+  result.remainder = b.create<arith::RemUIOp>(loc, lhs, rhs);
+  return result;
+}
+
+/// Create IR that computes the product of all elements in the set.
+static FailureOr<OpFoldResult> getIndexProduct(OpBuilder &b, Location loc,
+                                               ArrayRef<Value> set) {
+  if (set.empty())
+    return failure();
+  OpFoldResult result = set[0];
+  for (unsigned i = 1; i < set.size(); i++)
+    result = b.createOrFold<arith::MulIOp>(
+        loc, getValueOrCreateConstantIndexOp(b, loc, result), set[i]);
+  return result;
+}
+
+FailureOr<SmallVector<Value>> mlir::delinearizeIndex(OpBuilder &b, Location loc,
+                                                     Value linearIndex,
+                                                     ArrayRef<Value> dimSizes) {
+  unsigned numDims = dimSizes.size();
+
+  SmallVector<Value> divisors;
+  for (unsigned i = 1; i < numDims; i++) {
+    ArrayRef<Value> slice(dimSizes.begin() + i, dimSizes.end());
+    FailureOr<OpFoldResult> prod = getIndexProduct(b, loc, slice);
+    if (failed(prod))
+      return failure();
+    divisors.push_back(getValueOrCreateConstantIndexOp(b, loc, *prod));
+  }
+
+  SmallVector<Value> results;
+  Value residual = linearIndex;
+  for (Value divisor : divisors) {
+    DivModValue divMod = getDivMod(b, loc, residual, divisor);
+    results.push_back(divMod.quotient);
+    residual = divMod.remainder;
+  }
+  results.push_back(residual);
+  return results;
+}

diff  --git a/mlir/test/Dialect/Arithmetic/expand-ops.mlir b/mlir/test/Dialect/Arithmetic/expand-ops.mlir
index e0990fdec5656..1da12d369669e 100644
--- a/mlir/test/Dialect/Arithmetic/expand-ops.mlir
+++ b/mlir/test/Dialect/Arithmetic/expand-ops.mlir
@@ -230,3 +230,24 @@ func.func @minui(%a: i32, %b: i32) -> i32 {
 }
 // CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32)
 // CHECK-NEXT: %[[CMP:.*]] = arith.cmpi ult, %[[LHS]], %[[RHS]] : i32
+
+// -----
+
+// CHECK-LABEL: @static_basis
+//  CHECK-SAME:    (%[[IDX:.+]]: index)
+//       CHECK:   arith.constant
+//       CHECK:   arith.constant
+//   CHECK-DAG:   %[[c224:.+]] = arith.constant 224 : index
+//   CHECK-DAG:   %[[c50176:.+]] = arith.constant 50176 : index
+//       CHECK:   %[[N:.+]] = arith.divui %[[IDX]], %[[c50176]] : index
+//       CHECK:   %[[RES:.+]] = arith.remui %[[IDX]], %[[c50176]] : index
+//       CHECK:   %[[P:.+]] = arith.divui %[[RES]], %[[c224]] : index
+//       CHECK:   %[[Q:.+]] = arith.remui %[[RES]], %[[c224]] : index
+//       CHECK:   return %[[N]], %[[P]], %[[Q]]
+func.func @static_basis(%linear_index: index) -> (index, index, index) {
+  %b0 = arith.constant 16 : index
+  %b1 = arith.constant 224 : index
+  %b2 = arith.constant 224 : index
+  %1:3 = arith.delinearize_index %linear_index (%b0, %b1, %b2) : index, index, index
+  return %1#0, %1#1, %1#2 : index, index, index
+}

diff  --git a/mlir/test/Dialect/Arithmetic/invalid.mlir b/mlir/test/Dialect/Arithmetic/invalid.mlir
index 19c427b5e744f..201b75784cb4b 100644
--- a/mlir/test/Dialect/Arithmetic/invalid.mlir
+++ b/mlir/test/Dialect/Arithmetic/invalid.mlir
@@ -721,3 +721,19 @@ func.func @func() {
 
   %x = arith.constant 1 : i32
 }
+
+// -----
+
+func.func @delinearize(%idx: index, %basis0: index, %basis1 :index) {
+  // expected-error at +1 {{'arith.delinearize_index' op should return an index for each basis element}}
+  %1 = arith.delinearize_index %idx (%basis0, %basis1) : index
+  return
+}
+
+// -----
+
+func.func @delinearize(%idx: index, %basis0: index, %basis1 :index) {
+  // expected-error at +1 {{'arith.delinearize_index' op basis should not be empty}}
+  arith.delinearize_index %idx () : index
+  return
+}

diff  --git a/mlir/test/Dialect/Arithmetic/ops.mlir b/mlir/test/Dialect/Arithmetic/ops.mlir
index 359879f28f2ad..2a8df86122335 100644
--- a/mlir/test/Dialect/Arithmetic/ops.mlir
+++ b/mlir/test/Dialect/Arithmetic/ops.mlir
@@ -952,3 +952,9 @@ func.func @minimum(%v1: vector<4xf32>, %v2: vector<4xf32>,
   %min_unsigned = arith.minui %i1, %i2 : i32
   return
 }
+
+// CHECK-LABEL: func @delinearize
+func.func @delinearize(%idx: index, %basis0: index, %basis1 :index) -> (index, index) {
+  %1:2 = arith.delinearize_index %idx (%basis0, %basis1) : index, index
+  return %1#0, %1#1 : index, index
+}


        


More information about the Mlir-commits mailing list