Recursively counting numbers with fixed bit counts

I ran across this problem in a reddit side-bar job-ad, and was intrigued by the task (description paraphrased to decrease googleability):

Write a function
`` uint64_t bitsquares(uint64_t a, uint64_t b);``
such that it return the number of integers in [a,b] that have a square number of bits set to 1. Your function should run in less than O(b-a).

I think I see how to do it in something like logarithmic time. Here's how:

First off, we notice that we can list all the squares between 0 and 64: these are 0, 1, 9, 16, 25, 36, 49, and 64. The function I will propose will run through a binary tree of depth 64, shortcutting through branches whenever it can. In fact; changing implementation language completely, I wonder if I cannot even write it comprehensively in Haskell.

The key insight I had was that whenever you try to find the number of numbers with a bitcount matching some element of some list within the bounds of 0b0000...0000 and 0b000...01111...11, then it reduces to a simple binomial coefficient -- n choose k gives the number of numbers with k bits set among the n last. Furthermore, we can reduce the total size of the problem by removing a matching prefix from the two numbers we test from.

Hence, we trace how many bits off the top agree between the two numbers. We count the set bits among these, subtract them from each representative in the list of squares, giving us the counts we need to hit in the remainder.

Write a' for a with the agreeing prefix removed, and similarly for b'. Then the total count is the count for the reduced things from a' to 0b000...01...111 plus the count for the reduced things from 0 to b'. The reduction count for b' needs to be 1 larger than the one for a' since in one case, we are working with the prefix before the varying bit increases, and in the other, we work with the prefix after the varying bit increases -- the latter count is not really from 0 to b', but this is a useful proxy for the count from 0b0000...010...000 to b' with the additional high bit set.

In code, I managed to boil this down to:
`` import Data.Word import Data.Bits import Data.List (elemIndices)``
bitsquare :: Word64 -> Word64 -> Word64bitsquare a b = bitcountin a b squares -- # integers in [a,b] with square # of 1
s
squares = [1,4,9,16,25,36,49,64] :: [Word64]
allones = [fromIntegral (2^k - 1) | k <- [1..64]]
choose n 0 = 1
choose 0 k = 0
choose n k = (choose (n-1) (k-1)) * n `div` k
popCount :: Word64 -> Word64
popCount w = sum [1 | x <- [0..63], testBit w x]
-- # integers in [a,b] with 1-counts in counts
bitcountin :: Word64 -> Word64 -> [Word64] -> Word64
bitcountin a b counts
| a > b = 0
| a == b = if popCount b `elem` counts then 1 else 0 | (a == 0) && (b `elem` allones) = sum [choose n k | n <- [popCount b], k <- c
ounts]
| otherwise = (bitcountin a' low [c-lobits | c <- counts, c>= lobits]) +
(bitcountin hi b' [c-hibits | c <- counts, c>= hibits])
where
agreements = [(testBit a n) == (testBit b n) | n <- [0..63]]
agreeI = elemIndices False agreements
prefixIndex = last agreeI
prefixCount = sum [1 | x <- [prefixIndex..63], testBit a x]
a' = a .&. (2^prefixIndex - 1)
b' = b .&. (2^prefixIndex - 1)
low = 2^prefixIndex - 1
hi = 0
lobits = prefixCount
hibits = prefixCount+1

social