Forum | Documentation | Website | Blog

Skip to content
Snippets Groups Projects
ssh_command.go 3.19 KiB
package ssh

import (
	"bytes"
	"context"
	"errors"
	"io"
	"io/ioutil"
	"strings"
	"time"

	"golang.org/x/crypto/ssh"

	"gitlab.com/gitlab-org/gitlab-runner/helpers"
)

type Client struct {
	Config

	Stdout         io.Writer
	Stderr         io.Writer
	ConnectRetries int

	client *ssh.Client
}

type Command struct {
	Environment []string
	Command     []string
	Stdin       string
}

type ExitError struct {
	Inner error
}

func (e *ExitError) Error() string {
	if e.Inner == nil {
		return "error"
	}
	return e.Inner.Error()
}

func (s *Client) getSSHKey(identityFile string) (key ssh.Signer, err error) {
	buf, err := ioutil.ReadFile(identityFile)
	if err != nil {
		return nil, err
	}
	key, err = ssh.ParsePrivateKey(buf)
	return key, err
}

func (s *Client) getSSHAuthMethods() ([]ssh.AuthMethod, error) {
	var methods []ssh.AuthMethod
	methods = append(methods, ssh.Password(s.Password))

	if s.IdentityFile != "" {
		key, err := s.getSSHKey(s.IdentityFile)
		if err != nil {
			return nil, err
		}
		methods = append(methods, ssh.PublicKeys(key))
	}

	return methods, nil
}

func (s *Client) Connect() error {
	if s.Host == "" {
		s.Host = "localhost"
	}
	if s.User == "" {
		s.User = "root"
	}
	if s.Port == "" {
		s.Port = "22"
	}

	methods, err := s.getSSHAuthMethods()
	if err != nil {
		return err
	}

	config := &ssh.ClientConfig{
		User: s.User,
		Auth: methods,
	}

	connectRetries := s.ConnectRetries
	if connectRetries == 0 {
		connectRetries = 3
	}

	var finalError error

	for i := 0; i < connectRetries; i++ {
		client, err := ssh.Dial("tcp", s.Host+":"+s.Port, config)
		if err == nil {
			s.client = client
			return nil
		}
		time.Sleep(sshRetryInterval * time.Second)
		finalError = err
	}

	return finalError
}

func (s *Client) Exec(cmd string) error {
	if s.client == nil {
		return errors.New("Not connected")
	}

	session, err := s.client.NewSession()
	if err != nil {
		return err
	}
	session.Stdout = s.Stdout
	session.Stderr = s.Stderr
	err = session.Run(cmd)
	session.Close()
	return err
}

func (s *Command) fullCommand() string {
	var arguments []string
	// TODO: This method is compatible only with Bjourne compatible shells
	for _, part := range s.Command {
		arguments = append(arguments, helpers.ShellEscape(part))
	}
	return strings.Join(arguments, " ")
}

func (s *Client) Run(ctx context.Context, cmd Command) error {
	if s.client == nil {
		return errors.New("Not connected")
	}

	session, err := s.client.NewSession()
	if err != nil {
		return err
	}
	defer session.Close()

	var envVariables bytes.Buffer
	for _, keyValue := range cmd.Environment {
		envVariables.WriteString("export " + helpers.ShellEscape(keyValue) + "\n")
	}

	session.Stdin = io.MultiReader(
		&envVariables,
		bytes.NewBufferString(cmd.Stdin),
	)
	session.Stdout = s.Stdout
	session.Stderr = s.Stderr
	err = session.Start(cmd.fullCommand())
	if err != nil {
		return err
	}

	waitCh := make(chan error)
	go func() {
		err := session.Wait()
		if _, ok := err.(*ssh.ExitError); ok {
			err = &ExitError{Inner: err}
		}
		waitCh <- err
	}()

	select {
	case <-ctx.Done():
		session.Signal(ssh.SIGKILL)
		session.Close()
		return <-waitCh

	case err := <-waitCh:
		return err
	}
}

func (s *Client) Cleanup() {
	if s.client != nil {
		s.client.Close()
	}
}