Skip to content

Commit

Permalink
feat: add track associations based on raw hit association, weighted b…
Browse files Browse the repository at this point in the history
…y number of hits (#1564)

### Briefly, what does this PR introduce?
This PR builds on #1535 to create `edm4eic::Track` associations to
`edm4hep::MCParticles`. Results in weights like this:
```
events->Draw("CentralCKFTrackAssociations.weight")
```

![image](https://github.com/user-attachments/assets/a775165c-f644-49f7-8943-322daf4cc2f7)

```
events->Draw("sqrt(MCParticles[_CentralCKFSeededTrackAssociations_sim.index].momentum.x**2+MCParticles[_CentralCKFSeededTrackAssociations_sim.index].momentum.y**2+MCParticles[_CentralCKFSeededTrackAssociations_sim.index].momentum.z**2)")
```

![image](https://github.com/user-attachments/assets/217ba69b-f288-4b6b-b49b-6dbcdaf7ca2f)


### What kind of change does this PR introduce?
- [ ] Bug fix (issue #__)
- [x] New feature (issue: track associations)
- [ ] Documentation update
- [ ] Other: __

### Please check if this PR fulfills the following:
- [ ] Tests for the changes have been added
- [ ] Documentation has been added / updated
- [ ] Changes have been communicated to collaborators

### Does this PR introduce breaking changes? What changes might users
need to make to their code?
No.

### Does this PR change default behavior?
No.
  • Loading branch information
wdconinc authored Aug 14, 2024
1 parent 5db8eea commit 0d5b8d8
Show file tree
Hide file tree
Showing 15 changed files with 159 additions and 52 deletions.
49 changes: 47 additions & 2 deletions src/algorithms/tracking/ActsToTracks.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,23 @@
#include <Acts/EventData/TrackStateType.hpp>
#include <ActsExamples/EventData/IndexSourceLink.hpp>
#include <edm4eic/Cov6f.h>
#include <edm4eic/EDM4eicVersion.h>
#include <edm4eic/RawTrackerHit.h>
#include <edm4eic/TrackerHit.h>
#include <edm4hep/MCParticleCollection.h>
#include <edm4hep/SimTrackerHit.h>
#include <edm4hep/Vector2f.h>
#include <edm4hep/Vector3f.h>
#include <fmt/core.h>
#include <podio/ObjectID.h>
#include <podio/RelationRange.h>
#include <Eigen/Core>
#include <array>
#include <cmath>
#include <cstddef>
#include <gsl/pointers>
#include <map>
#include <numeric>
#include <optional>
#include <utility>

Expand All @@ -42,8 +51,8 @@ void ActsToTracks::init() {
}

void ActsToTracks::process(const Input& input, const Output& output) const {
const auto [meas2Ds, acts_trajectories] = input;
auto [trajectories, track_parameters, tracks] = output;
const auto [meas2Ds, acts_trajectories, raw_hit_assocs] = input;
auto [trajectories, track_parameters, tracks, tracks_assoc] = output;

// Loop over trajectories
for (const auto traj : acts_trajectories) {
Expand Down Expand Up @@ -147,6 +156,9 @@ void ActsToTracks::process(const Input& input, const Output& output) const {
);
track.setTrajectory(trajectory); // Trajectory of this track

// Determine track association with MCParticle, weighted by number of used measurements
std::map<edm4hep::MCParticle,double> mcparticle_weight_by_hit_count;

// save measurement2d to good measurements or outliers according to srclink index
// fix me: ideally, this should be integrated into multitrajectoryhelper
// fix me: should say "OutlierMeasurements" instead of "OutlierHits" etc
Expand All @@ -173,6 +185,23 @@ void ActsToTracks::process(const Input& input, const Output& output) const {
debug("Measurement on geo id={}, index={}, loc={},{}",
geoID, srclink_index, meas2D.getLoc().a, meas2D.getLoc().b);

// Determine track associations if hit associations provided
// FIXME: not able to check whether optional inputs were provided
//if (raw_hit_assocs->has_value()) {
#if EDM4EIC_VERSION_MAJOR >= 7
for (auto& hit : meas2D.getHits()) {
auto raw_hit = hit.getRawHit();
for (const auto raw_hit_assoc : *raw_hit_assocs) {
if (raw_hit_assoc.getRawHit() == raw_hit) {
auto sim_hit = raw_hit_assoc.getSimHit();
auto mc_particle = sim_hit.getMCParticle();
mcparticle_weight_by_hit_count[mc_particle]++;
}
}
}
#endif
//}

}
else if (typeFlags.test(Acts::TrackStateFlag::OutlierFlag)) {
trajectory.addToOutliers_deprecated(meas2D);
Expand All @@ -185,6 +214,22 @@ void ActsToTracks::process(const Input& input, const Output& output) const {

});

// Store track associations if hit associations provided
// FIXME: not able to check whether optional inputs were provided
//if (raw_hit_assocs->has_value()) {
double total_weight = std::accumulate(
mcparticle_weight_by_hit_count.begin(), mcparticle_weight_by_hit_count.end(),
0, [](const double sum, const auto& i) { return sum + i.second; });
for (const auto& [mcparticle, weight] : mcparticle_weight_by_hit_count) {
auto track_assoc = tracks_assoc->create();
track_assoc.setRec(track);
track_assoc.setSim(mcparticle);
double normalized_weight = weight / total_weight;
track_assoc.setWeight(normalized_weight);
debug("track {}: mcparticle {} weight {}", track.id().index, mcparticle.id().index, normalized_weight);
}
//}

}
}
}
Expand Down
32 changes: 29 additions & 3 deletions src/algorithms/tracking/ActsToTracks.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@

#include <ActsExamples/EventData/Trajectories.hpp>
#include <algorithms/algorithm.h>
#include <edm4eic/MCRecoTrackParticleAssociationCollection.h>
#include <edm4eic/MCRecoTrackerHitAssociationCollection.h>
#include <edm4eic/Measurement2DCollection.h>
#include <edm4eic/TrackCollection.h>
#include <edm4eic/TrackParametersCollection.h>
#include <edm4eic/TrajectoryCollection.h>
#include <optional>
#include <string>
#include <string_view>
#include <vector>
Expand All @@ -17,13 +20,36 @@ namespace eicrecon {

using ActsToTracksAlgorithm =
algorithms::Algorithm<
algorithms::Input<edm4eic::Measurement2DCollection, std::vector<ActsExamples::Trajectories>>,
algorithms::Output<edm4eic::TrajectoryCollection, edm4eic::TrackParametersCollection, edm4eic::TrackCollection>
algorithms::Input<
edm4eic::Measurement2DCollection,
std::vector<ActsExamples::Trajectories>,
std::optional<edm4eic::MCRecoTrackerHitAssociationCollection>
>,
algorithms::Output<
edm4eic::TrajectoryCollection,
edm4eic::TrackParametersCollection,
edm4eic::TrackCollection,
std::optional<edm4eic::MCRecoTrackParticleAssociationCollection>
>
>;

class ActsToTracks : public ActsToTracksAlgorithm {
public:
ActsToTracks(std::string_view name) : ActsToTracksAlgorithm{name, {"inputActsTrajectories", "inputActsConstTrackContainer"}, {"outputTrajectoryCollection", "outputTrackParametersCollection", "outputTrackCollection"}, "Converts ACTS trajectories to EDM4eic"} {};
ActsToTracks(std::string_view name)
: ActsToTracksAlgorithm{
name,
{
"inputMeasurements",
"inputActsTrajectories",
"inputRawTrackerHitAssociations",
},
{
"outputTrajectories",
"outputTrackParameters",
"outputTracks",
"outputTrackAssociations",
},
"Converts ACTS trajectories to EDM4eic"} {};

void init() final;
void process(const Input&, const Output&) const final;
Expand Down
2 changes: 1 addition & 1 deletion src/detectors/B0TRK/B0TRK.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ void InitPlugin(JApplication *app) {
},
{
"B0TrackerRawHits",
"B0TrackerHitAssociations"
"B0TrackerRawHitAssociations"
},
{
.threshold = 10.0 * dd4hep::keV,
Expand Down
2 changes: 1 addition & 1 deletion src/detectors/BTOF/BTOF.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ void InitPlugin(JApplication *app) {
},
{
"TOFBarrelRawHit",
"TOFBarrelHitAssociations"
"TOFBarrelRawHitAssociations"
},
{
.threshold = 6.0 * dd4hep::keV,
Expand Down
2 changes: 1 addition & 1 deletion src/detectors/BTRK/BTRK.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ void InitPlugin(JApplication *app) {
},
{
"SiBarrelRawHits",
"SiBarrelHitAssociations"
"SiBarrelRawHitAssociations"
},
{
.threshold = 0.54 * dd4hep::keV,
Expand Down
2 changes: 1 addition & 1 deletion src/detectors/BVTX/BVTX.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ void InitPlugin(JApplication *app) {
},
{
"SiBarrelVertexRawHits",
"SiBarrelVertexHitAssociations"
"SiBarrelVertexRawHitAssociations"
},
{
.threshold = 0.54 * dd4hep::keV,
Expand Down
2 changes: 1 addition & 1 deletion src/detectors/ECTOF/ECTOF.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ void InitPlugin(JApplication *app) {
"TOFEndcapHits"},
{
"TOFEndcapRawHits",
"TOFEndcapHitAssociations"
"TOFEndcapRawHitAssociations"
},
{
.threshold = 6.0 * dd4hep::keV,
Expand Down
2 changes: 1 addition & 1 deletion src/detectors/ECTRK/ECTRK.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ void InitPlugin(JApplication *app) {
},
{
"SiEndcapTrackerRawHits",
"SiEndcapHitAssociations"
"SiEndcapTrackerRawHitAssociations"
},
{
.threshold = 0.54 * dd4hep::keV,
Expand Down
2 changes: 1 addition & 1 deletion src/detectors/FOFFMTRK/FOFFMTRK.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ void InitPlugin(JApplication *app) {
},
{
"ForwardOffMTrackerRawHits",
"ForwardOffMTrackerHitAssociations"
"ForwardOffMTrackerRawHitAssociations"
},
{
.threshold = 10.0 * dd4hep::keV,
Expand Down
2 changes: 1 addition & 1 deletion src/detectors/LOWQ2/LOWQ2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ extern "C" {
},
{
"TaggerTrackerRawHits",
"TaggerTrackerHitAssociations"
"TaggerTrackerRawHitAssociations"
},
{
.threshold = 1.5 * dd4hep::keV,
Expand Down
8 changes: 4 additions & 4 deletions src/detectors/MPGD/MPGD.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ void InitPlugin(JApplication *app) {
},
{
"MPGDBarrelRawHits",
"MPGDBarrelHitAssociations"
"MPGDBarrelRawHitAssociations"
},
{
.threshold = 0.25 * dd4hep::keV,
Expand Down Expand Up @@ -54,7 +54,7 @@ void InitPlugin(JApplication *app) {
},
{
"OuterMPGDBarrelRawHits",
"OuterMPGDBarrelHitAssociations"
"OuterMPGDBarrelRawHitAssociations"
},
{
.threshold = 0.25 * dd4hep::keV,
Expand Down Expand Up @@ -82,7 +82,7 @@ void InitPlugin(JApplication *app) {
},
{
"BackwardMPGDEndcapRawHits",
"BackwardMPGDEndcapAssociations"
"BackwardMPGDEndcapRawHitAssociations"
},
{
.threshold = 0.25 * dd4hep::keV,
Expand Down Expand Up @@ -110,7 +110,7 @@ void InitPlugin(JApplication *app) {
},
{
"ForwardMPGDEndcapRawHits",
"ForwardMPGDHitAssociations"
"ForwardMPGDEndcapRawHitAssociations"
},
{
.threshold = 0.25 * dd4hep::keV,
Expand Down
2 changes: 1 addition & 1 deletion src/detectors/RPOTS/RPOTS.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ void InitPlugin(JApplication *app) {
},
{
"ForwardRomanPotRawHits",
"ForwardRomanPotHitAssociations"
"ForwardRomanPotRawHitAssociations"
},
{
.threshold = 10.0 * dd4hep::keV,
Expand Down
17 changes: 15 additions & 2 deletions src/global/tracking/ActsToTracks_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@ class ActsToTracks_factory

PodioInput<edm4eic::Measurement2D> m_measurements_input {this};
Input<ActsExamples::Trajectories> m_acts_trajectories_input {this};
PodioInput<edm4eic::MCRecoTrackerHitAssociation> m_raw_hit_assocs_input {this};
PodioOutput<edm4eic::Trajectory> m_trajectories_output {this};
PodioOutput<edm4eic::TrackParameters> m_parameters_output {this};
PodioOutput<edm4eic::Track> m_tracks_output {this};
PodioOutput<edm4eic::MCRecoTrackParticleAssociation> m_track_assocs_output {this};

public:
void Configure() {
Expand All @@ -38,8 +40,19 @@ class ActsToTracks_factory
for (auto acts_traj : m_acts_trajectories_input()) {
acts_trajectories_input.push_back(acts_traj);
}
m_algo->process({m_measurements_input(), acts_trajectories_input},
{m_trajectories_output().get(), m_parameters_output().get(), m_tracks_output().get()});
m_algo->process(
{
m_measurements_input(),
acts_trajectories_input,
m_raw_hit_assocs_input(),
},
{
m_trajectories_output().get(),
m_parameters_output().get(),
m_tracks_output().get(),
m_track_assocs_output().get(),
}
);
}
};

Expand Down
Loading

0 comments on commit 0d5b8d8

Please sign in to comment.