[Mlir-commits] [mlir] 20ae22b - [mlir][Affine] Add affine.delinearize_index operation
Christopher Bate
llvmlistbot at llvm.org
Fri Aug 19 08:24:23 PDT 2022
Author: Christopher Bate
Date: 2022-08-19T09:24:14-06:00
New Revision: 20ae22ba33e47ab4c9cbbde074290f9b0d8d85e2
URL: https://github.com/llvm/llvm-project/commit/20ae22ba33e47ab4c9cbbde074290f9b0d8d85e2
DIFF: https://github.com/llvm/llvm-project/commit/20ae22ba33e47ab4c9cbbde074290f9b0d8d85e2.diff
LOG: [mlir][Affine] Add affine.delinearize_index operation
This change adds a new AffineDelinearizeIndexOp to the affine dialect.
The operation accepts an index type as well as a basis (array of index
values) representing how the index should be decomposed into a
multi-index. The decomposition obeys a canonical semantic that treats
the final basis element as "fastest varying" and the first basis element
as "slowest varying". A naive lowering of the operation using a sequence
of AffineApplyOps is given.
RFC was discussed on discourse here: https://discourse.llvm.org/t/rfc-tensor-extracting-slices-from-tensor-collapse-shape/64034
Reviewed By: bondhugula, nicolasvasilache
Differential Revision: https://reviews.llvm.org/D131997
Added:
mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
mlir/test/Dialect/Affine/affine-expand-index-ops.mlir
Modified:
mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
mlir/include/mlir/Dialect/Affine/Passes.h
mlir/include/mlir/Dialect/Affine/Passes.td
mlir/include/mlir/Dialect/Affine/Utils.h
mlir/lib/Dialect/Affine/IR/AffineOps.cpp
mlir/lib/Dialect/Affine/IR/CMakeLists.txt
mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
mlir/lib/Dialect/Affine/Utils/CMakeLists.txt
mlir/lib/Dialect/Affine/Utils/Utils.cpp
mlir/test/Dialect/Affine/invalid.mlir
mlir/test/Dialect/Affine/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index e46e1bd08dee9..0c7a832414074 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -1061,4 +1061,47 @@ def AffineVectorStoreOp : AffineStoreOpBase<"vector_store"> {
let hasVerifier = 1;
}
+//===----------------------------------------------------------------------===//
+// AffineDelinearizeIndexOp
+//===----------------------------------------------------------------------===//
+
+def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index",
+ [NoSideEffect]> {
+ let summary = "delinearize an index";
+ let description = [{
+ The `affine.delinearize_index` operation takes a single index value and
+ calculates the multi-index according to the given basis.
+
+ Example:
+
+ ```
+ %indices:3 = affine.delinearize_index %linear_index into (%c16, %c224, %c224) : index, index, index
+ ```
+
+ In the above example, `%indices:3` conceptually holds the following:
+
+ ```
+ #map0 = affine_map<()[s0] -> (s0 floordiv 50176)>
+ #map1 = affine_map<()[s0] -> ((s0 mod 50176) floordiv 224)>
+ #map2 = affine_map<()[s0] -> (s0 mod 224)>
+ %indices_0 = affine.apply #map0()[%linear_index]
+ %indices_1 = affine.apply #map1()[%linear_index]
+ %indices_2 = affine.apply #map2()[%linear_index]
+ ```
+ }];
+
+ let arguments = (ins Index:$linear_index, Variadic<Index>:$basis);
+ let results = (outs Variadic<Index>:$multi_index);
+
+ let assemblyFormat = [{
+ $linear_index `into` ` ` `(` $basis `)` attr-dict `:` type($multi_index)
+ }];
+
+ let builders = [
+ OpBuilder<(ins "Value":$linear_index, "ArrayRef<OpFoldResult>":$basis)>
+ ];
+
+ let hasVerifier = 1;
+}
+
#endif // AFFINE_OPS
diff --git a/mlir/include/mlir/Dialect/Affine/Passes.h b/mlir/include/mlir/Dialect/Affine/Passes.h
index 2e18a6fb7f3a1..bab315ecffde4 100644
--- a/mlir/include/mlir/Dialect/Affine/Passes.h
+++ b/mlir/include/mlir/Dialect/Affine/Passes.h
@@ -110,6 +110,14 @@ createSuperVectorizePass(ArrayRef<int64_t> virtualVectorSize);
/// Overload relying on pass options for initialization.
std::unique_ptr<OperationPass<func::FuncOp>> createSuperVectorizePass();
+/// Populate patterns that expand affine index operations into more fundamental
+/// operations (not necessarily restricted to Affine dialect).
+void populateAffineExpandIndexOpsPatterns(RewritePatternSet &patterns);
+
+/// Creates a pass to expand affine index operations into more fundamental
+/// operations (not necessarily restricted to Affine dialect).
+std::unique_ptr<Pass> createAffineExpandIndexOpsPass();
+
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Affine/Passes.td b/mlir/include/mlir/Dialect/Affine/Passes.td
index d50c22569d56e..1f31bccf5f861 100644
--- a/mlir/include/mlir/Dialect/Affine/Passes.td
+++ b/mlir/include/mlir/Dialect/Affine/Passes.td
@@ -397,4 +397,9 @@ def SimplifyAffineStructures : Pass<"affine-simplify-structures", "func::FuncOp"
let constructor = "mlir::createSimplifyAffineStructuresPass()";
}
+def AffineExpandIndexOps : Pass<"affine-expand-index-ops"> {
+ let summary = "Lower affine operations operating on indices into more fundamental operations";
+ let constructor = "mlir::createAffineExpandIndexOpsPass()";
+}
+
#endif // MLIR_DIALECT_AFFINE_PASSES
diff --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h
index 345f955e2061c..006c61ced2125 100644
--- a/mlir/include/mlir/Dialect/Affine/Utils.h
+++ b/mlir/include/mlir/Dialect/Affine/Utils.h
@@ -304,6 +304,21 @@ Optional<SmallVector<Value, 8>> expandAffineMap(OpBuilder &builder,
AffineMap affineMap,
ValueRange operands);
+/// Holds the result of (div a, b) and (mod a, b).
+struct DivModValue {
+ Value quotient;
+ Value remainder;
+};
+
+/// Create IR to calculate (div lhs, rhs) and (mod lhs, rhs).
+DivModValue getDivMod(OpBuilder &b, Location loc, Value lhs, Value rhs);
+
+/// Generate the IR to delinearize `linearIndex` given the `basis` and return
+/// the multi-index.
+FailureOr<SmallVector<Value>> delinearizeIndex(OpBuilder &b, Location loc,
+ Value linearIndex,
+ ArrayRef<Value> basis);
+
} // namespace mlir
#endif // MLIR_DIALECT_AFFINE_UTILS_H
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 766f41dc0771d..ae1a4a320a14d 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -4036,6 +4036,34 @@ LogicalResult AffineVectorStoreOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// DelinearizeIndexOp
+//===----------------------------------------------------------------------===//
+
+void AffineDelinearizeIndexOp::build(OpBuilder &builder, OperationState &result,
+ Value linear_index,
+ ArrayRef<OpFoldResult> basis) {
+ result.addTypes(SmallVector<Type>(basis.size(), builder.getIndexType()));
+ result.addOperands(linear_index);
+ SmallVector<Value> basisValues =
+ llvm::to_vector(llvm::map_range(basis, [&](OpFoldResult ofr) -> Value {
+ Optional<int64_t> staticDim = getConstantIntValue(ofr);
+ if (staticDim.has_value())
+ return builder.create<arith::ConstantIndexOp>(result.location,
+ *staticDim);
+ return ofr.dyn_cast<Value>();
+ }));
+ result.addOperands(basisValues);
+}
+
+LogicalResult AffineDelinearizeIndexOp::verify() {
+ if (getBasis().empty())
+ return emitOpError("basis should not be empty");
+ if (getNumResults() != getBasis().size())
+ return emitOpError("should return an index for each basis element");
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Affine/IR/CMakeLists.txt b/mlir/lib/Dialect/Affine/IR/CMakeLists.txt
index 5616e80d79fb0..e98c935b3f36e 100644
--- a/mlir/lib/Dialect/Affine/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Affine/IR/CMakeLists.txt
@@ -12,6 +12,7 @@ add_mlir_dialect_library(MLIRAffineDialect
LINK_LIBS PUBLIC
MLIRArithmeticDialect
+ MLIRDialectUtils
MLIRIR
MLIRLoopLikeInterface
MLIRMemRefDialect
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
new file mode 100644
index 0000000000000..c162aa2f2d058
--- /dev/null
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
@@ -0,0 +1,63 @@
+//===- AffineExpandIndexOps.cpp - Affine expand index ops pass ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to expand affine index ops into one or more more
+// fundamental operations.
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/Passes.h"
+
+#include "PassDetail.h"
+#include "mlir/Dialect/Affine/Utils.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+namespace {
+/// Lowers `affine.delinearize_index` into a sequence of division and remainder
+/// operations.
+struct LowerDelinearizeIndexOps
+ : public OpRewritePattern<AffineDelinearizeIndexOp> {
+ using OpRewritePattern<AffineDelinearizeIndexOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op,
+ PatternRewriter &rewriter) const override {
+ FailureOr<SmallVector<Value>> multiIndex =
+ delinearizeIndex(rewriter, op->getLoc(), op.getLinearIndex(),
+ llvm::to_vector(op.getBasis()));
+ if (failed(multiIndex))
+ return failure();
+ rewriter.replaceOp(op, *multiIndex);
+ return success();
+ }
+};
+
+class ExpandAffineIndexOpsPass
+ : public AffineExpandIndexOpsBase<ExpandAffineIndexOpsPass> {
+public:
+ ExpandAffineIndexOpsPass() = default;
+
+ void runOnOperation() override {
+ MLIRContext *context = &getContext();
+ RewritePatternSet patterns(context);
+ populateAffineExpandIndexOpsPatterns(patterns);
+ if (failed(
+ applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
+ return signalPassFailure();
+ }
+};
+
+} // namespace
+
+void mlir::populateAffineExpandIndexOpsPatterns(RewritePatternSet &patterns) {
+ patterns.insert<LowerDelinearizeIndexOps>(patterns.getContext());
+}
+
+std::unique_ptr<Pass> mlir::createAffineExpandIndexOpsPass() {
+ return std::make_unique<ExpandAffineIndexOpsPass>();
+}
diff --git a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
index 1a2b2dbb17b80..4601a11bf2894 100644
--- a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRAffineTransforms
AffineDataCopyGeneration.cpp
+ AffineExpandIndexOps.cpp
AffineLoopInvariantCodeMotion.cpp
AffineLoopNormalize.cpp
AffineParallelize.cpp
diff --git a/mlir/lib/Dialect/Affine/Utils/CMakeLists.txt b/mlir/lib/Dialect/Affine/Utils/CMakeLists.txt
index 3be71bd357982..fb26df43b688e 100644
--- a/mlir/lib/Dialect/Affine/Utils/CMakeLists.txt
+++ b/mlir/lib/Dialect/Affine/Utils/CMakeLists.txt
@@ -10,6 +10,7 @@ add_mlir_dialect_library(MLIRAffineUtils
MLIRAffineDialect
MLIRAffineAnalysis
MLIRAnalysis
+ MLIRArithmeticUtils
MLIRMemRefDialect
MLIRTransformUtils
)
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index ae949b62a5279..66a0e3640aba6 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
#include "mlir/Dialect/Affine/LoopUtils.h"
+#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/AffineExprVisitor.h"
@@ -1821,3 +1822,52 @@ MemRefType mlir::normalizeMemRefType(MemRefType memrefType, OpBuilder b,
return newMemRefType;
}
+
+DivModValue mlir::getDivMod(OpBuilder &b, Location loc, Value lhs, Value rhs) {
+ DivModValue result;
+ AffineExpr d0, d1;
+ bindDims(b.getContext(), d0, d1);
+ result.quotient =
+ makeComposedAffineApply(b, loc, d0.floorDiv(d1), {lhs, rhs});
+ result.remainder = makeComposedAffineApply(b, loc, d0 % d1, {lhs, rhs});
+ return result;
+}
+
+/// Create IR that computes the product of all elements in the set.
+static FailureOr<OpFoldResult> getIndexProduct(OpBuilder &b, Location loc,
+ ArrayRef<Value> set) {
+ if (set.empty())
+ return failure();
+ OpFoldResult result = set[0];
+ AffineExpr s0, s1;
+ bindSymbols(b.getContext(), s0, s1);
+ for (unsigned i = 1, e = set.size(); i < e; i++)
+ result = makeComposedFoldedAffineApply(b, loc, s0 * s1, {result, set[i]});
+ return result;
+}
+
+FailureOr<SmallVector<Value>> mlir::delinearizeIndex(OpBuilder &b, Location loc,
+ Value linearIndex,
+ ArrayRef<Value> dimSizes) {
+ unsigned numDims = dimSizes.size();
+
+ SmallVector<Value> divisors;
+ for (unsigned i = 1; i < numDims; i++) {
+ ArrayRef<Value> slice = dimSizes.drop_front(i);
+ FailureOr<OpFoldResult> prod = getIndexProduct(b, loc, slice);
+ if (failed(prod))
+ return failure();
+ divisors.push_back(getValueOrCreateConstantIndexOp(b, loc, *prod));
+ }
+
+ SmallVector<Value> results;
+ results.reserve(divisors.size() + 1);
+ Value residual = linearIndex;
+ for (Value divisor : divisors) {
+ DivModValue divMod = getDivMod(b, loc, residual, divisor);
+ results.push_back(divMod.quotient);
+ residual = divMod.remainder;
+ }
+ results.push_back(residual);
+ return results;
+}
diff --git a/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir b/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir
new file mode 100644
index 0000000000000..70b7f397ad4fe
--- /dev/null
+++ b/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir
@@ -0,0 +1,46 @@
+// RUN: mlir-opt %s -affine-expand-index-ops -split-input-file | FileCheck %s
+
+// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0] -> (s0 floordiv 50176)>
+// CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0] -> ((s0 mod 50176) floordiv 224)>
+// CHECK-DAG: #[[$map2:.+]] = affine_map<()[s0] -> (s0 mod 224)>
+
+// CHECK-LABEL: @static_basis
+// CHECK-SAME: (%[[IDX:.+]]: index)
+// CHECK: %[[N:.+]] = affine.apply #[[$map0]]()[%[[IDX]]]
+// CHECK: %[[P:.+]] = affine.apply #[[$map1]]()[%[[IDX]]]
+// CHECK: %[[Q:.+]] = affine.apply #[[$map2]]()[%[[IDX]]]
+// CHECK: return %[[N]], %[[P]], %[[Q]]
+func.func @static_basis(%linear_index: index) -> (index, index, index) {
+ %b0 = arith.constant 16 : index
+ %b1 = arith.constant 224 : index
+ %b2 = arith.constant 224 : index
+ %1:3 = affine.delinearize_index %linear_index into (%b0, %b1, %b2) : index, index, index
+ return %1#0, %1#1, %1#2 : index, index, index
+}
+
+// -----
+
+// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0, s1, s2] -> (s2 floordiv (s0 * s1))>
+// CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 mod (s0 * s1)) floordiv s1)>
+// CHECK-DAG: #[[$map2:.+]] = affine_map<()[s0, s1, s2] -> ((s2 mod (s0 * s1)) mod s1)>
+
+// CHECK-LABEL: @dynamic_basis
+// CHECK-SAME: (%[[IDX:.+]]: index, %[[MEMREF:.+]]: memref
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK: %[[DIM1:.+]] = memref.dim %[[MEMREF]], %[[C1]] :
+// CHECK: %[[DIM2:.+]] = memref.dim %[[MEMREF]], %[[C2]] :
+// CHECK: %[[N:.+]] = affine.apply #[[$map0]]()[%[[DIM1]], %[[DIM2]], %[[IDX]]]
+// CHECK: %[[P:.+]] = affine.apply #[[$map1]]()[%[[DIM1]], %[[DIM2]], %[[IDX]]]
+// CHECK: %[[Q:.+]] = affine.apply #[[$map2]]()[%[[DIM1]], %[[DIM2]], %[[IDX]]]
+// CHECK: return %[[N]], %[[P]], %[[Q]]
+func.func @dynamic_basis(%linear_index: index, %src: memref<?x?x?xf32>) -> (index, index, index) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %b0 = memref.dim %src, %c0 : memref<?x?x?xf32>
+ %b1 = memref.dim %src, %c1 : memref<?x?x?xf32>
+ %b2 = memref.dim %src, %c2 : memref<?x?x?xf32>
+ %1:3 = affine.delinearize_index %linear_index into (%b0, %b1, %b2) : index, index, index
+ return %1#0, %1#1, %1#2 : index, index, index
+}
diff --git a/mlir/test/Dialect/Affine/invalid.mlir b/mlir/test/Dialect/Affine/invalid.mlir
index 03f4f124ddefa..866aa4062f3a6 100644
--- a/mlir/test/Dialect/Affine/invalid.mlir
+++ b/mlir/test/Dialect/Affine/invalid.mlir
@@ -485,3 +485,19 @@ func.func @missing_for_min(%arg0: index, %arg1: index, %arg2: memref<100xf32>) {
}
return
}
+
+// -----
+
+func.func @delinearize(%idx: index, %basis0: index, %basis1 :index) {
+ // expected-error at +1 {{'affine.delinearize_index' op should return an index for each basis element}}
+ %1 = affine.delinearize_index %idx into (%basis0, %basis1) : index
+ return
+}
+
+// -----
+
+func.func @delinearize(%idx: index, %basis0: index, %basis1 :index) {
+ // expected-error at +1 {{'affine.delinearize_index' op basis should not be empty}}
+ affine.delinearize_index %idx into () : index
+ return
+}
diff --git a/mlir/test/Dialect/Affine/ops.mlir b/mlir/test/Dialect/Affine/ops.mlir
index ad6f3651c1b21..df10163d59822 100644
--- a/mlir/test/Dialect/Affine/ops.mlir
+++ b/mlir/test/Dialect/Affine/ops.mlir
@@ -260,3 +260,12 @@ func.func @affine_for_multiple_yield(%buffer: memref<1024xf32>) -> (f32, f32) {
// CHECK-NEXT: %[[res2:.*]] = arith.addf %{{.*}}, %[[iter_arg2]] : f32
// CHECK-NEXT: affine.yield %[[res1]], %[[res2]] : f32, f32
// CHECK-NEXT: }
+
+// -----
+
+// CHECK-LABEL: func @delinearize
+func.func @delinearize(%linear_idx: index, %basis0: index, %basis1 :index) -> (index, index) {
+ // CHECK: affine.delinearize_index %{{.+}} into (%{{.+}}, %{{.+}}) : index, index
+ %1:2 = affine.delinearize_index %linear_idx into (%basis0, %basis1) : index, index
+ return %1#0, %1#1 : index, index
+}
More information about the Mlir-commits
mailing list