[Mlir-commits] [mlir] 7dc7790 - [mlir] Fix support for lowering non-32-bit affine reductions.

Alex Zinenko llvmlistbot at llvm.org
Tue Apr 6 05:00:24 PDT 2021


Author: Alex Zinenko
Date: 2021-04-06T14:00:15+02:00
New Revision: 7dc7790ec52efe799211edde9114ba5e467ccb36

URL: https://github.com/llvm/llvm-project/commit/7dc7790ec52efe799211edde9114ba5e467ccb36
DIFF: https://github.com/llvm/llvm-project/commit/7dc7790ec52efe799211edde9114ba5e467ccb36.diff

LOG: [mlir] Fix support for lowering non-32-bit affine reductions.

The existing implementation was always creating 32-bit constants for
floating-point and integer reductions regardless of the actual type, which
resulted in invalid IR being generated for any types other than f32 and i32
when lowering affine.parallel to SCF. Use the actual type instead.

Reviewed By: chelini

Differential Revision: https://reviews.llvm.org/D99942

Added: 
    

Modified: 
    mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
    mlir/test/Conversion/AffineToStandard/lower-affine.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
index 1ad07b2f6e068..c9e1754acce94 100644
--- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
+++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
@@ -368,17 +368,19 @@ class AffineForLowering : public OpRewritePattern<AffineForOp> {
 };
 
 /// Returns the identity value associated with an AtomicRMWKind op.
-static Value getIdentityValue(AtomicRMWKind op, OpBuilder &builder,
-                              Location loc) {
+static Value getIdentityValue(AtomicRMWKind op, Type resultType,
+                              OpBuilder &builder, Location loc) {
   switch (op) {
   case AtomicRMWKind::addf:
-    return builder.create<ConstantOp>(loc, builder.getF32FloatAttr(0));
+    return builder.create<ConstantOp>(loc, builder.getFloatAttr(resultType, 0));
   case AtomicRMWKind::addi:
-    return builder.create<ConstantOp>(loc, builder.getI32IntegerAttr(0));
+    return builder.create<ConstantOp>(loc,
+                                      builder.getIntegerAttr(resultType, 0));
   case AtomicRMWKind::mulf:
-    return builder.create<ConstantOp>(loc, builder.getF32FloatAttr(1));
+    return builder.create<ConstantOp>(loc, builder.getFloatAttr(resultType, 1));
   case AtomicRMWKind::muli:
-    return builder.create<ConstantOp>(loc, builder.getI32IntegerAttr(1));
+    return builder.create<ConstantOp>(loc,
+                                      builder.getIntegerAttr(resultType, 1));
   // TODO: Add remaining reduction operations.
   default:
     (void)emitOptionalError(loc, "Reduction operation type not supported");
@@ -453,15 +455,18 @@ class AffineParallelLowering : public OpRewritePattern<AffineParallelOp> {
     // scf.parallel handles the reduction operation 
diff erently unlike
     // affine.parallel.
     ArrayRef<Attribute> reductions = op.reductions().getValue();
-    for (Attribute reduction : reductions) {
+    for (auto pair : llvm::zip(reductions, op.getResultTypes())) {
       // For each of the reduction operations get the identity values for
       // initialization of the result values.
+      Attribute reduction = std::get<0>(pair);
+      Type resultType = std::get<1>(pair);
       Optional<AtomicRMWKind> reductionOp = symbolizeAtomicRMWKind(
           static_cast<uint64_t>(reduction.cast<IntegerAttr>().getInt()));
       assert(reductionOp.hasValue() &&
              "Reduction operation cannot be of None Type");
       AtomicRMWKind reductionOpValue = reductionOp.getValue();
-      identityVals.push_back(getIdentityValue(reductionOpValue, rewriter, loc));
+      identityVals.push_back(
+          getIdentityValue(reductionOpValue, resultType, rewriter, loc));
     }
     parOp = rewriter.create<scf::ParallelOp>(
         loc, lowerBoundTuple, upperBoundTuple, steps, identityVals,

diff  --git a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
index 4e358ae70134c..b02c96ef718e1 100644
--- a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
+++ b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
@@ -826,3 +826,81 @@ func @affine_parallel_with_reductions(%arg0: memref<3x3xf32>, %arg1: memref<3x3x
 // CHECK-NEXT:    }
 // CHECK-NEXT:    return
 // CHECK-NEXT:  }
+
+/////////////////////////////////////////////////////////////////////
+
+func @affine_parallel_with_reductions_f64(%arg0: memref<3x3xf64>, %arg1: memref<3x3xf64>) -> (f64, f64) {
+  %0:2 = affine.parallel (%kx, %ky) = (0, 0) to (2, 2) reduce ("addf", "mulf") -> (f64, f64) {
+            %1 = affine.load %arg0[%kx, %ky] : memref<3x3xf64>
+            %2 = affine.load %arg1[%kx, %ky] : memref<3x3xf64>
+            %3 = mulf %1, %2 : f64
+            %4 = addf %1, %2 : f64
+            affine.yield %3, %4 : f64, f64
+          }
+  return %0#0, %0#1 : f64, f64
+}
+// CHECK-LABEL: @affine_parallel_with_reductions_f64
+// CHECK:  %[[LOWER_1:.*]] = constant 0 : index
+// CHECK:  %[[LOWER_2:.*]] = constant 0 : index
+// CHECK:  %[[UPPER_1:.*]] = constant 2 : index
+// CHECK:  %[[UPPER_2:.*]] = constant 2 : index
+// CHECK:  %[[STEP_1:.*]] = constant 1 : index
+// CHECK:  %[[STEP_2:.*]] = constant 1 : index
+// CHECK:  %[[INIT_1:.*]] = constant 0.000000e+00 : f64
+// CHECK:  %[[INIT_2:.*]] = constant 1.000000e+00 : f64
+// CHECK:  %[[RES:.*]] = scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[LOWER_1]], %[[LOWER_2]]) to (%[[UPPER_1]], %[[UPPER_2]]) step (%[[STEP_1]], %[[STEP_2]]) init (%[[INIT_1]], %[[INIT_2]]) -> (f64, f64) {
+// CHECK:    %[[VAL_1:.*]] = memref.load
+// CHECK:    %[[VAL_2:.*]] = memref.load
+// CHECK:    %[[PRODUCT:.*]] = mulf
+// CHECK:    %[[SUM:.*]] = addf
+// CHECK:    scf.reduce(%[[PRODUCT]]) : f64 {
+// CHECK:    ^bb0(%[[LHS:.*]]: f64, %[[RHS:.*]]: f64):
+// CHECK:      %[[RES:.*]] = addf
+// CHECK:      scf.reduce.return %[[RES]] : f64
+// CHECK:    }
+// CHECK:    scf.reduce(%[[SUM]]) : f64 {
+// CHECK:    ^bb0(%[[LHS:.*]]: f64, %[[RHS:.*]]: f64):
+// CHECK:      %[[RES:.*]] = mulf
+// CHECK:      scf.reduce.return %[[RES]] : f64
+// CHECK:    }
+// CHECK:    scf.yield
+// CHECK:  }
+
+/////////////////////////////////////////////////////////////////////
+
+func @affine_parallel_with_reductions_i64(%arg0: memref<3x3xi64>, %arg1: memref<3x3xi64>) -> (i64, i64) {
+  %0:2 = affine.parallel (%kx, %ky) = (0, 0) to (2, 2) reduce ("addi", "muli") -> (i64, i64) {
+            %1 = affine.load %arg0[%kx, %ky] : memref<3x3xi64>
+            %2 = affine.load %arg1[%kx, %ky] : memref<3x3xi64>
+            %3 = muli %1, %2 : i64
+            %4 = addi %1, %2 : i64
+            affine.yield %3, %4 : i64, i64
+          }
+  return %0#0, %0#1 : i64, i64
+}
+// CHECK-LABEL: @affine_parallel_with_reductions_i64
+// CHECK:  %[[LOWER_1:.*]] = constant 0 : index
+// CHECK:  %[[LOWER_2:.*]] = constant 0 : index
+// CHECK:  %[[UPPER_1:.*]] = constant 2 : index
+// CHECK:  %[[UPPER_2:.*]] = constant 2 : index
+// CHECK:  %[[STEP_1:.*]] = constant 1 : index
+// CHECK:  %[[STEP_2:.*]] = constant 1 : index
+// CHECK:  %[[INIT_1:.*]] = constant 0 : i64
+// CHECK:  %[[INIT_2:.*]] = constant 1 : i64
+// CHECK:  %[[RES:.*]] = scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[LOWER_1]], %[[LOWER_2]]) to (%[[UPPER_1]], %[[UPPER_2]]) step (%[[STEP_1]], %[[STEP_2]]) init (%[[INIT_1]], %[[INIT_2]]) -> (i64, i64) {
+// CHECK:    %[[VAL_1:.*]] = memref.load
+// CHECK:    %[[VAL_2:.*]] = memref.load
+// CHECK:    %[[PRODUCT:.*]] = muli
+// CHECK:    %[[SUM:.*]] = addi
+// CHECK:    scf.reduce(%[[PRODUCT]]) : i64 {
+// CHECK:    ^bb0(%[[LHS:.*]]: i64, %[[RHS:.*]]: i64):
+// CHECK:      %[[RES:.*]] = addi
+// CHECK:      scf.reduce.return %[[RES]] : i64
+// CHECK:    }
+// CHECK:    scf.reduce(%[[SUM]]) : i64 {
+// CHECK:    ^bb0(%[[LHS:.*]]: i64, %[[RHS:.*]]: i64):
+// CHECK:      %[[RES:.*]] = muli
+// CHECK:      scf.reduce.return %[[RES]] : i64
+// CHECK:    }
+// CHECK:    scf.yield
+// CHECK:  }


        


More information about the Mlir-commits mailing list