[Mlir-commits] [mlir] cdf8f06 - [mlir] Add a folder for lbs, ubs, steps of scf.forall.

Alexander Belyaev llvmlistbot at llvm.org
Fri Feb 17 02:05:31 PST 2023


Author: Alexander Belyaev
Date: 2023-02-17T11:01:13+01:00
New Revision: cdf8f064694c37d9f89cfe24203efdc4804a00cc

URL: https://github.com/llvm/llvm-project/commit/cdf8f064694c37d9f89cfe24203efdc4804a00cc
DIFF: https://github.com/llvm/llvm-project/commit/cdf8f064694c37d9f89cfe24203efdc4804a00cc.diff

LOG: [mlir] Add a folder for lbs, ubs, steps of scf.forall.

Differential Revision: https://reviews.llvm.org/D144245

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Arith/Utils/Utils.h
    mlir/lib/Dialect/Arith/Utils/Utils.cpp
    mlir/lib/Dialect/SCF/IR/SCF.cpp
    mlir/test/Dialect/SCF/canonicalize.mlir
    mlir/test/Dialect/SCF/one-shot-bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
index 8b8c000b7ae2e..4bed6e5016f80 100644
--- a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
@@ -28,9 +28,16 @@ detail::op_matcher<arith::ConstantIndexOp> matchConstantIndex();
 
 /// Detects the `values` produced by a ConstantIndexOp and places the new
 /// constant in place of the corresponding sentinel value.
+/// TODO(pifon2a): Remove this function and use foldDynamicIndexList.
 void canonicalizeSubViewPart(SmallVectorImpl<OpFoldResult> &values,
                              function_ref<bool(int64_t)> isDynamic);
 
+/// Returns `success` when any of the elements in `ofrs` was produced by
+/// arith::ConstantIndexOp. In that case the constant attribute replaces the
+/// Value. Returns `failure` when no folding happened.
+LogicalResult foldDynamicIndexList(Builder &b,
+                                   SmallVectorImpl<OpFoldResult> &ofrs);
+
 llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank,
                                             ArrayRef<int64_t> shape);
 

diff  --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
index cf9fdc232e1a8..4d8b5adad7b7c 100644
--- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
@@ -23,8 +23,8 @@ detail::op_matcher<arith::ConstantIndexOp> mlir::matchConstantIndex() {
   return detail::op_matcher<arith::ConstantIndexOp>();
 }
 
-/// Detects the `values` produced by a ConstantIndexOp and places the new
-/// constant in place of the corresponding sentinel value.
+// Detects the `values` produced by a ConstantIndexOp and places the new
+// constant in place of the corresponding sentinel value.
 void mlir::canonicalizeSubViewPart(
     SmallVectorImpl<OpFoldResult> &values,
     llvm::function_ref<bool(int64_t)> isDynamic) {
@@ -38,6 +38,25 @@ void mlir::canonicalizeSubViewPart(
   }
 }
 
+// Returns `success` when any of the elements in `ofrs` was produced by
+// arith::ConstantIndexOp. In that case the constant attribute replaces the
+// Value. Returns `failure` when no folding happened.
+LogicalResult mlir::foldDynamicIndexList(Builder &b,
+                                         SmallVectorImpl<OpFoldResult> &ofrs) {
+  bool valuesChanged = false;
+  for (OpFoldResult &ofr : ofrs) {
+    if (ofr.is<Attribute>())
+      continue;
+    // Newly static, move from Value to constant.
+    if (auto cstOp =
+            ofr.dyn_cast<Value>().getDefiningOp<arith::ConstantIndexOp>()) {
+      ofr = b.getIndexAttr(cstOp.value());
+      valuesChanged = true;
+    }
+  }
+  return success(valuesChanged);
+}
+
 llvm::SmallBitVector mlir::getPositionsOfShapeOne(unsigned rank,
                                                   ArrayRef<int64_t> shape) {
   llvm::SmallBitVector dimsToProject(shape.size());

diff  --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index f1d07c8a5c56f..4415136ed1a43 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1430,11 +1430,45 @@ struct DimOfForallOp : public OpRewritePattern<tensor::DimOp> {
     return success();
   }
 };
+
+class ForallOpControlOperandsFolder : public OpRewritePattern<ForallOp> {
+public:
+  using OpRewritePattern<ForallOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ForallOp op,
+                                PatternRewriter &rewriter) const override {
+    SmallVector<OpFoldResult> mixedLowerBound(op.getMixedLowerBound());
+    SmallVector<OpFoldResult> mixedUpperBound(op.getMixedUpperBound());
+    SmallVector<OpFoldResult> mixedStep(op.getMixedStep());
+    if (failed(foldDynamicIndexList(rewriter, mixedLowerBound)) &&
+        failed(foldDynamicIndexList(rewriter, mixedUpperBound)) &&
+        failed(foldDynamicIndexList(rewriter, mixedStep)))
+      return failure();
+
+    SmallVector<Value> dynamicLowerBound, dynamicUpperBound, dynamicStep;
+    SmallVector<int64_t> staticLowerBound, staticUpperBound, staticStep;
+    dispatchIndexOpFoldResults(mixedLowerBound, dynamicLowerBound,
+                               staticLowerBound);
+    op.getDynamicLowerBoundMutable().assign(dynamicLowerBound);
+    op.setStaticLowerBound(staticLowerBound);
+
+    dispatchIndexOpFoldResults(mixedUpperBound, dynamicUpperBound,
+                               staticUpperBound);
+    op.getDynamicUpperBoundMutable().assign(dynamicUpperBound);
+    op.setStaticUpperBound(staticUpperBound);
+
+    dispatchIndexOpFoldResults(mixedStep, dynamicStep, staticStep);
+    op.getDynamicStepMutable().assign(dynamicStep);
+    op.setStaticStep(staticStep);
+    return success();
+  }
+};
+
 } // namespace
 
 void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                            MLIRContext *context) {
-  results.add<DimOfForallOp>(context);
+  results.add<DimOfForallOp, ForallOpControlOperandsFolder>(context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 4836fae4f757e..c211596db7445 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1497,3 +1497,28 @@ func.func @canonicalize_parallel_insert_slice_indices(
   // CHECK: return %[[dim]]
   return %dim : index
 }
+
+// -----
+
+// CHECK-LABEL: func @forall_fold_control_operands
+func.func @forall_fold_control_operands(
+    %arg0 : tensor<?x10xf32>, %arg1: tensor<?x10xf32>) -> tensor<?x10xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %dim0 = tensor.dim %arg0, %c0 : tensor<?x10xf32>
+  %dim1 = tensor.dim %arg0, %c1 : tensor<?x10xf32>
+
+  %result = scf.forall (%i, %j) = (%c0, %c0) to (%dim0, %dim1)
+      step (%c1, %c1) shared_outs(%o = %arg1) -> (tensor<?x10xf32>) {
+    %slice = tensor.extract_slice %arg1[%i, %j] [1, 1] [1, 1]
+      : tensor<?x10xf32> to tensor<1x1xf32>
+
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %slice into %o[%i, %j] [1, 1] [1, 1]
+        : tensor<1x1xf32> into tensor<?x10xf32>
+    }
+  }
+
+  return %result : tensor<?x10xf32>
+}
+// CHECK: forall (%{{.*}}, %{{.*}}) in (%{{.*}}, 10)

diff  --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
index 7726cadd32cc3..6a7bd40e3bbee 100644
--- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
@@ -658,7 +658,7 @@ func.func @scf_foreach_private_var(%t: tensor<10xf32>) -> f32 {
   // CHECK: %[[t_copy:.*]] = memref.alloc() {{.*}} : memref<10xf32>
   // CHECK: memref.copy %[[t]], %[[t_copy]]
 
-  // CHECK: scf.forall (%{{.*}}) in (%{{.*}}) {
+  // CHECK: scf.forall (%{{.*}}) in (2) {
 
   // Load from the copy and store into the shared output.
   // CHECK:   %[[subview:.*]] = memref.subview %[[t]]


        


More information about the Mlir-commits mailing list