295 lines
8.7 KiB
Lua
295 lines
8.7 KiB
Lua
|
-- Constructor
|
||
|
-- sc and fc are our Success and Failure Callbacks, resp.
|
||
|
local new = function(serv, sc, fc, now)
|
||
|
|
||
|
if type(serv) == "string" then serv = {serv}
|
||
|
elseif serv == nil then serv =
|
||
|
{
|
||
|
nil,
|
||
|
"1.nodemcu.pool.ntp.org",
|
||
|
"2.nodemcu.pool.ntp.org",
|
||
|
}
|
||
|
local ni = net.ifinfo(0)
|
||
|
ni = ni and ni.dhcp
|
||
|
serv[1] = ni.ntp_server or "0.nodemcu.pool.ntp.org"
|
||
|
elseif type(serv) ~= "table" then error "Bad server table"
|
||
|
end
|
||
|
|
||
|
if type(sc) ~= "function" then error "Bad success callback type" end
|
||
|
if fc ~= nil and type(fc) ~= "function" then error "Bad failure callback type" end
|
||
|
if now ~= nil and type(now) ~= "function" then error "Bad clock type" end
|
||
|
|
||
|
now = now or (rtctime and rtctime.get)
|
||
|
if now == nil then error "Need clock function" end
|
||
|
|
||
|
local _self = {servers = serv}
|
||
|
|
||
|
local _tmr -- contains the currently running timer, if any
|
||
|
local _udp -- the socket we're using to talk to the world
|
||
|
|
||
|
local _kod -- kiss of death flags accumulated accoss syncs
|
||
|
local _pbest -- best server from prior pass
|
||
|
|
||
|
local _res -- the best result we've got so far this pass
|
||
|
local _best -- best server this pass, for updating _pbest
|
||
|
|
||
|
local _six -- index of the server in serv to whom we are speaking
|
||
|
local _sat -- number of times we've tried to reach this server
|
||
|
|
||
|
-- Shut down the state machine
|
||
|
--
|
||
|
-- upvals: _tmr, _udp, _six, _sat, _res, _best
|
||
|
local function _stop()
|
||
|
-- stop any time-based callbacks and drop _tmr
|
||
|
_tmr = _tmr and _tmr:unregister()
|
||
|
|
||
|
_six, _sat, _res, _best = nil, nil, nil, nil
|
||
|
|
||
|
-- stop any UDP callbacks and drop the socket; to be safe against
|
||
|
-- knots tied in the registry, explicitly unregister callbacks first
|
||
|
if _udp then
|
||
|
_udp:on("receive", nil)
|
||
|
_udp:on("sent" , nil)
|
||
|
_udp:on("dns" , nil)
|
||
|
_udp:close()
|
||
|
_udp = nil
|
||
|
end
|
||
|
|
||
|
-- Count down _kod entries
|
||
|
if _kod then
|
||
|
for k,v in pairs(_kod) do _kod[k] = (v > 0) and (v - 1) or nil end
|
||
|
if #_kod == 0 then _kod = nil end
|
||
|
end
|
||
|
end
|
||
|
|
||
|
local nextServer
|
||
|
local doserver
|
||
|
|
||
|
-- Try communicating with the current server
|
||
|
--
|
||
|
-- upvals: now, _tmr, _udp, _best, _kod, _pbest, _res, _six
|
||
|
local function hail(ip)
|
||
|
_tmr:alarm(5000 --[[const param: SNTP_TIMEOUT]], tmr.ALARM_SINGLE, function()
|
||
|
_udp:on("sent", nil)
|
||
|
_udp:on("receive", nil)
|
||
|
return doserver("timeout")
|
||
|
end)
|
||
|
|
||
|
local txts = sntppkt.make_ts(now())
|
||
|
|
||
|
_udp:on("receive",
|
||
|
-- upvals: now, ip, txts, _tmr, _best, _kod, _pbest, _res, _six
|
||
|
function(skt, d, port, rxip)
|
||
|
-- many things constitute bad packets; drop with tmr running
|
||
|
if rxip ~= ip and ip ~= "224.0.1.1" then return end -- wrong peer (unless multicast)
|
||
|
if port ~= 123 then return end -- wrong port
|
||
|
if #d < 48 then return end -- too short
|
||
|
|
||
|
local pok, pkt = pcall(sntppkt.proc_pkt, d, txts, now())
|
||
|
|
||
|
if not pok or pkt == nil then
|
||
|
-- sntppkt can also reject the packet for having a bad cookie;
|
||
|
-- this is important to prevent processing spurious or delayed responses
|
||
|
return
|
||
|
end
|
||
|
|
||
|
_tmr:unregister()
|
||
|
skt:on("receive", nil) -- skt == _udp
|
||
|
skt:on("sent", nil)
|
||
|
|
||
|
if type(pkt) == "string" then
|
||
|
if pkt == "DENY" then -- KoD packet
|
||
|
|
||
|
if _kod and _kod[rxip] then
|
||
|
-- There was already a strike against this IP address, and now
|
||
|
-- it's permanent. We can't directly remove the IP from rotation,
|
||
|
-- but we can remove the DNS that's resolving to it, which isn't
|
||
|
-- great, but isn't the worst either.
|
||
|
if fc then fc("kod", serv[_six], _self) end
|
||
|
_kod[rxip] = nil
|
||
|
table.remove(serv, _six)
|
||
|
_six = _six - 1 -- nextServer will add one
|
||
|
else
|
||
|
_kod = _kod or {}
|
||
|
_kod[rxip] = 2
|
||
|
if fc then fc("goaway", serv[_six], _self, pkt) end
|
||
|
end
|
||
|
else
|
||
|
if fc then fc("goaway", serv[_six], _self, pkt) end
|
||
|
end
|
||
|
return nextServer()
|
||
|
end
|
||
|
|
||
|
if _pbest == serv[_six] then
|
||
|
-- this was our favorite server last time; if we don't have a
|
||
|
-- result or if we'd rather this one than the result we have...
|
||
|
if not _res or not pkt:pick(_res, true) then
|
||
|
_res = pkt
|
||
|
_best = _pbest
|
||
|
end
|
||
|
else
|
||
|
-- this was not our favorite server; take this result if we have no
|
||
|
-- other option or if it compares favorably to the one we have, which
|
||
|
-- might be from our favorite from last pass.
|
||
|
if not _res or _res:pick(pkt, _pbest == _best) then
|
||
|
_res = pkt
|
||
|
_best = serv[_six]
|
||
|
end
|
||
|
end
|
||
|
|
||
|
return nextServer()
|
||
|
end)
|
||
|
|
||
|
return _udp:send(123, ip,
|
||
|
-- '#' == 0x23: version 4, mode 3 (client), no LI
|
||
|
"#\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0"
|
||
|
.. txts)
|
||
|
end
|
||
|
|
||
|
-- upvals: _sat, _six, _udp, hail, _self
|
||
|
function doserver(err)
|
||
|
if _sat == 2 --[[const param: MAX_SERVER_ATTEMPTS]] then
|
||
|
if fc then fc(err, serv[_six], _self) end
|
||
|
return nextServer()
|
||
|
end
|
||
|
_sat = _sat + 1
|
||
|
|
||
|
return _udp:dns(serv[_six], function(skt, ip)
|
||
|
skt:on("dns", nil) -- skt == _udp
|
||
|
if ip == nil then return doserver("dns") else return hail(ip) end
|
||
|
end)
|
||
|
end
|
||
|
|
||
|
-- Move on to the next server or finish a pass
|
||
|
--
|
||
|
-- upvals: fc, serv, sc, _best, _pbest, _res, _sat, _six
|
||
|
function nextServer()
|
||
|
if _six >= #serv then
|
||
|
if _res then
|
||
|
_pbest = _best
|
||
|
local res = _res
|
||
|
local best = _best
|
||
|
_stop()
|
||
|
return sc(res:totable(), best, _self)
|
||
|
else
|
||
|
_stop()
|
||
|
if fc then return fc("all", #serv, _self) else return end
|
||
|
end
|
||
|
end
|
||
|
|
||
|
_six = _six + 1
|
||
|
_sat = 0
|
||
|
return doserver()
|
||
|
end
|
||
|
|
||
|
-- Poke all the servers and invoke the user's callbacks
|
||
|
--
|
||
|
-- upvals: _stop, _udp, _ENV, _tmr, _six, nextServer
|
||
|
function _self.sync()
|
||
|
_stop()
|
||
|
_udp = net.createUDPSocket()
|
||
|
_tmr = tmr.create()
|
||
|
_udp:listen() -- on random port
|
||
|
_six = 0
|
||
|
nextServer()
|
||
|
end
|
||
|
|
||
|
function _self.stop()
|
||
|
local res, best = _res, _best
|
||
|
_stop()
|
||
|
return res and res:totable(), best
|
||
|
end
|
||
|
|
||
|
return _self
|
||
|
|
||
|
end
|
||
|
|
||
|
-- A utility function which applies a result to the rtc
|
||
|
local update_rtc = function(res, obj)
|
||
|
local rate = nil
|
||
|
if obj.rtc_last ~= nil then
|
||
|
-- adjust drift compensation. We have three pieces of information:
|
||
|
--
|
||
|
-- our idea of time at rx (res.rx_*),
|
||
|
-- our idea of time at the last sync (obj.rtc_last.rx_*)
|
||
|
-- the measured theta now (res.theta_us)
|
||
|
--
|
||
|
-- We're going to integrate the theta signal over time and use
|
||
|
-- that to mediate the rate we set, making this a PI controller,
|
||
|
-- but we might take big steps if theta gets too bad.
|
||
|
local ok, err_int
|
||
|
local raw = res.raw
|
||
|
ok, rate, err_int = pcall(raw.drift_compensate, raw, obj.rtc_last,
|
||
|
obj.rtc_err_int or 0)
|
||
|
if not ok then
|
||
|
rate = nil -- don't set the rate this time
|
||
|
obj.rtc_last = nil -- or next time
|
||
|
else
|
||
|
obj.rtc_last = res.raw
|
||
|
obj.rtc_err_int = err_int
|
||
|
end
|
||
|
else
|
||
|
obj.rtc_last = res.raw
|
||
|
end
|
||
|
|
||
|
if rate == nil then
|
||
|
-- update time (and cut rate, in case it's gotten out of hand)
|
||
|
local now_s, now_us, now_r = rtctime.get()
|
||
|
local new_s, new_us = now_s + res.theta_s, now_us + res.theta_us
|
||
|
if new_us > 1000000 then
|
||
|
new_s = new_s + 1
|
||
|
new_us = new_us - 1000000
|
||
|
end
|
||
|
rtctime.set(new_s, new_us, now_r / 2)
|
||
|
else
|
||
|
-- just change the rate
|
||
|
rtctime.set(nil, nil, rate)
|
||
|
end
|
||
|
|
||
|
return rate ~= nil
|
||
|
end
|
||
|
|
||
|
-- Default operation
|
||
|
--
|
||
|
-- upvals: new, update_rtc
|
||
|
local go = function(servs, period, sc, fc)
|
||
|
local sntpobj = new(servs,
|
||
|
-- wrap the success callback with a utility function for managing the rtc
|
||
|
-- and polling frequency
|
||
|
function(res, serv, self)
|
||
|
local ok = update_rtc(res, self)
|
||
|
|
||
|
-- if the rate estimator thinks it has this under control, only poll
|
||
|
-- the server occasionally. Otherwise, bother it more frequently,
|
||
|
-- in a "bursty" way
|
||
|
if ok and ((self.rtc_burst or 0) == 0)
|
||
|
then self.tmr:interval(period or 1800000)
|
||
|
self.rtc_burst = nil
|
||
|
else self.tmr:interval(30000)
|
||
|
self.rtc_burst = (ok and self.rtc_burst or 40) - 1
|
||
|
end
|
||
|
|
||
|
-- invoke the user's callback
|
||
|
if sc then return sc(res, serv, self) end
|
||
|
end,
|
||
|
fc)
|
||
|
|
||
|
local t = tmr.create()
|
||
|
sntpobj.tmr = t
|
||
|
t:alarm(60000, tmr.ALARM_AUTO, function() collectgarbage() sntpobj.sync() end)
|
||
|
sntpobj.sync()
|
||
|
|
||
|
return sntpobj
|
||
|
end
|
||
|
|
||
|
-- from sntppkt
|
||
|
-- luacheck: ignore
|
||
|
local _lfs_strings = "theta_s", "theta_us", "delta", "delta_r", "epsilon_r",
|
||
|
"leapind", "stratum", "rx_s", "rx_us"
|
||
|
|
||
|
return {
|
||
|
update_rtc = update_rtc,
|
||
|
new = new,
|
||
|
go = go,
|
||
|
}
|