Skip to content

Commit

Permalink
x/ref/runtime/internal: use aws imds v1 first and then v2 (#174)
Browse files Browse the repository at this point in the history
AWS' instance meta data service cannot be accessed from a docker container running in bridge mode. This PR changes the behaviour to use the v1 service first and only if that fails, to use the v2 service. This fallback is required since it's possible to configure aws instances to allow v2 only.
  • Loading branch information
cosnicolaou authored Nov 20, 2020
1 parent 047b918 commit cb85103
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 34 deletions.
5 changes: 4 additions & 1 deletion x/ref/runtime/internal/cloudvm.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ type asyncChooser struct {
func (ac *asyncChooser) ChooseAddresses(protocol string, candidates []net.Addr) ([]net.Addr, error) {
select {
case <-ac.ch:
if cvmErr != nil {
return nil, cvmErr
}
return cvm.ChooseAddresses(protocol, candidates)
case <-ac.ctx.Done():
return nil, ac.ctx.Err()
Expand Down Expand Up @@ -115,7 +118,7 @@ func newCloudVM(ctx context.Context, logger logging.Logger, fl *flags.Virtualize

switch fl.VirtualizationProvider.Get().(flags.VirtualizationProvider) {
case flags.AWS:
if !cloudvm.OnAWS(ctx, time.Second) {
if !cloudvm.OnAWS(ctx, cvm.logger, time.Second) {
if fl.DissallowNativeFallback {
return nil, fmt.Errorf("this process is not running on AWS even though its command line says it is")
}
Expand Down
64 changes: 44 additions & 20 deletions x/ref/runtime/internal/cloudvm/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"sync"
"time"

"v.io/v23/logging"
"v.io/x/ref/lib/stats"
"v.io/x/ref/runtime/internal/cloudvm/cloudpaths"
)
Expand Down Expand Up @@ -56,65 +57,88 @@ const (
)

var (
onceAWS sync.Once
onAWS bool
onceAWS sync.Once
onAWS bool
onIMDSv2 bool
)

// OnAWS returns true if this process is running on Amazon Web Services.
// If true, the the stats variables AWSAccountIDStatName and GCPRegionStatName
// are set.
func OnAWS(ctx context.Context, timeout time.Duration) bool {
func OnAWS(ctx context.Context, logger logging.Logger, timeout time.Duration) bool {
onceAWS.Do(func() {
onAWS = awsInit(ctx, timeout)
onAWS, onIMDSv2 = awsInit(ctx, logger, timeout)
logger.VI(1).Infof("OnAWS: onAWS: %v, onIMDSv2: %v", onAWS, onIMDSv2)
})
return onAWS
}

// AWSPublicAddrs returns the current public IP of this AWS instance.
// Must be called after OnAWS.
func AWSPublicAddrs(ctx context.Context, timeout time.Duration) ([]net.Addr, error) {
return awsGetAddr(ctx, awsExternalURL(), timeout)
return awsGetAddr(ctx, onIMDSv2, awsExternalURL(), timeout)
}

// AWSPrivateAddrs returns the current private Addrs of this AWS instance.
// Must be called after OnAWS.
func AWSPrivateAddrs(ctx context.Context, timeout time.Duration) ([]net.Addr, error) {
return awsGetAddr(ctx, awsInternalURL(), timeout)
return awsGetAddr(ctx, onIMDSv2, awsInternalURL(), timeout)
}

func awsGet(ctx context.Context, url string, timeout time.Duration) ([]byte, error) {
func awsGet(ctx context.Context, imdsv2 bool, url string, timeout time.Duration) ([]byte, error) {
client := &http.Client{Timeout: timeout}
token, err := awsSetIMDSv2Token(ctx, awsTokenURL(), timeout)
if err != nil {
return nil, err
var token string
var err error
if imdsv2 {
token, err = awsSetIMDSv2Token(ctx, awsTokenURL(), timeout)
if err != nil {
return nil, err
}
}
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
req.Header.Add("X-aws-ec2-metadata-token", token)
if err != nil {
return nil, err
}
if len(token) > 0 {
req.Header.Add("X-aws-ec2-metadata-token", token)
}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
return nil, err
return nil, fmt.Errorf("HTTP Error: %v %v", url, resp.StatusCode)
}
if server := resp.Header["Server"]; len(server) != 1 || server[0] != "EC2ws" {
return nil, fmt.Errorf("wrong headers")
}
return ioutil.ReadAll(resp.Body)
}

// awsInit returns true if it can access AWS project metadata. It also
// awsInit returns true if it can access AWS project metadata and the version
// of the metadata service it was able to access. It also
// creates two stats variables with the account ID and zone.
func awsInit(ctx context.Context, timeout time.Duration) bool {
body, err := awsGet(ctx, awsIdentityDocURL(), timeout)
func awsInit(ctx context.Context, logger logging.Logger, timeout time.Duration) (bool, bool) {
v2 := false
// Try the v1 service first since it should always work unless v2
// is specifically configured (and hence v1 is disabled), in which
// case the expectation is that it fails fast with a 4xx HTTP error.
body, err := awsGet(ctx, false, awsIdentityDocURL(), timeout)
if err != nil {
return false
logger.VI(1).Infof("failed to access v1 metadata service: %v", err)
// can't access v1, try v2.
body, err = awsGet(ctx, true, awsIdentityDocURL(), timeout)
if err != nil {
logger.VI(1).Infof("failed to access v2 metadata service: %v", err)
return false, false
}
v2 = true
}
doc := map[string]interface{}{}
if err := json.Unmarshal(body, &doc); err != nil {
return false
logger.VI(1).Infof("failed to unmarshal metadata service response: %s: %v", body, err)
return false, false
}
found := 0
for _, v := range []struct {
Expand All @@ -130,11 +154,11 @@ func awsInit(ctx context.Context, timeout time.Duration) bool {
}
}
}
return found == 2
return found == 2, v2
}

func awsGetAddr(ctx context.Context, url string, timeout time.Duration) ([]net.Addr, error) {
body, err := awsGet(ctx, url, timeout)
func awsGetAddr(ctx context.Context, imdsv2 bool, url string, timeout time.Duration) ([]net.Addr, error) {
body, err := awsGet(ctx, imdsv2, url, timeout)
if err != nil {
return nil, err
}
Expand Down
19 changes: 14 additions & 5 deletions x/ref/runtime/internal/cloudvm/aws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,30 @@ import (
"testing"
"time"

"v.io/x/ref/internal/logger"
"v.io/x/ref/runtime/internal/cloudvm/cloudpaths"
"v.io/x/ref/runtime/internal/cloudvm/cloudvmtest"
)

func startAWSMetadataServer(t *testing.T) (string, func()) {
host, close := cloudvmtest.StartAWSMetadataServer(t)
func startAWSMetadataServer(t *testing.T, imdsv2Only bool) (string, func()) {
host, close := cloudvmtest.StartAWSMetadataServer(t, imdsv2Only)
SetAWSMetadataHost(host)
return host, close
}

func TestAWS(t *testing.T) {
testAWSIDMSVersion(t, false)
testAWSIDMSVersion(t, true)
}

func testAWSIDMSVersion(t *testing.T, imdsv2Only bool) {
ctx := context.Background()
host, stop := startAWSMetadataServer(t)
host, stop := startAWSMetadataServer(t, imdsv2Only)
defer stop()

if got, want := OnAWS(ctx, time.Second), true; got != want {
logger := logger.NewLogger("test")

if got, want := OnAWS(ctx, logger, time.Second), true; got != want {
t.Errorf("got %v, want %v", got, want)
}

Expand All @@ -45,8 +53,9 @@ func TestAWS(t *testing.T) {
if got, want := pub[0].String(), cloudvmtest.WellKnownPublicIP; got != want {
t.Errorf("got %v, want %v", got, want)
}

externalURL := host + cloudpaths.AWSPublicIPPath + "/noip"
noip, err := awsGetAddr(ctx, externalURL, time.Second)
noip, err := awsGetAddr(ctx, imdsv2Only, externalURL, time.Second)
if err != nil {
t.Fatal(err)
}
Expand Down
24 changes: 17 additions & 7 deletions x/ref/runtime/internal/cloudvm/cloudvmtest/aws_mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@ import (
"v.io/x/ref/runtime/internal/cloudvm/cloudpaths"
)

func StartAWSMetadataServer(t *testing.T) (string, func()) {
func StartAWSMetadataServer(t *testing.T, imdsv2Only bool) (string, func()) {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
var token string
http.HandleFunc(cloudpaths.AWSTokenPath, func(w http.ResponseWriter, req *http.Request) {
mux := &http.ServeMux{}
mux.HandleFunc(cloudpaths.AWSTokenPath, func(w http.ResponseWriter, req *http.Request) {
token = time.Now().String()
w.Header().Add("Server", "EC2ws")
fmt.Fprint(w, token)
Expand All @@ -32,7 +33,13 @@ func StartAWSMetadataServer(t *testing.T) (string, func()) {
return requestToken == token
}

http.HandleFunc(cloudpaths.AWSIdentityDocPath, func(w http.ResponseWriter, r *http.Request) {
mux.HandleFunc(cloudpaths.AWSIdentityDocPath, func(w http.ResponseWriter, r *http.Request) {
if imdsv2Only {
if len(r.Header.Get("X-aws-ec2-metadata-token")) == 0 {
w.WriteHeader(http.StatusUnauthorized)
return
}
}
if !validSession(r) {
w.WriteHeader(http.StatusForbidden)
return
Expand All @@ -58,19 +65,22 @@ func StartAWSMetadataServer(t *testing.T) (string, func()) {
fmt.Fprintf(w, format, args...)
}

http.HandleFunc(cloudpaths.AWSPrivateIPPath,
mux.HandleFunc(cloudpaths.AWSPrivateIPPath,
func(w http.ResponseWriter, r *http.Request) {
respond(w, r, WellKnownPrivateIP)
})
http.HandleFunc(cloudpaths.AWSPublicIPPath,
mux.HandleFunc(cloudpaths.AWSPublicIPPath,
func(w http.ResponseWriter, r *http.Request) {
respond(w, r, WellKnownPublicIP)
})
http.HandleFunc(cloudpaths.AWSPublicIPPath+"/noip",
mux.HandleFunc(cloudpaths.AWSPublicIPPath+"/noip",
func(w http.ResponseWriter, r *http.Request) {
respond(w, r, "")
})

go http.Serve(l, nil)
srv := http.Server{
Handler: mux,
}
go srv.Serve(l)
return "http://" + l.Addr().String(), func() { l.Close() }
}
2 changes: 1 addition & 1 deletion x/ref/runtime/internal/cloudvm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func hasAddr(addrs []net.Addr, host string) bool {
}

func TestCloudVMProviders(t *testing.T) {
awsHost, awsClose := cloudvmtest.StartAWSMetadataServer(t)
awsHost, awsClose := cloudvmtest.StartAWSMetadataServer(t, true)
defer awsClose()
cloudvm.SetAWSMetadataHost(awsHost)

Expand Down

0 comments on commit cb85103

Please sign in to comment.