[Mlir-commits] [mlir] [mlir][gpu] Add 'cluster_stride' attribute to gpu.subgroup_reduce (PR #107142)

Jakub Kuderski llvmlistbot at llvm.org
Wed Sep 4 10:44:37 PDT 2024


================
@@ -140,34 +137,68 @@ struct ScalarizeSingleElementReduce final
     Location loc = op.getLoc();
     Value extracted = rewriter.create<vector::ExtractOp>(loc, op.getValue(), 0);
     Value reduce = rewriter.create<gpu::SubgroupReduceOp>(
-        loc, extracted, op.getOp(), op.getUniform(), clusterSize);
+        loc, extracted, op.getOp(), op.getUniform(), op.getClusterSize(),
+        op.getClusterStride());
     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecTy, reduce);
     return success();
   }
 };
 
+struct ClusterInfo {
+  unsigned clusterStride;
+  unsigned clusterSize;
+  unsigned subgroupSize;
+};
+
+static FailureOr<ClusterInfo>
+getAndValidateClusterInfo(gpu::SubgroupReduceOp op, unsigned subgroupSize) {
+  ClusterInfo res;
+
+  assert(llvm::isPowerOf2_32(subgroupSize));
+  res.subgroupSize = subgroupSize;
+
+  std::optional<uint32_t> clusterSize = op.getClusterSize();
+  assert(!clusterSize ||
+         llvm::isPowerOf2_32(*clusterSize)); // Verifier should've caught this
+  if (clusterSize && *clusterSize > subgroupSize)
+    return op.emitOpError()
+           << "cluster size " << *clusterSize
+           << " is greater than subgroup size " << subgroupSize;
+  res.clusterSize = clusterSize.value_or(subgroupSize); // Effective size
+
+  auto clusterStride = op.getClusterStride();
+  assert(llvm::isPowerOf2_32(clusterStride)); // Verifier should've caught this
----------------
kuhar wrote:

```suggestion
  assert(llvm::isPowerOf2_32(clusterStride)); // Verifier should've caught this.
```

https://github.com/llvm/llvm-project/pull/107142


More information about the Mlir-commits mailing list