[Mlir-commits] [mlir] 310deca - [mlir] Add loop bounds to scf.foreach_thread.

Alexander Belyaev llvmlistbot at llvm.org
Thu Feb 16 23:58:03 PST 2023


Author: Alexander Belyaev
Date: 2023-02-17T08:57:52+01:00
New Revision: 310deca248c84a9d8e529654437797327391fdc1

URL: https://github.com/llvm/llvm-project/commit/310deca248c84a9d8e529654437797327391fdc1
DIFF: https://github.com/llvm/llvm-project/commit/310deca248c84a9d8e529654437797327391fdc1.diff

LOG: [mlir] Add loop bounds to scf.foreach_thread.

https://discourse.llvm.org/t/rfc-parallel-loops-on-tensors-in-mlir/68332

Differential Revision: https://reviews.llvm.org/D144072

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
    mlir/include/mlir/Interfaces/ViewLikeInterface.h
    mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
    mlir/lib/Dialect/SCF/IR/SCF.cpp
    mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
    mlir/lib/Interfaces/ViewLikeInterface.cpp
    mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir
    mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
    mlir/test/Dialect/SCF/ops.mlir
    mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir
    mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 24c37272d641e..6d627635bfe8e 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -353,16 +353,16 @@ def ForOp : SCF_Op<"for",
 
 def ForeachThreadOp : SCF_Op<"foreach_thread", [
        AttrSizedOperandSegments,
-       SingleBlockImplicitTerminator<"scf::PerformConcurrentlyOp">,
-       RecursiveMemoryEffects,
        AutomaticAllocationScope,
-      ]> {
+       RecursiveMemoryEffects,
+       SingleBlockImplicitTerminator<"scf::PerformConcurrentlyOp">,
+     ]> {
   let summary = "evaluate a block multiple times in parallel";
   let description = [{
     `scf.foreach_thread` is a target-independent multi-dimensional parallel
     region application operation. It has exactly one block that represents the
-    parallel body and it takes index operands that indicate how many parallel
-    instances of that function are created.
+    parallel body and it takes index operands that specify lower bounds, upper
+    bounds and steps.
 
     The op also takes a variadic number of tensor operands (`shared_outs`).
     The future buffers corresponding to these tensors are shared among all
@@ -404,7 +404,12 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
     When the parallel function body has side effects, their order is unspecified
     across threads.
 
-    Example:
+    `scf.foreach_thread` can be printed in two 
diff erent ways depending on
+    whether the loop is normalized or not. The loop is 'normalized' when all
+    lower bounds are equal to zero and steps are equal to one. In that case,
+    `lowerBound` and `step` operands will be omitted during printing.
+
+    Normalized loop example:
 
     ```mlir
     //
@@ -442,6 +447,38 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
     //
     ```
 
+    Loop with loop bounds example:
+
+    ```mlir
+    //
+    // Sequential context.
+    //
+    %pointwise = scf.foreach_thread (%i, %j) = (0, 0) to (%dim1, %dim2)
+      step (%tileSize1, %tileSize2) shared_outs(%o1 = %out)
+      -> (tensor<?x?xT>, tensor<?xT>) {
+      //
+      // Parallel context.
+      //
+      %sA = tensor.extract_slice %A[%i, %j][%tileSize1, %tileSize2][1, 1]
+        : tensor<?x?xT> to tensor<?x?xT>
+      %sB = tensor.extract_slice %B[%i, %j][%tileSize1, %tileSize2][1, 1]
+        : tensor<?x?xT> to tensor<?x?xT>
+      %sC = tensor.extract_slice %o[%i, %j][%tileSize1, %tileSize2][1, 1]
+        : tensor<?x?xT> to tensor<?x?xT>
+
+      %add = map {"arith.addf"} ins(%sA, %sB) outs(%sC)
+
+      scf.foreach_thread.perform_concurrently {
+        scf.foreach_thread.parallel_insert_slice %add into
+          %o[%i, %j][%tileSize1, %tileSize2][1, 1]
+          : tensor<?x?xT> into tensor<?x?xT>
+      }
+    }
+    // Implicit synchronization point.
+    // Sequential context.
+    //
+    ```
+
     Example with mapping attribute:
 
     ```mlir
@@ -481,9 +518,15 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
     }
     ```
   }];
-  let arguments = (ins Variadic<Index>:$num_threads,
-                       Variadic<AnyRankedTensor>:$outputs,
-                       OptionalAttr<DeviceMappingArrayAttr>:$mapping);
+  let arguments = (ins
+    Variadic<Index>:$dynamicLowerBound,
+    Variadic<Index>:$dynamicUpperBound,
+    Variadic<Index>:$dynamicStep,
+    DenseI64ArrayAttr:$staticLowerBound,
+    DenseI64ArrayAttr:$staticUpperBound,
+    DenseI64ArrayAttr:$staticStep,
+    Variadic<AnyRankedTensor>:$outputs,
+    OptionalAttr<DeviceMappingArrayAttr>:$mapping);
 
   let results = (outs Variadic<AnyType>:$results);
   let regions = (region SizedRegion<1>:$region);
@@ -495,58 +538,114 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
   // The default builder does not add the proper body BBargs, roll our own.
   let skipDefaultBuilders = 1;
   let builders = [
-    // Bodyless builder, outputs must be specified.
-    OpBuilder<(ins "ValueRange":$outputs, "ValueRange":$num_threads,
-                   "std::optional<ArrayAttr>":$mapping)>,
-    // Builder that takes a bodyBuilder lambda.
-    OpBuilder<(ins "ValueRange":$outputs, "ValueRange":$num_threads,
-                   "ArrayRef<Attribute>":$mapping,
-                   "function_ref<void(OpBuilder &, Location, ValueRange)>":$bodyBuilder)>
+    // Builder that takes loop bounds.
+    OpBuilder<(ins "ArrayRef<OpFoldResult>":$lbs,
+       "ArrayRef<OpFoldResult>":$ubs, "ArrayRef<OpFoldResult>":$steps,
+       "ValueRange":$outputs, "std::optional<ArrayAttr>":$mapping,
+       CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>",
+            "nullptr"> :$bodyBuilderFn)>,
+
+    // Builder for normalized loop that takes only upper bounds.
+    OpBuilder<(ins "ArrayRef<OpFoldResult>":$ubs,
+       "ValueRange":$outputs, "std::optional<ArrayAttr>":$mapping,
+       CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>",
+            "nullptr"> :$bodyBuilderFn)>,
   ];
+
   let extraClassDeclaration = [{
-    int64_t getRank() { return getNumThreads().size(); }
+    // Get lower bounds as OpFoldResult.
+    SmallVector<OpFoldResult> getMixedLowerBound() {
+      Builder b(getOperation()->getContext());
+      return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(), b);
+    }
+
+    // Get upper bounds as OpFoldResult.
+    SmallVector<OpFoldResult> getMixedUpperBound() {
+      Builder b(getOperation()->getContext());
+      return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(), b);
+    }
+
+    // Get steps as OpFoldResult.
+    SmallVector<OpFoldResult> getMixedStep() {
+      Builder b(getOperation()->getContext());
+      return getMixedValues(getStaticStep(), getDynamicStep(), b);
+    }
+
+    /// Get lower bounds as values.
+    SmallVector<Value> getLowerBound(OpBuilder &b) {
+      return getAsValues(b, getLoc(), getMixedLowerBound());
+    }
+
+    /// Get upper bounds as values.
+    SmallVector<Value> getUpperBound(OpBuilder &b) {
+      return getAsValues(b, getLoc(), getMixedUpperBound());
+    }
+
+    /// Get steps as values.
+    SmallVector<Value> getStep(OpBuilder &b) {
+      return getAsValues(b, getLoc(), getMixedStep());
+    }
+
+    int64_t getRank() { return getStaticLowerBound().size(); }
+
+    /// Number of operands controlling the loop: lbs, ubs, steps
+    unsigned getNumControlOperands() { return 3 * getRank(); }
+
+    /// Number of dynamic operands controlling the loop: lbs, ubs, steps
+    unsigned getNumDynamicControlOperands() {
+      return getODSOperandIndexAndLength(3).first;
+    }
 
     OpResult getTiedOpResult(OpOperand *opOperand) {
-      assert(opOperand->getOperandNumber() >= getRank() && "invalid operand");
+      assert(opOperand->getOperandNumber() >= getNumDynamicControlOperands() &&
+             "invalid operand");
       return getOperation()->getOpResult(
-          opOperand->getOperandNumber() - getRank());
+          opOperand->getOperandNumber() - getNumDynamicControlOperands());
     }
 
     /// Return the num_threads operand that is tied to the given thread id
     /// block argument.
     OpOperand *getTiedOpOperand(BlockArgument bbArg) {
       assert(bbArg.getArgNumber() >= getRank() && "invalid bbArg");
-      return &getOperation()->getOpOperand(bbArg.getArgNumber());
+
+      return &getOperation()->getOpOperand(getNumDynamicControlOperands() +
+                                           bbArg.getArgNumber() - getRank());
     }
 
     /// Return the shared_outs operand that is tied to the given OpResult.
     OpOperand *getTiedOpOperand(OpResult opResult) {
       assert(opResult.getDefiningOp() == getOperation() && "invalid OpResult");
-      return &getOperation()->getOpOperand(
-          opResult.getResultNumber() + getRank());
+      return &getOperation()->getOpOperand(getNumDynamicControlOperands() +
+                                           opResult.getResultNumber());
     }
 
     BlockArgument getTiedBlockArgument(OpOperand *opOperand) {
-      assert(opOperand->getOperandNumber() >= getRank() && "invalid operand");
-      return getBody()->getArgument(opOperand->getOperandNumber());
+      assert(opOperand->getOperandNumber() >= getNumDynamicControlOperands() &&
+             "invalid operand");
+
+      return getBody()->getArgument(opOperand->getOperandNumber() -
+                                    getNumDynamicControlOperands() + getRank());
     }
 
     ArrayRef<BlockArgument> getOutputBlockArguments() {
       return getBody()->getArguments().drop_front(getRank());
     }
 
-    ::mlir::ValueRange getThreadIndices() {
+    ::mlir::ValueRange getInductionVars() {
       return getBody()->getArguments().take_front(getRank());
     }
 
-    ::mlir::Value getThreadIndex(int64_t idx) {
-      return getThreadIndices()[idx];
+    ::mlir::Value getInductionVar(int64_t idx) {
+      return getInductionVars()[idx];
     }
 
     ::mlir::Block::BlockArgListType getRegionOutArgs() {
       return getBody()->getArguments().drop_front(getRank());
     }
 
+    /// Checks if the lbs are zeros and steps are ones.
+    bool isNormalized();
+
     /// Helper to sort `values` according to matching `keys`.
     /// Take a custom `compare` binary comparator which returns true if the first
     /// element is smaller than the second (i.e. compatible with std::sort).
@@ -559,7 +658,8 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
     // The ensureTerminator method generated by SingleBlockImplicitTerminator is
     // unaware of the fact that our terminator also needs a region to be
     // well-formed. We override it here to ensure that we do the right thing.
-    static void ensureTerminator(Region &region, OpBuilder &builder, Location loc);
+    static void ensureTerminator(Region & region, OpBuilder & builder,
+                                 Location loc);
 
     PerformConcurrentlyOp getTerminator();
   }];

diff  --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
index f950933b23c7a..5843ecd061df6 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
@@ -48,8 +48,10 @@ namespace mlir {
 /// in `integers` is `dynVal` or (2) the next value otherwise. This allows
 /// idiomatic printing of mixed value and integer attributes in a list. E.g.
 /// `[%arg0, 7, 42, %arg42]`.
-void printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
-                           OperandRange values, ArrayRef<int64_t> integers);
+void printDynamicIndexList(
+    OpAsmPrinter &printer, Operation *op, OperandRange values,
+    ArrayRef<int64_t> integers,
+    AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
 
 /// Pasrer hook for custom directive in assemblyFormat.
 ///
@@ -64,10 +66,11 @@ void printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
 /// E.g. after parsing "[%arg0, 7, 42, %arg42]":
 ///   1. `result` is filled with the i64 ArrayAttr "[`dynVal`, 7, 42, `dynVal`]"
 ///   2. `ssa` is filled with "[%arg0, %arg1]".
-ParseResult
-parseDynamicIndexList(OpAsmParser &parser,
-                      SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
-                      DenseI64ArrayAttr &integers);
+ParseResult parseDynamicIndexList(
+    OpAsmParser &parser,
+    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+    DenseI64ArrayAttr &integers,
+    AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
 
 /// Verify that a the `values` has as many elements as the number of entries in
 /// `attr` for which `isDynamic` evaluates to true.

diff  --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
index 90f04f3713cf3..e971b764c6f05 100644
--- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
+++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
@@ -180,16 +180,19 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapForeachToBlocksImpl(
   // transform op does not apply to individual ForeachThreadOp.
   Location loc = foreachThreadOp->getLoc();
 
+  if (!foreachThreadOp.isNormalized())
+    return transformOp.emitSilenceableError()
+           << "unsupported non-normalized loops";
   if (foreachThreadOp.getNumResults() > 0)
     return transformOp.emitSilenceableError()
            << "only bufferized scf.foreach_thread lowers to "
               "gpu.block_id";
-  if (foreachThreadOp.getNumThreads().size() > 3)
+  if (foreachThreadOp.getRank() > 3)
     return transformOp.emitSilenceableError()
            << "scf.foreach_thread with rank > 3 does not lower to "
               "gpu.block_id";
-  if (llvm::any_of(foreachThreadOp.getNumThreads(), [](Value v) {
-        return !v.getDefiningOp<arith::ConstantIndexOp>();
+  if (llvm::any_of(foreachThreadOp.getMixedUpperBound(), [](OpFoldResult ofr) {
+        return !getConstantIntValue(ofr).has_value();
       })) {
     return transformOp.emitSilenceableError()
            << "unsupported dynamic griddim size";
@@ -198,8 +201,7 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapForeachToBlocksImpl(
       llvm::to_vector(foreachThreadOp.getMapping()->getValue());
 
   // Step 1. Complete the blockMapping to a full mapping (with 1s) if necessary.
-  SmallVector<Value> numBlocks =
-      llvm::to_vector(foreachThreadOp.getNumThreads());
+  SmallVector<Value> numBlocks = foreachThreadOp.getUpperBound(rewriter);
   // Ensure we have 3 block sizes, one for each id.
   Value one;
   for (auto attr : mappingAttributes) {
@@ -227,7 +229,7 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapForeachToBlocksImpl(
   blockIdGenerator(rewriter, foreachThreadOp, blockOps);
   IRMapping bvm;
   for (auto [blockIdx, blockDim] :
-       llvm::zip(foreachThreadOp.getThreadIndices(), blockMapping)) {
+       llvm::zip(foreachThreadOp.getInductionVars(), blockMapping)) {
     bvm.map(blockIdx,
             blockOps[static_cast<int64_t>(
                 blockDim.cast<DeviceMappingAttrInterface>().getMappingId())]);
@@ -243,7 +245,7 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapForeachToBlocksImpl(
                                       sourceBlock.getOperations());
 
   // Step 5. RAUW thread indices to thread ops.
-  for (Value loopIndex : foreachThreadOp.getThreadIndices()) {
+  for (Value loopIndex : foreachThreadOp.getInductionVars()) {
     Value blockIdx = bvm.lookup(loopIndex);
     rewriter.replaceAllUsesWith(loopIndex, blockIdx);
   }
@@ -381,14 +383,16 @@ static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads(
     return emitDefiniteFailure(foreachThreadOp, message);
   };
   Location loc = foreachThreadOp->getLoc();
+  if (!foreachThreadOp.isNormalized())
+    return failureHelper("unsupported non-normalized loops");
   if (foreachThreadOp.getNumResults() > 0)
     return failureHelper(
         "only bufferized scf.foreach_thread lowers to gpu.thread_id");
-  if (foreachThreadOp.getNumThreads().size() > 3)
+  if (foreachThreadOp.getRank() > 3)
     return failureHelper(
         "scf.foreach_thread with rank > 3 does not lower to gpu.thread_id");
-  if (llvm::any_of(foreachThreadOp.getNumThreads(), [](Value v) {
-        return !v.getDefiningOp<arith::ConstantIndexOp>();
+  if (llvm::any_of(foreachThreadOp.getMixedUpperBound(), [](OpFoldResult ofr) {
+        return !getConstantIntValue(ofr).has_value();
       })) {
     return failureHelper("unsupported dynamic blockdim size");
   }
@@ -399,8 +403,7 @@ static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads(
 
   // Step 1. Complete the threadMapping to a full mapping (with 1s) if
   // necessary.
-  SmallVector<Value> numThreads =
-      llvm::to_vector(foreachThreadOp.getNumThreads());
+  SmallVector<Value> numThreads = foreachThreadOp.getUpperBound(rewriter);
   // Ensure we have 3 block sizes, one for each id.
   Value one;
   for (auto attr : threadMappingAttributes) {
@@ -437,7 +440,7 @@ static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads(
   }
   IRMapping bvm;
   for (auto [blockIdx, blockDim] :
-       llvm::zip(foreachThreadOp.getThreadIndices(), threadMapping)) {
+       llvm::zip(foreachThreadOp.getInductionVars(), threadMapping)) {
     bvm.map(blockIdx,
             threadOpsUpdated[blockDim.cast<DeviceMappingAttrInterface>()
                                  .getMappingId()]);
@@ -484,7 +487,7 @@ static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads(
                                       sourceBlock.getOperations());
 
   // Step 6. RAUW thread indices to thread ops.
-  for (Value loopIndex : foreachThreadOp.getThreadIndices()) {
+  for (Value loopIndex : foreachThreadOp.getInductionVars()) {
     Value threadIdx = bvm.lookup(loopIndex);
     rewriter.replaceAllUsesWith(loopIndex, threadIdx);
   }

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 7303bd738906c..10a1451fb7cc3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -253,7 +253,7 @@ static void calculateTileOffsetsAndSizes(
   OpBuilder::InsertionGuard g(b);
   b.setInsertionPointToStart(foreachThreadOp.getBody(0));
 
-  ValueRange threadIds = foreachThreadOp.getThreadIndices();
+  ValueRange threadIds = foreachThreadOp.getInductionVars();
   SmallVector<OpFoldResult> nonZeroNumThreads =
       llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) {
         return !isConstantIntValue(ofr, 0);
@@ -360,7 +360,7 @@ static FailureOr<ForeachThreadTilingResult> tileToForeachThreadOpImpl(
   // version because we require the use of RewriterBase in the body, so we
   // manually move the insertion point to the body below.
   scf::ForeachThreadOp foreachThreadOp = b.create<scf::ForeachThreadOp>(
-      loc, dest, ValueRange(materializedNonZeroNumThreads), mapping);
+      loc, getAsOpFoldResult((materializedNonZeroNumThreads)), dest, mapping);
 
   // 2. Fill out the ForeachThreadOp body.
   SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
@@ -681,8 +681,8 @@ linalg::tileReductionUsingForeachThread(RewriterBase &b,
 
   // 2. Create the ForeachThreadOp with an empty region.
   scf::ForeachThreadOp foreachThreadOp = b.create<scf::ForeachThreadOp>(
-      loc, (*identityTensor)->getResults(),
-      ValueRange(materializedNonZeroNumThreads), mapping);
+      loc, getAsOpFoldResult(materializedNonZeroNumThreads),
+      (*identityTensor)->getResults(), mapping);
 
   // 3. Calculate the tile offsets and sizes for the subsequent loop that will
   // be nested under `foreachThreadOp`.
@@ -712,7 +712,7 @@ linalg::tileReductionUsingForeachThread(RewriterBase &b,
                                            b.getIndexAttr(0));
       SmallVector<OpFoldResult> sizes = tiledSizes;
       sizes[reductionDim] = b.getIndexAttr(1);
-      outOffsets[reductionDim] = foreachThreadOp.getThreadIndices().front();
+      outOffsets[reductionDim] = foreachThreadOp.getInductionVars().front();
       // TODO: use SubsetExtractOpInterface once it is available.
       tiledDpsInitOperands.push_back(b.create<tensor::ExtractSliceOp>(
           loc, initOperand->get().getType().cast<RankedTensorType>(),
@@ -746,7 +746,7 @@ linalg::tileReductionUsingForeachThread(RewriterBase &b,
       if (failed(maybeTiled))
         return b.notifyMatchFailure(op, "failed tileLinalgOpImpl");
 
-      SmallVector<Value> ids = foreachThreadOp.getThreadIndices();
+      SmallVector<Value> ids = foreachThreadOp.getInductionVars();
       mapLoopToProcessorIds(cast<scf::ForOp>(maybeTiled->loops.back()), ids,
                             materializedNonZeroNumThreads);
       assert(maybeTiled->loops.size() == 1 &&
@@ -774,7 +774,7 @@ linalg::tileReductionUsingForeachThread(RewriterBase &b,
     int64_t sizeIdx = 0;
     for (int64_t i = 0, e = numThreads.size(); i < e; ++i) {
       if (i == reductionDim) {
-        resultOffsetsRank.push_back(foreachThreadOp.getThreadIndices().front());
+        resultOffsetsRank.push_back(foreachThreadOp.getInductionVars().front());
         resultSizesRank.push_back(b.getIndexAttr(1));
         continue;
       }

diff  --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 6a4da00bbad36..9032f533e5fbe 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1110,6 +1110,7 @@ Speculation::Speculatability ForOp::getSpeculatability() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult ForeachThreadOp::verify() {
+  unsigned numLoops = getRank();
   // Check number of outputs.
   if (getNumResults() != getOutputs().size())
     return emitOpError("produces ")
@@ -1118,18 +1119,18 @@ LogicalResult ForeachThreadOp::verify() {
 
   // Check that the body defines block arguments for thread indices and outputs.
   auto *body = getBody();
-  if (body->getNumArguments() != getRank() + getOutputs().size())
-    return emitOpError("region expects ") << getRank() << " arguments";
-  for (int64_t i = 0; i < getRank(); ++i)
+  if (body->getNumArguments() != numLoops + getOutputs().size())
+    return emitOpError("region expects ") << numLoops << " arguments";
+  for (int64_t i = 0; i < numLoops; ++i)
     if (!body->getArgument(i).getType().isIndex())
       return emitOpError("expects ")
              << i << "-th block argument to be an index";
   for (unsigned i = 0; i < getOutputs().size(); ++i)
-    if (body->getArgument(i + getRank()).getType() != getOutputs()[i].getType())
+    if (body->getArgument(i + numLoops).getType() != getOutputs()[i].getType())
       return emitOpError("type mismatch between ")
              << i << "-th output and corresponding block argument";
   if (getMapping().has_value() && !getMapping()->empty()) {
-    if (static_cast<int64_t>(getMapping()->size()) != getRank())
+    if (static_cast<int64_t>(getMapping()->size()) != numLoops)
       return emitOpError() << "mapping attribute size must match op rank";
     for (auto map : getMapping()->getValue()) {
       if (!isa<DeviceMappingAttrInterface>(map))
@@ -1138,15 +1139,41 @@ LogicalResult ForeachThreadOp::verify() {
     }
   }
 
+  // Verify mixed static/dynamic control variables.
+  Operation *op = getOperation();
+  if (failed(verifyListOfOperandsOrIntegers(op, "lower bound", numLoops,
+                                            getStaticLowerBound(),
+                                            getDynamicLowerBound())))
+    return failure();
+  if (failed(verifyListOfOperandsOrIntegers(op, "upper bound", numLoops,
+                                            getStaticUpperBound(),
+                                            getDynamicUpperBound())))
+    return failure();
+  if (failed(verifyListOfOperandsOrIntegers(op, "step", numLoops,
+                                            getStaticStep(), getDynamicStep())))
+    return failure();
+
   return success();
 }
 
 void ForeachThreadOp::print(OpAsmPrinter &p) {
-  p << " (";
-  llvm::interleaveComma(getThreadIndices(), p);
-  p << ") in (";
-  llvm::interleaveComma(getNumThreads(), p);
-  p << ")";
+  Operation *op = getOperation();
+  p << " (" << getInductionVars();
+  if (isNormalized()) {
+    p << ") in ";
+    printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
+                          OpAsmParser::Delimiter::Paren);
+  } else {
+    p << ") = ";
+    printDynamicIndexList(p, op, getDynamicLowerBound(), getStaticLowerBound(),
+                          OpAsmParser::Delimiter::Paren);
+    p << " to ";
+    printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
+                          OpAsmParser::Delimiter::Paren);
+    p << " step ";
+    printDynamicIndexList(p, op, getDynamicStep(), getStaticStep(),
+                          OpAsmParser::Delimiter::Paren);
+  }
   printInitializationList(p, getRegionOutArgs(), getOutputs(), " shared_outs");
   p << " ";
   if (!getRegionOutArgs().empty())
@@ -1154,28 +1181,60 @@ void ForeachThreadOp::print(OpAsmPrinter &p) {
   p.printRegion(getRegion(),
                 /*printEntryBlockArgs=*/false,
                 /*printBlockTerminators=*/getNumResults() > 0);
-  p.printOptionalAttrDict(getOperation()->getAttrs(),
-                          {"operand_segment_sizes"});
+  p.printOptionalAttrDict(op->getAttrs(), {getOperandSegmentSizesAttrName(),
+                                           getStaticLowerBoundAttrName(),
+                                           getStaticUpperBoundAttrName(),
+                                           getStaticStepAttrName()});
 }
 
 ParseResult ForeachThreadOp::parse(OpAsmParser &parser,
                                    OperationState &result) {
-  auto &builder = parser.getBuilder();
+  OpBuilder b(parser.getContext());
+  auto indexType = b.getIndexType();
+
   // Parse an opening `(` followed by thread index variables followed by `)`
   // TODO: when we can refer to such "induction variable"-like handles from the
   // declarative assembly format, we can implement the parser as a custom hook.
-  SmallVector<OpAsmParser::Argument, 4> threadIndices;
-  if (parser.parseArgumentList(threadIndices, OpAsmParser::Delimiter::Paren))
+  SmallVector<OpAsmParser::Argument, 4> ivs;
+  if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren))
     return failure();
 
-  // Parse `in` threadNums.
-  SmallVector<OpAsmParser::UnresolvedOperand, 4> threadNums;
-  if (parser.parseKeyword("in") ||
-      parser.parseOperandList(threadNums, threadIndices.size(),
+  DenseI64ArrayAttr staticLbs, staticUbs, staticSteps;
+  SmallVector<OpAsmParser::UnresolvedOperand> dynamicLbs, dynamicUbs,
+      dynamicSteps;
+  if (succeeded(parser.parseOptionalKeyword("in"))) {
+    // Parse upper bounds.
+    if (parseDynamicIndexList(parser, dynamicUbs, staticUbs,
                               OpAsmParser::Delimiter::Paren) ||
-      parser.resolveOperands(threadNums, builder.getIndexType(),
-                             result.operands))
-    return failure();
+        parser.resolveOperands(dynamicUbs, indexType, result.operands))
+      return failure();
+
+    unsigned numLoops = ivs.size();
+    staticLbs = b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 0));
+    staticSteps = b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 1));
+  } else {
+    // Parse lower bounds.
+    if (parser.parseEqual() ||
+        parseDynamicIndexList(parser, dynamicLbs, staticLbs,
+                              OpAsmParser::Delimiter::Paren) ||
+
+        parser.resolveOperands(dynamicLbs, indexType, result.operands))
+      return failure();
+
+    // Parse upper bounds.
+    if (parser.parseKeyword("to") ||
+        parseDynamicIndexList(parser, dynamicUbs, staticUbs,
+                              OpAsmParser::Delimiter::Paren) ||
+        parser.resolveOperands(dynamicUbs, indexType, result.operands))
+      return failure();
+
+    // Parse step values.
+    if (parser.parseKeyword("step") ||
+        parseDynamicIndexList(parser, dynamicSteps, staticSteps,
+                              OpAsmParser::Delimiter::Paren) ||
+        parser.resolveOperands(dynamicSteps, indexType, result.operands))
+      return failure();
+  }
 
   // Parse out operands and results.
   SmallVector<OpAsmParser::Argument, 4> regionOutArgs;
@@ -1195,9 +1254,9 @@ ParseResult ForeachThreadOp::parse(OpAsmParser &parser,
   // Parse region.
   SmallVector<OpAsmParser::Argument, 4> regionArgs;
   std::unique_ptr<Region> region = std::make_unique<Region>();
-  for (auto &idx : threadIndices) {
-    idx.type = builder.getIndexType();
-    regionArgs.push_back(idx);
+  for (auto &iv : ivs) {
+    iv.type = b.getIndexType();
+    regionArgs.push_back(iv);
   }
   for (const auto &it : llvm::enumerate(regionOutArgs)) {
     auto &out = it.value();
@@ -1208,92 +1267,111 @@ ParseResult ForeachThreadOp::parse(OpAsmParser &parser,
     return failure();
 
   // Ensure terminator and move region.
-  OpBuilder b(builder.getContext());
   ForeachThreadOp::ensureTerminator(*region, b, result.location);
   result.addRegion(std::move(region));
 
   // Parse the optional attribute list.
   if (parser.parseOptionalAttrDict(result.attributes))
     return failure();
+
+  result.addAttribute("staticLowerBound", staticLbs);
+  result.addAttribute("staticUpperBound", staticUbs);
+  result.addAttribute("staticStep", staticSteps);
   result.addAttribute("operand_segment_sizes",
                       parser.getBuilder().getDenseI32ArrayAttr(
-                          {static_cast<int32_t>(threadNums.size()),
+                          {static_cast<int32_t>(dynamicLbs.size()),
+                           static_cast<int32_t>(dynamicUbs.size()),
+                           static_cast<int32_t>(dynamicSteps.size()),
                            static_cast<int32_t>(outOperands.size())}));
   return success();
 }
 
-// Bodyless builder, outputs must be specified.
-void ForeachThreadOp::build(mlir::OpBuilder &builder,
-                            mlir::OperationState &result, ValueRange outputs,
-                            ValueRange numThreads,
-                            std::optional<ArrayAttr> mapping) {
-  result.addOperands(numThreads);
+// Builder that takes loop bounds.
+void ForeachThreadOp::build(
+    mlir::OpBuilder &b, mlir::OperationState &result,
+    ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
+    ArrayRef<OpFoldResult> steps, ValueRange outputs,
+    std::optional<ArrayAttr> mapping,
+    function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
+  SmallVector<int64_t> staticLbs, staticUbs, staticSteps;
+  SmallVector<Value> dynamicLbs, dynamicUbs, dynamicSteps;
+  dispatchIndexOpFoldResults(lbs, dynamicLbs, staticLbs);
+  dispatchIndexOpFoldResults(ubs, dynamicUbs, staticUbs);
+  dispatchIndexOpFoldResults(steps, dynamicSteps, staticSteps);
+
+  result.addOperands(dynamicLbs);
+  result.addOperands(dynamicUbs);
+  result.addOperands(dynamicSteps);
   result.addOperands(outputs);
+  result.addTypes(TypeRange(outputs));
+
+  result.addAttribute(getStaticLowerBoundAttrName(result.name),
+                      b.getDenseI64ArrayAttr(staticLbs));
+  result.addAttribute(getStaticUpperBoundAttrName(result.name),
+                      b.getDenseI64ArrayAttr(staticUbs));
+  result.addAttribute(getStaticStepAttrName(result.name),
+                      b.getDenseI64ArrayAttr(staticSteps));
+  result.addAttribute(
+      "operand_segment_sizes",
+      b.getDenseI32ArrayAttr({static_cast<int32_t>(dynamicLbs.size()),
+                              static_cast<int32_t>(dynamicUbs.size()),
+                              static_cast<int32_t>(dynamicSteps.size()),
+                              static_cast<int32_t>(outputs.size())}));
   if (mapping.has_value()) {
     result.addAttribute(ForeachThreadOp::getMappingAttrName(result.name),
                         mapping.value());
   }
 
-  result.addAttribute(
-      "operand_segment_sizes",
-      builder.getDenseI32ArrayAttr({static_cast<int32_t>(numThreads.size()),
-                                    static_cast<int32_t>(outputs.size())}));
-  result.addTypes(TypeRange(outputs));
-
   Region *bodyRegion = result.addRegion();
-  OpBuilder::InsertionGuard g(builder);
-  // createBlock sets the IP inside the block.
-  // Generally we would guard against that but the default ensureTerminator impl
-  // expects it ..
-  builder.createBlock(bodyRegion);
+  OpBuilder::InsertionGuard g(b);
+  b.createBlock(bodyRegion);
   Block &bodyBlock = bodyRegion->front();
-  // Add block arguments for indices and outputs.
-  bodyBlock.addArguments(
-      SmallVector<Type>(numThreads.size(), builder.getIndexType()),
-      SmallVector<Location>(numThreads.size(), result.location));
-  bodyBlock.addArguments(
-      TypeRange(outputs),
-      SmallVector<Location>(outputs.size(), result.location));
-  ForeachThreadOp::ensureTerminator(*bodyRegion, builder, result.location);
-}
 
-// Builder that takes a bodyBuilder lambda.
-void ForeachThreadOp::build(
-    mlir::OpBuilder &builder, mlir::OperationState &result, ValueRange outputs,
-    ValueRange numThreads, ArrayRef<Attribute> mapping,
-    function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
-  result.addOperands(numThreads);
-  result.addOperands(outputs);
-  result.addAttribute(ForeachThreadOp::getMappingAttrName(result.name),
-                      builder.getArrayAttr(mapping));
-  result.addAttribute(
-      "operand_segment_sizes",
-      builder.getDenseI32ArrayAttr({static_cast<int32_t>(numThreads.size()),
-                                    static_cast<int32_t>(outputs.size())}));
-  result.addTypes(TypeRange(outputs));
-
-  Region *bodyRegion = result.addRegion();
-  OpBuilder::InsertionGuard g(builder);
-  builder.createBlock(bodyRegion);
-  Block &bodyBlock = bodyRegion->front();
   // Add block arguments for indices and outputs.
   bodyBlock.addArguments(
-      SmallVector<Type>(numThreads.size(), builder.getIndexType()),
-      SmallVector<Location>(numThreads.size(), result.location));
+      SmallVector<Type>(lbs.size(), b.getIndexType()),
+      SmallVector<Location>(staticLbs.size(), result.location));
   bodyBlock.addArguments(
       TypeRange(outputs),
       SmallVector<Location>(outputs.size(), result.location));
 
-  builder.setInsertionPointToStart(&bodyBlock);
-  bodyBuilder(builder, result.location, bodyBlock.getArguments());
+  b.setInsertionPointToStart(&bodyBlock);
+  if (!bodyBuilderFn) {
+    ForeachThreadOp::ensureTerminator(*bodyRegion, b, result.location);
+    return;
+  }
+  bodyBuilderFn(b, result.location, bodyBlock.getArguments());
 #ifndef NDEBUG
   auto terminator =
       llvm::dyn_cast<PerformConcurrentlyOp>(bodyBlock.getTerminator());
   assert(terminator &&
-         "expected bodyBuilder to create PerformConcurrentlyOp terminator");
+         "expected bodyBuilderFn to create PerformConcurrentlyOp terminator");
 #endif // NDEBUG
 }
 
+// Builder that takes loop bounds.
+void ForeachThreadOp::build(
+    mlir::OpBuilder &b, mlir::OperationState &result,
+    ArrayRef<OpFoldResult> ubs, ValueRange outputs,
+    std::optional<ArrayAttr> mapping,
+    function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
+  unsigned numLoops = ubs.size();
+  SmallVector<OpFoldResult> lbs(numLoops, b.getIndexAttr(0));
+  SmallVector<OpFoldResult> steps(numLoops, b.getIndexAttr(1));
+  build(b, result, lbs, ubs, steps, outputs, mapping, bodyBuilderFn);
+}
+
+// Checks if the lbs are zeros and steps are ones.
+bool ForeachThreadOp::isNormalized() {
+  auto allEqual = [](ArrayRef<OpFoldResult> results, int64_t val) {
+    return llvm::all_of(results, [&](OpFoldResult ofr) {
+      auto intValue = getConstantIntValue(ofr);
+      return intValue.has_value() && intValue == val;
+    });
+  };
+  return allEqual(getMixedLowerBound(), 0) && allEqual(getMixedStep(), 1);
+}
+
 // The ensureTerminator method generated by SingleBlockImplicitTerminator is
 // unaware of the fact that our terminator also needs a region to be
 // well-formed. We override it here to ensure that we do the right thing.

diff  --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index e844f462389a9..d5c227967b36a 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -1004,15 +1004,14 @@ struct YieldOpInterface
 
 /// Return `true` if the given loop may have 0 iterations.
 bool mayHaveZeroIterations(scf::ForeachThreadOp foreachThreadOp) {
-  int64_t p = 1;
-  for (Value v : foreachThreadOp.getNumThreads()) {
-    if (std::optional<int64_t> c = getConstantIntValue(v)) {
-      p *= *c;
-    } else {
+  for (auto [lb, ub] : llvm::zip(foreachThreadOp.getMixedLowerBound(),
+                                 foreachThreadOp.getMixedUpperBound())) {
+    std::optional<int64_t> lbConst = getConstantIntValue(lb);
+    std::optional<int64_t> ubConst = getConstantIntValue(ub);
+    if (!lbConst.has_value() || !ubConst.has_value() || *lbConst >= *ubConst)
       return true;
-    }
   }
-  return p == 0;
+  return false;
 }
 
 /// Bufferization of ForeachThreadOp. This also bufferizes the terminator of the
@@ -1087,8 +1086,9 @@ struct ForeachThreadOpInterface
     rewriter.setInsertionPoint(foreachThreadOp);
     ForeachThreadOp newForeachThreadOp;
     newForeachThreadOp = rewriter.create<ForeachThreadOp>(
-        foreachThreadOp.getLoc(), /*outputs=*/ValueRange(),
-        foreachThreadOp.getNumThreads(), foreachThreadOp.getMapping());
+        foreachThreadOp.getLoc(), foreachThreadOp.getMixedLowerBound(),
+        foreachThreadOp.getMixedUpperBound(), foreachThreadOp.getMixedStep(),
+        /*outputs=*/ValueRange(), foreachThreadOp.getMapping());
 
     newForeachThreadOp.getBody()->getTerminator()->erase();
 
@@ -1127,10 +1127,28 @@ struct ForeachThreadOpInterface
 
   bool isRepetitiveRegion(Operation *op, unsigned index) const {
     auto foreachThreadOp = cast<ForeachThreadOp>(op);
-    // This op is not repetitive if it has just a single thread.
-    return !llvm::all_of(foreachThreadOp.getNumThreads(), [](Value v) {
-      return getConstantIntValue(v) == static_cast<int64_t>(1);
-    });
+
+    // This op is repetitive if it has 1 or more steps.
+    // If the control variables are dynamic, it is also considered so.
+    for (auto [lb, ub, step] : llvm::zip(foreachThreadOp.getMixedLowerBound(),
+                                         foreachThreadOp.getMixedUpperBound(),
+                                         foreachThreadOp.getMixedStep())) {
+      std::optional<int64_t> lbConstant = getConstantIntValue(lb);
+      if (!lbConstant)
+        return true;
+
+      std::optional<int64_t> ubConstant = getConstantIntValue(ub);
+      if (!ubConstant)
+        return true;
+
+      std::optional<int64_t> stepConstant = getConstantIntValue(step);
+      if (!stepConstant)
+        return true;
+
+      if (*lbConstant + *stepConstant < *ubConstant)
+        return true;
+    }
+    return false;
   }
 };
 

diff  --git a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
index e0f236b633f82..79a688ad9d4f7 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
@@ -180,10 +180,10 @@ struct AffineOpSCFCanonicalizationPattern : public OpRewritePattern<OpTy> {
       if (scf::ForeachThreadOp foreachThreadOp =
               scf::getForeachThreadOpThreadIndexOwner(iv)) {
         for (int64_t idx = 0; idx < foreachThreadOp.getRank(); ++idx) {
-          if (foreachThreadOp.getThreadIndices()[idx] == iv) {
-            lb = OpBuilder(iv.getContext()).getIndexAttr(0);
-            ub = foreachThreadOp.getNumThreads()[idx];
-            step = OpBuilder(iv.getContext()).getIndexAttr(1);
+          if (foreachThreadOp.getInductionVar(idx) == iv) {
+            lb = foreachThreadOp.getMixedLowerBound()[idx];
+            ub = foreachThreadOp.getMixedUpperBound()[idx];
+            step = foreachThreadOp.getMixedStep()[idx];
             return success();
           }
         }

diff  --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp
index 9d30f2797c0e8..0b1ecc9115fd7 100644
--- a/mlir/lib/Interfaces/ViewLikeInterface.cpp
+++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp
@@ -69,12 +69,45 @@ mlir::detail::verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op) {
   return success();
 }
 
+static char getLeftDelimiter(AsmParser::Delimiter delimiter) {
+  switch (delimiter) {
+  case AsmParser::Delimiter::Paren:
+    return '(';
+  case AsmParser::Delimiter::LessGreater:
+    return '<';
+  case AsmParser::Delimiter::Square:
+    return '[';
+  case AsmParser::Delimiter::Braces:
+    return '{';
+  default:
+    llvm_unreachable("unsupported delimiter");
+  }
+}
+
+static char getRightDelimiter(AsmParser::Delimiter delimiter) {
+  switch (delimiter) {
+  case AsmParser::Delimiter::Paren:
+    return ')';
+  case AsmParser::Delimiter::LessGreater:
+    return '>';
+  case AsmParser::Delimiter::Square:
+    return ']';
+  case AsmParser::Delimiter::Braces:
+    return '}';
+  default:
+    llvm_unreachable("unsupported delimiter");
+  }
+}
+
 void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
                                  OperandRange values,
-                                 ArrayRef<int64_t> integers) {
-  printer << '[';
+                                 ArrayRef<int64_t> integers,
+                                 AsmParser::Delimiter delimiter) {
+  char leftDelimiter = getLeftDelimiter(delimiter);
+  char rightDelimiter = getRightDelimiter(delimiter);
+  printer << leftDelimiter;
   if (integers.empty()) {
-    printer << "]";
+    printer << rightDelimiter;
     return;
   }
   unsigned idx = 0;
@@ -84,13 +117,13 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
     else
       printer << integer;
   });
-  printer << ']';
+  printer << rightDelimiter;
 }
 
 ParseResult mlir::parseDynamicIndexList(
     OpAsmParser &parser,
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
-    DenseI64ArrayAttr &integers) {
+    DenseI64ArrayAttr &integers, AsmParser::Delimiter delimiter) {
 
   SmallVector<int64_t, 4> integerVals;
   auto parseIntegerOrValue = [&]() {
@@ -107,8 +140,7 @@ ParseResult mlir::parseDynamicIndexList(
     }
     return success();
   };
-  if (parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Square,
-                                     parseIntegerOrValue,
+  if (parser.parseCommaSeparatedList(delimiter, parseIntegerOrValue,
                                      " in dynamic index list"))
     return parser.emitError(parser.getNameLoc())
            << "expected SSA value or integer";

diff  --git a/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir b/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir
index 586d09c51ebac..7b04d6c5e152b 100644
--- a/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir
+++ b/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir
@@ -13,9 +13,7 @@ module {
 //  CHECK-SAME:   %[[B:[0-9a-z]+]]: tensor<?x?xf32>
 //  CHECK-SAME:   %[[C:[0-9a-z]+]]: tensor<?x?xf32>
   func.func @matmul(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
-  //  CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index
-  //  CHECK-DAG: %[[C20:.*]] = arith.constant 20 : index
-  //      CHECK: scf.foreach_thread ({{.*}}) in (%[[C10]], %[[C20]]) shared_outs(%[[C_BLK:.*]] = %[[C]]) -> (tensor<?x?xf32>) {
+  //      CHECK: scf.foreach_thread ({{.*}}) in (10, 20) shared_outs(%[[C_BLK:.*]] = %[[C]]) -> (tensor<?x?xf32>) {
   //      CHECK:   %[[tA:.*]] = tensor.extract_slice %[[A]]{{.*}} : tensor<?x?xf32> to tensor<?x?xf32>
   //      CHECK:   %[[tB:.*]] = tensor.extract_slice %[[B]]{{.*}} : tensor<?x?xf32> to tensor<?x?xf32>
   //      CHECK:   %[[tC:.*]] = tensor.extract_slice %[[C_BLK]]{{.*}} : tensor<?x?xf32> to tensor<?x?xf32>
@@ -95,9 +93,7 @@ transform.sequence failures(propagate) {
 //  CHECK-SAME:   %[[B:[0-9a-z]+]]: tensor
 //  CHECK-SAME:   %[[C:[0-9a-z]+]]: tensor
 func.func @matmul_static(%A: tensor<100x200xf32>, %B: tensor<200x300xf32>, %C: tensor<100x300xf32>) -> tensor<100x300xf32> {
-  //  CHECK-DAG: %[[c10:.+]] = arith.constant 10 : index
-  //  CHECK-DAG: %[[c21:.+]] = arith.constant 21 : index
-  //      CHECK: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (%[[c10]], %[[c21]]) shared_outs(%[[C_BLK:.*]] = %[[C]])
+  //      CHECK: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (10, 21) shared_outs(%[[C_BLK:.*]] = %[[C]])
   //      CHECK:   %[[TSMIN:.+]] = affine.min #[[$map0]](%[[IV1]])
   //      CHECK:   %[[TS:.+]] = affine.max #[[$map1]](%[[TSMIN]])
   //  CHECK-NOT:   affine.min
@@ -175,9 +171,7 @@ transform.sequence failures(propagate) {
 //  CHECK-SAME:   %[[B:[0-9a-z]+]]: tensor
 //  CHECK-SAME:   %[[C:[0-9a-z]+]]: tensor
 func.func @matmul_tile_size_static(%A: tensor<100x200xf32>, %B: tensor<200x300xf32>, %C: tensor<100x300xf32>) -> tensor<100x300xf32> {
-  //  CHECK-DAG: %[[c10:.+]] = arith.constant 10 :
-  //  CHECK-DAG: %[[c15:.+]] = arith.constant 15 :
-  //      CHECK: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (%[[c10]], %[[c15]]) shared_outs(%[[C_BLK:.*]] = %[[C]])
+  //      CHECK: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (10, 15) shared_outs(%[[C_BLK:.*]] = %[[C]])
   //      CHECK:   %[[TS:.+]] = affine.min #[[$map0]](%[[IV1]])
   //  CHECK-NOT:   affine.max
   //  CHECK-NOT:   affine.min
@@ -225,8 +219,7 @@ module {
 // CHECK-DAG: #[[$map0:.+]] = affine_map<(d0) -> (d0 * 2)>
 
 // CHECK-LABEL: extract_source(
-//       CHECK:  %[[C2:.*]] = arith.constant 2 : index
-//       CHECK:  scf.foreach_thread (%[[ARG:.*]]) in (%[[C2]]) shared_outs(%{{.*}} = %{{.*}}) -> (tensor<4xf32>) {
+//       CHECK:  scf.foreach_thread (%[[ARG:.*]]) in (2) shared_outs(%{{.*}} = %{{.*}}) -> (tensor<4xf32>) {
 //       CHECK:    %[[OFF:.*]] = affine.apply #[[$map0]](%[[ARG]])
 //       CHECK:    scf.foreach_thread.perform_concurrently {
 //       CHECK:      tensor.parallel_insert_slice %{{.*}} into %{{.*}}[%[[OFF]]] [2] [1] : tensor<2xf32> into tensor<4xf32>
@@ -289,8 +282,7 @@ transform.sequence failures(propagate) {
   func.func @tile_output_multi_1d_static(%IN1: tensor<100xf32>, %IN2: tensor<100xf32>,
                                          %OUT1: tensor<100xf32>, %OUT2: tensor<100xf32>)
                                          -> (tensor<100xf32>, tensor<100xf32>) {
-//  CHECK-DAG: %[[c0:.+]] = arith.constant 7 :
-//      CHECK: scf.foreach_thread (%[[IV0:.+]]) in (%[[c0]]) shared_outs(%[[OUT1:[0-9a-z]+]] = %[[ORGOUT1]], %[[OUT2:[0-9a-z]+]] = %[[ORGOUT2]])
+//      CHECK: scf.foreach_thread (%[[IV0:.+]]) in (7) shared_outs(%[[OUT1:[0-9a-z]+]] = %[[ORGOUT1]], %[[OUT2:[0-9a-z]+]] = %[[ORGOUT2]])
 //      CHECK:   %[[TSMIN:.+]] = affine.min #[[$map0]](%[[IV0]])
 //      CHECK:   %[[TS:.+]] = affine.max #[[$map1]](%[[TSMIN]])
 //  CHECK-NOT:   affine.min
@@ -345,8 +337,7 @@ transform.sequence failures(propagate) {
   func.func @tile_output_multi_1d2d_static(%IN1: tensor<100xf32>, %IN2: tensor<100x300xf32>, %IN3: tensor<300xf32>,
                      %OUT1: tensor<300x100xf32>, %OUT2: tensor<300xf32>)
                      -> (tensor<300x100xf32>, tensor<300xf32>) {
-//  CHECK-DAG: %[[c0:.+]] = arith.constant 4 :
-//      CHECK: scf.foreach_thread (%[[IV0:.+]]) in (%[[c0]]) shared_outs(%[[OUT1:[0-9a-z]+]] = %[[ORGOUT1]], %[[OUT2:[0-9a-z]+]] = %[[ORGOUT2]])
+//      CHECK: scf.foreach_thread (%[[IV0:.+]]) in (4) shared_outs(%[[OUT1:[0-9a-z]+]] = %[[ORGOUT1]], %[[OUT2:[0-9a-z]+]] = %[[ORGOUT2]])
 //      CHECK:   %[[LB:.+]] = affine.apply #[[$map0]](%[[IV0]])
 //      CHECK:   %[[tIN1:.+]] = tensor.extract_slice %[[IN2]][0, %[[LB]]] [100, 75]
 //      CHECK:   %[[tIN2:.+]] = tensor.extract_slice %[[IN3]][%[[LB]]] [75]

diff  --git a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
index 4525527bd7777..0c4d9f054ae39 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
@@ -122,13 +122,12 @@ transform.sequence failures(propagate) {
 // CHECK-DAG:   %[[I:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG:   %[[C5:.*]] = arith.constant 5 : index
 // CHECK-DAG:   %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
 // CHECK-DAG:   %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
 // CHECK-DAG:   %[[D2:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?xf32>
 //     CHECK:   %[[E:.*]] = tensor.empty(%[[D2]]) : tensor<?x5xf32>
 //     CHECK:   %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor<?x5xf32>) -> tensor<?x5xf32>
-//     CHECK:   %[[L:.*]] = scf.foreach_thread (%[[IV:.+]]) in (%[[C5]]) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor<?x5xf32>) {
+//     CHECK:   %[[L:.*]] = scf.foreach_thread (%[[IV:.+]]) in (5) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor<?x5xf32>) {
 // CHECK-DAG:     %[[TS0:.+]] = affine.min #[[MAP0]](%[[IV]])[%[[D1]]]
 // CHECK-DAG:     %[[TS1:.+]] = affine.max #[[MAP1]](%[[TS0]])
 // CHECK-DAG:     %[[ET:.+]] = tensor.extract_slice %[[ARG3:.+]][0, %[[IV]]] [%[[D0]], 1] [1, 1] : tensor<?x5xf32> to tensor<?xf32>
@@ -175,7 +174,6 @@ transform.sequence failures(propagate) {
 // CHECK-DAG:   %[[I:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG:   %[[C5:.*]] = arith.constant 5 : index
 // CHECK-DAG:   %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
 // CHECK-DAG:   %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
 // CHECK-DAG:   %[[D2:.*]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
@@ -183,7 +181,7 @@ transform.sequence failures(propagate) {
 // CHECK-DAG:   %[[D4:.*]] = tensor.dim %[[ARG2]], %[[C1]] : tensor<?x?xf32>
 //     CHECK:   %[[E:.*]] = tensor.empty(%[[D3]], %[[D4]]) : tensor<?x?x5xf32>
 //     CHECK:   %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor<?x?x5xf32>) -> tensor<?x?x5xf32>
-//     CHECK:   %[[L:.*]] = scf.foreach_thread (%[[IV:.+]]) in (%[[C5]]) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor<?x?x5xf32>) {
+//     CHECK:   %[[L:.*]] = scf.foreach_thread (%[[IV:.+]]) in (5) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor<?x?x5xf32>) {
 // CHECK-DAG:     %[[TS0:.+]] = affine.min #[[MAP0]](%[[IV]])[%[[D1]]]
 // CHECK-DAG:     %[[TS1:.+]] = affine.max #[[MAP1]](%[[TS0]])
 // CHECK-DAG:     %[[ET:.+]] = tensor.extract_slice %[[ARG3:.+]][0, 0, %[[IV]]] [%[[D0]], %[[D2]], 1] [1, 1, 1] : tensor<?x?x5xf32> to tensor<?x?xf32>
@@ -235,13 +233,12 @@ transform.sequence failures(propagate) {
 // CHECK-DAG:   %[[I:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG:   %[[C5:.*]] = arith.constant 5 : index
 // CHECK-DAG:   %[[C15:.*]] = arith.constant 15 : index
 // CHECK-DAG:   %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
 // CHECK-DAG:   %[[D2:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?xf32>
 //     CHECK:   %[[E:.*]] = tensor.empty(%[[D2]]) : tensor<?x5xf32>
 //     CHECK:   %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor<?x5xf32>) -> tensor<?x5xf32>
-//     CHECK:   %[[L:.*]] = scf.foreach_thread (%[[IV:.+]]) in (%[[C5]]) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor<?x5xf32>) {
+//     CHECK:   %[[L:.*]] = scf.foreach_thread (%[[IV:.+]]) in (5) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor<?x5xf32>) {
 //     CHECK:     %[[ET:.+]] = tensor.extract_slice %[[ARG3:.+]][0, %[[IV]]] [%[[D0]], 1] [1, 1] : tensor<?x5xf32> to tensor<?xf32>
 //     CHECK:     %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
 //     CHECK:     %[[LB:.+]] = affine.apply #[[MAP0]]()[%[[IV]]]

diff  --git a/mlir/test/Dialect/SCF/ops.mlir b/mlir/test/Dialect/SCF/ops.mlir
index 18413e8e2a2fb..6d79f2cd01740 100644
--- a/mlir/test/Dialect/SCF/ops.mlir
+++ b/mlir/test/Dialect/SCF/ops.mlir
@@ -311,8 +311,8 @@ func.func @execute_region() -> i64 {
   return %res : i64
 }
 
-// CHECK-LABEL: func.func @simple_example
-func.func @simple_example(%in: tensor<100xf32>, %out: tensor<100xf32>) {
+// CHECK-LABEL: func.func @normalized_foreach_thread
+func.func @normalized_foreach_thread(%in: tensor<100xf32>, %out: tensor<100xf32>) {
   %c1 = arith.constant 1 : index
   %num_threads = arith.constant 100 : index
 
@@ -333,8 +333,32 @@ func.func @simple_example(%in: tensor<100xf32>, %out: tensor<100xf32>) {
   return
 }
 
-// CHECK-LABEL: func.func @elide_terminator
-func.func @elide_terminator() -> () {
+// CHECK-LABEL: func.func @explicit_loop_bounds_foreach_thread
+func.func @explicit_loop_bounds_foreach_thread(%in: tensor<100xf32>,
+    %out: tensor<100xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %num_threads = arith.constant 100 : index
+
+  //      CHECK:    scf.foreach_thread
+  // CHECK-NEXT:  tensor.extract_slice
+  // CHECK-NEXT:  scf.foreach_thread.perform_concurrently
+  // CHECK-NEXT:  tensor.parallel_insert_slice
+  // CHECK-NEXT:  }
+  // CHECK-NEXT:  }
+  // CHECK-NEXT:  return
+  %result = scf.foreach_thread (%thread_idx) =  (%c0) to (%num_threads) step (%c1) shared_outs(%o = %out) -> tensor<100xf32> {
+      %1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32>
+      scf.foreach_thread.perform_concurrently {
+        tensor.parallel_insert_slice %1 into %o[%thread_idx][1][1] :
+          tensor<1xf32> into tensor<100xf32>
+      }
+  }
+  return
+}
+
+// CHECK-LABEL: func.func @normalized_foreach_thread_elide_terminator
+func.func @normalized_foreach_thread_elide_terminator() -> () {
   %num_threads = arith.constant 100 : index
 
   //      CHECK:    scf.foreach_thread
@@ -345,6 +369,23 @@ func.func @elide_terminator() -> () {
     }
   } {mapping = [#gpu.thread<x>]}
   return
+
+}
+
+// CHECK-LABEL: func.func @explicit_loop_bounds_foreach_thread_elide_terminator
+func.func @explicit_loop_bounds_foreach_thread_elide_terminator() -> () {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %num_threads = arith.constant 100 : index
+
+  //      CHECK:    scf.foreach_thread
+  // CHECK-NEXT:  } {mapping = [#gpu.thread<x>]}
+  // CHECK-NEXT:  return
+  scf.foreach_thread (%thread_idx) = (%c0) to (%num_threads) step (%c1) {
+    scf.foreach_thread.perform_concurrently {
+    }
+  } {mapping = [#gpu.thread<x>]}
+  return
 }
 
 // CHECK-LABEL: @switch

diff  --git a/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir b/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir
index 541a54a911289..f366d331dff58 100644
--- a/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir
+++ b/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir
@@ -24,12 +24,11 @@ func.func @extract_slice_static(%input: tensor<3x5x7x11xf32>) -> tensor<20x11xf3
 //     CHECK: return %[[tile]]
 
 //     FOREACH: func.func @extract_slice_static(%[[arg0:.+]]:
-// FOREACH-DAG: %[[c20:.+]] = arith.constant 20 : index
 // FOREACH-DAG: %[[c3:.+]] = arith.constant 3 : index
 // FOREACH-DAG: %[[c5:.+]] = arith.constant 5 : index
 // FOREACH-DAG: %[[c7:.+]] = arith.constant 7 : index
 // FOREACH-DAG: %[[init:.+]] = tensor.empty() : tensor<20x11xf32>
-//     FOREACH: %[[tile:.+]] = scf.foreach_thread (%[[iv:.+]]) in (%[[c20]]) shared_outs(%[[dest:.+]] = %[[init]])
+//     FOREACH: %[[tile:.+]] = scf.foreach_thread (%[[iv:.+]]) in (20) shared_outs(%[[dest:.+]] = %[[init]])
 //     FOREACH:   %[[multiIndex:.+]]:3 = affine.delinearize_index %[[iv]] into (%[[c3]], %[[c5]], %[[c7]]
 //     FOREACH:   %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0] [1, 1, 1, 11] [1, 1, 1, 1] :
 //     FOREACH:   %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3]{{\]}} :

diff  --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
index 25039a5c41f92..6dcf6cf97339d 100644
--- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
@@ -247,9 +247,10 @@ struct RewriteExtractSliceFromCollapseShapeUsingScfForeach
                                 tensor::ExtractSliceFromCollapseHelper &helper,
                                 PatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
-    auto foreachOp = rewriter.create<scf::ForeachThreadOp>(
-        loc, /*outputs=*/dest, /*numThreads=*/helper.getIterationSpaceSizes(),
-        /*mapping=*/ArrayRef<Attribute>{},
+    auto foreachThreadOp = rewriter.create<scf::ForeachThreadOp>(
+        loc, /*numThreads=*/getAsOpFoldResult(helper.getIterationSpaceSizes()),
+        /*outputs=*/dest,
+        /*mapping=*/std::nullopt,
         [&](OpBuilder &nestedBuilder, Location loc, ValueRange regionArgs) {
           unsigned numThreadIdRegionArgs =
               helper.getIterationSpaceSizes().size();
@@ -267,7 +268,7 @@ struct RewriteExtractSliceFromCollapseShapeUsingScfForeach
           nestedBuilder.create<tensor::ParallelInsertSliceOp>(
               loc, tile, outputArgs[0], insertParams);
         });
-    rewriter.replaceOp(op, foreachOp->getResult(0));
+    rewriter.replaceOp(op, foreachThreadOp->getResult(0));
     return success();
   }
 };


        


More information about the Mlir-commits mailing list