diff --git a/_example/main.go b/_example/main.go index 6bdb0e7..e8e3df8 100644 --- a/_example/main.go +++ b/_example/main.go @@ -7,8 +7,7 @@ import ( "time" "github.com/libdns/libdns" - - "github.com/libdns/vultr" + "github.com/libdns/vultr/v2" ) func main() { @@ -22,6 +21,7 @@ func main() { fmt.Printf("ZONE not set\n") return } + shouldCleanup := os.Getenv("DELETE_RECORDS") == "true" provider := vultr.Provider{APIToken: token} @@ -44,42 +44,46 @@ func main() { for _, record := range records { fmt.Printf("%s (.%s): %s, %s\n", record.RR().Name, zone, record.RR().Data, record.RR().Type) - recordId, err := vultr.GetRecordID(record) - if err != nil { - fmt.Printf("ERROR: %s\n", err.Error()) - } - if record.RR().Name == testName { - testId = recordId + testId = record.(vultr.VultrRecord).ID } } if testId != "" { - // fmt.Printf("Delete entry for %s (id:%s)\n", testName, testId) - // _, err = provider.DeleteRecords(context.TODO(), zone, []libdns.Record{libdns.Record{ - // ID: testId, - // }}) - // if err != nil { - // fmt.Printf("ERROR: %s\n", err.Error()) - // } - // Set only works if we have a record.ID - fmt.Printf("Replacing entry for %s\n", testName) - _, err = provider.SetRecords(context.TODO(), zone, []libdns.Record{libdns.TXT{ - Name: testName, - Text: fmt.Sprintf("Replacement test entry created by libdns %s", time.Now()), - TTL: time.Duration(30) * time.Second, - ProviderData: testId, - }}) - if err != nil { - fmt.Printf("ERROR: %s\n", err.Error()) + if shouldCleanup { + fmt.Printf("Delete entry for %s (id:%s)\n", testName, testId) + _, err = provider.DeleteRecords(context.TODO(), zone, []libdns.Record{vultr.VultrRecord{ + ID: testId, + }}) + if err != nil { + fmt.Printf("ERROR: %s\n", err.Error()) + } + } else { + // Set only works if we have a record.ID + fmt.Printf("Replacing entry for %s\n", testName) + _, err = provider.SetRecords(context.TODO(), zone, []libdns.Record{vultr.VultrRecord{ + Record: libdns.RR{ + Name: testName, + Type: "TXT", + Data: fmt.Sprintf("Replacement test entry created by libdns %s", time.Now()), + TTL: time.Duration(90) * time.Second, + }, + ID: testId, + }}) + if err != nil { + fmt.Printf("ERROR: %s\n", err.Error()) + } } } else { fmt.Printf("Creating new entry for %s\n", testName) - _, err = provider.AppendRecords(context.TODO(), zone, []libdns.Record{libdns.RR{ - Type: "TXT", - Name: testName, - Data: fmt.Sprintf("This is a test entry created by libdns %s", time.Now()), - TTL: time.Duration(30) * time.Second, + _, err = provider.AppendRecords(context.TODO(), zone, []libdns.Record{vultr.VultrRecord{ + Record: libdns.RR{ + Type: "TXT", + Name: testName, + Data: fmt.Sprintf("This is a test entry created by libdns %s", time.Now()), + TTL: time.Duration(60) * time.Second, + }, + ID: testId, }}) if err != nil { fmt.Printf("ERROR: %s\n", err.Error()) diff --git a/client.go b/client.go index 9469731..fdad71e 100644 --- a/client.go +++ b/client.go @@ -42,11 +42,7 @@ func (p *Provider) getDNSEntries(ctx context.Context, domain string) ([]libdns.R } for _, entry := range dns_entries { - record, err := libdnsRecord(entry, domain) - if err != nil { - return records, err - } - + record := fromAPIRecord(entry, domain) records = append(records, record) } @@ -66,22 +62,14 @@ func (p *Provider) addDNSRecord(ctx context.Context, domain string, r libdns.Rec p.getClient() - rr := r.RR() + domainRecordReq := toDomainRecordReq(r) - domainRecordReq, err := vultrRecordReq(rr) + rec, _, err := p.client.vultr.DomainRecord.Create(ctx, domain, &domainRecordReq) if err != nil { return r, err } - rec, _, err := p.client.vultr.DomainRecord.Create(ctx, domain, &domainRecordReq) - if err != nil { - return nil, err - } - - record, err := libdnsRecord(*rec, domain) - if err != nil { - return nil, err - } + record := fromLibdnsRecord(r, rec.ID) return record, nil } @@ -92,7 +80,7 @@ func (p *Provider) removeDNSRecord(ctx context.Context, domain string, record li p.getClient() - recordId, err := GetRecordID(record) + recordId, err := getRecordId(record) if err != nil { return record, err } @@ -111,15 +99,12 @@ func (p *Provider) updateDNSRecord(ctx context.Context, domain string, record li p.getClient() - recordId, err := GetRecordID(record) + recordId, err := getRecordId(record) if err != nil { return record, err } - domainRecordReq, err := vultrRecordReq(record) - if err != nil { - return nil, err - } + domainRecordReq := toDomainRecordReq(record) err = p.client.vultr.DomainRecord.Update(ctx, domain, recordId, &domainRecordReq) if err != nil { diff --git a/helpers.go b/helpers.go index d5afb67..dec622e 100644 --- a/helpers.go +++ b/helpers.go @@ -2,167 +2,73 @@ package vultr import ( "fmt" - "net/netip" - "strconv" - "strings" "time" "github.com/libdns/libdns" "github.com/vultr/govultr/v3" ) -// Converts `govultr.DomainRecord` to `libdns.Record“ -// Taken from libdns/cloudflare, adapted for Vultr's specific format -func libdnsRecord(r govultr.DomainRecord, zone string) (libdns.Record, error) { +type VultrRecord struct { + Record libdns.RR + ID string +} + +func (r VultrRecord) RR() libdns.RR { + return r.Record +} + +// Converts a govultr.DomainRecord to libdns.Record +// Taken from libdns/digitalocean +func fromAPIRecord(r govultr.DomainRecord, zone string) VultrRecord { name := libdns.RelativeName(r.Name, zone) ttl := time.Duration(r.TTL) * time.Second - switch r.Type { - case "A", "AAAA": - addr, err := netip.ParseAddr(r.Data) - if err != nil { - return libdns.Address{}, fmt.Errorf("invalid IP address %q: %v", r.Data, err) - } + // Vultr uses a custom priority field for MX records + data := r.Data + if r.Type == "MX" { + data = fmt.Sprintf("%d %s", r.Priority, r.Data) + } - return libdns.Address{ - Name: name, - TTL: ttl, - IP: addr, - ProviderData: r.ID, - }, nil - case "CAA": - dataParts := strings.SplitN(r.Data, " ", 3) - if len(dataParts) < 3 { - return libdns.SRV{}, fmt.Errorf("record %v does not contain enough data fields; expected format: ' '", name) - } - - flags, err := strconv.Atoi(dataParts[0]) - if err != nil { - return libdns.SRV{}, fmt.Errorf("record %v contains invalid value for flags: %v", name, err) - } - - return libdns.CAA{ - Name: name, - TTL: ttl, - Flags: uint8(flags), - Tag: dataParts[1], - Value: dataParts[2], - ProviderData: r.ID, - }, nil - case "CNAME": - return libdns.CNAME{ - Name: name, - TTL: ttl, - Target: r.Data, - ProviderData: r.ID, - }, nil - case "MX": - return libdns.MX{ - Name: name, - TTL: ttl, - Preference: uint16(r.Priority), - Target: r.Data, - ProviderData: r.ID, - }, nil - case "NS": - return libdns.NS{ - Name: name, - TTL: ttl, - Target: r.Data, - ProviderData: r.ID, - }, nil - case "SRV": - // Vultr doesn't append the zone to the SRV record name, so we just need - // to parse 2 parts - parts := strings.SplitN(r.Name, ".", 2) - if len(parts) < 2 { - return libdns.SRV{}, fmt.Errorf("name %v does not contain enough fields; expected format: '_service._proto.name'", name) - } - - dataParts := strings.SplitN(r.Data, " ", 3) - if len(dataParts) < 3 { - return libdns.SRV{}, fmt.Errorf("record %v does not contain enough data fields; expected format: 'weight port target'", name) - } - - weight, err := strconv.Atoi(dataParts[0]) - if err != nil { - return libdns.SRV{}, fmt.Errorf("record %v contains invalid value for weight: %v", name, err) - } - - port, err := strconv.Atoi(dataParts[1]) - if err != nil { - return libdns.SRV{}, fmt.Errorf("record %v contains invalid value for port: %v", name, err) - } - - return libdns.SRV{ - Service: strings.TrimPrefix(parts[0], "_"), - Transport: strings.TrimPrefix(parts[1], "_"), - Name: zone, - TTL: ttl, - Priority: uint16(r.Priority), - Weight: uint16(weight), - Port: uint16(port), - Target: dataParts[2], - ProviderData: r.ID, - }, nil - case "TXT": - return libdns.TXT{ - Name: name, - TTL: ttl, - Text: r.Data, - ProviderData: r.ID, - }, nil - default: - return libdns.RR{ + return VultrRecord{ + Record: libdns.RR{ Name: name, TTL: ttl, Type: r.Type, - Data: r.Data, - }.Parse() + Data: data, + }, + ID: r.ID, } } -// Converts `libdns.Record` to `govultr.DomainRecordReq`, to be used with API -// requests. -func vultrRecordReq(r libdns.Record) (govultr.DomainRecordReq, error) { +// Converts a libdns.Record to VultrRecord with an optional ID +func fromLibdnsRecord(r libdns.Record, id string) VultrRecord { + rr := r.RR() + return VultrRecord{ + Record: rr, + ID: id, + } +} + +// Converts a libdns.Record to a govultr.DomainRecordReq +func toDomainRecordReq(r libdns.Record) govultr.DomainRecordReq { + rr := r.RR() return govultr.DomainRecordReq{ - Name: r.RR().Name, - Type: r.RR().Type, - TTL: int(r.RR().TTL.Seconds()), - Data: r.RR().Data, - }, nil + Name: rr.Name, + Type: rr.Type, + TTL: int(rr.TTL.Seconds()), + Data: rr.Data, + } } -func GetRecordID(r libdns.Record) (string, error) { - var recordId string - - switch r := r.(type) { - case libdns.Address: - recordId = r.ProviderData.(string) - return recordId, nil - case libdns.CAA: - recordId = r.ProviderData.(string) - return recordId, nil - case libdns.CNAME: - recordId = r.ProviderData.(string) - return recordId, nil - case libdns.MX: - recordId = r.ProviderData.(string) - return recordId, nil - case libdns.NS: - recordId = r.ProviderData.(string) - return recordId, nil - case libdns.SRV: - recordId = r.ProviderData.(string) - return recordId, nil - case libdns.ServiceBinding: - recordId = r.ProviderData.(string) - return recordId, nil - case libdns.TXT: - recordId = r.ProviderData.(string) - return recordId, nil - default: +func getRecordId(r libdns.Record) (string, error) { + var id string + if vr, err := r.(VultrRecord); err { + id = vr.ID } - return "", fmt.Errorf("libdns record has no provider record ID") + if id == "" { + return "", fmt.Errorf("record has no ID: %v", r) + } + + return id, nil }