[Mlir-commits] [mlir] ad3a078 - Fix linalg.dot over boolean tensors.

Johannes Reifferscheid llvmlistbot at llvm.org
Tue Jul 12 00:08:55 PDT 2022


Author: Johannes Reifferscheid
Date: 2022-07-12T09:08:45+02:00
New Revision: ad3a078745d973066cd9ea7b2199c5c666b4cd2a

URL: https://github.com/llvm/llvm-project/commit/ad3a078745d973066cd9ea7b2199c5c666b4cd2a
DIFF: https://github.com/llvm/llvm-project/commit/ad3a078745d973066cd9ea7b2199c5c666b4cd2a.diff

LOG: Fix linalg.dot over boolean tensors.

dot is currently miscompiled for booleans (uses add instead of or).

Reviewed By: bkramer

Differential Revision: https://reviews.llvm.org/D129292

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/loops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index a85b1f0ab4051..85298b0f6b95b 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -129,7 +129,9 @@ static MatchContractionResult isContractionInterfaceImpl(Operation *op) {
   // TODO: more fields than add/mul.
   if (!isAddMul<arith::AddFOp, arith::MulFOp>(linalgOp->getRegion(0).front()) &&
       !isAddMul<arith::AddIOp, arith::MulIOp>(linalgOp->getRegion(0).front()) &&
-      !isAddMul<complex::AddOp, complex::MulOp>(linalgOp->getRegion(0).front()))
+      !isAddMul<complex::AddOp, complex::MulOp>(
+          linalgOp->getRegion(0).front()) &&
+      !isAddMul<arith::OrIOp, arith::AndIOp>(linalgOp->getRegion(0).front()))
     return MatchContractionResult::NotAddMul;
   return MatchContractionResult::Success;
 }

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index cd71981ca6b25..1ce7d4dab1f13 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -325,6 +325,8 @@ class RegionBuilderHelper {
     bool allComplex = isComplex(arg0) && isComplex(arg1);
     bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
     bool allInteger = isInteger(arg0) && isInteger(arg1);
+    bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 &&
+                   arg1.getType().getIntOrFloatBitWidth() == 1;
     if (!allComplex && !allFloatingPoint && !allInteger)
       llvm_unreachable("unsupported non numeric type");
     OpBuilder builder = getBuilder();
@@ -334,18 +336,24 @@ class RegionBuilderHelper {
         return builder.create<complex::AddOp>(arg0.getLoc(), arg0, arg1);
       if (allFloatingPoint)
         return builder.create<arith::AddFOp>(arg0.getLoc(), arg0, arg1);
+      if (allBool)
+        return builder.create<arith::OrIOp>(arg0.getLoc(), arg0, arg1);
       return builder.create<arith::AddIOp>(arg0.getLoc(), arg0, arg1);
     case BinaryFn::sub:
       if (allComplex)
         return builder.create<complex::SubOp>(arg0.getLoc(), arg0, arg1);
       if (allFloatingPoint)
         return builder.create<arith::SubFOp>(arg0.getLoc(), arg0, arg1);
+      if (allBool)
+        llvm_unreachable("unsupported operation: sub with bools");
       return builder.create<arith::SubIOp>(arg0.getLoc(), arg0, arg1);
     case BinaryFn::mul:
       if (allComplex)
         return builder.create<complex::MulOp>(arg0.getLoc(), arg0, arg1);
       if (allFloatingPoint)
         return builder.create<arith::MulFOp>(arg0.getLoc(), arg0, arg1);
+      if (allBool)
+        return builder.create<arith::AndIOp>(arg0.getLoc(), arg0, arg1);
       return builder.create<arith::MulIOp>(arg0.getLoc(), arg0, arg1);
     case BinaryFn::max_signed:
       assert(!allComplex);

diff  --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index f35d76c700534..208664aa03bc6 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -137,6 +137,32 @@ func.func @dot(%arg0: memref<?xi8>, %M: index) {
 //       CHECKPARALLEL:   store %[[res]], %[[C]][] : memref<f32>
 
 
+func.func @dot_int(%arg0: memref<?xi32>, %arg1: memref<?xi32>,
+                   %arg3: memref<i32>) {
+  // Verifies that we use the correct arith operations for integers.
+  linalg.dot ins(%arg0, %arg1 : memref<?xi32>, memref<?xi32>)
+             outs(%arg3 : memref<i32>)
+  return
+}
+// CHECK-LABEL: func @dot_int(
+//       CHECK:   %[[inc:.*]] = arith.muli {{.*}} : i32
+//  CHECK-NEXT:   %[[res:.*]] = arith.addi {{.*}}, %[[inc]] : i32
+//  CHECK-NEXT:   store %[[res]], {{.*}} : memref<i32>
+
+
+func.func @dot_bool(%arg0: memref<?xi1>, %arg1: memref<?xi1>,
+                    %arg3: memref<i1>) {
+  // Verifies that we use the correct (saturating) arith operations for booleans.
+  linalg.dot ins(%arg0, %arg1 : memref<?xi1>, memref<?xi1>)
+             outs(%arg3 : memref<i1>)
+  return
+}
+// CHECK-LABEL: func @dot_bool(
+//       CHECK:   %[[inc:.*]] = arith.andi {{.*}} : i1
+//  CHECK-NEXT:   %[[res:.*]] = arith.ori {{.*}}, %[[inc]] : i1
+//  CHECK-NEXT:   store %[[res]], {{.*}} : memref<i1>
+
+
 func.func @dot_view(%arg0: memref<?xf32, offset: ?, strides: [1]>, %arg1: memref<?xf32, offset: ?, strides: [1]>, %arg2: memref<f32>) {
   linalg.dot ins(%arg0, %arg1 : memref<?xf32, offset: ?, strides: [1]>,
                                 memref<?xf32, offset: ?, strides: [1]>)


        


More information about the Mlir-commits mailing list