[Mlir-commits] [mlir] c59465e - [mlir][Transform] Add support for mapping to GPU warps and to linear ids

Nicolas Vasilache llvmlistbot at llvm.org
Mon Mar 20 01:05:42 PDT 2023


Author: Nicolas Vasilache
Date: 2023-03-20T01:05:32-07:00
New Revision: c59465e1203dd78d06e15f7ddf62141807dbd5a7

URL: https://github.com/llvm/llvm-project/commit/c59465e1203dd78d06e15f7ddf62141807dbd5a7
DIFF: https://github.com/llvm/llvm-project/commit/c59465e1203dd78d06e15f7ddf62141807dbd5a7.diff

LOG: [mlir][Transform] Add support for mapping to GPU warps and to linear ids

This revisions refactors the implementation of mapping to threads to additionally allow warps and linear ids to be specified.

`warp_dims` is currently specified along with `block_dims` as a transform attribute.

Linear ids on th other hand use the flattened block_dims to predicate on the first (linearized) k threads.
An additional GPULinearIdMappingAttr is added to the GPU dialect to allow specifying loops mapped to this new scheme.

Various implementation and transform op semantics cleanups are also applied.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D146130

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/GPU/TransformOps/GPUDeviceMappingAttr.td
    mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h
    mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td
    mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
    mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
    mlir/test/Dialect/GPU/transform-gpu-failing.mlir
    mlir/test/Dialect/GPU/transform-gpu.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUDeviceMappingAttr.td b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUDeviceMappingAttr.td
index 3b261acdee83a..699390c2f2959 100644
--- a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUDeviceMappingAttr.td
+++ b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUDeviceMappingAttr.td
@@ -64,6 +64,41 @@ def GPUWarpMappingAttr : GPU_Attr<"GPUWarpMapping", "warp", [
   }];
 }
 
+def LinearIdEnum : I64EnumAttr<"LinearId", "linear ids for loop mapping", [
+    DimX, DimY, DimZ]> {
+  let cppNamespace = "::mlir::gpu";
+}
+
+def GPULinearIdMapping : GPU_Attr<"GPULinearIdMapping", "linear", [
+  DeclareAttrInterfaceMethods<DeviceMappingAttrInterface> ] >  {
+  let parameters = (ins
+    EnumParameter<LinearIdEnum>:$linear_id
+  );
+  let assemblyFormat = "`<` params `>`";
+  let description = [{
+    An attribute to allow re-interpreting the linear mapping for threads in GPU
+    devices.
+
+    Threads (aka work item) are grouped into a thread block where block may be
+    described by a 1-, 2-, or 3-dimensional rectangular basis.
+    The linear thread id is obtained by linearizing the 1-, 2- or 3-dimensional
+    index. For instance, if the basis is denoted as (BX, BY, BZ) and the thread
+    id is denoted by (tx, ty, tz), the linear thread id is:
+      `linear_id = tx + ty * BX + tz * BX * BY)`.
+    The linear thread id is fixed for the duration of a GPU kernel.
+    
+    This linear id mapping attribute indicates a 
diff erent linearization relation
+    is applied locally to a loop nest. 
+    
+    For instance, if the new basis is denoted as (LBX, LBY, LBZ) the thread id
+    in the new basis is:
+      `(linear_id mod LBX , (linear_id / LBX) mod * LBY, linear_id / (LBX * LBY))`.
+    This reinterpretation is only fixe for the duration of a loop nest.
+    
+    It can be consumed by lowering to generate GPU code.
+  }];
+}
+
 def BlocksEnum : I64EnumAttr<"Blocks", "threads for loop mapping", [
     DimX, DimY, DimZ]> {
   let cppNamespace = "::mlir::gpu";

diff  --git a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h
index 579922a3a9c03..57d74d856cba7 100644
--- a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h
+++ b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h
@@ -33,33 +33,94 @@ class DialectRegistry;
 namespace transform {
 namespace gpu {
 
+/// Helper type for functions that generate ids for the mapping of a
+/// scf.forall.
+struct IdBuilderResult {
+  // Ops used to replace the forall induction variables.
+  SmallVector<Value> mappingIdOps;
+  // Actual mapping sizes used to predicate the forall body when they are
+  // smaller than the available mapping sizes.
+  SmallVector<int64_t> predicateMappingSizes;
+  // Ops used to predicate the forall body when predicateMappingSizes is smaller
+  // than the available mapping sizes.
+  SmallVector<Value> predicateIdOps;
+};
+
+/// Common gpu id builder type, allows the configuration of lowering for various
+/// mapping schemes. Takes:
+///   - A rewriter with insertion point set before the forall op to rewrite.
+///   - The loc of the forall op to rewrite.
+///   - A list of positive integers carrying the mapping sizes for the current
+///     forall op to rewrite.
+using GpuIdBuilderFnType =
+    std::function<IdBuilderResult(RewriterBase &, Location, ArrayRef<int64_t>)>;
+
+/// Helper struct for configuring the rewrite of mapped scf.forall ops to
+/// various gpu id configurations.
+struct GpuIdBuilder {
+  GpuIdBuilder(ArrayRef<OpFoldResult> blockDims, ArrayRef<int64_t> mappingSizes)
+      : blockDimsOfr(blockDims), availableMappingSizes(mappingSizes),
+        mappingAttributes(), idBuilder() {}
+
+  /// List of OpFoldResult carrying the  multi-dimensional number of
+  /// threads available in the current kernel (i.e. the current blockDims in
+  /// CUDA parlance).
+  ArrayRef<OpFoldResult> blockDimsOfr;
+
+  /// A list of positive integers carrying the number of available mapping
+  /// resources that can trigger predication,
+  ArrayRef<int64_t> availableMappingSizes;
+
+  /// The mapping attributes targeted by this generator.
+  SmallVector<DeviceMappingAttrInterface> mappingAttributes;
+
+  /// The constructor that builds the concrete IR for mapping ids.
+  GpuIdBuilderFnType idBuilder;
+};
+
 /// Map the top level `scf.forall` op to GPU Thread Blocks.
 /// Mapping is one-to-one and the induction variables of `scf.forall` are
-/// rewritten to gpu.block_id according to the thread_dim_apping attribute.
+/// rewritten to gpu.block_id according to the thread_dim_mapping attribute.
+///
 /// Dynamic, `scf.forall` trip counts are currently not supported.
 /// Dynamic block dim sizes are currently not supported.
-DiagnosedSilenceableFailure mapForallToBlocksImpl(
-    RewriterBase &rewriter, TransformOpInterface transformOp,
-    scf::ForallOp forallOp, SmallVectorImpl<int64_t> &gridDims,
-    const ArrayRef<DeviceMappingAttrInterface> &mappingAttributes,
-    function_ref<void(RewriterBase &, scf::ForallOp, SmallVectorImpl<Value> &)>
-        blockIdGenerator);
-
-/// Search `scf.forall` ops nested under `target` and map each such op to GPU
-/// threads. Mapping is one-to-one and the induction variables of `scf.forall`
-/// are rewritten to gpu.thread_id according to the thread_dim_mapping
-/// attribute.
-/// Sibling `scf.forall` are supported in which case, the union of the number of
-/// threads is computed and may result in predication.
+DiagnosedSilenceableFailure
+mapForallToBlocksImpl(RewriterBase &rewriter, TransformOpInterface transformOp,
+                      scf::ForallOp forallOp,
+                      SmallVectorImpl<int64_t> &gridDims,
+                      const GpuIdBuilder &gpuIdBuilder);
+
+/// Search `scf.forall` ops nested under `target` and map each such op to an
+/// explicit GPU implementation along `availableMappingSizes`.
+/// The mapping is one-to-one and the induction variables of `scf.forall` are
+/// rewritten to gpuIdBuilder.idBuilder according to the
+/// gpuIdBuilder.mappingAttributes attribute.
+///
 /// Dynamic, `scf.forall` trip counts are currently not supported.
-/// Dynamic block dim sizes are currently not supported.
+/// Dynamic `availableMappingSizes` sizes are currently not supported.
+/// `availableMappingSizes` is expected to be of size 3.
+DiagnosedSilenceableFailure mapOneForallToThreadsImpl(
+    RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
+    scf::ForallOp forallOp, ArrayRef<int64_t> availableMappingSizes,
+    bool syncAfterDistribute, const GpuIdBuilder &gpuIdBuilder);
+
+/// Search `scf.forall` ops nested under `target` and map each such op to an
+/// explicit GPU implementation along blockDims and warpDims.
+/// The mapping is one-to-one and the induction variables of `scf.forall` are
+/// rewritten to threads and warps ids according to the mapping attribute.
+///
+/// Dynamic, `scf.forall` trip counts are currently not supported.
+/// Dynamic `blockDims` or `warpDims` or `linearDims` sizes are currently not
+/// supported.
+/// `blockDims` is expected to be of size 3.
+/// `warpDims` is expected to be empty or of size 3.
+///
+/// The insertion point of the `rewriter` is expected to be set at the
+/// beginning of the `target` body block and dominate all other blocks.
 DiagnosedSilenceableFailure mapNestedForallToThreadsImpl(
     RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
-    Operation *target, const SmallVectorImpl<int64_t> &kernelBlockDims,
-    bool syncAfterDistribute,
-    const ArrayRef<DeviceMappingAttrInterface> &threadMappingAttributes,
-    function_ref<void(RewriterBase &, scf::ForallOp, SmallVectorImpl<Value> &)>
-        threadIdGenerator);
+    Operation *target, ArrayRef<int64_t> blockDimsOfr,
+    ArrayRef<int64_t> warpDims, bool syncAfterDistribute);
 
 /// Find the unique top level scf::ForallOp within a given target op.
 DiagnosedSilenceableFailure

diff  --git a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td
index 46f0e186741e8..c719fedc90e33 100644
--- a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td
@@ -22,21 +22,26 @@ def MapNestedForallToThreads :
      TransformEachOpTrait,
      TransformOpInterface]> {
   let description = [{
-      Target the `gpu.launch op` and rewrite all `scf.forall`
-      nested in it to distributed `gpu.thread_id` attribute.
-
-      The operation searches for `scf.forall` ops nested under `target`
-      and maps each such op to GPU threads. Mapping is one-to-one and the
-      induction variables of `scf.forall` are rewritten to
-      `gpu.thread_id` according to the `mapping` attribute.
-
-      Sibling `scf.forall` are supported in which case, the union of
-      the number of threads is computed and may result in predication.
-
-      Multiple scf.forall are supported per `gpu.launch` in which case,
-      the max of all the threads is computed and taken for the global
-      `gpu.thread_id`. If necessary, `scf.forall` that do not use the
-      whole thread range result in predicated computations.
+      Target the `gpu.launch op` and rewrite all `scf.forall` nested in it to 
+      distributed `gpu.thread_id` attribute.
+
+      The operation searches for `scf.forall` ops nested under `target` and maps
+      each such op to GPU threads. 
+      
+      `scf.forall` induction variables are rewritten to `gpu.thread_id` according
+      to the `mapping` attribute.
+
+      Different types of mappings attributes are supported:
+        - the block_dims is a list of integers that specifies the number of
+          threads in each dimension. This is a mandatory attribute that is used
+          to constrain the number of threads in each dimension. If an 
+          `scf.forall` op is mapped to fewer threads, predication occurs.
+        - the warp_dims is a list of integers that specifies the number of
+          warps in each dimension. This is an optional attribute that is used
+          to constrain the number of warps in each dimension. When present, this
+          attribute must be specified in a way that is compatible with the 
+          block_dims attribute. If an `scf.forall` op is mapped to fewer warps,
+          predicaiton occurs.
 
       Dynamic `scf.forall` trip counts are currently not supported.
       Dynamic block dim sizes are currently not supported.
@@ -45,10 +50,12 @@ def MapNestedForallToThreads :
       Only `scf.forall` distributed to **at most 3 dimensions** are
       currently supported.
 
-      Barriers are inserted after each scf.forall op for now.
+      The `sync_after_distribute`attribute controls whether a `gpu.barrier` is
+      inserted after each scf.forall op. At this time, this is an all or nothing
+      choice. This will need to be tightened in the future.
 
-      The operation alters the block size of the given gpu_launch using
-      blockDim argument.
+      The operation alters the block size of the given gpu_launch using the 
+      mandatory block_dims argument.
 
       #### Return modes:
 
@@ -83,6 +90,7 @@ def MapNestedForallToThreads :
         gpu.terminator
       }
       ```
+
       is translated to:
 
       ```
@@ -104,11 +112,18 @@ def MapNestedForallToThreads :
     }];
 
   let arguments = (ins PDL_Operation:$target,
-                   DefaultValuedAttr<I64ArrayAttr, "{}">:$blockDim,
-                   DefaultValuedAttr<BoolAttr, "true">:$syncAfterDistribute);
+                   DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$block_dims,
+                   DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$warp_dims,
+                   DefaultValuedAttr<BoolAttr, "true">:$sync_after_distribute);
   let results = (outs PDL_Operation:$result);
 
-  let assemblyFormat = "$target attr-dict";
+  let assemblyFormat = [{
+    $target
+    `block_dims` `=` $block_dims
+    (`warp_dims` `=` $warp_dims^)?
+    (`sync_after_distribute` `=` $sync_after_distribute^)?
+    attr-dict
+  }];
   let extraClassDeclaration = [{
     ::mlir::DiagnosedSilenceableFailure applyToOne(
         ::mlir::Operation *target,
@@ -117,7 +132,6 @@ def MapNestedForallToThreads :
   }];
 }
 
-
 def MapForallToBlocks :
   Op<Transform_Dialect, "gpu.map_forall_to_blocks",
     [FunctionalStyleTransformOpTrait,
@@ -142,8 +156,8 @@ def MapForallToBlocks :
     Only scf.forall distributed to **at most 3 dimensions** are
     currently supported.
 
-    The operation alters the block size of the given gpu_launch using
-    gridDim argument.
+    The operation alters the block size of the given gpu_launch using the 
+    grid_dims argument.
 
     #### Return modes:
 
@@ -162,11 +176,16 @@ def MapForallToBlocks :
   }];
 
   let arguments = (ins PDL_Operation:$target,
-                   DefaultValuedAttr<I64ArrayAttr, "{}">:$gridDim,
+                   DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$grid_dims,
                    UnitAttr:$generate_gpu_launch);
   let results = (outs PDL_Operation:$result);
 
-  let assemblyFormat = "$target attr-dict";
+  let assemblyFormat = [{
+    $target
+    (`generate_gpu_launch` $generate_gpu_launch^)?
+    (`grid_dims` `=` $grid_dims^)?
+    attr-dict
+  }];
   let extraClassDeclaration = [{
     ::mlir::DiagnosedSilenceableFailure applyToOne(
         ::mlir::Operation *target,

diff  --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 0ec5877f80361..f9d929d163445 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -46,6 +46,10 @@ int64_t GPUWarpMappingAttr::getMappingId() const {
   return static_cast<int64_t>(getWarp());
 }
 
+int64_t GPULinearIdMappingAttr::getMappingId() const {
+  return static_cast<int64_t>(getLinearId());
+}
+
 int64_t GPUThreadMappingAttr::getMappingId() const {
   return static_cast<int64_t>(getThread());
 }

diff  --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
index 748d9e46ac153..f1559970d36d9 100644
--- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
+++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
@@ -8,7 +8,9 @@
 
 #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h"
 
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h"
 #include "mlir/Dialect/PDL/IR/PDL.h"
@@ -16,9 +18,14 @@
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/IRMapping.h"
+#include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/Visitors.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
@@ -27,6 +34,7 @@
 using namespace mlir;
 using namespace mlir::gpu;
 using namespace mlir::transform;
+using namespace mlir::transform::gpu;
 
 #define DEBUG_TYPE "gpu-transforms"
 
@@ -35,68 +43,200 @@ using namespace mlir::transform;
 
 namespace {
 
-/// Helper type for functions that generate ids for the mapping of a scf.forall.
-using IdGeneratorFnType = llvm::function_ref<void(RewriterBase &, scf::ForallOp,
-                                                  SmallVectorImpl<Value> &)>;
+/// Return a flattened thread id for the workgroup with given sizes.
+static Value buildLinearThreadId(RewriterBase &rewriter, Location loc,
+                                 ArrayRef<OpFoldResult> blockDimsOfr) {
+  LLVM_DEBUG(llvm::interleaveComma(
+                 blockDimsOfr,
+                 DBGS() << "----buildLinearThreadId with blockDimsOfr:  ");
+             llvm::dbgs() << "\n");
+  assert(blockDimsOfr.size() == 3 && "expected 3 workgroup sizes");
+  AffineExpr tx, ty, tz, BDX, BDY;
+  bindDims(rewriter.getContext(), tx, ty, tz);
+  bindSymbols(rewriter.getContext(), BDX, BDY);
+  IndexType indexType = rewriter.getIndexType();
+  SmallVector<OpFoldResult> threadsAndWorkGroups{
+      rewriter.create<ThreadIdOp>(loc, indexType, Dimension::x).getResult(),
+      rewriter.create<ThreadIdOp>(loc, indexType, Dimension::y).getResult(),
+      rewriter.create<ThreadIdOp>(loc, indexType, Dimension::z).getResult()};
+  threadsAndWorkGroups.push_back(blockDimsOfr[0]);
+  threadsAndWorkGroups.push_back(blockDimsOfr[1]);
+  OpFoldResult ofr = makeComposedFoldedAffineApply(
+      rewriter, loc, tx + ty * BDX + tz * BDX * BDY, threadsAndWorkGroups);
+  return getValueOrCreateConstantIndexOp(rewriter, loc, ofr);
+}
 
-struct MappingToGpuHelper {
-  MappingToGpuHelper(SmallVector<DeviceMappingAttrInterface> mappingAttributes,
-                     IdGeneratorFnType idGenerator)
-      : mappingAttributes(mappingAttributes), idGenerator(idGenerator) {}
+/// Builder for gpu::BlockIdOps used in mapping scf.forall to blocks.
+/// The `idBuilder` method returns 3-D values used for indexing rewrites as well
+/// as 3-D sizes for predicate generation.
+struct GpuBlockIdBuilder : public GpuIdBuilder {
+
+  GpuBlockIdBuilder(MLIRContext *ctx, ArrayRef<OpFoldResult> blockDims,
+                    ArrayRef<int64_t> mappingSizes)
+      : GpuIdBuilder(blockDims, mappingSizes) {
+    mappingAttributes = {GPUBlockMappingAttr::get(ctx, Blocks::DimX),
+                         GPUBlockMappingAttr::get(ctx, Blocks::DimY),
+                         GPUBlockMappingAttr::get(ctx, Blocks::DimZ)},
+    idBuilder = [](RewriterBase &rewriter, Location loc,
+                   ArrayRef<int64_t> forallMappingSizes) {
+      IndexType indexType = rewriter.getIndexType();
+      SmallVector<Value> ids{
+          rewriter.create<BlockIdOp>(loc, indexType, Dimension::x),
+          rewriter.create<BlockIdOp>(loc, indexType, Dimension::y),
+          rewriter.create<BlockIdOp>(loc, indexType, Dimension::z)};
+      // Return 3-D ids for indexing rewrite and 3-D sizes and ids for
+      // predicate generation.
+      return IdBuilderResult{ids, SmallVector<int64_t>{forallMappingSizes},
+                             ids};
+    };
+  }
+};
 
-  SmallVector<DeviceMappingAttrInterface> mappingAttributes;
-  IdGeneratorFnType idGenerator;
+/// Builder for gpu::ThreadIdOp used in mapping scf.forall to thread ids without
+/// any reindexing.
+/// The `idBuilder` method returns 3-D values used for indexing rewrites as well
+/// as 3-D sizes for predicate generation.
+struct GpuThreadIdBuilder : public GpuIdBuilder {
+  GpuThreadIdBuilder(MLIRContext *ctx, ArrayRef<OpFoldResult> blockDims,
+                     ArrayRef<int64_t> mappingSizes)
+      : GpuIdBuilder(blockDims, mappingSizes) {
+    mappingAttributes = {GPUThreadMappingAttr::get(ctx, Threads::DimX),
+                         GPUThreadMappingAttr::get(ctx, Threads::DimY),
+                         GPUThreadMappingAttr::get(ctx, Threads::DimZ)};
+    idBuilder = [](RewriterBase &rewriter, Location loc,
+                   ArrayRef<int64_t> forallMappingSizes) {
+      IndexType indexType = rewriter.getIndexType();
+      SmallVector<Value> ids{
+          rewriter.create<ThreadIdOp>(loc, indexType, Dimension::x),
+          rewriter.create<ThreadIdOp>(loc, indexType, Dimension::y),
+          rewriter.create<ThreadIdOp>(loc, indexType, Dimension::z)};
+      // Return 3-D ids for indexing rewrite and 3-D sizes and ids for
+      // predicate generation.
+      return IdBuilderResult{ids, SmallVector<int64_t>{forallMappingSizes},
+                             ids};
+    };
+  }
 };
 
-struct MappingToGpuBlocksHelper : public MappingToGpuHelper {
-
-  MappingToGpuBlocksHelper(MLIRContext *ctx)
-      : MappingToGpuHelper(
-            SmallVector<DeviceMappingAttrInterface>{
-                GPUBlockMappingAttr::get(ctx, Blocks::DimX),
-                GPUBlockMappingAttr::get(ctx, Blocks::DimY),
-                GPUBlockMappingAttr::get(ctx, Blocks::DimZ)},
-            IdGeneratorFnType{[](RewriterBase &rewriter, scf::ForallOp forallOp,
-                                 SmallVectorImpl<Value> &ids) {
-              OpBuilder::InsertionGuard guard(rewriter);
-              rewriter.setInsertionPoint(forallOp);
-              IndexType indexType = rewriter.getIndexType();
-              auto loc = forallOp->getLoc();
-              ids.assign(
-                  {rewriter.create<BlockIdOp>(loc, indexType, Dimension::x),
-                   rewriter.create<BlockIdOp>(loc, indexType, Dimension::y),
-                   rewriter.create<BlockIdOp>(loc, indexType, Dimension::z)});
-            }}) {}
+/// Builder for warp ids used in mapping scf.forall to warps.
+/// This builder requires a specification of the number of warps along each
+/// dimension to more finely control mapping to warps as well a predication than
+/// by solely analyzing the IR.
+/// The `idBuilder` method returns 3-D values used for indexing rewrites as well
+/// as 3-D sizes for predicate generation.
+struct GpuWarpIdBuilder : public GpuIdBuilder {
+  GpuWarpIdBuilder(MLIRContext *ctx, ArrayRef<OpFoldResult> blockDims,
+                   ArrayRef<int64_t> mappingSizes)
+      : GpuIdBuilder(blockDims, mappingSizes) {
+    mappingAttributes = {GPUWarpMappingAttr::get(ctx, Warps::DimX),
+                         GPUWarpMappingAttr::get(ctx, Warps::DimY),
+                         GPUWarpMappingAttr::get(ctx, Warps::DimZ)};
+    idBuilder = [this](RewriterBase &rewriter, Location loc,
+                       ArrayRef<int64_t> forallMappingSizes) {
+      // Build the linear warp id and decompose it in the basis of
+      // `forallMappingSizes`.
+      Value linearId = buildLinearThreadId(rewriter, loc, this->blockDimsOfr);
+      AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
+      OpFoldResult warpIdOfr = makeComposedFoldedAffineApply(
+          rewriter, loc, d0.floorDiv(kWarpSize), {linearId});
+      Value warpId = getValueOrCreateConstantIndexOp(rewriter, loc, warpIdOfr);
+      SmallVector<int64_t> reverseBasisSizes(
+          llvm::reverse(this->availableMappingSizes));
+      SmallVector<int64_t> strides = computeStrides(reverseBasisSizes);
+      SmallVector<AffineExpr> delinearizingExprs = delinearize(d0, strides);
+      SmallVector<Value> ids;
+      for (AffineExpr e : delinearizingExprs)
+        ids.push_back(makeComposedAffineApply(rewriter, loc, e, warpId));
+
+      // clang-format off
+      LDBG("----linearId: " << linearId);
+          LDBG("----warpId: " << warpId);
+      LLVM_DEBUG(llvm::interleaveComma(reverseBasisSizes,
+                                       DBGS() << "--delinearization basis: ");
+                 llvm::dbgs() << "\n";
+                 llvm::interleaveComma(strides,
+                                       DBGS() << "--delinearization strides: ");
+                 llvm::dbgs() << "\n";
+                 llvm::interleaveComma(delinearizingExprs,
+                                       DBGS() << "--delinearization exprs: ");
+                 llvm::dbgs() << "\n";
+                 llvm::interleaveComma(ids, DBGS() << "--ids: ");
+                 llvm::dbgs() << "\n";);
+      // clang-format on
+
+      // Return 3-D ids for indexing rewrite and 3-D sizes and ids for
+      // predicate generation.
+      return IdBuilderResult{ids, SmallVector<int64_t>{forallMappingSizes},
+                             ids};
+    };
+  }
+
+  /// Static specification of the warp size.
+  /// In the future this may be configured by the transformation.
+  static constexpr int64_t kWarpSize = 32;
 };
 
-struct MappingToGpuThreadsHelper : public MappingToGpuHelper {
-  MappingToGpuThreadsHelper(MLIRContext *ctx)
-      : MappingToGpuHelper(
-            SmallVector<DeviceMappingAttrInterface>{
-                GPUThreadMappingAttr::get(ctx, Threads::DimX),
-                GPUThreadMappingAttr::get(ctx, Threads::DimY),
-                GPUThreadMappingAttr::get(ctx, Threads::DimZ)},
-            IdGeneratorFnType{[](RewriterBase &rewriter, scf::ForallOp forallOp,
-                                 SmallVectorImpl<Value> &ids) {
-              OpBuilder::InsertionGuard guard(rewriter);
-              rewriter.setInsertionPoint(forallOp);
-              IndexType indexType = rewriter.getIndexType();
-              auto loc = forallOp->getLoc();
-              ids.assign(
-                  {rewriter.create<ThreadIdOp>(loc, indexType, Dimension::x),
-                   rewriter.create<ThreadIdOp>(loc, indexType, Dimension::y),
-                   rewriter.create<ThreadIdOp>(loc, indexType, Dimension::z)});
-            }}) {}
+/// Builder for linear ids used in mapping scf.forall to reindexed threads.
+/// The `idBuilder` method returns 3-D values used for indexing rewrites as well
+/// as 1-D sizes for predicate generation.
+struct GpuLinearIdBuilder : public GpuIdBuilder {
+  GpuLinearIdBuilder(MLIRContext *ctx, ArrayRef<OpFoldResult> blockDims,
+                     ArrayRef<int64_t> mappingSizes)
+      : GpuIdBuilder(blockDims, mappingSizes) {
+    mappingAttributes = {GPULinearIdMappingAttr::get(ctx, LinearId::DimX),
+                         GPULinearIdMappingAttr::get(ctx, LinearId::DimY),
+                         GPULinearIdMappingAttr::get(ctx, LinearId::DimZ)};
+    idBuilder = [this](RewriterBase &rewriter, Location loc,
+                       ArrayRef<int64_t> forallMappingSizes) {
+      // Build the linear thread id and decompose it in the basis of
+      // `forallMappingSizes`.
+      Value linearId = buildLinearThreadId(rewriter, loc, this->blockDimsOfr);
+      SmallVector<int64_t> reverseBasisSizes(llvm::reverse(forallMappingSizes));
+      SmallVector<int64_t> strides = computeStrides(reverseBasisSizes);
+      AffineExpr d0;
+      bindDims(rewriter.getContext(), d0);
+      SmallVector<AffineExpr> delinearizingExprs = delinearize(d0, strides);
+      SmallVector<Value> ids;
+      for (AffineExpr e : delinearizingExprs)
+        ids.push_back(makeComposedAffineApply(rewriter, loc, e, linearId));
+
+      // clang-format off
+      LLVM_DEBUG(llvm::interleaveComma(reverseBasisSizes,
+                                       DBGS() << "--delinearization basis: ");
+                 llvm::dbgs() << "\n";
+                 llvm::interleaveComma(strides,
+                                       DBGS() << "--delinearization strides: ");
+                 llvm::dbgs() << "\n";
+                 llvm::interleaveComma(delinearizingExprs,
+                                       DBGS() << "--delinearization exprs: ");
+                 llvm::dbgs() << "\n";
+                 llvm::interleaveComma(ids, DBGS() << "--ids: ");
+                 llvm::dbgs() << "\n";);
+      // clang-format on
+
+      // Compute and return the 1-D actual mapping size spanned by the linearId,
+      // it will be used to predicate against the linearized total number of
+      // threads.
+      int64_t actualMappingSize = 1;
+      for (int64_t s : forallMappingSizes)
+        actualMappingSize *= s;
+
+      // Return 3-D ids for indexing rewrite and 1-D size and id for
+      // predicate generation.
+      return IdBuilderResult{ids, SmallVector<int64_t>{actualMappingSize},
+                             SmallVector<Value>{linearId}};
+    };
+  }
 };
 
 } // namespace
 
 static DiagnosedSilenceableFailure
-failureHelper(std::optional<TransformOpInterface> transformOp,
-              scf::ForallOp forallOp, const Twine &message) {
+definiteFailureHelper(std::optional<TransformOpInterface> transformOp,
+                      Operation *target, const Twine &message) {
   if (transformOp.has_value())
-    return emitDefiniteFailure(*transformOp, message);
-  return emitDefiniteFailure(forallOp, message);
+    return transformOp->emitDefiniteFailure() << message;
+  return emitDefiniteFailure(target, message);
 }
 
 /// Check if given mapping attributes are one of the desired attributes
@@ -104,7 +244,8 @@ static DiagnosedSilenceableFailure
 checkMappingAttributeTypes(std::optional<TransformOpInterface> transformOp,
                            scf::ForallOp forallOp) {
   if (!forallOp.getMapping().has_value())
-    return failureHelper(transformOp, forallOp, "mapping must be present");
+    return definiteFailureHelper(transformOp, forallOp,
+                                 "mapping must be present");
 
   bool hasBlockMapping =
       llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) {
@@ -114,20 +255,32 @@ checkMappingAttributeTypes(std::optional<TransformOpInterface> transformOp,
       llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) {
         return attr.isa<GPUThreadMappingAttr>();
       });
+  bool hasWarpMapping =
+      llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) {
+        return attr.isa<GPUWarpMappingAttr>();
+      });
+  bool hasLinearMapping =
+      llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) {
+        return attr.isa<GPULinearIdMappingAttr>();
+      });
   int64_t countMappingTypes = 0;
   countMappingTypes += hasBlockMapping ? 1 : 0;
   countMappingTypes += hasThreadMapping ? 1 : 0;
+  countMappingTypes += hasWarpMapping ? 1 : 0;
+  countMappingTypes += hasLinearMapping ? 1 : 0;
   if (countMappingTypes > 1) {
-    return failureHelper(transformOp, forallOp,
-                         "cannot mix 
diff erent mapping types, use nesting");
+    return definiteFailureHelper(
+        transformOp, forallOp,
+        "cannot mix 
diff erent mapping types, use nesting");
   }
 
   DenseSet<Attribute> seen;
   for (Attribute map : forallOp.getMapping()->getValue()) {
     if (seen.contains(map)) {
-      return failureHelper(transformOp, forallOp,
-                           "duplicated attribute, cannot map 
diff erent loops "
-                           "to the same processor");
+      return definiteFailureHelper(
+          transformOp, forallOp,
+          "duplicated attribute, cannot map 
diff erent loops "
+          "to the same processor");
     }
     seen.insert(map);
   }
@@ -146,26 +299,26 @@ verifyGpuMapping(std::optional<TransformOpInterface> transformOp,
 
   // Perform other non-types verifications.
   if (!forallOp.isNormalized())
-    return failureHelper(transformOp, forallOp,
-                         "unsupported non-normalized loops");
+    return definiteFailureHelper(transformOp, forallOp,
+                                 "unsupported non-normalized loops");
   if (forallOp.getNumResults() > 0)
-    return failureHelper(transformOp, forallOp,
-                         "only bufferized scf.forall can be mapped");
+    return definiteFailureHelper(transformOp, forallOp,
+                                 "only bufferized scf.forall can be mapped");
   if (forallOp.getRank() > 3)
-    return failureHelper(transformOp, forallOp,
-                         "scf.forall with rank > 3 does not lower");
+    return definiteFailureHelper(transformOp, forallOp,
+                                 "scf.forall with rank > 3 does not lower");
   if (llvm::any_of(forallOp.getMixedUpperBound(), [&](OpFoldResult ofr) {
         return !getConstantIntValue(ofr).has_value();
       })) {
-    return failureHelper(transformOp, forallOp,
-                         "unsupported dynamic sizes in forall op");
+    return definiteFailureHelper(transformOp, forallOp,
+                                 "unsupported dynamic sizes in forall op");
   }
   return DiagnosedSilenceableFailure::success();
 }
 
-/// Determines if the size of the kernel configuration is supported by the GPU
-/// architecture being used. It presently makes use of CUDA limitations, however
-/// that aspect may be enhanced for other GPUs.
+/// Determines if the size of the kernel configuration is supported by the
+/// GPU architecture being used. It presently makes use of CUDA limitations,
+/// however that aspect may be enhanced for other GPUs.
 static DiagnosedSilenceableFailure checkGpuLimits(
     TransformOpInterface transformOp, std::optional<int64_t> gridDimX,
     std::optional<int64_t> gridDimY, std::optional<int64_t> gridDimZ,
@@ -192,17 +345,17 @@ static DiagnosedSilenceableFailure checkGpuLimits(
       gridDimZ.value_or(1) > maxGriddimz ||
       gridDimX.value_or(1) > maxGriddimx) {
     return transformOp.emitSilenceableError()
-           << "Trying to launch a GPU kernel with gridDim = ("
+           << "Trying to launch a GPU kernel with grid_dims = ("
            << gridDimX.value_or(1) << ", " << gridDimY.value_or(1) << ", "
-           << gridDimZ.value_or(1) << ") blockDim = (" << blockDimX.value_or(1)
-           << ", " << blockDimY.value_or(1) << ", " << blockDimZ.value_or(1)
-           << "). It is larger than the limits.";
+           << gridDimZ.value_or(1) << ") block_dims = ("
+           << blockDimX.value_or(1) << ", " << blockDimY.value_or(1) << ", "
+           << blockDimZ.value_or(1) << "). It is larger than the limits.";
   }
   return DiagnosedSilenceableFailure::success();
 }
 
-/// Creates an empty-body gpu::LaunchOp using the provided kernel settings and
-/// put a terminator within.
+/// Creates an empty-body gpu::LaunchOp using the provided kernel settings
+/// and put a terminator within.
 static DiagnosedSilenceableFailure
 createGpuLaunch(RewriterBase &rewriter, Location loc,
                 TransformOpInterface transformOp, LaunchOp &launchOp,
@@ -278,24 +431,36 @@ alterGpuLaunch(IRRewriter &rewriter, LaunchOp gpuLaunch,
   return DiagnosedSilenceableFailure::success();
 }
 
-//===----------------------------------------------------------------------===//
-// MapForallToBlocks
-//===----------------------------------------------------------------------===//
+/// Struct to return the result of the rewrite of a forall operation.
+struct ForallRewriteResult {
+  SmallVector<int64_t> mappingSizes;
+  SmallVector<Value> mappingIds;
+};
 
-static FailureOr<SmallVector<int64_t>> rewriteOneForallCommonImpl(
+/// Helper to replace ids of dimensions known to be 1 by 0 to simplify the IR.
+template <typename OpTy, typename OperationOrBlock>
+static void
+replaceUnitMappingIdsHelper(RewriterBase &rewriter, Location loc,
+                            OperationOrBlock *parent, Value replacement,
+                            ArrayRef<int64_t> availableMappingSizes) {
+  parent->walk([&](OpTy idOp) {
+    if (availableMappingSizes[static_cast<int64_t>(idOp.getDimension())] == 1)
+      rewriter.replaceAllUsesWith(idOp.getResult(), replacement);
+  });
+}
+
+static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(
     RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
-    scf::ForallOp forallOp,
-    const SmallVectorImpl<int64_t> &availableMappingSizes,
-    const ArrayRef<DeviceMappingAttrInterface> &allMappingAttributes,
-    IdGeneratorFnType idGenerator) {
-  LDBG("Start rewriteOneForallCommonImpl");
+    scf::ForallOp forallOp, ForallRewriteResult &result,
+    ArrayRef<int64_t> availableMappingSizes, const GpuIdBuilder &gpuIdBuilder) {
+  LDBG("--start rewriteOneForallCommonImpl");
 
   // Step 0. GPU-specific verifications. There is no better place to anchor
-  // those right now: the ForallOp is target-independent and the transform op
-  // does not apply to individual ForallOp.
+  // those right now: the ForallOp is target-independent and the transform
+  // op does not apply to individual ForallOp.
   DiagnosedSilenceableFailure diag = verifyGpuMapping(transformOp, forallOp);
   if (!diag.succeeded())
-    return failure();
+    return diag;
 
   // Step 1. Complete the mapping to a full mapping (with 1s) if necessary.
   SmallVector<int64_t> tmpMappingSizes = llvm::to_vector(
@@ -304,97 +469,108 @@ static FailureOr<SmallVector<int64_t>> rewriteOneForallCommonImpl(
         assert(maybeStaticValue && "expected static value");
         return maybeStaticValue.value();
       }));
-  SmallVector<Attribute> forallMappings =
+  SmallVector<Attribute> forallMappingAttrs =
       llvm::to_vector(forallOp.getMapping()->getValue());
-  for (auto attr : allMappingAttributes) {
-    if (llvm::is_contained(forallMappings, attr))
+  for (auto attr : gpuIdBuilder.mappingAttributes) {
+    if (llvm::is_contained(forallMappingAttrs, attr))
       continue;
-    forallMappings.push_back(attr);
+    forallMappingAttrs.push_back(attr);
     tmpMappingSizes.push_back(1);
   }
+  LLVM_DEBUG(
+      llvm::interleaveComma(
+          tmpMappingSizes,
+          DBGS() << "----tmpMappingSizes extracted from scf.forall op: ");
+      llvm::dbgs() << "\n");
 
   // Step 2. sort the values by the corresponding DeviceMappingAttrInterface.
   auto comparator = [&](DeviceMappingAttrInterface a,
                         DeviceMappingAttrInterface b) -> bool {
     return a.getMappingId() < b.getMappingId();
   };
-  SmallVector<int64_t> mappingSizes =
-      getValuesSortedByKey(forallMappings, tmpMappingSizes, comparator);
-  LLVM_DEBUG(llvm::interleaveComma(mappingSizes, DBGS() << "mappingSizes: ");
-             llvm::dbgs() << "\n";
-             llvm::interleaveComma(forallMappings, DBGS() << "mappingAttrs: ");
+  SmallVector<int64_t> forallMappingSizes =
+      getValuesSortedByKey(forallMappingAttrs, tmpMappingSizes, comparator);
+  LLVM_DEBUG(llvm::interleaveComma(forallMappingSizes,
+                                   DBGS() << "----forallMappingSizes: ");
+             llvm::dbgs() << "\n"; llvm::interleaveComma(
+                 forallMappingAttrs, DBGS() << "----mappingAttrs: ");
              llvm::dbgs() << "\n");
 
-  // Step 3. Generate the mappingIdOps using the provided generator and map the
-  // induction variables to the newly created ops. Replace ids of dimension
-  // known to be of size 1 by zero to simplify the IR.
-  SmallVector<Value> mappingIdOps;
+  // Step 3. Generate the mappingIdOps using the provided generator.
   Location loc = forallOp.getLoc();
-  idGenerator(rewriter, forallOp, mappingIdOps);
-  LLVM_DEBUG(llvm::interleaveComma(mappingIdOps, DBGS() << "mappingIdOps: ");
-             llvm::dbgs() << "\n");
-  assert(mappingIdOps.size() == mappingSizes.size() && "expect equal sizes");
-  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
-  if (!availableMappingSizes.empty()) {
-    for (size_t i : llvm::seq(size_t(0), availableMappingSizes.size())) {
-      if (availableMappingSizes[i] == 1)
-        mappingIdOps[i] = zero;
-    }
-  }
+  OpBuilder::InsertionGuard guard(rewriter);
+  rewriter.setInsertionPoint(forallOp);
+  IdBuilderResult builderResult =
+      gpuIdBuilder.idBuilder(rewriter, loc, forallMappingSizes);
 
+  // Step 4. Map the induction variables to the mappingIdOps, this may involve a
+  // permutation.
+  SmallVector<Value> mappingIdOps = builderResult.mappingIdOps;
   IRMapping bvm;
   for (auto [iv, dim] :
        llvm::zip_equal(forallOp.getInductionVars(),
-                       ArrayRef<Attribute>{forallMappings}.take_front(
+                       ArrayRef<Attribute>{forallMappingAttrs}.take_front(
                            forallOp.getInductionVars().size()))) {
     Value peIdOp = mappingIdOps[static_cast<int64_t>(
         dim.cast<DeviceMappingAttrInterface>().getMappingId())];
     bvm.map(iv, peIdOp);
   }
 
-  // Step 4. Maybe create conditionals to predicate the region.
-  // Skip this step when availableMappingSizes is empty.
+  // Step 5. If the availableMappingSizes are already known, create conditionals
+  // to predicate the region. Otherwise, the current forall determines the
+  // availableMappingSizes and no predication occurs.
   Value predicate;
   if (!availableMappingSizes.empty()) {
-    LLVM_DEBUG(llvm::interleaveComma(availableMappingSizes,
-                                     DBGS() << "availableMappingSizes: ");
-               llvm::dbgs() << "\n");
-    for (auto [id, mappingSize, availableMappingSize] :
-         llvm::zip_equal(mappingIdOps, mappingSizes, availableMappingSizes)) {
+    SmallVector<int64_t> predicateMappingSizes =
+        builderResult.predicateMappingSizes;
+    SmallVector<Value> predicateIdOps = builderResult.predicateIdOps;
+    // clang-format off
+    LLVM_DEBUG(
+        llvm::interleaveComma(
+          predicateMappingSizes, DBGS() << "----predicateMappingSizes: ");
+        llvm::dbgs() << "\n"; 
+        llvm::interleaveComma(
+          availableMappingSizes, DBGS() << "----availableMappingSizes: ");
+        llvm::dbgs() << "\n";
+        llvm::interleaveComma(predicateIdOps, DBGS() << "----predicateIdOps: ");
+        llvm::dbgs() << "\n");
+    // clang-format on
+    for (auto [id, mappingSize, availableMappingSize] : llvm::zip_equal(
+             predicateIdOps, predicateMappingSizes, availableMappingSizes)) {
       if (mappingSize > availableMappingSize) {
-        (void)failureHelper(
+        return definiteFailureHelper(
             transformOp, forallOp,
             "Trying to map to fewer GPU threads than loop iterations but "
             "overprovisioning is not yet supported. "
             "Try additional tiling of the before mapping or map to more "
             "threads.");
-        return failure();
       }
       if (mappingSize == availableMappingSize)
         continue;
       Value idx = rewriter.create<arith::ConstantIndexOp>(loc, mappingSize);
       Value tmpPredicate = rewriter.create<arith::CmpIOp>(
           loc, arith::CmpIPredicate::ult, id, idx);
-      LDBG("predicate: " << tmpPredicate);
+      LDBG("----predicate: " << tmpPredicate);
       predicate = predicate ? rewriter.create<arith::AndIOp>(loc, predicate,
                                                              tmpPredicate)
                             : tmpPredicate;
     }
   }
 
-  // Step 5. Move the body of forallOp.
+  // Step 6. Move the body of forallOp.
   // Erase the terminator first, it will not be used.
   rewriter.eraseOp(forallOp.getTerminator());
   Block *targetBlock;
   Block::iterator insertionPoint;
   if (predicate) {
-    // Step 5.a. If predicated, move at the beginning.
-    auto ifOp =
-        rewriter.create<scf::IfOp>(loc, predicate, /*withElseRegion=*/false);
+    // Step 6.a. If predicated, move at the beginning.
+    auto ifOp = rewriter.create<scf::IfOp>(loc, predicate,
+                                           /*withElseRegion=*/false);
     targetBlock = ifOp.thenBlock();
     insertionPoint = ifOp.thenBlock()->begin();
   } else {
-    // Step 5.b. Otherwise, move inline just at the rewriter insertion point.
+    // Step 6.b. Otherwise, move inline just at the rewriter insertion
+    // point.
     targetBlock = forallOp->getBlock();
     insertionPoint = rewriter.getInsertionPoint();
   }
@@ -402,32 +578,59 @@ static FailureOr<SmallVector<int64_t>> rewriteOneForallCommonImpl(
   targetBlock->getOperations().splice(insertionPoint,
                                       sourceBlock.getOperations());
 
-  // Step 6. RAUW thread indices to thread ops.
+  // Step 7. RAUW indices.
   for (Value loopIndex : forallOp.getInductionVars()) {
     Value threadIdx = bvm.lookup(loopIndex);
     rewriter.replaceAllUsesWith(loopIndex, threadIdx);
   }
 
-  // Step 7. Erase old op.
+  // Step 8. Erase old op.
   rewriter.eraseOp(forallOp);
 
-  return mappingSizes;
+  result = ForallRewriteResult{forallMappingSizes, mappingIdOps};
+  return DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// MapForallToBlocks
+//===----------------------------------------------------------------------===//
+
 DiagnosedSilenceableFailure mlir::transform::gpu::mapForallToBlocksImpl(
     RewriterBase &rewriter, TransformOpInterface transformOp,
     scf::ForallOp forallOp, SmallVectorImpl<int64_t> &gridDims,
-    const ArrayRef<DeviceMappingAttrInterface> &allMappingAttributes,
-    IdGeneratorFnType idGenerator) {
-  // Pass an empty anyAvailableMappingSizes.
+    const GpuIdBuilder &gpuIdBuilder) {
+  LDBG("Start mapForallToBlocksImpl");
+
+  Location loc = forallOp.getLoc();
+  Block *parentBlock = forallOp->getBlock();
+  Value zero;
+  {
+    // Create an early zero index value for replacements and immediately reset
+    // the insertion point.
+    OpBuilder::InsertionGuard guard(rewriter);
+    rewriter.setInsertionPointToStart(parentBlock);
+    zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+  }
+
   SmallVector<int64_t> anyAvailableMappingSizes;
-  FailureOr<SmallVector<int64_t>> maybeMappingSizes =
-      rewriteOneForallCommonImpl(rewriter, transformOp, forallOp,
-                                 anyAvailableMappingSizes, allMappingAttributes,
-                                 idGenerator);
-  if (failed(maybeMappingSizes))
-    return DiagnosedSilenceableFailure::definiteFailure();
-  gridDims = *maybeMappingSizes;
+  ForallRewriteResult rewriteResult;
+  // Pass an empty anyAvailableMappingSizes.
+  DiagnosedSilenceableFailure diag =
+      rewriteOneForallCommonImpl(rewriter, transformOp, forallOp, rewriteResult,
+                                 anyAvailableMappingSizes, gpuIdBuilder);
+
+  // Return if anything goes wrong, use silenceable failure as a match failure.
+  if (!diag.succeeded())
+    return diag;
+
+  // Set the gridDims that act as a return.
+  gridDims = rewriteResult.mappingSizes;
+
+  // Replace ids of dimensions known to be 1 by 0 to simplify the IR.
+  // Here, the result of mapping determines the available mapping sizes.
+  replaceUnitMappingIdsHelper<BlockDimOp>(rewriter, loc, parentBlock, zero,
+                                          gridDims);
+
   return DiagnosedSilenceableFailure::success();
 }
 
@@ -476,7 +679,7 @@ transform::MapForallToBlocks::applyToOne(Operation *target,
     return diag;
   }
 
-  SmallVector<int64_t> gridDims = extractFromI64ArrayAttr(getGridDim());
+  SmallVector<int64_t> gridDims{getGridDims()};
   if (!getGenerateGpuLaunch() && gridDims.size() != 3)
     return transformOp.emitDefiniteFailure("transform require size-3 mapping");
 
@@ -496,17 +699,14 @@ transform::MapForallToBlocks::applyToOne(Operation *target,
     topLevelForallOp = cast<scf::ForallOp>(newForallOp);
   }
 
-  diag = verifyGpuMapping(transformOp, topLevelForallOp);
-  if (!diag.succeeded())
-    return diag;
-
-  MappingToGpuBlocksHelper helper(getContext());
+  GpuBlockIdBuilder gpuBlockIdBuilder(getContext(), {}, {});
   diag = mlir::transform::gpu::mapForallToBlocksImpl(
-      rewriter, transformOp, topLevelForallOp, gridDims,
-      helper.mappingAttributes, helper.idGenerator);
+      rewriter, transformOp, topLevelForallOp, gridDims, gpuBlockIdBuilder);
   if (!diag.succeeded())
     return diag;
 
+  // Set the GPU launch configuration for the grid dims late, this is subject to
+  // IR inspection.
   diag = alterGpuLaunch(rewriter, gpuLaunch,
                         cast<TransformOpInterface>(getOperation()), gridDims[0],
                         gridDims[1], gridDims[2]);
@@ -519,37 +719,133 @@ transform::MapForallToBlocks::applyToOne(Operation *target,
 // MapNestedForallToThreads
 //===----------------------------------------------------------------------===//
 
+DiagnosedSilenceableFailure mlir::transform::gpu::mapOneForallToThreadsImpl(
+    RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
+    scf::ForallOp forallOp, ArrayRef<int64_t> availableMappingSizes,
+    bool syncAfterDistribute, const GpuIdBuilder &gpuIdBuilder) {
+  // Ignore cases with 
diff erent attributes than this builder supports.
+  for (Attribute map : forallOp.getMapping()->getValue()) {
+    if (!llvm::is_contained(gpuIdBuilder.mappingAttributes, map)) {
+      LDBG("--skip " << map);
+      LLVM_DEBUG(llvm::interleaveComma(gpuIdBuilder.mappingAttributes,
+                                       DBGS() << "----not in: ");
+                 llvm::dbgs() << "\n";);
+      return emitSilenceableFailure(forallOp);
+    }
+  }
+
+  Location loc = forallOp.getLoc();
+  OpBuilder::InsertionGuard g(rewriter);
+  // Insert after to allow for syncthreads after `forall` is erased.
+  rewriter.setInsertionPointAfter(forallOp);
+  ForallRewriteResult rewriteResult;
+  DiagnosedSilenceableFailure diag =
+      rewriteOneForallCommonImpl(rewriter, transformOp, forallOp, rewriteResult,
+                                 availableMappingSizes, gpuIdBuilder);
+
+  // Return if anything goes wrong, use silenceable failure as a match failure.
+  if (!diag.succeeded())
+    return diag;
+
+  // Add a syncthreads if needed. TODO: warpsync
+  if (syncAfterDistribute)
+    rewriter.create<BarrierOp>(loc);
+
+  return DiagnosedSilenceableFailure::success();
+}
+
 DiagnosedSilenceableFailure mlir::transform::gpu::mapNestedForallToThreadsImpl(
     RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
-    Operation *target, const SmallVectorImpl<int64_t> &kernelBlockDims,
-    bool syncAfterDistribute,
-    const ArrayRef<DeviceMappingAttrInterface> &allMappingAttributes,
-    IdGeneratorFnType idGenerator) {
+    Operation *target, ArrayRef<int64_t> blockDims, ArrayRef<int64_t> warpDims,
+    bool syncAfterDistribute) {
+  LDBG("Start mapNestedForallToThreadsImpl");
+  MLIRContext *ctx = rewriter.getContext();
+  SmallVector<OpFoldResult> blockDimsOfr =
+      getAsIndexOpFoldResult(ctx, blockDims);
+
+  if (blockDims.size() != 3)
+    return definiteFailureHelper(transformOp, target,
+                                 "requires size-3 thread mapping");
+  if (!warpDims.empty()) {
+    if (warpDims.size() != 3)
+      return definiteFailureHelper(transformOp, target,
+                                   "requires empty or size-3 warp mapping");
+  }
+
+  // Create an early zero index value for replacements.
+  Location loc = target->getLoc();
+  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
   DiagnosedSilenceableFailure diag = DiagnosedSilenceableFailure::success();
-  target->walk([&](scf::ForallOp forallOp) {
-    // Ignore cases with 
diff erent attributes.
-    for (Attribute map : forallOp.getMapping()->getValue()) {
-      if (!llvm::is_contained(allMappingAttributes, map)) {
+  WalkResult walkResult = target->walk([&](scf::ForallOp forallOp) {
+    //===--------------------------------------------------------------------===//
+    // Mapping to warp ids.
+    //===--------------------------------------------------------------------===//
+    if (!warpDims.empty()) {
+      LLVM_DEBUG(
+          llvm::interleaveComma(
+              warpDims, DBGS() << "+mapNestedForallToThreadsImpl warpDims: ");
+          llvm::dbgs() << "\n");
+      LLVM_DEBUG(llvm::interleaveComma(
+                     blockDimsOfr, DBGS() << "--warpDims with blockDimsOfr:  ");
+                 llvm::dbgs() << "\n");
+      GpuWarpIdBuilder gpuWarpIdBuilder(ctx, blockDimsOfr, warpDims);
+      diag = mlir::transform::gpu::mapOneForallToThreadsImpl(
+          rewriter, transformOp, forallOp, warpDims, syncAfterDistribute,
+          gpuWarpIdBuilder);
+      // Use silenceable failure to encode "failure to match" and pass
+      // through.
+      if (diag.isDefiniteFailure())
+        return WalkResult::interrupt();
+      if (diag.succeeded())
         return WalkResult::skip();
-      }
-    }
-    diag = verifyGpuMapping(transformOp, forallOp);
-    if (diag.succeeded()) {
-      // Take the loc ahead of time
-      Location loc = forallOp.getLoc();
-      OpBuilder::InsertionGuard g(rewriter);
-      rewriter.setInsertionPointAfter(forallOp);
-      if (failed(rewriteOneForallCommonImpl(rewriter, transformOp, forallOp,
-                                            kernelBlockDims,
-                                            allMappingAttributes, idGenerator)))
-        diag = DiagnosedSilenceableFailure::definiteFailure();
-      // Add a syncthreads if needed. TODO: warpsync
-      if (syncAfterDistribute)
-        rewriter.create<BarrierOp>(loc);
     }
-    return diag.succeeded() ? WalkResult::advance() : WalkResult::interrupt();
+
+    //===--------------------------------------------------------------------===//
+    // Mapping to linear ids.
+    //===--------------------------------------------------------------------===//
+    LDBG("+mapNestedForallToThreadsImpl linearDims");
+    LLVM_DEBUG(llvm::interleaveComma(
+                   blockDimsOfr, DBGS() << "--linearDims with blockDimsOfr:  ");
+               llvm::dbgs() << "\n");
+    int64_t numThreads = 1;
+    for (int64_t b : blockDims)
+      numThreads *= b;
+    GpuLinearIdBuilder gpuLinearIdBuilder(ctx, blockDimsOfr, numThreads);
+    diag = mlir::transform::gpu::mapOneForallToThreadsImpl(
+        rewriter, transformOp, forallOp, numThreads, syncAfterDistribute,
+        gpuLinearIdBuilder);
+    // Use silenceable failure to encode "failure to match" and pass through.
+    if (diag.isDefiniteFailure())
+      return WalkResult::interrupt();
+    if (diag.succeeded())
+      return WalkResult::skip();
+
+    //===--------------------------------------------------------------------===//
+    // Mapping to block ids (happens last so we can replay ThreadIdOp).
+    //===--------------------------------------------------------------------===//
+    LLVM_DEBUG(
+        llvm::interleaveComma(
+            blockDimsOfr, DBGS() << "mapNestedForallToThreadsImpl blockDims: ");
+        llvm::dbgs() << "\n");
+    GpuThreadIdBuilder gpuThreadIdBuilder(ctx, blockDimsOfr, blockDims);
+    diag = mlir::transform::gpu::mapOneForallToThreadsImpl(
+        rewriter, transformOp, forallOp, blockDims, syncAfterDistribute,
+        gpuThreadIdBuilder);
+    // Use silenceable failure to encode "failure to match" and pass through.
+    if (diag.isDefiniteFailure())
+      return WalkResult::interrupt();
+
+    return WalkResult::advance();
   });
-  return diag;
+  if (walkResult.wasInterrupted())
+    return diag;
+
+  // Replace ids of dimensions known to be 1 by 0 to simplify the IR.
+  // Here, the result of mapping determines the available mapping sizes.
+  replaceUnitMappingIdsHelper<ThreadIdOp>(rewriter, loc, target, zero,
+                                          blockDims);
+
+  return DiagnosedSilenceableFailure::success();
 }
 
 DiagnosedSilenceableFailure transform::MapNestedForallToThreads::applyToOne(
@@ -561,32 +857,29 @@ DiagnosedSilenceableFailure transform::MapNestedForallToThreads::applyToOne(
   if (!gpuLaunch)
     return emitSilenceableError() << "Given target is not a gpu.launch";
 
-  SmallVector<int64_t> blockDims = extractFromI64ArrayAttr(getBlockDim());
-  if (blockDims.size() != 3)
-    return transformOp.emitDefiniteFailure("transform require size-3 mapping");
+  // Mapping to block ids.
+  SmallVector<int64_t> blockDims{getBlockDims()};
 
   DiagnosedSilenceableFailure diag =
       checkGpuLimits(transformOp, std::nullopt, std::nullopt, std::nullopt,
                      blockDims[0], blockDims[1], blockDims[2]);
   if (diag.isSilenceableFailure()) {
-    diag.attachNote(getLoc()) << getBlockDimAttrName() << " is too large";
+    diag.attachNote(getLoc()) << getBlockDimsAttrName() << " is too large";
     return diag;
   }
 
-  MLIRContext *ctx = getContext();
-  IRRewriter rewriter(ctx);
-  MappingToGpuThreadsHelper helper(ctx);
-  diag = mlir::transform::gpu::mapNestedForallToThreadsImpl(
-      rewriter, transformOp, target, blockDims, getSyncAfterDistribute(),
-      helper.mappingAttributes, helper.idGenerator);
-
-  if (!diag.succeeded())
-    return diag;
-
+  // Set the GPU launch configuration for the block dims early, this is not
+  // subject to IR inspection.
+  IRRewriter rewriter(getContext());
   diag = alterGpuLaunch(rewriter, gpuLaunch, transformOp, std::nullopt,
                         std::nullopt, std::nullopt, blockDims[0], blockDims[1],
                         blockDims[2]);
 
+  rewriter.setInsertionPointToStart(&gpuLaunch.getBody().front());
+  diag =
+      mapNestedForallToThreadsImpl(rewriter, transformOp, gpuLaunch, blockDims,
+                                   getWarpDims(), getSyncAfterDistribute());
+
   results.push_back(gpuLaunch.getOperation());
   return diag;
 }

diff  --git a/mlir/test/Dialect/GPU/transform-gpu-failing.mlir b/mlir/test/Dialect/GPU/transform-gpu-failing.mlir
index 50f49727d3e68..459b800f76d35 100644
--- a/mlir/test/Dialect/GPU/transform-gpu-failing.mlir
+++ b/mlir/test/Dialect/GPU/transform-gpu-failing.mlir
@@ -8,7 +8,7 @@ transform.sequence failures(propagate) {
 ^bb0(%arg0: !pdl.operation):
   %funcop = transform.structured.match ops{["tensor.empty"]} in %arg0 : (!pdl.operation) -> !pdl.operation
   // expected-error @below {{Given target is not a gpu.launch}}
-  %1 = transform.gpu.map_nested_forall_to_threads %funcop
+  %1 = transform.gpu.map_nested_forall_to_threads %funcop block_dims = [1, 1, 1]
 }
 
 // -----
@@ -47,9 +47,9 @@ func.func @map_nested_forall_to_threads_excessive_threads(%x: memref<2 x 32 x f3
 transform.sequence failures(propagate) {
 ^bb1(%arg0: !pdl.operation):
   %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!pdl.operation) -> !pdl.operation
-  // expected-error @below {{Trying to launch a GPU kernel with gridDim = (1, 1, 1) blockDim = (1200, 9, 1). It is larger than the limits.}}
-  // expected-note @below {{"blockDim" is too large}}
-  transform.gpu.map_nested_forall_to_threads %funcop { blockDim = [1200, 9, 1] }
+  // expected-error @below {{Trying to launch a GPU kernel with grid_dims = (1, 1, 1) block_dims = (1200, 9, 1). It is larger than the limits.}}
+  // expected-note @below {{"block_dims" is too large}}
+  transform.gpu.map_nested_forall_to_threads %funcop block_dims = [1200, 9, 1]
 }
 
 // -----
@@ -90,7 +90,7 @@ transform.sequence failures(propagate) {
 ^bb1(%arg0: !pdl.operation):
   %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!pdl.operation) -> !pdl.operation
   // expected-error @below {{Trying to map to fewer GPU threads than loop iterations but overprovisioning is not yet supported. Try additional tiling of the before mapping or map to more threads.}}
-  transform.gpu.map_nested_forall_to_threads %funcop { blockDim = [128, 4, 1] }
+  transform.gpu.map_nested_forall_to_threads %funcop block_dims = [128, 4, 1]
 }
 
 // -----
@@ -116,7 +116,7 @@ transform.sequence failures(propagate) {
 ^bb1(%arg0: !pdl.operation):
   %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!pdl.operation) -> !pdl.operation
   // expected-error @below {{unsupported dynamic sizes}}
-  transform.gpu.map_nested_forall_to_threads %funcop { blockDim = [128, 4, 1] }
+  transform.gpu.map_nested_forall_to_threads %funcop block_dims = [128, 4, 1]
 }
 
 // -----
@@ -138,7 +138,7 @@ transform.sequence failures(propagate) {
   %forall, %tiled = transform.structured.tile_to_forall_op %matmul num_threads [10, 20, 30] (mapping = [ #gpu.thread<y>, #gpu.thread<x>, #gpu.thread<z> ] )
   %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!pdl.operation) -> !pdl.operation
   // expected-error @below {{only bufferized scf.forall can be mapped}}
-  transform.gpu.map_nested_forall_to_threads %funcop { blockDim = [128, 4, 1] }
+  transform.gpu.map_nested_forall_to_threads %funcop block_dims = [128, 4, 1]
 }
 
 // -----
@@ -243,8 +243,8 @@ func.func @map_forall_to_blocks_large_loop(%x: memref<2 x 32 x f32>, %y: memref<
 transform.sequence failures(propagate) {
 ^bb0(%arg0: !pdl.operation):
   %funcop = transform.structured.match ops{["func.func"]} in %arg0 : (!pdl.operation) -> !pdl.operation
-  // expected-error @below {{Trying to launch a GPU kernel with gridDim = (65535, 65535, 1) blockDim = (1, 1, 1). It is larger than the limits.}}
-  %1 = transform.gpu.map_forall_to_blocks %funcop { generate_gpu_launch }
+  // expected-error @below {{Trying to launch a GPU kernel with grid_dims = (65535, 65535, 1) block_dims = (1, 1, 1). It is larger than the limits.}}
+  %1 = transform.gpu.map_forall_to_blocks %funcop generate_gpu_launch
 }
 
 // -----
@@ -271,7 +271,7 @@ transform.sequence failures(propagate) {
 ^bb1(%arg0: !pdl.operation):
   %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!pdl.operation) -> !pdl.operation
   // expected-error @below {{duplicated attribute, cannot map 
diff erent loops to the same processor}}
-  transform.gpu.map_nested_forall_to_threads %funcop { blockDim = [32, 32, 1]}
+  transform.gpu.map_nested_forall_to_threads %funcop block_dims = [32, 32, 1]
 }
 
 // -----

diff  --git a/mlir/test/Dialect/GPU/transform-gpu.mlir b/mlir/test/Dialect/GPU/transform-gpu.mlir
index 447ff1597657d..fcf56c8024bfa 100644
--- a/mlir/test/Dialect/GPU/transform-gpu.mlir
+++ b/mlir/test/Dialect/GPU/transform-gpu.mlir
@@ -33,7 +33,7 @@ func.func @saxpy2dblock(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %stream
 transform.sequence failures(propagate) {
 ^bb1(%arg0: !pdl.operation):
   %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!pdl.operation) -> !pdl.operation
-  transform.gpu.map_forall_to_blocks %funcop { gridDim = [12, 9, 1]}
+  transform.gpu.map_forall_to_blocks %funcop grid_dims = [12, 9, 1]
 }
 
 // -----
@@ -87,7 +87,7 @@ func.func @saxpy2d(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %stream : !g
 transform.sequence failures(propagate) {
 ^bb1(%arg0: !pdl.operation):
   %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!pdl.operation) -> !pdl.operation
-  transform.gpu.map_nested_forall_to_threads %funcop { blockDim = [12, 9, 1] }
+  transform.gpu.map_nested_forall_to_threads %funcop block_dims = [12, 9, 1]
 }
 
 // -----
@@ -127,7 +127,7 @@ transform.sequence failures(propagate) {
 ^bb1(%arg0: !pdl.operation):
   %funcop = transform.structured.match ops{["func.func"]} in %arg0 : (!pdl.operation) -> !pdl.operation
   %gpuLaunch = transform.gpu.map_forall_to_blocks %funcop { generate_gpu_launch }
-  transform.gpu.map_nested_forall_to_threads %gpuLaunch { blockDim = [32, 4, 1] }
+  transform.gpu.map_nested_forall_to_threads %gpuLaunch block_dims = [32, 4, 1]
 }
 
 // -----
@@ -160,7 +160,7 @@ func.func @saxpy2d_no_barrier(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %
 transform.sequence failures(propagate) {
 ^bb1(%arg0: !pdl.operation):
   %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!pdl.operation) -> !pdl.operation
-  transform.gpu.map_nested_forall_to_threads %funcop { blockDim = [12, 9, 1], syncAfterDistribute = false }
+  transform.gpu.map_nested_forall_to_threads %funcop block_dims = [12, 9, 1] sync_after_distribute = false
 }
 
 // -----
@@ -192,7 +192,7 @@ func.func @saxpy2d_singleloop(%x: !type, %y: !type, %stream : !gpu.async.token)
 transform.sequence failures(propagate) {
 ^bb1(%arg0: !pdl.operation):
   %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!pdl.operation) -> !pdl.operation
-  transform.gpu.map_nested_forall_to_threads %funcop { blockDim = [32, 1, 1]}
+  transform.gpu.map_nested_forall_to_threads %funcop block_dims = [32, 1, 1]
 }
 
 // -----
@@ -228,7 +228,7 @@ func.func @saxpy3d_fold_id_z(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %s
 transform.sequence failures(propagate) {
 ^bb1(%arg0: !pdl.operation):
   %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!pdl.operation) -> !pdl.operation
-  transform.gpu.map_nested_forall_to_threads %funcop { blockDim = [12, 9, 1], syncAfterDistribute = false }
+  transform.gpu.map_nested_forall_to_threads %funcop block_dims = [12, 9, 1] sync_after_distribute = false
 }
 
 // -----
@@ -236,29 +236,64 @@ transform.sequence failures(propagate) {
 !type = memref<2 x 32 x f32>
 !type1d = memref<32 x f32>
 
+// CHECK-DAG: #[[$MAPWY:.*]] = affine_map<(d0, d1) -> (((d0 + d1 * 12) floordiv 32) floordiv 4)>
+// CHECK-DAG: #[[$MAPWX:.*]] = affine_map<(d0, d1) -> ((((d0 + d1 * 12) floordiv 32) mod 4) floordiv 2)>
+
+// CHECK-DAG: #[[$MAPLIN:.*]] = affine_map<(d0, d1) -> (d0 + d1 * 12)>
+// CHECK-DAG: #[[$MAPLY:.*]] = affine_map<(d0, d1) -> ((d0 + d1 * 12) floordiv 20)>
+// CHECK-DAG: #[[$MAPLX:.*]] = affine_map<(d0, d1) -> (((d0 + d1 * 12) mod 20) floordiv 10)>
+
 // CHECK-LABEL: func.func @map_multi_level(
 func.func @map_multi_level(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %stream : !gpu.async.token) -> !type {
   %one = arith.constant 1 : index
-  %c12 = arith.constant 12 : index
+  %c10 = arith.constant 10 : index
   %c9 = arith.constant 9 : index
   %c7 = arith.constant 7 : index
-// check that the thread level got distributed but not the warp level.
-//  CHECK-NOT:  {mapping = #gpu.thread
-//      CHECK:  {mapping = [#gpu.warp<x>]}
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+
+  // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+  // CHECK-DAG: %[[C11:.*]] = arith.constant 11 : index
+  // CHECK-DAG: %[[C12:.*]] = arith.constant 12 : index
+  // CHECK-DAG: %[[C20:.*]] = arith.constant 20 : index
+
+  // check that both the thread level and the warp level got distributed.
+  //  CHECK-NOT: #gpu.thread
+  //  CHECK-NOT: #gpu.warp
   %name = gpu.launch async[%stream] blocks(%arg3, %arg4, %arg5) in (%arg9 = %one, %arg10 = %one, %arg11 = %one)
             threads(%arg6, %arg7, %arg8) in (%arg12 = %one, %arg13 = %one, %arg14 = %one)
   {
+    // CHECK-DAG: %[[TIDX:.*]] = gpu.thread_id  x
+    // CHECK-DAG: %[[TIDY:.*]] = gpu.thread_id  y
     scf.forall (%i, %j) in (%c7, %c9) {
-        %4 = memref.load %x[%i, %j] : !type
-        %5 = memref.load %y[%i, %j] : !type
-        %6 = math.fma %alpha, %4, %5 : f32
-        memref.store %6, %y[%i, %j] : !type
-     }  { mapping = [#gpu.thread<y>, #gpu.thread<x>]}
-     scf.forall (%i) in (%c12) {
+      %4 = memref.load %x[%i, %j] : !type
+      %5 = memref.load %y[%i, %j] : !type
+      %6 = math.fma %alpha, %4, %5 : f32
+      memref.store %6, %y[%i, %j] : !type
+    }  { mapping = [#gpu.thread<y>, #gpu.thread<x>]}
+
+    // CHECK-DAG: %[[WIDY:.*]] = affine.apply #[[$MAPWY]](%[[TIDX]], %[[TIDY]])
+    // CHECK-DAG: %[[WIDX:.*]] = affine.apply #[[$MAPWX]](%[[TIDX]], %[[TIDY]])
+    // CHECK-DAG: %[[CMPX:.*]] = arith.cmpi ult, %[[WIDX]], %[[C1]] : index
+    // CHECK-DAG: %[[CMPY:.*]] = arith.cmpi ult, %[[WIDY]], %[[C1]] : index
+    //     CHECK: %[[COND:.*]] = arith.andi %[[CMPY]], %[[CMPX]] : i1
+    //     CHECK: scf.if %[[COND]]
+    scf.forall (%i) in (%c1) {
         %7 = memref.load %t[%i] : !type1d
         %8 = arith.addf %alpha, %7 : f32
         memref.store %8, %t[%i] : !type1d
      }  {mapping = [#gpu.warp<x>] }
+
+    // CHECK-DAG: %[[LIN:.*]] = affine.apply #[[$MAPLIN]](%[[TIDX]], %[[TIDY]])
+    // CHECK-DAG: %[[LIDY:.*]] = affine.apply #[[$MAPLY]](%[[TIDX]], %[[TIDY]])
+    // CHECK-DAG: %[[LIDZ:.*]] = affine.apply #[[$MAPLX]](%[[TIDX]], %[[TIDY]])
+    // CHECK-DAG: %[[COND:.*]] = arith.cmpi ult, %[[LIN]], %[[C20]] : index
+    //     CHECK: scf.if %[[COND]]
+    scf.forall (%i, %j) in (%c10, %c2) {
+        %7 = memref.load %t[%i] : !type1d
+        %8 = arith.addf %alpha, %7 : f32
+        memref.store %8, %t[%j] : !type1d
+     }  {mapping = [#gpu.linear<x>, #gpu.linear<y>] }
     gpu.terminator
   }
   return %y : !type
@@ -267,5 +302,6 @@ func.func @map_multi_level(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %str
 transform.sequence failures(propagate) {
 ^bb1(%arg0: !pdl.operation):
   %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!pdl.operation) -> !pdl.operation
-  transform.gpu.map_nested_forall_to_threads %funcop { blockDim = [12, 9, 1] }
+  transform.gpu.map_nested_forall_to_threads %funcop
+    block_dims = [12, 11, 1] warp_dims = [2, 2, 1]
 }


        


More information about the Mlir-commits mailing list