diff --git a/doc/SPEC.md b/doc/SPEC.md index 44d497c..8bb160c 100644 --- a/doc/SPEC.md +++ b/doc/SPEC.md @@ -2,15 +2,14 @@ ## Models -Example Config +Example Config (anonymous mode) ListenHost: "the:yggdrasil:ip:address:of:the:autoygg:server" ListenPort: 8080 GatewayOwner: "You " GatewayDescription: "This is an Yggdrasil gateway operated for fun and profit" RequireRegistration: false - RequireApproval: false - AccessListEnabled: true + AccessListEnabled: false StateDir: "/var/lib/autoygg" MaxClients: 10 LeaseTimeoutSeconds: 14400 @@ -26,23 +25,14 @@ Registration Model gorm.Model YggIP string // Client Yggdrasil IP address PublicKey string // Client Yggdrasil PublicKey - ClientName string // Registration name (optional) - ClientEmail string // Registration email (optional) - ClientPhone string // Registration phone (optional) - Error string - Approved Bool - } - -Lease Model - - type lease struct { - gorm.Model - YggIP string // Client Yggdrasil IP address - PublicKey string // Client Yggdrasil PublicKey - GatewayPublicKey string + ClientName string // Registration name (optional depending on operating mode) + ClientEmail string // Registration email (optional depending on operating mode) + ClientPhone string // Registration phone (optional depending on operating mode) ClientIP string // The tunnel IP address assigned to the client ClientNetMask int // The tunnel netmask ClientGateway string + Error string + Approved Bool LeaseExpires time.Time } @@ -56,43 +46,59 @@ ACL Model ## Operating Modes ### Full Anonymous -* Allows anybody to directly `GET /lease` to use the gateway, subject to ACL config +* Allows anybody to do `POST /register` without sending personal information +* Access granted automatically * RequireRegistration = false +* AccessListEnabled = false ### Registration -* Requires all gateway users to first `POST /register` to store personal information with the gateway before requesting `POST /lease` +* Requires all users to do `POST /register` with personal information (name, phone, e-mail) +* Access granted automatically * RequireRegistration = true -* RequireApproval = false +* AccessListEnabled = false ### Registration & Approval -* Requires all gateway users to `POST /register` and wait for the gateway admin to manually approve the registration before the user is allowed to `POST /lease` +* Requires all users to do `POST /register` with personal information (name, phone, e-mail) +* Must wait for the gateway admin to manually approve the registration to use the gateway by adding an entry to the AccessList * RequireRegistration = true -* RequireApproval = true +* AccessListEnabled = true + +### Full Anonymous & Approval +* Allows anybody to do `POST /register` without sending personal information +* Must wait for the gateway admin to manually approve the registration to use the gateway by adding an entry to the AccessList +* RequireRegistration = false +* AccessListEnabled = true + +## ACL Modes +### ACL disabled +* Allows anyone with a valid registration to use the gateway +* AccessListEnabled = false + +### ACL enabled +* Allows only valid registrations with an ACL entry set to `access: true` to use the gateway +* AccessListEnabled = true + +## ACL Check Routine: +* If ACL entry exists for client IP with Access: false + * Return access error +* If AccessListEnabled=true and ACL entry does not exist for client IP with Access: true + * Return access error ## Endpoints - * `GET /info`: Returns GatewayOwner, Description, RequireRegistration, ACLEnabled + * `GET /info`: Returns GatewayOwner, Description, RequireRegistration, AccessListEnabled * `GET /register`: - * Return access error if ACL check fails - * If RequireRegistration=false: Disabled - * If RequireRegistration=true: Return registration status for user if found or 404 + * If AccessListEnabled=true, apply ACLs, return access error if access denied + * If Registration is found, return status, otherwise return error * `POST /register`: - * Return access error if ACL check fails - * If RequireRegistration=false, Disabled - * If ACLEnabled=true, apply ACLFile based on ACLMode, give access error if conditions not met - * If RequireRegistration=true, Store registration information with Approved=false - * Storing unapproved feels like the safer thing to do in case someone switches RequireApproval on and off + * If AccessListEnabled=true, apply ACLs, return access error if access denied + * If RequireRegistration=true: require ClientName, ClientEmail, ClientPhone to be populated, otherwise return error + * Create Registration, provision client * `POST /renew`: - * Return access error if ACL check fails - * If RequireRegistration=true: Deny unless approved registration found - * If ACLEnabled=true, apply ACLFile based on ACLMode, give access error if conditions not met - * Assign lease, provision lease, and store in leases table + * If AccessListEnabled=true, apply ACLs, return access error if access denied + * If RequireRegistration=true: require ClientName, ClientEmail, ClientPhone to be populated, otherwise return error + * If Registration is found, extend lease expiry date, otherwise return error * `POST /release`: - * Return access error if ACL check fails - * Remove lease from leases, teardown lease, and return success. Return 404 if lease doesn't exist - * ACL Check Routine: - * If acl entry exists for client IP with Access: false - * Return access error - * If AccessListEnabled=true and acl entry does not exist for client IP with Access: true - * Return access error + * If AccessListEnabled=true, apply ACLs, return access error if access denied + * Remove Registration, unprovision client, and return success. Return 404 if lease doesn't exist # Client Operating Model diff --git a/internal/client.go b/internal/client.go index 45f9997..0320974 100644 --- a/internal/client.go +++ b/internal/client.go @@ -76,11 +76,11 @@ Options: fmt.Fprintln(os.Stderr, "") } -func doRequestWorker(fs *flag.FlagSet, verb string, action string, gatewayHost string, gatewayPort string) (response []byte, err error) { +func doRequestWorker(fs *flag.FlagSet, verb string, action string, gatewayHost string, gatewayPort string, i info) (response []byte, err error) { validActions := map[string]bool{ - "register": true, - "renew": true, - "release": true, + "register": true, // register and request a lease + "renew": true, // renew an existing lease + "release": true, // release an existing lease } if !validActions[action] { err = errors.New("Invalid action: " + action) @@ -92,9 +92,12 @@ func doRequestWorker(fs *flag.FlagSet, verb string, action string, gatewayHost s if err != nil { return } - r.ClientName = cViper.GetString("clientname") - r.ClientEmail = cViper.GetString("clientemail") - r.ClientPhone = cViper.GetString("clientphone") + // Only send ClientName, ClientEmail and ClientPhone when registration is required + if i.RequireRegistration { + r.ClientName = cViper.GetString("clientname") + r.ClientEmail = cViper.GetString("clientemail") + r.ClientPhone = cViper.GetString("clientphone") + } r.ClientVersion = version req, err := json.Marshal(r) if err != nil { @@ -285,7 +288,7 @@ func clientLoadConfig(path string) { } } -func clientCreateFlagSet() (fs *flag.FlagSet) { +func clientCreateFlagSet(args []string) (fs *flag.FlagSet) { fs = flag.NewFlagSet("Autoygg", flag.ContinueOnError) fs.Usage = func() { clientUsage(fs) } @@ -310,7 +313,7 @@ func clientCreateFlagSet() (fs *flag.FlagSet) { fs.Bool("help", false, "print usage and exit") fs.Bool("version", false, "print version and exit") - err := fs.Parse(os.Args[1:]) + err := fs.Parse(args) if err != nil { Fatal(err) } @@ -334,7 +337,7 @@ func doInfoRequest(fs *flag.FlagSet, gatewayHost string, gatewayPort string) (i client := http.Client{ Transport: &http.Transport{ Dial: (&net.Dialer{ - Timeout: 200 * time.Millisecond, + Timeout: 500 * time.Millisecond, }).Dial, }, } @@ -356,8 +359,17 @@ func doInfoRequest(fs *flag.FlagSet, gatewayHost string, gatewayPort string) (i func doRequest(fs *flag.FlagSet, action string, gatewayHost string, gatewayPort string, State state) (r registration, newState state, err error) { newState = State + + // Do an info request to know if registration is required + i, err := handleInfoWorker(fs) + if err != nil { + handleError(err, cViper, false) + return + } + + verb := "post" log.Printf("Sending `" + action + "` request to autoygg") - response, err := doRequestWorker(fs, "post", action, gatewayHost, gatewayPort) + response, err := doRequestWorker(fs, verb, action, gatewayHost, gatewayPort, i) if err != nil { handleError(err, cViper, false) return @@ -434,7 +446,7 @@ func saveState(State state) { } func clientValidateConfig() (fs *flag.FlagSet) { - fs = clientCreateFlagSet() + fs = clientCreateFlagSet(os.Args[1:]) if cViper.GetBool("UseConfig") { cViper.SetConfigType("yaml") @@ -487,27 +499,29 @@ func clientValidateConfig() (fs *flag.FlagSet) { return } -func handleInfo(fs *flag.FlagSet) { - i, err := doInfoRequest(fs, cViper.GetString("GatewayHost"), cViper.GetString("GatewayPort")) +func handleInfoWorker(fs *flag.FlagSet) (i info, err error) { + i, err = doInfoRequest(fs, cViper.GetString("GatewayHost"), cViper.GetString("GatewayPort")) if err != nil { if os.IsTimeout(err) { - logAndExit(fmt.Sprintf("Timeout: could not connect to gateway at %s", cViper.GetString("GatewayHost")), 1) - } else { - logAndExit(err.Error(), 1) + err = fmt.Errorf("Timeout: could not connect to gateway at %s", cViper.GetString("GatewayHost")) } } - json, err := json.MarshalIndent(i, "", " ") + return +} + +func handleInfo(fs *flag.FlagSet, i info) { + infoJson, err := json.MarshalIndent(i, "", " ") if err != nil { logAndExit(err.Error(), 1) } - fmt.Printf("%s\n", json) + fmt.Printf("%s\n", infoJson) os.Exit(0) } // ClientMain is the main() function for the client program func ClientMain() { cViper = viper.New() - setupLogWriters(cViper) + setupLogWriters(cViper, true) fs := clientValidateConfig() @@ -537,14 +551,21 @@ func ClientMain() { } if cViper.GetString("Action") == "info" { - handleInfo(fs) - } else if cViper.GetString("Action") == "register" { - State.DesiredState = "connected" - } else if cViper.GetString("Action") == "release" { - State.DesiredState = "disconnected" - State, err = clientTearDownRoutes(State.ClientIP, State.ClientNetMask, State.ClientGateway, State.GatewayPublicKey, State) + i, err := handleInfoWorker(fs) + // if the 'info' request failed bail out here if err != nil { - Fatal(err) + logAndExit(err.Error(), 1) + } + handleInfo(fs, i) + } else { + if cViper.GetString("Action") == "register" || cViper.GetString("Action") == "renew" { + State.DesiredState = "connected" + } else if cViper.GetString("Action") == "release" { + State.DesiredState = "disconnected" + State, err = clientTearDownRoutes(State.ClientIP, State.ClientNetMask, State.ClientGateway, State.GatewayPublicKey, State) + if err != nil { + Fatal(err) + } } } diff --git a/internal/client_test.go b/internal/client_test.go index 71ab5c2..36ce62e 100644 --- a/internal/client_test.go +++ b/internal/client_test.go @@ -29,6 +29,9 @@ type Suite struct{} var YggAddress string var serverConfigDir string +var clientEmail string +var clientName string +var clientPhone string var srv *http.Server var db *gorm.DB @@ -42,9 +45,9 @@ func StopServer(c *check.C) { func StartServer(c *check.C) { sViper = viper.New() - serverLoadConfig(serverConfigDir) + serverLoadConfig(serverConfigDir, []string{}) - db = setupDB("sqlite3", sViper.GetString("StateDir")+"/autoygg.db") + db = setupDB("sqlite3", sViper.GetString("StateDir")+"/autoygg.db", false) r := setupRouter(db) srv = &http.Server{ @@ -115,7 +118,7 @@ func (*Suite) TestConfigLoading(c *check.C) { // Load default config cViper = viper.New() - clientCreateFlagSet() + clientCreateFlagSet([]string{}) // Test defaults c.Assert(cViper.GetBool("daemon"), check.Equals, true) @@ -145,7 +148,11 @@ func (*Suite) TestConfigLoading(c *check.C) { func (*Suite) TestInfo(c *check.C) { // Load default config - fs := clientCreateFlagSet() + cViper = viper.New() + fs := clientCreateFlagSet([]string{}) + + // Populate a custom config file + writeServerConfig(c, []byte("---\nListenHost: \""+YggAddress+"\"\nListenPort: "+GatewayPort+"\nStateDir: \""+serverConfigDir+"\"")) i, err := doInfoRequest(fs, YggAddress, GatewayPort) @@ -157,7 +164,6 @@ func (*Suite) TestInfo(c *check.C) { c.Check(i.GatewayInfoURL, check.Equals, "") c.Check(i.SoftwareVersion, check.Equals, "dev") c.Check(i.RequireRegistration, check.Equals, true) - c.Check(i.RequireApproval, check.Equals, true) c.Check(i.AccessListEnabled, check.Equals, true) } @@ -168,7 +174,7 @@ func CustomClientConfig(c *check.C) (tmpDir string) { } // Populate a custom config file - clientYaml := []byte("---\nGatewayHost: \"" + YggAddress + "\"\nGatewayPort: \"" + GatewayPort + "\"\nStateDir: \"" + tmpDir + "\"\n") + clientYaml := []byte("---\nGatewayHost: \"" + YggAddress + "\"\nGatewayPort: \"" + GatewayPort + "\"\nStateDir: \"" + tmpDir + "\"\nClientEmail: \"" + clientEmail + "\"\nClientName: \"" + clientName + "\"\nClientPhone: \"" + clientPhone + "\"") configFile := filepath.Join(tmpDir, "client.yaml") err = ioutil.WriteFile(configFile, clientYaml, 0644) if err != nil { @@ -179,16 +185,145 @@ func CustomClientConfig(c *check.C) (tmpDir string) { return } +func (*Suite) TestRegistrationAndApproval(c *check.C) { + // Load default config + cViper = viper.New() + fs := clientCreateFlagSet([]string{}) + + clientEmail = "test@example.com" + clientName = "Joe Tester" + clientPhone = "555-1234567" + + tmpDir := CustomClientConfig(c) + defer os.RemoveAll(tmpDir) + + // Load default config + clientLoadConfig(tmpDir) + + var err error + var State state + State, err = loadState(State) + c.Assert(err, check.Equals, nil) + + writeAccessList(c, []byte("---\nAccessList:\n")) + writeServerConfig(c, []byte("---\nListenHost: \""+YggAddress+"\"\nListenPort: "+GatewayPort+"\nStateDir: \""+serverConfigDir+"\"\n")) + + // Try to register when our address is not on the accesslist + r, State, err := doRequest(fs, "register", YggAddress, GatewayPort, State) + c.Assert(err, check.Equals, nil) + c.Assert(r.Error, check.Equals, "Registration not allowed") + + // Add our address to the accesslist + writeAccessList(c, []byte("---\nAccessList:\n - yggip: "+YggAddress+"\n access: true\n comment: TestRegistration\n")) + + r, State, err = doRequest(fs, "register", YggAddress, GatewayPort, State) + c.Assert(err, check.Equals, nil) + c.Assert(r.Error, check.Equals, "") + + // We should have sent the personal information, double check that. + c.Assert(r.ClientName, check.Not(check.Equals), "") + c.Assert(r.ClientEmail, check.Not(check.Equals), "") + c.Assert(r.ClientPhone, check.Not(check.Equals), "") + + loadedState, err := loadState(state{}) + c.Assert(err, check.Equals, nil) + c.Assert(loadedState.State, check.Equals, "connected") + + r, State, err = doRequest(fs, "renew", YggAddress, GatewayPort, State) + c.Assert(err, check.Equals, nil) + c.Assert(r.Error, check.Equals, "") + + r, State, err = doRequest(fs, "release", YggAddress, GatewayPort, State) + c.Assert(err, check.Equals, nil) + c.Assert(r.Error, check.Equals, "") + + loadedState, err = loadState(state{}) + c.Assert(err, check.Equals, nil) + c.Assert(loadedState.State, check.Equals, "disconnected") + + // Release non-existent lease + r, State, err = doRequest(fs, "release", YggAddress, GatewayPort, State) + c.Assert(err, check.Equals, nil) + c.Assert(r.Error, check.Equals, "Registration not found") + + // Renew non-existent lease + r, _, err = doRequest(fs, "renew", YggAddress, GatewayPort, State) + c.Assert(err, check.Equals, nil) + c.Assert(r.Error, check.Equals, "Registration not found") +} + func (*Suite) TestRegistration(c *check.C) { // Load default config - fs := clientCreateFlagSet() + cViper = viper.New() + fs := clientCreateFlagSet([]string{}) + + clientEmail = "test@example.com" + clientName = "Joe Tester" + clientPhone = "555-1234567" tmpDir := CustomClientConfig(c) defer os.RemoveAll(tmpDir) + // Load default config + clientLoadConfig(tmpDir) + + var err error + var State state + State, err = loadState(State) + c.Assert(err, check.Equals, nil) + + writeAccessList(c, []byte("---\nAccessList:\n")) + writeServerConfig(c, []byte("---\nListenHost: \""+YggAddress+"\"\nListenPort: "+GatewayPort+"\nStateDir: \""+serverConfigDir+"\"\nAccessListEnabled: false\n")) + + r, State, err := doRequest(fs, "register", YggAddress, GatewayPort, State) + c.Assert(err, check.Equals, nil) + c.Assert(r.Error, check.Equals, "") + + // We should have sent the personal information, double check that. + c.Assert(r.ClientName, check.Not(check.Equals), "") + c.Assert(r.ClientEmail, check.Not(check.Equals), "") + c.Assert(r.ClientPhone, check.Not(check.Equals), "") + + loadedState, err := loadState(state{}) + c.Assert(err, check.Equals, nil) + c.Assert(loadedState.State, check.Equals, "connected") + + r, State, err = doRequest(fs, "renew", YggAddress, GatewayPort, State) + c.Assert(err, check.Equals, nil) + c.Assert(r.Error, check.Equals, "") + + r, State, err = doRequest(fs, "release", YggAddress, GatewayPort, State) + c.Assert(err, check.Equals, nil) + c.Assert(r.Error, check.Equals, "") + + loadedState, err = loadState(state{}) + c.Assert(err, check.Equals, nil) + c.Assert(loadedState.State, check.Equals, "disconnected") + + // Release non-existent lease + r, State, err = doRequest(fs, "release", YggAddress, GatewayPort, State) + c.Assert(err, check.Equals, nil) + c.Assert(r.Error, check.Equals, "Registration not found") + + // Renew non-existent lease + r, _, err = doRequest(fs, "renew", YggAddress, GatewayPort, State) + c.Assert(err, check.Equals, nil) + c.Assert(r.Error, check.Equals, "Registration not found") +} + +func (*Suite) TestAnonymous(c *check.C) { // Load default config cViper = viper.New() - clientCreateFlagSet() + fs := clientCreateFlagSet([]string{}) + + clientEmail = "" + clientName = "" + clientPhone = "" + + tmpDir := CustomClientConfig(c) + defer os.RemoveAll(tmpDir) + + // Load default config clientLoadConfig(tmpDir) var err error @@ -197,7 +332,66 @@ func (*Suite) TestRegistration(c *check.C) { c.Assert(err, check.Equals, nil) writeAccessList(c, []byte("---\nAccessList:\n")) - writeServerConfig(c, []byte("---\nListenHost: \""+YggAddress+"\"\nListenPort: "+GatewayPort+"\nStateDir: \""+serverConfigDir+"\"\n")) + writeServerConfig(c, []byte("---\nListenHost: \""+YggAddress+"\"\nListenPort: "+GatewayPort+"\nStateDir: \""+serverConfigDir+"\"\nRequireRegistration: false\nAccessListEnabled: false\n")) + + r, State, err := doRequest(fs, "register", YggAddress, GatewayPort, State) + c.Assert(err, check.Equals, nil) + c.Assert(r.Error, check.Equals, "") + + // We should not have sent the personal information, double check that. + c.Assert(r.ClientName, check.Equals, "") + c.Assert(r.ClientEmail, check.Equals, "") + c.Assert(r.ClientPhone, check.Equals, "") + + loadedState, err := loadState(state{}) + c.Assert(err, check.Equals, nil) + c.Assert(loadedState.State, check.Equals, "connected") + + r, State, err = doRequest(fs, "renew", YggAddress, GatewayPort, State) + c.Assert(err, check.Equals, nil) + c.Assert(r.Error, check.Equals, "") + + r, State, err = doRequest(fs, "release", YggAddress, GatewayPort, State) + c.Assert(err, check.Equals, nil) + c.Assert(r.Error, check.Equals, "") + + loadedState, err = loadState(state{}) + c.Assert(err, check.Equals, nil) + c.Assert(loadedState.State, check.Equals, "disconnected") + + // Release non-existent lease + r, State, err = doRequest(fs, "release", YggAddress, GatewayPort, State) + c.Assert(err, check.Equals, nil) + c.Assert(r.Error, check.Equals, "Registration not found") + + // Renew non-existent lease + r, _, err = doRequest(fs, "renew", YggAddress, GatewayPort, State) + c.Assert(err, check.Equals, nil) + c.Assert(r.Error, check.Equals, "Registration not found") +} + +func (*Suite) TestAnonymousAndApproval(c *check.C) { + // Load default config + cViper = viper.New() + fs := clientCreateFlagSet([]string{}) + + clientEmail = "" + clientName = "" + clientPhone = "" + + tmpDir := CustomClientConfig(c) + defer os.RemoveAll(tmpDir) + + // Load default config + clientLoadConfig(tmpDir) + + var err error + var State state + State, err = loadState(State) + c.Assert(err, check.Equals, nil) + + writeAccessList(c, []byte("---\nAccessList:\n")) + writeServerConfig(c, []byte("---\nListenHost: \""+YggAddress+"\"\nListenPort: "+GatewayPort+"\nStateDir: \""+serverConfigDir+"\"\nRequireRegistration: false\n")) // Try to register when our address is not on the accesslist r, State, err := doRequest(fs, "register", YggAddress, GatewayPort, State) @@ -211,6 +405,11 @@ func (*Suite) TestRegistration(c *check.C) { c.Assert(err, check.Equals, nil) c.Assert(r.Error, check.Equals, "") + // We should not have sent the personal information, double check that. + c.Assert(r.ClientName, check.Equals, "") + c.Assert(r.ClientEmail, check.Equals, "") + c.Assert(r.ClientPhone, check.Equals, "") + loadedState, err := loadState(state{}) c.Assert(err, check.Equals, nil) c.Assert(loadedState.State, check.Equals, "connected") @@ -240,7 +439,7 @@ func (*Suite) TestRegistration(c *check.C) { func (*Suite) TestLeaseExpiration(c *check.C) { // Load default config - fs := clientCreateFlagSet() + fs := clientCreateFlagSet([]string{}) tmpDir := CustomClientConfig(c) defer os.RemoveAll(tmpDir) @@ -253,7 +452,7 @@ func (*Suite) TestLeaseExpiration(c *check.C) { c.Assert(err, check.Equals, nil) // Set LeaseTimeoutSeconds to zero seconds - writeServerConfig(c, []byte("---\nListenHost: \""+YggAddress+"\"\nListenPort: "+GatewayPort+"\nStateDir: \""+serverConfigDir+"\"\nLeaseTimeoutSeconds: 0\n")) + writeServerConfig(c, []byte("---\nListenHost: \""+YggAddress+"\"\nListenPort: "+GatewayPort+"\nStateDir: \""+serverConfigDir+"\"\nLeaseTimeoutSeconds: 0\nRequireRegistration: false\n")) writeAccessList(c, []byte("---\nAccessList:\n - yggip: "+YggAddress+"\n access: true\n comment: TestRegistration\n")) r, State, err := doRequest(fs, "register", YggAddress, GatewayPort, State) diff --git a/internal/server.go b/internal/server.go index cab6f0b..0bbfb76 100644 --- a/internal/server.go +++ b/internal/server.go @@ -64,36 +64,26 @@ func enablePrometheusEndpoint() (p *ginprometheus.Prometheus) { return } -func registrationAllowed(address string) bool { - if !sViper.GetBool("RequireRegistration") { - // Registration is disabled. Reject. - debug("Registration is not required, rejecting request from %s\n", address) - return false - } - +func checkACL(address string) bool { if sViper.GetBool("AccessListEnabled") { if _, found := accesslist[address]; found && accesslist[address].Access { // The address is on the accesslist. Accept. - debug("This address is accesslisted, accepted request from %s\n", address) + log.Printf("This address is accesslisted, accepted request from %s\n", address) return true + } else { + log.Printf("AccessList enabled and this address is not on the access list, rejected request from %s\n", address) + return false } - } else { - // The accesslist is disabled and registration is required. Accept. - debug("AccessList disabled and registration is required, accepted request from %s\n", address) - return true } - debug("AccessList enabled and registration is required, address not on accesslist, rejected request from %s\n", address) - return false + // The accesslist is disabled. Accept. + log.Println("AccessList disabled, accepted request from", address) + return true } func registerHandler(db *gorm.DB, c *gin.Context) { var existingRegistration registration statusCode := http.StatusOK - if !validateRegistration(c) { - return - } - if result := db.Where("ygg_ip = ?", c.ClientIP()).First(&existingRegistration); result.Error != nil { // IsRecordNotFound is normal if we haven't seen this public key before if gorm.IsRecordNotFoundError(result.Error) { @@ -105,6 +95,11 @@ func registerHandler(db *gorm.DB, c *gin.Context) { return } } + + if !validateRegistration(c, existingRegistration) { + return + } + if existingRegistration.State == "pending" { statusCode = http.StatusAccepted } else if existingRegistration.State == "open" { @@ -164,14 +159,25 @@ func bindRegistration(c *gin.Context) (r registration, err error) { return } -func validateRegistration(c *gin.Context) bool { +func validateRegistration(c *gin.Context, r registration) bool { // Is this address allowed to register? - if !registrationAllowed(c.ClientIP()) { + if !checkACL(c.ClientIP()) { c.JSON(http.StatusForbidden, registration{Error: "Registration not allowed"}) c.Abort() - incErrorCount("registration_denied") + incErrorCount("registration_denied_acl") return false } + // When RequireRegistration is set, make sure there are values in Name/Phone/E-mail + if sViper.GetBool("RequireRegistration") { + if r.ClientEmail == "" || + r.ClientPhone == "" || + r.ClientName == "" { + c.JSON(http.StatusForbidden, registration{Error: "Client name, e-mail and phone must be supplied"}) + c.Abort() + incErrorCount("registration_denied_missing_client_info") + return false + } + } return true } @@ -180,10 +186,6 @@ func authorized(db *gorm.DB, c *gin.Context) (r registration, existingRegistrati if err != nil { return } - if !validateRegistration(c) { - err = errors.New("Registration not allowed") - return - } if result := db.Where("ygg_ip = ?", c.ClientIP()).First(&existingRegistration); result.Error != nil { if gorm.IsRecordNotFoundError(result.Error) { @@ -267,7 +269,7 @@ func newRegistrationHandler(db *gorm.DB, c *gin.Context) { return } - if !validateRegistration(c) { + if !validateRegistration(c, newRegistration) { return } @@ -300,7 +302,7 @@ func newRegistrationHandler(db *gorm.DB, c *gin.Context) { // new lease newRegistration.LeaseExpires = time.Now().UTC().Add(time.Duration(sViper.GetInt("LeaseTimeoutSeconds")) * time.Second) - log.Printf("new registration: %+v\n", newRegistration) + debug("New registration: %+v\n", newRegistration) mutex.Lock() defer mutex.Unlock() if result := db.Save(&newRegistration); result.Error != nil { @@ -397,7 +399,6 @@ func setupRouter(db *gorm.DB) (r *gin.Engine) { Location: sViper.GetString("GatewayLocation"), GatewayInfoURL: sViper.GetString("GatewayInfoURL"), RequireRegistration: sViper.GetBool("RequireRegistration"), - RequireApproval: sViper.GetBool("RequireApproval"), AccessListEnabled: sViper.GetBool("AccessListEnabled"), SoftwareVersion: version, } @@ -420,13 +421,15 @@ func setupRouter(db *gorm.DB) (r *gin.Engine) { return } -func setupDB(driver string, credentials string) (db *gorm.DB) { +func setupDB(driver string, credentials string, databaseDebug bool) (db *gorm.DB) { db, err := gorm.Open(driver, credentials) if err != nil { fmt.Printf("%s\n", err) Fatal("Couldn't initialize database connection") } - db.LogMode(true) + if databaseDebug { + db.LogMode(true) + } // Migrate the schema db.AutoMigrate(®istration{}) @@ -443,7 +446,6 @@ func serverLoadConfigDefaults() { sViper.SetDefault("GatewayLocation", "Physical location of the gateway") sViper.SetDefault("GatewayInfoURL", "") sViper.SetDefault("RequireRegistration", true) - sViper.SetDefault("RequireApproval", true) sViper.SetDefault("MaxClients", 10) sViper.SetDefault("LeaseTimeoutSeconds", 14400) // Default to 4 hours sViper.SetDefault("GatewayTunnelIP", "10.42.0.1") @@ -454,6 +456,7 @@ func serverLoadConfigDefaults() { sViper.SetDefault("AccessListFile", "accesslist") // Name of the file that contains the accesslist. Omit .yaml extension. sViper.SetDefault("YggdrasilInterface", "tun0") // Name of the yggdrasil tunnel interface sViper.SetDefault("Debug", false) + sViper.SetDefault("DatabaseDebug", false) sViper.SetDefault("Version", false) sViper.SetDefault("GatewayPublicKey", "") // Set up rudimentary firewall rules that will permit @@ -478,7 +481,7 @@ func serverLoadConfigDefaults() { sViper.SetDefault("DelIpRouteTableMeshCommand", "ip ro del default dev %%GatewayWanInterface%% table %%RoutingTableNumber%%") } -func serverLoadConfig(path string) (fs *flag.FlagSet) { +func serverLoadConfig(path string, args []string) (fs *flag.FlagSet) { viperLoadSharedDefaults(sViper) serverLoadConfigDefaults() @@ -514,7 +517,7 @@ func serverLoadConfig(path string) (fs *flag.FlagSet) { fs.Bool("help", false, "print usage and exit") fs.Bool("version", false, "print version and exit") - err = fs.Parse(os.Args[1:]) + err = fs.Parse(args) if err != nil { Fatal(err) } @@ -539,12 +542,11 @@ func serverLoadConfig(path string) (fs *flag.FlagSet) { } if configErr != nil { - Fatal(fmt.Sprintln("Fatal error reading config file:", err.Error())) + Fatal(fmt.Sprintln("Fatal error reading config file:", configErr.Error())) } initializeViperList("AccessList", path, &accesslist) - sViper.WatchConfig() // Automatically reload the main config when it changes sViper.OnConfigChange(func(e fsnotify.Event) { if sViper.GetBool("Debug") { debug = debugLog.Printf @@ -557,6 +559,7 @@ func serverLoadConfig(path string) (fs *flag.FlagSet) { debug(dumpConfiguration(sViper, "server")) debug("+=+=+=+=+=+=+=+=+=+=+=") }) + sViper.WatchConfig() // Automatically reload the main config when it changes return } @@ -585,11 +588,15 @@ func initializeViperList(name string, path string, list *map[string]acl) { *list = loadList(name, localViper) localViper.WatchConfig() // Automatically reload the config files when they change localViper.OnConfigChange(func(e fsnotify.Event) { - log.Println("Config file changed:", e.Name) - debug("Current configuration:") - debug("+=+=+=+=+=+=+=+=+=+=+=") - debug(dumpConfiguration(localViper, "server")) - debug("+=+=+=+=+=+=+=+=+=+=+=") + // It would be nice to dump the config here + log.Println(name+" file changed:", e.Name) + // If the configuration file that is reloaded is an access list, the localViper config object is empty + if name == "Config" { + debug("New configuration:") + debug("+=+=+=+=+=+=+=+=+=+=+=") + debug(dumpConfiguration(localViper, "server")) + debug("+=+=+=+=+=+=+=+=+=+=+=") + } *list = loadList(name, localViper) }) } @@ -601,7 +608,7 @@ func loadList(name string, localViper *viper.Viper) map[string]acl { list := make(map[string]acl) var slice []acl if !sViper.GetBool(name + "Enabled") { - fmt.Printf("%sEnabled is not set", name) + fmt.Printf("%sEnabled is not set\n", name) return list } err := localViper.UnmarshalKey("accesslist", &slice) @@ -892,7 +899,6 @@ func expireLeasesWorker(db *gorm.DB, mutex *sync.Mutex) { return } debug("Found %d leases to expire\n", result.RowsAffected) - log.Printf("Found %d leases to expire\n", result.RowsAffected) // These leases are expired, mark them as such and make sure that Yggdrasil doesn't route them anymore for _, r := range registrations { @@ -922,12 +928,12 @@ func expireLeasesWorker(db *gorm.DB, mutex *sync.Mutex) { // ServerMain is the main() function for the server program func ServerMain() { sViper = viper.New() - setupLogWriters(sViper) + setupLogWriters(sViper, false) // Enable the Prometheus endpoint enablePrometheus = true - fs := serverLoadConfig("") + fs := serverLoadConfig("", os.Args[1:]) // if GatewayPublicKey is not set in the config, calculate it here. // This has the advantage that --help and --version are already handled @@ -963,7 +969,7 @@ func ServerMain() { } } - db := setupDB("sqlite3", sViper.GetString("StateDir")+"/autoygg.db") + db := setupDB("sqlite3", sViper.GetString("StateDir")+"/autoygg.db", sViper.GetBool("DatabaseDebug")) defer db.Close() r := setupRouter(db) diff --git a/internal/shared.go b/internal/shared.go index e8e6e5b..5209e98 100644 --- a/internal/shared.go +++ b/internal/shared.go @@ -33,7 +33,8 @@ var ( ) type logWriter struct { - quiet bool + quiet bool + interactive bool } func command(name string, arg ...string) (cmd *exec.Cmd) { @@ -43,7 +44,9 @@ func command(name string, arg ...string) (cmd *exec.Cmd) { } func (writer logWriter) Write(bytes []byte) (int, error) { - if !writer.quiet { + if !writer.interactive && !writer.quiet { + return fmt.Printf("%s", string(bytes)) + } else if !writer.quiet { // Strip the last character, it's a newline! return fmt.Printf("%-70s", string(bytes[:len(bytes)-1])) } @@ -58,7 +61,6 @@ type info struct { GatewayInfoURL string SoftwareVersion string RequireRegistration bool - RequireApproval bool AccessListEnabled bool } @@ -533,14 +535,15 @@ func handleError(err error, lViper *viper.Viper, terminateOnFail bool) { } } -func setupLogWriters(lViper *viper.Viper) { - // Initialize our own logWriter that right justifies all lines at 70 characters - // and removes the trailing newline from log statements. Used for status lines - // where we want to write something, then execute a command, and follow with - // [ok] or [FAIL] on the same line. +func setupLogWriters(lViper *viper.Viper, interactive bool) { + // Initialize our own logWriter. In 'interactive mode', it right justifies + // all lines at 70 characters and removes the trailing newline from log + // statements. Used for status lines where we want to write something, then + // execute a command, and follow with [ok] or [FAIL] on the same line. log.SetFlags(0) writer := new(logWriter) writer.quiet = lViper.GetBool("Quiet") + writer.interactive = interactive log.SetOutput(writer) }