[Mlir-commits] [mlir] 9580468 - [mlir][affine] Enforce each result type to match Reduction ops in affine.parallel verifier

Mehdi Amini llvmlistbot at llvm.org
Sun Oct 1 14:24:31 PDT 2023


Author: Zhenyan Zhu
Date: 2023-10-01T14:24:17-07:00
New Revision: 9580468302a1c8f6236a121163c9087ac4e02cfe

URL: https://github.com/llvm/llvm-project/commit/9580468302a1c8f6236a121163c9087ac4e02cfe
DIFF: https://github.com/llvm/llvm-project/commit/9580468302a1c8f6236a121163c9087ac4e02cfe.diff

LOG:  [mlir][affine] Enforce each result type to match Reduction ops in affine.parallel verifier

This patch updates AffineParallelOp::verify() to check each result type matches
its corresponding reduction op (i.e, the result type must be a `FloatType` if
the reduction attribute is `addf`)

affine.parallel will crash on --lower-affine if the corresponding result type
cannot match the reduction attribute.

```
      %128 = affine.parallel (%arg2, %arg3) = (0, 0) to (8, 7) reduce ("maxf") -> (memref<8x7xf32>) {
        %alloc_33 = memref.alloc() : memref<8x7xf32>
        affine.yield %alloc_33 : memref<8x7xf32>
      }
```
This will crash and report a type conversion issue when we run `mlir-opt --lower-affine`

```
Assertion failed: (isa<To>(Val) && "cast<Ty>() argument of incompatible type!"), function cast, file Casting.h, line 572.
PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
Stack dump:
0.	Program arguments: mlir-opt --lower-affine temp.mlir
 #0 0x0000000102a18f18 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) (/workspacebin/mlir-opt+0x1002f8f18)
 #1 0x0000000102a171b4 llvm::sys::RunSignalHandlers() (/workspacebin/mlir-opt+0x1002f71b4)
 #2 0x0000000102a195c4 SignalHandler(int) (/workspacebin/mlir-opt+0x1002f95c4)
 #3 0x00000001be7894c4 (/usr/lib/system/libsystem_platform.dylib+0x1803414c4)
 #4 0x00000001be771ee0 (/usr/lib/system/libsystem_pthread.dylib+0x180329ee0)
 #5 0x00000001be6ac340 (/usr/lib/system/libsystem_c.dylib+0x180264340)
 #6 0x00000001be6ab754 (/usr/lib/system/libsystem_c.dylib+0x180263754)
 #7 0x0000000106864790 mlir::arith::getIdentityValueAttr(mlir::arith::AtomicRMWKind, mlir::Type, mlir::OpBuilder&, mlir::Location) (.cold.4) (/workspacebin/mlir-opt+0x104144790)
 #8 0x0000000102ba66ac mlir::arith::getIdentityValueAttr(mlir::arith::AtomicRMWKind, mlir::Type, mlir::OpBuilder&, mlir::Location) (/workspacebin/mlir-opt+0x1004866ac)
 #9 0x0000000102ba6910 mlir::arith::getIdentityValue(mlir::arith::AtomicRMWKind, mlir::Type, mlir::OpBuilder&, mlir::Location) (/workspacebin/mlir-opt+0x100486910)
...
```

Fixes #64068

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D157985

Added: 
    

Modified: 
    mlir/lib/Dialect/Affine/IR/AffineOps.cpp
    mlir/test/Dialect/Affine/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index c61bd566c7676f1..113f4cfc31c104b 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -3915,6 +3915,49 @@ void AffineParallelOp::setSteps(ArrayRef<int64_t> newSteps) {
   setStepsAttr(getBodyBuilder().getI64ArrayAttr(newSteps));
 }
 
+// check whether resultType match op or not in affine.parallel
+static bool isResultTypeMatchAtomicRMWKind(Type resultType,
+                                           arith::AtomicRMWKind op) {
+  switch (op) {
+  case arith::AtomicRMWKind::addf:
+    return isa<FloatType>(resultType);
+  case arith::AtomicRMWKind::addi:
+    return isa<IntegerType>(resultType);
+  case arith::AtomicRMWKind::assign:
+    return true;
+  case arith::AtomicRMWKind::mulf:
+    return isa<FloatType>(resultType);
+  case arith::AtomicRMWKind::muli:
+    return isa<IntegerType>(resultType);
+  case arith::AtomicRMWKind::maximumf:
+    return isa<FloatType>(resultType);
+  case arith::AtomicRMWKind::minimumf:
+    return isa<FloatType>(resultType);
+  case arith::AtomicRMWKind::maxs: {
+    auto intType = llvm::dyn_cast<IntegerType>(resultType);
+    return intType && intType.isSigned();
+  }
+  case arith::AtomicRMWKind::mins: {
+    auto intType = llvm::dyn_cast<IntegerType>(resultType);
+    return intType && intType.isSigned();
+  }
+  case arith::AtomicRMWKind::maxu: {
+    auto intType = llvm::dyn_cast<IntegerType>(resultType);
+    return intType && intType.isUnsigned();
+  }
+  case arith::AtomicRMWKind::minu: {
+    auto intType = llvm::dyn_cast<IntegerType>(resultType);
+    return intType && intType.isUnsigned();
+  }
+  case arith::AtomicRMWKind::ori:
+    return isa<IntegerType>(resultType);
+  case arith::AtomicRMWKind::andi:
+    return isa<IntegerType>(resultType);
+  default:
+    return false;
+  }
+}
+
 LogicalResult AffineParallelOp::verify() {
   auto numDims = getNumDims();
   if (getLowerBoundsGroups().getNumElements() != numDims ||
@@ -3946,11 +3989,16 @@ LogicalResult AffineParallelOp::verify() {
   if (getReductions().size() != getNumResults())
     return emitOpError("a reduction must be specified for each output");
 
-  // Verify reduction  ops are all valid
-  for (Attribute attr : getReductions()) {
+  // Verify reduction ops are all valid and each result type matches reduction
+  // ops
+  for (auto it : llvm::enumerate((getReductions()))) {
+    Attribute attr = it.value();
     auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
     if (!intAttr || !arith::symbolizeAtomicRMWKind(intAttr.getInt()))
       return emitOpError("invalid reduction attribute");
+    auto kind = arith::symbolizeAtomicRMWKind(intAttr.getInt()).value();
+    if (!isResultTypeMatchAtomicRMWKind(getResult(it.index()).getType(), kind))
+      return emitOpError("result type cannot match reduction attribute");
   }
 
   // Verify that the bound operands are valid dimension/symbols.

diff  --git a/mlir/test/Dialect/Affine/invalid.mlir b/mlir/test/Dialect/Affine/invalid.mlir
index 1dc3451ed7db87c..1bcb6fc4a365ddf 100644
--- a/mlir/test/Dialect/Affine/invalid.mlir
+++ b/mlir/test/Dialect/Affine/invalid.mlir
@@ -297,6 +297,18 @@ func.func @affine_parallel(%arg0 : index, %arg1 : index, %arg2 : index) {
 
 // -----
 
+func.func @affine_parallel(%arg0 : index, %arg1 : index, %arg2 : index) {
+  %0 = memref.alloc() : memref<100x100xi32>
+  //  expected-error at +1 {{result type cannot match reduction attribute}}
+  %1 = affine.parallel (%i, %j) = (0, 0) to (100, 100) step (10, 10) reduce ("minimumf") -> (i32) {
+    %2 = affine.load %0[%i, %j] : memref<100x100xi32>
+    affine.yield %2 : i32
+  }
+  return
+}
+
+// -----
+
 func.func @vector_load_invalid_vector_type() {
   %0 = memref.alloc() : memref<100xf32>
   affine.for %i0 = 0 to 16 step 8 {


        


More information about the Mlir-commits mailing list