fix(mypy): type annotations for linear algebra algorithms (#4317)

* fix(mypy): type annotations for linear algebra algorithms

* refactor: remove linear algebra directory from mypy exclude
This commit is contained in:
Dhruv Manilawala
2021-04-05 19:07:38 +05:30
committed by GitHub
parent 20c7518028
commit 8c2986026b
8 changed files with 100 additions and 74 deletions

View File

@@ -3,10 +3,12 @@ Resources:
- https://en.wikipedia.org/wiki/Conjugate_gradient_method
- https://en.wikipedia.org/wiki/Definite_symmetric_matrix
"""
from typing import Any
import numpy as np
def _is_matrix_spd(matrix: np.array) -> bool:
def _is_matrix_spd(matrix: np.ndarray) -> bool:
"""
Returns True if input matrix is symmetric positive definite.
Returns False otherwise.
@@ -38,10 +40,11 @@ def _is_matrix_spd(matrix: np.array) -> bool:
eigen_values, _ = np.linalg.eigh(matrix)
# Check sign of all eigenvalues.
return np.all(eigen_values > 0)
# np.all returns a value of type np.bool_
return bool(np.all(eigen_values > 0))
def _create_spd_matrix(dimension: np.int64) -> np.array:
def _create_spd_matrix(dimension: int) -> Any:
"""
Returns a symmetric positive definite matrix given a dimension.
@@ -64,11 +67,11 @@ def _create_spd_matrix(dimension: np.int64) -> np.array:
def conjugate_gradient(
spd_matrix: np.array,
load_vector: np.array,
spd_matrix: np.ndarray,
load_vector: np.ndarray,
max_iterations: int = 1000,
tol: float = 1e-8,
) -> np.array:
) -> Any:
"""
Returns solution to the linear system np.dot(spd_matrix, x) = b.
@@ -141,6 +144,8 @@ def conjugate_gradient(
# Update number of iterations.
iterations += 1
if iterations > max_iterations:
break
return x