Skip to content

Commit

Permalink
Merge pull request #83 from ECCCO-mission/speed-up
Browse files Browse the repository at this point in the history
Adds tracking of num iterations
  • Loading branch information
jmbhughes authored Jul 25, 2024
2 parents 96fe742 + 5dae4cc commit da4572a
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 16 deletions.
28 changes: 18 additions & 10 deletions overlappogram/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,17 @@ def unfold(config):

for alpha in config["model"]["alphas"]:
for rho in config["model"]["rhos"]:
print(80*"-")
print(80 * "-")
print(f"Beginning inversion for alpha={alpha}, rho={rho}.")
start = time.time()
em_cube, prediction, scores, unconverged_rows = inversion.invert(
overlappogram,
config["model"],
alpha,
rho,
num_threads=config["execution"]["num_threads"],
mode_switch_thread_count=config["execution"]["mode_switch_thread_count"],
mode=MODE_MAPPING.get(config['execution']['mode'], 'invalid')
em_cube, prediction, scores, unconverged_rows, n_iter = inversion.invert(
overlappogram,
config["model"],
alpha,
rho,
num_threads=config["execution"]["num_threads"],
mode_switch_thread_count=config["execution"]["mode_switch_thread_count"],
mode=MODE_MAPPING.get(config['execution']['mode'], 'invalid')
)
end = time.time()
print(
Expand All @@ -68,8 +68,11 @@ def unfold(config):
f"seconds; {len(unconverged_rows)} unconverged rows",
)

print(f"Unconverged rows: {unconverged_rows}")

postfix = (
"x" + str(config["inversion"]["solution_fov_width"]) + "_" + str(rho * 10) + "_" + str(alpha) + "_wpsf"
"x" + str(config["inversion"]["solution_fov_width"]) + "_" + str(rho * 10) + "_" + str(
alpha) + "_wpsf"
)
save_em_cube(
em_cube,
Expand All @@ -90,6 +93,11 @@ def unfold(config):
with open(scores_path, 'w') as f:
f.write("\n".join(scores.flatten().astype(str).tolist()))

niter_path = os.path.join(config["output"]["directory"],
f"{config['output']['prefix']}_niter_{postfix}.txt")
with open(niter_path, 'w') as f:
f.write("\n".join(n_iter.flatten().astype(str).tolist()))

if config["output"]["make_spectral"]:
spectral_images = create_spectrally_pure_images(
[em_cube], config["paths"]["gnt"], config["inversion"]["response_dependency_list"]
Expand Down
11 changes: 8 additions & 3 deletions overlappogram/inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(
self._row_scores: np.ndarray | None = None
self._overlappogram_width: int | None = None
self._overlappogram_height: int | None = None
self._n_iter: np.ndarray | None = None

self._thread_count_lock = Lock()

Expand Down Expand Up @@ -104,13 +105,15 @@ def _invert_image_row(self, row_index, chunk_index):
data_out = model.predict(masked_response_function)
em = model.coef_
score_data = model.score(masked_response_function, image_row)
n_iter = model.n_iter_
except ConvergenceWarning:
self._unconverged_rows.append(row_index)
em = np.zeros((self._num_slits * self._num_deps), dtype=np.float32)
data_out = np.zeros(self._overlappogram_width, dtype=np.float32)
score_data = -999
n_iter = -1

return row_index, em, data_out, score_data
return row_index, em, data_out, score_data, n_iter

def _progress_indicator(self, future):
"""used in multithreading to track progress of inversion"""
Expand Down Expand Up @@ -155,7 +158,7 @@ def _switch_to_row_inversion(self, model_config, alpha, rho, num_row_threads=50)

def _collect_results(self, mode_switch_thread_count, model_config, alpha, rho):
for future in concurrent.futures.as_completed(self.futures):
row_index, em, data_out, score_data = future.result()
row_index, em, data_out, score_data, n_iter = future.result()
for slit_num in range(self._num_slits):
if self._smooth_over == "dependence":
slit_em = em[slit_num * self._num_deps : (slit_num + 1) * self._num_deps]
Expand All @@ -164,6 +167,7 @@ def _collect_results(self, mode_switch_thread_count, model_config, alpha, rho):
self._em_data[row_index, slit_num, :] = slit_em
self._inversion_prediction[row_index, :] = data_out
self._row_scores[row_index] = score_data
self._n_iter[row_index] = n_iter

rows_remaining = self.total_row_count - self._completed_row_count

Expand All @@ -173,7 +177,6 @@ def _collect_results(self, mode_switch_thread_count, model_config, alpha, rho):

def _start_row_inversion(self, model_config, alpha, rho, num_threads):
self.executors = [concurrent.futures.ThreadPoolExecutor(max_workers=num_threads)]

self.futures = {}
self._models = []
for i, row_index in enumerate(range(self._detector_row_range[0], self._detector_row_range[1])):
Expand Down Expand Up @@ -253,6 +256,7 @@ def _initialize_with_overlappogram(self, overlappogram):
self._em_data = np.zeros((self._overlappogram_height, self._num_slits, self._num_deps), dtype=np.float32)
self._inversion_prediction = np.zeros((self._overlappogram_height, self._overlappogram_width), dtype=np.float32)
self._row_scores = np.zeros((self._overlappogram_height, 1), dtype=np.float32)
self._n_iter = np.zeros((self._overlappogram_height, 1), dtype=np.int32)

def invert(
self,
Expand Down Expand Up @@ -301,4 +305,5 @@ def invert(
NDCube(data=self._inversion_prediction, wcs=out_wcs, meta=self._response_meta),
self._row_scores,
self._unconverged_rows,
self._n_iter
)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ requires = ["setuptools",

[project]
name = "overlappogram"
version = "0.0.9"
version = "0.0.10"
dependencies = ["numpy<2.0.0",
"astropy",
"scikit-learn",
Expand Down
2 changes: 1 addition & 1 deletion tests/test_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_inversion_runs(tmp_path, inversion_mode, is_weighted):
detector_row_range=config["inversion"]["detector_row_range"],
)

em_cube, prediction, scores, unconverged_rows = inversion.invert(
em_cube, prediction, scores, unconverged_rows, niter = inversion.invert(
overlappogram,
config["model"],
3E-5,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_spectral.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_create_spectrally_pure_images(tmp_path):
detector_row_range=config["inversion"]["detector_row_range"],
)

em_cube, prediction, scores, unconverged_rows = inversion.invert(
em_cube, prediction, scores, unconverged_rows, _ = inversion.invert(
overlappogram,
config["model"],
3E-5,
Expand Down

0 comments on commit da4572a

Please sign in to comment.