Compare commits

..

5 Commits

Author SHA1 Message Date
Alexandre Almeida
eec1c11265 Remove fmt.Println from provider method 2025-07-23 14:14:03 +02:00
Alexandre Almeida
a45084a707 Query API for record ID if wrapped struct doesn't have it 2025-07-23 13:18:27 +02:00
Alexandre Almeida
44cf557ad2 Simplify if block in fromAPIRecord 2025-06-07 15:31:25 +02:00
Alexandre Almeida
08474842cc Fix parsing of MX and SRV records to use Vultr's priority field 2025-06-07 15:27:01 +02:00
Alexandre Almeida
62cb30921f Rewrite logic for libdns v1.0.0, do not use ProviderData field 2025-06-07 14:55:54 +02:00
3 changed files with 129 additions and 198 deletions

View File

@ -7,8 +7,7 @@ import (
"time"
"github.com/libdns/libdns"
"github.com/libdns/vultr"
"github.com/libdns/vultr/v2"
)
func main() {
@ -22,6 +21,7 @@ func main() {
fmt.Printf("ZONE not set\n")
return
}
shouldCleanup := os.Getenv("DELETE_RECORDS") == "true"
provider := vultr.Provider{APIToken: token}
@ -44,42 +44,46 @@ func main() {
for _, record := range records {
fmt.Printf("%s (.%s): %s, %s\n", record.RR().Name, zone, record.RR().Data, record.RR().Type)
recordId, err := vultr.GetRecordID(record)
if err != nil {
fmt.Printf("ERROR: %s\n", err.Error())
}
if record.RR().Name == testName {
testId = recordId
testId = record.(vultr.VultrRecord).ID
}
}
if testId != "" {
// fmt.Printf("Delete entry for %s (id:%s)\n", testName, testId)
// _, err = provider.DeleteRecords(context.TODO(), zone, []libdns.Record{libdns.Record{
// ID: testId,
// }})
// if err != nil {
// fmt.Printf("ERROR: %s\n", err.Error())
// }
// Set only works if we have a record.ID
fmt.Printf("Replacing entry for %s\n", testName)
_, err = provider.SetRecords(context.TODO(), zone, []libdns.Record{libdns.TXT{
Name: testName,
Text: fmt.Sprintf("Replacement test entry created by libdns %s", time.Now()),
TTL: time.Duration(30) * time.Second,
ProviderData: testId,
if shouldCleanup {
fmt.Printf("Delete entry for %s (id:%s)\n", testName, testId)
_, err = provider.DeleteRecords(context.TODO(), zone, []libdns.Record{vultr.VultrRecord{
ID: testId,
}})
if err != nil {
fmt.Printf("ERROR: %s\n", err.Error())
}
} else {
// Set only works if we have a record.ID
fmt.Printf("Replacing entry for %s\n", testName)
_, err = provider.SetRecords(context.TODO(), zone, []libdns.Record{vultr.VultrRecord{
Record: libdns.RR{
Name: testName,
Type: "TXT",
Data: fmt.Sprintf("Replacement test entry created by libdns %s", time.Now()),
TTL: time.Duration(90) * time.Second,
},
ID: testId,
}})
if err != nil {
fmt.Printf("ERROR: %s\n", err.Error())
}
}
} else {
fmt.Printf("Creating new entry for %s\n", testName)
_, err = provider.AppendRecords(context.TODO(), zone, []libdns.Record{libdns.RR{
_, err = provider.AppendRecords(context.TODO(), zone, []libdns.Record{vultr.VultrRecord{
Record: libdns.RR{
Type: "TXT",
Name: testName,
Data: fmt.Sprintf("This is a test entry created by libdns %s", time.Now()),
TTL: time.Duration(30) * time.Second,
TTL: time.Duration(60) * time.Second,
},
ID: testId,
}})
if err != nil {
fmt.Printf("ERROR: %s\n", err.Error())

View File

@ -2,6 +2,7 @@ package vultr
import (
"context"
"fmt"
"sync"
"golang.org/x/oauth2"
@ -27,9 +28,6 @@ func (p *Provider) getClient() error {
}
func (p *Provider) getDNSEntries(ctx context.Context, domain string) ([]libdns.Record, error) {
p.client.mutex.Lock()
defer p.client.mutex.Unlock()
p.getClient()
listOptions := &govultr.ListOptions{}
@ -42,11 +40,7 @@ func (p *Provider) getDNSEntries(ctx context.Context, domain string) ([]libdns.R
}
for _, entry := range dns_entries {
record, err := libdnsRecord(entry, domain)
if err != nil {
return records, err
}
record := fromAPIRecord(entry, domain)
records = append(records, record)
}
@ -66,22 +60,14 @@ func (p *Provider) addDNSRecord(ctx context.Context, domain string, r libdns.Rec
p.getClient()
rr := r.RR()
domainRecordReq := toDomainRecordReq(r)
domainRecordReq, err := vultrRecordReq(rr)
rec, _, err := p.client.vultr.DomainRecord.Create(ctx, domain, &domainRecordReq)
if err != nil {
return r, err
}
rec, _, err := p.client.vultr.DomainRecord.Create(ctx, domain, &domainRecordReq)
if err != nil {
return nil, err
}
record, err := libdnsRecord(*rec, domain)
if err != nil {
return nil, err
}
record := fromLibdnsRecord(r, rec.ID)
return record, nil
}
@ -92,9 +78,19 @@ func (p *Provider) removeDNSRecord(ctx context.Context, domain string, record li
p.getClient()
recordId, err := GetRecordID(record)
recordId, err := getRecordId(record)
if err != nil {
return record, err
// try to get the ID from API if we don't have it
records, err := p.getDNSEntries(ctx, domain)
if err != nil {
return record, fmt.Errorf("could not get record ID from API")
}
for _, rec := range records {
if rec.RR().Name == record.RR().Name {
recordId = rec.(VultrRecord).ID
}
}
}
err = p.client.vultr.DomainRecord.Delete(ctx, domain, recordId)
@ -111,15 +107,22 @@ func (p *Provider) updateDNSRecord(ctx context.Context, domain string, record li
p.getClient()
recordId, err := GetRecordID(record)
recordId, err := getRecordId(record)
if err != nil {
return record, err
// try to get the ID from API if we don't have it
records, err := p.getDNSEntries(ctx, domain)
if err != nil {
return record, fmt.Errorf("could not get record ID from API")
}
domainRecordReq, err := vultrRecordReq(record)
if err != nil {
return nil, err
for _, rec := range records {
if rec.RR().Data == record.RR().Data {
recordId = rec.(VultrRecord).ID
}
}
}
domainRecordReq := toDomainRecordReq(record)
err = p.client.vultr.DomainRecord.Update(ctx, domain, recordId, &domainRecordReq)
if err != nil {

View File

@ -2,8 +2,6 @@ package vultr
import (
"fmt"
"net/netip"
"strconv"
"strings"
"time"
@ -11,158 +9,84 @@ import (
"github.com/vultr/govultr/v3"
)
// Converts `govultr.DomainRecord` to `libdns.Record“
// Taken from libdns/cloudflare, adapted for Vultr's specific format
func libdnsRecord(r govultr.DomainRecord, zone string) (libdns.Record, error) {
type VultrRecord struct {
Record libdns.RR
ID string
}
func (r VultrRecord) RR() libdns.RR {
return r.Record
}
// Converts a govultr.DomainRecord to libdns.Record
// Taken from libdns/digitalocean
func fromAPIRecord(r govultr.DomainRecord, zone string) VultrRecord {
name := libdns.RelativeName(r.Name, zone)
ttl := time.Duration(r.TTL) * time.Second
switch r.Type {
case "A", "AAAA":
addr, err := netip.ParseAddr(r.Data)
if err != nil {
return libdns.Address{}, fmt.Errorf("invalid IP address %q: %v", r.Data, err)
// Vultr uses a custom priority field for MX and SRV records
data := r.Data
if r.Type == "MX" || r.Type == "SRV" {
data = fmt.Sprintf("%d %s", r.Priority, r.Data)
}
return libdns.Address{
Name: name,
TTL: ttl,
IP: addr,
ProviderData: r.ID,
}, nil
case "CAA":
dataParts := strings.SplitN(r.Data, " ", 3)
if len(dataParts) < 3 {
return libdns.SRV{}, fmt.Errorf("record %v does not contain enough data fields; expected format: '<flags> <tag> <value>'", name)
}
flags, err := strconv.Atoi(dataParts[0])
if err != nil {
return libdns.SRV{}, fmt.Errorf("record %v contains invalid value for flags: %v", name, err)
}
return libdns.CAA{
Name: name,
TTL: ttl,
Flags: uint8(flags),
Tag: dataParts[1],
Value: dataParts[2],
ProviderData: r.ID,
}, nil
case "CNAME":
return libdns.CNAME{
Name: name,
TTL: ttl,
Target: r.Data,
ProviderData: r.ID,
}, nil
case "MX":
return libdns.MX{
Name: name,
TTL: ttl,
Preference: uint16(r.Priority),
Target: r.Data,
ProviderData: r.ID,
}, nil
case "NS":
return libdns.NS{
Name: name,
TTL: ttl,
Target: r.Data,
ProviderData: r.ID,
}, nil
case "SRV":
// Vultr doesn't append the zone to the SRV record name, so we just need
// to parse 2 parts
parts := strings.SplitN(r.Name, ".", 2)
if len(parts) < 2 {
return libdns.SRV{}, fmt.Errorf("name %v does not contain enough fields; expected format: '_service._proto.name'", name)
}
dataParts := strings.SplitN(r.Data, " ", 3)
if len(dataParts) < 3 {
return libdns.SRV{}, fmt.Errorf("record %v does not contain enough data fields; expected format: 'weight port target'", name)
}
weight, err := strconv.Atoi(dataParts[0])
if err != nil {
return libdns.SRV{}, fmt.Errorf("record %v contains invalid value for weight: %v", name, err)
}
port, err := strconv.Atoi(dataParts[1])
if err != nil {
return libdns.SRV{}, fmt.Errorf("record %v contains invalid value for port: %v", name, err)
}
return libdns.SRV{
Service: strings.TrimPrefix(parts[0], "_"),
Transport: strings.TrimPrefix(parts[1], "_"),
Name: zone,
TTL: ttl,
Priority: uint16(r.Priority),
Weight: uint16(weight),
Port: uint16(port),
Target: dataParts[2],
ProviderData: r.ID,
}, nil
case "TXT":
return libdns.TXT{
Name: name,
TTL: ttl,
Text: r.Data,
ProviderData: r.ID,
}, nil
default:
return libdns.RR{
return VultrRecord{
Record: libdns.RR{
Name: name,
TTL: ttl,
Type: r.Type,
Data: r.Data,
}.Parse()
Data: data,
},
ID: r.ID,
}
}
// Converts `libdns.Record` to `govultr.DomainRecordReq`, to be used with API
// requests.
func vultrRecordReq(r libdns.Record) (govultr.DomainRecordReq, error) {
// Converts a libdns.Record to VultrRecord with an optional ID
func fromLibdnsRecord(r libdns.Record, id string) VultrRecord {
rr := r.RR()
return VultrRecord{
Record: rr,
ID: id,
}
}
// Converts a libdns.Record to a govultr.DomainRecordReq
func toDomainRecordReq(r libdns.Record) govultr.DomainRecordReq {
data := r.RR().Data
var priority int
// Vultr uses a custom priority field for MX and SRV records
if rec, ok := r.RR().Parse(); ok == nil {
if r.RR().Type == "MX" {
mx := rec.(libdns.MX)
priority = int(mx.Preference)
data = mx.Target
} else if r.RR().Type == "SRV" {
srv := rec.(libdns.SRV)
priority = int(srv.Priority)
data = data[strings.Index(data, " ")+1:]
}
}
rr := r.RR()
return govultr.DomainRecordReq{
Name: r.RR().Name,
Type: r.RR().Type,
TTL: int(r.RR().TTL.Seconds()),
Data: r.RR().Data,
}, nil
Name: rr.Name,
Type: rr.Type,
TTL: int(rr.TTL.Seconds()),
Data: data,
Priority: &priority,
}
}
func GetRecordID(r libdns.Record) (string, error) {
var recordId string
switch r := r.(type) {
case libdns.Address:
recordId = r.ProviderData.(string)
return recordId, nil
case libdns.CAA:
recordId = r.ProviderData.(string)
return recordId, nil
case libdns.CNAME:
recordId = r.ProviderData.(string)
return recordId, nil
case libdns.MX:
recordId = r.ProviderData.(string)
return recordId, nil
case libdns.NS:
recordId = r.ProviderData.(string)
return recordId, nil
case libdns.SRV:
recordId = r.ProviderData.(string)
return recordId, nil
case libdns.ServiceBinding:
recordId = r.ProviderData.(string)
return recordId, nil
case libdns.TXT:
recordId = r.ProviderData.(string)
return recordId, nil
default:
func getRecordId(r libdns.Record) (string, error) {
var id string
if vr, err := r.(VultrRecord); err {
id = vr.ID
}
return "", fmt.Errorf("libdns record has no provider record ID")
if id == "" {
return "", fmt.Errorf("record has no ID: %v", r)
}
return id, nil
}