Skip to content

Commit

Permalink
refactor: apply black formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
rickstaa committed Mar 15, 2024
1 parent 16ff564 commit bdd79c4
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,7 @@ def estimate_step_learning_rate(
decay_rate = get_exponential_decay_rate(
lr_start, lr_final, adjusted_total_steps
)
lr = float(
Decimal(lr_start) * (Decimal(decay_rate) ** Decimal(adjusted_step))
)
lr = float(Decimal(lr_start) * (Decimal(decay_rate) ** Decimal(adjusted_step)))
else:
supported_schedulers = ["LambdaLR", "ExponentialLR"]
raise ValueError(
Expand Down
4 changes: 1 addition & 3 deletions stable_learning_control/algos/tf2/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,9 +1023,7 @@ def sac(
type="warning",
)
decay_types[name] = lr_decay_type
lr_a_decay_type, lr_c_decay_type, lr_alpha_decay_type = (
decay_types.values()
)
lr_a_decay_type, lr_c_decay_type, lr_alpha_decay_type = decay_types.values()

# Calculate the number of learning rate scheduler steps.
if lr_decay_ref == "step":
Expand Down
4 changes: 3 additions & 1 deletion stable_learning_control/utils/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ def plot_data(
if isinstance(data, list):
data = pd.concat(data, ignore_index=True)
sns.set(style=style, font_scale=font_scale)
sns.lineplot(data=data, x=xaxis, y=value, hue=condition, errorbar=errorbar, **kwargs)
sns.lineplot(
data=data, x=xaxis, y=value, hue=condition, errorbar=errorbar, **kwargs
)
plt.legend(loc="best").set_draggable(True)

xscale = np.max(np.asarray(data[xaxis])) > 5e3
Expand Down

0 comments on commit bdd79c4

Please sign in to comment.