Skip to content

Commit

Permalink
Add and use applicative population transformer
Browse files Browse the repository at this point in the history
  • Loading branch information
turion committed Jan 2, 2024
1 parent b582f8d commit 6493bed
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 14 deletions.
8 changes: 4 additions & 4 deletions monad-bayes.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -63,21 +63,21 @@ common deps
, scientific ^>=0.3
, statistics >=0.14.0 && <0.17
, text >=1.2 && <2.1
, transformers ^>=0.5.6
, vector >=0.12.0 && <0.14
, vty ^>=5.38

common test-deps
build-depends:
, abstract-par ^>=0.3
, criterion >=1.5 && <1.7
, criterion >=1.5 && <1.7
, directory ^>=1.3
, hspec ^>=2.11
, monad-bayes
, optparse-applicative >=0.17 && <0.19
, optparse-applicative >=0.17 && <0.19
, process ^>=1.6
, QuickCheck ^>=2.14
, time >=1.9 && <1.13
, transformers ^>=0.5.6
, time >=1.9 && <1.13
, typed-process ^>=0.2

autogen-modules: Paths_monad_bayes
Expand Down
33 changes: 28 additions & 5 deletions src/Control/Applicative/List.hs
Original file line number Diff line number Diff line change
@@ -1,17 +1,40 @@
{-# LANGUAGE StandaloneDeriving #-}

module Control.Applicative.List where

-- base

import Control.Applicative
import Data.Functor.Compose

-- transformers
import Control.Monad.Trans.Writer.Strict

-- log-domain
import Numeric.Log (Log)

-- * Applicative ListT

-- | _Applicative_ transformer adding a list/nondeterminism/choice effect.
-- It is not a valid monad transformer, but it is a valid 'Applicative'.
newtype ListT m a = ListT {getListT :: Compose [] m a}
newtype ListT m a = ListT {getListT :: Compose m [] a}
deriving newtype (Functor, Applicative, Alternative)

lift :: m a -> ListT m a
lift = ListT . Compose . pure
lift :: (Functor m) => m a -> ListT m a
lift = ListT . Compose . fmap pure

runListT :: ListT m a -> [m a]
runListT :: ListT m a -> m [a]
runListT = getCompose . getListT

-- * Applicative Population transformer

-- WriterT has to be used instead of WeightedT,
-- since WeightedT uses StateT under the hood,
-- which requires a Monad (ListT m) constraint.
newtype PopulationT m a = PopulationT {getPopulationT :: WriterT (Log Double) (ListT m) a}
deriving newtype (Functor, Applicative, Alternative)

runPopulationT :: PopulationT m a -> m [(a, Log Double)]
runPopulationT = runListT . runWriterT . getPopulationT

fromWeightedList :: m [(a, Log Double)] -> PopulationT m a
fromWeightedList = PopulationT . WriterT . ListT . Compose
7 changes: 4 additions & 3 deletions src/Control/Monad/Bayes/Inference/RMSMC.hs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import Control.Monad.Bayes.Inference.SMC
import Control.Monad.Bayes.Population
( PopulationT,
flatten,
single,
withParticles,
)
import Control.Monad.Bayes.Sequential.Coroutine as Seq
Expand All @@ -50,7 +51,7 @@ rmsmc ::
PopulationT m a
rmsmc (MCMCConfig {..}) (SMCConfig {..}) =
marginal
. S.sequentially (composeCopies numMCMCSteps (TrStat.hoist flatten . mhStep) . TrStat.hoist resampler) numSteps
. S.sequentially (composeCopies numMCMCSteps (TrStat.hoist (single . flatten) . mhStep) . TrStat.hoist resampler) numSteps
. S.hoistFirst (TrStat.hoist (withParticles numParticles))

-- | Resample-move Sequential Monte Carlo with a more efficient
Expand All @@ -64,7 +65,7 @@ rmsmcBasic ::
PopulationT m a
rmsmcBasic (MCMCConfig {..}) (SMCConfig {..}) =
TrBas.marginal
. S.sequentially (TrBas.hoist flatten . composeCopies numMCMCSteps (TrBas.hoist flatten . TrBas.mhStep) . TrBas.hoist resampler) numSteps
. S.sequentially (TrBas.hoist (single . flatten) . composeCopies numMCMCSteps (TrBas.hoist (single . flatten) . TrBas.mhStep) . TrBas.hoist resampler) numSteps
. S.hoistFirst (TrBas.hoist (withParticles numParticles))

-- | A variant of resample-move Sequential Monte Carlo
Expand All @@ -79,7 +80,7 @@ rmsmcDynamic ::
PopulationT m a
rmsmcDynamic (MCMCConfig {..}) (SMCConfig {..}) =
TrDyn.marginal
. S.sequentially (TrDyn.freeze . composeCopies numMCMCSteps (TrDyn.hoist flatten . TrDyn.mhStep) . TrDyn.hoist resampler) numSteps
. S.sequentially (TrDyn.freeze . composeCopies numMCMCSteps (TrDyn.hoist (single . flatten) . TrDyn.mhStep) . TrDyn.hoist resampler) numSteps
. S.hoistFirst (TrDyn.hoist (withParticles numParticles))

-- | Apply a function a given number of times.
Expand Down
12 changes: 10 additions & 2 deletions src/Control/Monad/Bayes/Population.hs
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@ module Control.Monad.Bayes.Population
popAvg,
withParticles,
flatten,
single,
)
where

import Control.Applicative (Alternative)
import Control.Applicative.List qualified as ApplicativeListT
import Control.Arrow (second)
import Control.Monad (MonadPlus, replicateM)
import Control.Monad.Bayes.Class
Expand Down Expand Up @@ -277,5 +279,11 @@ hoist ::
hoist f = PopulationT . Weighted.hoist (hoistFreeT f) . getPopulationT

-- | Flatten all layers of the free structure
flatten :: (Monad m) => PopulationT m a -> PopulationT m a
flatten = fromWeightedList . runPopulationT
flatten :: (Monad m) => PopulationT m a -> ApplicativeListT.PopulationT m a
flatten = ApplicativeListT.fromWeightedList . runPopulationT

-- | Create a population from a single layer of branching computations.
--
-- Similar to 'fromWeightedListT'.
single :: (Monad m) => ApplicativeListT.PopulationT m a -> PopulationT m a
single = fromWeightedList . ApplicativeListT.runPopulationT

0 comments on commit 6493bed

Please sign in to comment.