425 lines
10 KiB
Go
425 lines
10 KiB
Go
package smtp
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"net"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
type smtpTestServerConfig struct {
|
|
extensions []string
|
|
failAuth bool
|
|
failMail bool
|
|
failRcpt bool
|
|
failData bool
|
|
failQuit bool
|
|
}
|
|
|
|
type smtpTestServerState struct {
|
|
authCalls int
|
|
mailFrom string
|
|
rcptTo string
|
|
data string
|
|
}
|
|
|
|
type smtpTestServer struct {
|
|
host string
|
|
port int
|
|
state *smtpTestServerState
|
|
stop func()
|
|
}
|
|
|
|
func requireSMTPIntegration(t *testing.T) {
|
|
t.Helper()
|
|
if testing.Short() {
|
|
t.Skip("skipping smtp integration test in short mode")
|
|
}
|
|
}
|
|
|
|
func startSMTPTestServer(t *testing.T, cfg smtpTestServerConfig) *smtpTestServer {
|
|
t.Helper()
|
|
requireSMTPIntegration(t)
|
|
|
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
t.Fatalf("listen smtp test server: %v", err)
|
|
}
|
|
|
|
state := &smtpTestServerState{}
|
|
var wg sync.WaitGroup
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
conn, err := ln.Accept()
|
|
if err != nil {
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
|
|
r := bufio.NewReader(conn)
|
|
w := bufio.NewWriter(conn)
|
|
write := func(line string) bool {
|
|
if _, err := w.WriteString(line + "\r\n"); err != nil {
|
|
return false
|
|
}
|
|
return w.Flush() == nil
|
|
}
|
|
if !write("220 localhost ESMTP ready") {
|
|
return
|
|
}
|
|
|
|
inData := false
|
|
var dataBuf strings.Builder
|
|
for {
|
|
line, err := r.ReadString('\n')
|
|
if err != nil {
|
|
return
|
|
}
|
|
line = strings.TrimRight(line, "\r\n")
|
|
|
|
if inData {
|
|
if line == "." {
|
|
state.data = dataBuf.String()
|
|
if !write("250 2.0.0 queued") {
|
|
return
|
|
}
|
|
inData = false
|
|
continue
|
|
}
|
|
dataBuf.WriteString(line)
|
|
dataBuf.WriteString("\n")
|
|
continue
|
|
}
|
|
|
|
parts := strings.SplitN(line, " ", 2)
|
|
cmd := strings.ToUpper(parts[0])
|
|
arg := ""
|
|
if len(parts) > 1 {
|
|
arg = parts[1]
|
|
}
|
|
|
|
switch cmd {
|
|
case "EHLO", "HELO":
|
|
lines := append([]string{"localhost"}, cfg.extensions...)
|
|
for i, ext := range lines {
|
|
prefix := "250-"
|
|
if i == len(lines)-1 {
|
|
prefix = "250 "
|
|
}
|
|
if !write(prefix + ext) {
|
|
return
|
|
}
|
|
}
|
|
case "AUTH":
|
|
state.authCalls++
|
|
if cfg.failAuth {
|
|
if !write("535 5.7.8 auth failed") {
|
|
return
|
|
}
|
|
continue
|
|
}
|
|
authArg := strings.ToUpper(strings.TrimSpace(arg))
|
|
switch {
|
|
case strings.HasPrefix(authArg, "PLAIN"):
|
|
if !write("235 2.7.0 authenticated") {
|
|
return
|
|
}
|
|
case strings.HasPrefix(authArg, "LOGIN"):
|
|
if !write("334 VXNlcm5hbWU6") {
|
|
return
|
|
}
|
|
if _, err := r.ReadString('\n'); err != nil {
|
|
return
|
|
}
|
|
if !write("334 UGFzc3dvcmQ6") {
|
|
return
|
|
}
|
|
if _, err := r.ReadString('\n'); err != nil {
|
|
return
|
|
}
|
|
if !write("235 2.7.0 authenticated") {
|
|
return
|
|
}
|
|
default:
|
|
if !write("504 5.5.4 unsupported auth mechanism") {
|
|
return
|
|
}
|
|
}
|
|
case "MAIL":
|
|
state.mailFrom = arg
|
|
if cfg.failMail {
|
|
if !write("550 5.1.0 sender rejected") {
|
|
return
|
|
}
|
|
continue
|
|
}
|
|
if !write("250 2.1.0 ok") {
|
|
return
|
|
}
|
|
case "RCPT":
|
|
state.rcptTo = arg
|
|
if cfg.failRcpt {
|
|
if !write("550 5.1.1 recipient rejected") {
|
|
return
|
|
}
|
|
continue
|
|
}
|
|
if !write("250 2.1.5 ok") {
|
|
return
|
|
}
|
|
case "DATA":
|
|
if cfg.failData {
|
|
if !write("554 5.5.0 data rejected") {
|
|
return
|
|
}
|
|
continue
|
|
}
|
|
if !write("354 end data with <CR><LF>.<CR><LF>") {
|
|
return
|
|
}
|
|
inData = true
|
|
dataBuf.Reset()
|
|
case "QUIT":
|
|
if cfg.failQuit {
|
|
_ = write("554 5.5.1 quit failed")
|
|
return
|
|
}
|
|
_ = write("221 2.0.0 bye")
|
|
return
|
|
default:
|
|
if !write("250 2.0.0 ok") {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
|
|
tcpAddr := ln.Addr().(*net.TCPAddr)
|
|
return &smtpTestServer{
|
|
host: "127.0.0.1",
|
|
port: tcpAddr.Port,
|
|
state: state,
|
|
stop: func() {
|
|
_ = ln.Close()
|
|
wg.Wait()
|
|
},
|
|
}
|
|
}
|
|
|
|
func TestSendUnencryptedSuccess(t *testing.T) {
|
|
srv := startSMTPTestServer(t, smtpTestServerConfig{})
|
|
defer srv.stop()
|
|
|
|
mailer := NewSMTPMailer(SMTPConfig{
|
|
Host: srv.host,
|
|
Port: srv.port,
|
|
From: "noreply@example.com",
|
|
Mode: SMTPModeUnencrypted,
|
|
})
|
|
|
|
err := mailer.Send(context.Background(), "to@example.com", "subject", "<p>hello</p>", "hello")
|
|
if err != nil {
|
|
t.Fatalf("Send: %v", err)
|
|
}
|
|
if !strings.Contains(srv.state.mailFrom, "noreply@example.com") {
|
|
t.Fatalf("expected mail from to include sender, got %q", srv.state.mailFrom)
|
|
}
|
|
if !strings.Contains(srv.state.rcptTo, "to@example.com") {
|
|
t.Fatalf("expected rcpt to include recipient, got %q", srv.state.rcptTo)
|
|
}
|
|
if !strings.Contains(srv.state.data, "Subject: subject") {
|
|
t.Fatalf("expected message data to include subject, got %q", srv.state.data)
|
|
}
|
|
}
|
|
|
|
func TestSendUnencryptedErrorPaths(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
cfg smtpTestServerConfig
|
|
wantErr string
|
|
}{
|
|
{name: "mail from error", cfg: smtpTestServerConfig{failMail: true}, wantErr: "smtp mail from"},
|
|
{name: "rcpt error", cfg: smtpTestServerConfig{failRcpt: true}, wantErr: "smtp rcpt to"},
|
|
{name: "data error", cfg: smtpTestServerConfig{failData: true}, wantErr: "smtp data"},
|
|
{name: "quit error", cfg: smtpTestServerConfig{failQuit: true}, wantErr: "smtp quit"},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
srv := startSMTPTestServer(t, tc.cfg)
|
|
defer srv.stop()
|
|
|
|
mailer := NewSMTPMailer(SMTPConfig{
|
|
Host: srv.host,
|
|
Port: srv.port,
|
|
From: "noreply@example.com",
|
|
Mode: SMTPModeUnencrypted,
|
|
})
|
|
err := mailer.Send(context.Background(), "to@example.com", "subject", "<p>hello</p>", "hello")
|
|
if err == nil || !strings.Contains(err.Error(), tc.wantErr) {
|
|
t.Fatalf("expected %q error, got %v", tc.wantErr, err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestSendAuthModes(t *testing.T) {
|
|
t.Run("auth none does not issue AUTH command", func(t *testing.T) {
|
|
srv := startSMTPTestServer(t, smtpTestServerConfig{extensions: []string{"AUTH PLAIN LOGIN"}})
|
|
defer srv.stop()
|
|
|
|
mailer := NewSMTPMailer(SMTPConfig{
|
|
Host: srv.host,
|
|
Port: srv.port,
|
|
From: "noreply@example.com",
|
|
Mode: SMTPModeUnencrypted,
|
|
Username: "alice",
|
|
Password: "secret",
|
|
Auth: "none",
|
|
})
|
|
|
|
if err := mailer.Send(context.Background(), "to@example.com", "subject", "<p>hello</p>", "hello"); err != nil {
|
|
t.Fatalf("Send: %v", err)
|
|
}
|
|
if srv.state.authCalls != 0 {
|
|
t.Fatalf("expected no AUTH command, got %d", srv.state.authCalls)
|
|
}
|
|
})
|
|
|
|
t.Run("auth plain", func(t *testing.T) {
|
|
srv := startSMTPTestServer(t, smtpTestServerConfig{extensions: []string{"AUTH PLAIN"}})
|
|
defer srv.stop()
|
|
|
|
mailer := NewSMTPMailer(SMTPConfig{
|
|
Host: srv.host,
|
|
Port: srv.port,
|
|
From: "noreply@example.com",
|
|
Mode: SMTPModeUnencrypted,
|
|
Username: "alice",
|
|
Password: "secret",
|
|
Auth: "plain",
|
|
})
|
|
|
|
if err := mailer.Send(context.Background(), "to@example.com", "subject", "<p>hello</p>", "hello"); err != nil {
|
|
t.Fatalf("Send: %v", err)
|
|
}
|
|
if srv.state.authCalls == 0 {
|
|
t.Fatal("expected AUTH command for plain auth")
|
|
}
|
|
})
|
|
|
|
t.Run("auth login", func(t *testing.T) {
|
|
srv := startSMTPTestServer(t, smtpTestServerConfig{extensions: []string{"AUTH LOGIN"}})
|
|
defer srv.stop()
|
|
|
|
mailer := NewSMTPMailer(SMTPConfig{
|
|
Host: srv.host,
|
|
Port: srv.port,
|
|
From: "noreply@example.com",
|
|
Mode: SMTPModeUnencrypted,
|
|
Username: "alice",
|
|
Password: "secret",
|
|
Auth: "login",
|
|
})
|
|
|
|
if err := mailer.Send(context.Background(), "to@example.com", "subject", "<p>hello</p>", "hello"); err != nil {
|
|
t.Fatalf("Send: %v", err)
|
|
}
|
|
if srv.state.authCalls == 0 {
|
|
t.Fatal("expected AUTH command for login auth")
|
|
}
|
|
})
|
|
|
|
t.Run("auth auto failure", func(t *testing.T) {
|
|
srv := startSMTPTestServer(t, smtpTestServerConfig{failAuth: true})
|
|
defer srv.stop()
|
|
|
|
mailer := NewSMTPMailer(SMTPConfig{
|
|
Host: srv.host,
|
|
Port: srv.port,
|
|
From: "noreply@example.com",
|
|
Mode: SMTPModeUnencrypted,
|
|
Username: "alice",
|
|
Password: "secret",
|
|
Auth: "auto",
|
|
})
|
|
|
|
err := mailer.Send(context.Background(), "to@example.com", "subject", "<p>hello</p>", "hello")
|
|
if err == nil || !strings.Contains(err.Error(), "smtp auth failed") {
|
|
t.Fatalf("expected smtp auth failed error, got %v", err)
|
|
}
|
|
})
|
|
|
|
t.Run("unsupported auth method", func(t *testing.T) {
|
|
srv := startSMTPTestServer(t, smtpTestServerConfig{})
|
|
defer srv.stop()
|
|
|
|
mailer := NewSMTPMailer(SMTPConfig{
|
|
Host: srv.host,
|
|
Port: srv.port,
|
|
From: "noreply@example.com",
|
|
Mode: SMTPModeUnencrypted,
|
|
Username: "alice",
|
|
Password: "secret",
|
|
Auth: "weird",
|
|
})
|
|
|
|
err := mailer.Send(context.Background(), "to@example.com", "subject", "<p>hello</p>", "hello")
|
|
if err == nil || !strings.Contains(err.Error(), "unsupported smtp auth method") {
|
|
t.Fatalf("expected unsupported auth method error, got %v", err)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestSendTLSFailsWhenStartTLSExtensionMissing(t *testing.T) {
|
|
srv := startSMTPTestServer(t, smtpTestServerConfig{})
|
|
defer srv.stop()
|
|
|
|
mailer := NewSMTPMailer(SMTPConfig{
|
|
Host: srv.host,
|
|
Port: srv.port,
|
|
From: "noreply@example.com",
|
|
Mode: SMTPModeTLS,
|
|
})
|
|
|
|
err := mailer.Send(context.Background(), "to@example.com", "subject", "<p>hello</p>", "hello")
|
|
if err == nil || !strings.Contains(err.Error(), "does not support STARTTLS") {
|
|
t.Fatalf("expected STARTTLS unsupported error, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestDialHelpersWrapDialErrors(t *testing.T) {
|
|
requireSMTPIntegration(t)
|
|
|
|
ctx := context.Background()
|
|
deadline := time.Now().Add(500 * time.Millisecond)
|
|
|
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
t.Fatalf("listen ephemeral: %v", err)
|
|
}
|
|
addr := ln.Addr().String()
|
|
_ = ln.Close()
|
|
host, portStr, err := net.SplitHostPort(addr)
|
|
if err != nil {
|
|
t.Fatalf("split host port: %v", err)
|
|
}
|
|
port, err := strconv.Atoi(portStr)
|
|
if err != nil {
|
|
t.Fatalf("parse port: %v", err)
|
|
}
|
|
unusedAddr := net.JoinHostPort(host, strconv.Itoa(port))
|
|
|
|
if _, err := dialPlain(ctx, unusedAddr, host, deadline); err == nil || !strings.Contains(err.Error(), "smtp plain dial") {
|
|
t.Fatalf("expected wrapped plain dial error, got %v", err)
|
|
}
|
|
if _, err := dialSSL(ctx, unusedAddr, host, deadline); err == nil || !strings.Contains(err.Error(), "smtp ssl dial") {
|
|
t.Fatalf("expected wrapped ssl dial error, got %v", err)
|
|
}
|
|
}
|