[Mlir-commits] [mlir] cf29d0a - [mlir][tosa]Create a check for i64 input in apply_scale lowering in TosaToArith
Rob Suderman
llvmlistbot at llvm.org
Thu Sep 7 11:10:36 PDT 2023
Author: Natasha Kononenko
Date: 2023-09-07T11:08:52-07:00
New Revision: cf29d0a73764e81bf2cb41474f49926f5ef4ddd1
URL: https://github.com/llvm/llvm-project/commit/cf29d0a73764e81bf2cb41474f49926f5ef4ddd1
DIFF: https://github.com/llvm/llvm-project/commit/cf29d0a73764e81bf2cb41474f49926f5ef4ddd1.diff
LOG: [mlir][tosa]Create a check for i64 input in apply_scale lowering in TosaToArith
Reviewed By: rsuderman
Differential Revision: https://reviews.llvm.org/D159473
Added:
Modified:
mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
index c025fb9e1367d9..50e57682a2dc8d 100644
--- a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
+++ b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
@@ -81,7 +81,9 @@ class ApplyScaleGenericOpConverter
Value shift32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, op.getShift());
// Compute the multiplication in 64-bits then select the high / low parts.
- Value value64 = rewriter.create<arith::ExtSIOp>(loc, i64Ty, value);
+ Value value64 = value;
+ if (getElementTypeOrSelf(valueTy) != rewriter.getI64Type())
+ value64 = rewriter.create<arith::ExtSIOp>(loc, i64Ty, value);
Value multiplier64 =
rewriter.create<arith::ExtSIOp>(loc, i64Ty, multiplier32);
Value multiply64 =
diff --git a/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir b/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir
index 52dca4a48b204d..c4f82d53af9822 100644
--- a/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir
+++ b/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir
@@ -118,3 +118,40 @@ func.func @apply_scale_test_i48(%arg0 : i48, %arg1 : i32, %arg2 : i8) -> (i32) {
%res = tosa.apply_scale %arg0, %arg1, %arg2 {double_round = true} : (i48, i32, i8) -> i32
return %res : i32
}
+
+// -----
+
+// CHECK-LABEL: @apply_scale_test_i64
+// SCALE: tosa.apply_scale
+func.func @apply_scale_test_i64(%arg0 : i64, %arg1 : i32, %arg2 : i8) -> (i32) {
+ // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i64
+ // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : i64
+ // CHECK-DAG: %[[C31:.+]] = arith.constant 31 : i32
+
+ // Multiply in 64 bits.
+ // CHECK-DAG: %[[M64:.+]] = arith.extsi %arg1 : i32 to i64
+ // CHECK-DAG: %[[MUL:.+]] = arith.muli %arg0, %[[M64]]
+
+ // Round normally.
+ // CHECK-DAG: %[[S32:.+]] = arith.extui %arg2 : i8 to i32
+ // CHECK-DAG: %[[S64:.+]] = arith.extui %[[S32]] : i32 to i64
+ // CHECK-DAG: %[[ONEL:.+]] = arith.shli %[[C1]], %[[S64]] : i64
+ // CHECK-DAG: %[[ONER:.+]] = arith.shrui %[[ONEL]], %[[C1]]
+ // CHECK-DAG: %[[ROUND:.+]] = arith.addi %[[MUL]], %[[ONER]]
+
+ // Apply double rounding.
+ // CHECK-DAG: %[[DUP:.+]] = arith.constant 1073741824 : i64
+ // CHECK-DAG: %[[DDOWN:.+]] = arith.constant -1073741824 : i64
+ // CHECK-DAG: %[[POS:.+]] = arith.cmpi sge, %arg0, %[[C0]]
+ // CHECK-DAG: %[[DBIT:.+]] = arith.select %[[POS]], %[[DUP]], %[[DDOWN]]
+ // CHECK-DAG: %[[DRND:.+]] = arith.addi %[[DBIT]], %[[ROUND]]
+ // CHECK-DAG: %[[USED:.+]] = arith.cmpi sgt, %[[S32]], %[[C31]] : i32
+ // CHECK-DAG: %[[RES64:.+]] = arith.select %[[USED]], %[[DRND]], %[[ROUND]] : i64
+
+ // Shift and truncate final answer.
+ // CHECK-DAG: %[[SHR:.+]] = arith.shrsi %[[RES64]], %[[S64]]
+ // CHECK-DAG: %[[TRUNC:.+]] = arith.trunci %[[SHR]] : i64 to i32
+ // CHECK: return %[[TRUNC]]
+ %res = tosa.apply_scale %arg0, %arg1, %arg2 {double_round = true} : (i64, i32, i8) -> i32
+ return %res : i32
+}
More information about the Mlir-commits
mailing list