[Mlir-commits] [mlir] [mlir][gpu][transforms] Add support for mapping to lanes (PR #146912)

Nicolas Vasilache llvmlistbot at llvm.org
Mon Jul 7 05:57:31 PDT 2025


https://github.com/nicolasvasilache updated https://github.com/llvm/llvm-project/pull/146912

>From a9cf08972549990f7013cd6e190725ff6fa62fa5 Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nico.vasilache at amd.com>
Date: Thu, 3 Jul 2025 17:29:10 +0200
Subject: [PATCH 1/2] [mlir][gpu][transforms] Add support for mapping to lanes

Co-authored-by: Oleksandr "Alex" Zinenko <git at ozinenko.com>
---
 .../Dialect/GPU/IR/GPUDeviceMappingAttr.td    | 24 +++++++
 .../mlir/Dialect/GPU/TransformOps/Utils.h     |  9 +++
 mlir/lib/Dialect/GPU/IR/GPUDialect.cpp        | 14 ++++
 .../GPU/TransformOps/GPUTransformOps.cpp      | 17 ++++-
 mlir/lib/Dialect/GPU/TransformOps/Utils.cpp   | 67 +++++++++++++++++++
 mlir/test/Dialect/GPU/transform-gpu.mlir      | 63 +++++++++++++++++
 6 files changed, 193 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUDeviceMappingAttr.td b/mlir/include/mlir/Dialect/GPU/IR/GPUDeviceMappingAttr.td
index 6e0f6f1d78eda..63f228ca3157f 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUDeviceMappingAttr.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUDeviceMappingAttr.td
@@ -228,6 +228,30 @@ def GPUThreadMappingAttr
   }];
 }
 
+def GPULaneMappingAttr
+    : GPU_Attr<"GPULaneMapping", "lane", [
+      DeclareAttrInterfaceMethods<DeviceMappingAttrInterface> ]> {
+  let parameters = (ins
+    EnumParameter<MappingIdEnum>:$lane
+  );
+  let assemblyFormat = "`<` params `>`";
+  let description = [{
+    An attribute that allows defining lane parallelism for GPU devices.
+
+    It can be consumed by lowering to generate GPU.
+
+    #### 3D mapping mode
+
+    Unsupported
+
+    #### Linear mapping mode
+
+    The linear lane id is obtained by linearizing the index of the lane.
+    If required, predication occurs on the linear id. This allows specifying
+    predication on a 1D subset of the (linearized) lanes.
+  }];
+}
+
 def GPUMemorySpaceMappingAttr : GPU_Attr<"GPUMemorySpaceMapping", "memory_space", [
   DeclareAttrInterfaceMethods<DeviceMappingAttrInterface> ] >  {
   let parameters = (ins
diff --git a/mlir/include/mlir/Dialect/GPU/TransformOps/Utils.h b/mlir/include/mlir/Dialect/GPU/TransformOps/Utils.h
index 52fc6f4d5c71b..111c67638efc8 100644
--- a/mlir/include/mlir/Dialect/GPU/TransformOps/Utils.h
+++ b/mlir/include/mlir/Dialect/GPU/TransformOps/Utils.h
@@ -117,6 +117,15 @@ struct GpuThreadIdBuilder : public GpuIdBuilder {
   GpuThreadIdBuilder(MLIRContext *ctx, bool useLinearMapping = false);
 };
 
+/// Builder for lane id.
+/// The `idBuilder` method returns nD values used for indexing rewrites as well
+/// as 1D sizes for predicate generation.
+/// This `useLinearMapping` case is the only supported case.
+struct GpuLaneIdBuilder : public GpuIdBuilder {
+  GpuLaneIdBuilder(MLIRContext *ctx, int64_t warpSize, bool unused);
+  int64_t warpSize = 32;
+};
+
 /// Determine if the size of the kernel configuration is supported by the
 /// GPU architecture being used.
 /// TODO this is currently hardwired to CUDA, parameterize and generalize.
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index a5eb62ce66e0b..56631f1aac084 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -106,6 +106,20 @@ int64_t GPUThreadMappingAttr::getRelativeIndex() const {
              : getMappingId();
 }
 
+int64_t GPULaneMappingAttr::getMappingId() const {
+  return static_cast<int64_t>(getLane());
+}
+
+bool GPULaneMappingAttr::isLinearMapping() const {
+  return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
+}
+
+int64_t GPULaneMappingAttr::getRelativeIndex() const {
+  return isLinearMapping()
+             ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
+             : getMappingId();
+}
+
 int64_t GPUMemorySpaceMappingAttr::getMappingId() const {
   return static_cast<int64_t>(getAddressSpace());
 }
diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
index 6446235c06fb2..20d1c94409238 100644
--- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
+++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
@@ -313,11 +313,14 @@ checkMappingAttributeTypes(std::optional<TransformOpInterface> transformOp,
                                      llvm::IsaPred<GPUWarpMappingAttr>);
   bool hasThreadMapping = llvm::any_of(forallOp.getMapping().value(),
                                        llvm::IsaPred<GPUThreadMappingAttr>);
+  bool hasLaneMapping = llvm::any_of(forallOp.getMapping().value(),
+                                     llvm::IsaPred<GPULaneMappingAttr>);
   int64_t countMappingTypes = 0;
   countMappingTypes += hasBlockMapping ? 1 : 0;
   countMappingTypes += hasWarpgroupMapping ? 1 : 0;
   countMappingTypes += hasWarpMapping ? 1 : 0;
   countMappingTypes += hasThreadMapping ? 1 : 0;
+  countMappingTypes += hasLaneMapping ? 1 : 0;
   if (countMappingTypes > 1) {
     return definiteFailureHelper(
         transformOp, forallOp,
@@ -330,7 +333,8 @@ checkMappingAttributeTypes(std::optional<TransformOpInterface> transformOp,
         "scf.forall op requires a mapping attribute of kind 'block'");
   }
   if (std::is_same<MappingKindType, ThreadMappingKind>::value &&
-      !hasThreadMapping && !hasWarpMapping && !hasWarpgroupMapping) {
+      !hasLaneMapping && !hasThreadMapping && !hasWarpMapping &&
+      !hasWarpgroupMapping) {
     return definiteFailureHelper(transformOp, forallOp,
                                  "scf.forall op requires a mapping attribute "
                                  "of kind 'thread' or 'warp'");
@@ -473,10 +477,17 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(
   SmallVector<int64_t> originalBasis(availableMappingSizes);
   bool originalBasisWasProvided = !originalBasis.empty();
   if (!originalBasisWasProvided) {
+    LDBG("----originalBasis was not provided, deriving it and there will be no "
+         "predication");
     originalBasis = forallMappingSizes;
     while (originalBasis.size() < 3)
       originalBasis.push_back(1);
+  } else {
+    LDBG("----originalBasis was provided, using it, there will be predication");
   }
+  LLVM_DEBUG(
+      llvm::interleaveComma(originalBasis, DBGS() << "------originalBasis: ");
+      llvm::dbgs() << "\n");
 
   IdBuilderResult builderResult =
       gpuIdBuilder.idBuilder(rewriter, loc, forallMappingSizes, originalBasis);
@@ -490,6 +501,7 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(
            forallMappingAttrs.getArrayRef().take_front(forallOp.getRank()))) {
     auto mappingAttr = cast<DeviceMappingAttrInterface>(dim);
     Value peIdOp = mappingIdOps[mappingAttr.getRelativeIndex()];
+    LDBG("----map: " << iv << " to" << peIdOp);
     bvm.map(iv, peIdOp);
   }
 
@@ -790,6 +802,9 @@ getThreadIdBuilder(std::optional<TransformOpInterface> transformOp,
           .Case([&](GPUThreadMappingAttr) {
             return GpuThreadIdBuilder(ctx, useLinearMapping);
           })
+          .Case([&](GPULaneMappingAttr) {
+            return GpuLaneIdBuilder(ctx, warpSize, useLinearMapping);
+          })
           .Default([&](DeviceMappingAttrInterface) -> GpuIdBuilder {
             llvm_unreachable("unknown mapping attribute");
           });
diff --git a/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp b/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp
index 9853e80828390..c693a2fa01e89 100644
--- a/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp
+++ b/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp
@@ -156,6 +156,63 @@ static GpuIdBuilderFnType common3DIdBuilderFn(int64_t multiplicity = 1) {
   return res;
 }
 
+/// Create a lane id builder that takes the `originalBasis` and decompose
+/// it in the basis of `forallMappingSizes`. The linear id builder returns an
+/// n-D vector of ids for indexing and 1-D size + id for predicate generation.
+static GpuIdBuilderFnType laneIdBuilderFn(int64_t periodicity) {
+  auto res = [periodicity](RewriterBase &rewriter, Location loc,
+                           ArrayRef<int64_t> forallMappingSizes,
+                           ArrayRef<int64_t> originalBasis) {
+    SmallVector<OpFoldResult> originalBasisOfr =
+        getAsIndexOpFoldResult(rewriter.getContext(), originalBasis);
+    OpFoldResult linearId =
+        buildLinearId<ThreadIdOp>(rewriter, loc, originalBasisOfr);
+    AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
+    linearId = affine::makeComposedFoldedAffineApply(
+        rewriter, loc, d0 % periodicity, {linearId});
+
+    // Sizes in [0 .. n] -> [n .. 0] order to properly compute strides in
+    // "row-major" order.
+    SmallVector<int64_t> reverseBasisSizes(llvm::reverse(forallMappingSizes));
+    SmallVector<int64_t> strides = computeStrides(reverseBasisSizes);
+    SmallVector<AffineExpr> delinearizingExprs = delinearize(d0, strides);
+    SmallVector<Value> ids;
+    // Reverse back to be in [0 .. n] order.
+    for (AffineExpr e : llvm::reverse(delinearizingExprs)) {
+      ids.push_back(
+          affine::makeComposedAffineApply(rewriter, loc, e, {linearId}));
+    }
+
+    // clang-format off
+      LLVM_DEBUG(llvm::interleaveComma(reverseBasisSizes,
+                                       DBGS() << "--delinearization basis: ");
+                 llvm::dbgs() << "\n";
+                 llvm::interleaveComma(strides,
+                                       DBGS() << "--delinearization strides: ");
+                 llvm::dbgs() << "\n";
+                 llvm::interleaveComma(delinearizingExprs,
+                                       DBGS() << "--delinearization exprs: ");
+                 llvm::dbgs() << "\n";
+                 llvm::interleaveComma(ids, DBGS() << "--ids: ");
+                 llvm::dbgs() << "\n";);
+    // clang-format on
+
+    // Return n-D ids for indexing and 1-D size + id for predicate generation.
+    return IdBuilderResult{
+        /*mappingIdOps=*/ids,
+        /*availableMappingSizes=*/
+        SmallVector<int64_t>{computeProduct(originalBasis)},
+        // `forallMappingSizes` iterate in the scaled basis, they need to be
+        // scaled back into the original basis to provide tight
+        // activeMappingSizes quantities for predication.
+        /*activeMappingSizes=*/
+        SmallVector<int64_t>{computeProduct(forallMappingSizes)},
+        /*activeIdOps=*/SmallVector<Value>{linearId.get<Value>()}};
+  };
+
+  return res;
+}
+
 namespace mlir {
 namespace transform {
 namespace gpu {
@@ -221,6 +278,16 @@ GpuThreadIdBuilder::GpuThreadIdBuilder(MLIRContext *ctx, bool useLinearMapping)
                   : common3DIdBuilderFn<ThreadIdOp>(/*multiplicity=*/1);
 }
 
+GpuLaneIdBuilder::GpuLaneIdBuilder(MLIRContext *ctx, int64_t warpSize,
+                                   bool unused)
+    : GpuIdBuilder(ctx, /*useLinearMapping=*/true,
+                   [](MLIRContext *ctx, MappingId id) {
+                     return GPULaneMappingAttr::get(ctx, id);
+                   }),
+      warpSize(warpSize) {
+  idBuilder = laneIdBuilderFn(/*periodicity=*/warpSize);
+}
+
 DiagnosedSilenceableFailure checkGpuLimits(TransformOpInterface transformOp,
                                            std::optional<int64_t> gridDimX,
                                            std::optional<int64_t> gridDimY,
diff --git a/mlir/test/Dialect/GPU/transform-gpu.mlir b/mlir/test/Dialect/GPU/transform-gpu.mlir
index 09ae0f4af686f..fe5d451408355 100644
--- a/mlir/test/Dialect/GPU/transform-gpu.mlir
+++ b/mlir/test/Dialect/GPU/transform-gpu.mlir
@@ -691,3 +691,66 @@ module attributes {transform.with_named_sequence} {
       transform.yield
   }
 }
+
+// -----
+
+#map = affine_map<(d0) -> (d0 *  128)>
+#map1 = affine_map<(d0) -> (d0 * 32)>
+
+// CHECK-DAG: #[[$MAPB:.*]] = affine_map<()[s0] -> (s0 * 128)>
+// CHECK-DAG: #[[$MAPLANE:.*]] = affine_map<()[s0, s1] -> ((s0 + s1 * 73) mod 32)>
+// CHECK-DAG: #[[$MAPI:.*]] = affine_map<()[s0, s1] -> (s0 * 32 + s1 * 2336 - ((s0 + s1 * 73) floordiv 2) * 64)>
+// CHECK-DAG: #[[$MAPJ:.*]] = affine_map<()[s0, s1] -> ((((s0 + s1 * 73) mod 32) floordiv 2) * 32)>
+
+// CHECK-LABEL: func.func @simple_fill(
+func.func @simple_fill(%arg0: memref<128xf32>) -> memref<128xf32> {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant dense<0.000000e+00> : vector<32xf32>
+//       CHECK:   %[[C6:.*]] = arith.constant 6 : index
+//       CHECK:   gpu.launch
+  scf.forall (%arg1) in (1) {
+//       CHECK:     %[[BIDX:.*]] = gpu.block_id  x
+//       CHECK:     %[[BLX:.*]] = affine.apply #[[$MAPB]]()[%[[BIDX]]]
+    %0 = affine.apply #map(%arg1)
+    %subview = memref.subview %arg0[%0] [128] [1] : memref<128xf32> to memref<128xf32, strided<[1], offset: ?>>
+
+    // %arg2 and %arg3 map to lanes [0, 6) and are turned into epxressions
+    // involving threadIdx.x/y by the map_nested_forall_to_threads
+    // transformation. This results in a if (linear_thread_id < 6) conditional.
+    scf.forall (%arg2, %arg3) in (2, 3) {
+      //       CHECK:     %[[TIDX:.*]] = gpu.thread_id  x
+      //       CHECK:     %[[TIDY:.*]] = gpu.thread_id  y
+      //       CHECK:     %[[LID:.*]] = affine.apply #[[$MAPLANE]]()[%[[TIDX]], %[[TIDY]]]
+      //       CHECK:     %[[COND:.*]] = arith.cmpi ult, %[[LID]], %[[C6]]
+      //       CHECK:     scf.if %[[COND]]
+      //       CHECK:       %[[I:.*]] = affine.apply #[[$MAPI]]()[%[[TIDX]], %[[TIDY]]]
+      //       CHECK:       %[[J:.*]] = affine.apply #[[$MAPJ]]()[%[[TIDX]], %[[TIDY]]]
+      //       CHECK:       memref.subview %{{.*}}[%[[I]]] [%[[J]]]
+      %1 = affine.apply #map1(%arg2)
+      %2 = affine.apply #map1(%arg3)
+      %subview_0 = memref.subview %subview[%1] [%2] [1] : memref<128xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[1], offset: ?>>
+      vector.transfer_write %cst, %subview_0[%c0] {in_bounds = [true]} : vector<32xf32>, memref<?xf32, strided<[1], offset: ?>>
+
+    // This could be obtained e.g. if a previous transformation mapped this loop
+    // to lanes. This can aslo be written by hand as valid IR.
+    } {mapping = [#gpu.lane<linear_dim_0>, #gpu.lane<linear_dim_1>]}
+
+    memref.copy %subview, %subview : memref<128xf32, strided<[1], offset: ?>> to memref<128xf32, strided<[1], offset: ?>>
+  } {mapping = [#gpu.block<x>]}
+  return %arg0 : memref<128xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %module_op
+      : (!transform.any_op) -> !transform.any_op
+    %gpu_launch = transform.gpu.map_forall_to_blocks %func generate_gpu_launch
+      : (!transform.any_op) -> !transform.any_op
+
+    // This transformation maps scf.forall ivs to a particular mapping of thread
+    // ids (laneid, threadid, warpid or warpgroupid).
+    transform.gpu.map_nested_forall_to_threads %gpu_launch block_dims = [73, 5, 1]
+      : (!transform.any_op) -> !transform.any_op
+      transform.yield
+  }
+}

>From 3e86c1f1e5cf8ef4ad7a982dad9e69c47b7ee890 Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <Nico.Vasilache at amd.com>
Date: Mon, 7 Jul 2025 14:57:21 +0200
Subject: [PATCH 2/2] Update mlir/test/Dialect/GPU/transform-gpu.mlir

Co-authored-by: Oleksandr "Alex" Zinenko <git at ozinenko.com>
---
 mlir/test/Dialect/GPU/transform-gpu.mlir | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/GPU/transform-gpu.mlir b/mlir/test/Dialect/GPU/transform-gpu.mlir
index fe5d451408355..157dfa2c3e297 100644
--- a/mlir/test/Dialect/GPU/transform-gpu.mlir
+++ b/mlir/test/Dialect/GPU/transform-gpu.mlir
@@ -694,7 +694,7 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-#map = affine_map<(d0) -> (d0 *  128)>
+#map = affine_map<(d0) -> (d0 * 128)>
 #map1 = affine_map<(d0) -> (d0 * 32)>
 
 // CHECK-DAG: #[[$MAPB:.*]] = affine_map<()[s0] -> (s0 * 128)>



More information about the Mlir-commits mailing list