[Mlir-commits] [mlir] ad3a078 - Fix linalg.dot over boolean tensors.
Johannes Reifferscheid
llvmlistbot at llvm.org
Tue Jul 12 00:08:55 PDT 2022
Author: Johannes Reifferscheid
Date: 2022-07-12T09:08:45+02:00
New Revision: ad3a078745d973066cd9ea7b2199c5c666b4cd2a
URL: https://github.com/llvm/llvm-project/commit/ad3a078745d973066cd9ea7b2199c5c666b4cd2a
DIFF: https://github.com/llvm/llvm-project/commit/ad3a078745d973066cd9ea7b2199c5c666b4cd2a.diff
LOG: Fix linalg.dot over boolean tensors.
dot is currently miscompiled for booleans (uses add instead of or).
Reviewed By: bkramer
Differential Revision: https://reviews.llvm.org/D129292
Added:
Modified:
mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/loops.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index a85b1f0ab4051..85298b0f6b95b 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -129,7 +129,9 @@ static MatchContractionResult isContractionInterfaceImpl(Operation *op) {
// TODO: more fields than add/mul.
if (!isAddMul<arith::AddFOp, arith::MulFOp>(linalgOp->getRegion(0).front()) &&
!isAddMul<arith::AddIOp, arith::MulIOp>(linalgOp->getRegion(0).front()) &&
- !isAddMul<complex::AddOp, complex::MulOp>(linalgOp->getRegion(0).front()))
+ !isAddMul<complex::AddOp, complex::MulOp>(
+ linalgOp->getRegion(0).front()) &&
+ !isAddMul<arith::OrIOp, arith::AndIOp>(linalgOp->getRegion(0).front()))
return MatchContractionResult::NotAddMul;
return MatchContractionResult::Success;
}
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index cd71981ca6b25..1ce7d4dab1f13 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -325,6 +325,8 @@ class RegionBuilderHelper {
bool allComplex = isComplex(arg0) && isComplex(arg1);
bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
bool allInteger = isInteger(arg0) && isInteger(arg1);
+ bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 &&
+ arg1.getType().getIntOrFloatBitWidth() == 1;
if (!allComplex && !allFloatingPoint && !allInteger)
llvm_unreachable("unsupported non numeric type");
OpBuilder builder = getBuilder();
@@ -334,18 +336,24 @@ class RegionBuilderHelper {
return builder.create<complex::AddOp>(arg0.getLoc(), arg0, arg1);
if (allFloatingPoint)
return builder.create<arith::AddFOp>(arg0.getLoc(), arg0, arg1);
+ if (allBool)
+ return builder.create<arith::OrIOp>(arg0.getLoc(), arg0, arg1);
return builder.create<arith::AddIOp>(arg0.getLoc(), arg0, arg1);
case BinaryFn::sub:
if (allComplex)
return builder.create<complex::SubOp>(arg0.getLoc(), arg0, arg1);
if (allFloatingPoint)
return builder.create<arith::SubFOp>(arg0.getLoc(), arg0, arg1);
+ if (allBool)
+ llvm_unreachable("unsupported operation: sub with bools");
return builder.create<arith::SubIOp>(arg0.getLoc(), arg0, arg1);
case BinaryFn::mul:
if (allComplex)
return builder.create<complex::MulOp>(arg0.getLoc(), arg0, arg1);
if (allFloatingPoint)
return builder.create<arith::MulFOp>(arg0.getLoc(), arg0, arg1);
+ if (allBool)
+ return builder.create<arith::AndIOp>(arg0.getLoc(), arg0, arg1);
return builder.create<arith::MulIOp>(arg0.getLoc(), arg0, arg1);
case BinaryFn::max_signed:
assert(!allComplex);
diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index f35d76c700534..208664aa03bc6 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -137,6 +137,32 @@ func.func @dot(%arg0: memref<?xi8>, %M: index) {
// CHECKPARALLEL: store %[[res]], %[[C]][] : memref<f32>
+func.func @dot_int(%arg0: memref<?xi32>, %arg1: memref<?xi32>,
+ %arg3: memref<i32>) {
+ // Verifies that we use the correct arith operations for integers.
+ linalg.dot ins(%arg0, %arg1 : memref<?xi32>, memref<?xi32>)
+ outs(%arg3 : memref<i32>)
+ return
+}
+// CHECK-LABEL: func @dot_int(
+// CHECK: %[[inc:.*]] = arith.muli {{.*}} : i32
+// CHECK-NEXT: %[[res:.*]] = arith.addi {{.*}}, %[[inc]] : i32
+// CHECK-NEXT: store %[[res]], {{.*}} : memref<i32>
+
+
+func.func @dot_bool(%arg0: memref<?xi1>, %arg1: memref<?xi1>,
+ %arg3: memref<i1>) {
+ // Verifies that we use the correct (saturating) arith operations for booleans.
+ linalg.dot ins(%arg0, %arg1 : memref<?xi1>, memref<?xi1>)
+ outs(%arg3 : memref<i1>)
+ return
+}
+// CHECK-LABEL: func @dot_bool(
+// CHECK: %[[inc:.*]] = arith.andi {{.*}} : i1
+// CHECK-NEXT: %[[res:.*]] = arith.ori {{.*}}, %[[inc]] : i1
+// CHECK-NEXT: store %[[res]], {{.*}} : memref<i1>
+
+
func.func @dot_view(%arg0: memref<?xf32, offset: ?, strides: [1]>, %arg1: memref<?xf32, offset: ?, strides: [1]>, %arg2: memref<f32>) {
linalg.dot ins(%arg0, %arg1 : memref<?xf32, offset: ?, strides: [1]>,
memref<?xf32, offset: ?, strides: [1]>)
More information about the Mlir-commits
mailing list