Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions cmd/ddns/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,14 @@ import (
"github.com/orvice/ddns/internal/wire"
)

var (
IPNotifyFormat = "[%s] ip changed, old IP: %s new IP: %s"
)

func main() {
app, err := wire.NewApp()
if err != nil {
slog.Error("init app error", "error", err)
os.Exit(1)
}
app.Run(context.Background())
if err := app.Run(context.Background()); err != nil {
slog.Error("app run error", "error", err)
os.Exit(1)
}
}
25 changes: 17 additions & 8 deletions dns/dns.go
Original file line number Diff line number Diff line change
@@ -1,35 +1,44 @@
package dns

import (
"fmt"

"github.com/libdns/alidns"
"github.com/libdns/cloudflare"
"github.com/libdns/libdns"
"github.com/orvice/ddns/internal/config"
)

const (
ProviderCloudflare = "cloudflare"
ProviderAliyun = "aliyun"
)

type LibDNS interface {
libdns.RecordGetter
libdns.RecordAppender
libdns.RecordSetter
}

func New(conf *config.Config) LibDNS {
// New returns the DNS provider selected by conf.DNSProvider. An empty value
// defaults to Cloudflare; an unknown value returns an error so misconfiguration
// fails loudly.
func New(conf *config.Config) (LibDNS, error) {
switch conf.DNSProvider {
case "cloudflare":
return NewCloudFlare(conf)
case "aliyun":
return NewAliyun(conf)
case "", ProviderCloudflare:
return NewCloudFlare(conf), nil
case ProviderAliyun:
return NewAliyun(conf), nil
default:
return nil, fmt.Errorf("unknown dns provider: %q", conf.DNSProvider)
}
return NewCloudFlare(conf)
}

// cloudflare
func NewCloudFlare(conf *config.Config) LibDNS {
provider := cloudflare.Provider{APIToken: conf.CFToken}
return &provider
}

// aliyun
func NewAliyun(conf *config.Config) LibDNS {
provider := alidns.Provider{
AccKeyID: conf.AliyunAccessKeyID,
Expand Down
21 changes: 21 additions & 0 deletions dns/dns_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"testing"

"github.com/libdns/cloudflare"
"github.com/orvice/ddns/internal/config"
)

func TestCloudFlare(t *testing.T) {
Expand All @@ -17,3 +18,23 @@ func TestCloudFlare(t *testing.T) {
}
t.Log(records)
}

func TestNew_DefaultsToCloudflare(t *testing.T) {
d, err := New(&config.Config{})
if err != nil || d == nil {
t.Fatalf("expected default cloudflare provider, got err=%v d=%v", err, d)
}
}

func TestNew_Aliyun(t *testing.T) {
d, err := New(&config.Config{DNSProvider: "aliyun"})
if err != nil || d == nil {
t.Fatalf("expected aliyun provider, got err=%v d=%v", err, d)
}
}

func TestNew_UnknownProvider(t *testing.T) {
if _, err := New(&config.Config{DNSProvider: "nope"}); err == nil {
t.Fatal("expected error for unknown provider")
}
}
118 changes: 67 additions & 51 deletions internal/app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"fmt"
"log/slog"
"os"
"strings"
"time"

Expand All @@ -17,6 +16,8 @@ import (

var (
IPNotifyFormat = "[%s] ip changed, old IP: %s new IP: %s"

DefaultInterval = 3 * time.Minute
)

type App struct {
Expand All @@ -25,6 +26,7 @@ type App struct {
dnsProvider dns.LibDNS
ipGetter ip.Getter
notifier notify.Notifier
interval time.Duration
}

func New(logger *slog.Logger, config *config.Config, dnsProvider dns.LibDNS, ipGetter ip.Getter, notifier notify.Notifier) *App {
Expand All @@ -34,85 +36,99 @@ func New(logger *slog.Logger, config *config.Config, dnsProvider dns.LibDNS, ipG
dnsProvider: dnsProvider,
ipGetter: ipGetter,
notifier: notifier,
interval: DefaultInterval,
}
}

func (a *App) Run(ctx context.Context) {
// WithInterval overrides the polling interval. Useful for tests.
func (a *App) WithInterval(d time.Duration) *App {
a.interval = d
return a
}

// Run polls and syncs until ctx is cancelled or an unrecoverable error occurs.
// It returns the cause; nil if ctx was cancelled cleanly.
func (a *App) Run(ctx context.Context) error {
ticker := time.NewTicker(a.interval)
defer ticker.Stop()

if err := a.UpdateIP(ctx); err != nil {
return err
}
for {
select {
case <-ctx.Done():
return
default:
err := a.updateIP(ctx)
if err != nil {
a.logger.Error("update ip error", "error", err.Error())
os.Exit(1)
return nil
case <-ticker.C:
if err := a.UpdateIP(ctx); err != nil {
return err
}
time.Sleep(time.Minute * 3)
}
}
}

func (a *App) updateIP(ctx context.Context) error {
ip, err := a.ipGetter.GetIP()
// UpdateIP performs one sync cycle: fetch current IP, look up the existing
// record, and create or update it as needed.
func (a *App) UpdateIP(ctx context.Context) error {
currentIP, err := a.ipGetter.GetIP()
if err != nil {
a.logger.Error("Get ip error", "error", err)
return err
}

name, zone := zoneFromDomain(a.config.Domain)
a.logger.Info("zone from domain",
"name", name,
"zone", zone)
a.logger.Info("zone from domain", "name", name, "zone", zone)

currentIP, err := a.dnsProvider.GetRecords(ctx, zone)
records, err := a.dnsProvider.GetRecords(ctx, zone)
if err != nil {
a.logger.Error("Get records error", "error", err)
return err
}

var found bool
var record *libdns.Record
for _, r := range currentIP {
if r.Name == name {
found = true
record = &r
break
}
record, found := findRecord(records, name)
if !found {
return a.appendRecord(ctx, zone, name, currentIP)
}
if found {
if record.Value == ip {
a.logger.Info("ip is same, skip update", "ip", ip)
return nil
}
oldIP := record.Value
record.Value = ip
_, err = a.dnsProvider.SetRecords(ctx, zone, []libdns.Record{
*record,
})
if err != nil {
a.logger.Error("Set records error", "error", err)
return err
}
_ = a.notifier.Send(ctx, fmt.Sprintf(IPNotifyFormat, a.config.Domain, oldIP, ip))
} else {
_, err = a.dnsProvider.AppendRecords(ctx, zone, []libdns.Record{
{
Name: name,
Value: ip,
Type: "A",
},
})
if err != nil {
a.logger.Error("Append records error", "error", err)
return err
}
if record.Value == currentIP {
a.logger.Info("ip is same, skip update", "ip", currentIP)
return nil
}
return a.updateRecord(ctx, zone, record, currentIP)
}

func (a *App) appendRecord(ctx context.Context, zone, name, ip string) error {
_, err := a.dnsProvider.AppendRecords(ctx, zone, []libdns.Record{
{Name: name, Value: ip, Type: "A"},
})
if err != nil {
a.logger.Error("Append records error", "error", err)
}
return err
}

func (a *App) updateRecord(ctx context.Context, zone string, record libdns.Record, newIP string) error {
oldIP := record.Value
record.Value = newIP
if _, err := a.dnsProvider.SetRecords(ctx, zone, []libdns.Record{record}); err != nil {
a.logger.Error("Set records error", "error", err)
return err
}
if a.notifier != nil {
_ = a.notifier.Send(ctx, fmt.Sprintf(IPNotifyFormat, a.config.Domain, oldIP, newIP))
}
return nil
}

// zoneFromDomain return zone and domain
func findRecord(records []libdns.Record, name string) (libdns.Record, bool) {
for _, r := range records {
if r.Name == name {
return r, true
}
}
return libdns.Record{}, false
}

// zoneFromDomain return name and zone.
func zoneFromDomain(domain string) (string, string) {
arr := strings.SplitN(domain, ".", 2)
if len(arr) == 1 {
Expand Down
Loading
Loading