diff --git a/handlers/shawty.go b/handlers/shawty.go index afd17e2..1f15dca 100644 --- a/handlers/shawty.go +++ b/handlers/shawty.go @@ -21,18 +21,17 @@ func Shawty(w http.ResponseWriter, r *http.Request) { inputUrl := r.FormValue("url") Logger.Info("Shorten the URL", "url", inputUrl, "user-agent", r.UserAgent()) - // Validate the URL - err := validate.ValidateUrl(inputUrl) + parsedUrl, err := validate.ValidateUrl(inputUrl) if err != nil { - Logger.Warn("failed to validate URL", "url", inputUrl, "user-agent", r.UserAgent(), "error", err) + Logger.Warn("failed to validate URL", "url", parsedUrl, "user-agent", r.UserAgent(), "error", err) errorTempl := template.Must(template.ParseFiles("./partial-html/short-link-error.html")) asserts.NoErr(errorTempl.Execute(w, err), "Failed to execute template short-link-error.html") return } hasher := sha256.New() - hasher.Write([]byte(inputUrl)) + hasher.Write([]byte(parsedUrl)) checksum := hasher.Sum(nil) db := utils.ConnectDB() @@ -47,22 +46,22 @@ func Shawty(w http.ResponseWriter, r *http.Request) { // Check if the url exists in the database code, err := queries.GetCode(r.Context(), hashUrl) if err != nil { - Logger.Info("URL doesn't exists in the database", "url", inputUrl, "user-agent", r.UserAgent()) + Logger.Info("URL doesn't exists in the database", "url", parsedUrl, "user-agent", r.UserAgent()) // Check if err doesn't equal to `sql.ErrNoRows` // And if true then log the error and return if err != sql.ErrNoRows { - Logger.Error("failed to query the code for the URL", "error", err, "code", hashUrl, "input-url", inputUrl, "user-agent", r.UserAgent()) + Logger.Error("failed to query the code for the URL", "error", err, "code", hashUrl, "input-url", parsedUrl, "user-agent", r.UserAgent()) utils.ServerErrTempl(w, "An error occur when querying the database") return } // Insert the URL in the database if doesn't exists _, err = queries.CreateShortLink(r.Context(), database.CreateShortLinkParams{ - OriginalUrl: inputUrl, + OriginalUrl: parsedUrl, Code: hashUrl, }) if err != nil { - Logger.Error("failed to query to create short link", "original_url", inputUrl, "code", hashUrl, "error", err) + Logger.Error("failed to query to create short link", "original_url", parsedUrl, "code", hashUrl, "error", err) utils.ServerErrTempl(w, "An error occur when saving the URL to the database") return } @@ -75,7 +74,7 @@ func Shawty(w http.ResponseWriter, r *http.Request) { return } - Logger.Info("URL exists in the database", "url", inputUrl, "code", hashUrl, "user-agent", r.UserAgent()) + Logger.Info("URL exists in the database", "url", parsedUrl, "code", hashUrl, "user-agent", r.UserAgent()) w.WriteHeader(http.StatusCreated) data := ShortLink{ ShortUrl: code, diff --git a/validate/validate_url.go b/validate/validate_url.go index 93c0482..a79eeea 100644 --- a/validate/validate_url.go +++ b/validate/validate_url.go @@ -53,38 +53,46 @@ func (*DomainTooLong) Error() string { func (link *InvalidUrlPath) Error() string { return fmt.Sprintf("URL path contains invalid characters: %s", link.path) } -func ValidateUrl(link string) error { + +func ValidateUrl(link string) (string, error) { parsedUrl, err := url.Parse(link) if err != nil { - return err + return link, err } - if parsedUrl.Scheme != "" && parsedUrl.Scheme != "https" { - return &InvalidUrlSchema{schema: parsedUrl.Scheme} + // Check if the scheme is empty, if so default to https + if parsedUrl.Scheme == "" { + link = "https://" + link + parsedUrl, err = url.Parse(link) + if err != nil { + return link, err + } } - //TODO: add https to the start if the url schem is empty + if parsedUrl.Scheme != "https" { + return link, &InvalidUrlSchema{schema: parsedUrl.Scheme} + } // Check URL length if len(link) > 1000 { - return &UrlTooLong{url: uint(len(link))} + return link, &UrlTooLong{url: uint(len(link))} } domain := parsedUrl.Hostname() path := parsedUrl.Path if err := validateDomain(domain); err != nil { - return err + return link, err } // Check for ASCII path characters if path != "" { if !utils.IsASCII(path) { - return &InvalidUrlPath{path: path} + return link, &InvalidUrlPath{path: path} } } - return nil + return link, nil } func validateDomain(domain string) error { @@ -98,20 +106,17 @@ func validateDomain(domain string) error { // Check for allowed domain characters for _, c := range domain { if !(utils.IsValidChar(c) || c == '-' || c == '.') { - fmt.Println(domain) return &InvalidDomainFormat{} } } if strings.Contains(domain, " ") { - fmt.Println("h3ell") return &InvalidDomainFormat{} } // Check for consecutive dashes re := regexp.MustCompile(`-{2,}`) if re.MatchString(domain) { - fmt.Println("h2cell") return &InvalidDomainFormat{} } @@ -138,7 +143,6 @@ func isValidDomainPart(part string) error { // Check for leading or trailing dashes & empty parts if strings.HasPrefix(part, "-") || strings.HasSuffix(part, "-") || part == "" { - fmt.Println("h2cellwq") return &InvalidDomainFormat{} }