package smtp import ( "context" "crypto/rand" "crypto/tls" "encoding/base64" "encoding/hex" "fmt" "net" "net/smtp" "strings" "time" ) type SMTPMode string const ( SMTPModeSSL SMTPMode = "ssl" SMTPModeTLS SMTPMode = "tls" SMTPModeUnencrypted SMTPMode = "unencrypted" ) type SMTPConfig struct { Host string Port int Username string Password string From string Mode SMTPMode Auth string } type Mailer interface { Send(ctx context.Context, to string, subject string, htmlBody string, textBody string) error } type SMTPMailer struct { cfg SMTPConfig } const defaultSMTPOperationTimeout = 30 * time.Second func NewSMTPMailer(cfg SMTPConfig) *SMTPMailer { return &SMTPMailer{cfg: cfg} } func (m *SMTPMailer) Send(ctx context.Context, to string, subject string, htmlBody string, textBody string) error { if strings.TrimSpace(m.cfg.Host) == "" || m.cfg.Port == 0 || strings.TrimSpace(m.cfg.From) == "" { return fmt.Errorf("smtp is not configured") } if ctx == nil { ctx = context.Background() } if err := ctx.Err(); err != nil { return err } now := time.Now() addr := fmt.Sprintf("%s:%d", m.cfg.Host, m.cfg.Port) deadline := smtpSendDeadline(ctx, now) if !deadline.After(now) { if err := ctx.Err(); err != nil { return err } return context.DeadlineExceeded } var client *smtp.Client var err error switch m.cfg.Mode { case SMTPModeSSL: client, err = dialSSL(ctx, addr, m.cfg.Host, deadline) case SMTPModeUnencrypted: client, err = dialPlain(ctx, addr, m.cfg.Host, deadline) case SMTPModeTLS: fallthrough default: client, err = dialStartTLS(ctx, addr, m.cfg.Host, deadline) } if err != nil { return err } defer client.Close() done := make(chan struct{}) go func() { select { case <-ctx.Done(): _ = client.Close() case <-done: } }() defer close(done) if m.cfg.Username != "" { if err := authenticate(client, m.cfg); err != nil { if ctx.Err() != nil { return ctx.Err() } return err } } if err := client.Mail(m.cfg.From); err != nil { if ctx.Err() != nil { return ctx.Err() } return fmt.Errorf("smtp mail from: %w", err) } if err := client.Rcpt(to); err != nil { if ctx.Err() != nil { return ctx.Err() } return fmt.Errorf("smtp rcpt to: %w", err) } writer, err := client.Data() if err != nil { if ctx.Err() != nil { return ctx.Err() } return fmt.Errorf("smtp data: %w", err) } message := buildMIMEMessage(m.cfg.From, to, subject, textBody, htmlBody) if _, err := writer.Write([]byte(message)); err != nil { _ = writer.Close() if ctx.Err() != nil { return ctx.Err() } return fmt.Errorf("smtp write body: %w", err) } if err := writer.Close(); err != nil { if ctx.Err() != nil { return ctx.Err() } return fmt.Errorf("smtp close body: %w", err) } if err := client.Quit(); err != nil { if ctx.Err() != nil { return ctx.Err() } return fmt.Errorf("smtp quit: %w", err) } return nil } func smtpSendDeadline(ctx context.Context, now time.Time) time.Time { if ctx != nil { if deadline, ok := ctx.Deadline(); ok { return deadline } } return now.Add(defaultSMTPOperationTimeout) } func authenticate(client *smtp.Client, cfg SMTPConfig) error { authMethod := strings.ToLower(strings.TrimSpace(cfg.Auth)) if authMethod == "" { authMethod = "auto" } _, methodsRaw := client.Extension("AUTH") methods := strings.ToUpper(methodsRaw) tryLogin := strings.Contains(methods, "LOGIN") tryPlain := strings.Contains(methods, "PLAIN") switch authMethod { case "none": return nil case "login": return authWithLogin(client, cfg) case "plain": return authWithPlain(client, cfg) case "auto": if tryLogin { if err := authWithLogin(client, cfg); err == nil { return nil } } if tryPlain { if err := authWithPlain(client, cfg); err == nil { return nil } } if !tryLogin && !tryPlain { // Last fallback if server did not advertise methods. if err := authWithLogin(client, cfg); err == nil { return nil } if err := authWithPlain(client, cfg); err == nil { return nil } } return fmt.Errorf("smtp auth failed for available methods") default: return fmt.Errorf("unsupported smtp auth method: %s", authMethod) } } func authWithPlain(client *smtp.Client, cfg SMTPConfig) error { auth := smtp.PlainAuth("", cfg.Username, cfg.Password, cfg.Host) if err := client.Auth(auth); err != nil { return fmt.Errorf("smtp plain auth: %w", err) } return nil } func authWithLogin(client *smtp.Client, cfg SMTPConfig) error { if err := client.Auth(loginAuth{username: cfg.Username, password: cfg.Password}); err != nil { return fmt.Errorf("smtp login auth: %w", err) } return nil } type loginAuth struct { username string password string } func (a loginAuth) Start(_ *smtp.ServerInfo) (string, []byte, error) { return "LOGIN", []byte{}, nil } func (a loginAuth) Next(fromServer []byte, more bool) ([]byte, error) { if !more { return nil, nil } challengeRaw := strings.TrimSpace(string(fromServer)) challenge := strings.ToLower(challengeRaw) if decoded, err := base64.StdEncoding.DecodeString(challengeRaw); err == nil { challenge = strings.ToLower(string(decoded)) } switch { case strings.Contains(challenge, "username"): return []byte(a.username), nil case strings.Contains(challenge, "password"): return []byte(a.password), nil } return nil, fmt.Errorf("unexpected smtp login challenge") } func dialSSL(ctx context.Context, addr string, host string, deadline time.Time) (*smtp.Client, error) { timeout := time.Until(deadline) if timeout <= 0 { if err := ctx.Err(); err != nil { return nil, err } return nil, context.DeadlineExceeded } dialer := &tls.Dialer{ NetDialer: &net.Dialer{Timeout: timeout}, Config: &tls.Config{ServerName: host}, } conn, err := dialer.DialContext(ctx, "tcp", addr) if err != nil { if ctx.Err() != nil { return nil, ctx.Err() } return nil, fmt.Errorf("smtp ssl dial: %w", err) } _ = conn.SetDeadline(deadline) client, err := smtp.NewClient(conn, host) if err != nil { return nil, fmt.Errorf("smtp new client ssl: %w", err) } return client, nil } func dialPlain(ctx context.Context, addr string, host string, deadline time.Time) (*smtp.Client, error) { timeout := time.Until(deadline) if timeout <= 0 { if err := ctx.Err(); err != nil { return nil, err } return nil, context.DeadlineExceeded } dialer := &net.Dialer{Timeout: timeout} conn, err := dialer.DialContext(ctx, "tcp", addr) if err != nil { if ctx.Err() != nil { return nil, ctx.Err() } return nil, fmt.Errorf("smtp plain dial: %w", err) } _ = conn.SetDeadline(deadline) client, err := smtp.NewClient(conn, host) if err != nil { return nil, fmt.Errorf("smtp new client plain: %w", err) } return client, nil } func dialStartTLS(ctx context.Context, addr string, host string, deadline time.Time) (*smtp.Client, error) { client, err := dialPlain(ctx, addr, host, deadline) if err != nil { return nil, err } if ok, _ := client.Extension("STARTTLS"); !ok { _ = client.Close() return nil, fmt.Errorf("smtp server does not support STARTTLS") } if err := client.StartTLS(&tls.Config{ServerName: host}); err != nil { _ = client.Close() if ctx.Err() != nil { return nil, ctx.Err() } return nil, fmt.Errorf("smtp starttls: %w", err) } return client, nil } func buildMIMEMessage(from, to, subject, textBody, htmlBody string) string { boundary := randomMIMEBoundary() return strings.Join([]string{ fmt.Sprintf("From: %s", from), fmt.Sprintf("To: %s", to), fmt.Sprintf("Subject: %s", subject), "MIME-Version: 1.0", fmt.Sprintf("Content-Type: multipart/alternative; boundary=%s", boundary), "", fmt.Sprintf("--%s", boundary), "Content-Type: text/plain; charset=UTF-8", "", textBody, fmt.Sprintf("--%s", boundary), "Content-Type: text/html; charset=UTF-8", "", htmlBody, fmt.Sprintf("--%s--", boundary), "", }, "\r\n") } func randomMIMEBoundary() string { raw := make([]byte, 12) if _, err := rand.Read(raw); err != nil { return fmt.Sprintf("mime-boundary-%d", time.Now().UnixNano()) } return "mime-boundary-" + hex.EncodeToString(raw) }