Skip to content

Commit

Permalink
Make trailedAssociatorTask
Browse files Browse the repository at this point in the history
Review update

Review Update
  • Loading branch information
bsmartradio committed Sep 19, 2023
1 parent 5d6d4c1 commit 1af91a5
Show file tree
Hide file tree
Showing 7 changed files with 278 additions and 17 deletions.
1 change: 1 addition & 0 deletions python/lsst/ap/association/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.

from .version import *
from .trailedSourceFilter import *
from .association import *
from .diaForcedSource import *
from .loadDiaCatalogs import *
Expand Down
40 changes: 36 additions & 4 deletions python/lsst/ap/association/association.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import lsst.pex.config as pexConfig
import lsst.pipe.base as pipeBase
from lsst.utils.timer import timeMethod
from .trailedSourceFilter import TrailedSourceFilterTask

# Enforce an error for unsafe column/array value setting in pandas.
pd.options.mode.chained_assignment = 'raise'
Expand All @@ -40,13 +41,27 @@
class AssociationConfig(pexConfig.Config):
"""Config class for AssociationTask.
"""

maxDistArcSeconds = pexConfig.Field(
dtype=float,
doc='Maximum distance in arcseconds to test for a DIASource to be a '
'match to a DIAObject.',
doc="Maximum distance in arcseconds to test for a DIASource to be a "
"match to a DIAObject.",
default=1.0,
)

trailedSourceFilter = pexConfig.ConfigurableField(
target=TrailedSourceFilterTask,
doc="Subtask to remove long trailed sources based on catalog source "
"morphological measurements.",
)

doTrailedSourceFilter = pexConfig.Field(
doc="Run traildeSourceFilter to remove long trailed sources from "
"output catalog.",
dtype=bool,
default=True,
)


class AssociationTask(pipeBase.Task):
"""Associate DIAOSources into existing DIAObjects.
Expand All @@ -60,10 +75,16 @@ class AssociationTask(pipeBase.Task):
ConfigClass = AssociationConfig
_DefaultName = "association"

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.config.doTrailedSourceFilter:
self.makeSubtask("trailedSourceFilter")

@timeMethod
def run(self,
diaSources,
diaObjects):
diaObjects,
exposure_time=None):
"""Associate the new DiaSources with existing DiaObjects.
Parameters
Expand All @@ -72,6 +93,8 @@ def run(self,
New DIASources to be associated with existing DIAObjects.
diaObjects : `pandas.DataFrame`
Existing diaObjects from the Apdb.
exposure_time : `float`, optional
Exposure time from difference image.
Returns
-------
Expand All @@ -98,7 +121,16 @@ def run(self,
nUpdatedDiaObjects=0,
nUnassociatedDiaObjects=0)

matchResult = self.associate_sources(diaObjects, diaSources)
if self.config.doTrailedSourceFilter:
diaTrailedResult = self.trailedSourceFilter.run(diaSources, exposure_time)
matchResult = self.associate_sources(diaObjects, diaTrailedResult.diaSources)

self.log.warning("%i DIASources exceed maxTrailLength, dropping "
"from source catalog."
% len(diaTrailedResult.trailedDiaSources))

else:
matchResult = self.associate_sources(diaObjects, diaSources)

mask = matchResult.diaSources["diaObjectId"] != 0

Expand Down
12 changes: 6 additions & 6 deletions python/lsst/ap/association/diaPipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
Currently loads directly from the Apdb rather than pre-loading.
"""

__all__ = ("DiaPipelineConfig",
"DiaPipelineTask",
"DiaPipelineConnections")

import pandas as pd

import lsst.dax.apdb as daxApdb
Expand All @@ -44,10 +48,6 @@
PackageAlertsTask)
from lsst.ap.association.ssoAssociation import SolarSystemAssociationTask

__all__ = ("DiaPipelineConfig",
"DiaPipelineTask",
"DiaPipelineConnections")


class DiaPipelineConnections(
pipeBase.PipelineTaskConnections,
Expand Down Expand Up @@ -367,8 +367,8 @@ def run(self,
loaderResult = self.diaCatalogLoader.run(diffIm, self.apdb)

# Associate new DiaSources with existing DiaObjects.
assocResults = self.associator.run(diaSourceTable,
loaderResult.diaObjects)
assocResults = self.associator.run(diaSourceTable, loaderResult.diaObjects,
exposure_time=diffIm.getInfo().getVisitInfo().getExposureTime())
if self.config.doSolarSystemAssociation:
ssoAssocResult = self.solarSystemAssociator.run(
assocResults.unAssocDiaSources,
Expand Down
112 changes: 112 additions & 0 deletions python/lsst/ap/association/trailedSourceFilter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# This file is part of ap_association.
#
# Developed for the LSST Data Management System.
# This product includes software developed by the LSST Project
# (https://www.lsst.org).
# See the COPYRIGHT file at the top-level directory of this distribution
# for details of code ownership.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

__all__ = ("TrailedSourceFilterTask", "TrailedSourceFilterConfig")

import lsst.pex.config as pexConfig
import lsst.pipe.base as pipeBase
from lsst.utils.timer import timeMethod


class TrailedSourceFilterConfig(pexConfig.Config):
"""Config class for TrailedSourceFilterTask.
"""

maxTrailLength = pexConfig.Field(
dtype=float,
doc="Length of long trailed sources to remove from the input catalog, "
"in arcseconds per second. Default comes from DMTN-199, which "
"requires removal of sources with trails longer than 10 "
"degrees/day, which is 36000/3600/24arcsec/second, or roughly"
"0.416 arcseconds per second.",
default=36000/3600.0/24.0,
)


class TrailedSourceFilterTask(pipeBase.Task):
"""Find trailed sources in DIASources and filter them as per DMTN-199
guidelines.
This task checks the length of trailLength in the DIASource catalog using
a given arcsecond/second rate from maxTrailLength and the exposure time.
The two values are used to calculate the maximum allowed trail length and
filters out any trail longer than the maximum. The maxTrailLength is
outlined in DMTN-199 and determines the default value.
"""

ConfigClass = TrailedSourceFilterConfig
_DefaultName = "trailedSourceFilter"

@timeMethod
def run(self, dia_sources, exposure_time):
"""Remove trailed sources longer than ``config.maxTrailLength`` from
the input catalog.
Parameters
----------
dia_sources : `pandas.DataFrame`
New DIASources to be checked for trailed sources.
exposure_time : `float`
Exposure time from difference image.
Returns
-------
result : `lsst.pipe.base.Struct`
Results struct with components.
- ``dia_sources`` : DIASource table that is free from unwanted
trailed sources. (`pandas.DataFrame`)
- ``trailed_dia_sources`` : DIASources that have trails which
exceed maxTrailLength/second*exposure_time.
(`pandas.DataFrame`)
"""

trail_mask = self._check_dia_source_trail(dia_sources, exposure_time)

return pipeBase.Struct(
diaSources=dia_sources[~trail_mask].reset_index(drop=True),
trailedDiaSources=dia_sources[trail_mask].reset_index(drop=True))

def _check_dia_source_trail(self, dia_sources, exposure_time):
"""Find DiaSources that have long trails.
Creates a mask for sources with lengths greater than 0.416
arcseconds/second multiplied by the exposure time.
Parameters
----------
dia_sources : `pandas.DataFrame`
Input DIASources to check for trail lengths.
exposure_time : `float`
Exposure time from difference image.
Returns
-------
trail_mask : `pandas.DataFrame`
Boolean mask for DIASources which are greater than the
cutoff length.
"""

trail_mask = (dia_sources.loc[:, "trailLength"].values[:]
>= (self.config.maxTrailLength*exposure_time))

return trail_mask
32 changes: 26 additions & 6 deletions tests/test_association_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import numpy as np
import pandas as pd
import unittest

import lsst.geom as geom
import lsst.utils.tests

Expand All @@ -46,20 +45,23 @@ def setUp(self):
self.diaSources = pd.DataFrame(data=[
{"ra": 0.04*idx + scatter*rng.uniform(-1, 1),
"dec": 0.04*idx + scatter*rng.uniform(-1, 1),
"diaSourceId": idx + 1 + self.nObjects, "diaObjectId": 0}
"diaSourceId": idx + 1 + self.nObjects, "diaObjectId": 0, "trailLength": 5.5*idx}
for idx in range(self.nSources)])
self.diaSourceZeroScatter = pd.DataFrame(data=[
{"ra": 0.04*idx,
"dec": 0.04*idx,
"diaSourceId": idx + 1 + self.nObjects, "diaObjectId": 0}
"diaSourceId": idx + 1 + self.nObjects, "diaObjectId": 0, "trailLength": 5.5*idx}
for idx in range(self.nSources)])
self.exposure_time = 30.0

def test_run(self):
"""Test the full task by associating a set of diaSources to
existing diaObjects.
"""
assocTask = AssociationTask()
results = assocTask.run(self.diaSources, self.diaObjects)
config = AssociationTask.ConfigClass()
config.doTrailedSourceFilter = False
assocTask = AssociationTask(config=config)
results = assocTask.run(self.diaSources, self.diaObjects, exposure_time=self.exposure_time)

self.assertEqual(results.nUpdatedDiaObjects, len(self.diaObjects) - 1)
self.assertEqual(results.nUnassociatedDiaObjects, 1)
Expand All @@ -75,13 +77,31 @@ def test_run(self):
[0]):
self.assertEqual(test_obj_id, expected_obj_id)

def test_run_trailed_sources(self):
"""Test the full task by associating a set of diaSources to
existing diaObjects when trailed sources are filtered.
This should filter out two of the five sources based on trail length,
leaving one unassociated diaSource and two associated diaSources.
"""
assocTask = AssociationTask()
results = assocTask.run(self.diaSources, self.diaObjects, exposure_time=self.exposure_time)

self.assertEqual(results.nUpdatedDiaObjects, len(self.diaObjects) - 3)
self.assertEqual(results.nUnassociatedDiaObjects, 3)
self.assertEqual(len(results.matchedDiaSources), len(self.diaObjects) - 3)
self.assertEqual(len(results.unAssocDiaSources), 1)
np.testing.assert_array_equal(results.matchedDiaSources["diaObjectId"].values, [1, 2])
np.testing.assert_array_equal(results.unAssocDiaSources["diaObjectId"].values, [0])

def test_run_no_existing_objects(self):
"""Test the run method with a completely empty database.
"""
assocTask = AssociationTask()
results = assocTask.run(
self.diaSources,
pd.DataFrame(columns=["ra", "dec", "diaObjectId"]))
pd.DataFrame(columns=["ra", "dec", "diaObjectId", "trailLength"]),
exposure_time=self.exposure_time)
self.assertEqual(results.nUpdatedDiaObjects, 0)
self.assertEqual(results.nUnassociatedDiaObjects, 0)
self.assertEqual(len(results.matchedDiaSources), 0)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_diaPipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def solarSystemAssociator_run(self, unAssocDiaSources, solarSystemObjectTable, d
unAssocDiaSources=MagicMock(spec=pd.DataFrame()))

@lsst.utils.timer.timeMethod
def associator_run(self, table, diaObjects):
def associator_run(self, table, diaObjects, exposure_time=None):
return lsst.pipe.base.Struct(nUpdatedDiaObjects=2, nUnassociatedDiaObjects=3,
matchedDiaSources=MagicMock(spec=pd.DataFrame()),
unAssocDiaSources=MagicMock(spec=pd.DataFrame()))
Expand Down
Loading

0 comments on commit 1af91a5

Please sign in to comment.