[Mlir-commits] [mlir] cdf8f06 - [mlir] Add a folder for lbs, ubs, steps of scf.forall.
Alexander Belyaev
llvmlistbot at llvm.org
Fri Feb 17 02:05:31 PST 2023
Author: Alexander Belyaev
Date: 2023-02-17T11:01:13+01:00
New Revision: cdf8f064694c37d9f89cfe24203efdc4804a00cc
URL: https://github.com/llvm/llvm-project/commit/cdf8f064694c37d9f89cfe24203efdc4804a00cc
DIFF: https://github.com/llvm/llvm-project/commit/cdf8f064694c37d9f89cfe24203efdc4804a00cc.diff
LOG: [mlir] Add a folder for lbs, ubs, steps of scf.forall.
Differential Revision: https://reviews.llvm.org/D144245
Added:
Modified:
mlir/include/mlir/Dialect/Arith/Utils/Utils.h
mlir/lib/Dialect/Arith/Utils/Utils.cpp
mlir/lib/Dialect/SCF/IR/SCF.cpp
mlir/test/Dialect/SCF/canonicalize.mlir
mlir/test/Dialect/SCF/one-shot-bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
index 8b8c000b7ae2e..4bed6e5016f80 100644
--- a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
@@ -28,9 +28,16 @@ detail::op_matcher<arith::ConstantIndexOp> matchConstantIndex();
/// Detects the `values` produced by a ConstantIndexOp and places the new
/// constant in place of the corresponding sentinel value.
+/// TODO(pifon2a): Remove this function and use foldDynamicIndexList.
void canonicalizeSubViewPart(SmallVectorImpl<OpFoldResult> &values,
function_ref<bool(int64_t)> isDynamic);
+/// Returns `success` when any of the elements in `ofrs` was produced by
+/// arith::ConstantIndexOp. In that case the constant attribute replaces the
+/// Value. Returns `failure` when no folding happened.
+LogicalResult foldDynamicIndexList(Builder &b,
+ SmallVectorImpl<OpFoldResult> &ofrs);
+
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank,
ArrayRef<int64_t> shape);
diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
index cf9fdc232e1a8..4d8b5adad7b7c 100644
--- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
@@ -23,8 +23,8 @@ detail::op_matcher<arith::ConstantIndexOp> mlir::matchConstantIndex() {
return detail::op_matcher<arith::ConstantIndexOp>();
}
-/// Detects the `values` produced by a ConstantIndexOp and places the new
-/// constant in place of the corresponding sentinel value.
+// Detects the `values` produced by a ConstantIndexOp and places the new
+// constant in place of the corresponding sentinel value.
void mlir::canonicalizeSubViewPart(
SmallVectorImpl<OpFoldResult> &values,
llvm::function_ref<bool(int64_t)> isDynamic) {
@@ -38,6 +38,25 @@ void mlir::canonicalizeSubViewPart(
}
}
+// Returns `success` when any of the elements in `ofrs` was produced by
+// arith::ConstantIndexOp. In that case the constant attribute replaces the
+// Value. Returns `failure` when no folding happened.
+LogicalResult mlir::foldDynamicIndexList(Builder &b,
+ SmallVectorImpl<OpFoldResult> &ofrs) {
+ bool valuesChanged = false;
+ for (OpFoldResult &ofr : ofrs) {
+ if (ofr.is<Attribute>())
+ continue;
+ // Newly static, move from Value to constant.
+ if (auto cstOp =
+ ofr.dyn_cast<Value>().getDefiningOp<arith::ConstantIndexOp>()) {
+ ofr = b.getIndexAttr(cstOp.value());
+ valuesChanged = true;
+ }
+ }
+ return success(valuesChanged);
+}
+
llvm::SmallBitVector mlir::getPositionsOfShapeOne(unsigned rank,
ArrayRef<int64_t> shape) {
llvm::SmallBitVector dimsToProject(shape.size());
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index f1d07c8a5c56f..4415136ed1a43 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1430,11 +1430,45 @@ struct DimOfForallOp : public OpRewritePattern<tensor::DimOp> {
return success();
}
};
+
+class ForallOpControlOperandsFolder : public OpRewritePattern<ForallOp> {
+public:
+ using OpRewritePattern<ForallOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ForallOp op,
+ PatternRewriter &rewriter) const override {
+ SmallVector<OpFoldResult> mixedLowerBound(op.getMixedLowerBound());
+ SmallVector<OpFoldResult> mixedUpperBound(op.getMixedUpperBound());
+ SmallVector<OpFoldResult> mixedStep(op.getMixedStep());
+ if (failed(foldDynamicIndexList(rewriter, mixedLowerBound)) &&
+ failed(foldDynamicIndexList(rewriter, mixedUpperBound)) &&
+ failed(foldDynamicIndexList(rewriter, mixedStep)))
+ return failure();
+
+ SmallVector<Value> dynamicLowerBound, dynamicUpperBound, dynamicStep;
+ SmallVector<int64_t> staticLowerBound, staticUpperBound, staticStep;
+ dispatchIndexOpFoldResults(mixedLowerBound, dynamicLowerBound,
+ staticLowerBound);
+ op.getDynamicLowerBoundMutable().assign(dynamicLowerBound);
+ op.setStaticLowerBound(staticLowerBound);
+
+ dispatchIndexOpFoldResults(mixedUpperBound, dynamicUpperBound,
+ staticUpperBound);
+ op.getDynamicUpperBoundMutable().assign(dynamicUpperBound);
+ op.setStaticUpperBound(staticUpperBound);
+
+ dispatchIndexOpFoldResults(mixedStep, dynamicStep, staticStep);
+ op.getDynamicStepMutable().assign(dynamicStep);
+ op.setStaticStep(staticStep);
+ return success();
+ }
+};
+
} // namespace
void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<DimOfForallOp>(context);
+ results.add<DimOfForallOp, ForallOpControlOperandsFolder>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 4836fae4f757e..c211596db7445 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1497,3 +1497,28 @@ func.func @canonicalize_parallel_insert_slice_indices(
// CHECK: return %[[dim]]
return %dim : index
}
+
+// -----
+
+// CHECK-LABEL: func @forall_fold_control_operands
+func.func @forall_fold_control_operands(
+ %arg0 : tensor<?x10xf32>, %arg1: tensor<?x10xf32>) -> tensor<?x10xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %dim0 = tensor.dim %arg0, %c0 : tensor<?x10xf32>
+ %dim1 = tensor.dim %arg0, %c1 : tensor<?x10xf32>
+
+ %result = scf.forall (%i, %j) = (%c0, %c0) to (%dim0, %dim1)
+ step (%c1, %c1) shared_outs(%o = %arg1) -> (tensor<?x10xf32>) {
+ %slice = tensor.extract_slice %arg1[%i, %j] [1, 1] [1, 1]
+ : tensor<?x10xf32> to tensor<1x1xf32>
+
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %slice into %o[%i, %j] [1, 1] [1, 1]
+ : tensor<1x1xf32> into tensor<?x10xf32>
+ }
+ }
+
+ return %result : tensor<?x10xf32>
+}
+// CHECK: forall (%{{.*}}, %{{.*}}) in (%{{.*}}, 10)
diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
index 7726cadd32cc3..6a7bd40e3bbee 100644
--- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
@@ -658,7 +658,7 @@ func.func @scf_foreach_private_var(%t: tensor<10xf32>) -> f32 {
// CHECK: %[[t_copy:.*]] = memref.alloc() {{.*}} : memref<10xf32>
// CHECK: memref.copy %[[t]], %[[t_copy]]
- // CHECK: scf.forall (%{{.*}}) in (%{{.*}}) {
+ // CHECK: scf.forall (%{{.*}}) in (2) {
// Load from the copy and store into the shared output.
// CHECK: %[[subview:.*]] = memref.subview %[[t]]
More information about the Mlir-commits
mailing list