BK-Tree in C++

A BK-tree is an efficient data structure for storing items for which there is some concept of distance between them. A common use is to store words for approximate string matching, using the Levenshtein distance metric. Another use is to store phrases using some measure of semantic distance in order to implement an AI system like a chatterbot.

This is a simple implementation of a BK-tree in C++.

#ifndef BK_TREE_HPP
#define BK_TREE_HPP

#include <map>
#include <memory>
#include <type_traits>

namespace MB
{

template <typename T, typename Unit, typename Metric> class bktree;

namespace detail
{

template <typename T, typename Unit, typename Metric>
class bktree_node
{
    friend class bktree<T, Unit, Metric>;
	typedef bktree_node<T, Unit, Metric> node_type;
    typedef std::unique_ptr<node_type> node_ptr_type;
	
	T value_;
	std::map<Unit, node_ptr_type> children_;

	bktree_node(const T &value)
		: value_(value)
    {
    }

	bool insert(const T& value, const Metric& distance)
    {
        bool inserted = false;
		Unit dist = distance(value, this->value_);
		if (dist > 0) {
            // Not already here 
            auto it = children_.find(dist);
            if (it == children_.end()) {
                // Attach a new edge here
                children_.insert(std::make_pair(dist, node_ptr_type(new node_type(value))));
                inserted = true;
            }
            else {
                // Follow existing edge
                inserted = it->second->insert(value, distance);
            }
            return inserted;
        }
	}

    template <class OutputIt>
	OutputIt find(const T &value, const Unit& limit, const Metric& distance, OutputIt it) const
    {
		Unit dist = distance(value, this->value_);
		if (dist <= limit) {
            // Found one
			*it++ = std::make_pair(this->value_, dist);
        }
		for (auto iter = children_.begin(); iter != children_.end(); ++iter) {
            // Follow edges between dist + limit and dist - limit
			if (dist - limit <= iter->first && dist + limit >= iter->first) {
				iter->second->find(value, limit, distance, it);
            }
		}
        return it;
	}
};

}; // namespace detail

template <typename T, typename Unit, typename Metric>
class bktree
{
private:
	typedef typename detail::bktree_node<T, Unit, Metric>::node_type node_type;
    typedef typename detail::bktree_node<T, Unit, Metric>::node_ptr_type node_ptr_type;

    node_ptr_type head_;
    const Metric distance_;
	size_t size_;

public:
	bktree(const Metric& distance = Metric())
        : head_(nullptr), distance_(distance), size_(0L)
    {
        static_assert(std::is_integral<Unit>::value, "Integral unit type required.");
    }

	bool insert(const T &value)
    {
        bool inserted = false;
		if (head_ == nullptr) {
            // Inserting the first value
			head_ = node_ptr_type(new node_type(value));
			size_ = 1;
            inserted = true;
		}
        else if (head_->insert(value, distance_)) {
            ++size_;
            inserted = true;
        }
        return inserted;
	}

    template <class OutputIt>
	OutputIt find(const T& value, const Unit& limit, OutputIt it) const
    {
		return head_->find(value, limit, distance_, it);
	}

	size_t size() const
    {
		return size_;
	}
};

}; // namespace MB

#endif // BK_TREE_HPP

Here is an example program. It loads a dictionary from a file, and then prompts the user to enter words and a limit on distance, and it will return all words within that Levenshtein distance of the word entered. I used this dictionary with it:

#include <fstream>
#include <iostream>
#include <string>

#include "bk_tree.hpp"
#include "levenshtein_distance.hpp"

int main()
{
    const char filename[] = "dictionary.txt";
    std::ifstream ifs(filename);
    if (!ifs) {
        std::cerr << "Couldn't open " << filename << " for reading\n";
        return 1;
    }
    MB::bktree<std::string, int, levenshtein_distance> tree;
    std::string word;
    while (ifs >> word) {
        tree.insert(word);
    }
    std::cout << "Inserted " << tree.size() << " words\n";
    std::vector<std::pair<std::string, int> > results;
    do {
        std::cout << "Enter a word (\"quit!\" to quit)> ";
        std::cin >> word;
        if (word != "quit!") {
            int limit;
            std::cout << "Enter a limit> ";
            std::cin >> limit;
            tree.find(word, limit, std::back_inserter(results));
            for (auto it = results.begin(); it != results.end(); ++it) {
                std::cout << it->first << "(distance " << it->second << ")\n";
            }
            results.clear();
        }
    }  while (word != "quit!");
}

Example run:

$ ./dictionary
Inserted 127142 words
Enter a word ("quit!" to quit)> mispelling
Enter a limit> 2
spelling(distance 2)
misdealing(distance 2)
impelling(distance 2)
miscalling(distance 2)
mispenning(distance 2)
misrelying(distance 2)
respelling(distance 2)
dispelling(distance 1)
misbilling(distance 2)
misspelling(distance 1)
misspellings(distance 2)
Enter a word ("quit!" to quit)> quit!
$

Here is an implementation of the Levenshtein distance:

#include <string>
#include <vector>
#include <algorithm>

namespace MB
{

namespace detail
{

template <typename T>
T min3(const T& a, const T& b, const T& c)
{
   return std::min(std::min(a, b), c);
}

}; // namespace detail

class levenshtein_distance 
{
    mutable std::vector<std::vector<unsigned int> > matrix_;

public:
    explicit levenshtein_distance(size_t initial_size = 8)
        : matrix_(initial_size, std::vector<unsigned int>(initial_size))
    {
    }

    unsigned int operator()(const std::string& s, const std::string& t) const
    {
        const size_t m = s.size();
        const size_t n = t.size();
        // The distance between a string and the empty string is the string's length
        if (m == 0) {
            return n;
        }
        if (n == 0) {
            return m;
        }
        // Size the matrix as necessary
        if (matrix_.size() < m + 1) {
            matrix_.resize(m + 1, matrix_[0]);
        }
        if (matrix_[0].size() < n + 1) {
            for (auto& mat : matrix_) {
                mat.resize(n + 1);
            }
        }
        // The top row and left column are prefixes that can be reached by
        // insertions and deletions alone
        unsigned int i, j;
        for (i = 1;  i <= m; ++i) {
            matrix_[i][0] = i;
        }
        for (j = 1; j <= n; ++j) {
            matrix_[0][j] = j;
        }
        // Fill in the rest of the matrix
        for (j = 1; j <= n; ++j) {
            for (i = 1; i <= m; ++i) {
                unsigned int substitution_cost = s[i - 1] == t[j - 1] ? 0 : 1;
                matrix_[i][j] =
                    detail::min3(matrix_[i - 1][j] + 1,         // Deletion
                    matrix_[i][j - 1] + 1,                      // Insertion
                    matrix_[i - 1][j - 1] + substitution_cost); // Substitution
            }
        }
        return matrix_[m][n];
    }
};

}; // namespace MB