Skip to content

Commit

Permalink
Improve typehint (#92)
Browse files Browse the repository at this point in the history
* Add pycln

* Use typevar.

* Embedding traits should not inherit from Baserecommender.
  • Loading branch information
tohtsky authored Mar 11, 2022
1 parent a1893be commit 9f97cb3
Show file tree
Hide file tree
Showing 15 changed files with 37 additions and 21 deletions.
5 changes: 5 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,8 @@ repos:
rev: 20.8b1
hooks:
- id: black
- repo: https://github.com/hadialqattan/pycln
rev: v1.2.4 # Possible releases: https://github.com/hadialqattan/pycln/releases
hooks:
- id: pycln
args: [--config=pyproject.toml]
2 changes: 0 additions & 2 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
import os
import sys
from typing import List

# Configuration file for the Sphinx documentation builder.
Expand Down
2 changes: 0 additions & 2 deletions examples/movielens/movielens_1m.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import json
import logging
import os
from typing import List, Tuple, Type

from scipy import sparse as sps
Expand Down
1 change: 0 additions & 1 deletion irspack/dataset/downloader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
import urllib.request
from abc import ABCMeta, abstractmethod
from io import BytesIO
Expand Down
2 changes: 1 addition & 1 deletion irspack/dataset/movielens/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from abc import ABCMeta, abstractmethod
from abc import abstractmethod
from io import BytesIO

import pandas as pd
Expand Down
2 changes: 1 addition & 1 deletion irspack/definitions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, List, Optional, Union
from typing import Union

import numpy as np
from scipy import sparse as sps
Expand Down
2 changes: 1 addition & 1 deletion irspack/optimizers/_optimizers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Any, Dict, List, Optional, Type
from typing import Any, Dict, List, Optional

from irspack.definitions import InteractionMatrix

Expand Down
4 changes: 3 additions & 1 deletion irspack/recommenders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
)
from irspack.recommenders.base_earlystop import BaseRecommenderWithEarlyStopping
from irspack.recommenders.dense_slim import DenseSLIMConfig, DenseSLIMRecommender
from irspack.recommenders.edlae import EDLAERecommender
from irspack.recommenders.edlae import EDLAEConfig, EDLAERecommender
from irspack.recommenders.ials import IALSConfig, IALSRecommender
from irspack.recommenders.knn import (
AsymmetricCosineKNNConfig,
Expand Down Expand Up @@ -41,6 +41,8 @@
"RP3betaRecommender",
"DenseSLIMConfig",
"DenseSLIMRecommender",
"EDLAERecommender",
"EDLAEConfig",
"SLIMConfig",
"SLIMRecommender",
"IALSConfig",
Expand Down
25 changes: 19 additions & 6 deletions irspack/recommenders/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
from abc import ABCMeta, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, Optional, Type, Union, no_type_check
from typing import (
TYPE_CHECKING,
Any,
Dict,
Optional,
Type,
TypeVar,
Union,
no_type_check,
)

import numpy as np
from optuna.trial import Trial
Expand All @@ -16,6 +25,8 @@
UserIndexArray,
)

R = TypeVar("R", bound="BaseRecommender")


def _sparse_to_array(U: Any) -> np.ndarray:
if sps.issparse(U):
Expand Down Expand Up @@ -78,15 +89,17 @@ def __init__(self, X_train_all: InteractionMatrix, **kwargs: Any) -> None:

@classmethod
def from_config(
cls, X_train_all: InteractionMatrix, config: RecommenderConfig
) -> "BaseRecommender":
cls: Type[R],
X_train_all: InteractionMatrix,
config: RecommenderConfig,
) -> R:
if not isinstance(config, cls.config_class):
raise ValueError(
f"Different config has been given. config must be {cls.config_class}"
)
return cls(X_train_all, **config.dict())

def learn(self) -> "BaseRecommender":
def learn(self: R) -> R:
"""Learns and returns itself.
Returns:
Expand Down Expand Up @@ -245,7 +258,7 @@ def get_score_block(self, begin: int, end: int) -> DenseScoreArray:
return _sparse_to_array(self.U[begin:end].dot(self._X_csc))


class BaseRecommenderWithUserEmbedding(BaseRecommender):
class BaseRecommenderWithUserEmbedding:
"""Defines a recommender with user embedding (e.g., matrix factorization.).
These class can be a base CF estimator for CB2CF (with user profile -> user embedding NN).
"""
Expand Down Expand Up @@ -276,7 +289,7 @@ def get_score_from_user_embedding(
raise NotImplementedError("get_score_from_item_embedding must be implemtented.")


class BaseRecommenderWithItemEmbedding(BaseRecommender):
class BaseRecommenderWithItemEmbedding:
"""Defines a recommender with item embedding (e.g., matrix factorization.).
These class can be a base CF estimator for CB2CF (with item profile -> item embedding NN).
"""
Expand Down
2 changes: 1 addition & 1 deletion irspack/recommenders/edlae.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import gc

import numpy as np
from scipy import linalg, sparse
from scipy import linalg

from ..definitions import InteractionMatrix
from .base import BaseSimilarityRecommender, RecommenderConfig
Expand Down
3 changes: 2 additions & 1 deletion irspack/recommenders/truncsvd.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import warnings
from typing import Optional

from numpy import random
from sklearn.decomposition import TruncatedSVD

from ..definitions import (
Expand All @@ -11,6 +10,7 @@
UserIndexArray,
)
from .base import (
BaseRecommender,
BaseRecommenderWithItemEmbedding,
BaseRecommenderWithUserEmbedding,
RecommenderConfig,
Expand All @@ -23,6 +23,7 @@ class TruncatedSVDConfig(RecommenderConfig):


class TruncatedSVDRecommender(
BaseRecommender,
BaseRecommenderWithUserEmbedding,
BaseRecommenderWithItemEmbedding,
):
Expand Down
1 change: 0 additions & 1 deletion irspack/split/userwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import numpy as np
import pandas as pd
from numpy.lib.arraysetops import unique
from scipy import sparse as sps

from irspack.definitions import InteractionMatrix, OptionalRandomState
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,6 @@ include_trailing_comma = true
line_length = 88
multi_line_output = 3
use_parentheses = true

[tool.pycln]
all = true
2 changes: 1 addition & 1 deletion tests/autopilot/mock_classes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pickle
import time
from typing import IO, Any, Dict, List, Optional
from typing import IO, Any, Dict, List

import numpy as np
import scipy.sparse as sps
Expand Down
2 changes: 0 additions & 2 deletions tests/utils/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import warnings

import numpy as np
import pytest
import scipy.sparse as sps
Expand Down

0 comments on commit 9f97cb3

Please sign in to comment.