-
Notifications
You must be signed in to change notification settings - Fork 12
/
algorithm.py
49 lines (33 loc) · 1.15 KB
/
algorithm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from abc import ABC, ABCMeta, abstractmethod
__all__ = [
"Algorithm",
"CompressionAlgorithm",
"VisionAlgorithm",
"TextGenerationAlgorithm",
"AdaptAlgorithm",
]
class Algorithm(ABC):
"""Base class for all of Nyun's algorithms."""
@abstractmethod
def run(self):
"""The run method for the algorithm."""
class CompressionAlgorithm(Algorithm, metaclass=ABCMeta):
"""Base class for all of Nyun's compression algorithms."""
def run(self):
"""The run method for the algorithm."""
self.compress_model()
@abstractmethod
def compress_model(self):
"""Compress the model."""
class VisionAlgorithm(CompressionAlgorithm):
"""Base class for all of Nyun's Vision compression algorithms."""
class TextGenerationAlgorithm(CompressionAlgorithm):
"""Base class for all of Nyun's Text Generation compression algorithms."""
class AdaptionAlgorithm(Algorithm):
"""Base class for all of Nyun's Adapt compression algorithms."""
def run(self):
"""The run method for the algorithm."""
self.adapt_model()
@abstractmethod
def adapt_model(self):
"""Adapt the model."""