Kevin Wu's Blog Home

2023-2-20

LeetCode and Linear Algebra (or: How I learned to stop worrying and love the math)

I recently was stumped by Naming a Company (LC hard). We are given an array of words (all lowercase, a-z). For each pair of words, we can form a new string by swapping their first letters. For example, (peanut, butter) -> beanut putter.

This new string is valid if neither of the words are present in the original array. We’re supposed to find the number of valid strings.

Solving the problem

Original approach

My approach to this problem was to merge all the words with the same suffix (i.e., removing the first letter). Then, we keep track of the first letters used for each unique suffix.

After some thinking, we can see that words from different suffixes can only be paired if both words’ first letters are not common to both suffixes. To show this, consider the words aa, ba, bb, cb.

Here are the common suffixes and the first letters used for each:

  1. suffix a: a, b
  2. suffix b: b, c

ba cannot be paired with any word with the suffix b since after swapping its first letter with that word, bb is formed, which is an existing word. However, aa can be paired with bc since they have different suffixes and their first letters are not common to both suffixes.

Thus, the number of valid strings created for a pair of suffixes is the product of the number of first letters used by each suffix (but not both). To find the number of valid strings, we can sum that over all pairs of suffixes! Unfortunately, it’s O(n2)O(n^2) as in the worst case, the suffixes never repeat. No dice.

Published solution

At this point, I gave up and peeked at the solution.

As it turns out, instead of grouping words that have the same suffix, we can flip it around and group words that have the same first letter! Now, two words that have different first letters can be paired if neither of their suffixes are common to both prefixes.

We can use the same example as above (aa, ba, bb, cb):

  1. prefix a: a
  2. prefix b: a, b
  3. prefix c: b

aa cannot be paired with bb as the suffix a is common to the suffixes of first letters a and b, but aa can be paired with cb.

To find the answer, we iterate through pairs of first letters and sum the length of the suffixes unique to the first letter times the length of the suffixes unique to the second. Since there are only 26 possible first letters, this has time complexity O(2626n)O(26 \cdot 26 \cdot n) (ignoring hashing of the words). Much better!

Using bits

We can represent both of these approaches using a boolean matrix (matrix of 1s and 0s).

The first approach

Since we have a mapping from suffix to used first letters, we have a matrix of size m×26m \times 26 (mm being the number of unique prefixes).

Now, what does the solution algorithm look like when we represent the mapping using matrices?

for each pair of rows (r1, r2) in the matrix:
    let s = r1 & r2 (bitwise AND, the intersection of the two rows)
    let a = r1 & !s (elements in r1 that are not in the intersection of the two rows)
    let b = r2 & !s (elements in r2 that are not in the intersection of the two rows)
    add sum(a) * sum(b) to the answer

Using some boolean algebra, we can simplify a = r1 & !s:

a = r1 & !(r1 & r2)
  = r1 & (!r1 | !r2)
  = (r1 & !r1) | (!r1 & r2)
  = false | (!r1 & r2)
  = !r1 & r2

Using the same logic, we can see that b = r1 & !r2.

We end up with a pretty elegant algorithm:

for each pair of rows (r1, r2) in the matrix:
    add sum(!r1 & r2) * sum(r1 & !r2)

The second approach

For the first approach, we considered the rows of the matrix (i.e., the first letters for each unique prefix). For the second approach, we can instead use the columns of the matrix (i.e., the unique prefixes for each first letter).

The rest of the algorithm is the same:

for each pair of columns (c1, c2) in the matrix:
    add sum(!c1 & c2) * sum(c1 & !c2)

Wait a second… how could this get us the same result? How is it remotely possible that doing weird bitwise ANDs, NOTs, products, and sums are the same doing it across pairs of rows or pairs of columns?

Linear algebra

Transpositions? Trace?

Those of you who have taken linear algebra might recognize the second approach using the columns instead of the rows as a transposition. Indeed, we can instead transpose the matrix, and use the same algorithm for both approaches.

But this still doesn’t explain why they both get the same result. If we rephrased our algorithm in math terms, maybe we can gain some insight.

How do we represent doing bitwise ANDs on pairs of possible rows? Matrix multiplication!

Consider a simpler version of the above algorithm (removing the bitwise NOT):

for each pair of rows (r1, r2) in the matrix:
    add sum(r1 & r2) * sum(r1 & r2)

If we let MM be our matrix, each entry of MMTMM^T (MM multiplied by its transpose) is the sum of the bitwise AND of a pair of rows.

Since each sum is multiplied by itself, we can square each term. Thus, our algorithm is simply summing the square of each element of MMTMM^T. We can now prove that transposition doesn’t affect the result using some math.

To do this, we can use some properties of a matrix’s trace, which is the sum of its elements along the main diagonal. For any matrix AA, the trace of ATAA^TA (or AATAA^T) is the sum of the squares of each element of AA. Furthermore, it can be proven that tr(AB)=tr(BA)\mathrm{tr}(AB) = \mathrm{tr}(BA) if A,BA, B are m×nm \times n and n×mn \times m real matrices, respectively.

Therefore,

the sum of squares of MMTM M^T

=tr((MMT)(MMT))= \mathrm{tr}((M M^T)(M M^T))

=tr(M(MTMMT))= \mathrm{tr}(M (M^T M M^T))

=tr((MTMMT)M)= \mathrm{tr}((M^T M M^T) M)

=tr((MTM)(MTM))= \mathrm{tr}((M^T M) (M^T M))

== the sum of squares of MTMM^T M.

Since we get the same sum by replacing MM with MTM^T, we can do the calculation across rows or columns and get the same result!

What about the bitwise NOT?

“But this doesn’t explain why it works with the bitwise NOT!” We’re almost there, I promise.

To refresh your memory, here’s the final algorithm we found:

for each pair of rows (r1, r2) in the matrix:
    add sum(!r1 & r2) * sum(r1 & !r2)

To incorporate the bitwise NOT, let f(M)f(M) be a function that outputs the bitwise NOT of a matrix. We can then represent !r1 & r2 as an element in f(M)MTf(M) M^T. Similarly, r1 & !r2 is an element in Mf(MT)M f(M^T). Putting it all together, we must multiply these two matrices element-wise. This is called the Hadamard product (denoted by ABA \circ B, and totally new to me!).

Thus, the final output is simply the sum of each entry in (f(M)MT)(Mf(MT))(f(M) M^T) \circ (M f(M^T)). Our goal is to show that we get the same result if we replace MM with MTM^T.

Notice that transposing the bitwise NOT of a matrix is the same as taking the bitwise NOT of the transposed matrix. In other terms, f(M)T=f(MT)f(M)^T = f(M^T).

However, we still have a Hadamard product in our expression, which I’m not sure how to deal with. The final piece of the puzzle is to convert it back to a trace! Show for yourself that tr(ATB)\mathrm{tr}(A^T B) is equal to the sum of the elements of ABA \circ B.

Therefore,

the sum of elements of (f(M)MT)(Mf(MT))(f(M) M^T) \circ (M f(M^T))

=tr((f(M)MT)T(Mf(MT)))= \mathrm{tr} \left((f(M) M^T)^T (M f(M^T)) \right) by converting it into a trace

=tr(Mf(M)TMf(M)T)= \mathrm{tr} \left(M f(M)^T M f(M)^T \right) as (AB)T=BTAT(AB)^T = B^T A^T

=tr((Mf(M)TMf(M)T)T)= \mathrm{tr} \left((M f(M)^T M f(M)^T \right)^T) since the main diagonal is the same after a transpose

=tr(f(M)MTf(M)MT)= \mathrm{tr} \left(f(M) M^T f(M) M^T \right) as (ABCD)T=DTCTBTAT(ABCD)^T = D^T C^T B^T A^T

=tr(MTf(M)MTf(M))= \mathrm{tr} \left(M^T f(M) M^T f(M) \right) as tr(AB)=tr(BA)\mathrm{tr}(AB) = \mathrm{tr}(BA)

=tr((f(MT)M)T(MTf(M)))= \mathrm{tr} \left( (f(M^T) M)^T (M^T f(M)) \right)

== the sum of elements of (f(MT)M)(MTf(M))(f(M^T) M) \circ (M^T f(M))

With this, we can see that transposing MM does not affect the value. Performing the calculation by rows or by columns doesn’t affect the result at all.

Conclusion

By now, you might have noticed that this is true for any matrix MM, not just one containing 1s and 0s. This result also holds for any function f(M)f(M) that has the property that f(M)T=f(MT)f(M)^T = f(M^T). The two places where f(M)f(M) is used can even be completely different functions.

Check it out for yourself!

import numpy as np

def calc(mat):
    a = np.square(mat) @ mat.T
    b = mat @ np.pow(mat.T, 3)
    return np.trace(a.T @ b)

M = np.random.randint(99, size=(16, 10))
row_sum = calc(M)
col_sum = calc(M.T)

print(M)
print(row_sum, col_sum)
print(row_sum == col_sum)

From LeetCode, to bits, to linear algebra - math and CS are the same in more ways than I thought.