Skip to content

Commit

Permalink
Minor refactor of data filter
Browse files Browse the repository at this point in the history
  • Loading branch information
bsdz committed Jan 15, 2024
1 parent fa69506 commit d798d46
Show file tree
Hide file tree
Showing 6 changed files with 13 additions and 6 deletions.
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[flake8]
max-line-length = 88
extend-ignore = E203, E704, E741
extend-ignore = E203, E704, E741, F401
5 changes: 4 additions & 1 deletion yabte/backtest/asset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

from dataclasses import dataclass
from decimal import Decimal
from typing import TypeAlias, TypeVar, Union, cast
from typing import TYPE_CHECKING, TypeAlias, TypeVar, Union, cast

import pandas as pd
from mypy_extensions import mypyc_attr
Expand All @@ -9,6 +11,7 @@


# use ints until mypyc supports IntFlag
# https://github.com/mypyc/mypyc/issues/1022
AssetDataFieldInfo = int
ADFI_AVAILABLE_AT_CLOSE: int = 1
ADFI_AVAILABLE_AT_OPEN: int = 2
Expand Down
2 changes: 1 addition & 1 deletion yabte/backtest/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def _check_data(df, asset_map):

# check and fix data for each asset
dfs = {
asset.data_label: asset.check_and_fix_data(df[asset.data_label])
asset.data_label: asset.check_and_fix_data(asset._filter_data(df))
for asset_name, asset in asset_map.items()
}

Expand Down
6 changes: 5 additions & 1 deletion yabte/backtest/transaction.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
from __future__ import annotations

import logging
from dataclasses import dataclass
from decimal import Decimal
from typing import TYPE_CHECKING

import pandas as pd

# TODO: use explicit imports until mypyc fixes attribute lookups in dataclass
# (https://github.com/mypyc/mypyc/issues/1000)
from pandas import Timestamp # type: ignore

from .asset import AssetName
if TYPE_CHECKING:
from .asset import AssetName

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion yabte/utilities/plot/matplotlib/strategy_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def plot_strategy_runner(sr: StrategyRunner, settings: dict[str, Any] | None = N

for book, axs in zip(sr.books, axss.T):
for i, asset in enumerate(traded_assets):
prices = sr.data[asset.data_label]
prices = asset._filter_data(sr.data)

up = prices[prices.Close >= prices.Open]
down = prices[prices.Close < prices.Open]
Expand Down
2 changes: 1 addition & 1 deletion yabte/utilities/plot/plotly/strategy_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def plot_strategy_runner(sr: StrategyRunner, settings: dict[str, Any] | None = N

for col, book in enumerate(sr.books, start=1):
for row, asset in enumerate(traded_assets, start=1):
prices = sr.data[asset.data_label]
prices = asset._filter_data(sr.data)

fig.add_trace(
go.Candlestick(
Expand Down

0 comments on commit d798d46

Please sign in to comment.