[Mlir-commits] [mlir] 8450bbf - [mlir][arith] Add expansion pattern for ext/trunc of bf16

Robert Suderman llvmlistbot at llvm.org
Thu Apr 6 11:33:06 PDT 2023


Author: Robert Suderman
Date: 2023-04-06T18:24:02Z
New Revision: 8450bbf7f98d0de4b8df7b21835b3f15d25dfa2f

URL: https://github.com/llvm/llvm-project/commit/8450bbf7f98d0de4b8df7b21835b3f15d25dfa2f
DIFF: https://github.com/llvm/llvm-project/commit/8450bbf7f98d0de4b8df7b21835b3f15d25dfa2f.diff

LOG: [mlir][arith] Add expansion pattern for ext/trunc of bf16

bf16 has a trivial truncation/extension behavior with F32 that
can be described in elementary arith operations. Include some
expansions to efficiently convert including rounding towards
infinity for f32 to bf16 truncation.

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D147585

Added: 
    mlir/test/mlir-cpu-runner/expand-arith-ops.mlir

Modified: 
    mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
    mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
    mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
    mlir/test/Dialect/Arith/expand-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
index 257a62aa39f78..6d60f8aefd63c 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
@@ -38,6 +38,9 @@ void populateArithWideIntEmulationPatterns(
 /// Add patterns to expand Arith ceil/floor division ops.
 void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns);
 
+/// Add patterns to expand Arith bf16 patterns to lower level bitcasts/shifts.
+void populateExpandBFloat16Patterns(RewritePatternSet &patterns);
+
 /// Add patterns to expand Arith ops.
 void populateArithExpandOpsPatterns(RewritePatternSet &patterns);
 

diff  --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
index c5b80346bd52f..29b5e6f1dee86 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
@@ -32,6 +32,10 @@ def ArithExpandOps : Pass<"arith-expand"> {
   let summary = "Legalize Arith ops to be convertible to LLVM.";
   let constructor = "mlir::arith::createArithExpandOpsPass()";
   let dependentDialects = ["vector::VectorDialect"];
+  let options = [
+    Option<"includeBf16", "include-bf16", "bool", /*default=*/"false",
+           "Enable the BF16 expansion patterns">,
+  ];
 }
 
 def ArithUnsignedWhenEquivalent : Pass<"arith-unsigned-when-equivalent"> {

diff  --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 8f34531937c5c..78fc8a51080c4 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -10,6 +10,7 @@
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Transforms/DialectConversion.h"
 
@@ -25,15 +26,13 @@ using namespace mlir;
 /// Create an integer or index constant.
 static Value createConst(Location loc, Type type, int value,
                          PatternRewriter &rewriter) {
-
-  auto elTy = getElementTypeOrSelf(type);
-  auto constantAttr = rewriter.getIntegerAttr(elTy, value);
-
-  if (auto vecTy = llvm::dyn_cast<ShapedType>(type))
+  auto attr = rewriter.getIntegerAttr(getElementTypeOrSelf(type), value);
+  if (auto shapedTy = dyn_cast<ShapedType>(type)) {
     return rewriter.create<arith::ConstantOp>(
-        loc, vecTy, DenseElementsAttr::get(vecTy, constantAttr));
+        loc, DenseElementsAttr::get(shapedTy, attr));
+  }
 
-  return rewriter.create<arith::ConstantOp>(loc, constantAttr);
+  return rewriter.create<arith::ConstantOp>(loc, attr);
 }
 
 namespace {
@@ -187,6 +186,122 @@ struct MaxMinFOpConverter : public OpRewritePattern<OpTy> {
   }
 };
 
+struct BFloat16ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(arith::ExtFOp op,
+                                PatternRewriter &rewriter) const final {
+    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    auto operand = op.getOperand();
+    Type operandTy = operand.getType();
+    Type resultTy = op.getType();
+    Type operandETy = getElementTypeOrSelf(operandTy);
+    Type resultETy = getElementTypeOrSelf(resultTy);
+
+    if (!operandETy.isBF16() || !resultETy.isF32()) {
+      return rewriter.notifyMatchFailure(op, "not a ext of bf16 to f32.");
+    }
+
+    Type i16Ty = b.getI16Type();
+    Type i32Ty = b.getI32Type();
+    if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
+      i16Ty = shapedTy.clone(i16Ty);
+      i32Ty = shapedTy.clone(i32Ty);
+    }
+
+    Value bitcast = b.create<arith::BitcastOp>(i16Ty, operand);
+    Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
+
+    Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter);
+    Value shl = b.create<arith::ShLIOp>(exti, c16);
+    Value result = b.create<arith::BitcastOp>(resultTy, shl);
+
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
+struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(arith::TruncFOp op,
+                                PatternRewriter &rewriter) const final {
+    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    auto operand = op.getOperand();
+    Type operandTy = operand.getType();
+    Type resultTy = op.getType();
+    Type operandETy = getElementTypeOrSelf(operandTy);
+    Type resultETy = getElementTypeOrSelf(resultTy);
+
+    if (!operandETy.isF32() || !resultETy.isBF16()) {
+      return rewriter.notifyMatchFailure(op, "not a trunc of f32 to bf16.");
+    }
+
+    Type i1Ty = b.getI1Type();
+    Type i16Ty = b.getI16Type();
+    Type i32Ty = b.getI32Type();
+    Type f32Ty = b.getF32Type();
+    if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
+      i1Ty = shapedTy.clone(i1Ty);
+      i16Ty = shapedTy.clone(i16Ty);
+      i32Ty = shapedTy.clone(i32Ty);
+      f32Ty = shapedTy.clone(f32Ty);
+    }
+
+    Value bitcast = b.create<arith::BitcastOp>(i32Ty, operand);
+
+    Value c23 = createConst(op.getLoc(), i32Ty, 23, rewriter);
+    Value c31 = createConst(op.getLoc(), i32Ty, 31, rewriter);
+    Value c23Mask = createConst(op.getLoc(), i32Ty, (1 << 23) - 1, rewriter);
+    Value expMask =
+        createConst(op.getLoc(), i32Ty, ((1 << 8) - 1) << 23, rewriter);
+    Value expMax =
+        createConst(op.getLoc(), i32Ty, ((1 << 8) - 2) << 23, rewriter);
+
+    // Grab the sign bit.
+    Value sign = b.create<arith::ShRUIOp>(bitcast, c31);
+
+    // Our mantissa rounding value depends on the sign bit and the last
+    // truncated bit.
+    Value cManRound = createConst(op.getLoc(), i32Ty, (1 << 15), rewriter);
+    cManRound = b.create<arith::SubIOp>(cManRound, sign);
+
+    // Grab out the mantissa and directly apply rounding.
+    Value man = b.create<arith::AndIOp>(bitcast, c23Mask);
+    Value manRound = b.create<arith::AddIOp>(man, cManRound);
+
+    // Grab the overflow bit and shift right if we overflow.
+    Value roundBit = b.create<arith::ShRUIOp>(manRound, c23);
+    Value manNew = b.create<arith::ShRUIOp>(manRound, roundBit);
+
+    // Grab the exponent and round using the mantissa's carry bit.
+    Value exp = b.create<arith::AndIOp>(bitcast, expMask);
+    Value expCarry = b.create<arith::AddIOp>(exp, manRound);
+    expCarry = b.create<arith::AndIOp>(expCarry, expMask);
+
+    // If the exponent is saturated, we keep the max value.
+    Value expCmp =
+        b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, exp, expMax);
+    exp = b.create<arith::SelectOp>(expCmp, exp, expCarry);
+
+    // If the exponent is max and we rolled over, keep the old mantissa.
+    Value roundBitBool = b.create<arith::TruncIOp>(i1Ty, roundBit);
+    Value keepOldMan = b.create<arith::AndIOp>(expCmp, roundBitBool);
+    man = b.create<arith::SelectOp>(keepOldMan, man, manNew);
+
+    // Assemble the now rounded f32 value (as an i32).
+    Value rounded = b.create<arith::ShLIOp>(sign, c31);
+    rounded = b.create<arith::OrIOp>(rounded, exp);
+    rounded = b.create<arith::OrIOp>(rounded, man);
+
+    Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter);
+    Value shr = b.create<arith::ShRUIOp>(rounded, c16);
+    Value trunc = b.create<arith::TruncIOp>(i16Ty, shr);
+    Value result = b.create<arith::BitcastOp>(resultTy, trunc);
+
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
 struct ArithExpandOpsPass
     : public arith::impl::ArithExpandOpsBase<ArithExpandOpsPass> {
   void runOnOperation() override {
@@ -204,6 +319,24 @@ struct ArithExpandOpsPass
       arith::MaxFOp,
       arith::MinFOp
     >();
+
+    if (includeBf16) {
+      arith::populateExpandBFloat16Patterns(patterns);
+      target.addDynamicallyLegalOp<arith::ExtFOp>(
+        [](arith::ExtFOp op) {
+          Type inETy = getElementTypeOrSelf(op.getOperand().getType());
+          Type outETy = getElementTypeOrSelf(op.getType());
+          return !(inETy.isBF16() && outETy.isF32());
+        });
+
+      target.addDynamicallyLegalOp<arith::TruncFOp>(
+        [](arith::TruncFOp op)  {
+          Type inETy = getElementTypeOrSelf(op.getOperand().getType());
+          Type outETy = getElementTypeOrSelf(op.getType());
+          return !(inETy.isF32() && outETy.isBF16());
+        });
+    }
+
     // clang-format on
     if (failed(applyPartialConversion(getOperation(), target,
                                       std::move(patterns))))
@@ -220,6 +353,11 @@ void mlir::arith::populateCeilFloorDivExpandOpsPatterns(
           patterns.getContext());
 }
 
+void mlir::arith::populateExpandBFloat16Patterns(RewritePatternSet &patterns) {
+  patterns.add<BFloat16ExtFOpConverter, BFloat16TruncFOpConverter>(
+      patterns.getContext());
+}
+
 void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) {
   populateCeilFloorDivExpandOpsPatterns(patterns);
   // clang-format off

diff  --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir
index 7b7eb4003956a..6f28553116794 100644
--- a/mlir/test/Dialect/Arith/expand-ops.mlir
+++ b/mlir/test/Dialect/Arith/expand-ops.mlir
@@ -1,217 +1,48 @@
-// RUN: mlir-opt %s -arith-expand -split-input-file | FileCheck %s
-
-// Test ceil divide with signed integer
-// CHECK-LABEL:       func @ceildivi
-// CHECK-SAME:     ([[ARG0:%.+]]: i32, [[ARG1:%.+]]: i32) -> i32 {
-func.func @ceildivi(%arg0: i32, %arg1: i32) -> (i32) {
-  %res = arith.ceildivsi %arg0, %arg1 : i32
-  return %res : i32
-
-// CHECK:           [[ONE:%.+]] = arith.constant 1 : i32
-// CHECK:           [[ZERO:%.+]] = arith.constant 0 : i32
-// CHECK:           [[MINONE:%.+]] = arith.constant -1 : i32
-// CHECK:           [[CMP1:%.+]] = arith.cmpi sgt, [[ARG1]], [[ZERO]] : i32
-// CHECK:           [[X:%.+]] = arith.select [[CMP1]], [[MINONE]], [[ONE]] : i32
-// CHECK:           [[TRUE1:%.+]] = arith.addi [[X]], [[ARG0]] : i32
-// CHECK:           [[TRUE2:%.+]] = arith.divsi [[TRUE1]], [[ARG1]] : i32
-// CHECK:           [[TRUE3:%.+]] = arith.addi [[ONE]], [[TRUE2]] : i32
-// CHECK:           [[FALSE1:%.+]] = arith.subi [[ZERO]], [[ARG0]] : i32
-// CHECK:           [[FALSE2:%.+]] = arith.divsi [[FALSE1]], [[ARG1]] : i32
-// CHECK:           [[FALSE3:%.+]] = arith.subi [[ZERO]], [[FALSE2]] : i32
-// CHECK:           [[NNEG:%.+]] = arith.cmpi slt, [[ARG0]], [[ZERO]] : i32
-// CHECK:           [[NPOS:%.+]] = arith.cmpi sgt, [[ARG0]], [[ZERO]] : i32
-// CHECK:           [[MNEG:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : i32
-// CHECK:           [[MPOS:%.+]] = arith.cmpi sgt, [[ARG1]], [[ZERO]] : i32
-// CHECK:           [[TERM1:%.+]] = arith.andi [[NNEG]], [[MNEG]] : i1
-// CHECK:           [[TERM2:%.+]] = arith.andi [[NPOS]], [[MPOS]] : i1
-// CHECK:           [[CMP2:%.+]] = arith.ori [[TERM1]], [[TERM2]] : i1
-// CHECK:           [[RES:%.+]] = arith.select [[CMP2]], [[TRUE3]], [[FALSE3]] : i32
-}
-
-// -----
-
-// Test ceil divide with index type
-// CHECK-LABEL:       func @ceildivi_index
-// CHECK-SAME:     ([[ARG0:%.+]]: index, [[ARG1:%.+]]: index) -> index {
-func.func @ceildivi_index(%arg0: index, %arg1: index) -> (index) {
-  %res = arith.ceildivsi %arg0, %arg1 : index
-  return %res : index
-
-// CHECK:           [[ONE:%.+]] = arith.constant 1 : index
-// CHECK:           [[ZERO:%.+]] = arith.constant 0 : index
-// CHECK:           [[MINONE:%.+]] = arith.constant -1 : index
-// CHECK:           [[CMP1:%.+]] = arith.cmpi sgt, [[ARG1]], [[ZERO]] : index
-// CHECK:           [[X:%.+]] = arith.select [[CMP1]], [[MINONE]], [[ONE]] : index
-// CHECK:           [[TRUE1:%.+]] = arith.addi [[X]], [[ARG0]] : index
-// CHECK:           [[TRUE2:%.+]] = arith.divsi [[TRUE1]], [[ARG1]] : index
-// CHECK:           [[TRUE3:%.+]] = arith.addi [[ONE]], [[TRUE2]] : index
-// CHECK:           [[FALSE1:%.+]] = arith.subi [[ZERO]], [[ARG0]] : index
-// CHECK:           [[FALSE2:%.+]] = arith.divsi [[FALSE1]], [[ARG1]] : index
-// CHECK:           [[FALSE3:%.+]] = arith.subi [[ZERO]], [[FALSE2]] : index
-// CHECK:           [[NNEG:%.+]] = arith.cmpi slt, [[ARG0]], [[ZERO]] : index
-// CHECK:           [[NPOS:%.+]] = arith.cmpi sgt, [[ARG0]], [[ZERO]] : index
-// CHECK:           [[MNEG:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : index
-// CHECK:           [[MPOS:%.+]] = arith.cmpi sgt, [[ARG1]], [[ZERO]] : index
-// CHECK:           [[TERM1:%.+]] = arith.andi [[NNEG]], [[MNEG]] : i1
-// CHECK:           [[TERM2:%.+]] = arith.andi [[NPOS]], [[MPOS]] : i1
-// CHECK:           [[CMP2:%.+]] = arith.ori [[TERM1]], [[TERM2]] : i1
-// CHECK:           [[RES:%.+]] = arith.select [[CMP2]], [[TRUE3]], [[FALSE3]] : index
-}
-
-// -----
-
-// Test floor divide with signed integer
-// CHECK-LABEL:       func @floordivi
-// CHECK-SAME:     ([[ARG0:%.+]]: i32, [[ARG1:%.+]]: i32) -> i32 {
-func.func @floordivi(%arg0: i32, %arg1: i32) -> (i32) {
-  %res = arith.floordivsi %arg0, %arg1 : i32
-  return %res : i32
-// CHECK:           [[ONE:%.+]] = arith.constant 1 : i32
-// CHECK:           [[ZERO:%.+]] = arith.constant 0 : i32
-// CHECK:           [[MIN1:%.+]] = arith.constant -1 : i32
-// CHECK:           [[CMP1:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : i32
-// CHECK:           [[X:%.+]] = arith.select [[CMP1]], [[ONE]], [[MIN1]] : i32
-// CHECK:           [[TRUE1:%.+]] = arith.subi [[X]], [[ARG0]] : i32
-// CHECK:           [[TRUE2:%.+]] = arith.divsi [[TRUE1]], [[ARG1]] : i32
-// CHECK:           [[TRUE3:%.+]] = arith.subi [[MIN1]], [[TRUE2]] : i32
-// CHECK:           [[FALSE:%.+]] = arith.divsi [[ARG0]], [[ARG1]] : i32
-// CHECK:           [[NNEG:%.+]] = arith.cmpi slt, [[ARG0]], [[ZERO]] : i32
-// CHECK:           [[NPOS:%.+]] = arith.cmpi sgt, [[ARG0]], [[ZERO]] : i32
-// CHECK:           [[MNEG:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : i32
-// CHECK:           [[MPOS:%.+]] = arith.cmpi sgt, [[ARG1]], [[ZERO]] : i32
-// CHECK:           [[TERM1:%.+]] = arith.andi [[NNEG]], [[MPOS]] : i1
-// CHECK:           [[TERM2:%.+]] = arith.andi [[NPOS]], [[MNEG]] : i1
-// CHECK:           [[CMP2:%.+]] = arith.ori [[TERM1]], [[TERM2]] : i1
-// CHECK:           [[RES:%.+]] = arith.select [[CMP2]], [[TRUE3]], [[FALSE]] : i32
-}
-
-// -----
-
-// Test floor divide with index type
-// CHECK-LABEL:       func @floordivi_index
-// CHECK-SAME:     ([[ARG0:%.+]]: index, [[ARG1:%.+]]: index) -> index {
-func.func @floordivi_index(%arg0: index, %arg1: index) -> (index) {
-  %res = arith.floordivsi %arg0, %arg1 : index
-  return %res : index
-// CHECK:           [[ONE:%.+]] = arith.constant 1 : index
-// CHECK:           [[ZERO:%.+]] = arith.constant 0 : index
-// CHECK:           [[MIN1:%.+]] = arith.constant -1 : index
-// CHECK:           [[CMP1:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : index
-// CHECK:           [[X:%.+]] = arith.select [[CMP1]], [[ONE]], [[MIN1]] : index
-// CHECK:           [[TRUE1:%.+]] = arith.subi [[X]], [[ARG0]] : index
-// CHECK:           [[TRUE2:%.+]] = arith.divsi [[TRUE1]], [[ARG1]] : index
-// CHECK:           [[TRUE3:%.+]] = arith.subi [[MIN1]], [[TRUE2]] : index
-// CHECK:           [[FALSE:%.+]] = arith.divsi [[ARG0]], [[ARG1]] : index
-// CHECK:           [[NNEG:%.+]] = arith.cmpi slt, [[ARG0]], [[ZERO]] : index
-// CHECK:           [[NPOS:%.+]] = arith.cmpi sgt, [[ARG0]], [[ZERO]] : index
-// CHECK:           [[MNEG:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : index
-// CHECK:           [[MPOS:%.+]] = arith.cmpi sgt, [[ARG1]], [[ZERO]] : index
-// CHECK:           [[TERM1:%.+]] = arith.andi [[NNEG]], [[MPOS]] : i1
-// CHECK:           [[TERM2:%.+]] = arith.andi [[NPOS]], [[MNEG]] : i1
-// CHECK:           [[CMP2:%.+]] = arith.ori [[TERM1]], [[TERM2]] : i1
-// CHECK:           [[RES:%.+]] = arith.select [[CMP2]], [[TRUE3]], [[FALSE]] : index
-}
-
-// -----
-
-// Test floor divide with vector
-// CHECK-LABEL:   func.func @floordivi_vec(
-// CHECK-SAME:                             %[[VAL_0:.*]]: vector<4xi32>,
-// CHECK-SAME:                             %[[VAL_1:.*]]: vector<4xi32>) -> vector<4xi32> {
-func.func @floordivi_vec(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> (vector<4xi32>) {
-  %res = arith.floordivsi %arg0, %arg1 : vector<4xi32>
-  return %res : vector<4xi32>
-// CHECK:           %[[VAL_2:.*]] = arith.constant dense<1> : vector<4xi32>
-// CHECK:           %[[VAL_3:.*]] = arith.constant dense<0> : vector<4xi32>
-// CHECK:           %[[VAL_4:.*]] = arith.constant dense<-1> : vector<4xi32>
-// CHECK:           %[[VAL_5:.*]] = arith.cmpi slt, %[[VAL_1]], %[[VAL_3]] : vector<4xi32>
-// CHECK:           %[[VAL_6:.*]] = arith.select %[[VAL_5]], %[[VAL_2]], %[[VAL_4]] : vector<4xi1>, vector<4xi32>
-// CHECK:           %[[VAL_7:.*]] = arith.subi %[[VAL_6]], %[[VAL_0]] : vector<4xi32>
-// CHECK:           %[[VAL_8:.*]] = arith.divsi %[[VAL_7]], %[[VAL_1]] : vector<4xi32>
-// CHECK:           %[[VAL_9:.*]] = arith.subi %[[VAL_4]], %[[VAL_8]] : vector<4xi32>
-// CHECK:           %[[VAL_10:.*]] = arith.divsi %[[VAL_0]], %[[VAL_1]] : vector<4xi32>
-// CHECK:           %[[VAL_11:.*]] = arith.cmpi slt, %[[VAL_0]], %[[VAL_3]] : vector<4xi32>
-// CHECK:           %[[VAL_12:.*]] = arith.cmpi sgt, %[[VAL_0]], %[[VAL_3]] : vector<4xi32>
-// CHECK:           %[[VAL_13:.*]] = arith.cmpi slt, %[[VAL_1]], %[[VAL_3]] : vector<4xi32>
-// CHECK:           %[[VAL_14:.*]] = arith.cmpi sgt, %[[VAL_1]], %[[VAL_3]] : vector<4xi32>
-// CHECK:           %[[VAL_15:.*]] = arith.andi %[[VAL_11]], %[[VAL_14]] : vector<4xi1>
-// CHECK:           %[[VAL_16:.*]] = arith.andi %[[VAL_12]], %[[VAL_13]] : vector<4xi1>
-// CHECK:           %[[VAL_17:.*]] = arith.ori %[[VAL_15]], %[[VAL_16]] : vector<4xi1>
-// CHECK:           %[[VAL_18:.*]] = arith.select %[[VAL_17]], %[[VAL_9]], %[[VAL_10]] : vector<4xi1>, vector<4xi32>
-}
-
-// -----
-
-// Test ceil divide with unsigned integer
-// CHECK-LABEL:       func @ceildivui
-// CHECK-SAME:     ([[ARG0:%.+]]: i32, [[ARG1:%.+]]: i32) -> i32 {
-func.func @ceildivui(%arg0: i32, %arg1: i32) -> (i32) {
-  %res = arith.ceildivui %arg0, %arg1 : i32
-  return %res : i32
-// CHECK:           [[ZERO:%.+]] = arith.constant 0 : i32
-// CHECK:           [[ISZERO:%.+]] = arith.cmpi eq, %arg0, [[ZERO]] : i32
-// CHECK:           [[ONE:%.+]] = arith.constant 1 : i32
-// CHECK:           [[SUB:%.+]] = arith.subi %arg0, [[ONE]] : i32
-// CHECK:           [[DIV:%.+]] = arith.divui [[SUB]], %arg1 : i32
-// CHECK:           [[REM:%.+]] = arith.addi [[DIV]], [[ONE]] : i32
-// CHECK:           [[RES:%.+]] = arith.select [[ISZERO]], [[ZERO]], [[REM]] : i32
-}
-
-// -----
-
-// Test unsigned ceil divide with index
-// CHECK-LABEL:       func @ceildivui_index
-// CHECK-SAME:     ([[ARG0:%.+]]: index, [[ARG1:%.+]]: index) -> index {
-func.func @ceildivui_index(%arg0: index, %arg1: index) -> (index) {
-  %res = arith.ceildivui %arg0, %arg1 : index
-  return %res : index
-// CHECK:           [[ZERO:%.+]] = arith.constant 0 : index
-// CHECK:           [[ISZERO:%.+]] = arith.cmpi eq, %arg0, [[ZERO]] : index
-// CHECK:           [[ONE:%.+]] = arith.constant 1 : index
-// CHECK:           [[SUB:%.+]] = arith.subi %arg0, [[ONE]] : index
-// CHECK:           [[DIV:%.+]] = arith.divui [[SUB]], %arg1 : index
-// CHECK:           [[REM:%.+]] = arith.addi [[DIV]], [[ONE]] : index
-// CHECK:           [[RES:%.+]] = arith.select [[ISZERO]], [[ZERO]], [[REM]] : index
-}
-
-// -----
-
-// CHECK-LABEL: func @maxf
-func.func @maxf(%a: f32, %b: f32) -> f32 {
-  %result = arith.maxf %a, %b : f32
-  return %result : f32
-}
-// CHECK-SAME: %[[LHS:.*]]: f32, %[[RHS:.*]]: f32)
-// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf ugt, %[[LHS]], %[[RHS]] : f32
-// CHECK-NEXT: %[[SELECT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : f32
-// CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[RHS]], %[[RHS]] : f32
-// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[IS_NAN]], %[[RHS]], %[[SELECT]] : f32
-// CHECK-NEXT: return %[[RESULT]] : f32
-
-// -----
-
-// CHECK-LABEL: func @maxf_vector
-func.func @maxf_vector(%a: vector<4xf16>, %b: vector<4xf16>) -> vector<4xf16> {
-  %result = arith.maxf %a, %b : vector<4xf16>
-  return %result : vector<4xf16>
-}
-// CHECK-SAME: %[[LHS:.*]]: vector<4xf16>, %[[RHS:.*]]: vector<4xf16>)
-// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf ugt, %[[LHS]], %[[RHS]] : vector<4xf16>
-// CHECK-NEXT: %[[SELECT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]]
-// CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[RHS]], %[[RHS]] : vector<4xf16>
-// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[IS_NAN]], %[[RHS]], %[[SELECT]]
-// CHECK-NEXT: return %[[RESULT]] : vector<4xf16>
-
-// -----
-
-// CHECK-LABEL: func @minf
-func.func @minf(%a: f32, %b: f32) -> f32 {
-  %result = arith.minf %a, %b : f32
-  return %result : f32
-}
-// CHECK-SAME: %[[LHS:.*]]: f32, %[[RHS:.*]]: f32)
-// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf ult, %[[LHS]], %[[RHS]] : f32
-// CHECK-NEXT: %[[SELECT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : f32
-// CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[RHS]], %[[RHS]] : f32
-// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[IS_NAN]], %[[RHS]], %[[SELECT]] : f32
-// CHECK-NEXT: return %[[RESULT]] : f32
+// RUN: mlir-opt %s -arith-expand="include-bf16=true" --canonicalize -split-input-file | FileCheck %s
+
+func.func @truncf_f32(%arg0 : f32) -> bf16 {
+    %0 = arith.truncf %arg0 : f32 to bf16
+    return %0 : bf16
+}
+
+// CHECK-LABEL: @truncf_f32
+
+// CHECK: %[[C16:.+]] = arith.constant 16
+// CHECK: %[[C32768:.+]] = arith.constant 32768
+// CHECK: %[[C2130706432:.+]] = arith.constant 2130706432
+// CHECK: %[[C2139095040:.+]] = arith.constant 2139095040
+// CHECK: %[[C8388607:.+]] = arith.constant 8388607
+// CHECK: %[[C31:.+]] = arith.constant 31
+// CHECK: %[[C23:.+]] = arith.constant 23
+// CHECK: %[[BITCAST:.+]] = arith.bitcast %arg0
+// CHECK: %[[SIGN:.+]] = arith.shrui %[[BITCAST:.+]], %[[C31]]
+// CHECK: %[[ROUND:.+]] = arith.subi %[[C32768]], %[[SIGN]]
+// CHECK: %[[MANTISSA:.+]] = arith.andi %[[BITCAST]], %[[C8388607]]
+// CHECK: %[[ROUNDED:.+]] = arith.addi %[[MANTISSA]], %[[ROUND]]
+// CHECK: %[[ROLL:.+]] = arith.shrui %[[ROUNDED]], %[[C23]]
+// CHECK: %[[SHR:.+]] = arith.shrui %[[ROUNDED]], %[[ROLL]]
+// CHECK: %[[EXP:.+]] = arith.andi %0, %[[C2139095040]]
+// CHECK: %[[EXPROUND:.+]] = arith.addi %[[EXP]], %[[ROUNDED]]
+// CHECK: %[[EXPROLL:.+]] = arith.andi %[[EXPROUND]], %[[C2139095040]]
+// CHECK: %[[EXPMAX:.+]] = arith.cmpi uge, %[[EXP]], %[[C2130706432]]
+// CHECK: %[[EXPNEW:.+]] = arith.select %[[EXPMAX]], %[[EXP]], %[[EXPROLL]]
+// CHECK: %[[OVERFLOW_B:.+]] = arith.trunci %[[ROLL]]
+// CHECK: %[[KEEP_MAN:.+]] = arith.andi %[[EXPMAX]], %[[OVERFLOW_B]]
+// CHECK: %[[MANNEW:.+]] = arith.select %[[KEEP_MAN]], %[[MANTISSA]], %[[SHR]]
+// CHECK: %[[NEWSIGN:.+]] = arith.shli %[[SIGN]], %[[C31]]
+// CHECK: %[[WITHEXP:.+]] = arith.ori %[[NEWSIGN]], %[[EXPNEW]]
+// CHECK: %[[WITHMAN:.+]] = arith.ori %[[WITHEXP]], %[[MANNEW]]
+// CHECK: %[[SHIFT:.+]] = arith.shrui %[[WITHMAN]], %[[C16]]
+// CHECK: %[[TRUNC:.+]] = arith.trunci %[[SHIFT]]
+// CHECK: %[[RES:.+]] = arith.bitcast %[[TRUNC]]
+// CHECK: return %[[RES]]
+
+// -----
+
+func.func @truncf_vector_f32(%arg0 : vector<4xf32>) -> vector<4xbf16> {
+    %0 = arith.truncf %arg0 : vector<4xf32> to vector<4xbf16>
+    return %0 : vector<4xbf16>
+}
+
+// CHECK-LABEL: @truncf_vector_f32
+// CHECK-NOT: arith.truncf

diff  --git a/mlir/test/mlir-cpu-runner/expand-arith-ops.mlir b/mlir/test/mlir-cpu-runner/expand-arith-ops.mlir
new file mode 100644
index 0000000000000..44141cc4eeaf4
--- /dev/null
+++ b/mlir/test/mlir-cpu-runner/expand-arith-ops.mlir
@@ -0,0 +1,62 @@
+// RUN:   mlir-opt %s -pass-pipeline="builtin.module(func.func(arith-expand{include-bf16=true},convert-arith-to-llvm),convert-vector-to-llvm,convert-func-to-llvm,reconcile-unrealized-casts)" \
+// RUN: | mlir-cpu-runner                                                      \
+// RUN:     -e main -entry-point-result=void -O0                               \
+// RUN:     -shared-libs=%mlir_c_runner_utils  \
+// RUN:     -shared-libs=%mlir_runner_utils    \
+// RUN: | FileCheck %s
+
+func.func @trunc_bf16(%a : f32) {
+  %b = arith.truncf %a : f32 to bf16
+  %c = arith.extf %b : bf16 to f32
+  vector.print %c : f32
+  return
+}
+
+func.func @main() {
+  // CHECK: 1.00781
+  %roundOneI = arith.constant 0x3f808000 : i32
+  %roundOneF = arith.bitcast %roundOneI : i32 to f32
+  call @trunc_bf16(%roundOneF): (f32) -> ()
+
+  // CHECK-NEXT: -1
+  %noRoundNegOneI = arith.constant 0xbf808000 : i32
+  %noRoundNegOneF = arith.bitcast %noRoundNegOneI : i32 to f32
+  call @trunc_bf16(%noRoundNegOneF): (f32) -> ()
+
+  // CHECK-NEXT: -1.00781
+  %roundNegOneI = arith.constant 0xbf808001 : i32
+  %roundNegOneF = arith.bitcast %roundNegOneI : i32 to f32
+  call @trunc_bf16(%roundNegOneF): (f32) -> ()
+
+  // CHECK-NEXT: inf
+  %infi = arith.constant 0x7f800000 : i32
+  %inff = arith.bitcast %infi : i32 to f32
+  call @trunc_bf16(%inff): (f32) -> ()
+
+  // CHECK-NEXT: -inf
+  %neginfi = arith.constant 0xff800000 : i32
+  %neginff = arith.bitcast %neginfi : i32 to f32
+  call @trunc_bf16(%neginff): (f32) -> ()
+
+  // CHECK-NEXT: 3.38953e+38
+  %bigi = arith.constant 0x7f7fffff : i32
+  %bigf = arith.bitcast %bigi : i32 to f32
+  call @trunc_bf16(%bigf): (f32) -> ()
+
+  // CHECK-NEXT: -3.38953e+38
+  %negbigi = arith.constant 0xff7fffff : i32
+  %negbigf = arith.bitcast %negbigi : i32 to f32
+  call @trunc_bf16(%negbigf): (f32) -> ()
+
+  // CHECK-NEXT: 1.625
+  %exprolli = arith.constant 0x3fcfffff : i32
+  %exprollf = arith.bitcast %exprolli : i32 to f32
+  call @trunc_bf16(%exprollf): (f32) -> ()
+
+  // CHECK-NEXT: -1.625
+  %exprollnegi = arith.constant 0xbfcfffff : i32
+  %exprollnegf = arith.bitcast %exprollnegi : i32 to f32
+  call @trunc_bf16(%exprollnegf): (f32) -> ()
+
+  return
+}


        


More information about the Mlir-commits mailing list