346 lines
8.0 KiB
Go
346 lines
8.0 KiB
Go
package smtp
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"crypto/tls"
|
|
"encoding/base64"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"net"
|
|
"net/smtp"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
type SMTPMode string
|
|
|
|
const (
|
|
SMTPModeSSL SMTPMode = "ssl"
|
|
SMTPModeTLS SMTPMode = "tls"
|
|
SMTPModeUnencrypted SMTPMode = "unencrypted"
|
|
)
|
|
|
|
type SMTPConfig struct {
|
|
Host string
|
|
Port int
|
|
Username string
|
|
Password string
|
|
From string
|
|
Mode SMTPMode
|
|
Auth string
|
|
}
|
|
|
|
type Mailer interface {
|
|
Send(ctx context.Context, to string, subject string, htmlBody string, textBody string) error
|
|
}
|
|
|
|
type SMTPMailer struct {
|
|
cfg SMTPConfig
|
|
}
|
|
|
|
const defaultSMTPOperationTimeout = 30 * time.Second
|
|
|
|
func NewSMTPMailer(cfg SMTPConfig) *SMTPMailer {
|
|
return &SMTPMailer{cfg: cfg}
|
|
}
|
|
|
|
func (m *SMTPMailer) Send(ctx context.Context, to string, subject string, htmlBody string, textBody string) error {
|
|
if strings.TrimSpace(m.cfg.Host) == "" || m.cfg.Port == 0 || strings.TrimSpace(m.cfg.From) == "" {
|
|
return fmt.Errorf("smtp is not configured")
|
|
}
|
|
if ctx == nil {
|
|
ctx = context.Background()
|
|
}
|
|
if err := ctx.Err(); err != nil {
|
|
return err
|
|
}
|
|
|
|
now := time.Now()
|
|
addr := fmt.Sprintf("%s:%d", m.cfg.Host, m.cfg.Port)
|
|
deadline := smtpSendDeadline(ctx, now)
|
|
if !deadline.After(now) {
|
|
if err := ctx.Err(); err != nil {
|
|
return err
|
|
}
|
|
return context.DeadlineExceeded
|
|
}
|
|
|
|
var client *smtp.Client
|
|
var err error
|
|
|
|
switch m.cfg.Mode {
|
|
case SMTPModeSSL:
|
|
client, err = dialSSL(ctx, addr, m.cfg.Host, deadline)
|
|
case SMTPModeUnencrypted:
|
|
client, err = dialPlain(ctx, addr, m.cfg.Host, deadline)
|
|
case SMTPModeTLS:
|
|
fallthrough
|
|
default:
|
|
client, err = dialStartTLS(ctx, addr, m.cfg.Host, deadline)
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer client.Close()
|
|
done := make(chan struct{})
|
|
go func() {
|
|
select {
|
|
case <-ctx.Done():
|
|
_ = client.Close()
|
|
case <-done:
|
|
}
|
|
}()
|
|
defer close(done)
|
|
|
|
if m.cfg.Username != "" {
|
|
if err := authenticate(client, m.cfg); err != nil {
|
|
if ctx.Err() != nil {
|
|
return ctx.Err()
|
|
}
|
|
return err
|
|
}
|
|
}
|
|
|
|
if err := client.Mail(m.cfg.From); err != nil {
|
|
if ctx.Err() != nil {
|
|
return ctx.Err()
|
|
}
|
|
return fmt.Errorf("smtp mail from: %w", err)
|
|
}
|
|
if err := client.Rcpt(to); err != nil {
|
|
if ctx.Err() != nil {
|
|
return ctx.Err()
|
|
}
|
|
return fmt.Errorf("smtp rcpt to: %w", err)
|
|
}
|
|
|
|
writer, err := client.Data()
|
|
if err != nil {
|
|
if ctx.Err() != nil {
|
|
return ctx.Err()
|
|
}
|
|
return fmt.Errorf("smtp data: %w", err)
|
|
}
|
|
|
|
message := buildMIMEMessage(m.cfg.From, to, subject, textBody, htmlBody)
|
|
if _, err := writer.Write([]byte(message)); err != nil {
|
|
_ = writer.Close()
|
|
if ctx.Err() != nil {
|
|
return ctx.Err()
|
|
}
|
|
return fmt.Errorf("smtp write body: %w", err)
|
|
}
|
|
if err := writer.Close(); err != nil {
|
|
if ctx.Err() != nil {
|
|
return ctx.Err()
|
|
}
|
|
return fmt.Errorf("smtp close body: %w", err)
|
|
}
|
|
if err := client.Quit(); err != nil {
|
|
if ctx.Err() != nil {
|
|
return ctx.Err()
|
|
}
|
|
return fmt.Errorf("smtp quit: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func smtpSendDeadline(ctx context.Context, now time.Time) time.Time {
|
|
if ctx != nil {
|
|
if deadline, ok := ctx.Deadline(); ok {
|
|
return deadline
|
|
}
|
|
}
|
|
return now.Add(defaultSMTPOperationTimeout)
|
|
}
|
|
|
|
func authenticate(client *smtp.Client, cfg SMTPConfig) error {
|
|
authMethod := strings.ToLower(strings.TrimSpace(cfg.Auth))
|
|
if authMethod == "" {
|
|
authMethod = "auto"
|
|
}
|
|
|
|
_, methodsRaw := client.Extension("AUTH")
|
|
methods := strings.ToUpper(methodsRaw)
|
|
|
|
tryLogin := strings.Contains(methods, "LOGIN")
|
|
tryPlain := strings.Contains(methods, "PLAIN")
|
|
|
|
switch authMethod {
|
|
case "none":
|
|
return nil
|
|
case "login":
|
|
return authWithLogin(client, cfg)
|
|
case "plain":
|
|
return authWithPlain(client, cfg)
|
|
case "auto":
|
|
if tryLogin {
|
|
if err := authWithLogin(client, cfg); err == nil {
|
|
return nil
|
|
}
|
|
}
|
|
if tryPlain {
|
|
if err := authWithPlain(client, cfg); err == nil {
|
|
return nil
|
|
}
|
|
}
|
|
if !tryLogin && !tryPlain {
|
|
// Last fallback if server did not advertise methods.
|
|
if err := authWithLogin(client, cfg); err == nil {
|
|
return nil
|
|
}
|
|
if err := authWithPlain(client, cfg); err == nil {
|
|
return nil
|
|
}
|
|
}
|
|
return fmt.Errorf("smtp auth failed for available methods")
|
|
default:
|
|
return fmt.Errorf("unsupported smtp auth method: %s", authMethod)
|
|
}
|
|
}
|
|
|
|
func authWithPlain(client *smtp.Client, cfg SMTPConfig) error {
|
|
auth := smtp.PlainAuth("", cfg.Username, cfg.Password, cfg.Host)
|
|
if err := client.Auth(auth); err != nil {
|
|
return fmt.Errorf("smtp plain auth: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func authWithLogin(client *smtp.Client, cfg SMTPConfig) error {
|
|
if err := client.Auth(loginAuth{username: cfg.Username, password: cfg.Password}); err != nil {
|
|
return fmt.Errorf("smtp login auth: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
type loginAuth struct {
|
|
username string
|
|
password string
|
|
}
|
|
|
|
func (a loginAuth) Start(_ *smtp.ServerInfo) (string, []byte, error) {
|
|
return "LOGIN", []byte{}, nil
|
|
}
|
|
|
|
func (a loginAuth) Next(fromServer []byte, more bool) ([]byte, error) {
|
|
if !more {
|
|
return nil, nil
|
|
}
|
|
challengeRaw := strings.TrimSpace(string(fromServer))
|
|
challenge := strings.ToLower(challengeRaw)
|
|
if decoded, err := base64.StdEncoding.DecodeString(challengeRaw); err == nil {
|
|
challenge = strings.ToLower(string(decoded))
|
|
}
|
|
switch {
|
|
case strings.Contains(challenge, "username"):
|
|
return []byte(a.username), nil
|
|
case strings.Contains(challenge, "password"):
|
|
return []byte(a.password), nil
|
|
}
|
|
return nil, fmt.Errorf("unexpected smtp login challenge")
|
|
}
|
|
|
|
func dialSSL(ctx context.Context, addr string, host string, deadline time.Time) (*smtp.Client, error) {
|
|
timeout := time.Until(deadline)
|
|
if timeout <= 0 {
|
|
if err := ctx.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
return nil, context.DeadlineExceeded
|
|
}
|
|
|
|
dialer := &tls.Dialer{
|
|
NetDialer: &net.Dialer{Timeout: timeout},
|
|
Config: &tls.Config{ServerName: host},
|
|
}
|
|
conn, err := dialer.DialContext(ctx, "tcp", addr)
|
|
if err != nil {
|
|
if ctx.Err() != nil {
|
|
return nil, ctx.Err()
|
|
}
|
|
return nil, fmt.Errorf("smtp ssl dial: %w", err)
|
|
}
|
|
_ = conn.SetDeadline(deadline)
|
|
client, err := smtp.NewClient(conn, host)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("smtp new client ssl: %w", err)
|
|
}
|
|
return client, nil
|
|
}
|
|
|
|
func dialPlain(ctx context.Context, addr string, host string, deadline time.Time) (*smtp.Client, error) {
|
|
timeout := time.Until(deadline)
|
|
if timeout <= 0 {
|
|
if err := ctx.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
return nil, context.DeadlineExceeded
|
|
}
|
|
|
|
dialer := &net.Dialer{Timeout: timeout}
|
|
conn, err := dialer.DialContext(ctx, "tcp", addr)
|
|
if err != nil {
|
|
if ctx.Err() != nil {
|
|
return nil, ctx.Err()
|
|
}
|
|
return nil, fmt.Errorf("smtp plain dial: %w", err)
|
|
}
|
|
_ = conn.SetDeadline(deadline)
|
|
client, err := smtp.NewClient(conn, host)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("smtp new client plain: %w", err)
|
|
}
|
|
return client, nil
|
|
}
|
|
|
|
func dialStartTLS(ctx context.Context, addr string, host string, deadline time.Time) (*smtp.Client, error) {
|
|
client, err := dialPlain(ctx, addr, host, deadline)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if ok, _ := client.Extension("STARTTLS"); !ok {
|
|
_ = client.Close()
|
|
return nil, fmt.Errorf("smtp server does not support STARTTLS")
|
|
}
|
|
if err := client.StartTLS(&tls.Config{ServerName: host}); err != nil {
|
|
_ = client.Close()
|
|
if ctx.Err() != nil {
|
|
return nil, ctx.Err()
|
|
}
|
|
return nil, fmt.Errorf("smtp starttls: %w", err)
|
|
}
|
|
return client, nil
|
|
}
|
|
|
|
func buildMIMEMessage(from, to, subject, textBody, htmlBody string) string {
|
|
boundary := randomMIMEBoundary()
|
|
return strings.Join([]string{
|
|
fmt.Sprintf("From: %s", from),
|
|
fmt.Sprintf("To: %s", to),
|
|
fmt.Sprintf("Subject: %s", subject),
|
|
"MIME-Version: 1.0",
|
|
fmt.Sprintf("Content-Type: multipart/alternative; boundary=%s", boundary),
|
|
"",
|
|
fmt.Sprintf("--%s", boundary),
|
|
"Content-Type: text/plain; charset=UTF-8",
|
|
"",
|
|
textBody,
|
|
fmt.Sprintf("--%s", boundary),
|
|
"Content-Type: text/html; charset=UTF-8",
|
|
"",
|
|
htmlBody,
|
|
fmt.Sprintf("--%s--", boundary),
|
|
"",
|
|
}, "\r\n")
|
|
}
|
|
|
|
func randomMIMEBoundary() string {
|
|
raw := make([]byte, 12)
|
|
if _, err := rand.Read(raw); err != nil {
|
|
return fmt.Sprintf("mime-boundary-%d", time.Now().UnixNano())
|
|
}
|
|
return "mime-boundary-" + hex.EncodeToString(raw)
|
|
}
|