abhi divekar - 8 months ago 42

C++ Question

I have the following question, which is actually from a coding test I recently took:

A function

`f(n) = a*n + b*n*(floor(log(n)/log(2))) + c*n*n*n`

At a particular value, let

`f(n) = k`

Given

`k, a, b, c`

`n`

For a given value of

`k`

`n`

`1 <= n < 2^63-1`

0 < a, b < 100

0 <= c < 100

0 < k < 2^63-1

The logic here is that since

`f(n)`

`n`

The code I wrote was as follows:

`#include<iostream>`

#include<stdlib.h>

#include<math.h>

using namespace std;

unsigned long long logToBase2Floor(unsigned long long n){

return (unsigned long long)(double(log(n))/double(log(2)));

}

#define f(n, a, b, c) (a*n + b*n*(logToBase2Floor(n)) + c*n*n*n)

unsigned long long findNByBinarySearch(unsigned long long k, unsigned long long a, unsigned long long b, unsigned long long c){

unsigned long long low = 1;

unsigned long long high = (unsigned long long)(pow(2, 63)) - 1;

unsigned long long n;

while(low<=high){

n = (low+high)/2;

cout<<"\n\n k= "<<k;

cout<<"\n f(n,a,b,c)= "<<f(n,a,b,c)<<" low = "<<low<<" mid="<<n<<" high = "<<high;

if(f(n,a,b,c) == k)

return n;

else if(f(n,a,b,c) < k)

low = n+1;

else high = n-1;

}

return 0;

}

I then tried it with a few test cases:

`int main(){`

unsigned long long n, a, b, c;

n = (unsigned long long)pow(2,63)-1;

a = 99;

b = 99;

c = 99;

cout<<"\nn="<<n<<" a="<<a<<" b="<<b<<" c="<<c<<" k = "<<f(n, a, b, c);

cout<<"\nANSWER: "<<findNByBinarySearch(f(n, a, b, c), a, b, c)<<endl;

n = 1000;

cout<<"\nn="<<n<<" a="<<a<<" b="<<b<<" c="<<c<<" k = "<<f(n, a, b, c);

cout<<"\nANSWER: "<<findNByBinarySearch(f(n, a, b, c), a, b, c)<<endl;

return 0;

}

Then something weird happened.

The code works for the test case

`n = (unsigned long long)pow(2,63)-1;`

`n=1000`

`n=1000 a=99 b=99 c=99 k = 99000990000`

k= 99000990000

f(n,a,b,c)= 4611686018427387904 low = 1 mid=4611686018427387904 high = 9223372036854775807

...

...

k= 99000990000

f(n,a,b,c)= 172738215936 low = 1 mid=67108864 high = 134217727

k= 99000990000

f(n,a,b,c)= 86369107968 low = 1 mid=33554432 high = 67108863

k= 99000990000

f(n,a,b,c)= 129553661952 low = 33554433 mid=50331648 high = 67108863**

...

...

k= 99000990000

f(n,a,b,c)= 423215328047139441 low = 37748737 mid=37748737 high = 37748737

ANSWER: 0

Something didn't seem right mathematically. How was it that the value of

`f(1000)`

`f(33554432)`

So I tried the same code in Python, and got the following values:

`>>> f(1000, 99, 99, 99)`

99000990000L

>>> f(33554432, 99, 99, 99)

3740114254432845378355200L

So, the value is definitely larger.

- What is happening exactly?
- How should I solve it?

Answer

The problem is here:

```
unsigned long long low = 1;
// Side note: This is simply (2ULL << 62) - 1
unsigned long long high = (unsigned long long)(pow(2, 63)) - 1;
unsigned long long n;
while (/* irrelevant */) {
n = (low + high) / 2;
// Some stuff that do not modify n...
f(n, a, b, c) // <-- Here!
}
```

In the first iteration, you have `low = 1`

and `high = 2^63 - 1`

, which mean that `n = 2^63 / 2 = 2^62`

. Now, let's look at `f`

:

```
#define f(n, a, b, c) (/* I do not care about this... */ + c*n*n*n)
```

You have `n^3`

in `f`

, so for `n = 2^62`

, `n^3 = 2^186`

, which is probably way too large for your `unsigned long long`

(which is likely to be 64-bits long).

Preamble, I am using `ull_t`

because I am lazy, and you should avoid macro in C++, prefer using a function and let the compiler inline it. Also, I prefer a loop against using the `log`

function to compute the log2 of an `unsigned long long`

.

```
using ull_t = unsigned long long;
constexpr auto log2 (ull_t n) {
ull_t log = 0;
while (n >>= 1) ++log;
return log;
}
constexpr auto f (ull_t n, ull_t a, ull_t b, ull_t c) {
return a * n + b * n * log2(n) + c * n * n * n;
}
```

The main problem here is the way you find your bounds for `n`

, so you should try to find a better upper (and maybe lower) bound(s) for `n`

to start with. You should also split your function into two cases:

- If
**c is 0**, you have`f(n, a, b, 0) = a * n + b * n + log2(n) = n * (a + b * log2(n))`

.

Depending on the value of `a`

and `b`

, getting a bound for `n`

may be difficult (see comments). I personally would try all possible values of `log2(n)`

(there are only 64 possible values given the range of `n`

, so this would have the same complexity as a binary search) and for each of these value, check if the corresponding `n`

matches the given `k`

:

```
// Find n when c = 0
constexpr auto find_n (ull_t k, ull_t a, ull_t b)
for (ull_t l = 1; l < 64; ++l) {
auto n = k / (a + b * l);
if (f(n, a, b, 0) == k)
return n;
}
return 0ULL;
}
```

If

**c is not 0**, the dominant term is clearly`c * n * n * n`

when`n`

is large (we do not care when`n`

is small), so you should starts with something like (approximation, this will not work for all values of`k`

/`n`

):*high = 2*^{((64 - log2(c)) / 4)}