package smtp import ( "bufio" "context" "net" "strconv" "strings" "sync" "testing" "time" ) type smtpTestServerConfig struct { extensions []string failAuth bool failMail bool failRcpt bool failData bool failQuit bool } type smtpTestServerState struct { authCalls int mailFrom string rcptTo string data string } type smtpTestServer struct { host string port int state *smtpTestServerState stop func() } func requireSMTPIntegration(t *testing.T) { t.Helper() if testing.Short() { t.Skip("skipping smtp integration test in short mode") } } func startSMTPTestServer(t *testing.T, cfg smtpTestServerConfig) *smtpTestServer { t.Helper() requireSMTPIntegration(t) ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("listen smtp test server: %v", err) } state := &smtpTestServerState{} var wg sync.WaitGroup wg.Add(1) go func() { defer wg.Done() conn, err := ln.Accept() if err != nil { return } defer conn.Close() r := bufio.NewReader(conn) w := bufio.NewWriter(conn) write := func(line string) bool { if _, err := w.WriteString(line + "\r\n"); err != nil { return false } return w.Flush() == nil } if !write("220 localhost ESMTP ready") { return } inData := false var dataBuf strings.Builder for { line, err := r.ReadString('\n') if err != nil { return } line = strings.TrimRight(line, "\r\n") if inData { if line == "." { state.data = dataBuf.String() if !write("250 2.0.0 queued") { return } inData = false continue } dataBuf.WriteString(line) dataBuf.WriteString("\n") continue } parts := strings.SplitN(line, " ", 2) cmd := strings.ToUpper(parts[0]) arg := "" if len(parts) > 1 { arg = parts[1] } switch cmd { case "EHLO", "HELO": lines := append([]string{"localhost"}, cfg.extensions...) for i, ext := range lines { prefix := "250-" if i == len(lines)-1 { prefix = "250 " } if !write(prefix + ext) { return } } case "AUTH": state.authCalls++ if cfg.failAuth { if !write("535 5.7.8 auth failed") { return } continue } authArg := strings.ToUpper(strings.TrimSpace(arg)) switch { case strings.HasPrefix(authArg, "PLAIN"): if !write("235 2.7.0 authenticated") { return } case strings.HasPrefix(authArg, "LOGIN"): if !write("334 VXNlcm5hbWU6") { return } if _, err := r.ReadString('\n'); err != nil { return } if !write("334 UGFzc3dvcmQ6") { return } if _, err := r.ReadString('\n'); err != nil { return } if !write("235 2.7.0 authenticated") { return } default: if !write("504 5.5.4 unsupported auth mechanism") { return } } case "MAIL": state.mailFrom = arg if cfg.failMail { if !write("550 5.1.0 sender rejected") { return } continue } if !write("250 2.1.0 ok") { return } case "RCPT": state.rcptTo = arg if cfg.failRcpt { if !write("550 5.1.1 recipient rejected") { return } continue } if !write("250 2.1.5 ok") { return } case "DATA": if cfg.failData { if !write("554 5.5.0 data rejected") { return } continue } if !write("354 end data with .") { return } inData = true dataBuf.Reset() case "QUIT": if cfg.failQuit { _ = write("554 5.5.1 quit failed") return } _ = write("221 2.0.0 bye") return default: if !write("250 2.0.0 ok") { return } } } }() tcpAddr := ln.Addr().(*net.TCPAddr) return &smtpTestServer{ host: "127.0.0.1", port: tcpAddr.Port, state: state, stop: func() { _ = ln.Close() wg.Wait() }, } } func TestSendUnencryptedSuccess(t *testing.T) { srv := startSMTPTestServer(t, smtpTestServerConfig{}) defer srv.stop() mailer := NewSMTPMailer(SMTPConfig{ Host: srv.host, Port: srv.port, From: "noreply@example.com", Mode: SMTPModeUnencrypted, }) err := mailer.Send(context.Background(), "to@example.com", "subject", "

hello

", "hello") if err != nil { t.Fatalf("Send: %v", err) } if !strings.Contains(srv.state.mailFrom, "noreply@example.com") { t.Fatalf("expected mail from to include sender, got %q", srv.state.mailFrom) } if !strings.Contains(srv.state.rcptTo, "to@example.com") { t.Fatalf("expected rcpt to include recipient, got %q", srv.state.rcptTo) } if !strings.Contains(srv.state.data, "Subject: subject") { t.Fatalf("expected message data to include subject, got %q", srv.state.data) } } func TestSendUnencryptedErrorPaths(t *testing.T) { tests := []struct { name string cfg smtpTestServerConfig wantErr string }{ {name: "mail from error", cfg: smtpTestServerConfig{failMail: true}, wantErr: "smtp mail from"}, {name: "rcpt error", cfg: smtpTestServerConfig{failRcpt: true}, wantErr: "smtp rcpt to"}, {name: "data error", cfg: smtpTestServerConfig{failData: true}, wantErr: "smtp data"}, {name: "quit error", cfg: smtpTestServerConfig{failQuit: true}, wantErr: "smtp quit"}, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { srv := startSMTPTestServer(t, tc.cfg) defer srv.stop() mailer := NewSMTPMailer(SMTPConfig{ Host: srv.host, Port: srv.port, From: "noreply@example.com", Mode: SMTPModeUnencrypted, }) err := mailer.Send(context.Background(), "to@example.com", "subject", "

hello

", "hello") if err == nil || !strings.Contains(err.Error(), tc.wantErr) { t.Fatalf("expected %q error, got %v", tc.wantErr, err) } }) } } func TestSendAuthModes(t *testing.T) { t.Run("auth none does not issue AUTH command", func(t *testing.T) { srv := startSMTPTestServer(t, smtpTestServerConfig{extensions: []string{"AUTH PLAIN LOGIN"}}) defer srv.stop() mailer := NewSMTPMailer(SMTPConfig{ Host: srv.host, Port: srv.port, From: "noreply@example.com", Mode: SMTPModeUnencrypted, Username: "alice", Password: "secret", Auth: "none", }) if err := mailer.Send(context.Background(), "to@example.com", "subject", "

hello

", "hello"); err != nil { t.Fatalf("Send: %v", err) } if srv.state.authCalls != 0 { t.Fatalf("expected no AUTH command, got %d", srv.state.authCalls) } }) t.Run("auth plain", func(t *testing.T) { srv := startSMTPTestServer(t, smtpTestServerConfig{extensions: []string{"AUTH PLAIN"}}) defer srv.stop() mailer := NewSMTPMailer(SMTPConfig{ Host: srv.host, Port: srv.port, From: "noreply@example.com", Mode: SMTPModeUnencrypted, Username: "alice", Password: "secret", Auth: "plain", }) if err := mailer.Send(context.Background(), "to@example.com", "subject", "

hello

", "hello"); err != nil { t.Fatalf("Send: %v", err) } if srv.state.authCalls == 0 { t.Fatal("expected AUTH command for plain auth") } }) t.Run("auth login", func(t *testing.T) { srv := startSMTPTestServer(t, smtpTestServerConfig{extensions: []string{"AUTH LOGIN"}}) defer srv.stop() mailer := NewSMTPMailer(SMTPConfig{ Host: srv.host, Port: srv.port, From: "noreply@example.com", Mode: SMTPModeUnencrypted, Username: "alice", Password: "secret", Auth: "login", }) if err := mailer.Send(context.Background(), "to@example.com", "subject", "

hello

", "hello"); err != nil { t.Fatalf("Send: %v", err) } if srv.state.authCalls == 0 { t.Fatal("expected AUTH command for login auth") } }) t.Run("auth auto failure", func(t *testing.T) { srv := startSMTPTestServer(t, smtpTestServerConfig{failAuth: true}) defer srv.stop() mailer := NewSMTPMailer(SMTPConfig{ Host: srv.host, Port: srv.port, From: "noreply@example.com", Mode: SMTPModeUnencrypted, Username: "alice", Password: "secret", Auth: "auto", }) err := mailer.Send(context.Background(), "to@example.com", "subject", "

hello

", "hello") if err == nil || !strings.Contains(err.Error(), "smtp auth failed") { t.Fatalf("expected smtp auth failed error, got %v", err) } }) t.Run("unsupported auth method", func(t *testing.T) { srv := startSMTPTestServer(t, smtpTestServerConfig{}) defer srv.stop() mailer := NewSMTPMailer(SMTPConfig{ Host: srv.host, Port: srv.port, From: "noreply@example.com", Mode: SMTPModeUnencrypted, Username: "alice", Password: "secret", Auth: "weird", }) err := mailer.Send(context.Background(), "to@example.com", "subject", "

hello

", "hello") if err == nil || !strings.Contains(err.Error(), "unsupported smtp auth method") { t.Fatalf("expected unsupported auth method error, got %v", err) } }) } func TestSendTLSFailsWhenStartTLSExtensionMissing(t *testing.T) { srv := startSMTPTestServer(t, smtpTestServerConfig{}) defer srv.stop() mailer := NewSMTPMailer(SMTPConfig{ Host: srv.host, Port: srv.port, From: "noreply@example.com", Mode: SMTPModeTLS, }) err := mailer.Send(context.Background(), "to@example.com", "subject", "

hello

", "hello") if err == nil || !strings.Contains(err.Error(), "does not support STARTTLS") { t.Fatalf("expected STARTTLS unsupported error, got %v", err) } } func TestDialHelpersWrapDialErrors(t *testing.T) { requireSMTPIntegration(t) ctx := context.Background() deadline := time.Now().Add(500 * time.Millisecond) ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("listen ephemeral: %v", err) } addr := ln.Addr().String() _ = ln.Close() host, portStr, err := net.SplitHostPort(addr) if err != nil { t.Fatalf("split host port: %v", err) } port, err := strconv.Atoi(portStr) if err != nil { t.Fatalf("parse port: %v", err) } unusedAddr := net.JoinHostPort(host, strconv.Itoa(port)) if _, err := dialPlain(ctx, unusedAddr, host, deadline); err == nil || !strings.Contains(err.Error(), "smtp plain dial") { t.Fatalf("expected wrapped plain dial error, got %v", err) } if _, err := dialSSL(ctx, unusedAddr, host, deadline); err == nil || !strings.Contains(err.Error(), "smtp ssl dial") { t.Fatalf("expected wrapped ssl dial error, got %v", err) } }