diff --git a/tunnel/src/main/java/com/wireguard/android/backend/GoBackend.java b/tunnel/src/main/java/com/wireguard/android/backend/GoBackend.java index 85c60349a..392aa29a8 100644 --- a/tunnel/src/main/java/com/wireguard/android/backend/GoBackend.java +++ b/tunnel/src/main/java/com/wireguard/android/backend/GoBackend.java @@ -44,6 +44,7 @@ public final class GoBackend implements Backend { @Nullable private static AlwaysOnCallback alwaysOnCallback; private static GhettoCompletableFuture vpnService = new GhettoCompletableFuture<>(); private final Context context; + private final TunnelActionHandler tunnelActionHandler; @Nullable private Config currentConfig; @Nullable private Tunnel currentTunnel; private int currentTunnelHandle = -1; @@ -54,8 +55,20 @@ public final class GoBackend implements Backend { * @param context An Android {@link Context} */ public GoBackend(final Context context) { + this(context, new NoopTunnelActionHandler()); + } + + /** + * Public constructor for GoBackend + * + * @param context An Android {@link Context} + * @param tunnelActionHandler A {@link TunnelActionHandler} for executing Pre/Post Up/Down + * scripts when a tunnel's state changes + */ + public GoBackend(final Context context, final TunnelActionHandler tunnelActionHandler) { SharedLibraryLoader.loadSharedLibrary(context, "wg-go"); this.context = context; + this.tunnelActionHandler = tunnelActionHandler; } /** @@ -279,7 +292,9 @@ private void setStateInternal(final Tunnel tunnel, @Nullable final Config config if (tun == null) throw new BackendException(Reason.TUN_CREATION_ERROR); Log.d(TAG, "Go backend v" + wgVersion()); + tunnelActionHandler.runPreUp(config.getInterface().getPreUp()); currentTunnelHandle = wgTurnOn(tunnel.getName(), tun.detachFd(), goConfig); + tunnelActionHandler.runPostUp(config.getInterface().getPostUp()); } if (currentTunnelHandle < 0) throw new BackendException(Reason.GO_ACTIVATION_ERROR_CODE, currentTunnelHandle); @@ -295,7 +310,11 @@ private void setStateInternal(final Tunnel tunnel, @Nullable final Config config return; } + if (currentConfig != null) + tunnelActionHandler.runPreDown(currentConfig.getInterface().getPreDown()); wgTurnOff(currentTunnelHandle); + if (currentConfig != null) + tunnelActionHandler.runPostDown(currentConfig.getInterface().getPostDown()); currentTunnel = null; currentTunnelHandle = -1; currentConfig = null; diff --git a/tunnel/src/main/java/com/wireguard/android/backend/NoopTunnelActionHandler.java b/tunnel/src/main/java/com/wireguard/android/backend/NoopTunnelActionHandler.java new file mode 100644 index 000000000..2d11878f4 --- /dev/null +++ b/tunnel/src/main/java/com/wireguard/android/backend/NoopTunnelActionHandler.java @@ -0,0 +1,34 @@ +/* + * Copyright © 2020 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.android.backend; + +import java.util.Collection; + +/** + * A {@link TunnelActionHandler} implementation that does not execute any scripts. + */ +public final class NoopTunnelActionHandler implements TunnelActionHandler { + + @Override + public void runPreUp(final Collection scripts) { + + } + + @Override + public void runPostUp(final Collection scripts) { + + } + + @Override + public void runPreDown(final Collection scripts) { + + } + + @Override + public void runPostDown(final Collection scripts) { + + } +} diff --git a/tunnel/src/main/java/com/wireguard/android/backend/RootTunnelActionHandler.java b/tunnel/src/main/java/com/wireguard/android/backend/RootTunnelActionHandler.java new file mode 100644 index 000000000..1ed548588 --- /dev/null +++ b/tunnel/src/main/java/com/wireguard/android/backend/RootTunnelActionHandler.java @@ -0,0 +1,76 @@ +/* + * Copyright © 2020 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.android.backend; + +import android.util.Log; + +import com.wireguard.android.util.RootShell; +import com.wireguard.android.util.RootShell.RootShellException; +import com.wireguard.util.NonNullForAll; + +import java.io.IOException; +import java.util.Collection; + +/** + * A {@link TunnelActionHandler} implementation that executes scripts using a root shell. + * Scripts are executed sequentially. If there is an error executing a script for a given step + * the remaining scripts in that step are skipped. + */ +@NonNullForAll +public final class RootTunnelActionHandler implements TunnelActionHandler { + + private static final String TAG = "WireGuard/TunnelAction"; + private final RootShell rootShell; + + public RootTunnelActionHandler(final RootShell rootShell) { + this.rootShell = rootShell; + } + + @Override + public void runPreDown(final Collection scripts) { + if (scripts.isEmpty()) return; + Log.d(TAG, "Running PreDown scripts"); + runTunnelScripts(scripts); + } + + @Override + public void runPostDown(final Collection scripts) { + if (scripts.isEmpty()) return; + Log.d(TAG, "Running PostDown scripts"); + runTunnelScripts(scripts); + } + + @Override + public void runPreUp(final Collection scripts) { + if (scripts.isEmpty()) return; + Log.d(TAG, "Running PreUp scripts"); + runTunnelScripts(scripts); + } + + @Override + public void runPostUp(final Collection scripts) { + if (scripts.isEmpty()) return; + Log.d(TAG, "Running PostUp scripts"); + runTunnelScripts(scripts); + } + + private void runTunnelScripts(final Iterable scripts) { + for (final String script : scripts) { + if (script.contains("%i")) { + Log.e(TAG, "'%i' syntax is not supported with the GoBackend. Aborting"); + return; + } + + try { + rootShell.run(null, script); + } catch (final IOException | RootShellException e) { + Log.e(TAG, "Failed to execute script.", e); + return; + } + } + + } +} diff --git a/tunnel/src/main/java/com/wireguard/android/backend/TunnelActionHandler.java b/tunnel/src/main/java/com/wireguard/android/backend/TunnelActionHandler.java new file mode 100644 index 000000000..7463d9622 --- /dev/null +++ b/tunnel/src/main/java/com/wireguard/android/backend/TunnelActionHandler.java @@ -0,0 +1,42 @@ +/* + * Copyright © 2020 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.android.backend; + +import java.util.Collection; + +/** + * Handles executing Pre/Post Up/Down scripts when the state of the WireGuard tunnel changes + */ +public interface TunnelActionHandler { + + /** + * Execute scripts before bringing up the tunnel + * + * @param scripts Collection of scripts to execute + */ + void runPreUp(Collection scripts); + + /** + * Execute scripts after bringing up the tunnel + * + * @param scripts Collection of scripts to execute + */ + void runPostUp(Collection scripts); + + /** + * Execute scripts before bringing down the tunnel + * + * @param scripts Collection of scripts to execute + */ + void runPreDown(Collection scripts); + + /** + * Execute scripts after bringing down the tunnel + * + * @param scripts Collection of scripts to execute + */ + void runPostDown(Collection scripts); +} diff --git a/tunnel/src/main/java/com/wireguard/config/Interface.java b/tunnel/src/main/java/com/wireguard/config/Interface.java index 01bb3699b..875f2457e 100644 --- a/tunnel/src/main/java/com/wireguard/config/Interface.java +++ b/tunnel/src/main/java/com/wireguard/config/Interface.java @@ -14,6 +14,7 @@ import com.wireguard.util.NonNullForAll; import java.net.InetAddress; +import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.LinkedHashSet; @@ -45,6 +46,10 @@ public final class Interface { private final KeyPair keyPair; private final Optional listenPort; private final Optional mtu; + private final List preUp; + private final List postUp; + private final List preDown; + private final List postDown; private Interface(final Builder builder) { // Defensively copy to ensure immutability even if the Builder is reused. @@ -55,6 +60,10 @@ private Interface(final Builder builder) { keyPair = Objects.requireNonNull(builder.keyPair, "Interfaces must have a private key"); listenPort = builder.listenPort; mtu = builder.mtu; + preUp = Collections.unmodifiableList(new ArrayList<>(builder.preUp)); + postUp = Collections.unmodifiableList(new ArrayList<>(builder.postUp)); + preDown = Collections.unmodifiableList(new ArrayList<>(builder.preDown)); + postDown = Collections.unmodifiableList(new ArrayList<>(builder.postDown)); } /** @@ -93,6 +102,18 @@ public static Interface parse(final Iterable lines) case "privatekey": builder.parsePrivateKey(attribute.getValue()); break; + case "preup": + builder.parsePreUp(attribute.getValue()); + break; + case "postup": + builder.parsePostUp(attribute.getValue()); + break; + case "predown": + builder.parsePreDown(attribute.getValue()); + break; + case "postdown": + builder.parsePostDown(attribute.getValue()); + break; default: throw new BadConfigException(Section.INTERFACE, Location.TOP_LEVEL, Reason.UNKNOWN_ATTRIBUTE, attribute.getKey()); @@ -112,7 +133,12 @@ public boolean equals(final Object obj) { && includedApplications.equals(other.includedApplications) && keyPair.equals(other.keyPair) && listenPort.equals(other.listenPort) - && mtu.equals(other.mtu); + && mtu.equals(other.mtu) + && preUp.equals(other.preUp) + && postUp.equals(other.postUp) + && preDown.equals(other.preDown) + && postDown.equals(other.postDown); + } /** @@ -182,6 +208,22 @@ public Optional getMtu() { return mtu; } + public List getPreUp() { + return preUp; + } + + public List getPostUp() { + return postUp; + } + + public List getPreDown() { + return preDown; + } + + public List getPostDown() { + return postDown; + } + @Override public int hashCode() { int hash = 1; @@ -192,6 +234,10 @@ public int hashCode() { hash = 31 * hash + keyPair.hashCode(); hash = 31 * hash + listenPort.hashCode(); hash = 31 * hash + mtu.hashCode(); + hash = 31 * hash + preUp.hashCode(); + hash = 31 * hash + postUp.hashCode(); + hash = 31 * hash + preDown.hashCode(); + hash = 31 * hash + postDown.hashCode(); return hash; } @@ -231,6 +277,14 @@ public String toWgQuickString() { listenPort.ifPresent(lp -> sb.append("ListenPort = ").append(lp).append('\n')); mtu.ifPresent(m -> sb.append("MTU = ").append(m).append('\n')); sb.append("PrivateKey = ").append(keyPair.getPrivateKey().toBase64()).append('\n'); + for (final String script : preUp) + sb.append("PreUp = ").append(script).append('\n'); + for (final String script : postUp) + sb.append("PostUp = ").append(script).append('\n'); + for (final String script : preDown) + sb.append("PreDown = ").append(script).append('\n'); + for (final String script : postDown) + sb.append("PostDown = ").append(script).append('\n'); return sb.toString(); } @@ -263,6 +317,14 @@ public static final class Builder { private Optional listenPort = Optional.empty(); // Defaults to not present. private Optional mtu = Optional.empty(); + // Defaults to empty list + private List preUp = new ArrayList<>(); + // Defaults to empty list + private List postUp = new ArrayList<>(); + // Defaults to empty list + private List preDown = new ArrayList<>(); + // Defaults to empty list + private List postDown = new ArrayList<>(); public Builder addAddress(final InetNetwork address) { addresses.add(address); @@ -366,6 +428,26 @@ public Builder parsePrivateKey(final String privateKey) throws BadConfigExceptio } } + public Builder parsePreUp(final String script) { + preUp.add(script); + return this; + } + + public Builder parsePostUp(final String script) { + postUp.add(script); + return this; + } + + public Builder parsePreDown(final String script) { + preDown.add(script); + return this; + } + + public Builder parsePostDown(final String script) { + postDown.add(script); + return this; + } + public Builder setKeyPair(final KeyPair keyPair) { this.keyPair = keyPair; return this; diff --git a/tunnel/src/test/java/com/wireguard/config/ConfigTest.java b/tunnel/src/test/java/com/wireguard/config/ConfigTest.java index 693a37ea7..b80093cb5 100644 --- a/tunnel/src/test/java/com/wireguard/config/ConfigTest.java +++ b/tunnel/src/test/java/com/wireguard/config/ConfigTest.java @@ -11,6 +11,7 @@ import java.io.InputStream; import java.util.Arrays; import java.util.Collection; +import java.util.Collections; import java.util.HashSet; import java.util.Objects; @@ -45,5 +46,9 @@ public void valid_config_parses_correctly() throws IOException, ParseException { assertEquals("Test config has exactly one peer", 1, config.getPeers().size()); assertEquals("Test config's allowed IPs are 0.0.0.0/0 and ::0/0", config.getPeers().get(0).getAllowedIps(), expectedAllowedIps); assertEquals("Test config has one DNS server", 1, config.getInterface().getDnsServers().size()); + assertEquals("Test config loads multiple pre up scripts", Arrays.asList("echo \"pre up 1\"", "echo \"pre up 2\""), config.getInterface().getPreUp()); + assertEquals("Test config loads single post up script", Collections.singletonList("echo \"post up 1\""), config.getInterface().getPostUp()); + assertEquals("Test config loads single pre down script", Collections.singletonList("echo \"pre down 1\""), config.getInterface().getPreDown()); + assertEquals("Test config loads single post down scripts", Collections.singletonList("echo \"post down 1\""), config.getInterface().getPostDown()); } } diff --git a/tunnel/src/test/resources/working.conf b/tunnel/src/test/resources/working.conf index 3f9665c39..290c05061 100644 --- a/tunnel/src/test/resources/working.conf +++ b/tunnel/src/test/resources/working.conf @@ -2,6 +2,12 @@ Address = 192.0.2.2/32,2001:db8:ffff:ffff:ffff:ffff:ffff:ffff/128 DNS = 192.0.2.0 PrivateKey = TFlmmEUC7V7VtiDYLKsbP5rySTKLIZq1yn8lMqK83wo= +PreUp = echo "pre up 1" +PreUp = echo "pre up 2" +PostUp = echo "post up 1" +PreDown = echo "pre down 1" +PostDown = echo "post down 1" + [Peer] AllowedIPs = 0.0.0.0/0, ::0/0 Endpoint = 192.0.2.1:51820 diff --git a/ui/src/main/java/com/wireguard/android/Application.kt b/ui/src/main/java/com/wireguard/android/Application.kt index fe98d0d25..9738cae2c 100644 --- a/ui/src/main/java/com/wireguard/android/Application.kt +++ b/ui/src/main/java/com/wireguard/android/Application.kt @@ -17,6 +17,9 @@ import androidx.datastore.preferences.Preferences import androidx.datastore.preferences.createDataStore import com.wireguard.android.backend.Backend import com.wireguard.android.backend.GoBackend +import com.wireguard.android.backend.NoopTunnelActionHandler +import com.wireguard.android.backend.RootTunnelActionHandler +import com.wireguard.android.backend.TunnelActionHandler import com.wireguard.android.backend.WgQuickBackend import com.wireguard.android.configStore.FileConfigStore import com.wireguard.android.model.TunnelManager @@ -76,8 +79,10 @@ class Application : android.app.Application() { } if (!UserKnobs.disableKernelModule.first() && ModuleLoader.isModuleLoaded()) { try { - if (!didStartRootShell) + if (!didStartRootShell) { rootShell.start() + didStartRootShell = true + } val wgQuickBackend = WgQuickBackend(applicationContext, rootShell, toolsInstaller) wgQuickBackend.setMultipleTunnels(UserKnobs.multipleTunnels.first()) backend = wgQuickBackend @@ -88,7 +93,18 @@ class Application : android.app.Application() { } } if (backend == null) { - backend = GoBackend(applicationContext) + var tunnelActionHandler: TunnelActionHandler + try { + if (!didStartRootShell) + rootShell.start() + tunnelActionHandler = RootTunnelActionHandler(rootShell) + Log.d(TAG, "Using root action handler") + } catch (ignored: Exception) { + tunnelActionHandler = NoopTunnelActionHandler() + Log.d(TAG, "Using NOOP action handler") + } + + backend = GoBackend(applicationContext, tunnelActionHandler); GoBackend.setAlwaysOnCallback { get().applicationScope.launch { get().tunnelManager.restoreState(true) } } } return backend