[Mlir-commits] [mlir] 44e6318 - [mlir][transforms] Revamp the implementation of mapping loops to GPUs

Nicolas Vasilache llvmlistbot at llvm.org
Tue Jul 25 15:09:13 PDT 2023


Author: Nicolas Vasilache
Date: 2023-07-26T00:09:08+02:00
New Revision: 44e6318ceacdc00d4f9b0fbb2814d6dc03e27f7d

URL: https://github.com/llvm/llvm-project/commit/44e6318ceacdc00d4f9b0fbb2814d6dc03e27f7d
DIFF: https://github.com/llvm/llvm-project/commit/44e6318ceacdc00d4f9b0fbb2814d6dc03e27f7d.diff

LOG: [mlir][transforms] Revamp the implementation of mapping loops to GPUs

This revision significantly simplifies the specification and implementation of mapping loops to GPU ids.

Each type of mapping (block, warpgroup, warp, thread) now comes with 2 mapping modes:
  1. a 3-D "grid-like" mode, subject to alignment considerations on threadIdx.x, on which predication
     may occur on a per-dimension 3-D sub-rectangle basis.
  2. a n-D linearized mode, on which predication may only occur on a linear basis.

In the process, better size and alignment requirement inference are introduced along with improved runtime verification messages.

The `warp_dims` attribute was deemed confusing and is removed from the transform in favor of better size inference.

Differential Revision: https://reviews.llvm.org/D155941

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/GPU/TransformOps/GPUDeviceMappingAttr.td
    mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h
    mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td
    mlir/include/mlir/Dialect/GPU/TransformOps/Utils.h
    mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td
    mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
    mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
    mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
    mlir/lib/Dialect/GPU/TransformOps/Utils.cpp
    mlir/lib/Dialect/Linalg/TransformOps/GPUHeuristics.cpp
    mlir/test/Dialect/GPU/transform-gpu-failing.mlir
    mlir/test/Dialect/GPU/transform-gpu.mlir
    mlir/test/Dialect/Linalg/transform-op-gpu-map-copy-to-threads.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUDeviceMappingAttr.td b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUDeviceMappingAttr.td
index 699390c2f2959a..6e0f6f1d78eda7 100644
--- a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUDeviceMappingAttr.td
+++ b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUDeviceMappingAttr.td
@@ -20,107 +20,214 @@ include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td"
 def DimX : I64EnumAttrCase<"DimX", 0, "x">;
 def DimY : I64EnumAttrCase<"DimY", 1, "y">;
 def DimZ : I64EnumAttrCase<"DimZ", 2, "z">;
-
-def ThreadsEnum : I64EnumAttr<"Threads", "threads for loop mapping", [
-    DimX, DimY, DimZ]> {
+def LinearDim0 : I64EnumAttrCase<"LinearDim0", 3, "linear_dim_0">;
+def LinearDim1 : I64EnumAttrCase<"LinearDim1", 4, "linear_dim_1">;
+def LinearDim2 : I64EnumAttrCase<"LinearDim2", 5, "linear_dim_2">;
+def LinearDim3 : I64EnumAttrCase<"LinearDim3", 6, "linear_dim_3">;
+def LinearDim4 : I64EnumAttrCase<"LinearDim4", 7, "linear_dim_4">;
+def LinearDim5 : I64EnumAttrCase<"LinearDim5", 8, "linear_dim_5">;
+def LinearDim6 : I64EnumAttrCase<"LinearDim6", 9, "linear_dim_6">;
+def LinearDim7 : I64EnumAttrCase<"LinearDim7", 10, "linear_dim_7">;
+def LinearDim8 : I64EnumAttrCase<"LinearDim8", 11, "linear_dim_8">;
+def LinearDim9 : I64EnumAttrCase<"LinearDim9", 12, "linear_dim_9">;
+
+// TODO: This would be better represented with separate Grid and Linear Mapping
+// ids. Unfortunately it is not yet possible to have an optional EnumParameter
+// so we currently embed the 2 modes in the same enum.
+def MappingIdEnum : I64EnumAttr<"MappingId", "Mapping ids for loop mapping", [
+    DimX, DimY, DimZ,
+    LinearDim0, LinearDim1, LinearDim2, LinearDim3, LinearDim4, 
+    LinearDim5, LinearDim6, LinearDim7, LinearDim8, LinearDim9]> {
   let cppNamespace = "::mlir::gpu";
 }
 
-def GPUThreadMappingAttr
-    : GPU_Attr<"GPUThreadMapping", "thread", [
-      DeclareAttrInterfaceMethods<DeviceMappingAttrInterface> ]> {
+def GPUBlockMappingAttr : GPU_Attr<"GPUBlockMapping", "block", [
+  DeclareAttrInterfaceMethods<DeviceMappingAttrInterface> ] >  {
   let parameters = (ins
-    EnumParameter<ThreadsEnum>:$thread
+    EnumParameter<MappingIdEnum>:$block
   );
   let assemblyFormat = "`<` params `>`";
   let description = [{
-    An attribute that allows defining thread parallelism for GPU devices.
+    An attribute that allows defining thread block parallelism for GPU devices.
 
-    Thread (aka work item) are grouped into a thread blocks where block may be
-    described by a 1-, 2-, or 3-dimensional rectangle. This attribute indicates
-    that thread parallelism is desired. It can be consumed by lowering to
-    generate GPU.
-  }];
-}
+    Thread blocks (aka workgroup) are grouped into a grid described by a 
+    3-dimensional rectangle.
+    This attribute indicates that thread block parallelism is desired.
+    It can be consumed by lowering to generate GPU code.
+    2 modes are supported: (1) 3D mapping mode and (2) linear mapping mode.
 
-def WarpsEnum : I64EnumAttr<"Warps", "threads for loop mapping", [
-    DimX, DimY, DimZ]> {
-  let cppNamespace = "::mlir::gpu";
+    #### 3D mapping mode
+
+    The 3D block id is simply the 3D index of the block `(bidx, bidy, bidz)`. 
+    If required, predication occurs on a per-dimension basis. This allows 
+    specifying predication on a 3D sub-rectangle of the grid.
+
+    #### Linear mapping mode
+
+    The linear block id is obtained by linearizing the index of the block. 
+    If required, predication occurs on the linear id. This allows specifying
+    predication on a 1D subset of the (linearized) grid.
+
+    For instance, if the basis is denoted as (GX, GY, GZ) and the block id is
+    denoted by (bx, by, bz), the block id is:
+      `linear_id = bx + by * GX + bz * GX * GBY)`.
+    The linear block id is fixed for the duration of a GPU kernel.
+    
+    This linear id mapping attribute indicates a 
diff erent linearization relation
+    is applied locally to a loop nest. 
+    
+    For instance, if the new basis is denoted as (LBD0, LBD1, LBD2, LBD3) the 
+    block id in the new basis is:
+      ```(linear_id mod LBD0 , 
+          (linear_id / LBD0) mod * LBD1, 
+          (linear_id / (LBD0 * LBD1)) mod LBD2, 
+          (linear_id / (LBD0 * LBD1 * LBD2)) mod LBD3)```.
+    This reinterpretation is only fixed for the duration of a loop nest.
+  }];
 }
 
-def GPUWarpMappingAttr : GPU_Attr<"GPUWarpMapping", "warp", [
-  DeclareAttrInterfaceMethods<DeviceMappingAttrInterface> ] >  {
+def GPUWarpgroupMappingAttr
+    : GPU_Attr<"GPUWarpgroupMapping", "warpgroup", [
+      DeclareAttrInterfaceMethods<DeviceMappingAttrInterface> ]> {
   let parameters = (ins
-    EnumParameter<WarpsEnum>:$warp
+    EnumParameter<MappingIdEnum>:$warpgroup
   );
   let assemblyFormat = "`<` params `>`";
   let description = [{
-    An attribute that allows defining thread block parallelism for GPU devices.
+    An attribute that allows defining warpgroup parallelism for GPU devices.
 
-    Warp (aka subgroup) are grouped into a grid where grid may be
-    described by a 1-, 2-, or 3-dimensional rectangle. This attribute indicates
-    that thread block parallelism is desired. It can be consumed by lowering to
-    generate GPU code.
-  }];
-}
+    Threads of proper granularity (e.g. multiple of 
+    "kNumWarpsPerGroup * kWarpSize" on CUDA devices) can be grouped into
+    warpgroups described by a 3-dimensional rectangle. 
+    This attribute indicates that warpgroup parallelism is desired. 
+    It can be consumed by lowering to generate GPU code.
+    2 modes are supported: (1) 3D mapping mode and (2) linear mapping mode.
 
-def LinearIdEnum : I64EnumAttr<"LinearId", "linear ids for loop mapping", [
-    DimX, DimY, DimZ]> {
-  let cppNamespace = "::mlir::gpu";
+    #### 3D mapping mode
+
+    The 3D warpgroup id is simply the adjusted 3D index of the thread 
+    `(tidx / (kNumWarpsPerGroup * kWarpSize), tidy, tidz)`.
+    If required, predication occurs on a per-dimension basis. This allows 
+    specifying predication on a 3D sub-rectangle of the warpgroups.
+
+    #### Linear mapping mode
+
+    The linear warpgroup id is obtained by linearizing the index of the warpgroup.
+    If required, predication occurs on the linear id. This allows specifying
+    predication on a 1D "kNumWarpsPerGroup * kWarpSize"-aligned subset of the 
+    (linearized) block.
+
+    For instance, if the basis is denoted as (BX, BY, BZ) and the thread id is
+    id is denoted by (tx, ty, tz), the linear warpgroup id is:
+      ```linear_id = (tx + ty * BX + tz * BX * BY) 
+                 / (kNumWarpsPerGroup * kWarpSize)```.
+    The linear warpgroup id is fixed for the duration of a GPU kernel.
+    
+    This linear id mapping attribute indicates a 
diff erent linearization relation
+    is applied locally to a loop nest. 
+    
+    For instance, if the new basis is denoted as (LWGD0, LWGD1, LWGD2, LWGD3) the 
+    warpgroup id in the new basis is:
+      ```(linear_id mod LWGD0 , 
+          (linear_id / LWGD0) mod * LWGD1, 
+          (linear_id / (LWGD0 * LWGD1)) mod LWGD2, 
+          (linear_id / (LWGD0 * LWGD1 * LWGD2)) mod LWGD3)```.
+    This reinterpretation is only fixed for the duration of a loop nest.
+  }];
 }
 
-def GPULinearIdMapping : GPU_Attr<"GPULinearIdMapping", "linear", [
-  DeclareAttrInterfaceMethods<DeviceMappingAttrInterface> ] >  {
+def GPUWarpMappingAttr
+    : GPU_Attr<"GPUWarpMapping", "warp", [
+      DeclareAttrInterfaceMethods<DeviceMappingAttrInterface> ]> {
   let parameters = (ins
-    EnumParameter<LinearIdEnum>:$linear_id
+    EnumParameter<MappingIdEnum>:$warp
   );
   let assemblyFormat = "`<` params `>`";
   let description = [{
-    An attribute to allow re-interpreting the linear mapping for threads in GPU
-    devices.
+    An attribute that allows defining warp parallelism for GPU devices.
 
-    Threads (aka work item) are grouped into a thread block where block may be
-    described by a 1-, 2-, or 3-dimensional rectangular basis.
-    The linear thread id is obtained by linearizing the 1-, 2- or 3-dimensional
-    index. For instance, if the basis is denoted as (BX, BY, BZ) and the thread
-    id is denoted by (tx, ty, tz), the linear thread id is:
-      `linear_id = tx + ty * BX + tz * BX * BY)`.
-    The linear thread id is fixed for the duration of a GPU kernel.
+    Threads of proper granularity (e.g. multiple of "warp size" on CUDA devices) 
+    can be grouped into warps described by a 3-dimensional rectangle. 
+    This attribute indicates that warp parallelism is desired.
+    It can be consumed by lowering to generate GPU code.
+    2 modes are supported: (1) 3D mapping mode and (2) linear mapping mode.
+
+    #### 3D mapping mode
+
+    The 3D warp id is simply the adjusted 3D index of the thread 
+    `(tidx / kWarpSize, tidy, tidz)`.
+    If required, predication occurs on a per-dimension basis. This allows 
+    specifying predication on a 3D sub-rectangle of the warpgroups.
+
+    #### Linear mapping mode
+
+    The linear warp id is obtained by linearizing the index of the warp.
+    If required, predication occurs on the linear id. This allows specifying
+    predication on a 1D "kWarpSize"-aligned subset of the (linearized) block.
+
+    For instance, if the basis is denoted as (BX, BY, BZ) and the thread id is
+    id is denoted by (tx, ty, tz), the linear warp id is:
+      `linear_id = (tx + ty * BX + tz * BX * BY) / kWarpSize`.
+    The linear warp id is fixed for the duration of a GPU kernel.
     
     This linear id mapping attribute indicates a 
diff erent linearization relation
     is applied locally to a loop nest. 
     
-    For instance, if the new basis is denoted as (LBX, LBY, LBZ) the thread id
-    in the new basis is:
-      `(linear_id mod LBX , (linear_id / LBX) mod * LBY, linear_id / (LBX * LBY))`.
-    This reinterpretation is only fixe for the duration of a loop nest.
-    
-    It can be consumed by lowering to generate GPU code.
+    For instance, if the new basis is denoted as (LWD0, LWD1, LWD2, LWD3) the 
+    warp id in the new basis is:
+      ```(linear_id mod LWD0 , 
+          (linear_id / LWD0) mod * LWD1, 
+          (linear_id / (LWD0 * LWD1)) mod LWD2, 
+          (linear_id / (LWD0 * LWD1 * LWD2)) mod LWD3)```.
+    This reinterpretation is only fixed for the duration of a loop nest.
   }];
 }
 
-def BlocksEnum : I64EnumAttr<"Blocks", "threads for loop mapping", [
-    DimX, DimY, DimZ]> {
-  let cppNamespace = "::mlir::gpu";
-}
-
-def GPUBlockMappingAttr : GPU_Attr<"GPUBlockMapping", "block", [
-  DeclareAttrInterfaceMethods<DeviceMappingAttrInterface> ] >  {
+def GPUThreadMappingAttr
+    : GPU_Attr<"GPUThreadMapping", "thread", [
+      DeclareAttrInterfaceMethods<DeviceMappingAttrInterface> ]> {
   let parameters = (ins
-    EnumParameter<BlocksEnum>:$block
+    EnumParameter<MappingIdEnum>:$thread
   );
   let assemblyFormat = "`<` params `>`";
   let description = [{
-    An attribute that allows defining thread block parallelism for GPU devices.
+    An attribute that allows defining thread parallelism for GPU devices.
+
+    Thread (aka work item) are grouped into a thread blocks described by a 
+    3-dimensional rectangle.
+    This attribute indicates that thread parallelism is desired.
+    It can be consumed by lowering to generate GPU.
+
+    #### 3D mapping mode
+
+    The 3D thread id is simply the 3D index of the thread `(tidx, tidy, tidz)`. 
+    If required, predication occurs on a per-dimension basis. This allows 
+    specifying predication on a 3D sub-rectangle of the block.
+
+    #### Linear mapping mode
 
-    Thread blocks (aka work-group) are grouped into a grid where grid may be
-    described by a 1-, 2-, or 3-dimensional rectangle. This attribute indicates
-    that thread block parallelism is desired. It can be consumed by lowering to
-    generate GPU code.
+    The linear thread id is obtained by linearizing the index of the thread. 
+    If required, predication occurs on the linear id. This allows specifying
+    predication on a 1D subset of the (linearized) block.
+
+    For instance, if the basis is denoted as (BX, BY, BZ) and the thread id is
+    id is denoted by (tx, ty, tz), the linear thread id is:
+      ```linear_id = (tx + ty * BX + tz * BX * BY)```.
+    The linear thread id is fixed for the duration of a GPU kernel.
+    
+    This linear id mapping attribute indicates a 
diff erent linearization relation
+    is applied locally to a loop nest. 
+    
+    For instance, if the new basis is denoted as (LTD0, LTD1, LTD2, LTD3) the 
+    thread id in the new basis is:
+      ```(linear_id mod LTD0 , 
+          (linear_id / LTD0) mod * LTD1, 
+          (linear_id / (LTD0 * LTD1)) mod LTD2, 
+          (linear_id / (LTD0 * LTD1 * LTD2)) mod LTD3)```.
+    This reinterpretation is only fixed for the duration of a loop nest.
   }];
 }
 
-
 def GPUMemorySpaceMappingAttr : GPU_Attr<"GPUMemorySpaceMapping", "memory_space", [
   DeclareAttrInterfaceMethods<DeviceMappingAttrInterface> ] >  {
   let parameters = (ins
@@ -138,5 +245,4 @@ def GPUMemorySpaceMappingAttr : GPU_Attr<"GPUMemorySpaceMapping", "memory_space"
   }];
 }
 
-
 #endif // GPU_DEVICE_MAPPING_ATTR

diff  --git a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h
index a1cfa406c60ceb..d6612c7c0b7ff4 100644
--- a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h
+++ b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h
@@ -33,12 +33,12 @@ namespace transform {
 namespace gpu {
 struct GpuIdBuilder;
 
-/// Map the top level `scf.forall` op to GPU Thread Blocks.
+/// Map the top level `scf.forall` op to GPU blocks.
 /// Mapping is one-to-one and the induction variables of `scf.forall` are
 /// rewritten to gpu.block_id according to the thread_dim_mapping attribute.
 ///
 /// Dynamic, `scf.forall` trip counts are currently not supported.
-/// Dynamic block dim sizes are currently not supported.
+/// Dynamic `gridDims` are currently not supported.
 DiagnosedSilenceableFailure
 mapForallToBlocksImpl(RewriterBase &rewriter, TransformOpInterface transformOp,
                       scf::ForallOp forallOp,
@@ -46,36 +46,36 @@ mapForallToBlocksImpl(RewriterBase &rewriter, TransformOpInterface transformOp,
                       const GpuIdBuilder &gpuIdBuilder);
 
 /// Search `scf.forall` ops nested under `target` and map each such op to an
-/// explicit GPU implementation along `availableMappingSizes`.
+/// explicit GPU implementation along `blockDims`.
 /// The mapping is one-to-one and the induction variables of `scf.forall` are
 /// rewritten to gpuIdBuilder.idBuilder according to the
 /// gpuIdBuilder.mappingAttributes attribute.
 ///
 /// Dynamic, `scf.forall` trip counts are currently not supported.
-/// Dynamic `availableMappingSizes` sizes are currently not supported.
-/// `availableMappingSizes` is expected to be of size 3.
-DiagnosedSilenceableFailure mapOneForallToThreadsImpl(
-    RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
-    scf::ForallOp forallOp, ArrayRef<int64_t> availableMappingSizes,
-    bool syncAfterDistribute, const GpuIdBuilder &gpuIdBuilder);
+/// Dynamic `blockDims` sizes are currently not supported.
+/// `blockDims` is expected to be of size 3.
+DiagnosedSilenceableFailure
+mapOneForallToThreadsImpl(RewriterBase &rewriter,
+                          std::optional<TransformOpInterface> transformOp,
+                          scf::ForallOp forallOp, ArrayRef<int64_t> blockDims,
+                          int64_t warpSize, bool syncAfterDistribute);
 
 /// Search `scf.forall` ops nested under `target` and map each such op to an
-/// explicit GPU implementation along blockDims and warpDims.
+/// explicit GPU implementation along `blockDims`.
 /// The mapping is one-to-one and the induction variables of `scf.forall` are
-/// rewritten to threads and warps ids according to the mapping attribute.
+/// rewritten to appropriate ids according to the mapping attribute.
 ///
 /// Dynamic, `scf.forall` trip counts are currently not supported.
-/// Dynamic `blockDims` or `warpDims` or `linearDims` sizes are currently not
-/// supported.
-/// `blockDims` is expected to be of size 3.
-/// `warpDims` is expected to be empty or of size 3.
+/// Dynamic `blockDims` or `newBasis` entries are currently not
+/// supported. `blockDims` is expected to be of size 3.
 ///
 /// The insertion point of the `rewriter` is expected to be set at the
 /// beginning of the `target` body block and dominate all other blocks.
-DiagnosedSilenceableFailure mapNestedForallToThreadsImpl(
-    RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
-    Operation *target, ArrayRef<int64_t> blockDimsOfr,
-    ArrayRef<int64_t> warpDims, bool syncAfterDistribute);
+DiagnosedSilenceableFailure
+mapNestedForallToThreadsImpl(RewriterBase &rewriter,
+                             std::optional<TransformOpInterface> transformOp,
+                             Operation *target, ArrayRef<int64_t> blockDims,
+                             int64_t warpSize, bool syncAfterDistribute);
 
 } // namespace gpu
 } // namespace transform

diff  --git a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td
index c7ffbafeefd023..7f25bf7d980c12 100644
--- a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td
@@ -167,15 +167,15 @@ def MapNestedForallToThreads :
 
   let arguments = (ins TransformHandleTypeInterface:$target,
                    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$block_dims,
-                   DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$warp_dims,
-                   DefaultValuedAttr<BoolAttr, "true">:$sync_after_distribute);
+                   DefaultValuedAttr<BoolAttr, "true">:$sync_after_distribute,
+                   DefaultValuedAttr<I64Attr, "32">:$warp_size);
   let results = (outs TransformHandleTypeInterface:$result);
 
   let assemblyFormat = [{
     $target
     `block_dims` `=` $block_dims
-    (`warp_dims` `=` $warp_dims^)?
     (`sync_after_distribute` `=` $sync_after_distribute^)?
+    (`warp_size` `=` $warp_size^)?
     attr-dict
     `:` functional-type($target, $result)
   }];

diff  --git a/mlir/include/mlir/Dialect/GPU/TransformOps/Utils.h b/mlir/include/mlir/Dialect/GPU/TransformOps/Utils.h
index ac10f5c5008eff..6ec5fc53d81eca 100644
--- a/mlir/include/mlir/Dialect/GPU/TransformOps/Utils.h
+++ b/mlir/include/mlir/Dialect/GPU/TransformOps/Utils.h
@@ -27,17 +27,26 @@ class ForallOp;
 namespace transform {
 namespace gpu {
 
-/// Helper type for functions that generate ids for the mapping of a
-/// scf.forall.
+/// Helper type for functions that generate ids for the mapping of a scf.forall.
+/// Operates on both 1) an "original" basis that represents the individual
+/// thread and block ids and 2) a "scaled" basis that represents grouped ids
+/// (e.g. block clusters, warpgroups and warps).
+/// The mapping of ids is done in the "scaled" basis (i.e. when mapping to warps
+/// a division by 32 occurs).
+/// The predication is in the "original" basis using the "active" quantities
+/// (`activeMappingSizes`, `availableMappingSizes` and `activeIdOps`).
 struct IdBuilderResult {
   // Ops used to replace the forall induction variables.
   SmallVector<Value> mappingIdOps;
+  // Available mapping sizes used to predicate the forall body when they are
+  // larger than the predicate mapping sizes.
+  SmallVector<int64_t> availableMappingSizes;
   // Actual mapping sizes used to predicate the forall body when they are
   // smaller than the available mapping sizes.
-  SmallVector<int64_t> predicateMappingSizes;
-  // Ops used to predicate the forall body when predicateMappingSizes is smaller
+  SmallVector<int64_t> activeMappingSizes;
+  // Ops used to predicate the forall body when activeMappingSizes is smaller
   // than the available mapping sizes.
-  SmallVector<Value> predicateIdOps;
+  SmallVector<Value> activeIdOps;
 };
 
 /// Common gpu id builder type, allows the configuration of lowering for various
@@ -46,24 +55,18 @@ struct IdBuilderResult {
 ///   - The loc of the forall op to rewrite.
 ///   - A list of positive integers carrying the mapping sizes for the current
 ///     forall op to rewrite.
-using GpuIdBuilderFnType =
-    std::function<IdBuilderResult(RewriterBase &, Location, ArrayRef<int64_t>)>;
+using GpuIdBuilderFnType = std::function<IdBuilderResult(
+    RewriterBase &, Location, ArrayRef<int64_t>, ArrayRef<int64_t>)>;
 
 /// Helper struct for configuring the rewrite of mapped scf.forall ops to
 /// various gpu id configurations.
 struct GpuIdBuilder {
-  GpuIdBuilder(ArrayRef<OpFoldResult> blockDims, ArrayRef<int64_t> mappingSizes)
-      : blockDimsOfr(blockDims), availableMappingSizes(mappingSizes),
-        mappingAttributes(), idBuilder() {}
+  using MappingIdBuilderFnType = std::function<DeviceMappingAttrInterface(
+      MLIRContext *, mlir::gpu::MappingId)>;
 
-  /// List of OpFoldResult carrying the  multi-dimensional number of
-  /// threads available in the current kernel (i.e. the current blockDims in
-  /// CUDA parlance).
-  ArrayRef<OpFoldResult> blockDimsOfr;
-
-  /// A list of positive integers carrying the number of available mapping
-  /// resources that can trigger predication,
-  ArrayRef<int64_t> availableMappingSizes;
+  GpuIdBuilder() = default;
+  GpuIdBuilder(MLIRContext *ctx, bool useLinearMapping,
+               MappingIdBuilderFnType builder);
 
   /// The mapping attributes targeted by this generator.
   SmallVector<DeviceMappingAttrInterface> mappingAttributes;
@@ -72,43 +75,46 @@ struct GpuIdBuilder {
   GpuIdBuilderFnType idBuilder;
 };
 
-/// Builder for gpu::BlockIdOps used in mapping scf.forall to blocks.
-/// The `idBuilder` method returns 3-D values used for indexing rewrites as well
-/// as 3-D sizes for predicate generation.
+/// Builder for gpu::BlockIdOps used to map scf.forall to blocks.
+/// If `useLinearMapping` is false, the `idBuilder` method returns 3D values
+/// used for indexing rewrites as well as 3D sizes for predicate generation.
+/// If `useLinearMapping` is true, the `idBuilder` method returns nD values
+/// used for indexing rewrites as well as 1D sizes for predicate generation.
 struct GpuBlockIdBuilder : public GpuIdBuilder {
-  GpuBlockIdBuilder(MLIRContext *ctx, ArrayRef<OpFoldResult> blockDims,
-                    ArrayRef<int64_t> mappingSizes);
+  GpuBlockIdBuilder(MLIRContext *ctx, bool useLinearMapping = false);
 };
 
-/// Builder for gpu::ThreadIdOp used in mapping scf.forall to thread ids without
-/// any reindexing.
-/// The `idBuilder` method returns 3-D values used for indexing rewrites as well
-/// as 3-D sizes for predicate generation.
-struct GpuThreadIdBuilder : public GpuIdBuilder {
-  GpuThreadIdBuilder(MLIRContext *ctx, ArrayRef<OpFoldResult> blockDims,
-                     ArrayRef<int64_t> mappingSizes);
+/// Builder for warpgroup ids used to map scf.forall to reindexed warpgroups.
+/// If `useLinearMapping` is false, the `idBuilder` method returns 3D values
+/// used for indexing rewrites as well as 3D sizes for predicate generation.
+/// If `useLinearMapping` is true, the `idBuilder` method returns nD values
+/// used for indexing rewrites as well as 1D sizes for predicate generation.
+struct GpuWarpgroupIdBuilder : public GpuIdBuilder {
+  GpuWarpgroupIdBuilder(MLIRContext *ctx, int64_t warpSize,
+                        bool useLinearMapping = false);
+  int64_t warpSize = 32;
+  /// In the future this may be configured by the transformation.
+  static constexpr int64_t kNumWarpsPerGroup = 4;
 };
 
-/// Builder for warp ids used in mapping scf.forall to warps.
-/// This builder requires a specification of the number of warps along each
-/// dimension to more finely control mapping to warps as well a predication than
-/// by solely analyzing the IR.
-/// The `idBuilder` method returns 3-D values used for indexing rewrites as well
-/// as 3-D sizes for predicate generation.
+/// Builder for warp ids used to map scf.forall to reindexed warps.
+/// If `useLinearMapping` is false, the `idBuilder` method returns 3D values
+/// used for indexing rewrites as well as 3D sizes for predicate generation.
+/// If `useLinearMapping` is true, the `idBuilder` method returns nD values
+/// used for indexing rewrites as well as 1D sizes for predicate generation.
 struct GpuWarpIdBuilder : public GpuIdBuilder {
-  GpuWarpIdBuilder(MLIRContext *ctx, ArrayRef<OpFoldResult> blockDims,
-                   ArrayRef<int64_t> mappingSizes);
-  /// Static specification of the warp size.
-  /// In the future this may be configured by the transformation.
-  static constexpr int64_t kWarpSize = 32;
+  GpuWarpIdBuilder(MLIRContext *ctx, int64_t warpSize,
+                   bool useLinearMapping = false);
+  int64_t warpSize = 32;
 };
 
-/// Builder for linear ids used in mapping scf.forall to reindexed threads.
-/// The `idBuilder` method returns 3-D values used for indexing rewrites as well
-/// as 1-D sizes for predicate generation.
-struct GpuLinearIdBuilder : public GpuIdBuilder {
-  GpuLinearIdBuilder(MLIRContext *ctx, ArrayRef<OpFoldResult> blockDims,
-                     ArrayRef<int64_t> mappingSizes);
+/// Builder for warp ids used to map scf.forall to reindexed threads.
+/// If `useLinearMapping` is false, the `idBuilder` method returns 3D values
+/// used for indexing rewrites as well as 3D sizes for predicate generation.
+/// If `useLinearMapping` is true, the `idBuilder` method returns nD values
+/// used for indexing rewrites as well as 1D sizes for predicate generation.
+struct GpuThreadIdBuilder : public GpuIdBuilder {
+  GpuThreadIdBuilder(MLIRContext *ctx, bool useLinearMapping = false);
 };
 
 /// Determine if the size of the kernel configuration is supported by the

diff  --git a/mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td b/mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td
index 8d07f791d3a8ff..96db2a40cf58e8 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td
@@ -36,10 +36,26 @@ def DeviceMappingAttrInterface : AttrInterface<"DeviceMappingAttrInterface"> {
   }];
 
  let methods = [
-    InterfaceMethod<[{
-        Returns mapping as an integer from the attribute.
+    InterfaceMethod<
+      /*desc=*/"Return mapping as an integer from the attribute.",
+      /*retTy=*/"int64_t",
+      /*methodName=*/"getMappingId",
+      /*args=*/(ins)
+    >,
+    InterfaceMethod<
+      /*desc=*/"Return true if the attribute specifies a linear mapping",
+      /*retTy=*/"bool",
+      /*methodName=*/"isLinearMapping",
+      /*args=*/(ins)
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return the [0..n) relative index of the attribute depending on its group.
+        This can be used to index into a contiguous array.
       }],
-      "int64_t", "getMappingId", (ins)
+      /*retTy=*/"int64_t",
+      /*methodName=*/"getRelativeIndex", 
+      /*args=*/(ins)
     >
   ];
 }

diff  --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index b8f7a26ab6a178..91c43a81feea44 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -19,6 +19,7 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/SmallVectorExtras.h"
 
 namespace mlir {
 
@@ -84,7 +85,7 @@ SmallVector<OpFoldResult> getAsIndexOpFoldResult(MLIRContext *ctx,
 
 /// If ofr is a constant integer or an IntegerAttr, return the integer.
 std::optional<int64_t> getConstantIntValue(OpFoldResult ofr);
-/// If all ifs are constant integers or IntegerAttrs, return the integers.
+/// If all ofrs are constant integers or IntegerAttrs, return the integers.
 std::optional<SmallVector<int64_t>>
 getConstantIntValues(ArrayRef<OpFoldResult> ofrs);
 

diff  --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index f809a9627de26a..c9f378c181e36d 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -28,6 +28,7 @@
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Transforms/InliningUtils.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/ErrorHandling.h"
 
 using namespace mlir;
 using namespace mlir::gpu;
@@ -42,22 +43,70 @@ int64_t GPUBlockMappingAttr::getMappingId() const {
   return static_cast<int64_t>(getBlock());
 }
 
+bool GPUBlockMappingAttr::isLinearMapping() const {
+  return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
+}
+
+int64_t GPUBlockMappingAttr::getRelativeIndex() const {
+  return isLinearMapping()
+             ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
+             : getMappingId();
+}
+
+int64_t GPUWarpgroupMappingAttr::getMappingId() const {
+  return static_cast<int64_t>(getWarpgroup());
+}
+
+bool GPUWarpgroupMappingAttr::isLinearMapping() const {
+  return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
+}
+
+int64_t GPUWarpgroupMappingAttr::getRelativeIndex() const {
+  return isLinearMapping()
+             ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
+             : getMappingId();
+}
+
 int64_t GPUWarpMappingAttr::getMappingId() const {
   return static_cast<int64_t>(getWarp());
 }
 
-int64_t GPULinearIdMappingAttr::getMappingId() const {
-  return static_cast<int64_t>(getLinearId());
+bool GPUWarpMappingAttr::isLinearMapping() const {
+  return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
+}
+
+int64_t GPUWarpMappingAttr::getRelativeIndex() const {
+  return isLinearMapping()
+             ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
+             : getMappingId();
 }
 
 int64_t GPUThreadMappingAttr::getMappingId() const {
   return static_cast<int64_t>(getThread());
 }
 
+bool GPUThreadMappingAttr::isLinearMapping() const {
+  return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
+}
+
+int64_t GPUThreadMappingAttr::getRelativeIndex() const {
+  return isLinearMapping()
+             ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
+             : getMappingId();
+}
+
 int64_t GPUMemorySpaceMappingAttr::getMappingId() const {
   return static_cast<int64_t>(getAddressSpace());
 }
 
+bool GPUMemorySpaceMappingAttr::isLinearMapping() const {
+  llvm_unreachable("GPUMemorySpaceMappingAttr does not support linear mapping");
+}
+
+int64_t GPUMemorySpaceMappingAttr::getRelativeIndex() const {
+  llvm_unreachable("GPUMemorySpaceMappingAttr does not support relative index");
+}
+
 //===----------------------------------------------------------------------===//
 // MMAMatrixType
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
index ddbe5d47ff4456..6b3246d116dc4e 100644
--- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
+++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
@@ -12,7 +12,6 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
-#include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h"
 #include "mlir/Dialect/GPU/TransformOps/Utils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h"
@@ -34,6 +33,7 @@
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Debug.h"
+#include "llvm/Support/ErrorHandling.h"
 
 using namespace mlir;
 using namespace mlir::gpu;
@@ -770,23 +770,23 @@ checkMappingAttributeTypes(std::optional<TransformOpInterface> transformOp,
       llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) {
         return isa<GPUBlockMappingAttr>(attr);
       });
-  bool hasThreadMapping =
+  bool hasWarpgroupMapping =
       llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) {
-        return isa<GPUThreadMappingAttr>(attr);
+        return isa<GPUWarpgroupMappingAttr>(attr);
       });
   bool hasWarpMapping =
       llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) {
         return isa<GPUWarpMappingAttr>(attr);
       });
-  bool hasLinearMapping =
+  bool hasThreadMapping =
       llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) {
-        return isa<GPULinearIdMappingAttr>(attr);
+        return isa<GPUThreadMappingAttr>(attr);
       });
   int64_t countMappingTypes = 0;
   countMappingTypes += hasBlockMapping ? 1 : 0;
-  countMappingTypes += hasThreadMapping ? 1 : 0;
+  countMappingTypes += hasWarpgroupMapping ? 1 : 0;
   countMappingTypes += hasWarpMapping ? 1 : 0;
-  countMappingTypes += hasLinearMapping ? 1 : 0;
+  countMappingTypes += hasThreadMapping ? 1 : 0;
   if (countMappingTypes > 1) {
     return definiteFailureHelper(
         transformOp, forallOp,
@@ -798,12 +798,22 @@ checkMappingAttributeTypes(std::optional<TransformOpInterface> transformOp,
     if (seen.contains(map)) {
       return definiteFailureHelper(
           transformOp, forallOp,
-          "duplicated attribute, cannot map 
diff erent loops "
-          "to the same processor");
+          "duplicate attribute, cannot map 
diff erent loops "
+          "to the same mapping id");
     }
     seen.insert(map);
   }
 
+  auto isLinear = [](Attribute a) {
+    return cast<DeviceMappingAttrInterface>(a).isLinearMapping();
+  };
+  if (llvm::any_of(forallOp.getMapping()->getValue(), isLinear) &&
+      !llvm::all_of(forallOp.getMapping()->getValue(), isLinear)) {
+    return definiteFailureHelper(
+        transformOp, forallOp,
+        "cannot mix linear and non-linear mapping modes");
+  }
+
   return DiagnosedSilenceableFailure::success();
 }
 
@@ -823,14 +833,27 @@ verifyGpuMapping(std::optional<TransformOpInterface> transformOp,
   if (forallOp.getNumResults() > 0)
     return definiteFailureHelper(transformOp, forallOp,
                                  "only bufferized scf.forall can be mapped");
-  if (forallOp.getRank() > 3)
+  bool useLinearMapping = cast<DeviceMappingAttrInterface>(
+                              forallOp.getMapping()->getValue().front())
+                              .isLinearMapping();
+  // TODO: This would be more natural with support for Optional<EnumParameter>
+  // in GPUDeviceMappingAttr.
+  int64_t maxNumMappingsSupported =
+      useLinearMapping ? (getMaxEnumValForMappingId() -
+                          static_cast<uint64_t>(MappingId::DimZ))
+                       : 3;
+  if (forallOp.getRank() > maxNumMappingsSupported) {
     return definiteFailureHelper(transformOp, forallOp,
-                                 "scf.forall with rank > 3 does not lower");
-  if (llvm::any_of(forallOp.getMixedUpperBound(), [&](OpFoldResult ofr) {
-        return !getConstantIntValue(ofr).has_value();
-      })) {
-    return definiteFailureHelper(transformOp, forallOp,
-                                 "unsupported dynamic sizes in forall op");
+                                 "scf.forall with rank > ")
+           << maxNumMappingsSupported
+           << " does not lower for the specified mapping attribute type";
+  }
+  auto numParallelIterations =
+      getConstantIntValues(forallOp.getMixedUpperBound());
+  if (!forallOp.isNormalized() || !numParallelIterations.has_value()) {
+    return definiteFailureHelper(
+        transformOp, forallOp,
+        "requires statically sized, normalized forall op");
   }
   return DiagnosedSilenceableFailure::success();
 }
@@ -855,8 +878,8 @@ replaceUnitMappingIdsHelper(RewriterBase &rewriter, Location loc,
 
 static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(
     RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
-    scf::ForallOp forallOp, ForallRewriteResult &result,
-    ArrayRef<int64_t> availableMappingSizes, const GpuIdBuilder &gpuIdBuilder) {
+    scf::ForallOp forallOp, ArrayRef<int64_t> availableMappingSizes,
+    ForallRewriteResult &result, const GpuIdBuilder &gpuIdBuilder) {
   LDBG("--start rewriteOneForallCommonImpl");
 
   // Step 0. GPU-specific verifications. There is no better place to anchor
@@ -867,18 +890,36 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(
     return diag;
 
   // Step 1. Complete the mapping to a full mapping (with 1s) if necessary.
-  SmallVector<int64_t> tmpMappingSizes = llvm::to_vector(
-      llvm::map_range(forallOp.getMixedUpperBound(), [](OpFoldResult ofr) {
-        auto maybeStaticValue = getConstantIntValue(ofr);
-        assert(maybeStaticValue && "expected static value");
-        return maybeStaticValue.value();
-      }));
-  SmallVector<Attribute> forallMappingAttrs =
-      llvm::to_vector(forallOp.getMapping()->getValue());
+  auto numParallelIterations =
+      getConstantIntValues(forallOp.getMixedUpperBound());
+  assert(forallOp.isNormalized() && numParallelIterations.has_value() &&
+         "requires statically sized, normalized forall op");
+  SmallVector<int64_t> tmpMappingSizes = numParallelIterations.value();
+  SetVector<Attribute> forallMappingAttrs;
+  forallMappingAttrs.insert(forallOp.getMapping()->getValue().begin(),
+                            forallOp.getMapping()->getValue().end());
+  auto comparator = [](Attribute a, Attribute b) -> bool {
+    return cast<DeviceMappingAttrInterface>(a).getMappingId() <
+           cast<DeviceMappingAttrInterface>(b).getMappingId();
+  };
+
+  // Step 1.b. In the linear case, compute the max mapping to avoid needlessly
+  // mapping all dimensions. In the 3-D mapping case we need to map all
+  // dimensions.
+  DeviceMappingAttrInterface maxMapping =
+      cast<DeviceMappingAttrInterface>(*std::max_element(
+          forallMappingAttrs.begin(), forallMappingAttrs.end(), comparator));
+  DeviceMappingAttrInterface maxLinearMapping;
+  if (maxMapping.isLinearMapping())
+    maxLinearMapping = maxMapping;
   for (auto attr : gpuIdBuilder.mappingAttributes) {
-    if (llvm::is_contained(forallMappingAttrs, attr))
+    // If attr overflows, just skip.
+    if (maxLinearMapping && comparator(maxLinearMapping, attr))
       continue;
-    forallMappingAttrs.push_back(attr);
+    // Try to insert. If element was already present, just continue.
+    if (!forallMappingAttrs.insert(attr))
+      continue;
+    // Otherwise, we have a new insertion without a size -> use size 1.
     tmpMappingSizes.push_back(1);
   }
   LLVM_DEBUG(
@@ -888,60 +929,65 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(
       llvm::dbgs() << "\n");
 
   // Step 2. sort the values by the corresponding DeviceMappingAttrInterface.
-  auto comparator = [&](Attribute a, Attribute b) -> bool {
-    return cast<DeviceMappingAttrInterface>(a).getMappingId() <
-           cast<DeviceMappingAttrInterface>(b).getMappingId();
-  };
-  SmallVector<int64_t> forallMappingSizes =
-      getValuesSortedByKey(forallMappingAttrs, tmpMappingSizes, comparator);
+  SmallVector<int64_t> forallMappingSizes = getValuesSortedByKey(
+      forallMappingAttrs.getArrayRef(), tmpMappingSizes, comparator);
   LLVM_DEBUG(llvm::interleaveComma(forallMappingSizes,
                                    DBGS() << "----forallMappingSizes: ");
              llvm::dbgs() << "\n"; llvm::interleaveComma(
-                 forallMappingAttrs, DBGS() << "----mappingAttrs: ");
+                 forallMappingAttrs, DBGS() << "----forallMappingAttrs: ");
              llvm::dbgs() << "\n");
 
   // Step 3. Generate the mappingIdOps using the provided generator.
   Location loc = forallOp.getLoc();
   OpBuilder::InsertionGuard guard(rewriter);
   rewriter.setInsertionPoint(forallOp);
+  SmallVector<int64_t> originalBasis(availableMappingSizes);
+  bool originalBasisWasProvided = !originalBasis.empty();
+  if (!originalBasisWasProvided) {
+    originalBasis = forallMappingSizes;
+    while (originalBasis.size() < 3)
+      originalBasis.push_back(1);
+  }
+
   IdBuilderResult builderResult =
-      gpuIdBuilder.idBuilder(rewriter, loc, forallMappingSizes);
+      gpuIdBuilder.idBuilder(rewriter, loc, forallMappingSizes, originalBasis);
 
-  // Step 4. Map the induction variables to the mappingIdOps, this may involve a
-  // permutation.
+  // Step 4. Map the induction variables to the mappingIdOps, this may involve
+  // a permutation.
   SmallVector<Value> mappingIdOps = builderResult.mappingIdOps;
   IRMapping bvm;
-  for (auto [iv, dim] :
-       llvm::zip_equal(forallOp.getInductionVars(),
-                       ArrayRef<Attribute>{forallMappingAttrs}.take_front(
-                           forallOp.getInductionVars().size()))) {
-    Value peIdOp = mappingIdOps[static_cast<int64_t>(
-        cast<DeviceMappingAttrInterface>(dim).getMappingId())];
+  for (auto [iv, dim] : llvm::zip_equal(
+           forallOp.getInductionVars(),
+           forallMappingAttrs.getArrayRef().take_front(forallOp.getRank()))) {
+    auto mappingAttr = cast<DeviceMappingAttrInterface>(dim);
+    Value peIdOp = mappingIdOps[mappingAttr.getRelativeIndex()];
     bvm.map(iv, peIdOp);
   }
 
-  // Step 5. If the availableMappingSizes are already known, create conditionals
-  // to predicate the region. Otherwise, the current forall determines the
-  // availableMappingSizes and no predication occurs.
+  // Step 5. If the originalBasis is already known, create conditionals to
+  // predicate the region. Otherwise, the current forall determines the
+  // originalBasis and no predication occurs.
   Value predicate;
-  if (!availableMappingSizes.empty()) {
-    SmallVector<int64_t> predicateMappingSizes =
-        builderResult.predicateMappingSizes;
-    SmallVector<Value> predicateIdOps = builderResult.predicateIdOps;
+  if (originalBasisWasProvided) {
+    SmallVector<int64_t> activeMappingSizes = builderResult.activeMappingSizes;
+    SmallVector<int64_t> availableMappingSizes =
+        builderResult.availableMappingSizes;
+    SmallVector<Value> activeIdOps = builderResult.activeIdOps;
     // clang-format off
     LLVM_DEBUG(
         llvm::interleaveComma(
-          predicateMappingSizes, DBGS() << "----predicateMappingSizes: ");
+          activeMappingSizes, DBGS() << "----activeMappingSizes: ");
         llvm::dbgs() << "\n"; 
         llvm::interleaveComma(
           availableMappingSizes, DBGS() << "----availableMappingSizes: ");
         llvm::dbgs() << "\n";
-        llvm::interleaveComma(predicateIdOps, DBGS() << "----predicateIdOps: ");
+        llvm::interleaveComma(activeIdOps, DBGS() << "----activeIdOps: ");
         llvm::dbgs() << "\n");
     // clang-format on
-    for (auto [id, mappingSize, availableMappingSize] : llvm::zip_equal(
-             predicateIdOps, predicateMappingSizes, availableMappingSizes)) {
-      if (mappingSize > availableMappingSize) {
+    for (auto [activeId, activeMappingSize, availableMappingSize] :
+         llvm::zip_equal(activeIdOps, activeMappingSizes,
+                         availableMappingSizes)) {
+      if (activeMappingSize > availableMappingSize) {
         return definiteFailureHelper(
             transformOp, forallOp,
             "Trying to map to fewer GPU threads than loop iterations but "
@@ -949,11 +995,12 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(
             "Try additional tiling of the before mapping or map to more "
             "threads.");
       }
-      if (mappingSize == availableMappingSize)
+      if (activeMappingSize == availableMappingSize)
         continue;
-      Value idx = rewriter.create<arith::ConstantIndexOp>(loc, mappingSize);
+      Value idx =
+          rewriter.create<arith::ConstantIndexOp>(loc, activeMappingSize);
       Value tmpPredicate = rewriter.create<arith::CmpIOp>(
-          loc, arith::CmpIPredicate::ult, id, idx);
+          loc, arith::CmpIPredicate::ult, activeId, idx);
       LDBG("----predicate: " << tmpPredicate);
       predicate = predicate ? rewriter.create<arith::AndIOp>(loc, predicate,
                                                              tmpPredicate)
@@ -991,6 +1038,12 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(
   // Step 8. Erase old op.
   rewriter.eraseOp(forallOp);
 
+  LLVM_DEBUG(llvm::interleaveComma(forallMappingSizes,
+                                   DBGS() << "----result forallMappingSizes: ");
+             llvm::dbgs() << "\n"; llvm::interleaveComma(
+                 mappingIdOps, DBGS() << "----result mappingIdOps: ");
+             llvm::dbgs() << "\n");
+
   result = ForallRewriteResult{forallMappingSizes, mappingIdOps};
   return DiagnosedSilenceableFailure::success();
 }
@@ -1016,28 +1069,52 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapForallToBlocksImpl(
     zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
   }
 
-  SmallVector<int64_t> anyAvailableMappingSizes;
   ForallRewriteResult rewriteResult;
-  // Pass an empty anyAvailableMappingSizes.
-  DiagnosedSilenceableFailure diag =
-      rewriteOneForallCommonImpl(rewriter, transformOp, forallOp, rewriteResult,
-                                 anyAvailableMappingSizes, gpuIdBuilder);
+  DiagnosedSilenceableFailure diag = rewriteOneForallCommonImpl(
+      rewriter, transformOp, forallOp,
+      /*availableMappingSizes=*/gridDims, rewriteResult, gpuIdBuilder);
 
-  // Return if anything goes wrong, use silenceable failure as a match failure.
+  // Return if anything goes wrong, use silenceable failure as a match
+  // failure.
   if (!diag.succeeded())
     return diag;
 
-  // Set the gridDims that act as a return.
-  gridDims = rewriteResult.mappingSizes;
+  // If gridDims was not provided already, set it from the return.
+  if (gridDims.empty()) {
+    gridDims = rewriteResult.mappingSizes;
+    while (gridDims.size() < 3)
+      gridDims.push_back(1);
+  }
+  assert(gridDims.size() == 3 && "Need 3-D gridDims");
 
   // Replace ids of dimensions known to be 1 by 0 to simplify the IR.
   // Here, the result of mapping determines the available mapping sizes.
   replaceUnitMappingIdsHelper<BlockDimOp>(rewriter, loc, parentBlock, zero,
-                                          gridDims);
+                                          rewriteResult.mappingSizes);
 
   return DiagnosedSilenceableFailure::success();
 }
 
+DiagnosedSilenceableFailure
+mlir::transform::gpu::findTopLevelForallOp(Operation *target,
+                                           scf::ForallOp &topLevelForallOp,
+                                           TransformOpInterface transformOp) {
+  auto walkResult = target->walk([&](scf::ForallOp forallOp) {
+    if (forallOp->getParentOfType<scf::ForallOp>())
+      return WalkResult::advance();
+    if (topLevelForallOp)
+      // TODO: Handle multiple forall if they are independent.
+      return WalkResult::interrupt();
+    topLevelForallOp = forallOp;
+    return WalkResult::advance();
+  });
+
+  if (walkResult.wasInterrupted())
+    return transformOp.emitSilenceableError()
+           << "could not find a unique topLevel scf.forall";
+  return DiagnosedSilenceableFailure::success();
+}
+
 DiagnosedSilenceableFailure transform::MapForallToBlocks::applyToOne(
     transform::TransformRewriter &rewriter, Operation *target,
     ApplyToEachResultList &results, transform::TransformState &state) {
@@ -1072,23 +1149,28 @@ DiagnosedSilenceableFailure transform::MapForallToBlocks::applyToOne(
   if (getGenerateGpuLaunch()) {
     DiagnosedSilenceableFailure diag =
         createGpuLaunch(rewriter, target->getLoc(), transformOp, gpuLaunch);
-    if (!diag.succeeded()) {
+    if (!diag.succeeded())
       return diag;
-    }
+
     rewriter.setInsertionPointToStart(&gpuLaunch.getBody().front());
     Operation *newForallOp = rewriter.clone(*topLevelForallOp);
     rewriter.eraseOp(topLevelForallOp);
     topLevelForallOp = cast<scf::ForallOp>(newForallOp);
   }
 
-  GpuBlockIdBuilder gpuBlockIdBuilder(getContext(), {}, {});
+  // The BlockIdBuilder adapts to whatever is thrown at it.
+  auto mappingAttr = cast<DeviceMappingAttrInterface>(
+      topLevelForallOp.getMapping()->getValue().front());
+  bool useLinearMapping = mappingAttr.isLinearMapping();
+  GpuBlockIdBuilder gpuBlockIdBuilder(getContext(), useLinearMapping);
+
   diag = mlir::transform::gpu::mapForallToBlocksImpl(
       rewriter, transformOp, topLevelForallOp, gridDims, gpuBlockIdBuilder);
   if (!diag.succeeded())
     return diag;
 
-  // Set the GPU launch configuration for the grid dims late, this is subject to
-  // IR inspection.
+  // Set the GPU launch configuration for the grid dims late, this is
+  // subject to IR inspection.
   diag = alterGpuLaunch(rewriter, gpuLaunch,
                         cast<TransformOpInterface>(getOperation()), gridDims[0],
                         gridDims[1], gridDims[2]);
@@ -1101,19 +1183,90 @@ DiagnosedSilenceableFailure transform::MapForallToBlocks::applyToOne(
 // MapNestedForallToThreads
 //===----------------------------------------------------------------------===//
 
+static DiagnosedSilenceableFailure checkMappingSpec(
+    std::optional<TransformOpInterface> transformOp, scf::ForallOp forallOp,
+    ArrayRef<int64_t> numParallelIterations, ArrayRef<int64_t> blockOrGridSizes,
+    int factor, bool useLinearMapping = false) {
+  if (!useLinearMapping && blockOrGridSizes.front() % factor != 0) {
+    auto diag = definiteFailureHelper(
+        transformOp, forallOp,
+        Twine("3-D mapping: size of threadIdx.x must be a multiple of ") +
+            std::to_string(factor));
+    return diag;
+  }
+  if (computeProduct(numParallelIterations) * factor >
+      computeProduct(blockOrGridSizes)) {
+    auto diag = definiteFailureHelper(
+        transformOp, forallOp,
+        Twine(
+            "the number of required parallel resources (blocks or threads) ") +
+            std::to_string(computeProduct(numParallelIterations) * factor) +
+            std::string(" overflows the number of available resources ") +
+            std::to_string(computeProduct(blockOrGridSizes)));
+    return diag;
+  }
+  return DiagnosedSilenceableFailure::success();
+}
+
+static DiagnosedSilenceableFailure
+getThreadIdBuilder(std::optional<TransformOpInterface> transformOp,
+                   scf::ForallOp forallOp, ArrayRef<int64_t> blockSizes,
+                   int64_t warpSize, GpuIdBuilder &gpuIdBuilder) {
+  auto mappingAttr = cast<DeviceMappingAttrInterface>(
+      forallOp.getMapping()->getValue().front());
+  bool useLinearMapping = mappingAttr.isLinearMapping();
+
+  // Sanity checks that may result in runtime verification errors.
+  auto numParallelIterations =
+      getConstantIntValues((forallOp.getMixedUpperBound()));
+  if (!forallOp.isNormalized() || !numParallelIterations.has_value()) {
+    return definiteFailureHelper(
+        transformOp, forallOp,
+        "requires statically sized, normalized forall op");
+  }
+  int64_t factor = 1;
+  if (isa<GPUWarpgroupMappingAttr>(mappingAttr)) {
+    factor = GpuWarpgroupIdBuilder::kNumWarpsPerGroup * warpSize;
+  } else if (isa<GPUWarpMappingAttr>(mappingAttr)) {
+    factor = warpSize;
+  }
+  DiagnosedSilenceableFailure diag =
+      checkMappingSpec(transformOp, forallOp, numParallelIterations.value(),
+                       blockSizes, factor, useLinearMapping);
+  if (!diag.succeeded())
+    return diag;
+
+  // Start mapping.
+  MLIRContext *ctx = forallOp.getContext();
+  gpuIdBuilder =
+      TypeSwitch<DeviceMappingAttrInterface, GpuIdBuilder>(mappingAttr)
+          .Case([&](GPUWarpgroupMappingAttr) {
+            return GpuWarpgroupIdBuilder(ctx, warpSize, useLinearMapping);
+          })
+          .Case([&](GPUWarpMappingAttr) {
+            return GpuWarpIdBuilder(ctx, warpSize, useLinearMapping);
+          })
+          .Case([&](GPUThreadMappingAttr) {
+            return GpuThreadIdBuilder(ctx, useLinearMapping);
+          })
+          .Default([&](DeviceMappingAttrInterface) -> GpuIdBuilder {
+            llvm_unreachable("unknown mapping attribute");
+          });
+  return DiagnosedSilenceableFailure::success();
+}
+
 DiagnosedSilenceableFailure mlir::transform::gpu::mapOneForallToThreadsImpl(
     RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
-    scf::ForallOp forallOp, ArrayRef<int64_t> availableMappingSizes,
-    bool syncAfterDistribute, const GpuIdBuilder &gpuIdBuilder) {
-  // Ignore cases with 
diff erent attributes than this builder supports.
-  for (Attribute map : forallOp.getMapping()->getValue()) {
-    if (!llvm::is_contained(gpuIdBuilder.mappingAttributes, map)) {
-      LDBG("--skip " << map);
-      LLVM_DEBUG(llvm::interleaveComma(gpuIdBuilder.mappingAttributes,
-                                       DBGS() << "----not in: ");
-                 llvm::dbgs() << "\n";);
-      return emitSilenceableFailure(forallOp);
-    }
+    scf::ForallOp forallOp, ArrayRef<int64_t> blockSizes, int64_t warpSize,
+    bool syncAfterDistribute) {
+
+  GpuIdBuilder gpuIdBuilder;
+  {
+    // Try to construct the id builder, if it fails, return.
+    DiagnosedSilenceableFailure diag = getThreadIdBuilder(
+        transformOp, forallOp, blockSizes, warpSize, gpuIdBuilder);
+    if (!diag.succeeded())
+      return diag;
   }
 
   Location loc = forallOp.getLoc();
@@ -1121,14 +1274,10 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapOneForallToThreadsImpl(
   // Insert after to allow for syncthreads after `forall` is erased.
   rewriter.setInsertionPointAfter(forallOp);
   ForallRewriteResult rewriteResult;
-  DiagnosedSilenceableFailure diag =
-      rewriteOneForallCommonImpl(rewriter, transformOp, forallOp, rewriteResult,
-                                 availableMappingSizes, gpuIdBuilder);
-
-  // Return if anything goes wrong, use silenceable failure as a match failure.
+  DiagnosedSilenceableFailure diag = rewriteOneForallCommonImpl(
+      rewriter, transformOp, forallOp, blockSizes, rewriteResult, gpuIdBuilder);
   if (!diag.succeeded())
     return diag;
-
   // Add a syncthreads if needed. TODO: warpsync
   if (syncAfterDistribute)
     rewriter.create<BarrierOp>(loc);
@@ -1138,20 +1287,12 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapOneForallToThreadsImpl(
 
 DiagnosedSilenceableFailure mlir::transform::gpu::mapNestedForallToThreadsImpl(
     RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
-    Operation *target, ArrayRef<int64_t> blockDims, ArrayRef<int64_t> warpDims,
+    Operation *target, ArrayRef<int64_t> blockDims, int64_t warpSize,
     bool syncAfterDistribute) {
   LDBG("Start mapNestedForallToThreadsImpl");
-  MLIRContext *ctx = rewriter.getContext();
-  SmallVector<OpFoldResult> blockDimsOfr =
-      getAsIndexOpFoldResult(ctx, blockDims);
-
-  if (blockDims.size() != 3)
+  if (blockDims.size() != 3) {
     return definiteFailureHelper(transformOp, target,
                                  "requires size-3 thread mapping");
-  if (!warpDims.empty()) {
-    if (warpDims.size() != 3)
-      return definiteFailureHelper(transformOp, target,
-                                   "requires empty or size-3 warp mapping");
   }
 
   // Create an early zero index value for replacements.
@@ -1159,64 +1300,13 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapNestedForallToThreadsImpl(
   Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
   DiagnosedSilenceableFailure diag = DiagnosedSilenceableFailure::success();
   WalkResult walkResult = target->walk([&](scf::ForallOp forallOp) {
-    //===--------------------------------------------------------------------===//
-    // Mapping to warp ids.
-    //===--------------------------------------------------------------------===//
-    if (!warpDims.empty()) {
-      LLVM_DEBUG(
-          llvm::interleaveComma(
-              warpDims, DBGS() << "+mapNestedForallToThreadsImpl warpDims: ");
-          llvm::dbgs() << "\n");
-      LLVM_DEBUG(llvm::interleaveComma(
-                     blockDimsOfr, DBGS() << "--warpDims with blockDimsOfr:  ");
-                 llvm::dbgs() << "\n");
-      GpuWarpIdBuilder gpuWarpIdBuilder(ctx, blockDimsOfr, warpDims);
-      diag = mlir::transform::gpu::mapOneForallToThreadsImpl(
-          rewriter, transformOp, forallOp, warpDims, syncAfterDistribute,
-          gpuWarpIdBuilder);
-      // Use silenceable failure to encode "failure to match" and pass
-      // through.
-      if (diag.isDefiniteFailure())
-        return WalkResult::interrupt();
-      if (diag.succeeded())
-        return WalkResult::skip();
-    }
-
-    //===--------------------------------------------------------------------===//
-    // Mapping to linear ids.
-    //===--------------------------------------------------------------------===//
-    LDBG("+mapNestedForallToThreadsImpl linearDims");
-    LLVM_DEBUG(llvm::interleaveComma(
-                   blockDimsOfr, DBGS() << "--linearDims with blockDimsOfr:  ");
-               llvm::dbgs() << "\n");
-    int64_t numThreads = 1;
-    for (int64_t b : blockDims)
-      numThreads *= b;
-    GpuLinearIdBuilder gpuLinearIdBuilder(ctx, blockDimsOfr, numThreads);
     diag = mlir::transform::gpu::mapOneForallToThreadsImpl(
-        rewriter, transformOp, forallOp, numThreads, syncAfterDistribute,
-        gpuLinearIdBuilder);
-    // Use silenceable failure to encode "failure to match" and pass through.
+        rewriter, transformOp, forallOp, blockDims, warpSize,
+        syncAfterDistribute);
     if (diag.isDefiniteFailure())
       return WalkResult::interrupt();
     if (diag.succeeded())
       return WalkResult::skip();
-
-    //===--------------------------------------------------------------------===//
-    // Mapping to block ids (happens last so we can replay ThreadIdOp).
-    //===--------------------------------------------------------------------===//
-    LLVM_DEBUG(
-        llvm::interleaveComma(
-            blockDimsOfr, DBGS() << "mapNestedForallToThreadsImpl blockDims: ");
-        llvm::dbgs() << "\n");
-    GpuThreadIdBuilder gpuThreadIdBuilder(ctx, blockDimsOfr, blockDims);
-    diag = mlir::transform::gpu::mapOneForallToThreadsImpl(
-        rewriter, transformOp, forallOp, blockDims, syncAfterDistribute,
-        gpuThreadIdBuilder);
-    // Use silenceable failure to encode "failure to match" and pass through.
-    if (diag.isDefiniteFailure())
-      return WalkResult::interrupt();
-
     return WalkResult::advance();
   });
   if (walkResult.wasInterrupted())
@@ -1242,7 +1332,6 @@ DiagnosedSilenceableFailure transform::MapNestedForallToThreads::applyToOne(
 
   // Mapping to block ids.
   SmallVector<int64_t> blockDims{getBlockDims()};
-
   DiagnosedSilenceableFailure diag =
       checkGpuLimits(transformOp, std::nullopt, std::nullopt, std::nullopt,
                      blockDims[0], blockDims[1], blockDims[2]);
@@ -1260,7 +1349,7 @@ DiagnosedSilenceableFailure transform::MapNestedForallToThreads::applyToOne(
   rewriter.setInsertionPointToStart(&gpuLaunch.getBody().front());
   diag =
       mapNestedForallToThreadsImpl(rewriter, transformOp, gpuLaunch, blockDims,
-                                   getWarpDims(), getSyncAfterDistribute());
+                                   getWarpSize(), getSyncAfterDistribute());
 
   results.push_back(gpuLaunch.getOperation());
   return diag;

diff  --git a/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp b/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp
index 3ba9bf1a5a14a4..55683aebebfc17 100644
--- a/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp
+++ b/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp
@@ -46,103 +46,59 @@ using namespace mlir::transform::gpu;
 #define DBGS_ALIAS() (llvm::dbgs() << '[' << DEBUG_TYPE_ALIAS << "] ")
 
 /// Return a flattened thread id for the workgroup with given sizes.
-static Value buildLinearThreadId(RewriterBase &rewriter, Location loc,
-                                 ArrayRef<OpFoldResult> blockDimsOfr) {
+template <typename ThreadOrBlockIdOp>
+static Value buildLinearId(RewriterBase &rewriter, Location loc,
+                           ArrayRef<OpFoldResult> originalBasisOfr) {
   LLVM_DEBUG(llvm::interleaveComma(
-                 blockDimsOfr,
-                 DBGS() << "----buildLinearThreadId with blockDimsOfr:  ");
+                 originalBasisOfr,
+                 DBGS() << "----buildLinearId with originalBasisOfr:  ");
              llvm::dbgs() << "\n");
-  assert(blockDimsOfr.size() == 3 && "expected 3 workgroup sizes");
+  assert(originalBasisOfr.size() == 3 && "expected 3 sizes");
+  IndexType indexType = rewriter.getIndexType();
   AffineExpr tx, ty, tz, BDX, BDY;
   bindDims(rewriter.getContext(), tx, ty, tz);
   bindSymbols(rewriter.getContext(), BDX, BDY);
-  IndexType indexType = rewriter.getIndexType();
-  SmallVector<OpFoldResult> threadsAndWorkGroups{
-      rewriter.create<ThreadIdOp>(loc, indexType, Dimension::x).getResult(),
-      rewriter.create<ThreadIdOp>(loc, indexType, Dimension::y).getResult(),
-      rewriter.create<ThreadIdOp>(loc, indexType, Dimension::z).getResult()};
-  threadsAndWorkGroups.push_back(blockDimsOfr[0]);
-  threadsAndWorkGroups.push_back(blockDimsOfr[1]);
+  SmallVector<OpFoldResult> vals{
+      rewriter.create<ThreadOrBlockIdOp>(loc, indexType, Dimension::x)
+          .getResult(),
+      rewriter.create<ThreadOrBlockIdOp>(loc, indexType, Dimension::y)
+          .getResult(),
+      rewriter.create<ThreadOrBlockIdOp>(loc, indexType, Dimension::z)
+          .getResult(),
+      originalBasisOfr[0], originalBasisOfr[1]};
   OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
-      rewriter, loc, tx + ty * BDX + tz * BDX * BDY, threadsAndWorkGroups);
+      rewriter, loc, tx + ty * BDX + tz * BDX * BDY, vals);
   return getValueOrCreateConstantIndexOp(rewriter, loc, ofr);
 }
 
-namespace mlir {
-namespace transform {
-namespace gpu {
-
-GpuBlockIdBuilder::GpuBlockIdBuilder(MLIRContext *ctx,
-                                     ArrayRef<OpFoldResult> blockDims,
-                                     ArrayRef<int64_t> mappingSizes)
-    : GpuIdBuilder(blockDims, mappingSizes) {
-  mappingAttributes = {GPUBlockMappingAttr::get(ctx, Blocks::DimX),
-                       GPUBlockMappingAttr::get(ctx, Blocks::DimY),
-                       GPUBlockMappingAttr::get(ctx, Blocks::DimZ)},
-  idBuilder = [](RewriterBase &rewriter, Location loc,
-                 ArrayRef<int64_t> forallMappingSizes) {
-    IndexType indexType = rewriter.getIndexType();
-    SmallVector<Value> ids{
-        rewriter.create<BlockIdOp>(loc, indexType, Dimension::x),
-        rewriter.create<BlockIdOp>(loc, indexType, Dimension::y),
-        rewriter.create<BlockIdOp>(loc, indexType, Dimension::z)};
-    // Return 3-D ids for indexing rewrite and 3-D sizes and ids for
-    // predicate generation.
-    return IdBuilderResult{ids, SmallVector<int64_t>{forallMappingSizes}, ids};
-  };
-}
-
-GpuThreadIdBuilder::GpuThreadIdBuilder(MLIRContext *ctx,
-                                       ArrayRef<OpFoldResult> blockDims,
-                                       ArrayRef<int64_t> mappingSizes)
-    : GpuIdBuilder(blockDims, mappingSizes) {
-  mappingAttributes = {GPUThreadMappingAttr::get(ctx, Threads::DimX),
-                       GPUThreadMappingAttr::get(ctx, Threads::DimY),
-                       GPUThreadMappingAttr::get(ctx, Threads::DimZ)};
-  idBuilder = [](RewriterBase &rewriter, Location loc,
-                 ArrayRef<int64_t> forallMappingSizes) {
-    IndexType indexType = rewriter.getIndexType();
-    SmallVector<Value> ids{
-        rewriter.create<ThreadIdOp>(loc, indexType, Dimension::x),
-        rewriter.create<ThreadIdOp>(loc, indexType, Dimension::y),
-        rewriter.create<ThreadIdOp>(loc, indexType, Dimension::z)};
-    // Return 3-D ids for indexing rewrite and 3-D sizes and ids for
-    // predicate generation.
-    return IdBuilderResult{ids, SmallVector<int64_t>{forallMappingSizes}, ids};
-  };
-}
-
-GpuWarpIdBuilder::GpuWarpIdBuilder(MLIRContext *ctx,
-                                   ArrayRef<OpFoldResult> blockDims,
-                                   ArrayRef<int64_t> mappingSizes)
-    : GpuIdBuilder(blockDims, mappingSizes) {
-  mappingAttributes = {GPUWarpMappingAttr::get(ctx, Warps::DimX),
-                       GPUWarpMappingAttr::get(ctx, Warps::DimY),
-                       GPUWarpMappingAttr::get(ctx, Warps::DimZ)};
-  idBuilder = [this](RewriterBase &rewriter, Location loc,
-                     ArrayRef<int64_t> forallMappingSizes) {
-    // Build the linear warp id and decompose it in the basis of
-    // `forallMappingSizes`.
-    Value linearId = buildLinearThreadId(rewriter, loc, this->blockDimsOfr);
-    AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
-    OpFoldResult warpIdOfr = affine::makeComposedFoldedAffineApply(
-        rewriter, loc, d0.floorDiv(kWarpSize), {linearId});
-    Value warpId = getValueOrCreateConstantIndexOp(rewriter, loc, warpIdOfr);
-    // Sizes in [x, y, z] -> [z, y x] order to properly compute strides in
+/// Create a linear id builder that takes the `originalBasisOfr` and decompose
+/// it in the basis of `forallMappingSizes`. The linear id builder returns an
+/// n-D vector of ids for indexing and 1-D size + id for predicate generation.
+template <typename ThreadOrBlockIdOp>
+static GpuIdBuilderFnType commonLinearIdBuilderFn(int64_t multiplicity = 1) {
+  auto res = [multiplicity](RewriterBase &rewriter, Location loc,
+                            ArrayRef<int64_t> forallMappingSizes,
+                            ArrayRef<int64_t> originalBasis) {
+    SmallVector<OpFoldResult> originalBasisOfr =
+        getAsIndexOpFoldResult(rewriter.getContext(), originalBasis);
+    OpFoldResult linearId =
+        buildLinearId<ThreadOrBlockIdOp>(rewriter, loc, originalBasisOfr);
+    // Sizes in [0 .. n] -> [n .. 0] order to properly compute strides in
     // "row-major" order.
-    SmallVector<int64_t> reverseBasisSizes(
-        llvm::reverse(this->availableMappingSizes));
+    SmallVector<int64_t> reverseBasisSizes(llvm::reverse(forallMappingSizes));
     SmallVector<int64_t> strides = computeStrides(reverseBasisSizes);
+    AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
+    OpFoldResult scaledLinearId = affine::makeComposedFoldedAffineApply(
+        rewriter, loc, d0.floorDiv(multiplicity), {linearId});
     SmallVector<AffineExpr> delinearizingExprs = delinearize(d0, strides);
     SmallVector<Value> ids;
-    // Reverse back to be in [x, y, z] order.
-    for (AffineExpr e : llvm::reverse(delinearizingExprs))
+    // Reverse back to be in [0 .. n] order.
+    for (AffineExpr e : llvm::reverse(delinearizingExprs)) {
       ids.push_back(
-          affine::makeComposedAffineApply(rewriter, loc, e, {warpId}));
+          affine::makeComposedAffineApply(rewriter, loc, e, {scaledLinearId}));
+    }
 
     // clang-format off
-      LDBG("----linearId: " << linearId);
-          LDBG("----warpId: " << warpId);
       LLVM_DEBUG(llvm::interleaveComma(reverseBasisSizes,
                                        DBGS() << "--delinearization basis: ");
                  llvm::dbgs() << "\n";
@@ -156,63 +112,121 @@ GpuWarpIdBuilder::GpuWarpIdBuilder(MLIRContext *ctx,
                  llvm::dbgs() << "\n";);
     // clang-format on
 
-    // Return 3-D ids for indexing rewrite and 3-D sizes and ids for
-    // predicate generation.
-    return IdBuilderResult{ids, SmallVector<int64_t>{forallMappingSizes}, ids};
+    // Return n-D ids for indexing and 1-D size + id for predicate generation.
+    return IdBuilderResult{
+        /*mappingIdOps=*/ids,
+        /*availableMappingSizes=*/
+        SmallVector<int64_t>{computeProduct(originalBasis)},
+        // `forallMappingSizes` iterate in the scaled basis, they need to be
+        // scaled back into the original basis to provide tight
+        // activeMappingSizes quantities for predication.
+        /*activeMappingSizes=*/
+        SmallVector<int64_t>{computeProduct(forallMappingSizes) * multiplicity},
+        /*activeIdOps=*/SmallVector<Value>{linearId.get<Value>()}};
   };
+
+  return res;
 }
 
-GpuLinearIdBuilder::GpuLinearIdBuilder(MLIRContext *ctx,
-                                       ArrayRef<OpFoldResult> blockDims,
-                                       ArrayRef<int64_t> mappingSizes)
-    : GpuIdBuilder(blockDims, mappingSizes) {
-  mappingAttributes = {GPULinearIdMappingAttr::get(ctx, LinearId::DimX),
-                       GPULinearIdMappingAttr::get(ctx, LinearId::DimY),
-                       GPULinearIdMappingAttr::get(ctx, LinearId::DimZ)};
-  idBuilder = [this](RewriterBase &rewriter, Location loc,
-                     ArrayRef<int64_t> forallMappingSizes) {
-    // Build the linear thread id and decompose it in the basis of
-    // `forallMappingSizes`.
-    Value linearId = buildLinearThreadId(rewriter, loc, this->blockDimsOfr);
-    // Sizes in [x, y, z] -> [z, y x] order to properly compute strides in
-    // "row-major" order.
-    SmallVector<int64_t> reverseBasisSizes(llvm::reverse(forallMappingSizes));
-    SmallVector<int64_t> strides = computeStrides(reverseBasisSizes);
-    AffineExpr d0;
-    bindDims(rewriter.getContext(), d0);
-    SmallVector<AffineExpr> delinearizingExprs = delinearize(d0, strides);
-    SmallVector<Value> ids;
-    // Reverse back to be in [x, y, z] order.
-    for (AffineExpr e : llvm::reverse(delinearizingExprs))
-      ids.push_back(
-          affine::makeComposedAffineApply(rewriter, loc, e, {linearId}));
+/// Create a simple 3-D id builder that takes the `originalBasisOfr`
+/// The 3-D id builder returns a 3-D vector of ids for indexing and 3-D sizes
+/// + ids for predicate generation.
+template <typename ThreadOrBlockIdOp>
+static GpuIdBuilderFnType common3DIdBuilderFn(int64_t multiplicity = 1) {
+  auto res = [multiplicity](RewriterBase &rewriter, Location loc,
+                            ArrayRef<int64_t> forallMappingSizes,
+                            ArrayRef<int64_t> originalBasis) {
+    IndexType indexType = rewriter.getIndexType();
+    SmallVector<Value> ids{
+        rewriter.create<ThreadOrBlockIdOp>(loc, indexType, Dimension::x),
+        rewriter.create<ThreadOrBlockIdOp>(loc, indexType, Dimension::y),
+        rewriter.create<ThreadOrBlockIdOp>(loc, indexType, Dimension::z)};
+    // In the 3-D mapping case, scale the first dimension by the multiplicity.
+    SmallVector<Value> scaledIds = ids;
+    AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
+    scaledIds[0] = affine::makeComposedFoldedAffineApply(
+                       rewriter, loc, d0.floorDiv(multiplicity), {scaledIds[0]})
+                       .get<Value>();
+    // In the 3-D mapping case, unscale the first dimension by the multiplicity.
+    SmallVector<int64_t> forallMappingSizeInOriginalBasis(
+        forallMappingSizes.begin(), forallMappingSizes.end());
+    forallMappingSizeInOriginalBasis[0] *= multiplicity;
+    return IdBuilderResult{
+        /*mappingIdOps=*/scaledIds,
+        /*availableMappingSizes=*/SmallVector<int64_t>{originalBasis},
+        // `forallMappingSizes` iterate in the scaled basis, they need to be
+        // scaled back into the original basis to provide tight
+        // activeMappingSizes quantities for predication.
+        /*activeMappingSizes=*/
+        SmallVector<int64_t>{forallMappingSizeInOriginalBasis},
+        /*activeIdOps=*/ids};
+  };
+  return res;
+}
 
-    // clang-format off
-      LLVM_DEBUG(llvm::interleaveComma(reverseBasisSizes,
-                                       DBGS() << "--delinearization basis: ");
-                 llvm::dbgs() << "\n";
-                 llvm::interleaveComma(strides,
-                                       DBGS() << "--delinearization strides: ");
-                 llvm::dbgs() << "\n";
-                 llvm::interleaveComma(delinearizingExprs,
-                                       DBGS() << "--delinearization exprs: ");
-                 llvm::dbgs() << "\n";
-                 llvm::interleaveComma(ids, DBGS() << "--ids: ");
-                 llvm::dbgs() << "\n";);
-    // clang-format on
+namespace mlir {
+namespace transform {
+namespace gpu {
 
-    // Compute and return the 1-D actual mapping size spanned by the linearId,
-    // it will be used to predicate against the linearized total number of
-    // threads.
-    int64_t actualMappingSize = 1;
-    for (int64_t s : forallMappingSizes)
-      actualMappingSize *= s;
+GpuIdBuilder::GpuIdBuilder(MLIRContext *ctx, bool useLinearMapping,
+                           MappingIdBuilderFnType fn)
+    : mappingAttributes(), idBuilder() {
+  if (useLinearMapping) {
+    for (uint64_t d = static_cast<uint64_t>(MappingId::LinearDim0),
+                  e = getMaxEnumValForMappingId();
+         d <= e; ++d)
+      mappingAttributes.push_back(fn(ctx, symbolizeMappingId(d).value()));
+  } else {
+    for (uint64_t d = static_cast<uint64_t>(MappingId::DimX),
+                  e = static_cast<uint64_t>(MappingId::DimZ);
+         d <= e; ++d)
+      mappingAttributes.push_back(fn(ctx, symbolizeMappingId(d).value()));
+  }
+}
 
-    // Return 3-D ids for indexing rewrite and 1-D size and id for
-    // predicate generation.
-    return IdBuilderResult{ids, SmallVector<int64_t>{actualMappingSize},
-                           SmallVector<Value>{linearId}};
-  };
+GpuBlockIdBuilder::GpuBlockIdBuilder(MLIRContext *ctx, bool useLinearMapping)
+    : GpuIdBuilder(ctx, useLinearMapping, [](MLIRContext *ctx, MappingId id) {
+        return GPUBlockMappingAttr::get(ctx, id);
+      }) {
+  idBuilder = useLinearMapping
+                  ? commonLinearIdBuilderFn<BlockIdOp>(/*multiplicity=*/1)
+                  : common3DIdBuilderFn<BlockIdOp>(/*multiplicity=*/1);
+}
+
+GpuWarpgroupIdBuilder::GpuWarpgroupIdBuilder(MLIRContext *ctx, int64_t warpSize,
+                                             bool useLinearMapping)
+    : GpuIdBuilder(ctx, useLinearMapping,
+                   [](MLIRContext *ctx, MappingId id) {
+                     return GPUWarpgroupMappingAttr::get(ctx, id);
+                   }),
+      warpSize(warpSize) {
+  idBuilder = useLinearMapping
+                  ? commonLinearIdBuilderFn<ThreadIdOp>(
+                        /*multiplicity=*/kNumWarpsPerGroup * warpSize)
+                  : common3DIdBuilderFn<ThreadIdOp>(
+                        /*multiplicity=*/kNumWarpsPerGroup * warpSize);
+}
+
+GpuWarpIdBuilder::GpuWarpIdBuilder(MLIRContext *ctx, int64_t warpSize,
+                                   bool useLinearMapping)
+    : GpuIdBuilder(ctx, useLinearMapping,
+                   [](MLIRContext *ctx, MappingId id) {
+                     return GPUWarpMappingAttr::get(ctx, id);
+                   }),
+      warpSize(warpSize) {
+  idBuilder =
+      useLinearMapping
+          ? commonLinearIdBuilderFn<ThreadIdOp>(/*multiplicity=*/warpSize)
+          : common3DIdBuilderFn<ThreadIdOp>(/*multiplicity=*/warpSize);
+}
+
+GpuThreadIdBuilder::GpuThreadIdBuilder(MLIRContext *ctx, bool useLinearMapping)
+    : GpuIdBuilder(ctx, useLinearMapping, [](MLIRContext *ctx, MappingId id) {
+        return GPUThreadMappingAttr::get(ctx, id);
+      }) {
+  idBuilder = useLinearMapping
+                  ? commonLinearIdBuilderFn<ThreadIdOp>(/*multiplicity=*/1)
+                  : common3DIdBuilderFn<ThreadIdOp>(/*multiplicity=*/1);
 }
 
 DiagnosedSilenceableFailure checkGpuLimits(TransformOpInterface transformOp,
@@ -322,25 +336,6 @@ DiagnosedSilenceableFailure alterGpuLaunch(
   return DiagnosedSilenceableFailure::success();
 }
 
-DiagnosedSilenceableFailure
-findTopLevelForallOp(Operation *target, scf::ForallOp &topLevelForallOp,
-                     TransformOpInterface transformOp) {
-  auto walkResult = target->walk([&](scf::ForallOp forallOp) {
-    if (forallOp->getParentOfType<scf::ForallOp>())
-      return WalkResult::advance();
-    if (topLevelForallOp)
-      // TODO: Handle multiple forall if they are independent.
-      return WalkResult::interrupt();
-    topLevelForallOp = forallOp;
-    return WalkResult::advance();
-  });
-
-  if (walkResult.wasInterrupted())
-    return transformOp.emitSilenceableError()
-           << "could not find a unique topLevel scf.forall";
-  return DiagnosedSilenceableFailure::success();
-}
-
 } // namespace gpu
 } // namespace transform
 } // namespace mlir

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/GPUHeuristics.cpp b/mlir/lib/Dialect/Linalg/TransformOps/GPUHeuristics.cpp
index e4fa8d6bc74b74..38abdd122d633d 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/GPUHeuristics.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/GPUHeuristics.cpp
@@ -24,14 +24,14 @@ using namespace mlir;
 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
 
-static Attribute linearIdX(MLIRContext *ctx) {
-  return gpu::GPULinearIdMappingAttr::get(ctx, gpu::LinearId::DimX);
+static Attribute linearId0(MLIRContext *ctx) {
+  return gpu::GPUThreadMappingAttr::get(ctx, gpu::MappingId::LinearDim0);
 }
-static Attribute linearIdY(MLIRContext *ctx) {
-  return gpu::GPULinearIdMappingAttr::get(ctx, gpu::LinearId::DimY);
+static Attribute linearId1(MLIRContext *ctx) {
+  return gpu::GPUThreadMappingAttr::get(ctx, gpu::MappingId::LinearDim1);
 }
-static Attribute linearIdZ(MLIRContext *ctx) {
-  return gpu::GPULinearIdMappingAttr::get(ctx, gpu::LinearId::DimZ);
+static Attribute linearId2(MLIRContext *ctx) {
+  return gpu::GPUThreadMappingAttr::get(ctx, gpu::MappingId::LinearDim2);
 }
 
 transform::gpu::CopyMappingInfo::CopyMappingInfo(MLIRContext *ctx,
@@ -78,8 +78,8 @@ transform::gpu::CopyMappingInfo::CopyMappingInfo(MLIRContext *ctx,
         std::tie(size, numThreads) = pair;
         return mlir::ceilDiv(size, numThreads);
       }));
-  SmallVector<Attribute> allThreadMappings{linearIdZ(ctx), linearIdY(ctx),
-                                           linearIdX(ctx)};
+  SmallVector<Attribute> allThreadMappings{linearId2(ctx), linearId1(ctx),
+                                           linearId0(ctx)};
 
   // Set the thread mapping.
   this->threadMapping =

diff  --git a/mlir/test/Dialect/GPU/transform-gpu-failing.mlir b/mlir/test/Dialect/GPU/transform-gpu-failing.mlir
index b087816fdc7c98..564fdd46e91b17 100644
--- a/mlir/test/Dialect/GPU/transform-gpu-failing.mlir
+++ b/mlir/test/Dialect/GPU/transform-gpu-failing.mlir
@@ -89,7 +89,7 @@ func.func @map_nested_forall_to_threads_fewer_threads(%x: memref<2 x 32 x f32>,
 transform.sequence failures(propagate) {
 ^bb1(%arg0: !transform.any_op):
   %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!transform.any_op) -> !transform.any_op
-  // expected-error @below {{Trying to map to fewer GPU threads than loop iterations but overprovisioning is not yet supported. Try additional tiling of the before mapping or map to more threads.}}
+  // expected-error @below {{the number of required parallel resources (blocks or threads) 6300 overflows the number of available resources 512}}
   transform.gpu.map_nested_forall_to_threads %funcop block_dims = [128, 4, 1] : (!transform.any_op) -> !transform.any_op
 }
 
@@ -115,7 +115,7 @@ func.func @map_nested_forall_to_threads_dynamic_trip_count(%x: memref<2 x 32 x f
 transform.sequence failures(propagate) {
 ^bb1(%arg0: !transform.any_op):
   %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!transform.any_op) -> !transform.any_op
-  // expected-error @below {{unsupported dynamic sizes}}
+  // expected-error @below {{requires statically sized, normalized forall op}}
   transform.gpu.map_nested_forall_to_threads %funcop block_dims = [128, 4, 1] : (!transform.any_op) -> !transform.any_op
 }
 
@@ -135,11 +135,11 @@ func.func @map_nested_forall_to_threads_not_buffer(%x: tensor<32x32xf32>, %y: te
 transform.sequence failures(propagate) {
 ^bb1(%arg0: !transform.any_op):
   %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg0 : (!transform.any_op) -> !transform.any_op
-  %forall, %tiled = transform.structured.tile_to_forall_op %matmul num_threads [10, 20, 30] (mapping = [ #gpu.thread<y>, #gpu.thread<x>, #gpu.thread<z> ] )
+  %forall, %tiled = transform.structured.tile_to_forall_op %matmul num_threads [2, 3, 1] (mapping = [ #gpu.thread<y>, #gpu.thread<x>, #gpu.thread<z> ] )
     : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
   %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!transform.any_op) -> !transform.any_op
   // expected-error @below {{only bufferized scf.forall can be mapped}}
-  transform.gpu.map_nested_forall_to_threads %funcop block_dims = [128, 4, 1] : (!transform.any_op) -> !transform.any_op
+  transform.gpu.map_nested_forall_to_threads %funcop block_dims = [96, 4, 1] : (!transform.any_op) -> !transform.any_op
 }
 
 // -----
@@ -250,6 +250,33 @@ transform.sequence failures(propagate) {
 
 // -----
 
+!type = memref<32x32xf32>
+func.func @saxpy2d_singleloop(%x: !type, %y: !type, %stream : !gpu.async.token) -> !type {
+  %c32 = arith.constant 32 : index
+  %one = arith.constant 1 : index
+  %name = gpu.launch async[%stream] blocks(%arg3, %arg4, %arg5) in (%arg9 = %one, %arg10 = %one, %arg11 = %one)
+            threads(%arg6, %arg7, %arg8) in (%arg12 = %one, %arg13 = %one, %arg14 = %one)
+  {
+    scf.forall (%i, %j) in (%c32, %c32) {
+        %4 = memref.load %x[%i, %j] : !type
+        %5 = memref.load %y[%i, %j] : !type
+        %6 = arith.mulf %4, %5 : f32
+        memref.store %6, %y[%i, %j] : !type
+     }  { mapping = [#gpu.thread<x>, #gpu.warp<y>] }
+    gpu.terminator
+  }
+  return %y : !type
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg0: !transform.any_op):
+  %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+  // expected-error @below {{cannot mix 
diff erent mapping types, use nesting}}
+  transform.gpu.map_nested_forall_to_threads %funcop block_dims = [32, 32, 1] : (!transform.any_op) -> !transform.any_op
+}
+
+// -----
+
 !type = memref<32x32xf32>
 func.func @saxpy2d_singleloop(%x: !type, %y: !type, %stream : !gpu.async.token) -> !type {
   %c32 = arith.constant 32 : index
@@ -271,6 +298,33 @@ func.func @saxpy2d_singleloop(%x: !type, %y: !type, %stream : !gpu.async.token)
 transform.sequence failures(propagate) {
 ^bb1(%arg0: !transform.any_op):
   %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!transform.any_op) -> !transform.any_op
-  // expected-error @below {{duplicated attribute, cannot map 
diff erent loops to the same processor}}
+  // expected-error @below {{duplicate attribute, cannot map 
diff erent loops to the same mapping id}}
+  transform.gpu.map_nested_forall_to_threads %funcop block_dims = [32, 32, 1] : (!transform.any_op) -> !transform.any_op
+}
+
+// -----
+
+!type = memref<32x32xf32>
+func.func @saxpy2d_singleloop(%x: !type, %y: !type, %stream : !gpu.async.token) -> !type {
+  %c32 = arith.constant 32 : index
+  %one = arith.constant 1 : index
+  %name = gpu.launch async[%stream] blocks(%arg3, %arg4, %arg5) in (%arg9 = %one, %arg10 = %one, %arg11 = %one)
+            threads(%arg6, %arg7, %arg8) in (%arg12 = %one, %arg13 = %one, %arg14 = %one)
+  {
+    scf.forall (%i, %j) in (%c32, %c32) {
+        %4 = memref.load %x[%i, %j] : !type
+        %5 = memref.load %y[%i, %j] : !type
+        %6 = arith.mulf %4, %5 : f32
+        memref.store %6, %y[%i, %j] : !type
+     }  { mapping = [#gpu.thread<x>, #gpu.thread<linear_dim_0>] }
+    gpu.terminator
+  }
+  return %y : !type
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg0: !transform.any_op):
+  %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+  // expected-error @below {{cannot mix linear and non-linear mapping modes}}
   transform.gpu.map_nested_forall_to_threads %funcop block_dims = [32, 32, 1] : (!transform.any_op) -> !transform.any_op
 }

diff  --git a/mlir/test/Dialect/GPU/transform-gpu.mlir b/mlir/test/Dialect/GPU/transform-gpu.mlir
index ba37a41eddc688..de42c266c34f45 100644
--- a/mlir/test/Dialect/GPU/transform-gpu.mlir
+++ b/mlir/test/Dialect/GPU/transform-gpu.mlir
@@ -3,11 +3,11 @@
 !type = memref<2 x 32 x f32>
 !type1d = memref<32 x f32>
 
-// CHECK-LABEL: func.func @saxpy2dblock(
+// CHECK-LABEL: func.func @blocks_3d(
 // CHECK-SAME:    %[[ARGX:[0-9a-z]+]]: memref<2x32xf32>
 // CHECK-SAME:    %[[ARGY:[0-9a-z]+]]: memref<2x32xf32>
 // CHECK-SAME:    %[[ARGT:[0-9a-z]+]]: memref<32xf32>
-func.func @saxpy2dblock(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %stream : !gpu.async.token) -> !type {
+func.func @blocks_3d(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %stream : !gpu.async.token) -> !type {
   %c9 = arith.constant 9 : index
   %c7 = arith.constant 7 : index
   %one = arith.constant 1 : index
@@ -41,11 +41,112 @@ transform.sequence failures(propagate) {
 !type = memref<2 x 32 x f32>
 !type1d = memref<32 x f32>
 
-// CHECK-LABEL: func.func @saxpy2d(
+// CHECK-DAG: #[[$MAP:.*]] = affine_map<(d0) -> (d0 floordiv 128)> 
+
+// CHECK-LABEL: func.func @warpgroup_3d(
+// CHECK-SAME:    %[[ARGX:[0-9a-z]+]]: memref<2x32xf32>
+// CHECK-SAME:    %[[ARGY:[0-9a-z]+]]: memref<2x32xf32>
+// CHECK-SAME:    %[[ARGT:[0-9a-z]+]]: memref<32xf32>
+func.func @warpgroup_3d(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %stream : !gpu.async.token) -> !type {
+  %c1 = arith.constant 1 : index
+  %c3 = arith.constant 3 : index
+  %one = arith.constant 1 : index
+  // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+  // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+  // CHECK-DAG: %[[C384:.*]] = arith.constant 384 : index
+  // CHECK-DAG: %[[C512:.*]] = arith.constant 512 : index
+
+//      CHECK:   gpu.launch
+//      CHECK:   %[[TIDX:.*]] = gpu.thread_id  x
+//      CHECK:   %[[TIDY:.*]] = gpu.thread_id  y
+//  CHECK-DAG:   %[[WG:.*]] = affine.apply #[[$MAP]](%[[TIDX]])
+//  CHECK-DAG:   %[[CMPX:.*]] = arith.cmpi ult, %[[TIDX]], %[[C384]] : index
+//  CHECK-DAG:   %[[CMPY:.*]] = arith.cmpi ult, %[[TIDY]], %[[C1]] : index
+//      CHECK:   %[[COND:.*]] = arith.andi %[[CMPX]], %[[CMPY]] : i1
+//      CHECK:   scf.if %[[COND]]
+//      CHECK:     memref.load %[[ARGX]][%[[WG]], %[[TIDY]]]
+//      CHECK:     memref.load %[[ARGY]][%[[WG]], %[[TIDY]]]
+  %name = gpu.launch async[%stream] blocks(%arg3, %arg4, %arg5) in (%arg9 = %one, %arg10 = %one, %arg11 = %one)
+            threads(%arg6, %arg7, %arg8) in (%arg12 = %one, %arg13 = %one, %arg14 = %one)
+  {
+    scf.forall (%i, %j) in (%c3, %c1) {
+        %4 = memref.load %x[%i, %j] : !type
+        %5 = memref.load %y[%i, %j] : !type
+        %6 = math.fma %alpha, %4, %5 : f32
+        memref.store %6, %y[%i, %j] : !type
+     }  { mapping = [#gpu.warpgroup<x>, #gpu.warpgroup<y>]}
+    gpu.terminator
+  }
+  return %y : !type
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg0: !transform.any_op):
+  %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+  transform.gpu.map_nested_forall_to_threads %funcop block_dims = [512, 2, 1] : (!transform.any_op) -> !transform.any_op
+}
+
+// -----
+
+!type = memref<2 x 32 x f32>
+!type1d = memref<32 x f32>
+
+// CHECK-DAG: #[[$MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)> 
+
+// CHECK-LABEL: func.func @warp_3d(
+// CHECK-SAME:    %[[ARGX:[0-9a-z]+]]: memref<2x32xf32>
+// CHECK-SAME:    %[[ARGY:[0-9a-z]+]]: memref<2x32xf32>
+// CHECK-SAME:    %[[ARGT:[0-9a-z]+]]: memref<32xf32>
+func.func @warp_3d(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %stream : !gpu.async.token) -> !type {
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %one = arith.constant 1 : index
+  // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+  // CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
+  // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+  // CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
+  // CHECK-DAG: %[[c64:.*]] = arith.constant 64 : index
+
+//      CHECK:   gpu.launch
+//      CHECK:   %[[TIDX:.*]] = gpu.thread_id  x
+//      CHECK:   %[[TIDY:.*]] = gpu.thread_id  y
+//  CHECK-DAG:   %[[W:.*]] = affine.apply #[[$MAP]](%[[TIDX]])
+//  CHECK-DAG:   %[[CMPX:.*]] = arith.cmpi ult, %[[TIDX]], %[[C32]] : index
+//  CHECK-DAG:   %[[CMPY:.*]] = arith.cmpi ult, %[[TIDY]], %[[C3]] : index
+//      CHECK:   %[[COND:.*]] = arith.andi %[[CMPX]], %[[CMPY]] : i1
+//      CHECK:   scf.if %[[COND]]
+//      CHECK:     memref.load %[[ARGX]][%[[W]], %[[TIDY]]]
+//      CHECK:     memref.load %[[ARGY]][%[[W]], %[[TIDY]]]
+  %name = gpu.launch async[%stream] blocks(%arg3, %arg4, %arg5) in (%arg9 = %one, %arg10 = %one, %arg11 = %one)
+            threads(%arg6, %arg7, %arg8) in (%arg12 = %one, %arg13 = %one, %arg14 = %one)
+  {
+    scf.forall (%i, %j, %k) in (%c2, %c3, %c3) {
+        %4 = memref.load %x[%i, %j] : !type
+        %5 = memref.load %y[%i, %j] : !type
+        %6 = math.fma %alpha, %4, %5 : f32
+        memref.store %6, %y[%i, %j] : !type
+     }  { mapping = [#gpu.warp<x>, #gpu.warp<y>, #gpu.warp<z>]}
+    gpu.terminator
+  }
+  return %y : !type
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg0: !transform.any_op):
+  %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+  transform.gpu.map_nested_forall_to_threads %funcop block_dims = [64, 4, 3] warp_size = 16: (!transform.any_op) -> !transform.any_op
+}
+
+// -----
+
+!type = memref<2 x 32 x f32>
+!type1d = memref<32 x f32>
+
+// CHECK-LABEL: func.func @threads_3d(
 // CHECK-SAME:    %[[ARGX:[0-9a-z]+]]: memref<2x32xf32>
 // CHECK-SAME:    %[[ARGY:[0-9a-z]+]]: memref<2x32xf32>
 // CHECK-SAME:    %[[ARGT:[0-9a-z]+]]: memref<32xf32>
-func.func @saxpy2d(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %stream : !gpu.async.token) -> !type {
+func.func @threads_3d(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %stream : !gpu.async.token) -> !type {
   %one = arith.constant 1 : index
   %c12 = arith.constant 12 : index
   %c9 = arith.constant 9 : index
@@ -231,31 +332,142 @@ transform.sequence failures(propagate) {
   transform.gpu.map_nested_forall_to_threads %funcop block_dims = [12, 9, 1] sync_after_distribute = false : (!transform.any_op) -> !transform.any_op
 }
 
+
+// -----
+
+!type = memref<2 x 32 x f32>
+!type1d = memref<32 x f32>
+
+// CHECK-DAG: #[[$MAPWGLIN:.*]] = affine_map<(d0, d1, d2) -> (d0 + d1 * 32 + d2 * 256)>
+// CHECK-DAG: #[[$MAPWGX:.*]] = affine_map<(d0, d1) -> (((d0 + d1 * 32) floordiv 128) mod 2)>
+// CHECK-DAG: #[[$MAPWGY:.*]] = affine_map<(d0, d1, d2) -> (d2 + ((d0 + d1 * 32) floordiv 128) floordiv 2)>
+
+// CHECK-LABEL: func.func @warpgroup_linear(
+// CHECK-SAME:    %[[ARGX:[0-9a-z]+]]: memref<2x32xf32>
+// CHECK-SAME:    %[[ARGY:[0-9a-z]+]]: memref<2x32xf32>
+// CHECK-SAME:    %[[ARGT:[0-9a-z]+]]: memref<32xf32>
+func.func @warpgroup_linear(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %stream : !gpu.async.token) -> !type {
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %one = arith.constant 1 : index
+
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C768:.*]] = arith.constant 768 : index
+// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
+// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+
+// CHECK-DAG: %[[TIDX:.*]] = gpu.thread_id  x
+// CHECK-DAG: %[[TIDY:.*]] = gpu.thread_id  y
+// CHECK-DAG: %[[TIDZ:.*]] = gpu.thread_id  z
+// CHECK-DAG: %[[WIDLIN:.*]] = affine.apply #[[$MAPWGLIN]](%[[TIDX]], %[[TIDY]], %[[TIDZ]])
+// CHECK-DAG: %[[WIDX:.*]] = affine.apply #[[$MAPWGX]](%[[TIDX]], %[[TIDY]])
+// CHECK-DAG: %[[WIDY:.*]] = affine.apply #[[$MAPWGY]](%[[TIDX]], %[[TIDY]], %[[TIDZ]])
+// CHECK-DAG: %[[CMPLIN:.*]] = arith.cmpi ult, %[[WIDLIN]], %[[C768]] : index
+//     CHECK: scf.if %[[CMPLIN]]
+//      CHECK:   memref.load %[[ARGX]][%[[WIDX]], %[[WIDY]]]
+//      CHECK:   memref.load %[[ARGY]][%[[WIDX]], %[[WIDY]]]
+  %name = gpu.launch async[%stream] blocks(%arg3, %arg4, %arg5) in (%arg9 = %one, %arg10 = %one, %arg11 = %one)
+            threads(%arg6, %arg7, %arg8) in (%arg12 = %one, %arg13 = %one, %arg14 = %one)
+  {
+    scf.forall (%i, %j) in (%c2, %c3) {
+        %4 = memref.load %x[%i, %j] : !type
+        %5 = memref.load %y[%i, %j] : !type
+        %6 = math.fma %alpha, %4, %5 : f32
+        memref.store %6, %y[%i, %j] : !type
+     }  { mapping = [#gpu.warpgroup<linear_dim_0>, #gpu.warpgroup<linear_dim_1>]}
+    gpu.terminator
+  }
+  return %y : !type
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg0: !transform.any_op):
+  %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+  transform.gpu.map_nested_forall_to_threads %funcop block_dims = [32, 8, 4] : (!transform.any_op) -> !transform.any_op
+}
+
 // -----
 
 !type = memref<2 x 32 x f32>
 !type1d = memref<32 x f32>
 
-// CHECK-DAG: #[[$MAPWX:.*]] = affine_map<(d0, d1) -> (((d0 + d1 * 12) floordiv 32) mod 3)>
-// CHECK-DAG: #[[$MAPWY:.*]] = affine_map<(d0, d1) -> ((((d0 + d1 * 12) floordiv 32) mod 6) floordiv 3)>
+// CHECK-DAG: #[[$MAPWLIN:.*]] = affine_map<(d0, d1, d2) -> (d0 + d1 * 32 + d2 * 256)>
+// CHECK-DAG: #[[$MAPWX:.*]] = affine_map<(d0, d1, d2) -> ((d1 + d2 * 8 + d0 floordiv 32) mod 2)>
+// CHECK-DAG: #[[$MAPWY:.*]] = affine_map<(d0, d1, d2) -> ((d1 + d2 * 8 + d0 floordiv 32) floordiv 2)>
+
+// CHECK-LABEL: func.func @warp_linear(
+// CHECK-SAME:    %[[ARGX:[0-9a-z]+]]: memref<2x32xf32>
+// CHECK-SAME:    %[[ARGY:[0-9a-z]+]]: memref<2x32xf32>
+// CHECK-SAME:    %[[ARGT:[0-9a-z]+]]: memref<32xf32>
+func.func @warp_linear(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %stream : !gpu.async.token) -> !type {
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %one = arith.constant 1 : index
 
-// CHECK-DAG: #[[$MAPLIN:.*]] = affine_map<(d0, d1) -> (d0 + d1 * 12)>
-// CHECK-DAG: #[[$MAPLX:.*]] = affine_map<(d0, d1) -> ((d0 + d1 * 12) mod 10)>
-// CHECK-DAG: #[[$MAPLY:.*]] = affine_map<(d0, d1) -> (((d0 + d1 * 12) mod 20) floordiv 10)>
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
+// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG: %[[C192:.*]] = arith.constant 192 : index
+
+// CHECK-DAG: %[[TIDX:.*]] = gpu.thread_id  x
+// CHECK-DAG: %[[TIDY:.*]] = gpu.thread_id  y
+// CHECK-DAG: %[[TIDZ:.*]] = gpu.thread_id  z
+// CHECK-DAG: %[[WIDLIN:.*]] = affine.apply #[[$MAPWLIN]](%[[TIDX]], %[[TIDY]], %[[TIDZ]])
+// CHECK-DAG: %[[WIDX:.*]] = affine.apply #[[$MAPWX]](%[[TIDX]], %[[TIDY]], %[[TIDZ]])
+// CHECK-DAG: %[[WIDY:.*]] = affine.apply #[[$MAPWY]](%[[TIDX]], %[[TIDY]], %[[TIDZ]])
+// CHECK-DAG: %[[CMPLIN:.*]] = arith.cmpi ult, %[[WIDLIN]], %[[C192]] : index
+//     CHECK: scf.if %[[CMPLIN]]
+//      CHECK:   memref.load %[[ARGX]][%[[WIDX]], %[[WIDY]]]
+//      CHECK:   memref.load %[[ARGY]][%[[WIDX]], %[[WIDY]]]
+  %name = gpu.launch async[%stream] blocks(%arg3, %arg4, %arg5) in (%arg9 = %one, %arg10 = %one, %arg11 = %one)
+            threads(%arg6, %arg7, %arg8) in (%arg12 = %one, %arg13 = %one, %arg14 = %one)
+  {
+    scf.forall (%i, %j) in (%c2, %c3) {
+        %4 = memref.load %x[%i, %j] : !type
+        %5 = memref.load %y[%i, %j] : !type
+        %6 = math.fma %alpha, %4, %5 : f32
+        memref.store %6, %y[%i, %j] : !type
+     }  { mapping = [#gpu.warp<linear_dim_0>, #gpu.warp<linear_dim_1>]}
+    gpu.terminator
+  }
+  return %y : !type
+}
 
-// CHECK-LABEL: func.func @map_multi_level(
-func.func @map_multi_level(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %stream : !gpu.async.token) -> !type {
+transform.sequence failures(propagate) {
+^bb1(%arg0: !transform.any_op):
+  %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+  transform.gpu.map_nested_forall_to_threads %funcop block_dims = [32, 8, 4] : (!transform.any_op) -> !transform.any_op
+}
+
+// -----
+
+!type = memref<2 x 32 x f32>
+!type1d = memref<32 x f32>
+
+// CHECK-DAG: #[[$MAPWX:.*]] = affine_map<(d0, d1) -> (((d0 + d1 * 18) floordiv 32) mod 3)>
+// CHECK-DAG: #[[$MAPWY:.*]] = affine_map<(d0, d1) -> ((((d0 + d1 * 18) floordiv 32) mod 6) floordiv 3)>
+
+// CHECK-DAG: #[[$MAPLIN:.*]] = affine_map<(d0, d1) -> (d0 + d1 * 18)>
+// CHECK-DAG: #[[$MAPLX:.*]] = affine_map<(d0, d1) -> ((d0 + d1 * 18) mod 10)>
+// CHECK-DAG: #[[$MAPLY:.*]] = affine_map<(d0, d1) -> ((d0 + d1 * 18) floordiv 10)>
+
+// CHECK-LABEL: func.func @map_multi_level_linear(
+func.func @map_multi_level_linear(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %stream : !gpu.async.token) -> !type {
   %one = arith.constant 1 : index
   %c10 = arith.constant 10 : index
   %c9 = arith.constant 9 : index
   %c7 = arith.constant 7 : index
   %c1 = arith.constant 1 : index
   %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
 
   // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
   // CHECK-DAG: %[[C11:.*]] = arith.constant 11 : index
-  // CHECK-DAG: %[[C12:.*]] = arith.constant 12 : index
+  // CHECK-DAG: %[[C18:.*]] = arith.constant 18 : index
   // CHECK-DAG: %[[C20:.*]] = arith.constant 20 : index
+  // CHECK-DAG: %[[C192:.*]] = arith.constant 192 : index
 
   // check that both the thread level and the warp level got distributed.
   //  CHECK-NOT: #gpu.thread
@@ -272,19 +484,17 @@ func.func @map_multi_level(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %str
       memref.store %6, %y[%i, %j] : !type
     }  { mapping = [#gpu.thread<y>, #gpu.thread<x>]}
 
+    // CHECK-DAG: %[[LIN:.*]] = affine.apply #[[$MAPLIN]](%[[TIDX]], %[[TIDY]])
     // CHECK-DAG: %[[WIDX:.*]] = affine.apply #[[$MAPWX]](%[[TIDX]], %[[TIDY]])
     // CHECK-DAG: %[[WIDY:.*]] = affine.apply #[[$MAPWY]](%[[TIDX]], %[[TIDY]])
-    // CHECK-DAG: %[[CMPX:.*]] = arith.cmpi ult, %[[WIDX]], %[[C1]] : index
-    // CHECK-DAG: %[[CMPY:.*]] = arith.cmpi ult, %[[WIDY]], %[[C1]] : index
-    //     CHECK: %[[COND:.*]] = arith.andi %[[CMPX]], %[[CMPY]] : i1
-    //     CHECK: scf.if %[[COND]]
-    scf.forall (%i) in (%c1) {
-        %7 = memref.load %t[%i] : !type1d
+    // CHECK-DAG: %[[CMPLIN:.*]] = arith.cmpi ult, %[[LIN]], %[[C192]] : index
+    //     CHECK: scf.if %[[CMPLIN]]
+    scf.forall (%i, %j, %k) in (%c3, %c2, %c1) {
+        %7 = memref.load %x[%i, %j] : !type
         %8 = arith.addf %alpha, %7 : f32
-        memref.store %8, %t[%i] : !type1d
-     }  {mapping = [#gpu.warp<x>] }
+        memref.store %8, %y[%i, %j] : !type
+     }  {mapping = [#gpu.warp<linear_dim_0>, #gpu.warp<linear_dim_1>, #gpu.warp<linear_dim_2>] }
 
-    // CHECK-DAG: %[[LIN:.*]] = affine.apply #[[$MAPLIN]](%[[TIDX]], %[[TIDY]])
     // CHECK-DAG: %[[LIDX:.*]] = affine.apply #[[$MAPLX]](%[[TIDX]], %[[TIDY]])
     // CHECK-DAG: %[[LIDY:.*]] = affine.apply #[[$MAPLY]](%[[TIDX]], %[[TIDY]])
     // CHECK-DAG: %[[COND:.*]] = arith.cmpi ult, %[[LIN]], %[[C20]] : index
@@ -295,7 +505,7 @@ func.func @map_multi_level(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %str
         %7 = memref.load %t[%i] : !type1d
         %8 = arith.addf %alpha, %7 : f32
         memref.store %8, %t[%j] : !type1d
-     }  {mapping = [#gpu.linear<x>, #gpu.linear<y>] }
+     }  {mapping = [#gpu.thread<linear_dim_0>, #gpu.thread<linear_dim_1>] }
     gpu.terminator
   }
   return %y : !type
@@ -305,41 +515,151 @@ transform.sequence failures(propagate) {
 ^bb1(%arg0: !transform.any_op):
   %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!transform.any_op) -> !transform.any_op
   transform.gpu.map_nested_forall_to_threads %funcop
-    block_dims = [12, 11, 1] warp_dims = [3, 2, 1] : (!transform.any_op) -> !transform.any_op
+    block_dims = [18, 11, 1] : (!transform.any_op) -> !transform.any_op
 }
 
 // -----
 
-// CHECK-LABEL: func.func @tiling_buffer_semantic_op(
-//       CHECK:   gpu.launch {{.*}} {
-//       CHECK:     scf.forall {{.*}} {
-//       CHECK:       memref.subview
-//       CHECK:       memref.subview
-//       CHECK:       linalg.generic
-//       CHECK:     }
-//       CHECK:   }
-func.func @tiling_buffer_semantic_op(%x: memref<32x32xf32>, %y: memref<32x32xf32>, %stream : !gpu.async.token) {
+!type = memref<2 x 32 x f32>
+!type1d = memref<32 x f32>
+
+// CHECK-DAG: #[[$MAPBLIN:.*]] = affine_map<(d0, d1, d2) -> (d0 + d1 * 12 + d2 * 108)>
+// CHECK-DAG: #[[$MAPBX:.*]] = affine_map<(d0, d1, d2) -> ((d0 + d1 * 12 + d2 * 108) mod 7)>
+// CHECK-DAG: #[[$MAPBY:.*]] = affine_map<(d0, d1, d2) -> ((d0 + d1 * 12 + d2 * 108) floordiv 7)>
+
+// CHECK-LABEL: func.func @block_linear_existing_launch(
+// CHECK-SAME:    %[[ARGX:[0-9a-z]+]]: memref<2x32xf32>
+// CHECK-SAME:    %[[ARGY:[0-9a-z]+]]: memref<2x32xf32>
+// CHECK-SAME:    %[[ARGT:[0-9a-z]+]]: memref<32xf32>
+func.func @block_linear_existing_launch(
+    %x: !type, %y: !type, %t: !type1d, %alpha : f32, %stream : !gpu.async.token) -> !type {
+  %c9 = arith.constant 9 : index
+  %c7 = arith.constant 7 : index
   %one = arith.constant 1 : index
+  // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+  // CHECK-DAG: %[[C9:.*]] = arith.constant 9 : index
+  // CHECK-DAG: %[[C12:.*]] = arith.constant 12 : index
+  // CHECK-DAG: %[[C63:.*]] = arith.constant 63 : index
+//      CHECK:   gpu.launch async [{{.*}}] blocks({{.*}}) in (%{{.*}} = %[[C12]], %{{.*}} = %[[C9]], %{{.*}} = %[[C1]]) threads
+//  CHECK-DAG: %[[BIDX:.*]] = gpu.block_id  x
+//  CHECK-DAG: %[[BIDY:.*]] = gpu.block_id  y
+//  CHECK-DAG: %[[BIDZ:.*]] = gpu.block_id  z
+//  CHECK-DAG: %[[BIDLIN:.*]] = affine.apply #[[$MAPBLIN]](%[[BIDX]], %[[BIDY]], %[[BIDZ]])
+//  CHECK-DAG: %[[BLX:.*]] = affine.apply #[[$MAPBX]](%[[BIDX]], %[[BIDY]], %[[BIDZ]])
+//  CHECK-DAG: %[[BLY:.*]] = affine.apply #[[$MAPBY]](%[[BIDX]], %[[BIDY]], %[[BIDZ]])
+//  CHECK-DAG: %[[CMPLIN:.*]] = arith.cmpi ult, %[[BIDLIN]], %[[C63]] : index
+//     CHECK: scf.if %[[CMPLIN]]
+//      CHECK:   memref.load %[[ARGX]][%[[BLX]], %[[BLY]]]
+//      CHECK:   memref.load %[[ARGY]][%[[BLX]], %[[BLY]]]
   %name = gpu.launch async[%stream] blocks(%arg3, %arg4, %arg5) in (%arg9 = %one, %arg10 = %one, %arg11 = %one)
             threads(%arg6, %arg7, %arg8) in (%arg12 = %one, %arg13 = %one, %arg14 = %one)
   {
-    linalg.generic
-      {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
-                        affine_map<(d0, d1) -> (d0, d1)>],
-       iterator_types = ["parallel", "parallel"]}
-      ins(%x : memref<32x32xf32>)
-      outs(%y : memref<32x32xf32>) {
-        ^bb0(%in: f32, %out: f32):
-          linalg.yield %in : f32
-    }
+    scf.forall (%i, %j) in (%c7, %c9) {
+        %4 = memref.load %x[%i, %j] : !type
+        %5 = memref.load %y[%i, %j] : !type
+        %6 = math.fma %alpha, %4, %5 : f32
+        memref.store %6, %y[%i, %j] : !type
+     }  { mapping = [#gpu.block<linear_dim_0>, #gpu.block<linear_dim_1>]}
     gpu.terminator
   }
-  return
+  return %y : !type
 }
 
 transform.sequence failures(propagate) {
 ^bb1(%arg0: !transform.any_op):
-  %matmul = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
-  %forall, %tiled = transform.structured.tile_to_forall_op %matmul num_threads [10, 20, 30] (mapping = [ #gpu.thread<y>, #gpu.thread<x>, #gpu.thread<z> ] )
-    : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+  %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+  transform.gpu.map_forall_to_blocks %funcop grid_dims = [12, 9, 1] : (!transform.any_op) -> !transform.any_op
+}
+
+// -----
+
+!type = memref<2 x 32 x f32>
+!type1d = memref<32 x f32>
+
+// CHECK-DAG: #[[$MAPBX:.*]] = affine_map<(d0) -> (d0 mod 7)>
+// CHECK-DAG: #[[$MAPBY:.*]] = affine_map<(d0, d1, d2) -> (d1 + d2 * 9 + d0 floordiv 7)>
+
+// CHECK-LABEL: func.func @block_linear_generate_launch(
+// CHECK-SAME:    %[[ARGX:[0-9a-z]+]]: memref<2x32xf32>
+// CHECK-SAME:    %[[ARGY:[0-9a-z]+]]: memref<2x32xf32>
+// CHECK-SAME:    %[[ARGT:[0-9a-z]+]]: memref<32xf32>
+func.func @block_linear_generate_launch(
+    %x: !type, %y: !type, %t: !type1d, %alpha : f32, %stream : !gpu.async.token) -> !type {
+  %c9 = arith.constant 9 : index
+  %c7 = arith.constant 7 : index
+  %one = arith.constant 1 : index
+
+  // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+  // CHECK-DAG: %[[C7:.*]] = arith.constant 7 : index
+  // CHECK-DAG: %[[C9:.*]] = arith.constant 9 : index
+//      CHECK:   gpu.launch blocks({{.*}}) in (%{{.*}} = %[[C7]], %{{.*}} = %[[C9]], %{{.*}} = %[[C1]]) threads
+//  CHECK-DAG: %[[BIDX:.*]] = gpu.block_id  x
+//  CHECK-DAG: %[[BIDY:.*]] = gpu.block_id  y
+//  CHECK-DAG: %[[BIDZ:.*]] = gpu.block_id  z
+//  CHECK-DAG: %[[BLX:.*]] = affine.apply #[[$MAPBX]](%[[BIDX]])
+//  CHECK-DAG: %[[BLY:.*]] = affine.apply #[[$MAPBY]](%[[BIDX]], %[[BIDY]], %[[BIDZ]])
+//      CHECK:   memref.load %[[ARGX]][%[[BLX]], %[[BLY]]]
+//      CHECK:   memref.load %[[ARGY]][%[[BLX]], %[[BLY]]]
+  scf.forall (%i, %j) in (%c7, %c9) {
+    %4 = memref.load %x[%i, %j] : !type
+    %5 = memref.load %y[%i, %j] : !type
+    %6 = math.fma %alpha, %4, %5 : f32
+    memref.store %6, %y[%i, %j] : !type
+  }  { mapping = [#gpu.block<linear_dim_0>, #gpu.block<linear_dim_1>]}
+
+  return %y : !type
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg0: !transform.any_op):
+  %funcop = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+  transform.gpu.map_forall_to_blocks %funcop generate_gpu_launch : (!transform.any_op) -> !transform.any_op
+}
+
+// -----
+
+#map = affine_map<(d0) -> (d0 *  128)>                             
+#map1 = affine_map<(d0) -> (d0 * 32)> 
+
+// CHECK-DAG: #[[$MAPB:.*]] = affine_map<(d0) -> (d0 * 128)>
+// CHECK-DAG: #[[$MAPW:.*]] = affine_map<(d0, d1, d2) -> (d2 * 32 + ((d0 + d1 * 4) floordiv 32) * 32)>
+
+// CHECK-LABEL: func.func @simple_fill(
+func.func @simple_fill(%arg0: memref<128xf32>) -> memref<128xf32> {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant dense<0.000000e+00> : vector<32xf32>
+//       CHECK:   %[[C1:.*]] = arith.constant 1 : index
+//       CHECK:   %[[C4:.*]] = arith.constant 4 : index       
+//       CHECK:   %[[C8:.*]] = arith.constant 8 : index        
+//       CHECK:   gpu.launch 
+  scf.forall (%arg1) in (1) {
+//       CHECK:     %[[BIDX:.*]] = gpu.block_id  x
+//       CHECK:     %[[BLX:.*]] = affine.apply #[[$MAPB]](%[[BIDX]])
+    %0 = affine.apply #map(%arg1)
+    %subview = memref.subview %arg0[%0] [128] [1] : memref<128xf32> to memref<128xf32, strided<[1], offset: ?>>
+    scf.forall (%arg2) in (4) {
+//       CHECK:     %[[TIDX:.*]] = gpu.thread_id  x
+//       CHECK:     %[[TIDY:.*]] = gpu.thread_id  y
+//       CHECK:     %[[TIDZ:.*]] = gpu.thread_id  z
+//       CHECK:     %[[THX:.*]] = affine.apply #[[$MAPW]](%[[TIDX]], %[[TIDY]], %[[TIDZ]])
+//   CHECK-NOT:     scf.if
+//       CHECK:       memref.subview %{{.*}}[%[[THX]]]
+      %1 = affine.apply #map1(%arg2)
+      %subview_0 = memref.subview %subview[%1] [32] [1] : memref<128xf32, strided<[1], offset: ?>> to memref<32xf32, strided<[1], offset: ?>>
+      vector.transfer_write %cst, %subview_0[%c0] {in_bounds = [true]} : vector<32xf32>, memref<32xf32, strided<[1], offset: ?>>
+      memref.copy %subview_0, %subview_0 : memref<32xf32, strided<[1], offset: ?>> to memref<32xf32, strided<[1], offset: ?>>
+    } {mapping = [#gpu.warp<linear_dim_0>]}
+    memref.copy %subview, %subview : memref<128xf32, strided<[1], offset: ?>> to memref<128xf32, strided<[1], offset: ?>>
+  } {mapping = [#gpu.block<x>]}
+  return %arg0 : memref<128xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !transform.any_op):
+  %func = transform.structured.match ops{["func.func"]} in %module_op 
+    : (!transform.any_op) -> !transform.any_op
+  %gpu_launch = transform.gpu.map_forall_to_blocks %func generate_gpu_launch 
+    : (!transform.any_op) -> !transform.any_op
+  transform.gpu.map_nested_forall_to_threads %gpu_launch block_dims = [4, 8, 4] 
+    : (!transform.any_op) -> !transform.any_op
 }

diff  --git a/mlir/test/Dialect/Linalg/transform-op-gpu-map-copy-to-threads.mlir b/mlir/test/Dialect/Linalg/transform-op-gpu-map-copy-to-threads.mlir
index d8e79880f69b9d..c9657493c245db 100644
--- a/mlir/test/Dialect/Linalg/transform-op-gpu-map-copy-to-threads.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-gpu-map-copy-to-threads.mlir
@@ -9,7 +9,7 @@ func.func @copy_1d_8xf16(%t0: !tt, %out: !tt) -> !tt {
   /// minor transfer size -> 1 thread.
   // CHECK: scf.forall {{.*}} in (1) {{.*}}
   // CHECK:   linalg.copy {{.*}} -> tensor<8xf16>
-  // CHECK: {mapping = [#gpu.linear<x>]}
+  // CHECK: {mapping = [#gpu.thread<linear_dim_0>]}
   %0 = linalg.copy ins(%t0: !tt) outs(%out: !tt) -> !tt 
   return %0 : !tt
 }
@@ -36,7 +36,7 @@ func.func @pad_1d_8xf16(%t0: !tin, %sz: index) -> !tt {
   // CHECK: scf.forall {{.*}} in (1) {{.*}}
   // CHECK:   %[[padded:.*]] = tensor.pad {{.*}}
   // CHECK:   tensor.cast %[[padded]] : tensor<?xf16> to tensor<8xf16>
-  // CHECK: {mapping = [#gpu.linear<x>]}
+  // CHECK: {mapping = [#gpu.thread<linear_dim_0>]}
   %0 = tensor.pad %t0 low[0] high[%sz] {
   ^bb0(%arg0: index):
     tensor.yield %cst : f16
@@ -63,7 +63,7 @@ func.func @copy_1d_16xf16(%t0: !tt, %out: !tt) -> !tt {
   /// minor transfer size -> 2 threads.
   // CHECK: scf.forall {{.*}} in (2) {{.*}}
   // CHECK:   linalg.copy {{.*}} -> tensor<8xf16>
-  // CHECK: {mapping = [#gpu.linear<x>]}
+  // CHECK: {mapping = [#gpu.thread<linear_dim_0>]}
   %0 = linalg.copy ins(%t0: !tt) outs(%out: !tt) -> !tt 
   return %0 : !tt
 }
@@ -87,7 +87,7 @@ func.func @copy_1d_20xf16(%t0: !tt, %out: !tt) -> !tt {
   /// minor transfer size -> 5 threads.
   // CHECK: scf.forall {{.*}} in (5) {{.*}}
   // CHECK:   linalg.copy {{.*}} -> tensor<4xf16>
-  // CHECK: {mapping = [#gpu.linear<x>]}
+  // CHECK: {mapping = [#gpu.thread<linear_dim_0>]}
   %0 = linalg.copy ins(%t0: !tt) outs(%out: !tt) -> !tt 
   return %0 : !tt
 }
@@ -112,7 +112,7 @@ func.func @copy_1d_20xf16(%t0: !tt, %out: !tt) -> !tt {
   /// minor transfer size -> 5 threads.
   // CHECK: scf.forall {{.*}} in (5) {{.*}}
   // CHECK:   linalg.copy {{.*}} -> tensor<4xf16>
-  // CHECK: {mapping = [#gpu.linear<x>]}
+  // CHECK: {mapping = [#gpu.thread<linear_dim_0>]}
   %0 = linalg.copy ins(%t0: !tt) outs(%out: !tt) -> !tt 
   return %0 : !tt
 }
@@ -136,7 +136,7 @@ func.func @copy_1d_128xf16(%t0: !tt, %out: !tt) -> !tt {
   /// the transfer size to 4xf16.
   // CHECK: scf.forall {{.*}} in (32) {{.*}}
   // CHECK:   linalg.copy {{.*}} -> tensor<4xf16>
-  // CHECK: {mapping = [#gpu.linear<x>]}
+  // CHECK: {mapping = [#gpu.thread<linear_dim_0>]}
   %0 = linalg.copy ins(%t0: !tt) outs(%out: !tt) -> !tt 
   return %0 : !tt
 }
@@ -159,7 +159,7 @@ func.func @copy_1d_256xf16(%t0: !tt, %out: !tt) -> !tt {
   /// Enough data for all threads and no need for predication.
   // CHECK: scf.forall {{.*}} in (32) {{.*}}
   // CHECK:   linalg.copy {{.*}} -> tensor<8xf16>
-  // CHECK: {mapping = [#gpu.linear<x>]}
+  // CHECK: {mapping = [#gpu.thread<linear_dim_0>]}
   %0 = linalg.copy ins(%t0: !tt) outs(%out: !tt) -> !tt 
   return %0 : !tt
 }
@@ -181,7 +181,7 @@ transform.sequence failures(propagate) {
 func.func @copy_3d_16x32x64xi8(%t0: !tt, %out: !tt) -> !tt {
   // CHECK: scf.forall {{.*}} in (1, 8, 4) {{.*}}
   // CHECK:   linalg.copy {{.*}} -> tensor<16x4x16xi8>
-  // CHECK: {mapping = [#gpu.linear<z>, #gpu.linear<y>, #gpu.linear<x>]}
+  // CHECK: {mapping = [#gpu.thread<linear_dim_2>, #gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]}
   %0 = linalg.copy ins(%t0: !tt) outs(%out: !tt) -> !tt 
   return %0 : !tt
 }
@@ -203,7 +203,7 @@ transform.sequence failures(propagate) {
 func.func @copy_3d_16x32x64xi8(%t0: !tt, %out: !tt) -> !tt {
   // CHECK: scf.forall {{.*}} in (1, 4, 8) {{.*}}
   // CHECK:   linalg.copy {{.*}} -> tensor<16x8x8xi8>
-  // CHECK: {mapping = [#gpu.linear<z>, #gpu.linear<y>, #gpu.linear<x>]}
+  // CHECK: {mapping = [#gpu.thread<linear_dim_2>, #gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]}
   %0 = linalg.copy ins(%t0: !tt) outs(%out: !tt) -> !tt 
   return %0 : !tt
 }
@@ -225,7 +225,7 @@ transform.sequence failures(propagate) {
 func.func @copy_3d_4x8x16xi8(%t0: !tt, %out: !tt) -> !tt {
   // CHECK: scf.forall {{.*}} in (4, 8, 1) {{.*}}
   // CHECK:   linalg.copy {{.*}} -> tensor<1x1x16xi8>
-  // CHECK: {mapping = [#gpu.linear<z>, #gpu.linear<y>, #gpu.linear<x>]}
+  // CHECK: {mapping = [#gpu.thread<linear_dim_2>, #gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]}
   %0 = linalg.copy ins(%t0: !tt) outs(%out: !tt) -> !tt 
   return %0 : !tt
 }
@@ -247,7 +247,7 @@ transform.sequence failures(propagate) {
 func.func @copy_3d_4x8x16xi8(%t0: !tt, %out: !tt) -> !tt {
   // CHECK: scf.forall {{.*}} in (1, 2, 16) {{.*}}
   // CHECK:   linalg.copy {{.*}} -> tensor<4x4x1xi8>
-  // CHECK: {mapping = [#gpu.linear<z>, #gpu.linear<y>, #gpu.linear<x>]}
+  // CHECK: {mapping = [#gpu.thread<linear_dim_2>, #gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]}
   %0 = linalg.copy ins(%t0: !tt) outs(%out: !tt) -> !tt 
   return %0 : !tt
 }
@@ -273,7 +273,7 @@ func.func @copy_3d_3x5x7xi8(%t0: !tt, %out: !tt) -> !tt {
   // take 3.
   // CHECK: scf.forall {{.*}} in (3, 1, 7) {{.*}}
   // CHECK:   linalg.copy {{.*}} -> tensor<1x5x1xi8>
-  // CHECK: {mapping = [#gpu.linear<z>, #gpu.linear<y>, #gpu.linear<x>]}
+  // CHECK: {mapping = [#gpu.thread<linear_dim_2>, #gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]}
   %0 = linalg.copy ins(%t0: !tt) outs(%out: !tt) -> !tt 
   return %0 : !tt
 }
@@ -296,7 +296,7 @@ func.func @copy_3d_16x15x5xi8(%t0: !tt, %out: !tt) -> !tt {
   // DP mapping: 5 mandated most minor, then 3 to allow 8 on the outermost.
   // CHECK: scf.forall {{.*}} in (8, 3, 5) {{.*}}
   // CHECK:   linalg.copy {{.*}} -> tensor<2x5x1xi8>
-  // CHECK: {mapping = [#gpu.linear<z>, #gpu.linear<y>, #gpu.linear<x>]}
+  // CHECK: {mapping = [#gpu.thread<linear_dim_2>, #gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]}
   %0 = linalg.copy ins(%t0: !tt) outs(%out: !tt) -> !tt 
   return %0 : !tt
 }
@@ -319,7 +319,7 @@ func.func @copy_3d_16x15x40xi8(%t0: !tt, %out: !tt) -> !tt {
   // DP mapping: 5 mandated most minor, then 3 to allow 8 on the outermost.
   // CHECK: scf.forall {{.*}} in (8, 3, 5) {{.*}}
   // CHECK:   linalg.copy {{.*}} -> tensor<2x5x8xi8>
-  // CHECK: {mapping = [#gpu.linear<z>, #gpu.linear<y>, #gpu.linear<x>]}
+  // CHECK: {mapping = [#gpu.thread<linear_dim_2>, #gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]}
   %0 = linalg.copy ins(%t0: !tt) outs(%out: !tt) -> !tt 
   return %0 : !tt
 }


        


More information about the Mlir-commits mailing list