Skip to content

Commit

Permalink
Use parsed identity file
Browse files Browse the repository at this point in the history
  • Loading branch information
mingan committed Mar 7, 2018
1 parent d2ead24 commit 829e6de
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 28 deletions.
57 changes: 32 additions & 25 deletions ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ type SSHClient struct {
sess *ssh.Session
user string
host string
identityFile string
remoteStdin io.WriteCloser
remoteStdout io.Reader
remoteStderr io.Reader
Expand Down Expand Up @@ -80,34 +81,40 @@ var initAuthMethodOnce sync.Once
var authMethod ssh.AuthMethod

// initAuthMethod initiates SSH authentication method.
func initAuthMethod() {
var signers []ssh.Signer

// If there's a running SSH Agent, try to use its Private keys.
sock, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK"))
if err == nil {
agent := agent.NewClient(sock)
signers, _ = agent.Signers()
}

// Try to read user's SSH private keys form the standard paths.
files, _ := filepath.Glob(os.Getenv("HOME") + "/.ssh/id_*")
for _, file := range files {
if strings.HasSuffix(file, ".pub") {
continue // Skip public keys.
}
data, err := ioutil.ReadFile(file)
if err != nil {
continue
func initAuthMethod(identityFilePath string) func() {
return func() {
var signers []ssh.Signer

// If there's a running SSH Agent, try to use its Private keys.
sock, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK"))
if err == nil {
agent := agent.NewClient(sock)
signers, _ = agent.Signers()
}
signer, err := ssh.ParsePrivateKey(data)
if err != nil {
continue

// Try to read user's SSH private keys form the standard paths.
files, _ := filepath.Glob(os.Getenv("HOME") + "/.ssh/id_*")
// Add nonstandard path
if identityFilePath != "" {
files = append(files, identityFilePath)
}
signers = append(signers, signer)
for _, file := range files {
if strings.HasSuffix(file, ".pub") {
continue // Skip public keys.
}
data, err := ioutil.ReadFile(file)
if err != nil {
continue
}
signer, err := ssh.ParsePrivateKey(data)
if err != nil {
continue
}
signers = append(signers, signer)

}
authMethod = ssh.PublicKeys(signers...)
}
authMethod = ssh.PublicKeys(signers...)
}

// SSHDialFunc can dial an ssh server and return a client
Expand All @@ -127,7 +134,7 @@ func (c *SSHClient) ConnectWith(host string, dialer SSHDialFunc) error {
return fmt.Errorf("Already connected")
}

initAuthMethodOnce.Do(initAuthMethod)
initAuthMethodOnce.Do(initAuthMethod(c.identityFile))

err := c.parseHost(host)
if err != nil {
Expand Down
7 changes: 4 additions & 3 deletions sup.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,10 @@ func (sup *Stackup) Run(network *Network, envVars EnvList, commands ...*Command)

// SSH client.
remote := &SSHClient{
env: env + `export SUP_HOST="` + host + `";`,
user: network.User,
color: Colors[i%len(Colors)],
env: env + `export SUP_HOST="` + host + `";`,
user: network.User,
color: Colors[i%len(Colors)],
identityFile: network.IdentityFile,
}

if bastion != nil {
Expand Down

0 comments on commit 829e6de

Please sign in to comment.