[Mlir-commits] [mlir] [mlir][gpu] Allow subgroup reductions over 1-d vector types (PR #76015)

Lei Zhang llvmlistbot at llvm.org
Wed Dec 20 15:41:05 PST 2023


================
@@ -591,10 +593,12 @@ class GPUSubgroupReduceConversion final
   LogicalResult
   matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    auto opType = op.getOp();
-    auto result =
-        createGroupReduceOp(rewriter, op.getLoc(), adaptor.getValue(), opType,
-                            /*isGroup*/ false, op.getUniform());
+    if (!isa<spirv::ScalarType>(adaptor.getValue().getType()))
+      return rewriter.notifyMatchFailure(op, "reduction type is not a scalar");
+
+    auto result = createGroupReduceOp(rewriter, op.getLoc(), adaptor.getValue(),
+                                      adaptor.getOp(),
+                                      /*isGroup*/ false, adaptor.getUniform());
----------------
antiagainst wrote:

Nit: `/*isGroup=*/`.

https://github.com/llvm/llvm-project/pull/76015


More information about the Mlir-commits mailing list