[Mlir-commits] [mlir] 535b507 - [mlir][Arithmetic] Add `arith.delinearize_index` operation
Christopher Bate
llvmlistbot at llvm.org
Fri Jul 22 10:22:04 PDT 2022
Author: Christopher Bate
Date: 2022-07-22T11:20:41-06:00
New Revision: 535b507ba58e8b5f604d53ffc961be1456d229a7
URL: https://github.com/llvm/llvm-project/commit/535b507ba58e8b5f604d53ffc961be1456d229a7
DIFF: https://github.com/llvm/llvm-project/commit/535b507ba58e8b5f604d53ffc961be1456d229a7.diff
LOG: [mlir][Arithmetic] Add `arith.delinearize_index` operation
This change adds a new DelinearizeIndexOp to the `arith` dialect. The
operation accepts an `index` type as well as a basis (array of index
values) representing how the index should be decomposed into a
multi-index. The decomposition obeys a canonical semantic that treats
the final basis element as "fastest varying" and the first basis element
as "slowest varying". A naive lowering of the operation using a sequence
of `arith.divui` and `arith.remui` operations is also given.
Differential Revision: https://reviews.llvm.org/D129697
Added:
Modified:
mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
mlir/include/mlir/Dialect/Arithmetic/Utils/Utils.h
mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
mlir/lib/Dialect/Arithmetic/IR/CMakeLists.txt
mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt
mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp
mlir/lib/Dialect/Arithmetic/Utils/Utils.cpp
mlir/test/Dialect/Arithmetic/expand-ops.mlir
mlir/test/Dialect/Arithmetic/invalid.mlir
mlir/test/Dialect/Arithmetic/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
index 75710d60c6d45..2bdff675bf9a1 100644
--- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
@@ -1219,4 +1219,47 @@ def SelectOp : Arith_Op<"select", [
let hasCustomAssemblyFormat = 1;
}
+//===----------------------------------------------------------------------===//
+// DelinearizeIndexOp
+//===----------------------------------------------------------------------===//
+
+def DelinearizeIndexOp : Op<Arithmetic_Dialect, "delinearize_index",
+ [NoSideEffect]> {
+ let summary = "delinearize an index";
+ let description = [{
+ The `arith.delinearize_index` operation takes a single index value and
+ calculates the multi-index according to the given basis.
+
+ Example:
+
+ ```
+ %indices:3 = arith.delinearize_index %linear_index (%1, %2, %3) : index, index, index
+ ```
+
+ In the above example, `%indices:3` conceptually holds the following:
+
+ ```
+ %v1 = arith.muli %1, %2 : index
+ %indices#0 = floorDiv(%linear_index , %v1)
+ %indices#1 = floorDiv(remander(%linear_index , %v1), %3)
+ %indices#2 = remainder(remainder(%linear_idnex, %v1), %3)
+ ```
+ }];
+
+ let arguments = (ins Index:$linear_index, Variadic<Index>:$basis);
+ let results = (outs Variadic<Index>:$multi_index);
+
+ let assemblyFormat = [{
+ $linear_index `(` $basis `)` attr-dict `:` type($multi_index)
+ }];
+
+ let builders = [
+ OpBuilder<(ins "Value":$linear_index, "ArrayRef<OpFoldResult>":$basis)>
+ ];
+
+ let hasVerifier = 1;
+}
+
+
+
#endif // ARITHMETIC_OPS
diff --git a/mlir/include/mlir/Dialect/Arithmetic/Utils/Utils.h b/mlir/include/mlir/Dialect/Arithmetic/Utils/Utils.h
index 924de08e81af2..a59e108bd5263 100644
--- a/mlir/include/mlir/Dialect/Arithmetic/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Arithmetic/Utils/Utils.h
@@ -108,6 +108,22 @@ struct ArithBuilder {
OpBuilder &b;
Location loc;
};
+
+/// Holds the result of (div a, b) and (mod a, b)
+struct DivModValue {
+ Value quotient;
+ Value remainder;
+};
+
+/// Create IR to calculate (div a, b) and (mod a, b)
+DivModValue getDivMod(OpBuilder &b, Location loc, Value lhs, Value rhs);
+
+/// Generate the IR to delinearize `linearIndex` given the `basis` and return
+/// the multi-index.
+FailureOr<SmallVector<Value>> delinearizeIndex(OpBuilder &b, Location loc,
+ Value linearIndex,
+ ArrayRef<Value> dimSizes);
+
} // namespace mlir
#endif // MLIR_DIALECT_ARITHMETIC_UTILS_UTILS_H
diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
index 939c969d95cae..efec55c125578 100644
--- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
+++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
@@ -10,6 +10,7 @@
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/CommonFolders.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
@@ -2040,6 +2041,35 @@ Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
return nullptr;
}
+//===----------------------------------------------------------------------===//
+// DelinearizeIndexOp
+//===----------------------------------------------------------------------===//
+
+void arith::DelinearizeIndexOp::build(OpBuilder &builder,
+ OperationState &result,
+ Value linear_index,
+ ArrayRef<OpFoldResult> basis) {
+ result.addTypes(SmallVector<Type>(basis.size(), builder.getIndexType()));
+ result.addOperands(linear_index);
+ SmallVector<Value> basisValues =
+ llvm::to_vector(llvm::map_range(basis, [&](OpFoldResult ofr) -> Value {
+ Optional<int64_t> staticDim = getConstantIntValue(ofr);
+ if (staticDim.hasValue())
+ return builder.create<arith::ConstantIndexOp>(result.location,
+ *staticDim);
+ return ofr.dyn_cast<Value>();
+ }));
+ result.addOperands(basisValues);
+}
+
+LogicalResult arith::DelinearizeIndexOp::verify() {
+ if (getBasis().empty())
+ return emitOpError("basis should not be empty");
+ if (getNumResults() != getBasis().size())
+ return emitOpError("should return an index for each basis element");
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arithmetic/IR/CMakeLists.txt b/mlir/lib/Dialect/Arithmetic/IR/CMakeLists.txt
index e23504bee3302..5ca47d98cc5b3 100644
--- a/mlir/lib/Dialect/Arithmetic/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Arithmetic/IR/CMakeLists.txt
@@ -15,6 +15,7 @@ add_mlir_dialect_library(MLIRArithmeticDialect
LINK_LIBS PUBLIC
MLIRDialect
+ MLIRDialectUtils
MLIRInferIntRangeInterface
MLIRInferTypeOpInterface
MLIRIR
diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt
index f140715e603ee..eead11d11557d 100644
--- a/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRArithmeticTransforms
LINK_LIBS PUBLIC
MLIRAnalysis
MLIRArithmeticDialect
+ MLIRArithmeticUtils
MLIRBufferizationDialect
MLIRBufferizationTransforms
MLIRInferIntRangeInterface
diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp
index afe7aab99af30..7c48b6f639f81 100644
--- a/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp
@@ -9,6 +9,7 @@
#include "PassDetail.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
+#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -188,6 +189,23 @@ struct MaxMinIOpConverter : public OpRewritePattern<OpTy> {
}
};
+/// Lowers `arith.delinearize_index` into a sequence of division and remainder
+/// operations.
+struct LowerDelinearizeIndexOps
+ : public OpRewritePattern<arith::DelinearizeIndexOp> {
+ using OpRewritePattern<arith::DelinearizeIndexOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(arith::DelinearizeIndexOp op,
+ PatternRewriter &rewriter) const override {
+ FailureOr<SmallVector<Value>> multiIndex =
+ delinearizeIndex(rewriter, op->getLoc(), op.getLinearIndex(),
+ llvm::to_vector(op.getBasis()));
+ if (failed(multiIndex))
+ return failure();
+ rewriter.replaceOp(op, *multiIndex);
+ return success();
+ }
+};
+
struct ArithmeticExpandOpsPass
: public ArithmeticExpandOpsBase<ArithmeticExpandOpsPass> {
void runOnOperation() override {
@@ -207,7 +225,8 @@ struct ArithmeticExpandOpsPass
arith::MaxUIOp,
arith::MinFOp,
arith::MinSIOp,
- arith::MinUIOp
+ arith::MinUIOp,
+ arith::DelinearizeIndexOp
>();
// clang-format on
if (failed(applyPartialConversion(getOperation(), target,
@@ -230,7 +249,8 @@ void mlir::arith::populateArithmeticExpandOpsPatterns(
MaxMinIOpConverter<MaxSIOp, arith::CmpIPredicate::sgt>,
MaxMinIOpConverter<MaxUIOp, arith::CmpIPredicate::ugt>,
MaxMinIOpConverter<MinSIOp, arith::CmpIPredicate::slt>,
- MaxMinIOpConverter<MinUIOp, arith::CmpIPredicate::ult>
+ MaxMinIOpConverter<MinUIOp, arith::CmpIPredicate::ult>,
+ LowerDelinearizeIndexOps
>(patterns.getContext());
// clang-format on
}
diff --git a/mlir/lib/Dialect/Arithmetic/Utils/Utils.cpp b/mlir/lib/Dialect/Arithmetic/Utils/Utils.cpp
index b568891df66a1..40588869a5880 100644
--- a/mlir/lib/Dialect/Arithmetic/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Arithmetic/Utils/Utils.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/IR/OpDefinition.h"
#include "llvm/ADT/SmallBitVector.h"
using namespace mlir;
@@ -115,3 +116,47 @@ Value ArithBuilder::slt(Value lhs, Value rhs) {
Value ArithBuilder::select(Value cmp, Value lhs, Value rhs) {
return b.create<arith::SelectOp>(loc, cmp, lhs, rhs);
}
+
+DivModValue mlir::getDivMod(OpBuilder &b, Location loc, Value lhs, Value rhs) {
+ DivModValue result;
+ result.quotient = b.create<arith::DivUIOp>(loc, lhs, rhs);
+ result.remainder = b.create<arith::RemUIOp>(loc, lhs, rhs);
+ return result;
+}
+
+/// Create IR that computes the product of all elements in the set.
+static FailureOr<OpFoldResult> getIndexProduct(OpBuilder &b, Location loc,
+ ArrayRef<Value> set) {
+ if (set.empty())
+ return failure();
+ OpFoldResult result = set[0];
+ for (unsigned i = 1; i < set.size(); i++)
+ result = b.createOrFold<arith::MulIOp>(
+ loc, getValueOrCreateConstantIndexOp(b, loc, result), set[i]);
+ return result;
+}
+
+FailureOr<SmallVector<Value>> mlir::delinearizeIndex(OpBuilder &b, Location loc,
+ Value linearIndex,
+ ArrayRef<Value> dimSizes) {
+ unsigned numDims = dimSizes.size();
+
+ SmallVector<Value> divisors;
+ for (unsigned i = 1; i < numDims; i++) {
+ ArrayRef<Value> slice(dimSizes.begin() + i, dimSizes.end());
+ FailureOr<OpFoldResult> prod = getIndexProduct(b, loc, slice);
+ if (failed(prod))
+ return failure();
+ divisors.push_back(getValueOrCreateConstantIndexOp(b, loc, *prod));
+ }
+
+ SmallVector<Value> results;
+ Value residual = linearIndex;
+ for (Value divisor : divisors) {
+ DivModValue divMod = getDivMod(b, loc, residual, divisor);
+ results.push_back(divMod.quotient);
+ residual = divMod.remainder;
+ }
+ results.push_back(residual);
+ return results;
+}
diff --git a/mlir/test/Dialect/Arithmetic/expand-ops.mlir b/mlir/test/Dialect/Arithmetic/expand-ops.mlir
index e0990fdec5656..1da12d369669e 100644
--- a/mlir/test/Dialect/Arithmetic/expand-ops.mlir
+++ b/mlir/test/Dialect/Arithmetic/expand-ops.mlir
@@ -230,3 +230,24 @@ func.func @minui(%a: i32, %b: i32) -> i32 {
}
// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32)
// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi ult, %[[LHS]], %[[RHS]] : i32
+
+// -----
+
+// CHECK-LABEL: @static_basis
+// CHECK-SAME: (%[[IDX:.+]]: index)
+// CHECK: arith.constant
+// CHECK: arith.constant
+// CHECK-DAG: %[[c224:.+]] = arith.constant 224 : index
+// CHECK-DAG: %[[c50176:.+]] = arith.constant 50176 : index
+// CHECK: %[[N:.+]] = arith.divui %[[IDX]], %[[c50176]] : index
+// CHECK: %[[RES:.+]] = arith.remui %[[IDX]], %[[c50176]] : index
+// CHECK: %[[P:.+]] = arith.divui %[[RES]], %[[c224]] : index
+// CHECK: %[[Q:.+]] = arith.remui %[[RES]], %[[c224]] : index
+// CHECK: return %[[N]], %[[P]], %[[Q]]
+func.func @static_basis(%linear_index: index) -> (index, index, index) {
+ %b0 = arith.constant 16 : index
+ %b1 = arith.constant 224 : index
+ %b2 = arith.constant 224 : index
+ %1:3 = arith.delinearize_index %linear_index (%b0, %b1, %b2) : index, index, index
+ return %1#0, %1#1, %1#2 : index, index, index
+}
diff --git a/mlir/test/Dialect/Arithmetic/invalid.mlir b/mlir/test/Dialect/Arithmetic/invalid.mlir
index 19c427b5e744f..201b75784cb4b 100644
--- a/mlir/test/Dialect/Arithmetic/invalid.mlir
+++ b/mlir/test/Dialect/Arithmetic/invalid.mlir
@@ -721,3 +721,19 @@ func.func @func() {
%x = arith.constant 1 : i32
}
+
+// -----
+
+func.func @delinearize(%idx: index, %basis0: index, %basis1 :index) {
+ // expected-error at +1 {{'arith.delinearize_index' op should return an index for each basis element}}
+ %1 = arith.delinearize_index %idx (%basis0, %basis1) : index
+ return
+}
+
+// -----
+
+func.func @delinearize(%idx: index, %basis0: index, %basis1 :index) {
+ // expected-error at +1 {{'arith.delinearize_index' op basis should not be empty}}
+ arith.delinearize_index %idx () : index
+ return
+}
diff --git a/mlir/test/Dialect/Arithmetic/ops.mlir b/mlir/test/Dialect/Arithmetic/ops.mlir
index 359879f28f2ad..2a8df86122335 100644
--- a/mlir/test/Dialect/Arithmetic/ops.mlir
+++ b/mlir/test/Dialect/Arithmetic/ops.mlir
@@ -952,3 +952,9 @@ func.func @minimum(%v1: vector<4xf32>, %v2: vector<4xf32>,
%min_unsigned = arith.minui %i1, %i2 : i32
return
}
+
+// CHECK-LABEL: func @delinearize
+func.func @delinearize(%idx: index, %basis0: index, %basis1 :index) -> (index, index) {
+ %1:2 = arith.delinearize_index %idx (%basis0, %basis1) : index, index
+ return %1#0, %1#1 : index, index
+}
More information about the Mlir-commits
mailing list