[Mlir-commits] [mlir] cf29d0a - [mlir][tosa]Create a check for i64 input in apply_scale lowering in TosaToArith

Rob Suderman llvmlistbot at llvm.org
Thu Sep 7 11:10:36 PDT 2023


Author: Natasha Kononenko
Date: 2023-09-07T11:08:52-07:00
New Revision: cf29d0a73764e81bf2cb41474f49926f5ef4ddd1

URL: https://github.com/llvm/llvm-project/commit/cf29d0a73764e81bf2cb41474f49926f5ef4ddd1
DIFF: https://github.com/llvm/llvm-project/commit/cf29d0a73764e81bf2cb41474f49926f5ef4ddd1.diff

LOG: [mlir][tosa]Create a check for i64 input in apply_scale lowering in TosaToArith

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D159473

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
    mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
index c025fb9e1367d9..50e57682a2dc8d 100644
--- a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
+++ b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
@@ -81,7 +81,9 @@ class ApplyScaleGenericOpConverter
     Value shift32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, op.getShift());
 
     // Compute the multiplication in 64-bits then select the high / low parts.
-    Value value64 = rewriter.create<arith::ExtSIOp>(loc, i64Ty, value);
+    Value value64 = value;
+    if (getElementTypeOrSelf(valueTy) != rewriter.getI64Type())
+      value64 = rewriter.create<arith::ExtSIOp>(loc, i64Ty, value);
     Value multiplier64 =
         rewriter.create<arith::ExtSIOp>(loc, i64Ty, multiplier32);
     Value multiply64 =

diff  --git a/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir b/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir
index 52dca4a48b204d..c4f82d53af9822 100644
--- a/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir
+++ b/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir
@@ -118,3 +118,40 @@ func.func @apply_scale_test_i48(%arg0 : i48, %arg1 : i32, %arg2 : i8) -> (i32) {
   %res = tosa.apply_scale %arg0, %arg1, %arg2 {double_round = true} : (i48, i32, i8) -> i32
   return %res : i32
 }
+
+// -----
+
+// CHECK-LABEL: @apply_scale_test_i64
+// SCALE: tosa.apply_scale
+func.func @apply_scale_test_i64(%arg0 : i64, %arg1 : i32, %arg2 : i8) -> (i32) {
+  // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i64
+  // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : i64
+  // CHECK-DAG: %[[C31:.+]] = arith.constant 31 : i32
+
+  // Multiply in 64 bits.
+  // CHECK-DAG: %[[M64:.+]] = arith.extsi %arg1 : i32 to i64
+  // CHECK-DAG: %[[MUL:.+]] = arith.muli %arg0, %[[M64]]
+
+  // Round normally.
+  // CHECK-DAG: %[[S32:.+]] = arith.extui %arg2 : i8 to i32
+  // CHECK-DAG: %[[S64:.+]] = arith.extui %[[S32]] : i32 to i64
+  // CHECK-DAG: %[[ONEL:.+]] = arith.shli %[[C1]], %[[S64]] : i64
+  // CHECK-DAG: %[[ONER:.+]] = arith.shrui %[[ONEL]], %[[C1]]
+  // CHECK-DAG: %[[ROUND:.+]] = arith.addi %[[MUL]], %[[ONER]]
+
+  // Apply double rounding.
+  // CHECK-DAG: %[[DUP:.+]] = arith.constant 1073741824 : i64
+  // CHECK-DAG: %[[DDOWN:.+]] = arith.constant -1073741824 : i64
+  // CHECK-DAG: %[[POS:.+]] = arith.cmpi sge, %arg0, %[[C0]]
+  // CHECK-DAG: %[[DBIT:.+]] = arith.select %[[POS]], %[[DUP]], %[[DDOWN]]
+  // CHECK-DAG: %[[DRND:.+]] = arith.addi %[[DBIT]], %[[ROUND]]
+  // CHECK-DAG: %[[USED:.+]] = arith.cmpi sgt, %[[S32]], %[[C31]] : i32
+  // CHECK-DAG: %[[RES64:.+]] = arith.select %[[USED]], %[[DRND]], %[[ROUND]] : i64
+
+  // Shift and truncate final answer.
+  // CHECK-DAG: %[[SHR:.+]] = arith.shrsi %[[RES64]], %[[S64]]
+  // CHECK-DAG: %[[TRUNC:.+]] = arith.trunci %[[SHR]] : i64 to i32
+  // CHECK: return %[[TRUNC]]
+  %res = tosa.apply_scale %arg0, %arg1, %arg2 {double_round = true} : (i64, i32, i8) -> i32
+  return %res : i32
+}


        


More information about the Mlir-commits mailing list