[Mlir-commits] [mlir] 993a38f - [MLIR][Affine] Extend getVectorReductionOp to support xor/maxnumf/minnumf (#163310)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Nov 3 20:52:55 PST 2025


Author: Lee Wei
Date: 2025-11-04T12:52:51+08:00
New Revision: 993a38fa539d23f83711a0e07d3cc40a0947ec7e

URL: https://github.com/llvm/llvm-project/commit/993a38fa539d23f83711a0e07d3cc40a0947ec7e
DIFF: https://github.com/llvm/llvm-project/commit/993a38fa539d23f83711a0e07d3cc40a0947ec7e.diff

LOG: [MLIR][Affine] Extend getVectorReductionOp to support xor/maxnumf/minnumf (#163310)

This PR extends the `getVectorReductionOp` function, which is used by
the affine vectorizer, to also recognize and support
`xor/maxnumf/minnumf` reduction operations.

Added: 
    

Modified: 
    mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Conversion/ConvertToSPIRV/vector.mlir
    mlir/test/Dialect/Affine/SuperVectorize/vectorize_reduction.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
index 4d2d8738aa4ad..3d1a73417d1ea 100644
--- a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
@@ -66,9 +66,10 @@ static Value getSupportedReduction(AffineForOp forOp, unsigned pos,
           .Case([](arith::MaxSIOp) { return arith::AtomicRMWKind::maxs; })
           .Case([](arith::MinUIOp) { return arith::AtomicRMWKind::minu; })
           .Case([](arith::MaxUIOp) { return arith::AtomicRMWKind::maxu; })
+          .Case([](arith::XOrIOp) { return arith::AtomicRMWKind::xori; })
+          .Case([](arith::MaxNumFOp) { return arith::AtomicRMWKind::maxnumf; })
+          .Case([](arith::MinNumFOp) { return arith::AtomicRMWKind::minnumf; })
           .Default([](Operation *) -> std::optional<arith::AtomicRMWKind> {
-            // TODO: AtomicRMW supports other kinds of reductions this is
-            // currently not detecting, add those when the need arises.
             return std::nullopt;
           });
   if (!maybeKind)

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index ae3423c40040d..daef0ba02100a 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -717,7 +717,15 @@ Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op,
   case arith::AtomicRMWKind::ori:
     return vector::ReductionOp::create(builder, vector.getLoc(),
                                        CombiningKind::OR, vector);
-  // TODO: Add remaining reduction operations.
+  case arith::AtomicRMWKind::minnumf:
+    return vector::ReductionOp::create(builder, vector.getLoc(),
+                                       CombiningKind::MINNUMF, vector);
+  case arith::AtomicRMWKind::maxnumf:
+    return vector::ReductionOp::create(builder, vector.getLoc(),
+                                       CombiningKind::MAXNUMF, vector);
+  case arith::AtomicRMWKind::xori:
+    return vector::ReductionOp::create(builder, vector.getLoc(),
+                                       CombiningKind::XOR, vector);
   default:
     (void)emitOptionalError(loc, "Reduction operation type not supported");
     break;

diff  --git a/mlir/test/Conversion/ConvertToSPIRV/vector.mlir b/mlir/test/Conversion/ConvertToSPIRV/vector.mlir
index a75f30d57fa74..cd8cfc8736915 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/vector.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/vector.mlir
@@ -275,6 +275,42 @@ func.func @reduction_minimumf(%v : vector<3xf32>, %s: f32) -> f32 {
 
 // -----
 
+// CHECK-LABEL:   spirv.func @reduction_minnumf(
+// CHECK-SAME:      %[[V:.*]]: vector<3xf32>,
+// CHECK-SAME:      %[[S:.*]]: f32) -> f32 "None" {
+// CHECK:           %[[S0:.*]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
+// CHECK:           %[[S1:.*]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
+// CHECK:           %[[S2:.*]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
+// CHECK:           %[[MIN0:.*]] = spirv.GL.FMin %[[S0]], %[[S1]] : f32
+// CHECK:           %[[MIN1:.*]] = spirv.GL.FMin %[[MIN0]], %[[S2]] : f32
+// CHECK:           %[[MIN2:.*]] = spirv.GL.FMin %[[MIN1]], %[[S]] : f32
+// CHECK:           spirv.ReturnValue %[[MIN2]] : f32
+// CHECK:         }
+func.func @reduction_minnumf(%v : vector<3xf32>, %s: f32) -> f32 {
+  %reduce = vector.reduction <minnumf>, %v, %s : vector<3xf32> into f32
+  return %reduce : f32
+}
+
+// -----
+
+// CHECK-LABEL:   spirv.func @reduction_maxnumf(
+// CHECK-SAME:      %[[V:.*]]: vector<3xf32>,
+// CHECK-SAME:      %[[S:.*]]: f32) -> f32 "None" {
+// CHECK:           %[[S0:.*]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
+// CHECK:           %[[S1:.*]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
+// CHECK:           %[[S2:.*]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
+// CHECK:           %[[MAX0:.*]] = spirv.GL.FMax %[[S0]], %[[S1]] : f32
+// CHECK:           %[[MAX1:.*]] = spirv.GL.FMax %[[MAX0]], %[[S2]] : f32
+// CHECK:           %[[MAX2:.*]] = spirv.GL.FMax %[[MAX1]], %[[S]] : f32
+// CHECK:           spirv.ReturnValue %[[MAX2]] : f32
+// CHECK:         }
+func.func @reduction_maxnumf(%v : vector<3xf32>, %s: f32) -> f32 {
+  %reduce = vector.reduction <maxnumf>, %v, %s : vector<3xf32> into f32
+  return %reduce : f32
+}
+
+// -----
+
 // CHECK-LABEL: func @reduction_maxsi
 //  CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32)
 //       CHECK:   %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xi32>

diff  --git a/mlir/test/Dialect/Affine/SuperVectorize/vectorize_reduction.mlir b/mlir/test/Dialect/Affine/SuperVectorize/vectorize_reduction.mlir
index b616632a6fe24..b062736575ad7 100644
--- a/mlir/test/Dialect/Affine/SuperVectorize/vectorize_reduction.mlir
+++ b/mlir/test/Dialect/Affine/SuperVectorize/vectorize_reduction.mlir
@@ -243,6 +243,106 @@ func.func @vecdim_reduction_ori(%in: memref<256x512xi32>, %out: memref<256xi32>)
 // CHECK:         affine.store %[[final_red]], %{{.*}} : memref<256xi32>
 // CHECK:       }
 
+// -----
+
+func.func @vecdim_reduction_xori(%in: memref<256x512xi32>, %out: memref<256xi32>) {
+ %cst = arith.constant 0 : i32
+ affine.for %i = 0 to 256 {
+   %final_red = affine.for %j = 0 to 512 iter_args(%red_iter = %cst) -> (i32) {
+     %ld = affine.load %in[%i, %j] : memref<256x512xi32>
+     %xor = arith.xori %red_iter, %ld : i32
+     affine.yield %xor : i32
+   }
+   affine.store %final_red, %out[%i] : memref<256xi32>
+ }
+ return
+}
+
+// CHECK-LABEL:   func.func @vecdim_reduction_xori(
+// CHECK-SAME:      %[[input:.*]]: memref<256x512xi32>,
+// CHECK-SAME:      %[[output:.*]]: memref<256xi32>) {
+// CHECK:           %[[cst:.*]] = arith.constant 0 : i32
+// CHECK:           affine.for %{{.*}} = 0 to 256 {
+// CHECK:             %[[vzero:.*]] = arith.constant dense<0> : vector<128xi32>
+// CHECK:             %[[vred:.*]] = affine.for %{{.*}} = 0 to 512 step 128 iter_args(%[[red_iter:.*]] = %[[vzero]]) -> (vector<128xi32>) {
+// CHECK:               %[[poison:.*]] = ub.poison : i32
+// CHECK:               %[[ld:.*]] = vector.transfer_read %[[input]]{{\[}}%{{.*}}, %{{.*}}], %[[poison]] : memref<256x512xi32>, vector<128xi32>
+// CHECK:               %[[xor:.*]] = arith.xori %[[red_iter]], %[[ld]] : vector<128xi32>
+// CHECK:               affine.yield %[[xor]] : vector<128xi32>
+// CHECK:             }
+// CHECK:             %[[final_red:.*]] = vector.reduction <xor>, %[[vred]] : vector<128xi32> into i32
+// CHECK:             affine.store %[[final_red]], %[[output]]{{\[}}%{{.*}}] : memref<256xi32>
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+
+// -----
+
+func.func @vecdim_reduction_minnumf(%in: memref<256x512xf32>, %out: memref<256xf32>) {
+ %cst = arith.constant 0xFF800000 : f32
+ affine.for %i = 0 to 256 {
+   %final_red = affine.for %j = 0 to 512 iter_args(%red_iter = %cst) -> (f32) {
+     %ld = affine.load %in[%i, %j] : memref<256x512xf32>
+     %min = arith.minnumf %red_iter, %ld : f32
+     affine.yield %min : f32
+   }
+   affine.store %final_red, %out[%i] : memref<256xf32>
+ }
+ return
+}
+
+// CHECK-LABEL:   func.func @vecdim_reduction_minnumf(
+// CHECK-SAME:      %[[input:.*]]: memref<256x512xf32>,
+// CHECK-SAME:      %[[output:.*]]: memref<256xf32>) {
+// CHECK:           %[[cst:.*]] = arith.constant 0xFF800000 : f32
+// CHECK:           affine.for %{{.*}} = 0 to 256 {
+// CHECK:             %[[vzero:.*]] = arith.constant dense<0x7FC00000> : vector<128xf32>
+// CHECK:             %[[vred:.*]] = affine.for %{{.*}} = 0 to 512 step 128 iter_args(%[[red_iter:.*]] = %[[vzero]]) -> (vector<128xf32>) {
+// CHECK:               %[[poison:.*]] = ub.poison : f32
+// CHECK:               %[[ld:.*]] = vector.transfer_read %[[input]]{{\[}}%{{.*}}, %{{.*}}], %[[poison]] : memref<256x512xf32>, vector<128xf32>
+// CHECK:               %[[min:.*]] = arith.minnumf %[[red_iter]], %[[ld]] : vector<128xf32>
+// CHECK:               affine.yield %[[min]] : vector<128xf32>
+// CHECK:             }
+// CHECK:             %[[red_scalar:.*]] = vector.reduction <minnumf>, %[[vred]] : vector<128xf32> into f32
+// CHECK:             %[[final_red:.*]] = arith.minnumf %[[red_scalar]], %[[cst]] : f32
+// CHECK:             affine.store %[[final_red]], %[[output]]{{\[}}%{{.*}}] : memref<256xf32>
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+
+// -----
+
+func.func @vecdim_reduction_maxnumf(%in: memref<256x512xf32>, %out: memref<256xf32>) {
+ %cst = arith.constant 0xFF800000 : f32
+ affine.for %i = 0 to 256 {
+   %final_red = affine.for %j = 0 to 512 iter_args(%red_iter = %cst) -> (f32) {
+     %ld = affine.load %in[%i, %j] : memref<256x512xf32>
+     %max = arith.maxnumf %red_iter, %ld : f32
+     affine.yield %max : f32
+   }
+   affine.store %final_red, %out[%i] : memref<256xf32>
+ }
+ return
+}
+
+// CHECK-LABEL:   func.func @vecdim_reduction_maxnumf(
+// CHECK-SAME:      %[[input:.*]]: memref<256x512xf32>,
+// CHECK-SAME:      %[[output:.*]]: memref<256xf32>) {
+// CHECK:           %[[cst:.*]] = arith.constant 0xFF800000 : f32
+// CHECK:           affine.for %{{.*}} = 0 to 256 {
+// CHECK:             %[[vzero:.*]] = arith.constant dense<0xFFC00000> : vector<128xf32>
+// CHECK:             %[[vred:.*]] = affine.for %{{.*}} = 0 to 512 step 128 iter_args(%[[red_iter:.*]] = %[[vzero]]) -> (vector<128xf32>) {
+// CHECK:               %[[poison:.*]] = ub.poison : f32
+// CHECK:               %[[ld:.*]] = vector.transfer_read %[[input]]{{\[}}%{{.*}}, %{{.*}}], %[[poison]] : memref<256x512xf32>, vector<128xf32>
+// CHECK:               %[[max:.*]] = arith.maxnumf %[[red_iter]], %[[ld]] : vector<128xf32>
+// CHECK:               affine.yield %[[max]] : vector<128xf32>
+// CHECK:             }
+// CHECK:             %[[red_scalar:.*]] = vector.reduction <maxnumf>, %[[vred]] : vector<128xf32> into f32
+// CHECK:             %[[final_red:.*]] = arith.maxnumf %[[red_scalar]], %[[cst]] : f32
+// CHECK:             affine.store %[[final_red]], %[[output]]{{\[}}%{{.*}}] : memref<256xf32>
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
 
 // -----
 


        


More information about the Mlir-commits mailing list