Skip to content

Commit

Permalink
Adding example to README and new test to ensure logging
Browse files Browse the repository at this point in the history
  • Loading branch information
pebeto committed Aug 16, 2023
1 parent b2631d0 commit c3c2ac1
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 5 deletions.
19 changes: 19 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,22 @@ The entire workload is divided into three different repositories:
- [x] Polished compatibility with composed models
- [ ] Polished compatibility with tuned models
- [ ] Polished compatibility with iterative models

## Example
```julia
# We first define a logger instance, providing the mlflow server address.
# The experiment name and artifact location are optional.
logger = MLFlowLogger("http://localhost:5000";
experiment_name="MLJFlow tests",
artifact_location="./mlj-test")

X, y = make_moons(100) # X is a 100x2 matrix, y is a 100-element vector

# Writing a normal MLJ workflow
DecisionTreeClassifier = @load DecisionTreeClassifier pkg=DecisionTree
dtc_machine = machine(dtc, X, y)

# Passing the logger to the machine is enough to enable mlflow logging
e1 = evaluate!(dtc_machine, resampling=CV(),
measures=[LogLoss(), Accuracy()], verbosity=1, logger=logger)
```
1 change: 0 additions & 1 deletion src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ multiple methods in MLJBase.
To use this logger, you need to have a MLFlow server running. For more
information, see [MLFlow documentation](https://www.mlflow.org/docs/latest/quickstart.html).
If it is not running, an informative error will be thrown.
Depending on the MLFlow server configuration, the `baseuri` can be a local
server or a remote server. The `experiment_name` is used to identify the
Expand Down
15 changes: 11 additions & 4 deletions test/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
X, y = make_moons(100)
DecisionTreeClassifier = @load DecisionTreeClassifier pkg=DecisionTree

dtc = DecisionTreeClassifier()
dtc_machine = machine(dtc, X, y)
e1 = evaluate!(dtc_machine, resampling=CV(),
pipe = Standardizer() |> DecisionTreeClassifier()
mach = machine(pipe, X, y)
e1 = evaluate!(mach, resampling=CV(),
measures=[LogLoss(), Accuracy()], verbosity=1, logger=logger)

@testset "log_evaluation" begin
Expand All @@ -17,8 +17,15 @@
@test typeof(runs[1]) == MLFlowRun
end

@testset "ensuring logging" begin
runs = searchruns(logger.client,
getexperiment(logger.client, logger.experiment_name))
@test issetequal(keys(runs[1].data.params),
String.([keys(MLJModelInterface.flat_params(pipe))...]))
end

@testset "save" begin
run = MLJBase.save(logger, dtc_machine)
run = MLJBase.save(logger, mach)
@test typeof(run) == MLFlowRun
@test listartifacts(logger.client, run) |> length == 1
end
Expand Down

0 comments on commit c3c2ac1

Please sign in to comment.