diff --git a/main.go b/main.go index 447a777..630da57 100644 --- a/main.go +++ b/main.go @@ -22,6 +22,7 @@ var ( keyFile = flag.String("key", "", "path to a private key file. If not provided, ssl-proxy will generate one for you in ~/.ssl-proxy/") domain = flag.String("domain", "", "domain to mint letsencrypt certificates for. Usage of this parameter implies acceptance of the LetsEncrypt terms of service.") redirectHTTP = flag.Bool("redirectHTTP", false, "if true, redirects http requests from port 80 to https at your fromURL") + ipFilter = flag.String("ipfilter", "", "source IP address to filter incoming requests on. If not provided allow all IP") ) const ( @@ -81,7 +82,7 @@ func main() { } // Setup reverse proxy ServeMux - p := reverseproxy.Build(toURL) + p := reverseproxy.Build(toURL, *ipFilter) mux := http.NewServeMux() mux.Handle("/", p) diff --git a/reverseproxy/reverseproxy.go b/reverseproxy/reverseproxy.go index fcd0fc8..5a42cbb 100644 --- a/reverseproxy/reverseproxy.go +++ b/reverseproxy/reverseproxy.go @@ -1,6 +1,7 @@ package reverseproxy import ( + "net" "net/http" "net/http/httputil" "net/url" @@ -8,37 +9,44 @@ import ( ) // Build initializes and returns a new ReverseProxy instance suitable for SSL proxying -func Build(toURL *url.URL) *httputil.ReverseProxy { +func Build(toURL *url.URL, ipFilter string) *httputil.ReverseProxy { localProxy := &httputil.ReverseProxy{} addProxyHeaders := func(req *http.Request) { req.Header.Set(http.CanonicalHeaderKey("X-Forwarded-Proto"), "https") req.Header.Set(http.CanonicalHeaderKey("X-Forwarded-Port"), "443") // TODO: inherit another port if needed } - localProxy.Director = newDirector(toURL, addProxyHeaders) + localProxy.Director = newDirector(toURL, ipFilter, addProxyHeaders) return localProxy } // newDirector creates a base director that should be exactly what http.NewSingleHostReverseProxy() creates, but allows // for the caller to supply and extraDirector function to decorate to request to the downstream server -func newDirector(target *url.URL, extraDirector func(*http.Request)) func(*http.Request) { +func newDirector(target *url.URL, ipFilter string, extraDirector func(*http.Request)) func(*http.Request) { targetQuery := target.RawQuery return func(req *http.Request) { - req.URL.Scheme = target.Scheme - req.URL.Host = target.Host - req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path) - if targetQuery == "" || req.URL.RawQuery == "" { - req.URL.RawQuery = targetQuery + req.URL.RawQuery - } else { - req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery - } - if _, ok := req.Header["User-Agent"]; !ok { - // explicitly disable User-Agent so it's not set to default value - req.Header.Set("User-Agent", "") - } + remoteIp, _, _ := net.SplitHostPort(req.RemoteAddr) + if ipFilter == "" || remoteIp == ipFilter { + req.URL.Scheme = target.Scheme + req.URL.Host = target.Host + req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path) + if targetQuery == "" || req.URL.RawQuery == "" { + req.URL.RawQuery = targetQuery + req.URL.RawQuery + } else { + req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery + } + if _, ok := req.Header["User-Agent"]; !ok { + // explicitly disable User-Agent so it's not set to default value + req.Header.Set("User-Agent", "") + } - if extraDirector != nil { - extraDirector(req) + if extraDirector != nil { + extraDirector(req) + } + } else { + // send to black hole + req.URL.Host = "127.0.0.1:0" + req.URL.Scheme = "http" } } }