Kalman filter

This commit is contained in:
DTTerastar 2023-11-09 16:24:17 -05:00
parent 72d786c148
commit e1b1d5f908
3 changed files with 70 additions and 73 deletions

View File

@ -16,7 +16,7 @@ class ClientCallbacks : public BLEClientCallbacks {
static ClientCallbacks clientCB; static ClientCallbacks clientCB;
BleFingerprint::BleFingerprint(BLEAdvertisedDevice *advertisedDevice, float fcmin, float beta, float dcutoff) : filteredDistance{FilteredDistance(fcmin, beta, dcutoff)} { BleFingerprint::BleFingerprint(BLEAdvertisedDevice *advertisedDevice, float fcmin, float beta, float dcutoff) : filteredDistance{FilteredDistance(25, 0.1)} {
firstSeenMillis = millis(); firstSeenMillis = millis();
address = NimBLEAddress(advertisedDevice->getAddress()); address = NimBLEAddress(advertisedDevice->getAddress());
addressType = advertisedDevice->getAddressType(); addressType = advertisedDevice->getAddressType();

View File

@ -1,64 +1,73 @@
#include "FilteredDistance.h" #include "FilteredDistance.h"
#include <Arduino.h> #include <Arduino.h>
#include <cmath> FilteredDistance::FilteredDistance(float processNoise, float measurementNoise): processNoise(processNoise), measurementNoise(measurementNoise), isFirstMeasurement(true) {}
#include <numeric>
#include <vector>
FilteredDistance::FilteredDistance(float minCutoff, float beta, float dcutoff)
: minCutoff(minCutoff), beta(beta), dcutoff(dcutoff), x(0), dx(0), lastDist(0), lastTime(0), total(0), readIndex(0) {
}
void FilteredDistance::initSpike(float dist) {
for (size_t i = 0; i < NUM_READINGS; i++) {
readings[i] = dist;
}
total = dist * NUM_READINGS;
}
float FilteredDistance::removeSpike(float dist) {
total -= readings[readIndex]; // Subtract the last reading
readings[readIndex] = dist; // Read the sensor
total += readings[readIndex]; // Add the reading to the total
readIndex = (readIndex + 1) % NUM_READINGS; // Advance to the next position in the array
auto average = total / static_cast<float>(NUM_READINGS); // Calculate the average
if (std::fabs(dist - average) > SPIKE_THRESHOLD)
return average; // Spike detected, use the average as the filtered value
return dist; // No spike, return the new value
}
void FilteredDistance::addMeasurement(float dist) { void FilteredDistance::addMeasurement(float dist) {
const bool initialized = lastTime != 0; if (isFirstMeasurement) {
const unsigned long now = micros(); // Initialize state
const unsigned long elapsed = now - lastTime; state[0] = dist; // Distance
lastTime = now; state[1] = 0; // Rate of change in distance
if (!initialized) { // Initialize covariance matrix
x = dist; // Set initial filter state to the first reading covariance[0][0] = 1; // Initial guess
dx = 0; // Initial derivative is unknown, so we set it to zero covariance[0][1] = 0;
lastDist = dist; covariance[1][0] = 0;
initSpike(dist); covariance[1][1] = 1; // Initial guess
} else {
float dT = std::max(elapsed * 0.000001f, 0.05f); // Convert microseconds to seconds, enforce a minimum dT
const float alpha = getAlpha(minCutoff, dT);
const float dAlpha = getAlpha(dcutoff, dT);
dist = removeSpike(dist); lastUpdateTime = micros(); // Set the update time
x += alpha * (dist - x); isFirstMeasurement = false;
dx = dAlpha * ((dist - lastDist) / dT); return;
lastDist = x + beta * dx;
} }
// Calculate time delta for subsequent measurements
unsigned long currentTime = micros();
float deltaTime = (currentTime - lastUpdateTime) / 1.0e6; // Convert micros to seconds
lastUpdateTime = currentTime;
// Perform prediction and update
prediction(deltaTime);
update(dist);
} }
const float FilteredDistance::getDistance() const { const float FilteredDistance::getDistance() const {
return lastDist; unsigned long currentTime = micros();
float deltaTime = (currentTime - lastUpdateTime) / 1.0e6; // Convert micros to seconds
// Calculate predicted distance
float predictedDistance = state[0] + state[1] * deltaTime;
return predictedDistance;
} }
float FilteredDistance::getAlpha(float cutoff, float dT) { void FilteredDistance::prediction(float deltaTime) {
float tau = 1.0f / (2 * M_PI * cutoff); // Update state estimate
return 1.0f / (1.0f + tau / dT); state[0] += state[1] * deltaTime;
// Update covariance
covariance[0][0] += deltaTime * (covariance[1][0] + covariance[0][1]) + processNoise;
covariance[0][1] += deltaTime * covariance[1][1];
covariance[1][0] += deltaTime * covariance[1][1];
}
void FilteredDistance::update(float distanceMeasurement) {
// Kalman gain calculation
float S = covariance[0][0] + measurementNoise;
float K[2]; // Kalman gain
K[0] = covariance[0][0] / S;
K[1] = covariance[1][0] / S;
// Update state
float y = distanceMeasurement - state[0]; // measurement residual
state[0] += K[0] * y;
state[1] += K[1] * y;
// Update covariance
float covariance00_temp = covariance[0][0];
float covariance01_temp = covariance[0][1];
covariance[0][0] -= K[0] * covariance00_temp;
covariance[0][1] -= K[0] * covariance01_temp;
covariance[1][0] -= K[1] * covariance00_temp;
covariance[1][1] -= K[1] * covariance01_temp;
} }

View File

@ -1,35 +1,23 @@
#ifndef FILTEREDDISTANCE_H #ifndef FILTEREDDISTANCE_H
#define FILTEREDDISTANCE_H #define FILTEREDDISTANCE_H
#include <Arduino.h>
#define SPIKE_THRESHOLD 1.0f // Threshold for spike detection
#define NUM_READINGS 10 // Number of readings to keep track of
class FilteredDistance { class FilteredDistance {
public: public:
FilteredDistance(float minCutoff = 1.0f, float beta = 0.0f, float dcutoff = 1.0f); FilteredDistance(float processNoise, float measurementNoise);
void addMeasurement(float dist); void addMeasurement(float dist);
const float getMedianDistance() const;
const float getDistance() const; const float getDistance() const;
bool hasValue() const { return lastTime != 0; } const bool hasValue() const { return !isFirstMeasurement; }
private: private:
float minCutoff; float state[2]; // State: [0] is distance, [1] is rate of change in distance
float beta; float covariance[2][2]; // State covariance
float dcutoff; unsigned long lastUpdateTime; // Time of the last update
float x, dx; float processNoise; // Process noise (Q)
float lastDist; float measurementNoise; // Measurement noise (R)
unsigned long lastTime; bool isFirstMeasurement;
float getAlpha(float cutoff, float dT); void prediction(float deltaTime);
void update(float distanceMeasurement);
float readings[NUM_READINGS]; // Array to store readings
int readIndex; // Current position in the array
float total; // Total of the readings
void initSpike(float dist);
float removeSpike(float dist);
}; };
#endif // FILTEREDDISTANCE_H #endif // FILTEREDDISTANCE_H