[Mlir-commits] [mlir] 3e07b0b - [MLIR] Fix lowering of affine operations with return values
Uday Bondhugula
llvmlistbot at llvm.org
Tue Dec 22 08:16:21 PST 2020
Author: Prateek Gupta
Date: 2020-12-22T21:44:31+05:30
New Revision: 3e07b0b9d3363fb767cbbaa2593fa91ac393fb7e
URL: https://github.com/llvm/llvm-project/commit/3e07b0b9d3363fb767cbbaa2593fa91ac393fb7e
DIFF: https://github.com/llvm/llvm-project/commit/3e07b0b9d3363fb767cbbaa2593fa91ac393fb7e.diff
LOG: [MLIR] Fix lowering of affine operations with return values
This commit addresses the issue of lowering affine.for and
affine.parallel having return values. Relevant test cases are also
added.
Signed-off-by: Prateek Gupta <prateek at polymagelabs.com>
Differential Revision: https://reviews.llvm.org/D93090
Added:
Modified:
mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
mlir/test/Conversion/AffineToStandard/lower-affine.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
index 58f44b6ed207..8721e6b96ed7 100644
--- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
+++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
@@ -334,7 +334,13 @@ class AffineYieldOpLowering : public OpRewritePattern<AffineYieldOp> {
LogicalResult matchAndRewrite(AffineYieldOp op,
PatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<scf::YieldOp>(op);
+ if (isa<scf::ParallelOp>(op.getParentOp())) {
+ // scf.parallel does not yield any values via its terminator scf.yield but
+ // models reductions
diff erently using additional ops in its region.
+ rewriter.replaceOpWithNewOp<scf::YieldOp>(op);
+ return success();
+ }
+ rewriter.replaceOpWithNewOp<scf::YieldOp>(op, op.operands());
return success();
}
};
@@ -349,14 +355,55 @@ class AffineForLowering : public OpRewritePattern<AffineForOp> {
Value lowerBound = lowerAffineLowerBound(op, rewriter);
Value upperBound = lowerAffineUpperBound(op, rewriter);
Value step = rewriter.create<ConstantIndexOp>(loc, op.getStep());
- auto f = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
- rewriter.eraseBlock(f.getBody());
- rewriter.inlineRegionBefore(op.region(), f.region(), f.region().end());
- rewriter.eraseOp(op);
+ auto scfForOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound,
+ step, op.getIterOperands());
+ rewriter.eraseBlock(scfForOp.getBody());
+ rewriter.inlineRegionBefore(op.region(), scfForOp.region(),
+ scfForOp.region().end());
+ rewriter.replaceOp(op, scfForOp.results());
return success();
}
};
+/// Returns the identity value associated with an AtomicRMWKind op.
+static Value getIdentityValue(AtomicRMWKind op, OpBuilder &builder,
+ Location loc) {
+ switch (op) {
+ case AtomicRMWKind::addf:
+ return builder.create<ConstantOp>(loc, builder.getF32FloatAttr(0));
+ case AtomicRMWKind::addi:
+ return builder.create<ConstantOp>(loc, builder.getI32IntegerAttr(0));
+ case AtomicRMWKind::mulf:
+ return builder.create<ConstantOp>(loc, builder.getF32FloatAttr(1));
+ case AtomicRMWKind::muli:
+ return builder.create<ConstantOp>(loc, builder.getI32IntegerAttr(1));
+ // TODO: Add remaining reduction operations.
+ default:
+ emitOptionalError(loc, "Reduction operation type not supported");
+ }
+ return nullptr;
+}
+
+/// Return the value obtained by applying the reduction operation kind
+/// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
+static Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc,
+ Value lhs, Value rhs) {
+ switch (op) {
+ case AtomicRMWKind::addf:
+ return builder.create<AddFOp>(loc, lhs, rhs);
+ case AtomicRMWKind::addi:
+ return builder.create<AddIOp>(loc, lhs, rhs);
+ case AtomicRMWKind::mulf:
+ return builder.create<MulFOp>(loc, lhs, rhs);
+ case AtomicRMWKind::muli:
+ return builder.create<MulIOp>(loc, lhs, rhs);
+ // TODO: Add remaining reduction operations.
+ default:
+ emitOptionalError(loc, "Reduction operation type not supported");
+ }
+ return nullptr;
+}
+
/// Convert an `affine.parallel` (loop nest) operation into a `scf.parallel`
/// operation.
class AffineParallelLowering : public OpRewritePattern<AffineParallelOp> {
@@ -369,12 +416,13 @@ class AffineParallelLowering : public OpRewritePattern<AffineParallelOp> {
SmallVector<Value, 8> steps;
SmallVector<Value, 8> upperBoundTuple;
SmallVector<Value, 8> lowerBoundTuple;
+ SmallVector<Value, 8> identityVals;
// Finding lower and upper bound by expanding the map expression.
// Checking if expandAffineMap is not giving NULL.
- Optional<SmallVector<Value, 8>> upperBound = expandAffineMap(
- rewriter, loc, op.upperBoundsMap(), op.getUpperBoundsOperands());
Optional<SmallVector<Value, 8>> lowerBound = expandAffineMap(
rewriter, loc, op.lowerBoundsMap(), op.getLowerBoundsOperands());
+ Optional<SmallVector<Value, 8>> upperBound = expandAffineMap(
+ rewriter, loc, op.upperBoundsMap(), op.getUpperBoundsOperands());
if (!lowerBound || !upperBound)
return failure();
upperBoundTuple = *upperBound;
@@ -383,13 +431,62 @@ class AffineParallelLowering : public OpRewritePattern<AffineParallelOp> {
for (Attribute step : op.steps())
steps.push_back(rewriter.create<ConstantIndexOp>(
loc, step.cast<IntegerAttr>().getInt()));
- // Creating empty scf.parallel op body with appropriate bounds.
- auto parallelOp = rewriter.create<scf::ParallelOp>(loc, lowerBoundTuple,
- upperBoundTuple, steps);
- rewriter.eraseBlock(parallelOp.getBody());
- rewriter.inlineRegionBefore(op.region(), parallelOp.region(),
- parallelOp.region().end());
- rewriter.eraseOp(op);
+ // Get the terminator op.
+ Operation *affineParOpTerminator = op.getBody()->getTerminator();
+ scf::ParallelOp parOp;
+ if (op.results().empty()) {
+ // Case with no reduction operations/return values.
+ parOp = rewriter.create<scf::ParallelOp>(loc, lowerBoundTuple,
+ upperBoundTuple, steps,
+ /*bodyBuilderFn=*/nullptr);
+ rewriter.eraseBlock(parOp.getBody());
+ rewriter.inlineRegionBefore(op.region(), parOp.region(),
+ parOp.region().end());
+ rewriter.replaceOp(op, parOp.results());
+ return success();
+ }
+ // Case with affine.parallel with reduction operations/return values.
+ // scf.parallel handles the reduction operation
diff erently unlike
+ // affine.parallel.
+ ArrayRef<Attribute> reductions = op.reductions().getValue();
+ for (Attribute reduction : reductions) {
+ // For each of the reduction operations get the identity values for
+ // initialization of the result values.
+ Optional<AtomicRMWKind> reductionOp = symbolizeAtomicRMWKind(
+ static_cast<uint64_t>(reduction.cast<IntegerAttr>().getInt()));
+ assert(reductionOp.hasValue() &&
+ "Reduction operation cannot be of None Type");
+ AtomicRMWKind reductionOpValue = reductionOp.getValue();
+ identityVals.push_back(getIdentityValue(reductionOpValue, rewriter, loc));
+ }
+ parOp = rewriter.create<scf::ParallelOp>(
+ loc, lowerBoundTuple, upperBoundTuple, steps, identityVals,
+ /*bodyBuilderFn=*/nullptr);
+
+ // Copy the body of the affine.parallel op.
+ rewriter.eraseBlock(parOp.getBody());
+ rewriter.inlineRegionBefore(op.region(), parOp.region(),
+ parOp.region().end());
+ assert(reductions.size() == affineParOpTerminator->getNumOperands() &&
+ "Unequal number of reductions and operands.");
+ for (unsigned i = 0, end = reductions.size(); i < end; i++) {
+ // For each of the reduction operations get the respective mlir::Value.
+ Optional<AtomicRMWKind> reductionOp =
+ symbolizeAtomicRMWKind(reductions[i].cast<IntegerAttr>().getInt());
+ assert(reductionOp.hasValue() &&
+ "Reduction Operation cannot be of None Type");
+ AtomicRMWKind reductionOpValue = reductionOp.getValue();
+ rewriter.setInsertionPoint(&parOp.getBody()->back());
+ auto reduceOp = rewriter.create<scf::ReduceOp>(
+ loc, affineParOpTerminator->getOperand(i));
+ rewriter.setInsertionPointToEnd(&reduceOp.reductionOperator().front());
+ Value reductionResult =
+ getReductionOp(reductionOpValue, rewriter, loc,
+ reduceOp.reductionOperator().front().getArgument(0),
+ reduceOp.reductionOperator().front().getArgument(1));
+ rewriter.create<scf::ReduceReturnOp>(loc, reductionResult);
+ }
+ rewriter.replaceOp(op, parOp.results());
return success();
}
};
diff --git a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
index f89d913cb64a..38d269913e51 100644
--- a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
+++ b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
@@ -26,6 +26,30 @@ func @simple_loop() {
/////////////////////////////////////////////////////////////////////
+func @for_with_yield(%buffer: memref<1024xf32>) -> (f32) {
+ %sum_0 = constant 0.0 : f32
+ %sum = affine.for %i = 0 to 10 step 2 iter_args(%sum_iter = %sum_0) -> (f32) {
+ %t = affine.load %buffer[%i] : memref<1024xf32>
+ %sum_next = addf %sum_iter, %t : f32
+ affine.yield %sum_next : f32
+ }
+ return %sum : f32
+}
+
+// CHECK-LABEL: func @for_with_yield
+// CHECK: %[[INIT_SUM:.*]] = constant 0.000000e+00 : f32
+// CHECK-NEXT: %[[LOWER:.*]] = constant 0 : index
+// CHECK-NEXT: %[[UPPER:.*]] = constant 10 : index
+// CHECK-NEXT: %[[STEP:.*]] = constant 2 : index
+// CHECK-NEXT: %[[SUM:.*]] = scf.for %[[IV:.*]] = %[[LOWER]] to %[[UPPER]] step %[[STEP]] iter_args(%[[SUM_ITER:.*]] = %[[INIT_SUM]]) -> (f32) {
+// CHECK-NEXT: load
+// CHECK-NEXT: %[[SUM_NEXT:.*]] = addf
+// CHECK-NEXT: scf.yield %[[SUM_NEXT]] : f32
+// CHECK-NEXT: }
+// CHECK-NEXT: return %[[SUM]] : f32
+
+/////////////////////////////////////////////////////////////////////
+
func private @pre(index) -> ()
func private @body2(index, index) -> ()
func private @post(index) -> ()
@@ -674,3 +698,104 @@ func @affine_parallel_tiled(%o: memref<100x100xf32>, %a: memref<100x100xf32>, %b
// CHECK: %[[A4:.*]] = load %[[ARG2]][%[[arg8]], %[[arg7]]] : memref<100x100xf32>
// CHECK: mulf %[[A3]], %[[A4]] : f32
// CHECK: scf.yield
+
+/////////////////////////////////////////////////////////////////////
+
+func @affine_parallel_simple(%arg0: memref<3x3xf32>, %arg1: memref<3x3xf32>) -> (memref<3x3xf32>) {
+ %O = alloc() : memref<3x3xf32>
+ affine.parallel (%kx, %ky) = (0, 0) to (2, 2) {
+ %1 = affine.load %arg0[%kx, %ky] : memref<3x3xf32>
+ %2 = affine.load %arg1[%kx, %ky] : memref<3x3xf32>
+ %3 = mulf %1, %2 : f32
+ affine.store %3, %O[%kx, %ky] : memref<3x3xf32>
+ }
+ return %O : memref<3x3xf32>
+}
+// CHECK-LABEL: func @affine_parallel_simple
+// CHECK: %[[LOWER_1:.*]] = constant 0 : index
+// CHECK-NEXT: %[[LOWER_2:.*]] = constant 0 : index
+// CHECK-NEXT: %[[UPPER_1:.*]] = constant 2 : index
+// CHECK-NEXT: %[[UPPER_2:.*]] = constant 2 : index
+// CHECK-NEXT: %[[STEP_1:.*]] = constant 1 : index
+// CHECK-NEXT: %[[STEP_2:.*]] = constant 1 : index
+// CHECK-NEXT: scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[LOWER_1]], %[[LOWER_2]]) to (%[[UPPER_1]], %[[UPPER_2]]) step (%[[STEP_1]], %[[STEP_2]]) {
+// CHECK-NEXT: %[[VAL_1:.*]] = load
+// CHECK-NEXT: %[[VAL_2:.*]] = load
+// CHECK-NEXT: %[[PRODUCT:.*]] = mulf
+// CHECK-NEXT: store
+// CHECK-NEXT: scf.yield
+// CHECK-NEXT: }
+// CHECK-NEXT: return
+// CHECK-NEXT: }
+
+/////////////////////////////////////////////////////////////////////
+
+func @affine_parallel_simple_dynamic_bounds(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) {
+ %c_0 = constant 0 : index
+ %output_dim = dim %arg0, %c_0 : memref<?x?xf32>
+ affine.parallel (%kx, %ky) = (%c_0, %c_0) to (%output_dim, %output_dim) {
+ %1 = affine.load %arg0[%kx, %ky] : memref<?x?xf32>
+ %2 = affine.load %arg1[%kx, %ky] : memref<?x?xf32>
+ %3 = mulf %1, %2 : f32
+ affine.store %3, %arg2[%kx, %ky] : memref<?x?xf32>
+ }
+ return
+}
+// CHECK-LABEL: func @affine_parallel_simple_dynamic_bounds
+// CHECK-SAME: %[[ARG_0:.*]]: memref<?x?xf32>, %[[ARG_1:.*]]: memref<?x?xf32>, %[[ARG_2:.*]]: memref<?x?xf32>
+// CHECK: %[[DIM_INDEX:.*]] = constant 0 : index
+// CHECK-NEXT: %[[UPPER:.*]] = dim %[[ARG_0]], %[[DIM_INDEX]] : memref<?x?xf32>
+// CHECK-NEXT: %[[LOWER_1:.*]] = constant 0 : index
+// CHECK-NEXT: %[[LOWER_2:.*]] = constant 0 : index
+// CHECK-NEXT: %[[STEP_1:.*]] = constant 1 : index
+// CHECK-NEXT: %[[STEP_2:.*]] = constant 1 : index
+// CHECK-NEXT: scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[LOWER_1]], %[[LOWER_2]]) to (%[[UPPER]], %[[UPPER]]) step (%[[STEP_1]], %[[STEP_2]]) {
+// CHECK-NEXT: %[[VAL_1:.*]] = load
+// CHECK-NEXT: %[[VAL_2:.*]] = load
+// CHECK-NEXT: %[[PRODUCT:.*]] = mulf
+// CHECK-NEXT: store
+// CHECK-NEXT: scf.yield
+// CHECK-NEXT: }
+// CHECK-NEXT: return
+// CHECK-NEXT: }
+
+/////////////////////////////////////////////////////////////////////
+
+func @affine_parallel_with_reductions(%arg0: memref<3x3xf32>, %arg1: memref<3x3xf32>) -> (f32, f32) {
+ %0:2 = affine.parallel (%kx, %ky) = (0, 0) to (2, 2) reduce ("addf", "mulf") -> (f32, f32) {
+ %1 = affine.load %arg0[%kx, %ky] : memref<3x3xf32>
+ %2 = affine.load %arg1[%kx, %ky] : memref<3x3xf32>
+ %3 = mulf %1, %2 : f32
+ %4 = addf %1, %2 : f32
+ affine.yield %3, %4 : f32, f32
+ }
+ return %0#0, %0#1 : f32, f32
+}
+// CHECK-LABEL: func @affine_parallel_with_reductions
+// CHECK: %[[LOWER_1:.*]] = constant 0 : index
+// CHECK-NEXT: %[[LOWER_2:.*]] = constant 0 : index
+// CHECK-NEXT: %[[UPPER_1:.*]] = constant 2 : index
+// CHECK-NEXT: %[[UPPER_2:.*]] = constant 2 : index
+// CHECK-NEXT: %[[STEP_1:.*]] = constant 1 : index
+// CHECK-NEXT: %[[STEP_2:.*]] = constant 1 : index
+// CHECK-NEXT: %[[INIT_1:.*]] = constant 0.000000e+00 : f32
+// CHECK-NEXT: %[[INIT_2:.*]] = constant 1.000000e+00 : f32
+// CHECK-NEXT: %[[RES:.*]] = scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[LOWER_1]], %[[LOWER_2]]) to (%[[UPPER_1]], %[[UPPER_2]]) step (%[[STEP_1]], %[[STEP_2]]) init (%[[INIT_1]], %[[INIT_2]]) -> (f32, f32) {
+// CHECK-NEXT: %[[VAL_1:.*]] = load
+// CHECK-NEXT: %[[VAL_2:.*]] = load
+// CHECK-NEXT: %[[PRODUCT:.*]] = mulf
+// CHECK-NEXT: %[[SUM:.*]] = addf
+// CHECK-NEXT: scf.reduce(%[[PRODUCT]]) : f32 {
+// CHECK-NEXT: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
+// CHECK-NEXT: %[[RES:.*]] = addf
+// CHECK-NEXT: scf.reduce.return %[[RES]] : f32
+// CHECK-NEXT: }
+// CHECK-NEXT: scf.reduce(%[[SUM]]) : f32 {
+// CHECK-NEXT: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
+// CHECK-NEXT: %[[RES:.*]] = mulf
+// CHECK-NEXT: scf.reduce.return %[[RES]] : f32
+// CHECK-NEXT: }
+// CHECK-NEXT: scf.yield
+// CHECK-NEXT: }
+// CHECK-NEXT: return
+// CHECK-NEXT: }
More information about the Mlir-commits
mailing list