add core lib
This commit is contained in:
345
smtp/smtp_mailer.go
Normal file
345
smtp/smtp_mailer.go
Normal file
@@ -0,0 +1,345 @@
|
||||
package smtp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/smtp"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type SMTPMode string
|
||||
|
||||
const (
|
||||
SMTPModeSSL SMTPMode = "ssl"
|
||||
SMTPModeTLS SMTPMode = "tls"
|
||||
SMTPModeUnencrypted SMTPMode = "unencrypted"
|
||||
)
|
||||
|
||||
type SMTPConfig struct {
|
||||
Host string
|
||||
Port int
|
||||
Username string
|
||||
Password string
|
||||
From string
|
||||
Mode SMTPMode
|
||||
Auth string
|
||||
}
|
||||
|
||||
type Mailer interface {
|
||||
Send(ctx context.Context, to string, subject string, htmlBody string, textBody string) error
|
||||
}
|
||||
|
||||
type SMTPMailer struct {
|
||||
cfg SMTPConfig
|
||||
}
|
||||
|
||||
const defaultSMTPOperationTimeout = 30 * time.Second
|
||||
|
||||
func NewSMTPMailer(cfg SMTPConfig) *SMTPMailer {
|
||||
return &SMTPMailer{cfg: cfg}
|
||||
}
|
||||
|
||||
func (m *SMTPMailer) Send(ctx context.Context, to string, subject string, htmlBody string, textBody string) error {
|
||||
if strings.TrimSpace(m.cfg.Host) == "" || m.cfg.Port == 0 || strings.TrimSpace(m.cfg.From) == "" {
|
||||
return fmt.Errorf("smtp is not configured")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
addr := fmt.Sprintf("%s:%d", m.cfg.Host, m.cfg.Port)
|
||||
deadline := smtpSendDeadline(ctx, now)
|
||||
if !deadline.After(now) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
return context.DeadlineExceeded
|
||||
}
|
||||
|
||||
var client *smtp.Client
|
||||
var err error
|
||||
|
||||
switch m.cfg.Mode {
|
||||
case SMTPModeSSL:
|
||||
client, err = dialSSL(ctx, addr, m.cfg.Host, deadline)
|
||||
case SMTPModeUnencrypted:
|
||||
client, err = dialPlain(ctx, addr, m.cfg.Host, deadline)
|
||||
case SMTPModeTLS:
|
||||
fallthrough
|
||||
default:
|
||||
client, err = dialStartTLS(ctx, addr, m.cfg.Host, deadline)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer client.Close()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
_ = client.Close()
|
||||
case <-done:
|
||||
}
|
||||
}()
|
||||
defer close(done)
|
||||
|
||||
if m.cfg.Username != "" {
|
||||
if err := authenticate(client, m.cfg); err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := client.Mail(m.cfg.From); err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
return fmt.Errorf("smtp mail from: %w", err)
|
||||
}
|
||||
if err := client.Rcpt(to); err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
return fmt.Errorf("smtp rcpt to: %w", err)
|
||||
}
|
||||
|
||||
writer, err := client.Data()
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
return fmt.Errorf("smtp data: %w", err)
|
||||
}
|
||||
|
||||
message := buildMIMEMessage(m.cfg.From, to, subject, textBody, htmlBody)
|
||||
if _, err := writer.Write([]byte(message)); err != nil {
|
||||
_ = writer.Close()
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
return fmt.Errorf("smtp write body: %w", err)
|
||||
}
|
||||
if err := writer.Close(); err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
return fmt.Errorf("smtp close body: %w", err)
|
||||
}
|
||||
if err := client.Quit(); err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
return fmt.Errorf("smtp quit: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func smtpSendDeadline(ctx context.Context, now time.Time) time.Time {
|
||||
if ctx != nil {
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
return deadline
|
||||
}
|
||||
}
|
||||
return now.Add(defaultSMTPOperationTimeout)
|
||||
}
|
||||
|
||||
func authenticate(client *smtp.Client, cfg SMTPConfig) error {
|
||||
authMethod := strings.ToLower(strings.TrimSpace(cfg.Auth))
|
||||
if authMethod == "" {
|
||||
authMethod = "auto"
|
||||
}
|
||||
|
||||
_, methodsRaw := client.Extension("AUTH")
|
||||
methods := strings.ToUpper(methodsRaw)
|
||||
|
||||
tryLogin := strings.Contains(methods, "LOGIN")
|
||||
tryPlain := strings.Contains(methods, "PLAIN")
|
||||
|
||||
switch authMethod {
|
||||
case "none":
|
||||
return nil
|
||||
case "login":
|
||||
return authWithLogin(client, cfg)
|
||||
case "plain":
|
||||
return authWithPlain(client, cfg)
|
||||
case "auto":
|
||||
if tryLogin {
|
||||
if err := authWithLogin(client, cfg); err == nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if tryPlain {
|
||||
if err := authWithPlain(client, cfg); err == nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if !tryLogin && !tryPlain {
|
||||
// Last fallback if server did not advertise methods.
|
||||
if err := authWithLogin(client, cfg); err == nil {
|
||||
return nil
|
||||
}
|
||||
if err := authWithPlain(client, cfg); err == nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("smtp auth failed for available methods")
|
||||
default:
|
||||
return fmt.Errorf("unsupported smtp auth method: %s", authMethod)
|
||||
}
|
||||
}
|
||||
|
||||
func authWithPlain(client *smtp.Client, cfg SMTPConfig) error {
|
||||
auth := smtp.PlainAuth("", cfg.Username, cfg.Password, cfg.Host)
|
||||
if err := client.Auth(auth); err != nil {
|
||||
return fmt.Errorf("smtp plain auth: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func authWithLogin(client *smtp.Client, cfg SMTPConfig) error {
|
||||
if err := client.Auth(loginAuth{username: cfg.Username, password: cfg.Password}); err != nil {
|
||||
return fmt.Errorf("smtp login auth: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type loginAuth struct {
|
||||
username string
|
||||
password string
|
||||
}
|
||||
|
||||
func (a loginAuth) Start(_ *smtp.ServerInfo) (string, []byte, error) {
|
||||
return "LOGIN", []byte{}, nil
|
||||
}
|
||||
|
||||
func (a loginAuth) Next(fromServer []byte, more bool) ([]byte, error) {
|
||||
if !more {
|
||||
return nil, nil
|
||||
}
|
||||
challengeRaw := strings.TrimSpace(string(fromServer))
|
||||
challenge := strings.ToLower(challengeRaw)
|
||||
if decoded, err := base64.StdEncoding.DecodeString(challengeRaw); err == nil {
|
||||
challenge = strings.ToLower(string(decoded))
|
||||
}
|
||||
switch {
|
||||
case strings.Contains(challenge, "username"):
|
||||
return []byte(a.username), nil
|
||||
case strings.Contains(challenge, "password"):
|
||||
return []byte(a.password), nil
|
||||
}
|
||||
return nil, fmt.Errorf("unexpected smtp login challenge")
|
||||
}
|
||||
|
||||
func dialSSL(ctx context.Context, addr string, host string, deadline time.Time) (*smtp.Client, error) {
|
||||
timeout := time.Until(deadline)
|
||||
if timeout <= 0 {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, context.DeadlineExceeded
|
||||
}
|
||||
|
||||
dialer := &tls.Dialer{
|
||||
NetDialer: &net.Dialer{Timeout: timeout},
|
||||
Config: &tls.Config{ServerName: host},
|
||||
}
|
||||
conn, err := dialer.DialContext(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
return nil, fmt.Errorf("smtp ssl dial: %w", err)
|
||||
}
|
||||
_ = conn.SetDeadline(deadline)
|
||||
client, err := smtp.NewClient(conn, host)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("smtp new client ssl: %w", err)
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func dialPlain(ctx context.Context, addr string, host string, deadline time.Time) (*smtp.Client, error) {
|
||||
timeout := time.Until(deadline)
|
||||
if timeout <= 0 {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, context.DeadlineExceeded
|
||||
}
|
||||
|
||||
dialer := &net.Dialer{Timeout: timeout}
|
||||
conn, err := dialer.DialContext(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
return nil, fmt.Errorf("smtp plain dial: %w", err)
|
||||
}
|
||||
_ = conn.SetDeadline(deadline)
|
||||
client, err := smtp.NewClient(conn, host)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("smtp new client plain: %w", err)
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func dialStartTLS(ctx context.Context, addr string, host string, deadline time.Time) (*smtp.Client, error) {
|
||||
client, err := dialPlain(ctx, addr, host, deadline)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if ok, _ := client.Extension("STARTTLS"); !ok {
|
||||
_ = client.Close()
|
||||
return nil, fmt.Errorf("smtp server does not support STARTTLS")
|
||||
}
|
||||
if err := client.StartTLS(&tls.Config{ServerName: host}); err != nil {
|
||||
_ = client.Close()
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
return nil, fmt.Errorf("smtp starttls: %w", err)
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func buildMIMEMessage(from, to, subject, textBody, htmlBody string) string {
|
||||
boundary := randomMIMEBoundary()
|
||||
return strings.Join([]string{
|
||||
fmt.Sprintf("From: %s", from),
|
||||
fmt.Sprintf("To: %s", to),
|
||||
fmt.Sprintf("Subject: %s", subject),
|
||||
"MIME-Version: 1.0",
|
||||
fmt.Sprintf("Content-Type: multipart/alternative; boundary=%s", boundary),
|
||||
"",
|
||||
fmt.Sprintf("--%s", boundary),
|
||||
"Content-Type: text/plain; charset=UTF-8",
|
||||
"",
|
||||
textBody,
|
||||
fmt.Sprintf("--%s", boundary),
|
||||
"Content-Type: text/html; charset=UTF-8",
|
||||
"",
|
||||
htmlBody,
|
||||
fmt.Sprintf("--%s--", boundary),
|
||||
"",
|
||||
}, "\r\n")
|
||||
}
|
||||
|
||||
func randomMIMEBoundary() string {
|
||||
raw := make([]byte, 12)
|
||||
if _, err := rand.Read(raw); err != nil {
|
||||
return fmt.Sprintf("mime-boundary-%d", time.Now().UnixNano())
|
||||
}
|
||||
return "mime-boundary-" + hex.EncodeToString(raw)
|
||||
}
|
||||
126
smtp/smtp_mailer_additional_test.go
Normal file
126
smtp/smtp_mailer_additional_test.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package smtp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSendReturnsConfigurationErrorWhenSMTPMissing(t *testing.T) {
|
||||
mailer := NewSMTPMailer(SMTPConfig{Host: "", Port: 587, From: "noreply@example.com", Mode: SMTPModeTLS})
|
||||
|
||||
err := mailer.Send(context.Background(), "to@example.com", "subject", "<p>body</p>", "body")
|
||||
if err == nil {
|
||||
t.Fatal("expected configuration error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "smtp is not configured") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendReturnsDeadlineExceededForExpiredContext(t *testing.T) {
|
||||
mailer := NewSMTPMailer(SMTPConfig{Host: "smtp.example.com", Port: 587, From: "noreply@example.com", Mode: SMTPModeTLS})
|
||||
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-time.Second))
|
||||
defer cancel()
|
||||
|
||||
err := mailer.Send(ctx, "to@example.com", "subject", "<p>body</p>", "body")
|
||||
if err == nil {
|
||||
t.Fatal("expected context deadline exceeded")
|
||||
}
|
||||
if err != context.DeadlineExceeded {
|
||||
t.Fatalf("expected context deadline exceeded, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoginAuthNextHandlesChallenges(t *testing.T) {
|
||||
auth := loginAuth{username: "alice", password: "s3cret"}
|
||||
|
||||
proto, initial, err := auth.Start(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected start error: %v", err)
|
||||
}
|
||||
if proto != "LOGIN" {
|
||||
t.Fatalf("expected LOGIN auth proto, got %q", proto)
|
||||
}
|
||||
if len(initial) != 0 {
|
||||
t.Fatalf("expected empty initial response, got %q", string(initial))
|
||||
}
|
||||
|
||||
value, err := auth.Next([]byte("Username:"), true)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected username challenge error: %v", err)
|
||||
}
|
||||
if string(value) != "alice" {
|
||||
t.Fatalf("expected username response alice, got %q", string(value))
|
||||
}
|
||||
|
||||
value, err = auth.Next([]byte("UGFzc3dvcmQ6"), true)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected password challenge error: %v", err)
|
||||
}
|
||||
if string(value) != "s3cret" {
|
||||
t.Fatalf("expected password response s3cret, got %q", string(value))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoginAuthNextHandlesTerminalAndUnexpectedChallenge(t *testing.T) {
|
||||
auth := loginAuth{username: "alice", password: "s3cret"}
|
||||
|
||||
value, err := auth.Next([]byte("ignored"), false)
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil error when more=false, got %v", err)
|
||||
}
|
||||
if value != nil {
|
||||
t.Fatalf("expected nil value when more=false, got %q", string(value))
|
||||
}
|
||||
|
||||
if _, err := auth.Next([]byte("realm"), true); err == nil {
|
||||
t.Fatal("expected error for unexpected login challenge")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildMIMEMessageContainsMultipartSections(t *testing.T) {
|
||||
msg := buildMIMEMessage("from@example.com", "to@example.com", "Subject", "plain body", "<p>html body</p>")
|
||||
|
||||
checks := []string{
|
||||
"From: from@example.com",
|
||||
"To: to@example.com",
|
||||
"Subject: Subject",
|
||||
"MIME-Version: 1.0",
|
||||
"Content-Type: text/plain; charset=UTF-8",
|
||||
"plain body",
|
||||
"Content-Type: text/html; charset=UTF-8",
|
||||
"<p>html body</p>",
|
||||
}
|
||||
for _, snippet := range checks {
|
||||
if !strings.Contains(msg, snippet) {
|
||||
t.Fatalf("expected MIME message to contain %q", snippet)
|
||||
}
|
||||
}
|
||||
if !strings.Contains(msg, "\r\n") {
|
||||
t.Fatal("expected CRLF separators in MIME message")
|
||||
}
|
||||
|
||||
if !strings.Contains(msg, "Content-Type: multipart/alternative; boundary=mime-boundary-") {
|
||||
t.Fatalf("expected random mime boundary header, got %q", msg)
|
||||
}
|
||||
if !strings.Contains(msg, "--mime-boundary-") {
|
||||
t.Fatalf("expected mime boundary delimiters in message, got %q", msg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDialHelpersReturnDeadlineExceededWhenDeadlineHasPassed(t *testing.T) {
|
||||
deadline := time.Now().Add(-time.Second)
|
||||
ctx := context.Background()
|
||||
|
||||
if _, err := dialPlain(ctx, "smtp.example.com:25", "smtp.example.com", deadline); err != context.DeadlineExceeded {
|
||||
t.Fatalf("dialPlain expected context deadline exceeded, got %v", err)
|
||||
}
|
||||
if _, err := dialSSL(ctx, "smtp.example.com:465", "smtp.example.com", deadline); err != context.DeadlineExceeded {
|
||||
t.Fatalf("dialSSL expected context deadline exceeded, got %v", err)
|
||||
}
|
||||
if _, err := dialStartTLS(ctx, "smtp.example.com:587", "smtp.example.com", deadline); err != context.DeadlineExceeded {
|
||||
t.Fatalf("dialStartTLS expected context deadline exceeded, got %v", err)
|
||||
}
|
||||
}
|
||||
424
smtp/smtp_mailer_integration_test.go
Normal file
424
smtp/smtp_mailer_integration_test.go
Normal file
@@ -0,0 +1,424 @@
|
||||
package smtp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type smtpTestServerConfig struct {
|
||||
extensions []string
|
||||
failAuth bool
|
||||
failMail bool
|
||||
failRcpt bool
|
||||
failData bool
|
||||
failQuit bool
|
||||
}
|
||||
|
||||
type smtpTestServerState struct {
|
||||
authCalls int
|
||||
mailFrom string
|
||||
rcptTo string
|
||||
data string
|
||||
}
|
||||
|
||||
type smtpTestServer struct {
|
||||
host string
|
||||
port int
|
||||
state *smtpTestServerState
|
||||
stop func()
|
||||
}
|
||||
|
||||
func requireSMTPIntegration(t *testing.T) {
|
||||
t.Helper()
|
||||
if testing.Short() {
|
||||
t.Skip("skipping smtp integration test in short mode")
|
||||
}
|
||||
}
|
||||
|
||||
func startSMTPTestServer(t *testing.T, cfg smtpTestServerConfig) *smtpTestServer {
|
||||
t.Helper()
|
||||
requireSMTPIntegration(t)
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("listen smtp test server: %v", err)
|
||||
}
|
||||
|
||||
state := &smtpTestServerState{}
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
r := bufio.NewReader(conn)
|
||||
w := bufio.NewWriter(conn)
|
||||
write := func(line string) bool {
|
||||
if _, err := w.WriteString(line + "\r\n"); err != nil {
|
||||
return false
|
||||
}
|
||||
return w.Flush() == nil
|
||||
}
|
||||
if !write("220 localhost ESMTP ready") {
|
||||
return
|
||||
}
|
||||
|
||||
inData := false
|
||||
var dataBuf strings.Builder
|
||||
for {
|
||||
line, err := r.ReadString('\n')
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
line = strings.TrimRight(line, "\r\n")
|
||||
|
||||
if inData {
|
||||
if line == "." {
|
||||
state.data = dataBuf.String()
|
||||
if !write("250 2.0.0 queued") {
|
||||
return
|
||||
}
|
||||
inData = false
|
||||
continue
|
||||
}
|
||||
dataBuf.WriteString(line)
|
||||
dataBuf.WriteString("\n")
|
||||
continue
|
||||
}
|
||||
|
||||
parts := strings.SplitN(line, " ", 2)
|
||||
cmd := strings.ToUpper(parts[0])
|
||||
arg := ""
|
||||
if len(parts) > 1 {
|
||||
arg = parts[1]
|
||||
}
|
||||
|
||||
switch cmd {
|
||||
case "EHLO", "HELO":
|
||||
lines := append([]string{"localhost"}, cfg.extensions...)
|
||||
for i, ext := range lines {
|
||||
prefix := "250-"
|
||||
if i == len(lines)-1 {
|
||||
prefix = "250 "
|
||||
}
|
||||
if !write(prefix + ext) {
|
||||
return
|
||||
}
|
||||
}
|
||||
case "AUTH":
|
||||
state.authCalls++
|
||||
if cfg.failAuth {
|
||||
if !write("535 5.7.8 auth failed") {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
authArg := strings.ToUpper(strings.TrimSpace(arg))
|
||||
switch {
|
||||
case strings.HasPrefix(authArg, "PLAIN"):
|
||||
if !write("235 2.7.0 authenticated") {
|
||||
return
|
||||
}
|
||||
case strings.HasPrefix(authArg, "LOGIN"):
|
||||
if !write("334 VXNlcm5hbWU6") {
|
||||
return
|
||||
}
|
||||
if _, err := r.ReadString('\n'); err != nil {
|
||||
return
|
||||
}
|
||||
if !write("334 UGFzc3dvcmQ6") {
|
||||
return
|
||||
}
|
||||
if _, err := r.ReadString('\n'); err != nil {
|
||||
return
|
||||
}
|
||||
if !write("235 2.7.0 authenticated") {
|
||||
return
|
||||
}
|
||||
default:
|
||||
if !write("504 5.5.4 unsupported auth mechanism") {
|
||||
return
|
||||
}
|
||||
}
|
||||
case "MAIL":
|
||||
state.mailFrom = arg
|
||||
if cfg.failMail {
|
||||
if !write("550 5.1.0 sender rejected") {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
if !write("250 2.1.0 ok") {
|
||||
return
|
||||
}
|
||||
case "RCPT":
|
||||
state.rcptTo = arg
|
||||
if cfg.failRcpt {
|
||||
if !write("550 5.1.1 recipient rejected") {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
if !write("250 2.1.5 ok") {
|
||||
return
|
||||
}
|
||||
case "DATA":
|
||||
if cfg.failData {
|
||||
if !write("554 5.5.0 data rejected") {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
if !write("354 end data with <CR><LF>.<CR><LF>") {
|
||||
return
|
||||
}
|
||||
inData = true
|
||||
dataBuf.Reset()
|
||||
case "QUIT":
|
||||
if cfg.failQuit {
|
||||
_ = write("554 5.5.1 quit failed")
|
||||
return
|
||||
}
|
||||
_ = write("221 2.0.0 bye")
|
||||
return
|
||||
default:
|
||||
if !write("250 2.0.0 ok") {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
tcpAddr := ln.Addr().(*net.TCPAddr)
|
||||
return &smtpTestServer{
|
||||
host: "127.0.0.1",
|
||||
port: tcpAddr.Port,
|
||||
state: state,
|
||||
stop: func() {
|
||||
_ = ln.Close()
|
||||
wg.Wait()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendUnencryptedSuccess(t *testing.T) {
|
||||
srv := startSMTPTestServer(t, smtpTestServerConfig{})
|
||||
defer srv.stop()
|
||||
|
||||
mailer := NewSMTPMailer(SMTPConfig{
|
||||
Host: srv.host,
|
||||
Port: srv.port,
|
||||
From: "noreply@example.com",
|
||||
Mode: SMTPModeUnencrypted,
|
||||
})
|
||||
|
||||
err := mailer.Send(context.Background(), "to@example.com", "subject", "<p>hello</p>", "hello")
|
||||
if err != nil {
|
||||
t.Fatalf("Send: %v", err)
|
||||
}
|
||||
if !strings.Contains(srv.state.mailFrom, "noreply@example.com") {
|
||||
t.Fatalf("expected mail from to include sender, got %q", srv.state.mailFrom)
|
||||
}
|
||||
if !strings.Contains(srv.state.rcptTo, "to@example.com") {
|
||||
t.Fatalf("expected rcpt to include recipient, got %q", srv.state.rcptTo)
|
||||
}
|
||||
if !strings.Contains(srv.state.data, "Subject: subject") {
|
||||
t.Fatalf("expected message data to include subject, got %q", srv.state.data)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendUnencryptedErrorPaths(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg smtpTestServerConfig
|
||||
wantErr string
|
||||
}{
|
||||
{name: "mail from error", cfg: smtpTestServerConfig{failMail: true}, wantErr: "smtp mail from"},
|
||||
{name: "rcpt error", cfg: smtpTestServerConfig{failRcpt: true}, wantErr: "smtp rcpt to"},
|
||||
{name: "data error", cfg: smtpTestServerConfig{failData: true}, wantErr: "smtp data"},
|
||||
{name: "quit error", cfg: smtpTestServerConfig{failQuit: true}, wantErr: "smtp quit"},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
srv := startSMTPTestServer(t, tc.cfg)
|
||||
defer srv.stop()
|
||||
|
||||
mailer := NewSMTPMailer(SMTPConfig{
|
||||
Host: srv.host,
|
||||
Port: srv.port,
|
||||
From: "noreply@example.com",
|
||||
Mode: SMTPModeUnencrypted,
|
||||
})
|
||||
err := mailer.Send(context.Background(), "to@example.com", "subject", "<p>hello</p>", "hello")
|
||||
if err == nil || !strings.Contains(err.Error(), tc.wantErr) {
|
||||
t.Fatalf("expected %q error, got %v", tc.wantErr, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendAuthModes(t *testing.T) {
|
||||
t.Run("auth none does not issue AUTH command", func(t *testing.T) {
|
||||
srv := startSMTPTestServer(t, smtpTestServerConfig{extensions: []string{"AUTH PLAIN LOGIN"}})
|
||||
defer srv.stop()
|
||||
|
||||
mailer := NewSMTPMailer(SMTPConfig{
|
||||
Host: srv.host,
|
||||
Port: srv.port,
|
||||
From: "noreply@example.com",
|
||||
Mode: SMTPModeUnencrypted,
|
||||
Username: "alice",
|
||||
Password: "secret",
|
||||
Auth: "none",
|
||||
})
|
||||
|
||||
if err := mailer.Send(context.Background(), "to@example.com", "subject", "<p>hello</p>", "hello"); err != nil {
|
||||
t.Fatalf("Send: %v", err)
|
||||
}
|
||||
if srv.state.authCalls != 0 {
|
||||
t.Fatalf("expected no AUTH command, got %d", srv.state.authCalls)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("auth plain", func(t *testing.T) {
|
||||
srv := startSMTPTestServer(t, smtpTestServerConfig{extensions: []string{"AUTH PLAIN"}})
|
||||
defer srv.stop()
|
||||
|
||||
mailer := NewSMTPMailer(SMTPConfig{
|
||||
Host: srv.host,
|
||||
Port: srv.port,
|
||||
From: "noreply@example.com",
|
||||
Mode: SMTPModeUnencrypted,
|
||||
Username: "alice",
|
||||
Password: "secret",
|
||||
Auth: "plain",
|
||||
})
|
||||
|
||||
if err := mailer.Send(context.Background(), "to@example.com", "subject", "<p>hello</p>", "hello"); err != nil {
|
||||
t.Fatalf("Send: %v", err)
|
||||
}
|
||||
if srv.state.authCalls == 0 {
|
||||
t.Fatal("expected AUTH command for plain auth")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("auth login", func(t *testing.T) {
|
||||
srv := startSMTPTestServer(t, smtpTestServerConfig{extensions: []string{"AUTH LOGIN"}})
|
||||
defer srv.stop()
|
||||
|
||||
mailer := NewSMTPMailer(SMTPConfig{
|
||||
Host: srv.host,
|
||||
Port: srv.port,
|
||||
From: "noreply@example.com",
|
||||
Mode: SMTPModeUnencrypted,
|
||||
Username: "alice",
|
||||
Password: "secret",
|
||||
Auth: "login",
|
||||
})
|
||||
|
||||
if err := mailer.Send(context.Background(), "to@example.com", "subject", "<p>hello</p>", "hello"); err != nil {
|
||||
t.Fatalf("Send: %v", err)
|
||||
}
|
||||
if srv.state.authCalls == 0 {
|
||||
t.Fatal("expected AUTH command for login auth")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("auth auto failure", func(t *testing.T) {
|
||||
srv := startSMTPTestServer(t, smtpTestServerConfig{failAuth: true})
|
||||
defer srv.stop()
|
||||
|
||||
mailer := NewSMTPMailer(SMTPConfig{
|
||||
Host: srv.host,
|
||||
Port: srv.port,
|
||||
From: "noreply@example.com",
|
||||
Mode: SMTPModeUnencrypted,
|
||||
Username: "alice",
|
||||
Password: "secret",
|
||||
Auth: "auto",
|
||||
})
|
||||
|
||||
err := mailer.Send(context.Background(), "to@example.com", "subject", "<p>hello</p>", "hello")
|
||||
if err == nil || !strings.Contains(err.Error(), "smtp auth failed") {
|
||||
t.Fatalf("expected smtp auth failed error, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("unsupported auth method", func(t *testing.T) {
|
||||
srv := startSMTPTestServer(t, smtpTestServerConfig{})
|
||||
defer srv.stop()
|
||||
|
||||
mailer := NewSMTPMailer(SMTPConfig{
|
||||
Host: srv.host,
|
||||
Port: srv.port,
|
||||
From: "noreply@example.com",
|
||||
Mode: SMTPModeUnencrypted,
|
||||
Username: "alice",
|
||||
Password: "secret",
|
||||
Auth: "weird",
|
||||
})
|
||||
|
||||
err := mailer.Send(context.Background(), "to@example.com", "subject", "<p>hello</p>", "hello")
|
||||
if err == nil || !strings.Contains(err.Error(), "unsupported smtp auth method") {
|
||||
t.Fatalf("expected unsupported auth method error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSendTLSFailsWhenStartTLSExtensionMissing(t *testing.T) {
|
||||
srv := startSMTPTestServer(t, smtpTestServerConfig{})
|
||||
defer srv.stop()
|
||||
|
||||
mailer := NewSMTPMailer(SMTPConfig{
|
||||
Host: srv.host,
|
||||
Port: srv.port,
|
||||
From: "noreply@example.com",
|
||||
Mode: SMTPModeTLS,
|
||||
})
|
||||
|
||||
err := mailer.Send(context.Background(), "to@example.com", "subject", "<p>hello</p>", "hello")
|
||||
if err == nil || !strings.Contains(err.Error(), "does not support STARTTLS") {
|
||||
t.Fatalf("expected STARTTLS unsupported error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDialHelpersWrapDialErrors(t *testing.T) {
|
||||
requireSMTPIntegration(t)
|
||||
|
||||
ctx := context.Background()
|
||||
deadline := time.Now().Add(500 * time.Millisecond)
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("listen ephemeral: %v", err)
|
||||
}
|
||||
addr := ln.Addr().String()
|
||||
_ = ln.Close()
|
||||
host, portStr, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
t.Fatalf("split host port: %v", err)
|
||||
}
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
t.Fatalf("parse port: %v", err)
|
||||
}
|
||||
unusedAddr := net.JoinHostPort(host, strconv.Itoa(port))
|
||||
|
||||
if _, err := dialPlain(ctx, unusedAddr, host, deadline); err == nil || !strings.Contains(err.Error(), "smtp plain dial") {
|
||||
t.Fatalf("expected wrapped plain dial error, got %v", err)
|
||||
}
|
||||
if _, err := dialSSL(ctx, unusedAddr, host, deadline); err == nil || !strings.Contains(err.Error(), "smtp ssl dial") {
|
||||
t.Fatalf("expected wrapped ssl dial error, got %v", err)
|
||||
}
|
||||
}
|
||||
47
smtp/smtp_mailer_test.go
Normal file
47
smtp/smtp_mailer_test.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package smtp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSendHonorsCanceledContext(t *testing.T) {
|
||||
mailer := NewSMTPMailer(SMTPConfig{
|
||||
Host: "smtp.example.com",
|
||||
Port: 587,
|
||||
From: "noreply@example.com",
|
||||
Mode: SMTPModeTLS,
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
err := mailer.Send(ctx, "to@example.com", "subject", "<p>hello</p>", "hello")
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Fatalf("expected %v, got %v", context.Canceled, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSMTPSendDeadlineUsesContextDeadline(t *testing.T) {
|
||||
now := time.Date(2026, time.January, 1, 12, 0, 0, 0, time.UTC)
|
||||
want := now.Add(2 * time.Minute)
|
||||
ctx, cancel := context.WithDeadline(context.Background(), want)
|
||||
defer cancel()
|
||||
|
||||
got := smtpSendDeadline(ctx, now)
|
||||
if !got.Equal(want) {
|
||||
t.Fatalf("expected context deadline %v, got %v", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSMTPSendDeadlineUsesDefaultWhenContextHasNoDeadline(t *testing.T) {
|
||||
now := time.Date(2026, time.January, 1, 12, 0, 0, 0, time.UTC)
|
||||
want := now.Add(defaultSMTPOperationTimeout)
|
||||
|
||||
got := smtpSendDeadline(context.Background(), now)
|
||||
if !got.Equal(want) {
|
||||
t.Fatalf("expected default deadline %v, got %v", want, got)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user