Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add track associations based on raw hit association, weighted by number of hits #1564

Merged
merged 11 commits into from
Aug 14, 2024
Merged
49 changes: 47 additions & 2 deletions src/algorithms/tracking/ActsToTracks.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,23 @@
#include <Acts/EventData/TrackStateType.hpp>
#include <ActsExamples/EventData/IndexSourceLink.hpp>
#include <edm4eic/Cov6f.h>
#include <edm4eic/EDM4eicVersion.h>
#include <edm4eic/RawTrackerHit.h>
#include <edm4eic/TrackerHit.h>
#include <edm4hep/MCParticleCollection.h>
#include <edm4hep/SimTrackerHit.h>
#include <edm4hep/Vector2f.h>
#include <edm4hep/Vector3f.h>
#include <fmt/core.h>
#include <podio/ObjectID.h>
#include <podio/RelationRange.h>
#include <Eigen/Core>
#include <array>
#include <cmath>
#include <cstddef>
#include <gsl/pointers>
#include <map>
#include <numeric>
#include <optional>
#include <utility>

Expand All @@ -42,8 +51,8 @@ void ActsToTracks::init() {
}

void ActsToTracks::process(const Input& input, const Output& output) const {
const auto [meas2Ds, acts_trajectories] = input;
auto [trajectories, track_parameters, tracks] = output;
const auto [meas2Ds, acts_trajectories, raw_hit_assocs] = input;
auto [trajectories, track_parameters, tracks, tracks_assoc] = output;

// Loop over trajectories
for (const auto traj : acts_trajectories) {
Expand Down Expand Up @@ -147,6 +156,9 @@ void ActsToTracks::process(const Input& input, const Output& output) const {
);
track.setTrajectory(trajectory); // Trajectory of this track

// Determine track association with MCParticle, weighted by number of used measurements
std::map<edm4hep::MCParticle,double> mcparticle_weight_by_hit_count;

// save measurement2d to good measurements or outliers according to srclink index
// fix me: ideally, this should be integrated into multitrajectoryhelper
// fix me: should say "OutlierMeasurements" instead of "OutlierHits" etc
Expand All @@ -173,6 +185,23 @@ void ActsToTracks::process(const Input& input, const Output& output) const {
debug("Measurement on geo id={}, index={}, loc={},{}",
geoID, srclink_index, meas2D.getLoc().a, meas2D.getLoc().b);

// Determine track associations if hit associations provided
// FIXME: not able to check whether optional inputs were provided
//if (raw_hit_assocs->has_value()) {
#if EDM4EIC_VERSION_MAJOR >= 7
for (auto& hit : meas2D.getHits()) {
auto raw_hit = hit.getRawHit();
for (const auto raw_hit_assoc : *raw_hit_assocs) {
if (raw_hit_assoc.getRawHit() == raw_hit) {
auto sim_hit = raw_hit_assoc.getSimHit();
auto mc_particle = sim_hit.getMCParticle();
mcparticle_weight_by_hit_count[mc_particle]++;
}
}
}
#endif
//}

}
else if (typeFlags.test(Acts::TrackStateFlag::OutlierFlag)) {
trajectory.addToOutliers_deprecated(meas2D);
Expand All @@ -185,6 +214,22 @@ void ActsToTracks::process(const Input& input, const Output& output) const {

});

// Store track associations if hit associations provided
// FIXME: not able to check whether optional inputs were provided
//if (raw_hit_assocs->has_value()) {
double total_weight = std::accumulate(
mcparticle_weight_by_hit_count.begin(), mcparticle_weight_by_hit_count.end(),
0, [](const double sum, const auto& i) { return sum + i.second; });
for (const auto& [mcparticle, weight] : mcparticle_weight_by_hit_count) {
bschmookler marked this conversation as resolved.
Show resolved Hide resolved
auto track_assoc = tracks_assoc->create();
track_assoc.setRec(track);
track_assoc.setSim(mcparticle);
double normalized_weight = weight / total_weight;
track_assoc.setWeight(normalized_weight);
debug("track {}: mcparticle {} weight {}", track.id().index, mcparticle.id().index, normalized_weight);
}
//}

}
}
}
Expand Down
32 changes: 29 additions & 3 deletions src/algorithms/tracking/ActsToTracks.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@

#include <ActsExamples/EventData/Trajectories.hpp>
#include <algorithms/algorithm.h>
#include <edm4eic/MCRecoTrackParticleAssociationCollection.h>
#include <edm4eic/MCRecoTrackerHitAssociationCollection.h>
#include <edm4eic/Measurement2DCollection.h>
#include <edm4eic/TrackCollection.h>
#include <edm4eic/TrackParametersCollection.h>
#include <edm4eic/TrajectoryCollection.h>
#include <optional>
#include <string>
#include <string_view>
#include <vector>
Expand All @@ -17,13 +20,36 @@ namespace eicrecon {

using ActsToTracksAlgorithm =
algorithms::Algorithm<
algorithms::Input<edm4eic::Measurement2DCollection, std::vector<ActsExamples::Trajectories>>,
algorithms::Output<edm4eic::TrajectoryCollection, edm4eic::TrackParametersCollection, edm4eic::TrackCollection>
algorithms::Input<
edm4eic::Measurement2DCollection,
std::vector<ActsExamples::Trajectories>,
std::optional<edm4eic::MCRecoTrackerHitAssociationCollection>
bschmookler marked this conversation as resolved.
Show resolved Hide resolved
>,
algorithms::Output<
edm4eic::TrajectoryCollection,
edm4eic::TrackParametersCollection,
edm4eic::TrackCollection,
std::optional<edm4eic::MCRecoTrackParticleAssociationCollection>
>
>;

class ActsToTracks : public ActsToTracksAlgorithm {
public:
ActsToTracks(std::string_view name) : ActsToTracksAlgorithm{name, {"inputActsTrajectories", "inputActsConstTrackContainer"}, {"outputTrajectoryCollection", "outputTrackParametersCollection", "outputTrackCollection"}, "Converts ACTS trajectories to EDM4eic"} {};
ActsToTracks(std::string_view name)
: ActsToTracksAlgorithm{
name,
{
"inputMeasurements",
"inputActsTrajectories",
"inputRawTrackerHitAssociations",
},
{
"outputTrajectories",
"outputTrackParameters",
"outputTracks",
"outputTrackAssociations",
},
"Converts ACTS trajectories to EDM4eic"} {};

void init() final;
void process(const Input&, const Output&) const final;
Expand Down
2 changes: 1 addition & 1 deletion src/detectors/B0TRK/B0TRK.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ void InitPlugin(JApplication *app) {
},
{
"B0TrackerRawHits",
"B0TrackerHitAssociations"
"B0TrackerRawHitAssociations"
},
{
.threshold = 10.0 * dd4hep::keV,
Expand Down
2 changes: 1 addition & 1 deletion src/detectors/BTOF/BTOF.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ void InitPlugin(JApplication *app) {
},
{
"TOFBarrelRawHit",
"TOFBarrelHitAssociations"
"TOFBarrelRawHitAssociations"
},
{
.threshold = 6.0 * dd4hep::keV,
Expand Down
2 changes: 1 addition & 1 deletion src/detectors/BTRK/BTRK.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ void InitPlugin(JApplication *app) {
},
{
"SiBarrelRawHits",
"SiBarrelHitAssociations"
"SiBarrelRawHitAssociations"
},
{
.threshold = 0.54 * dd4hep::keV,
Expand Down
2 changes: 1 addition & 1 deletion src/detectors/BVTX/BVTX.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ void InitPlugin(JApplication *app) {
},
{
"SiBarrelVertexRawHits",
"SiBarrelVertexHitAssociations"
"SiBarrelVertexRawHitAssociations"
},
{
.threshold = 0.54 * dd4hep::keV,
Expand Down
2 changes: 1 addition & 1 deletion src/detectors/ECTOF/ECTOF.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ void InitPlugin(JApplication *app) {
"TOFEndcapHits"},
{
"TOFEndcapRawHits",
"TOFEndcapHitAssociations"
"TOFEndcapRawHitAssociations"
},
{
.threshold = 6.0 * dd4hep::keV,
Expand Down
2 changes: 1 addition & 1 deletion src/detectors/ECTRK/ECTRK.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ void InitPlugin(JApplication *app) {
},
{
"SiEndcapTrackerRawHits",
"SiEndcapHitAssociations"
"SiEndcapTrackerRawHitAssociations"
},
{
.threshold = 0.54 * dd4hep::keV,
Expand Down
2 changes: 1 addition & 1 deletion src/detectors/FOFFMTRK/FOFFMTRK.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ void InitPlugin(JApplication *app) {
},
{
"ForwardOffMTrackerRawHits",
"ForwardOffMTrackerHitAssociations"
"ForwardOffMTrackerRawHitAssociations"
},
{
.threshold = 10.0 * dd4hep::keV,
Expand Down
2 changes: 1 addition & 1 deletion src/detectors/LOWQ2/LOWQ2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ extern "C" {
},
{
"TaggerTrackerRawHits",
"TaggerTrackerHitAssociations"
"TaggerTrackerRawHitAssociations"
},
{
.threshold = 1.5 * dd4hep::keV,
Expand Down
8 changes: 4 additions & 4 deletions src/detectors/MPGD/MPGD.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ void InitPlugin(JApplication *app) {
},
{
"MPGDBarrelRawHits",
"MPGDBarrelHitAssociations"
"MPGDBarrelRawHitAssociations"
},
{
.threshold = 0.25 * dd4hep::keV,
Expand Down Expand Up @@ -54,7 +54,7 @@ void InitPlugin(JApplication *app) {
},
{
"OuterMPGDBarrelRawHits",
"OuterMPGDBarrelHitAssociations"
"OuterMPGDBarrelRawHitAssociations"
},
{
.threshold = 0.25 * dd4hep::keV,
Expand Down Expand Up @@ -82,7 +82,7 @@ void InitPlugin(JApplication *app) {
},
{
"BackwardMPGDEndcapRawHits",
"BackwardMPGDEndcapAssociations"
"BackwardMPGDEndcapRawHitAssociations"
},
{
.threshold = 0.25 * dd4hep::keV,
Expand Down Expand Up @@ -110,7 +110,7 @@ void InitPlugin(JApplication *app) {
},
{
"ForwardMPGDEndcapRawHits",
"ForwardMPGDHitAssociations"
"ForwardMPGDEndcapRawHitAssociations"
},
{
.threshold = 0.25 * dd4hep::keV,
Expand Down
2 changes: 1 addition & 1 deletion src/detectors/RPOTS/RPOTS.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ void InitPlugin(JApplication *app) {
},
{
"ForwardRomanPotRawHits",
"ForwardRomanPotHitAssociations"
"ForwardRomanPotRawHitAssociations"
},
{
.threshold = 10.0 * dd4hep::keV,
Expand Down
17 changes: 15 additions & 2 deletions src/global/tracking/ActsToTracks_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@ class ActsToTracks_factory

PodioInput<edm4eic::Measurement2D> m_measurements_input {this};
Input<ActsExamples::Trajectories> m_acts_trajectories_input {this};
PodioInput<edm4eic::MCRecoTrackerHitAssociation> m_raw_hit_assocs_input {this};
PodioOutput<edm4eic::Trajectory> m_trajectories_output {this};
PodioOutput<edm4eic::TrackParameters> m_parameters_output {this};
PodioOutput<edm4eic::Track> m_tracks_output {this};
PodioOutput<edm4eic::MCRecoTrackParticleAssociation> m_track_assocs_output {this};

public:
void Configure() {
Expand All @@ -38,8 +40,19 @@ class ActsToTracks_factory
for (auto acts_traj : m_acts_trajectories_input()) {
acts_trajectories_input.push_back(acts_traj);
}
m_algo->process({m_measurements_input(), acts_trajectories_input},
{m_trajectories_output().get(), m_parameters_output().get(), m_tracks_output().get()});
m_algo->process(
{
m_measurements_input(),
acts_trajectories_input,
m_raw_hit_assocs_input(),
},
{
m_trajectories_output().get(),
m_parameters_output().get(),
m_tracks_output().get(),
m_track_assocs_output().get(),
}
);
}
};

Expand Down
Loading
Loading