[Mlir-commits] [mlir] [mlir][SCF][GPU] Add DeviceMaskingAttrInterface support to scf::Foral… (PR #146943)

Nicolas Vasilache llvmlistbot at llvm.org
Mon Jul 7 06:47:10 PDT 2025


https://github.com/nicolasvasilache updated https://github.com/llvm/llvm-project/pull/146943

>From a9cf08972549990f7013cd6e190725ff6fa62fa5 Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nico.vasilache at amd.com>
Date: Thu, 3 Jul 2025 17:29:10 +0200
Subject: [PATCH 1/4] [mlir][gpu][transforms] Add support for mapping to lanes

Co-authored-by: Oleksandr "Alex" Zinenko <git at ozinenko.com>
---
 .../Dialect/GPU/IR/GPUDeviceMappingAttr.td    | 24 +++++++
 .../mlir/Dialect/GPU/TransformOps/Utils.h     |  9 +++
 mlir/lib/Dialect/GPU/IR/GPUDialect.cpp        | 14 ++++
 .../GPU/TransformOps/GPUTransformOps.cpp      | 17 ++++-
 mlir/lib/Dialect/GPU/TransformOps/Utils.cpp   | 67 +++++++++++++++++++
 mlir/test/Dialect/GPU/transform-gpu.mlir      | 63 +++++++++++++++++
 6 files changed, 193 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUDeviceMappingAttr.td b/mlir/include/mlir/Dialect/GPU/IR/GPUDeviceMappingAttr.td
index 6e0f6f1d78eda..63f228ca3157f 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUDeviceMappingAttr.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUDeviceMappingAttr.td
@@ -228,6 +228,30 @@ def GPUThreadMappingAttr
   }];
 }
 
+def GPULaneMappingAttr
+    : GPU_Attr<"GPULaneMapping", "lane", [
+      DeclareAttrInterfaceMethods<DeviceMappingAttrInterface> ]> {
+  let parameters = (ins
+    EnumParameter<MappingIdEnum>:$lane
+  );
+  let assemblyFormat = "`<` params `>`";
+  let description = [{
+    An attribute that allows defining lane parallelism for GPU devices.
+
+    It can be consumed by lowering to generate GPU.
+
+    #### 3D mapping mode
+
+    Unsupported
+
+    #### Linear mapping mode
+
+    The linear lane id is obtained by linearizing the index of the lane.
+    If required, predication occurs on the linear id. This allows specifying
+    predication on a 1D subset of the (linearized) lanes.
+  }];
+}
+
 def GPUMemorySpaceMappingAttr : GPU_Attr<"GPUMemorySpaceMapping", "memory_space", [
   DeclareAttrInterfaceMethods<DeviceMappingAttrInterface> ] >  {
   let parameters = (ins
diff --git a/mlir/include/mlir/Dialect/GPU/TransformOps/Utils.h b/mlir/include/mlir/Dialect/GPU/TransformOps/Utils.h
index 52fc6f4d5c71b..111c67638efc8 100644
--- a/mlir/include/mlir/Dialect/GPU/TransformOps/Utils.h
+++ b/mlir/include/mlir/Dialect/GPU/TransformOps/Utils.h
@@ -117,6 +117,15 @@ struct GpuThreadIdBuilder : public GpuIdBuilder {
   GpuThreadIdBuilder(MLIRContext *ctx, bool useLinearMapping = false);
 };
 
+/// Builder for lane id.
+/// The `idBuilder` method returns nD values used for indexing rewrites as well
+/// as 1D sizes for predicate generation.
+/// This `useLinearMapping` case is the only supported case.
+struct GpuLaneIdBuilder : public GpuIdBuilder {
+  GpuLaneIdBuilder(MLIRContext *ctx, int64_t warpSize, bool unused);
+  int64_t warpSize = 32;
+};
+
 /// Determine if the size of the kernel configuration is supported by the
 /// GPU architecture being used.
 /// TODO this is currently hardwired to CUDA, parameterize and generalize.
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index a5eb62ce66e0b..56631f1aac084 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -106,6 +106,20 @@ int64_t GPUThreadMappingAttr::getRelativeIndex() const {
              : getMappingId();
 }
 
+int64_t GPULaneMappingAttr::getMappingId() const {
+  return static_cast<int64_t>(getLane());
+}
+
+bool GPULaneMappingAttr::isLinearMapping() const {
+  return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
+}
+
+int64_t GPULaneMappingAttr::getRelativeIndex() const {
+  return isLinearMapping()
+             ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
+             : getMappingId();
+}
+
 int64_t GPUMemorySpaceMappingAttr::getMappingId() const {
   return static_cast<int64_t>(getAddressSpace());
 }
diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
index 6446235c06fb2..20d1c94409238 100644
--- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
+++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
@@ -313,11 +313,14 @@ checkMappingAttributeTypes(std::optional<TransformOpInterface> transformOp,
                                      llvm::IsaPred<GPUWarpMappingAttr>);
   bool hasThreadMapping = llvm::any_of(forallOp.getMapping().value(),
                                        llvm::IsaPred<GPUThreadMappingAttr>);
+  bool hasLaneMapping = llvm::any_of(forallOp.getMapping().value(),
+                                     llvm::IsaPred<GPULaneMappingAttr>);
   int64_t countMappingTypes = 0;
   countMappingTypes += hasBlockMapping ? 1 : 0;
   countMappingTypes += hasWarpgroupMapping ? 1 : 0;
   countMappingTypes += hasWarpMapping ? 1 : 0;
   countMappingTypes += hasThreadMapping ? 1 : 0;
+  countMappingTypes += hasLaneMapping ? 1 : 0;
   if (countMappingTypes > 1) {
     return definiteFailureHelper(
         transformOp, forallOp,
@@ -330,7 +333,8 @@ checkMappingAttributeTypes(std::optional<TransformOpInterface> transformOp,
         "scf.forall op requires a mapping attribute of kind 'block'");
   }
   if (std::is_same<MappingKindType, ThreadMappingKind>::value &&
-      !hasThreadMapping && !hasWarpMapping && !hasWarpgroupMapping) {
+      !hasLaneMapping && !hasThreadMapping && !hasWarpMapping &&
+      !hasWarpgroupMapping) {
     return definiteFailureHelper(transformOp, forallOp,
                                  "scf.forall op requires a mapping attribute "
                                  "of kind 'thread' or 'warp'");
@@ -473,10 +477,17 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(
   SmallVector<int64_t> originalBasis(availableMappingSizes);
   bool originalBasisWasProvided = !originalBasis.empty();
   if (!originalBasisWasProvided) {
+    LDBG("----originalBasis was not provided, deriving it and there will be no "
+         "predication");
     originalBasis = forallMappingSizes;
     while (originalBasis.size() < 3)
       originalBasis.push_back(1);
+  } else {
+    LDBG("----originalBasis was provided, using it, there will be predication");
   }
+  LLVM_DEBUG(
+      llvm::interleaveComma(originalBasis, DBGS() << "------originalBasis: ");
+      llvm::dbgs() << "\n");
 
   IdBuilderResult builderResult =
       gpuIdBuilder.idBuilder(rewriter, loc, forallMappingSizes, originalBasis);
@@ -490,6 +501,7 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(
            forallMappingAttrs.getArrayRef().take_front(forallOp.getRank()))) {
     auto mappingAttr = cast<DeviceMappingAttrInterface>(dim);
     Value peIdOp = mappingIdOps[mappingAttr.getRelativeIndex()];
+    LDBG("----map: " << iv << " to" << peIdOp);
     bvm.map(iv, peIdOp);
   }
 
@@ -790,6 +802,9 @@ getThreadIdBuilder(std::optional<TransformOpInterface> transformOp,
           .Case([&](GPUThreadMappingAttr) {
             return GpuThreadIdBuilder(ctx, useLinearMapping);
           })
+          .Case([&](GPULaneMappingAttr) {
+            return GpuLaneIdBuilder(ctx, warpSize, useLinearMapping);
+          })
           .Default([&](DeviceMappingAttrInterface) -> GpuIdBuilder {
             llvm_unreachable("unknown mapping attribute");
           });
diff --git a/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp b/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp
index 9853e80828390..c693a2fa01e89 100644
--- a/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp
+++ b/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp
@@ -156,6 +156,63 @@ static GpuIdBuilderFnType common3DIdBuilderFn(int64_t multiplicity = 1) {
   return res;
 }
 
+/// Create a lane id builder that takes the `originalBasis` and decompose
+/// it in the basis of `forallMappingSizes`. The linear id builder returns an
+/// n-D vector of ids for indexing and 1-D size + id for predicate generation.
+static GpuIdBuilderFnType laneIdBuilderFn(int64_t periodicity) {
+  auto res = [periodicity](RewriterBase &rewriter, Location loc,
+                           ArrayRef<int64_t> forallMappingSizes,
+                           ArrayRef<int64_t> originalBasis) {
+    SmallVector<OpFoldResult> originalBasisOfr =
+        getAsIndexOpFoldResult(rewriter.getContext(), originalBasis);
+    OpFoldResult linearId =
+        buildLinearId<ThreadIdOp>(rewriter, loc, originalBasisOfr);
+    AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
+    linearId = affine::makeComposedFoldedAffineApply(
+        rewriter, loc, d0 % periodicity, {linearId});
+
+    // Sizes in [0 .. n] -> [n .. 0] order to properly compute strides in
+    // "row-major" order.
+    SmallVector<int64_t> reverseBasisSizes(llvm::reverse(forallMappingSizes));
+    SmallVector<int64_t> strides = computeStrides(reverseBasisSizes);
+    SmallVector<AffineExpr> delinearizingExprs = delinearize(d0, strides);
+    SmallVector<Value> ids;
+    // Reverse back to be in [0 .. n] order.
+    for (AffineExpr e : llvm::reverse(delinearizingExprs)) {
+      ids.push_back(
+          affine::makeComposedAffineApply(rewriter, loc, e, {linearId}));
+    }
+
+    // clang-format off
+      LLVM_DEBUG(llvm::interleaveComma(reverseBasisSizes,
+                                       DBGS() << "--delinearization basis: ");
+                 llvm::dbgs() << "\n";
+                 llvm::interleaveComma(strides,
+                                       DBGS() << "--delinearization strides: ");
+                 llvm::dbgs() << "\n";
+                 llvm::interleaveComma(delinearizingExprs,
+                                       DBGS() << "--delinearization exprs: ");
+                 llvm::dbgs() << "\n";
+                 llvm::interleaveComma(ids, DBGS() << "--ids: ");
+                 llvm::dbgs() << "\n";);
+    // clang-format on
+
+    // Return n-D ids for indexing and 1-D size + id for predicate generation.
+    return IdBuilderResult{
+        /*mappingIdOps=*/ids,
+        /*availableMappingSizes=*/
+        SmallVector<int64_t>{computeProduct(originalBasis)},
+        // `forallMappingSizes` iterate in the scaled basis, they need to be
+        // scaled back into the original basis to provide tight
+        // activeMappingSizes quantities for predication.
+        /*activeMappingSizes=*/
+        SmallVector<int64_t>{computeProduct(forallMappingSizes)},
+        /*activeIdOps=*/SmallVector<Value>{linearId.get<Value>()}};
+  };
+
+  return res;
+}
+
 namespace mlir {
 namespace transform {
 namespace gpu {
@@ -221,6 +278,16 @@ GpuThreadIdBuilder::GpuThreadIdBuilder(MLIRContext *ctx, bool useLinearMapping)
                   : common3DIdBuilderFn<ThreadIdOp>(/*multiplicity=*/1);
 }
 
+GpuLaneIdBuilder::GpuLaneIdBuilder(MLIRContext *ctx, int64_t warpSize,
+                                   bool unused)
+    : GpuIdBuilder(ctx, /*useLinearMapping=*/true,
+                   [](MLIRContext *ctx, MappingId id) {
+                     return GPULaneMappingAttr::get(ctx, id);
+                   }),
+      warpSize(warpSize) {
+  idBuilder = laneIdBuilderFn(/*periodicity=*/warpSize);
+}
+
 DiagnosedSilenceableFailure checkGpuLimits(TransformOpInterface transformOp,
                                            std::optional<int64_t> gridDimX,
                                            std::optional<int64_t> gridDimY,
diff --git a/mlir/test/Dialect/GPU/transform-gpu.mlir b/mlir/test/Dialect/GPU/transform-gpu.mlir
index 09ae0f4af686f..fe5d451408355 100644
--- a/mlir/test/Dialect/GPU/transform-gpu.mlir
+++ b/mlir/test/Dialect/GPU/transform-gpu.mlir
@@ -691,3 +691,66 @@ module attributes {transform.with_named_sequence} {
       transform.yield
   }
 }
+
+// -----
+
+#map = affine_map<(d0) -> (d0 *  128)>
+#map1 = affine_map<(d0) -> (d0 * 32)>
+
+// CHECK-DAG: #[[$MAPB:.*]] = affine_map<()[s0] -> (s0 * 128)>
+// CHECK-DAG: #[[$MAPLANE:.*]] = affine_map<()[s0, s1] -> ((s0 + s1 * 73) mod 32)>
+// CHECK-DAG: #[[$MAPI:.*]] = affine_map<()[s0, s1] -> (s0 * 32 + s1 * 2336 - ((s0 + s1 * 73) floordiv 2) * 64)>
+// CHECK-DAG: #[[$MAPJ:.*]] = affine_map<()[s0, s1] -> ((((s0 + s1 * 73) mod 32) floordiv 2) * 32)>
+
+// CHECK-LABEL: func.func @simple_fill(
+func.func @simple_fill(%arg0: memref<128xf32>) -> memref<128xf32> {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant dense<0.000000e+00> : vector<32xf32>
+//       CHECK:   %[[C6:.*]] = arith.constant 6 : index
+//       CHECK:   gpu.launch
+  scf.forall (%arg1) in (1) {
+//       CHECK:     %[[BIDX:.*]] = gpu.block_id  x
+//       CHECK:     %[[BLX:.*]] = affine.apply #[[$MAPB]]()[%[[BIDX]]]
+    %0 = affine.apply #map(%arg1)
+    %subview = memref.subview %arg0[%0] [128] [1] : memref<128xf32> to memref<128xf32, strided<[1], offset: ?>>
+
+    // %arg2 and %arg3 map to lanes [0, 6) and are turned into epxressions
+    // involving threadIdx.x/y by the map_nested_forall_to_threads
+    // transformation. This results in a if (linear_thread_id < 6) conditional.
+    scf.forall (%arg2, %arg3) in (2, 3) {
+      //       CHECK:     %[[TIDX:.*]] = gpu.thread_id  x
+      //       CHECK:     %[[TIDY:.*]] = gpu.thread_id  y
+      //       CHECK:     %[[LID:.*]] = affine.apply #[[$MAPLANE]]()[%[[TIDX]], %[[TIDY]]]
+      //       CHECK:     %[[COND:.*]] = arith.cmpi ult, %[[LID]], %[[C6]]
+      //       CHECK:     scf.if %[[COND]]
+      //       CHECK:       %[[I:.*]] = affine.apply #[[$MAPI]]()[%[[TIDX]], %[[TIDY]]]
+      //       CHECK:       %[[J:.*]] = affine.apply #[[$MAPJ]]()[%[[TIDX]], %[[TIDY]]]
+      //       CHECK:       memref.subview %{{.*}}[%[[I]]] [%[[J]]]
+      %1 = affine.apply #map1(%arg2)
+      %2 = affine.apply #map1(%arg3)
+      %subview_0 = memref.subview %subview[%1] [%2] [1] : memref<128xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[1], offset: ?>>
+      vector.transfer_write %cst, %subview_0[%c0] {in_bounds = [true]} : vector<32xf32>, memref<?xf32, strided<[1], offset: ?>>
+
+    // This could be obtained e.g. if a previous transformation mapped this loop
+    // to lanes. This can aslo be written by hand as valid IR.
+    } {mapping = [#gpu.lane<linear_dim_0>, #gpu.lane<linear_dim_1>]}
+
+    memref.copy %subview, %subview : memref<128xf32, strided<[1], offset: ?>> to memref<128xf32, strided<[1], offset: ?>>
+  } {mapping = [#gpu.block<x>]}
+  return %arg0 : memref<128xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %module_op
+      : (!transform.any_op) -> !transform.any_op
+    %gpu_launch = transform.gpu.map_forall_to_blocks %func generate_gpu_launch
+      : (!transform.any_op) -> !transform.any_op
+
+    // This transformation maps scf.forall ivs to a particular mapping of thread
+    // ids (laneid, threadid, warpid or warpgroupid).
+    transform.gpu.map_nested_forall_to_threads %gpu_launch block_dims = [73, 5, 1]
+      : (!transform.any_op) -> !transform.any_op
+      transform.yield
+  }
+}

>From c88aee740d5d944364e79600bf3c01493a1c3fee Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nico.vasilache at amd.com>
Date: Thu, 3 Jul 2025 18:32:59 +0200
Subject: [PATCH 2/4] [mlir] NFC - refactor id builder and avoid leaking impl
 details

---
 .../mlir/Dialect/GPU/TransformOps/Utils.h     |  31 ++-
 .../GPU/TransformOps/GPUTransformOps.cpp      |  33 +---
 mlir/lib/Dialect/GPU/TransformOps/Utils.cpp   | 176 +++++++++++-------
 3 files changed, 127 insertions(+), 113 deletions(-)

diff --git a/mlir/include/mlir/Dialect/GPU/TransformOps/Utils.h b/mlir/include/mlir/Dialect/GPU/TransformOps/Utils.h
index 111c67638efc8..de512ded59fec 100644
--- a/mlir/include/mlir/Dialect/GPU/TransformOps/Utils.h
+++ b/mlir/include/mlir/Dialect/GPU/TransformOps/Utils.h
@@ -28,27 +28,24 @@ namespace transform {
 namespace gpu {
 
 /// Helper type for functions that generate ids for the mapping of a scf.forall.
-/// Operates on both 1) an "original" basis that represents the individual
-/// thread and block ids and 2) a "scaled" basis that represents grouped ids
-/// (e.g. block clusters, warpgroups and warps).
-/// The mapping of ids is done in the "scaled" basis (i.e. when mapping to warps
-/// a division by 32 occurs).
-/// The predication is in the "original" basis using the "active" quantities
-/// (`activeMappingSizes`, `availableMappingSizes` and `activeIdOps`).
 struct IdBuilderResult {
-  // Ops used to replace the forall induction variables.
+  /// Error message, if not empty then building the ids failed.
+  std::string errorMsg;
+  /// Values used to replace the forall induction variables.
   SmallVector<Value> mappingIdOps;
-  // Available mapping sizes used to predicate the forall body when they are
-  // larger than the predicate mapping sizes.
-  SmallVector<int64_t> availableMappingSizes;
-  // Actual mapping sizes used to predicate the forall body when they are
-  // smaller than the available mapping sizes.
-  SmallVector<int64_t> activeMappingSizes;
-  // Ops used to predicate the forall body when activeMappingSizes is smaller
-  // than the available mapping sizes.
-  SmallVector<Value> activeIdOps;
+  /// Values used to predicate the forall body when activeMappingSizes is
+  /// smaller than the available mapping sizes.
+  SmallVector<Value> predicateOps;
 };
 
+inline raw_ostream &operator<<(raw_ostream &os, const IdBuilderResult &res) {
+  llvm::interleaveComma(res.mappingIdOps, os << "----mappingIdOps: ");
+  os << "\n";
+  llvm::interleaveComma(res.predicateOps, os << "----predicateOps: ");
+  os << "\n";
+  return os;
+}
+
 /// Common gpu id builder type, allows the configuration of lowering for various
 /// mapping schemes. Takes:
 ///   - A rewriter with insertion point set before the forall op to rewrite.
diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
index 20d1c94409238..63f87d9b5877e 100644
--- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
+++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
@@ -491,6 +491,10 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(
 
   IdBuilderResult builderResult =
       gpuIdBuilder.idBuilder(rewriter, loc, forallMappingSizes, originalBasis);
+  if (!builderResult.errorMsg.empty())
+    return definiteFailureHelper(transformOp, forallOp, builderResult.errorMsg);
+
+  LLVM_DEBUG(DBGS() << builderResult);
 
   // Step 4. Map the induction variables to the mappingIdOps, this may involve
   // a permutation.
@@ -501,7 +505,7 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(
            forallMappingAttrs.getArrayRef().take_front(forallOp.getRank()))) {
     auto mappingAttr = cast<DeviceMappingAttrInterface>(dim);
     Value peIdOp = mappingIdOps[mappingAttr.getRelativeIndex()];
-    LDBG("----map: " << iv << " to" << peIdOp);
+    LDBG("----map: " << iv << " to " << peIdOp);
     bvm.map(iv, peIdOp);
   }
 
@@ -510,32 +514,7 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(
   // originalBasis and no predication occurs.
   Value predicate;
   if (originalBasisWasProvided) {
-    SmallVector<int64_t> activeMappingSizes = builderResult.activeMappingSizes;
-    SmallVector<int64_t> availableMappingSizes =
-        builderResult.availableMappingSizes;
-    SmallVector<Value> activeIdOps = builderResult.activeIdOps;
-    LDBG("----activeMappingSizes: " << llvm::interleaved(activeMappingSizes));
-    LDBG("----availableMappingSizes: "
-         << llvm::interleaved(availableMappingSizes));
-    LDBG("----activeIdOps: " << llvm::interleaved(activeIdOps));
-    for (auto [activeId, activeMappingSize, availableMappingSize] :
-         llvm::zip_equal(activeIdOps, activeMappingSizes,
-                         availableMappingSizes)) {
-      if (activeMappingSize > availableMappingSize) {
-        return definiteFailureHelper(
-            transformOp, forallOp,
-            "Trying to map to fewer GPU threads than loop iterations but "
-            "overprovisioning is not yet supported. "
-            "Try additional tiling of the before mapping or map to more "
-            "threads.");
-      }
-      if (activeMappingSize == availableMappingSize)
-        continue;
-      Value idx =
-          rewriter.create<arith::ConstantIndexOp>(loc, activeMappingSize);
-      Value tmpPredicate = rewriter.create<arith::CmpIOp>(
-          loc, arith::CmpIPredicate::ult, activeId, idx);
-      LDBG("----predicate: " << tmpPredicate);
+    for (Value tmpPredicate : builderResult.predicateOps) {
       predicate = predicate ? rewriter.create<arith::AndIOp>(loc, predicate,
                                                              tmpPredicate)
                             : tmpPredicate;
diff --git a/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp b/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp
index c693a2fa01e89..795d643c05912 100644
--- a/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp
+++ b/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp
@@ -47,12 +47,57 @@ using namespace mlir::transform::gpu;
 #define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n")
 #define DBGS_ALIAS() (llvm::dbgs() << '[' << DEBUG_TYPE_ALIAS << "] ")
 
+/// Build predicates to filter execution by only the activeIds. Along each
+/// dimension, 3 cases appear:
+///   1. activeMappingSize > availableMappingSize: this is an unsupported case
+///      as this requires additional looping. An error message is produced to
+///      advise the user to tile more or to use more threads.
+///   2. activeMappingSize == availableMappingSize: no predication is needed.
+///   3. activeMappingSize < availableMappingSize: only a subset of threads
+///      should be active and we produce the boolean `id < activeMappingSize`
+///      for further use in building predicated execution.
+static FailureOr<SmallVector<Value>>
+buildPredicates(RewriterBase &rewriter, Location loc, ArrayRef<Value> activeIds,
+                ArrayRef<int64_t> activeMappingSizes,
+                ArrayRef<int64_t> availableMappingSizes,
+                std::string &errorMsg) {
+  // clang-format off
+  LLVM_DEBUG(
+    llvm::interleaveComma(
+      activeMappingSizes, DBGS() << "----activeMappingSizes: ");
+    DBGS() << "\n";
+    llvm::interleaveComma(
+      availableMappingSizes, DBGS() << "----availableMappingSizes: ");
+    DBGS() << "\n";);
+  // clang-format on
+
+  SmallVector<Value> predicateOps;
+  for (auto [activeId, activeMappingSize, availableMappingSize] :
+       llvm::zip_equal(activeIds, activeMappingSizes, availableMappingSizes)) {
+    if (activeMappingSize > availableMappingSize) {
+      errorMsg = "Trying to map to fewer GPU threads than loop iterations but "
+                 "overprovisioning is not yet supported. Try additional tiling "
+                 "before mapping or map to more threads.";
+      return failure();
+    }
+    if (activeMappingSize == availableMappingSize)
+      continue;
+    Value idx = rewriter.create<arith::ConstantIndexOp>(loc, activeMappingSize);
+    Value pred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
+                                                activeId, idx);
+    predicateOps.push_back(pred);
+  }
+  return predicateOps;
+}
+
 /// Return a flattened thread id for the workgroup with given sizes.
 template <typename ThreadOrBlockIdOp>
 static Value buildLinearId(RewriterBase &rewriter, Location loc,
                            ArrayRef<OpFoldResult> originalBasisOfr) {
-  LLVM_DEBUG(DBGS() << "----buildLinearId with originalBasisOfr:  "
-                    << llvm::interleaved(originalBasisOfr) << "\n");
+  LLVM_DEBUG(llvm::interleaveComma(
+                 originalBasisOfr,
+                 DBGS() << "----buildLinearId with originalBasisOfr:  ");
+             llvm::dbgs() << "\n");
   assert(originalBasisOfr.size() == 3 && "expected 3 sizes");
   IndexType indexType = rewriter.getIndexType();
   AffineExpr tx, ty, tz, bdx, bdy;
@@ -79,44 +124,43 @@ static GpuIdBuilderFnType commonLinearIdBuilderFn(int64_t multiplicity = 1) {
   auto res = [multiplicity](RewriterBase &rewriter, Location loc,
                             ArrayRef<int64_t> forallMappingSizes,
                             ArrayRef<int64_t> originalBasis) {
+    // 1. Compute linearId.
     SmallVector<OpFoldResult> originalBasisOfr =
         getAsIndexOpFoldResult(rewriter.getContext(), originalBasis);
-    OpFoldResult linearId =
+    Value physicalLinearId =
         buildLinearId<ThreadOrBlockIdOp>(rewriter, loc, originalBasisOfr);
+
+    // 2. Compute scaledLinearId.
+    AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
+    OpFoldResult scaledLinearId = affine::makeComposedFoldedAffineApply(
+        rewriter, loc, d0.floorDiv(multiplicity), {physicalLinearId});
+
+    // 3. Compute remapped indices.
+    SmallVector<Value> ids;
     // Sizes in [0 .. n] -> [n .. 0] order to properly compute strides in
     // "row-major" order.
     SmallVector<int64_t> reverseBasisSizes(llvm::reverse(forallMappingSizes));
     SmallVector<int64_t> strides = computeStrides(reverseBasisSizes);
-    AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
-    OpFoldResult scaledLinearId = affine::makeComposedFoldedAffineApply(
-        rewriter, loc, d0.floorDiv(multiplicity), {linearId});
     SmallVector<AffineExpr> delinearizingExprs = delinearize(d0, strides);
-    SmallVector<Value> ids;
     // Reverse back to be in [0 .. n] order.
     for (AffineExpr e : llvm::reverse(delinearizingExprs)) {
       ids.push_back(
           affine::makeComposedAffineApply(rewriter, loc, e, {scaledLinearId}));
     }
 
-    LLVM_DEBUG(DBGS() << "--delinearization basis: "
-                      << llvm::interleaved(reverseBasisSizes) << "\n";
-               DBGS() << "--delinearization strides: "
-                      << llvm::interleaved(strides) << "\n";
-               DBGS() << "--delinearization exprs: "
-                      << llvm::interleaved(delinearizingExprs) << "\n";
-               DBGS() << "--ids: " << llvm::interleaved(ids) << "\n");
-
-    // Return n-D ids for indexing and 1-D size + id for predicate generation.
-    return IdBuilderResult{
-        /*mappingIdOps=*/ids,
-        /*availableMappingSizes=*/
-        SmallVector<int64_t>{computeProduct(originalBasis)},
-        // `forallMappingSizes` iterate in the scaled basis, they need to be
-        // scaled back into the original basis to provide tight
-        // activeMappingSizes quantities for predication.
-        /*activeMappingSizes=*/
-        SmallVector<int64_t>{computeProduct(forallMappingSizes) * multiplicity},
-        /*activeIdOps=*/SmallVector<Value>{cast<Value>(linearId)}};
+    // 4. Handle predicates using physicalLinearId.
+    std::string errorMsg;
+    SmallVector<Value> predicateOps;
+    FailureOr<SmallVector<Value>> maybePredicateOps =
+        buildPredicates(rewriter, loc, physicalLinearId,
+                        computeProduct(forallMappingSizes) * multiplicity,
+                        computeProduct(originalBasis), errorMsg);
+    if (succeeded(maybePredicateOps))
+      predicateOps = *maybePredicateOps;
+
+    return IdBuilderResult{/*errorMsg=*/errorMsg,
+                           /*mappingIdOps=*/ids,
+                           /*predicateOps=*/predicateOps};
   };
 
   return res;
@@ -143,15 +187,18 @@ static GpuIdBuilderFnType common3DIdBuilderFn(int64_t multiplicity = 1) {
     // In the 3-D mapping case, unscale the first dimension by the multiplicity.
     SmallVector<int64_t> forallMappingSizeInOriginalBasis(forallMappingSizes);
     forallMappingSizeInOriginalBasis[0] *= multiplicity;
-    return IdBuilderResult{
-        /*mappingIdOps=*/scaledIds,
-        /*availableMappingSizes=*/SmallVector<int64_t>{originalBasis},
-        // `forallMappingSizes` iterate in the scaled basis, they need to be
-        // scaled back into the original basis to provide tight
-        // activeMappingSizes quantities for predication.
-        /*activeMappingSizes=*/
-        SmallVector<int64_t>{forallMappingSizeInOriginalBasis},
-        /*activeIdOps=*/ids};
+
+    std::string errorMsg;
+    SmallVector<Value> predicateOps;
+    FailureOr<SmallVector<Value>> maybePredicateOps =
+        buildPredicates(rewriter, loc, ids, forallMappingSizeInOriginalBasis,
+                        originalBasis, errorMsg);
+    if (succeeded(maybePredicateOps))
+      predicateOps = *maybePredicateOps;
+
+    return IdBuilderResult{/*errorMsg=*/errorMsg,
+                           /*mappingIdOps=*/scaledIds,
+                           /*predicateOps=*/predicateOps};
   };
   return res;
 }
@@ -159,55 +206,46 @@ static GpuIdBuilderFnType common3DIdBuilderFn(int64_t multiplicity = 1) {
 /// Create a lane id builder that takes the `originalBasis` and decompose
 /// it in the basis of `forallMappingSizes`. The linear id builder returns an
 /// n-D vector of ids for indexing and 1-D size + id for predicate generation.
-static GpuIdBuilderFnType laneIdBuilderFn(int64_t periodicity) {
-  auto res = [periodicity](RewriterBase &rewriter, Location loc,
-                           ArrayRef<int64_t> forallMappingSizes,
-                           ArrayRef<int64_t> originalBasis) {
+static GpuIdBuilderFnType laneIdBuilderFn(int64_t warpSize) {
+  auto res = [warpSize](RewriterBase &rewriter, Location loc,
+                        ArrayRef<int64_t> forallMappingSizes,
+                        ArrayRef<int64_t> originalBasis) {
+    // 1. Compute linearId.
     SmallVector<OpFoldResult> originalBasisOfr =
         getAsIndexOpFoldResult(rewriter.getContext(), originalBasis);
-    OpFoldResult linearId =
+    Value physicalLinearId =
         buildLinearId<ThreadIdOp>(rewriter, loc, originalBasisOfr);
+
+    // 2. Compute laneId.
     AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
-    linearId = affine::makeComposedFoldedAffineApply(
-        rewriter, loc, d0 % periodicity, {linearId});
+    OpFoldResult laneId = affine::makeComposedFoldedAffineApply(
+        rewriter, loc, d0 % warpSize, {physicalLinearId});
 
+    // 3. Compute remapped indices.
+    SmallVector<Value> ids;
     // Sizes in [0 .. n] -> [n .. 0] order to properly compute strides in
     // "row-major" order.
     SmallVector<int64_t> reverseBasisSizes(llvm::reverse(forallMappingSizes));
     SmallVector<int64_t> strides = computeStrides(reverseBasisSizes);
     SmallVector<AffineExpr> delinearizingExprs = delinearize(d0, strides);
-    SmallVector<Value> ids;
     // Reverse back to be in [0 .. n] order.
     for (AffineExpr e : llvm::reverse(delinearizingExprs)) {
       ids.push_back(
-          affine::makeComposedAffineApply(rewriter, loc, e, {linearId}));
+          affine::makeComposedAffineApply(rewriter, loc, e, {laneId}));
     }
 
-    // clang-format off
-      LLVM_DEBUG(llvm::interleaveComma(reverseBasisSizes,
-                                       DBGS() << "--delinearization basis: ");
-                 llvm::dbgs() << "\n";
-                 llvm::interleaveComma(strides,
-                                       DBGS() << "--delinearization strides: ");
-                 llvm::dbgs() << "\n";
-                 llvm::interleaveComma(delinearizingExprs,
-                                       DBGS() << "--delinearization exprs: ");
-                 llvm::dbgs() << "\n";
-                 llvm::interleaveComma(ids, DBGS() << "--ids: ");
-                 llvm::dbgs() << "\n";);
-    // clang-format on
-
-    // Return n-D ids for indexing and 1-D size + id for predicate generation.
-    return IdBuilderResult{
-        /*mappingIdOps=*/ids,
-        /*availableMappingSizes=*/
-        SmallVector<int64_t>{computeProduct(originalBasis)},
-        // `forallMappingSizes` iterate in the scaled basis, they need to be
-        // scaled back into the original basis to provide tight
-        // activeMappingSizes quantities for predication.
-        /*activeMappingSizes=*/
-        SmallVector<int64_t>{computeProduct(forallMappingSizes)},
-        /*activeIdOps=*/SmallVector<Value>{linearId.get<Value>()}};
+    // 4. Handle predicates using laneId.
+    std::string errorMsg;
+    SmallVector<Value> predicateOps;
+    FailureOr<SmallVector<Value>> maybePredicateOps = buildPredicates(
+        rewriter, loc, cast<Value>(laneId), computeProduct(forallMappingSizes),
+        computeProduct(originalBasis), errorMsg);
+    if (succeeded(maybePredicateOps))
+      predicateOps = *maybePredicateOps;
+
+    return IdBuilderResult{/*errorMsg=*/errorMsg,
+                           /*mappingIdOps=*/ids,
+                           /*predicateOps=*/predicateOps};
   };
 
   return res;

>From 85aa5f8c72801f5a75142a663d6e89e83e63decc Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nico.vasilache at amd.com>
Date: Thu, 3 Jul 2025 21:26:53 +0200
Subject: [PATCH 3/4] [mlir][SCF][GPU] Add DeviceMaskingAttrInterface support
 to scf::ForallOp and use it to implement warp specialization.

This revision adds DeviceMaskingAttrInterface and extends
DeviceMappingArrayAttr to accept a union of DeviceMappingAttrInterface
and DeviceMaskingAttrInterface.

The first implementation is if the form of a GPUMappingMaskAttr, which
can be additionally passed to the scf.forall.mapping attribute to
specify a mask on compute resources that should be active.

Support is added to GPUTransformOps to take advantage of this
information and lower to block/warpgroup/warp/thread specialization when
mapped to linear ids.

Co-authored-by: Oleksandr "Alex" Zinenko <git at ozinenko.com>
---
 .../Dialect/GPU/IR/GPUDeviceMappingAttr.td    |  18 ++++
 .../mlir/Dialect/GPU/TransformOps/Utils.h     |  15 ++-
 .../Dialect/SCF/IR/DeviceMappingInterface.td  |  45 +++++++-
 mlir/include/mlir/Dialect/SCF/IR/SCFOps.td    |  12 +++
 mlir/lib/Dialect/GPU/CMakeLists.txt           |   1 +
 mlir/lib/Dialect/GPU/IR/GPUDialect.cpp        |  45 ++++++++
 .../GPU/TransformOps/GPUTransformOps.cpp      |  62 +++++++----
 mlir/lib/Dialect/GPU/TransformOps/Utils.cpp   | 102 +++++++++++++-----
 mlir/lib/Dialect/SCF/IR/SCF.cpp               |  43 ++++++--
 .../Dialect/GPU/transform-gpu-failing.mlir    |  61 +++++++++++
 mlir/test/Dialect/GPU/transform-gpu.mlir      |  81 ++++++++++++++
 mlir/test/Dialect/SCF/invalid.mlir            |  18 ++++
 12 files changed, 444 insertions(+), 59 deletions(-)

diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUDeviceMappingAttr.td b/mlir/include/mlir/Dialect/GPU/IR/GPUDeviceMappingAttr.td
index 63f228ca3157f..e8540027e7b77 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUDeviceMappingAttr.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUDeviceMappingAttr.td
@@ -252,6 +252,24 @@ def GPULaneMappingAttr
   }];
 }
 
+def GPUMappingMaskAttr : GPU_Attr<"GPUMappingMask", "mask", [
+  DeclareAttrInterfaceMethods<DeviceMaskingAttrInterface> ] >  {
+  let parameters = (ins "uint64_t":$mask);
+  let assemblyFormat = "`<` params `>`";
+  let description = [{
+    Attribute describing how to filter the processing units that a
+    region is mapped to.
+
+    In the first implementation the masking is a bitfield that specifies for
+    each processing unit whether it is active or not.
+
+    In the future, we may want to implement this as a symbol to refer to
+    dynamically defined values.
+
+    Extending op semantics with an operand is deemed too intrusive at this time.
+  }];
+}
+
 def GPUMemorySpaceMappingAttr : GPU_Attr<"GPUMemorySpaceMapping", "memory_space", [
   DeclareAttrInterfaceMethods<DeviceMappingAttrInterface> ] >  {
   let parameters = (ins
diff --git a/mlir/include/mlir/Dialect/GPU/TransformOps/Utils.h b/mlir/include/mlir/Dialect/GPU/TransformOps/Utils.h
index de512ded59fec..0a11b8f8d3fa0 100644
--- a/mlir/include/mlir/Dialect/GPU/TransformOps/Utils.h
+++ b/mlir/include/mlir/Dialect/GPU/TransformOps/Utils.h
@@ -78,7 +78,8 @@ struct GpuIdBuilder {
 /// If `useLinearMapping` is true, the `idBuilder` method returns nD values
 /// used for indexing rewrites as well as 1D sizes for predicate generation.
 struct GpuBlockIdBuilder : public GpuIdBuilder {
-  GpuBlockIdBuilder(MLIRContext *ctx, bool useLinearMapping = false);
+  GpuBlockIdBuilder(MLIRContext *ctx, bool useLinearMapping = false,
+                    DeviceMaskingAttrInterface mask = nullptr);
 };
 
 /// Builder for warpgroup ids used to map scf.forall to reindexed warpgroups.
@@ -88,7 +89,8 @@ struct GpuBlockIdBuilder : public GpuIdBuilder {
 /// used for indexing rewrites as well as 1D sizes for predicate generation.
 struct GpuWarpgroupIdBuilder : public GpuIdBuilder {
   GpuWarpgroupIdBuilder(MLIRContext *ctx, int64_t warpSize,
-                        bool useLinearMapping = false);
+                        bool useLinearMapping = false,
+                        DeviceMaskingAttrInterface mask = nullptr);
   int64_t warpSize = 32;
   /// In the future this may be configured by the transformation.
   static constexpr int64_t kNumWarpsPerGroup = 4;
@@ -101,7 +103,8 @@ struct GpuWarpgroupIdBuilder : public GpuIdBuilder {
 /// used for indexing rewrites as well as 1D sizes for predicate generation.
 struct GpuWarpIdBuilder : public GpuIdBuilder {
   GpuWarpIdBuilder(MLIRContext *ctx, int64_t warpSize,
-                   bool useLinearMapping = false);
+                   bool useLinearMapping = false,
+                   DeviceMaskingAttrInterface mask = nullptr);
   int64_t warpSize = 32;
 };
 
@@ -111,7 +114,8 @@ struct GpuWarpIdBuilder : public GpuIdBuilder {
 /// If `useLinearMapping` is true, the `idBuilder` method returns nD values
 /// used for indexing rewrites as well as 1D sizes for predicate generation.
 struct GpuThreadIdBuilder : public GpuIdBuilder {
-  GpuThreadIdBuilder(MLIRContext *ctx, bool useLinearMapping = false);
+  GpuThreadIdBuilder(MLIRContext *ctx, bool useLinearMapping = false,
+                     DeviceMaskingAttrInterface mask = nullptr);
 };
 
 /// Builder for lane id.
@@ -119,7 +123,8 @@ struct GpuThreadIdBuilder : public GpuIdBuilder {
 /// as 1D sizes for predicate generation.
 /// This `useLinearMapping` case is the only supported case.
 struct GpuLaneIdBuilder : public GpuIdBuilder {
-  GpuLaneIdBuilder(MLIRContext *ctx, int64_t warpSize, bool unused);
+  GpuLaneIdBuilder(MLIRContext *ctx, int64_t warpSize, bool unused,
+                   DeviceMaskingAttrInterface mask = nullptr);
   int64_t warpSize = 32;
 };
 
diff --git a/mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td b/mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td
index 96db2a40cf58e..353aaf05bee0c 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td
@@ -60,8 +60,51 @@ def DeviceMappingAttrInterface : AttrInterface<"DeviceMappingAttrInterface"> {
   ];
 }
 
+def DeviceMaskingAttrInterface : AttrInterface<"DeviceMaskingAttrInterface"> {
+  let cppNamespace = "::mlir";
+  let description = [{
+    Attribute interface describing how to filter the processing units that a
+    region is mapped to.
+
+    A popcount can be applied to determine the logical linear index that a
+    physical processing unit is responsible for.
+  }];
+
+ let methods = [
+    InterfaceMethod<
+      /*desc=*/[{
+        Return the logical active id for a given physical id.
+        Expects a physicalLinearMappingId of I64Type.
+      }],
+      /*retTy=*/"Value",
+      /*methodName=*/"getLogicalLinearMappingId",
+      /*args=*/(ins "OpBuilder&":$builder, "Value":$physicalLinearMappingId)
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return the dynamic condition determining whether a given physical id is
+        active under the mask.
+        Expects a physicalLinearMappingId of I64Type.
+      }],
+      /*retTy=*/"Value",
+      /*methodName=*/"getIsActiveIdPredicate",
+      /*args=*/(ins "OpBuilder&":$builder, "Value":$physicalLinearMappingId)
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return the maximal number of pysical ids supported.
+        This is to account for temporary implementation limitations (e.g. i64)
+        and fail gracefully with actionnable error messages.
+      }],
+      /*retTy=*/"int64_t",
+      /*methodName=*/"getMaxNumPhysicalIds",
+      /*args=*/(ins)
+    >,
+  ];
+}
+
 def DeviceMappingArrayAttr :
-  TypedArrayAttrBase<DeviceMappingAttrInterface,
+  TypedArrayAttrBase<AnyAttrOf<[DeviceMappingAttrInterface, DeviceMaskingAttrInterface]>,
   "Device Mapping array attribute"> { }
 
 #endif // MLIR_DEVICEMAPPINGINTERFACE
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 8b14cef7437d4..2d15544e871b3 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -611,6 +611,18 @@ def ForallOp : SCF_Op<"forall", [
     /// Returns operations within scf.forall.in_parallel whose destination
     /// operand is the block argument `bbArg`.
     SmallVector<Operation*> getCombiningOps(BlockArgument bbArg);
+
+    /// Returns the subset of DeviceMappingArrayAttrs of type
+    /// DeviceMappingAttrInterface.
+    SmallVector<DeviceMappingAttrInterface> getDeviceMappingAttrs();
+
+    /// Returns the at most one DeviceMaskingAttrInterface in the mapping.
+    /// If more than one DeviceMaskingAttrInterface is specified, returns
+    /// failure. If no mapping is present, returns nullptr.
+    FailureOr<DeviceMaskingAttrInterface> getDeviceMaskingAttr();
+
+    /// Returns true if the mapping specified for this forall op is linear.
+    bool usesLinearMapping();
   }];
 }
 
diff --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt
index c8c53374d676b..4862d1f722785 100644
--- a/mlir/lib/Dialect/GPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/GPU/CMakeLists.txt
@@ -20,6 +20,7 @@ add_mlir_dialect_library(MLIRGPUDialect
   MLIRFunctionInterfaces
   MLIRInferIntRangeInterface
   MLIRIR
+  MLIRMathDialect
   MLIRMemRefDialect
   MLIRSideEffectInterfaces
   MLIRSupport
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 56631f1aac084..9d74c23c24cc8 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -14,6 +14,7 @@
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
+#include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
@@ -120,6 +121,50 @@ int64_t GPULaneMappingAttr::getRelativeIndex() const {
              : getMappingId();
 }
 
+int64_t GPUMappingMaskAttr::getMaxNumPhysicalIds() const { return 64; }
+
+///                 8       4       0
+/// Example mask  : 0 0 0 1 1 0 1 0 0
+///
+/// Active physical (resp. logical) is  2 (0), 4 (1) and 5 (2).
+/// Logical id for e.g. 5 (2) constructs filter (1 << 5 - 1).
+///
+/// Example mask  : 0 0 0 1 1 0 1 0 0
+/// Example filter: 0 0 0 0 1 1 1 1 1
+/// Intersection  : 0 0 0 0 1 0 1 0 0
+/// PopCnt        : 2
+Value GPUMappingMaskAttr::getLogicalLinearMappingId(
+    OpBuilder &b, Value physicalLinearMappingId) const {
+  Location loc = physicalLinearMappingId.getLoc();
+  Value mask = b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(getMask()));
+  Value one = b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(1));
+  Value filter = b.create<arith::ShLIOp>(loc, one, physicalLinearMappingId);
+  filter = b.create<arith::SubIOp>(loc, filter, one);
+  Value filteredId = b.create<arith::AndIOp>(loc, mask, filter);
+  return b.create<math::CtPopOp>(loc, filteredId);
+}
+
+///                 8       4       0
+/// Example mask  : 0 0 0 1 1 0 1 0 0
+///
+/// Active physical (resp. logical) is  2 (0), 4 (1) and 5 (2).
+/// Logical id for e.g. 5 (2) constructs filter (1 << 5).
+///
+/// Example mask  : 0 0 0 1 1 0 1 0 0
+/// Example filter: 0 0 0 1 0 0 0 0 0
+/// Intersection  : 0 0 0 1 0 0 0 0 0
+/// Cmp           : 1
+Value GPUMappingMaskAttr::getIsActiveIdPredicate(
+    OpBuilder &b, Value physicalLinearMappingId) const {
+  Location loc = physicalLinearMappingId.getLoc();
+  Value mask = b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(getMask()));
+  Value one = b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(1));
+  Value filter = b.create<arith::ShLIOp>(loc, one, physicalLinearMappingId);
+  Value filtered = b.create<arith::AndIOp>(loc, mask, filter);
+  Value zero = b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(0));
+  return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, filtered, zero);
+}
+
 int64_t GPUMemorySpaceMappingAttr::getMappingId() const {
   return static_cast<int64_t>(getAddressSpace());
 }
diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
index 63f87d9b5877e..a86fc47947130 100644
--- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
+++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
@@ -351,16 +351,25 @@ checkMappingAttributeTypes(std::optional<TransformOpInterface> transformOp,
     seen.insert(map);
   }
 
-  auto isLinear = [](Attribute a) {
-    return cast<DeviceMappingAttrInterface>(a).isLinearMapping();
+  auto isLinear = [](DeviceMappingAttrInterface attr) {
+    return attr.isLinearMapping();
   };
-  if (llvm::any_of(forallOp.getMapping()->getValue(), isLinear) &&
-      !llvm::all_of(forallOp.getMapping()->getValue(), isLinear)) {
+  if (llvm::any_of(forallOp.getDeviceMappingAttrs(), isLinear) &&
+      !llvm::all_of(forallOp.getDeviceMappingAttrs(), isLinear)) {
     return definiteFailureHelper(
         transformOp, forallOp,
         "cannot mix linear and non-linear mapping modes");
   }
 
+  FailureOr<DeviceMaskingAttrInterface> maybeMaskingAttr =
+      forallOp.getDeviceMaskingAttr();
+  if (succeeded(maybeMaskingAttr) && *maybeMaskingAttr &&
+      !forallOp.usesLinearMapping()) {
+    return definiteFailureHelper(
+        transformOp, forallOp,
+        "device masking is only available in linear mapping mode");
+  }
+
   return DiagnosedSilenceableFailure::success();
 }
 
@@ -381,9 +390,7 @@ verifyGpuMapping(std::optional<TransformOpInterface> transformOp,
   if (forallOp.getNumResults() > 0)
     return definiteFailureHelper(transformOp, forallOp,
                                  "only bufferized scf.forall can be mapped");
-  bool useLinearMapping = cast<DeviceMappingAttrInterface>(
-                              forallOp.getMapping()->getValue().front())
-                              .isLinearMapping();
+  bool useLinearMapping = forallOp.usesLinearMapping();
   // TODO: This would be more natural with support for Optional<EnumParameter>
   // in GPUDeviceMappingAttr.
   int64_t maxNumMappingsSupported =
@@ -436,8 +443,10 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(
   assert(forallOp.isNormalized() && numParallelIterations.has_value() &&
          "requires statically sized, normalized forall op");
   SmallVector<int64_t> tmpMappingSizes = numParallelIterations.value();
+  SmallVector<DeviceMappingAttrInterface> forallMappingAttrsVec =
+      forallOp.getDeviceMappingAttrs();
   SetVector<Attribute> forallMappingAttrs;
-  forallMappingAttrs.insert_range(forallOp.getMapping()->getValue());
+  forallMappingAttrs.insert_range(forallMappingAttrsVec);
   auto comparator = [](Attribute a, Attribute b) -> bool {
     return cast<DeviceMappingAttrInterface>(a).getMappingId() <
            cast<DeviceMappingAttrInterface>(b).getMappingId();
@@ -682,12 +691,17 @@ DiagnosedSilenceableFailure transform::MapForallToBlocks::applyToOne(
 
   // The BlockIdBuilder adapts to whatever is thrown at it.
   bool useLinearMapping = false;
-  if (topLevelForallOp.getMapping()) {
-    auto mappingAttr = cast<DeviceMappingAttrInterface>(
-        topLevelForallOp.getMapping()->getValue().front());
-    useLinearMapping = mappingAttr.isLinearMapping();
-  }
-  GpuBlockIdBuilder gpuBlockIdBuilder(getContext(), useLinearMapping);
+  if (topLevelForallOp.getMapping())
+    useLinearMapping = topLevelForallOp.usesLinearMapping();
+
+  FailureOr<DeviceMaskingAttrInterface> maybeMaskingAttr =
+      topLevelForallOp.getDeviceMaskingAttr();
+  assert(succeeded(maybeMaskingAttr) && "unexpected failed maybeMaskingAttr");
+  assert((!*maybeMaskingAttr || useLinearMapping) &&
+         "masking requires linear mapping");
+
+  GpuBlockIdBuilder gpuBlockIdBuilder(getContext(), useLinearMapping,
+                                      *maybeMaskingAttr);
 
   diag = mlir::transform::gpu::mapForallToBlocksImpl(
       rewriter, transformOp, topLevelForallOp, gridDims, gpuBlockIdBuilder);
@@ -744,8 +758,7 @@ static DiagnosedSilenceableFailure
 getThreadIdBuilder(std::optional<TransformOpInterface> transformOp,
                    scf::ForallOp forallOp, ArrayRef<int64_t> blockSizes,
                    int64_t warpSize, GpuIdBuilder &gpuIdBuilder) {
-  auto mappingAttr = cast<DeviceMappingAttrInterface>(
-      forallOp.getMapping()->getValue().front());
+  auto mappingAttr = forallOp.getDeviceMappingAttrs().front();
   bool useLinearMapping = mappingAttr.isLinearMapping();
 
   // Sanity checks that may result in runtime verification errors.
@@ -768,21 +781,30 @@ getThreadIdBuilder(std::optional<TransformOpInterface> transformOp,
   if (!diag.succeeded())
     return diag;
 
+  FailureOr<DeviceMaskingAttrInterface> maybeMaskingAttr =
+      forallOp.getDeviceMaskingAttr();
+  assert(succeeded(maybeMaskingAttr) && "unexpected failed maybeMaskingAttr");
+  assert((!*maybeMaskingAttr || useLinearMapping) &&
+         "masking requires linear mapping");
+
   // Start mapping.
   MLIRContext *ctx = forallOp.getContext();
   gpuIdBuilder =
       TypeSwitch<DeviceMappingAttrInterface, GpuIdBuilder>(mappingAttr)
           .Case([&](GPUWarpgroupMappingAttr) {
-            return GpuWarpgroupIdBuilder(ctx, warpSize, useLinearMapping);
+            return GpuWarpgroupIdBuilder(ctx, warpSize, useLinearMapping,
+                                         *maybeMaskingAttr);
           })
           .Case([&](GPUWarpMappingAttr) {
-            return GpuWarpIdBuilder(ctx, warpSize, useLinearMapping);
+            return GpuWarpIdBuilder(ctx, warpSize, useLinearMapping,
+                                    *maybeMaskingAttr);
           })
           .Case([&](GPUThreadMappingAttr) {
-            return GpuThreadIdBuilder(ctx, useLinearMapping);
+            return GpuThreadIdBuilder(ctx, useLinearMapping, *maybeMaskingAttr);
           })
           .Case([&](GPULaneMappingAttr) {
-            return GpuLaneIdBuilder(ctx, warpSize, useLinearMapping);
+            return GpuLaneIdBuilder(ctx, warpSize, useLinearMapping,
+                                    *maybeMaskingAttr);
           })
           .Default([&](DeviceMappingAttrInterface) -> GpuIdBuilder {
             llvm_unreachable("unknown mapping attribute");
diff --git a/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp b/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp
index 795d643c05912..6f4ad27a72a94 100644
--- a/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp
+++ b/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp
@@ -44,7 +44,7 @@ using namespace mlir::transform::gpu;
 #define DEBUG_TYPE "gpu-transforms"
 
 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
-#define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
 #define DBGS_ALIAS() (llvm::dbgs() << '[' << DEBUG_TYPE_ALIAS << "] ")
 
 /// Build predicates to filter execution by only the activeIds. Along each
@@ -120,10 +120,25 @@ static Value buildLinearId(RewriterBase &rewriter, Location loc,
 /// it in the basis of `forallMappingSizes`. The linear id builder returns an
 /// n-D vector of ids for indexing and 1-D size + id for predicate generation.
 template <typename ThreadOrBlockIdOp>
-static GpuIdBuilderFnType commonLinearIdBuilderFn(int64_t multiplicity = 1) {
-  auto res = [multiplicity](RewriterBase &rewriter, Location loc,
-                            ArrayRef<int64_t> forallMappingSizes,
-                            ArrayRef<int64_t> originalBasis) {
+static GpuIdBuilderFnType
+commonLinearIdBuilderFn(int64_t multiplicity = 1,
+                        DeviceMaskingAttrInterface mask = nullptr) {
+  auto res = [multiplicity, mask](RewriterBase &rewriter, Location loc,
+                                  ArrayRef<int64_t> forallMappingSizes,
+                                  ArrayRef<int64_t> originalBasis) {
+    // 0. Early-exit mask case.
+    if (mask) {
+      if (computeProduct(originalBasis) >
+          mask.getMaxNumPhysicalIds() * multiplicity) {
+        return IdBuilderResult{
+            /*errorMsg=*/std::string(
+                "mask representation too short to capture all physical ids: ") +
+                std::to_string(mask.getMaxNumPhysicalIds()),
+            /*mappingIdOps=*/{},
+            /*predicateOps=*/{}};
+      }
+    }
+
     // 1. Compute linearId.
     SmallVector<OpFoldResult> originalBasisOfr =
         getAsIndexOpFoldResult(rewriter.getContext(), originalBasis);
@@ -132,9 +147,25 @@ static GpuIdBuilderFnType commonLinearIdBuilderFn(int64_t multiplicity = 1) {
 
     // 2. Compute scaledLinearId.
     AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
-    OpFoldResult scaledLinearId = affine::makeComposedFoldedAffineApply(
+    OpFoldResult scaledLinearIdOfr = affine::makeComposedFoldedAffineApply(
         rewriter, loc, d0.floorDiv(multiplicity), {physicalLinearId});
 
+    // 2.b. Adjust with mask if needed.
+    Value scaledLinearIdI64;
+    Value scaledLinearId =
+        getValueOrCreateConstantIndexOp(rewriter, loc, scaledLinearIdOfr);
+    if (mask) {
+      scaledLinearId =
+          getValueOrCreateConstantIndexOp(rewriter, loc, scaledLinearIdOfr);
+      scaledLinearIdI64 = rewriter.create<arith::IndexCastUIOp>(
+          loc, rewriter.getI64Type(), scaledLinearId);
+      Value logicalLinearIdI64 =
+          mask.getLogicalLinearMappingId(rewriter, scaledLinearIdI64);
+      scaledLinearId = rewriter.create<arith::IndexCastUIOp>(
+          loc, rewriter.getIndexType(), logicalLinearIdI64);
+      LDBG("------adjusting linearId with mask: " << scaledLinearId);
+    }
+
     // 3. Compute remapped indices.
     SmallVector<Value> ids;
     // Sizes in [0 .. n] -> [n .. 0] order to properly compute strides in
@@ -148,15 +179,23 @@ static GpuIdBuilderFnType commonLinearIdBuilderFn(int64_t multiplicity = 1) {
           affine::makeComposedAffineApply(rewriter, loc, e, {scaledLinearId}));
     }
 
-    // 4. Handle predicates using physicalLinearId.
     std::string errorMsg;
     SmallVector<Value> predicateOps;
-    FailureOr<SmallVector<Value>> maybePredicateOps =
-        buildPredicates(rewriter, loc, physicalLinearId,
-                        computeProduct(forallMappingSizes) * multiplicity,
-                        computeProduct(originalBasis), errorMsg);
-    if (succeeded(maybePredicateOps))
-      predicateOps = *maybePredicateOps;
+    // 4. If mask present, it takes precedence to determine predication.
+    if (mask) {
+      Value isActiveIdPredicate =
+          mask.getIsActiveIdPredicate(rewriter, scaledLinearIdI64);
+      LDBG("------adjusting predicate with mask: " << isActiveIdPredicate);
+      predicateOps.push_back(isActiveIdPredicate);
+    } else {
+      // 4.b. Otherwise, handle predicates using physicalLinearId.
+      FailureOr<SmallVector<Value>> maybePredicateOps =
+          buildPredicates(rewriter, loc, physicalLinearId,
+                          computeProduct(forallMappingSizes) * multiplicity,
+                          computeProduct(originalBasis), errorMsg);
+      if (succeeded(maybePredicateOps))
+        predicateOps = *maybePredicateOps;
+    }
 
     return IdBuilderResult{/*errorMsg=*/errorMsg,
                            /*mappingIdOps=*/ids,
@@ -271,58 +310,67 @@ GpuIdBuilder::GpuIdBuilder(MLIRContext *ctx, bool useLinearMapping,
   }
 }
 
-GpuBlockIdBuilder::GpuBlockIdBuilder(MLIRContext *ctx, bool useLinearMapping)
+GpuBlockIdBuilder::GpuBlockIdBuilder(MLIRContext *ctx, bool useLinearMapping,
+                                     DeviceMaskingAttrInterface mask)
     : GpuIdBuilder(ctx, useLinearMapping, [](MLIRContext *ctx, MappingId id) {
         return GPUBlockMappingAttr::get(ctx, id);
       }) {
+  assert((!mask || useLinearMapping) && "mask requires linear mapping");
   idBuilder = useLinearMapping
-                  ? commonLinearIdBuilderFn<BlockIdOp>(/*multiplicity=*/1)
+                  ? commonLinearIdBuilderFn<BlockIdOp>(/*multiplicity=*/1, mask)
                   : common3DIdBuilderFn<BlockIdOp>(/*multiplicity=*/1);
 }
 
 GpuWarpgroupIdBuilder::GpuWarpgroupIdBuilder(MLIRContext *ctx, int64_t warpSize,
-                                             bool useLinearMapping)
+                                             bool useLinearMapping,
+                                             DeviceMaskingAttrInterface mask)
     : GpuIdBuilder(ctx, useLinearMapping,
                    [](MLIRContext *ctx, MappingId id) {
                      return GPUWarpgroupMappingAttr::get(ctx, id);
                    }),
       warpSize(warpSize) {
+  assert((!mask || useLinearMapping) && "mask requires linear mapping");
   idBuilder = useLinearMapping
                   ? commonLinearIdBuilderFn<ThreadIdOp>(
-                        /*multiplicity=*/kNumWarpsPerGroup * warpSize)
+                        /*multiplicity=*/kNumWarpsPerGroup * warpSize, mask)
                   : common3DIdBuilderFn<ThreadIdOp>(
                         /*multiplicity=*/kNumWarpsPerGroup * warpSize);
 }
 
 GpuWarpIdBuilder::GpuWarpIdBuilder(MLIRContext *ctx, int64_t warpSize,
-                                   bool useLinearMapping)
+                                   bool useLinearMapping,
+                                   DeviceMaskingAttrInterface mask)
     : GpuIdBuilder(ctx, useLinearMapping,
                    [](MLIRContext *ctx, MappingId id) {
                      return GPUWarpMappingAttr::get(ctx, id);
                    }),
       warpSize(warpSize) {
-  idBuilder =
-      useLinearMapping
-          ? commonLinearIdBuilderFn<ThreadIdOp>(/*multiplicity=*/warpSize)
-          : common3DIdBuilderFn<ThreadIdOp>(/*multiplicity=*/warpSize);
+  assert((!mask || useLinearMapping) && "mask requires linear mapping");
+  idBuilder = useLinearMapping
+                  ? commonLinearIdBuilderFn<ThreadIdOp>(
+                        /*multiplicity=*/warpSize, mask)
+                  : common3DIdBuilderFn<ThreadIdOp>(/*multiplicity=*/warpSize);
 }
 
-GpuThreadIdBuilder::GpuThreadIdBuilder(MLIRContext *ctx, bool useLinearMapping)
+GpuThreadIdBuilder::GpuThreadIdBuilder(MLIRContext *ctx, bool useLinearMapping,
+                                       DeviceMaskingAttrInterface mask)
     : GpuIdBuilder(ctx, useLinearMapping, [](MLIRContext *ctx, MappingId id) {
         return GPUThreadMappingAttr::get(ctx, id);
       }) {
-  idBuilder = useLinearMapping
-                  ? commonLinearIdBuilderFn<ThreadIdOp>(/*multiplicity=*/1)
-                  : common3DIdBuilderFn<ThreadIdOp>(/*multiplicity=*/1);
+  idBuilder =
+      useLinearMapping
+          ? commonLinearIdBuilderFn<ThreadIdOp>(/*multiplicity=*/1, mask)
+          : common3DIdBuilderFn<ThreadIdOp>(/*multiplicity=*/1);
 }
 
 GpuLaneIdBuilder::GpuLaneIdBuilder(MLIRContext *ctx, int64_t warpSize,
-                                   bool unused)
+                                   bool unused, DeviceMaskingAttrInterface mask)
     : GpuIdBuilder(ctx, /*useLinearMapping=*/true,
                    [](MLIRContext *ctx, MappingId id) {
                      return GPULaneMappingAttr::get(ctx, id);
                    }),
       warpSize(warpSize) {
+  assert(!mask && "mask NYI for lanes, unclear it should be at all");
   idBuilder = laneIdBuilderFn(/*periodicity=*/warpSize);
 }
 
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 79012dbd32f80..5a3bd984530db 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1175,13 +1175,11 @@ LogicalResult ForallOp::verify() {
       return emitOpError("type mismatch between ")
              << i << "-th output and corresponding block argument";
   if (getMapping().has_value() && !getMapping()->empty()) {
-    if (static_cast<int64_t>(getMapping()->size()) != numLoops)
+    if (getDeviceMappingAttrs().size() != numLoops)
       return emitOpError() << "mapping attribute size must match op rank";
-    for (auto map : getMapping()->getValue()) {
-      if (!isa<DeviceMappingAttrInterface>(map))
-        return emitOpError()
-               << getMappingAttrName() << " is not device mapping attribute";
-    }
+    if (failed(getDeviceMaskingAttr()))
+      return emitOpError() << getMappingAttrName()
+                           << " supports at most one device masking attribute";
   }
 
   // Verify mixed static/dynamic control variables.
@@ -1435,6 +1433,39 @@ SmallVector<Operation *> ForallOp::getCombiningOps(BlockArgument bbArg) {
   return storeOps;
 }
 
+SmallVector<DeviceMappingAttrInterface> ForallOp::getDeviceMappingAttrs() {
+  SmallVector<DeviceMappingAttrInterface> res;
+  if (!getMapping())
+    return res;
+  for (auto attr : getMapping()->getValue()) {
+    auto m = dyn_cast<DeviceMappingAttrInterface>(attr);
+    if (m)
+      res.push_back(m);
+  }
+  return res;
+}
+
+FailureOr<DeviceMaskingAttrInterface> ForallOp::getDeviceMaskingAttr() {
+  DeviceMaskingAttrInterface res;
+  if (!getMapping())
+    return res;
+  for (auto attr : getMapping()->getValue()) {
+    auto m = dyn_cast<DeviceMaskingAttrInterface>(attr);
+    if (m && res)
+      return failure();
+    if (m)
+      res = m;
+  }
+  return res;
+}
+
+bool ForallOp::usesLinearMapping() {
+  SmallVector<DeviceMappingAttrInterface> ifaces = getDeviceMappingAttrs();
+  if (ifaces.empty())
+    return false;
+  return ifaces.front().isLinearMapping();
+}
+
 std::optional<SmallVector<Value>> ForallOp::getLoopInductionVars() {
   return SmallVector<Value>{getBody()->getArguments().take_front(getRank())};
 }
diff --git a/mlir/test/Dialect/GPU/transform-gpu-failing.mlir b/mlir/test/Dialect/GPU/transform-gpu-failing.mlir
index 8d7a1aa2a55fc..bc052a0230a8e 100644
--- a/mlir/test/Dialect/GPU/transform-gpu-failing.mlir
+++ b/mlir/test/Dialect/GPU/transform-gpu-failing.mlir
@@ -405,6 +405,67 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+func.func @masking_mapping_attribute_requires_linear_mapping(
+    %x: memref<32xf32>, %y: memref<32xf32>, %t: memref<32 x f32>, %alpha : f32, %stream : !gpu.async.token) -> memref<32xf32> {
+  %one = arith.constant 1 : index
+  %c9 = arith.constant 9 : index
+  %c7 = arith.constant 7 : index
+  %name = gpu.launch async[%stream] blocks(%arg3, %arg4, %arg5) in (%arg9 = %one, %arg10 = %one, %arg11 = %one)
+            threads(%arg6, %arg7, %arg8) in (%arg12 = %one, %arg13 = %one, %arg14 = %one)
+  {
+    scf.forall (%i) in (%c7) {
+        %4 = memref.load %x[%i] : memref<32xf32>
+        %5 = memref.load %y[%i] : memref<32xf32>
+        %6 = math.fma %alpha, %4, %5 : f32
+        memref.store %6, %y[%i] : memref<32xf32>
+     }  { mapping = [#gpu.warp<x>, #gpu.mask<0x33>] }
+    gpu.terminator
+  }
+
+  return %y : memref<32xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    // expected-error @below {{device masking is only available in linear mapping mode}}
+    transform.gpu.map_nested_forall_to_threads %funcop block_dims = [1, 1, 1] : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @masking_mapping_attribute_requires_linear_mapping(
+    %x: memref<32xf32>, %y: memref<32xf32>, %t: memref<32 x f32>, %alpha : f32, %stream : !gpu.async.token) -> memref<32xf32> {
+  %one = arith.constant 1 : index
+  %c99 = arith.constant 99 : index
+  %name = gpu.launch async[%stream] blocks(%arg3, %arg4, %arg5) in (%arg9 = %one, %arg10 = %one, %arg11 = %one)
+            threads(%arg6, %arg7, %arg8) in (%arg12 = %one, %arg13 = %one, %arg14 = %one)
+  {
+    scf.forall (%i) in (%c99) {
+        %4 = memref.load %x[%i] : memref<32xf32>
+        %5 = memref.load %y[%i] : memref<32xf32>
+        %6 = math.fma %alpha, %4, %5 : f32
+        memref.store %6, %y[%i] : memref<32xf32>
+     }  { mapping = [#gpu.thread<linear_dim_0>, #gpu.mask<0xff>] }
+    gpu.terminator
+  }
+
+  return %y : memref<32xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    // expected-error @below {{mask representation too short to capture all physical ids: 64}}
+    transform.gpu.map_nested_forall_to_threads %funcop block_dims = [128, 1, 1] : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 func.func public @not_a_block_mapping_attribute(%arg0: memref<32x32xf32>, %arg1: memref<32x32xf32>, %arg2: memref<32x32xf32>) {
   scf.forall (%arg3, %arg4) in (1, 1) {
     linalg.matmul ins(%arg0, %arg1 : memref<32x32xf32>, memref<32x32xf32>) outs(%arg2 : memref<32x32xf32>)
diff --git a/mlir/test/Dialect/GPU/transform-gpu.mlir b/mlir/test/Dialect/GPU/transform-gpu.mlir
index fe5d451408355..71c7274aa7e67 100644
--- a/mlir/test/Dialect/GPU/transform-gpu.mlir
+++ b/mlir/test/Dialect/GPU/transform-gpu.mlir
@@ -754,3 +754,84 @@ module attributes {transform.with_named_sequence} {
       transform.yield
   }
 }
+
+// -----
+
+#map = affine_map<(d0) -> (d0 *  128)>
+#map1 = affine_map<(d0) -> (d0 * 32)>
+
+// CHECK-DAG: #[[$MAPB:.*]] = affine_map<()[s0] -> (s0 * 128)>
+// CHECK-DAG: #[[$MAP_LIN_W:.*]] = affine_map<()[s0, s1] -> ((s0 + s1 * 73) floordiv 32)>
+// CHECK-DAG: #[[$MAP_W0:.*]] = affine_map<()[s0] -> (s0 * 32 - (s0 floordiv 2) * 64)>
+// CHECK-DAG: #[[$MAP_W1:.*]] = affine_map<()[s0] -> ((s0 floordiv 2) * 32)>
+
+// CHECK-LABEL: func.func @simple_fill(
+func.func @simple_fill(%arg0: memref<128xf32>) -> memref<128xf32> {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant dense<0.000000e+00> : vector<32xf32>
+//   CHECK-DAG:   %[[C0_i64:.*]] = arith.constant 0 : i64
+//   CHECK-DAG:   %[[C1_i64:.*]] = arith.constant 1 : i64
+/// 0x2f1 is 753
+//   CHECK-DAG:   %[[C753_i64:.*]] = arith.constant 753 : i64
+
+//       CHECK:   gpu.launch
+  scf.forall (%arg1) in (1) {
+//       CHECK:     %[[BIDX:.*]] = gpu.block_id  x
+//       CHECK:     %[[BLX:.*]] = affine.apply #[[$MAPB]]()[%[[BIDX]]]
+    %0 = affine.apply #map(%arg1)
+    %subview = memref.subview %arg0[%0] [128] [1] : memref<128xf32> to memref<128xf32, strided<[1], offset: ?>>
+
+    // %arg2 and %arg3 map to lanes [0, 6) and are turned into epxressions
+    // involving threadIdx.x/y by the map_nested_forall_to_threads
+    // transformation. This results in a if (linear_thread_id < 6) conditional.
+    scf.forall (%arg2, %arg3) in (2, 3) {
+      //       CHECK:     %[[TIDX:.*]] = gpu.thread_id  x
+      //       CHECK:     %[[TIDY:.*]] = gpu.thread_id  y
+
+      //       CHECK:     %[[LIN_W:.*]] = affine.apply #[[$MAP_LIN_W]]()[%[[TIDX]], %[[TIDY]]]
+      //
+      // Compute the active warps below using the mask + popcnt
+      //       CHECK:     %[[LIN_W_i64:.*]] = arith.index_castui %[[LIN_W]] : index to i64
+      //       CHECK:     %[[TWO_POW_W:.*]] = arith.shli %[[C1_i64]], %[[LIN_W_i64]] : i64
+      //       CHECK:     %[[FILTER_TILL_W:.*]] = arith.subi %[[TWO_POW_W]], %[[C1_i64]] : i64
+      //       CHECK:     %[[ACTIVE_TILL_W:.*]] = arith.andi %[[FILTER_TILL_W]], %[[C753_i64]] : i64
+      //       CHECK:     %[[LOGICAL_ID_W_i64:.*]] = math.ctpop %[[ACTIVE_TILL_W]] : i64
+      //       CHECK:     %[[LOGICAL_ID_W:.*]] = arith.index_castui %[[LOGICAL_ID_W_i64]] : i64 to index
+      //
+      // Dynamically compute whether this warp is active below using the mask + popcnt
+      //       CHECK:     %[[IS_ACTIVE_W_MASK:.*]] = arith.andi %[[TWO_POW_W]], %[[C753_i64]] : i64
+      //       CHECK:     %[[IS_ACTIVE_W:.*]] = arith.cmpi ne, %[[IS_ACTIVE_W_MASK]], %[[C0_i64]] : i64
+      //       CHECK:     scf.if %[[IS_ACTIVE_W]] {
+
+      //       CHECK:       %[[W0:.*]] = affine.apply #[[$MAP_W0]]()[%[[LOGICAL_ID_W]]]
+      //       CHECK:       %[[W1:.*]] = affine.apply #[[$MAP_W1]]()[%[[LOGICAL_ID_W]]]
+      //       CHECK:       memref.subview %{{.*}}[%[[W0]]] [%[[W1]]]
+      %1 = affine.apply #map1(%arg2)
+      %2 = affine.apply #map1(%arg3)
+      %subview_0 = memref.subview %subview[%1] [%2] [1] : memref<128xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[1], offset: ?>>
+      vector.transfer_write %cst, %subview_0[%c0] {in_bounds = [true]} : vector<32xf32>, memref<?xf32, strided<[1], offset: ?>>
+
+    // This could be obtained e.g. if a previous transformation mapped this loop
+    // to lanes. This can aslo be written by hand as valid IR.
+    // This additionally uses the hex mask: 0x 10 1111 0001
+    } {mapping = [#gpu.warp<linear_dim_0>, #gpu.warp<linear_dim_1>, #gpu.mask<0x2f1>]}
+
+    memref.copy %subview, %subview : memref<128xf32, strided<[1], offset: ?>> to memref<128xf32, strided<[1], offset: ?>>
+  } {mapping = [#gpu.block<x>]}
+  return %arg0 : memref<128xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %module_op
+      : (!transform.any_op) -> !transform.any_op
+    %gpu_launch = transform.gpu.map_forall_to_blocks %func generate_gpu_launch
+      : (!transform.any_op) -> !transform.any_op
+
+    // This transformation maps scf.forall ivs to a particular mapping of thread
+    // ids (laneid, threadid, warpid or warpgroupid).
+    transform.gpu.map_nested_forall_to_threads %gpu_launch block_dims = [73, 5, 1]
+      : (!transform.any_op) -> !transform.any_op
+      transform.yield
+  }
+}
diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir
index b944852ceba3f..bb7958083e55c 100644
--- a/mlir/test/Dialect/SCF/invalid.mlir
+++ b/mlir/test/Dialect/SCF/invalid.mlir
@@ -684,6 +684,24 @@ func.func @forall_wrong_terminator_op() -> () {
 
 // -----
 
+func.func @at_most_one_masking_attribute(%in: tensor<100xf32>, %out: tensor<100xf32>) {
+  %c1 = arith.constant 1 : index
+  %num_threads = arith.constant 100 : index
+
+  // expected-error @below {{"mapping" supports at most one device masking attribute}}
+  %result = scf.forall (%thread_idx) in (%num_threads) shared_outs(%o = %out) -> (tensor<100xf32>) {
+      %1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32>
+      scf.forall.in_parallel {
+        tensor.parallel_insert_slice %1 into %o[%thread_idx][1][1] :
+          tensor<1xf32> into tensor<100xf32>
+      }
+  }  { mapping = [#gpu.thread<x>, #gpu.mask<0x1>, #gpu.mask<0x2>] }
+
+  return
+}
+
+// -----
+
 func.func @switch_wrong_case_count(%arg0: index) {
   // expected-error @below {{'scf.index_switch' op has 0 case regions but 1 case values}}
   "scf.index_switch"(%arg0) ({

>From 0cb1bc9465fba89e7cf69cb5a0d53fe6be4d97c1 Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <Nico.Vasilache at amd.com>
Date: Mon, 7 Jul 2025 15:47:01 +0200
Subject: [PATCH 4/4] Update
 mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td

Co-authored-by: Oleksandr "Alex" Zinenko <git at ozinenko.com>
---
 mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td b/mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td
index 353aaf05bee0c..5a07caabeee34 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td
@@ -92,7 +92,7 @@ def DeviceMaskingAttrInterface : AttrInterface<"DeviceMaskingAttrInterface"> {
     >,
     InterfaceMethod<
       /*desc=*/[{
-        Return the maximal number of pysical ids supported.
+        Return the maximal number of physical ids supported.
         This is to account for temporary implementation limitations (e.g. i64)
         and fail gracefully with actionnable error messages.
       }],



More information about the Mlir-commits mailing list