[llvm] 70caa31 - [Matrix] Refactor shape info computation (NFCI).

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Mon Feb 12 08:03:20 PST 2024


Author: Florian Hahn
Date: 2024-02-12T16:02:59Z
New Revision: 70caa316e955a35904e34961c79d75935b7d906f

URL: https://github.com/llvm/llvm-project/commit/70caa316e955a35904e34961c79d75935b7d906f
DIFF: https://github.com/llvm/llvm-project/commit/70caa316e955a35904e34961c79d75935b7d906f.diff

LOG: [Matrix] Refactor shape info computation (NFCI).

Factor our forward shape computation for a given instruction. This
allows re-use in a follow-up fix.

Added: 
    

Modified: 
    llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index b528762b545659..03e289f7a087ac 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -192,6 +192,109 @@ Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride,
   return VecStart;
 }
 
+namespace {
+struct ShapeInfo {
+  unsigned NumRows;
+  unsigned NumColumns;
+
+  bool IsColumnMajor;
+
+  ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0)
+      : NumRows(NumRows), NumColumns(NumColumns),
+        IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
+
+  ShapeInfo(Value *NumRows, Value *NumColumns)
+      : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(),
+                  cast<ConstantInt>(NumColumns)->getZExtValue()) {}
+
+  bool operator==(const ShapeInfo &other) {
+    return NumRows == other.NumRows && NumColumns == other.NumColumns;
+  }
+  bool operator!=(const ShapeInfo &other) { return !(*this == other); }
+
+  /// Returns true if shape-information is defined, meaning both dimensions
+  /// are != 0.
+  operator bool() const {
+    assert(NumRows == 0 || NumColumns != 0);
+    return NumRows != 0;
+  }
+
+  unsigned getStride() const {
+    if (IsColumnMajor)
+      return NumRows;
+    return NumColumns;
+  }
+
+  unsigned getNumVectors() const {
+    if (IsColumnMajor)
+      return NumColumns;
+    return NumRows;
+  }
+
+  /// Returns the transposed shape.
+  ShapeInfo t() const { return ShapeInfo(NumColumns, NumRows); }
+};
+} // namespace
+
+static bool isUniformShape(Value *V) {
+  Instruction *I = dyn_cast<Instruction>(V);
+  if (!I)
+    return true;
+
+  switch (I->getOpcode()) {
+  case Instruction::FAdd:
+  case Instruction::FSub:
+  case Instruction::FMul: // Scalar multiply.
+  case Instruction::FNeg:
+  case Instruction::Add:
+  case Instruction::Mul:
+  case Instruction::Sub:
+    return true;
+  default:
+    return false;
+  }
+}
+
+/// Return the ShapeInfo for the result of \p I, it it can be determined.
+static std::optional<ShapeInfo>
+computeShapeInfoForInst(Instruction *I,
+                        const ValueMap<Value *, ShapeInfo> &ShapeMap) {
+  Value *M;
+  Value *N;
+  Value *K;
+  if (match(I, m_Intrinsic<Intrinsic::matrix_multiply>(
+                   m_Value(), m_Value(), m_Value(M), m_Value(N), m_Value(K))))
+    return ShapeInfo(M, K);
+  if (match(I, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(), m_Value(M),
+                                                        m_Value(N)))) {
+    // Flip dimensions.
+    return ShapeInfo(N, M);
+  }
+  if (match(I, m_Intrinsic<Intrinsic::matrix_column_major_store>(
+                   m_Value(), m_Value(), m_Value(), m_Value(), m_Value(M),
+                   m_Value(N))))
+    return ShapeInfo(N, M);
+  if (match(I, m_Intrinsic<Intrinsic::matrix_column_major_load>(
+                   m_Value(), m_Value(), m_Value(), m_Value(M), m_Value(N))))
+    return ShapeInfo(M, N);
+  Value *MatrixA;
+  if (match(I, m_Store(m_Value(MatrixA), m_Value()))) {
+    auto OpShape = ShapeMap.find(MatrixA);
+    if (OpShape != ShapeMap.end())
+      return OpShape->second;
+  }
+
+  if (isUniformShape(I)) {
+    // Find the first operand that has a known shape and use that.
+    for (auto &Op : I->operands()) {
+      auto OpShape = ShapeMap.find(Op.get());
+      if (OpShape != ShapeMap.end())
+        return OpShape->second;
+    }
+  }
+  return std::nullopt;
+}
+
 /// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics.
 ///
 /// Currently, the lowering for each matrix intrinsic is done as follows:
@@ -383,48 +486,6 @@ class LowerMatrixIntrinsics {
     }
   };
 
-  struct ShapeInfo {
-    unsigned NumRows;
-    unsigned NumColumns;
-
-    bool IsColumnMajor;
-
-    ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0)
-        : NumRows(NumRows), NumColumns(NumColumns),
-          IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
-
-    ShapeInfo(Value *NumRows, Value *NumColumns)
-        : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(),
-                    cast<ConstantInt>(NumColumns)->getZExtValue()) {}
-
-    bool operator==(const ShapeInfo &other) {
-      return NumRows == other.NumRows && NumColumns == other.NumColumns;
-    }
-    bool operator!=(const ShapeInfo &other) { return !(*this == other); }
-
-    /// Returns true if shape-information is defined, meaning both dimensions
-    /// are != 0.
-    operator bool() const {
-      assert(NumRows == 0 || NumColumns != 0);
-      return NumRows != 0;
-    }
-
-    unsigned getStride() const {
-      if (IsColumnMajor)
-        return NumRows;
-      return NumColumns;
-    }
-
-    unsigned getNumVectors() const {
-      if (IsColumnMajor)
-        return NumColumns;
-      return NumRows;
-    }
-
-    /// Returns the transposed shape.
-    ShapeInfo t() const { return ShapeInfo(NumColumns, NumRows); }
-  };
-
   /// Maps instructions to their shape information. The shape information
   /// describes the shape to be used while lowering. This matches the shape of
   /// the result value of the instruction, with the only exceptions being store
@@ -554,25 +615,6 @@ class LowerMatrixIntrinsics {
     return true;
   }
 
-  bool isUniformShape(Value *V) {
-    Instruction *I = dyn_cast<Instruction>(V);
-    if (!I)
-      return true;
-
-    switch (I->getOpcode()) {
-    case Instruction::FAdd:
-    case Instruction::FSub:
-    case Instruction::FMul: // Scalar multiply.
-    case Instruction::FNeg:
-    case Instruction::Add:
-    case Instruction::Mul:
-    case Instruction::Sub:
-      return true;
-    default:
-      return false;
-    }
-  }
-
   /// Returns true if shape information can be used for \p V. The supported
   /// instructions must match the instructions that can be lowered by this pass.
   bool supportsShapeInfo(Value *V) {
@@ -610,43 +652,8 @@ class LowerMatrixIntrinsics {
 
       // New entry, set the value and insert operands
       bool Propagate = false;
-
-      Value *MatrixA;
-      Value *MatrixB;
-      Value *M;
-      Value *N;
-      Value *K;
-      if (match(Inst, m_Intrinsic<Intrinsic::matrix_multiply>(
-                          m_Value(MatrixA), m_Value(MatrixB), m_Value(M),
-                          m_Value(N), m_Value(K)))) {
-        Propagate = setShapeInfo(Inst, {M, K});
-      } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_transpose>(
-                                 m_Value(MatrixA), m_Value(M), m_Value(N)))) {
-        // Flip dimensions.
-        Propagate = setShapeInfo(Inst, {N, M});
-      } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_store>(
-                                 m_Value(MatrixA), m_Value(), m_Value(),
-                                 m_Value(), m_Value(M), m_Value(N)))) {
-        Propagate = setShapeInfo(Inst, {N, M});
-      } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_load>(
-                                 m_Value(), m_Value(), m_Value(), m_Value(M),
-                                 m_Value(N)))) {
-        Propagate = setShapeInfo(Inst, {M, N});
-      } else if (match(Inst, m_Store(m_Value(MatrixA), m_Value()))) {
-        auto OpShape = ShapeMap.find(MatrixA);
-        if (OpShape != ShapeMap.end())
-          setShapeInfo(Inst, OpShape->second);
-        continue;
-      } else if (isUniformShape(Inst)) {
-        // Find the first operand that has a known shape and use that.
-        for (auto &Op : Inst->operands()) {
-          auto OpShape = ShapeMap.find(Op.get());
-          if (OpShape != ShapeMap.end()) {
-            Propagate |= setShapeInfo(Inst, OpShape->second);
-            break;
-          }
-        }
-      }
+      if (auto SI = computeShapeInfoForInst(Inst, ShapeMap))
+        Propagate = setShapeInfo(Inst, *SI);
 
       if (Propagate) {
         NewWorkList.push_back(Inst);


        


More information about the llvm-commits mailing list