[Mlir-commits] [mlir] 2ceedc3 - [mlir][linalg] Add symbolic type conversion to linalg named ops.

Stella Laurenzo llvmlistbot at llvm.org
Sat Feb 27 15:54:30 PST 2021


Author: Stella Laurenzo
Date: 2021-02-27T15:52:35-08:00
New Revision: 2ceedc3a201386c6cbbcea5cec3f5e01d04f6445

URL: https://github.com/llvm/llvm-project/commit/2ceedc3a201386c6cbbcea5cec3f5e01d04f6445
DIFF: https://github.com/llvm/llvm-project/commit/2ceedc3a201386c6cbbcea5cec3f5e01d04f6445.diff

LOG: [mlir][linalg] Add symbolic type conversion to linalg named ops.

This enables this kind of construct in the DSL to generate a named op that is polymorphic over numeric type variables `T` and `U`, generating the correct arithmetic casts at construction time:

```
@tc_def_op
def polymorphic_matmul(A=TensorDef(T1, S.M, S.K),
                       B=TensorDef(T2, S.K, S.N),
                       C=TensorDef(U, S.M, S.N, output=True)):
  implements(ContractionOpInterface)
  C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
```

Presently, this only supports type variables that are bound to the element type of one of the arguments, although a further extension that allows binding a type variable to an attribute would allow some more expressiveness and may be useful for some formulations. This is left to a future patch. In addition, this patch does not yet materialize the verifier support which ensures that types are bound correctly (for such simple examples, failing to do so will yield IR that fails verification, it just won't yet fail with a precise error).

Note that the full grid of extensions/truncation/int<->float conversions are supported, but many of them are lossy and higher level code needs to be mindful of numerics (it is not the job of this level).

As-is, this should be sufficient for most integer matmul scenarios we work with in typical quantization schemes.

Differential Revision: https://reviews.llvm.org/D97603

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
    mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 43e41ab68fd2..93bc5760ed0c 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -15,14 +15,17 @@ structured_op: !LinalgStructuredOpConfig
     name: A
     usage: input
     shape: affine_map<()[s0, s1, s2] -> (s0, s2)>
+    element_type_var: T1
   - !<LinalgTensorDef>
     name: B
     usage: input
     shape: affine_map<()[s0, s1, s2] -> (s2, s1)>
+    element_type_var: T2
   - !<LinalgTensorDef>
     name: C
     usage: output
     shape: affine_map<()[s0, s1, s2] -> (s0, s1)>
+    element_type_var: U
   indexing_maps: !LinalgIndexingMapsConfig
     static_indexing_maps:
     - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
@@ -46,7 +49,15 @@ structured_op: !LinalgStructuredOpConfig
             fn_name: mul
             operands:
             - !ScalarExpression
-              scalar_arg: A
+              symbolic_cast:
+                type_var: U
+                operands:
+                - !ScalarExpression
+                  scalar_arg: A
             - !ScalarExpression
-              scalar_arg: B
+              symbolic_cast:
+                type_var: U
+                operands:
+                - !ScalarExpression
+                  scalar_arg: B
 

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 77d8e7026b18..acc8ff1807c1 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -155,6 +155,45 @@ class RegionBuilderHelper {
 public:
   RegionBuilderHelper(Block &block) : block(block) {}
 
+  // Generates operations to cast the given operand to a specified type.
+  // If the cast cannot be performed, a warning will be issued and the
+  // operand returned as-is (which will presumably yield a verification
+  // issue downstream).
+  Value cast(Type toType, Value operand) {
+    OpBuilder builder = getBuilder(operand);
+    auto loc = operand.getLoc();
+
+    if (operand.getType() == toType)
+      return operand;
+    if (auto toIntType = toType.dyn_cast<IntegerType>()) {
+      // If operand is floating point, cast directly to the int type.
+      if (operand.getType().isa<FloatType>())
+        return builder.create<FPToSIOp>(loc, toType, operand);
+      if (auto fromIntType = operand.getType().dyn_cast<IntegerType>()) {
+        // Either sign extend or truncate.
+        if (toIntType.getWidth() > fromIntType.getWidth())
+          return builder.create<SignExtendIOp>(loc, toType, operand);
+        else if (toIntType.getWidth() < fromIntType.getWidth())
+          return builder.create<TruncateIOp>(loc, toType, operand);
+      }
+    } else if (auto toFloatType = toType.dyn_cast<FloatType>()) {
+      // If operand is integer, cast directly to the float type.
+      // Note that it is unclear how to cast from BF16<->FP16.
+      if (operand.getType().isa<IntegerType>())
+        return builder.create<SIToFPOp>(loc, toFloatType, operand);
+      if (auto fromFloatType = operand.getType().dyn_cast<FloatType>()) {
+        if (toFloatType.getWidth() > fromFloatType.getWidth())
+          return builder.create<FPExtOp>(loc, toFloatType, operand);
+        else if (toFloatType.getWidth() < fromFloatType.getWidth())
+          return builder.create<FPTruncOp>(loc, toFloatType, operand);
+      }
+    }
+
+    emitWarning(operand.getLoc()) << "could not cast operand of type "
+                                  << operand.getType() << " to " << toType;
+    return operand;
+  }
+
   Value applyfn__add(Value lhs, Value rhs) {
     OpBuilder builder = getBuilder(lhs);
     if (isFloatingPoint(lhs))

diff  --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
index 186fb9627219..be2b77591cd6 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
@@ -25,3 +25,107 @@ func @generalize_matmul_tensor_i32(%A : tensor<16x8xi32>, %B: tensor<8x32xi32>,
 // CHECK-NEXT:   %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i32
 // CHECK-NEXT:   linalg.yield %[[ADD]] : i32
 // CHECK-NEXT: -> tensor<16x32xi32>
+
+// -----
+// Verifies floating point to integer cast.
+func @generalize_matmul_tensor_f32_f32_i16(%A : tensor<16x8xf32>, %B: tensor<8x32xf32>, %C: tensor<16x32xi16>) -> tensor<16x32xi16> {
+  %0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xf32>, tensor<8x32xf32>)
+                          outs(%C: tensor<16x32xi16>) -> tensor<16x32xi16>
+  return %0: tensor<16x32xi16>
+}
+
+// CHECK:      ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: i16)
+// CHECK-NEXT:   %[[A_CAST:.+]] = fptosi %[[A_ARG]] : f32 to i16
+// CHECK-NEXT:   %[[B_CAST:.+]] = fptosi %[[B_ARG]] : f32 to i16
+// CHECK-NEXT:   %[[MUL:.+]] = muli %[[A_CAST]], %[[B_CAST]] : i16
+// CHECK-NEXT:   %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i16
+// CHECK-NEXT:   linalg.yield %[[ADD]] : i16
+// CHECK-NEXT: -> tensor<16x32xi16>
+
+// -----
+// Verifies sign extension cast.
+func @generalize_matmul_tensor_i8_i8_i32(%A : tensor<16x8xi8>, %B: tensor<8x32xi8>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> {
+  %0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xi8>, tensor<8x32xi8>)
+                          outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32>
+  return %0: tensor<16x32xi32>
+}
+
+// -----
+// Verifies that 
diff erent argument types is legal.
+func @generalize_matmul_tensor_i8_i16_i32(%A : tensor<16x8xi8>, %B: tensor<8x32xi16>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> {
+  %0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xi8>, tensor<8x32xi16>)
+                          outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32>
+  return %0: tensor<16x32xi32>
+}
+
+// CHECK:      ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i8, %[[C_ARG:.+]]: i32)
+// CHECK-NEXT:   %[[A_CAST:.+]] = sexti %[[A_ARG]] : i8 to i32
+// CHECK-NEXT:   %[[B_CAST:.+]] = sexti %[[B_ARG]] : i8 to i32
+// CHECK-NEXT:   %[[MUL:.+]] = muli %[[A_CAST]], %[[B_CAST]] : i32
+// CHECK-NEXT:   %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i32
+// CHECK-NEXT:   linalg.yield %[[ADD]] : i32
+// CHECK-NEXT: -> tensor<16x32xi32>
+
+// -----
+// Somewhat non-sensical but checks integer truncation cast.
+func @generalize_matmul_tensor_i32_i32_i16(%A : tensor<16x8xi32>, %B: tensor<8x32xi32>, %C: tensor<16x32xi16>) -> tensor<16x32xi16> {
+  %0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xi32>, tensor<8x32xi32>)
+                          outs(%C: tensor<16x32xi16>) -> tensor<16x32xi16>
+  return %0: tensor<16x32xi16>
+}
+
+// CHECK:      ^{{.*}}(%[[A_ARG:.+]]: i32, %[[B_ARG:.+]]: i32, %[[C_ARG:.+]]: i16)
+// CHECK-NEXT:   %[[A_CAST:.+]] = trunci %[[A_ARG]] : i32 to i16
+// CHECK-NEXT:   %[[B_CAST:.+]] = trunci %[[B_ARG]] : i32 to i16
+// CHECK-NEXT:   %[[MUL:.+]] = muli %[[A_CAST]], %[[B_CAST]] : i16
+// CHECK-NEXT:   %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i16
+// CHECK-NEXT:   linalg.yield %[[ADD]] : i16
+// CHECK-NEXT: -> tensor<16x32xi16>
+
+// -----
+// Verifies integer to floating point cast.
+func @generalize_matmul_tensor_i8_i8_f32(%A : tensor<16x8xi8>, %B: tensor<8x32xi8>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
+  %0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xi8>, tensor<8x32xi8>)
+                          outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
+  return %0: tensor<16x32xf32>
+}
+
+// CHECK:      ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i8, %[[C_ARG:.+]]: f32)
+// CHECK-NEXT:   %[[A_CAST:.+]] = sitofp %[[A_ARG]] : i8 to f32
+// CHECK-NEXT:   %[[B_CAST:.+]] = sitofp %[[B_ARG]] : i8 to f32
+// CHECK-NEXT:   %[[MUL:.+]] = mulf %[[A_CAST]], %[[B_CAST]] : f32
+// CHECK-NEXT:   %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32
+// CHECK-NEXT:   linalg.yield %[[ADD]] : f32
+// CHECK-NEXT: -> tensor<16x32xf32>
+
+// -----
+// Verifies floating point extension cast.
+func @generalize_matmul_tensor_f16_f16_f32(%A : tensor<16x8xf16>, %B: tensor<8x32xf16>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
+  %0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xf16>, tensor<8x32xf16>)
+                          outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
+  return %0: tensor<16x32xf32>
+}
+
+// CHECK:      ^{{.*}}(%[[A_ARG:.+]]: f16, %[[B_ARG:.+]]: f16, %[[C_ARG:.+]]: f32)
+// CHECK-NEXT:   %[[A_CAST:.+]] = fpext %[[A_ARG]] : f16 to f32
+// CHECK-NEXT:   %[[B_CAST:.+]] = fpext %[[B_ARG]] : f16 to f32
+// CHECK-NEXT:   %[[MUL:.+]] = mulf %[[A_CAST]], %[[B_CAST]] : f32
+// CHECK-NEXT:   %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32
+// CHECK-NEXT:   linalg.yield %[[ADD]] : f32
+// CHECK-NEXT: -> tensor<16x32xf32>
+
+// -----
+// Verifies floating point truncation.
+func @generalize_matmul_tensor_f64_f64_f32(%A : tensor<16x8xf64>, %B: tensor<8x32xf64>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
+  %0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xf64>, tensor<8x32xf64>)
+                          outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
+  return %0: tensor<16x32xf32>
+}
+
+// CHECK:      ^{{.*}}(%[[A_ARG:.+]]: f64, %[[B_ARG:.+]]: f64, %[[C_ARG:.+]]: f32)
+// CHECK-NEXT:   %[[A_CAST:.+]] = fptrunc %[[A_ARG]] : f64 to f32
+// CHECK-NEXT:   %[[B_CAST:.+]] = fptrunc %[[B_ARG]] : f64 to f32
+// CHECK-NEXT:   %[[MUL:.+]] = mulf %[[A_CAST]], %[[B_CAST]] : f32
+// CHECK-NEXT:   %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32
+// CHECK-NEXT:   linalg.yield %[[ADD]] : f32
+// CHECK-NEXT: -> tensor<16x32xf32>

diff  --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index f651b03e72fd..d59be1624851 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -72,6 +72,7 @@ struct LinalgTensorDef {
   std::string name;
   LinalgTensorUsageDef usage;
   SerializedAffineMap shape;
+  std::string elementTypeVar;
 };
 
 enum class LinalgIteratorTypeDef {
@@ -92,9 +93,17 @@ struct ScalarApply {
   std::vector<ScalarExpression> operands;
 };
 
+struct ScalarSymbolicCast {
+  std::string typeVar;
+  // NOTE: This must be of arity 1, but to break the self-referential cycle,
+  // we use a heap allocated vector.
+  std::vector<ScalarExpression> operands;
+};
+
 struct ScalarExpression {
-  Optional<std::string> scalarArg;
-  Optional<ScalarApply> scalarApply;
+  Optional<std::string> arg;
+  Optional<ScalarApply> apply;
+  Optional<ScalarSymbolicCast> symbolicCast;
 };
 
 struct ScalarAssign {
@@ -163,12 +172,15 @@ struct MappingTraits<LinalgStructuredOpConfig> {
 ///   - `shape`: An AffineMap from all op symbols to the specific shape
 ///     of this argument. Each shape must be normalized over the same list of
 ///     symbols and have no dimension inputs.
+///   - `element_type_var`: The symbolic type variable that binds to the scalar
+///     element type of this TensorDef.
 template <>
 struct MappingTraits<LinalgTensorDef> {
   static void mapping(IO &io, LinalgTensorDef &info) {
     io.mapRequired("name", info.name);
     io.mapRequired("usage", info.usage);
     io.mapRequired("shape", info.shape);
+    io.mapRequired("element_type_var", info.elementTypeVar);
   }
 };
 
@@ -230,11 +242,13 @@ struct MappingTraits<ScalarAssign> {
 ///   - `scalar_arg`: Name of an argument to the op.
 ///   - `scalar_apply`: Result of evaluating a named function (see
 ///      `ScalarApply`).
+///   - `symbolic_cast`: Cast to a symbolic TypeVar bound elsewhere.
 template <>
 struct MappingTraits<ScalarExpression> {
   static void mapping(IO &io, ScalarExpression &info) {
-    io.mapOptional("scalar_arg", info.scalarArg);
-    io.mapOptional("scalar_apply", info.scalarApply);
+    io.mapOptional("scalar_arg", info.arg);
+    io.mapOptional("scalar_apply", info.apply);
+    io.mapOptional("symbolic_cast", info.symbolicCast);
   }
 };
 
@@ -251,6 +265,14 @@ struct MappingTraits<ScalarApply> {
   }
 };
 
+template <>
+struct MappingTraits<ScalarSymbolicCast> {
+  static void mapping(IO &io, ScalarSymbolicCast &info) {
+    io.mapRequired("type_var", info.typeVar);
+    io.mapRequired("operands", info.operands);
+  }
+};
+
 /// Helper mapping which accesses an AffineMapAttr as a serialized string of
 /// the same.
 template <>
@@ -348,6 +370,15 @@ findTensorDefArgIndex(StringRef name, SmallVectorImpl<LinalgTensorDef> &args) {
   return None;
 }
 
+static Optional<int>
+findTypeVarArgIndex(StringRef typeVar, SmallVectorImpl<LinalgTensorDef> &args) {
+  for (auto it : llvm::enumerate(args)) {
+    if (it.value().elementTypeVar == typeVar)
+      return it.index();
+  }
+  return None;
+}
+
 static ScalarAssign *
 findAssignment(StringRef name, SmallVectorImpl<ScalarAssign> &assignments) {
   for (auto &assign : assignments) {
@@ -726,9 +757,9 @@ void {0}::regionBuilder(Block &block, ValueRange captures) {{
       std::function<Optional<std::string>(ScalarExpression &)>
           generateExpression =
               [&](ScalarExpression &expression) -> Optional<std::string> {
-        if (expression.scalarArg) {
-          Optional<int> argIndex =
-              findTensorDefArgIndex(*expression.scalarArg, args);
+        if (expression.arg) {
+          // Argument reference.
+          Optional<int> argIndex = findTensorDefArgIndex(*expression.arg, args);
           if (!argIndex) {
             emitError(genContext.getLoc())
                 << "scalar argument not defined on the op: " << arg.name;
@@ -736,10 +767,11 @@ void {0}::regionBuilder(Block &block, ValueRange captures) {{
           }
           return std::string(
               llvm::formatv("block.getArgument({0})", *argIndex));
-        } else if (expression.scalarApply) {
+        } else if (expression.apply) {
+          // Apply function.
           // Recursively generate operands.
           SmallVector<std::string> operandCppValues;
-          for (ScalarExpression &operand : expression.scalarApply->operands) {
+          for (ScalarExpression &operand : expression.apply->operands) {
             auto operandCppValue = generateExpression(operand);
             if (!operandCppValue)
               return None;
@@ -748,9 +780,41 @@ void {0}::regionBuilder(Block &block, ValueRange captures) {{
           std::string cppIdent = llvm::formatv("value{0}", ++localCounter);
           stmts.push_back(
               llvm::formatv("Value {0} = helper.applyfn__{1}({2});", cppIdent,
-                            expression.scalarApply->fnName,
+                            expression.apply->fnName,
                             interleaveToString(operandCppValues, ", ")));
           return cppIdent;
+        } else if (expression.symbolicCast) {
+          // Symbolic cast.
+          // Operands must be arity 1.
+          if (expression.symbolicCast->operands.size() != 1) {
+            emitError(genContext.getLoc())
+                << "symbolic_cast operand arity must be 1";
+            return None;
+          }
+          Optional<std::string> operandCppValue =
+              generateExpression(expression.symbolicCast->operands[0]);
+          if (!operandCppValue)
+            return None;
+
+          // Try to map the TypeVar to an arg index (which map to block arg
+          // indices), since we can just get that type directly.
+          // TODO: Handle free type variables which do not map to an argument.
+          Optional<int> typeArgIndex =
+              findTypeVarArgIndex(expression.symbolicCast->typeVar, args);
+          if (!typeArgIndex) {
+            emitError(genContext.getLoc())
+                << "type variable " << expression.symbolicCast->typeVar
+                << ", used in a symbolic cast must map to an argument but it "
+                << "does not";
+            return None;
+          }
+          std::string typeCppValue =
+              llvm::formatv("block.getArgument({0}).getType()", *typeArgIndex);
+          std::string cppIdent = llvm::formatv("value{0}", ++localCounter);
+          stmts.push_back(llvm::formatv("Value {0} = helper.cast({1}, {2});",
+                                        cppIdent, typeCppValue,
+                                        *operandCppValue));
+          return cppIdent;
         } else {
           emitError(genContext.getLoc()) << "unknown ScalarExpression type";
           return None;


        


More information about the Mlir-commits mailing list