Source code for kbbq.bloom

"""
Methods to create and manipulate khmer bloom filters.

The base khmer bloom filter object is a Nodegraph.
Nodegraph.get(kmer) returns 0 or 1 depending on whether the kmer is present
Nodegraph.count(kmer) sets the value to 1 (present)
Nodegraph.add(khmer) sets the value to 1 (present) and returns true if it's new.
Nodegraph.consume(str) does add on each kmer in the string and returns the number of kmers added
Nodegraph.get_kmer_counts(str) does get on each kmer in the string and returns a list of counts
Nodegraph.save(filename) saves to filename
Nodegraph.load(filename) loads from filename
Nodegraph.hash(kmer) returns the hash of the k-mer
Nodegraph.get_kmer_hashes(kmer) hashes each k
khmer.calc_expected_collisions(nodegraph, force = False, max_false_pos = .15) will return false positive rate and issue a warning if it's too high.
when force is true, don't exit if false positive is higher than max.
#may also want to try a Nodetable at some point. It should have the same fns.

ksize = 32
maxmem = khmer.khmer_args.memory_setting('8G')
fake_args = namedtuple(max_memory_usage = maxmem, n_tables = 4)
tablesize = calculate_graphsize(fake_args, 'nodegraph', multiplier = 1.0)
nodegraph = khmer.Nodegraph(ksize, tablesize, fake_args.n_tables) 

"""

import khmer
import khmer.khmer_args
import numpy as np
import scipy.stats
import collections

[docs]def create_empty_nodegraph(ksize = 32, max_mem = '8G'): """ Create an empty nodegraph. From this point, fill it from a file or start filling it with count. """ mem = khmer.khmer_args.memory_setting(max_mem) fake_args = collections.namedtuple('fake_args', ['max_memory_usage', 'n_tables']) fargs = fake_args(max_memory_usage = mem, n_tables = 4) tablesize = khmer.khmer_args.calculate_graphsize(fargs, 'nodegraph', multiplier = 1.0) return khmer.Nodegraph(ksize, tablesize, fargs.n_tables)
[docs]def count_read(read, graph, sampling_rate): """ Put ReadData read in the countgraph. Each k-mer will be added with p = sampling_rate """ # ksize = graph.ksize() # #https://rigtorp.se/2011/01/01/rolling-statistics-numpy.html # kmers = numpy.lib.stride_stricks.as_strided(read.seq, # shape = (len(read.seq) - ksize + 1, ksize), # strides = read.seq.strides * 2) #2D array of shape (nkmers, ksize) #probably will be fastest to get all hashes in C then select the ones i want hashes = np.array(graph.get_kmer_hashes(np.str.join('',read.seq)), dtype = np.ulonglong) sampled = np.random.choice([True, False], size = hashes.shape, replace = True, p = [sampling_rate, 1.0 - sampling_rate]) contains_n = np.any(rolling_window(read.seq, graph.ksize()) == 'N', axis = 1) sampled[contains_n] = False for h in hashes[sampled]: graph.count(h.item())
[docs]def kmers_in_graph(read, graph): """ Query the graph for each kmer in read and return a :class:`numpy.ndarray` of bools. The returned array has length len(read) - ksize + 1 """ ingraph = np.array(graph.get_kmer_counts(np.str.join('',read.seq)), dtype = np.bool) ingraph[np.any(rolling_window(read.seq, graph.ksize()) == 'N', axis = 1)] = False return ingraph
[docs]def overlapping_kmers_in_graph(read, graph): """ Get the number of kmers overlapping each read position that are in the graph. The returned array has length len(read). """ ksize = graph.ksize() kmers = kmers_in_graph(read, graph) num_in_graph = np.zeros(len(read), dtype = np.int) num_in_graph_windowed = rolling_window(num_in_graph, ksize) np.add.at(num_in_graph_windowed, (kmers,), 1) #this will modify num_in_graph return num_in_graph
[docs]def overlapping_kmers_possible(read, ksize): """ Get the possible number of kmers overlapping each read position. For example, position 1 will always have only 1 overlapping kmer, position 2 will have 2, etc. """ if ksize > len(read): raise ValueError(f"ksize {ksize} too small for read with length {len(read)}. \ ksize must be <= the read length.") num_possible = np.zeros(len(read), dtype = np.int) koverlaps = rolling_window(num_possible, ksize) np.add.at(koverlaps, Ellipsis, 1) return num_possible
[docs]def p_kmer_added(sampling_rate, graph): """ The probability a kmer was added to the graph, including the false positive rate. P(A) = 1 - ( 1 - sampling_rate ) ^ n, where n is the assumed highest multiplicity of a weak kmer. When sampling_rate >= .1, n = 2. Otherwise, n = .2 / sampling_rate. This function returns P*(A) = P(A) + B - B*P(A), where B is the false positive rate. The point is to perform a binomial test, with p = p_kmer_added, k = n_kmers_in_graph, N = n_kmers_possible. The null hypothesis is that the kmer is not an error, and this is p under that assumption. Since the kmer is not an error, it was observed 2 or greater times. The alternative hypothesis is that the kmer is an error, so the true probability the kmer was added to the graph was less than the number returned by this function. I think the better model would be a negative binomial. In this case, we have no idea how many times an overlapping kmer appeared in our dataset. We do know that there are r kmers that failed to be added to the hash and that each time an overlapping kmer appeared, there was a *sampling_rate* probability that it was added to the hash. Then the number of overlapping kmers in the dataset can be predicted with a negative binomial distribution with r = (# in hash) and p = 1 - sampling_rate. If this distribution is X, the read coverage at that site is X / (# k-mers that could overlap the site); in most cases, X / k. Thus with this technique, we can calculate the distribution of site-by-site coverage in the dataset. In marginalizing, we sum the probability vector of each X then divide by the number of sites. Once we're done, we must also divide by X again, since each base with coverage x (element of X) was counted x times. In our sequencing model, the number of sequenced reads at any position is Poisson. Because errors are correlated, the number of sequenced erroneous reads is an overdispersed Poisson. This can be accurately modeled with a different negative binomial. Thus from our inference of coverage we should be able to fit a mixture of negative binomials that represents the coverage expected in the dataset, and this binomial will tell us the probability that a read is an error given its coverage. Note that we use the negative binomial for two distinct purposes: one to estimate the coverage given a site at a read, and one to model the total coverage of the dataset. Notably, this approach suggests an optimal choice for the sampling rate since the distribution becomes less predictive for some coverages at extreme points. A strategy for picking the theoretically optimal sampling rate can likely be found. """ fpr = khmer.calc_expected_collisions(graph, force = False, max_false_pos = .15) exp = .2 / sampling_rate if sampling_rate < .1 else 2 p_a = 1 - ( 1 - sampling_rate ) ** exp p_added = p_a + fpr - fpr * p_a return p_added
[docs]def calculate_thresholds(p_added, ksize): """ Calculate the thresholds. If the number of overlapping kmers is less than this number, the base is inferred to be erroneous. """ dists = [scipy.stats.binom(n = i, p = p_added) for i in range(1, ksize + 1, 1)] return np.array([0] + [int(d.ppf(.995)) for d in dists])
[docs]def infer_errors(overlapping, possible, thresholds): """ Perform a binomial test to infer which bases are erroneous. Given a vector of the number of overlapping kmers in the hash and the number of possible kmers in the hash for each position, infer whether each base is erroneous. This is done with essentially a binomial test. We evaluate the probability the sample came from a binomial distribution with p = p_added. We do a right-tailed test with the null hypothesis that all kmers overlapping the site are erroneous. We assume the multiplicity of a weak kmer is less than ~2, so this is reflected in p_added. We only use the right tail because a lower p parameter supports the null, while a high value supports the alternative. """ #n = 0 doesn't make sense return overlapping <= thresholds[possible] #if overlapping < thresholds, it's an error.
[docs]def infer_read_errors(read, graph, thresholds): """ Return an array of errors given a graph and the thresholds. """ overlapping = overlapping_kmers_in_graph(read, graph) possible = overlapping_kmers_possible(read, graph.ksize()) errors = infer_errors(overlapping, possible, thresholds) assert len(errors) == len(read) return errors
[docs]def add_trusted_kmers(read, graph): """ Add trusted kmers to graph. """ ksize = graph.ksize() hashes = np.array(graph.get_kmer_hashes(np.str.join('',read.seq)), dtype = np.ulonglong) errors = rolling_window(read.errors, ksize) trusted_kmers = np.all(~errors, axis = 1) for h in hashes[trusted_kmers]: graph.count(h.item())
[docs]def find_longest_trusted_block(trusted_kmers): """ Given a boolean array describing whether each kmer is trusted, return a pair of indices to the start and end of the block. The indices will be standard for python ranges; inclusive on the left side and exclusive on the right side. """ transitions = np.nonzero(np.diff(trusted_kmers) != 0)[0] + 1 #indices where changes occur segments = np.concatenate([np.array([0]), transitions, np.array([len(trusted_kmers)])]) #indices including beginning and end segment_pairs = rolling_window(segments, 2) #numsegments, 2 segment_lens = np.diff(segments) #lengths of each segment trusted_segments = trusted_kmers[segment_pairs[:,0]] #hack trusted_segment_lens = np.array(segment_lens, copy = True) trusted_segment_lens[~trusted_segments] = 0 longest_trusted = np.argmax(trusted_segment_lens) #argmax will pick the first in case of tie return segment_pairs[longest_trusted,0], segment_pairs[longest_trusted,1]
[docs]def infer_errors_from_trusted_kmers(read, graph): """ Return an array of errors and a bool describing whether multiple corrections were made. """ trusted_kmers = kmers_in_graph(read, graph) errors = np.zeros(len(read), dtype = np.bool) multiple = False if np.all(trusted_kmers) or np.all(~trusted_kmers): #do nothing return errors, multiple else: longest_trusted = find_longest_trusted_block(trusted_kmers) ksize = graph.ksize() #right side #kmer trusted_kmers[longest_trusted[1]] is an error #for kmer k in range(len(seq) + ksize - 1) #base to look at is base k + ksize - 1 # longest_trusted_len = longest_trusted[1] - longest_trusted[0] # if longest_trusted_len < ksize: # k = longest_trusted[1] # else: # k = longest_trusted[1] - 1 # trusted_kmers[k] = False k = longest_trusted[1] while k < len(trusted_kmers): if trusted_kmers[k]: k = k + 1 else: cor_len, base, m = correction_len(read.seq[k:], graph, right = True) multiple = multiple or m if cor_len is not None: #correction found # if read.seq[k + ksize - 1] != base: #not equal to current base errors[k + ksize - 1] = True read.seq[k + ksize - 1] = base k = k + cor_len[0] else: #could not find a fix; try chopping up the read and trying again #need to make sure this doesn't loop forever somehow; in lighter this #happens once and only once #it should be OK since we will eventually run out of trusted kmers # print('k+ksize:',k + ksize - 1) subread = read[(k+ksize-1):] if len(subread) > (len(read) / 2) or (len(subread) > ksize * 2): # print('trying again') errors[k+ksize-1:], m = infer_errors_from_trusted_kmers(subread, graph) multiple = multiple or m break #left side k = longest_trusted[0] - 1 # = -1 if the trusted block is at the start while k >= 0: if trusted_kmers[k]: k = k - 1 else: cor_len, base, m = correction_len(read.seq[:(k+ksize)], graph, right = False) multiple = multiple or m if cor_len is not None: #correction found errors[k] = True read.seq[k] = base k = k - cor_len[0] else: #could not find a fix; try chopping up the read and trying again #need to make sure this doesn't loop forever somehow; in lighter this #happens once and only once #it should be OK since we will eventually run out of trusted kmers subread = read[:(k+ksize)] if len(subread) > (len(read) / 2) or (len(subread) > ksize * 2): errors[:(k+ksize)], m = infer_errors_from_trusted_kmers(subread, graph) multiple = multiple or m break return errors, multiple
[docs]def correction_len(seq, graph, right = True): """ Get the number of corrections given an ndarray of sequence. This value is an array of values between between 1 and ksize, or len(seq) if no correction can be made. The length of the array represents the number of results if there is a tie. The last element returned says whether there were multiple corrections that led to a trusted kmer. This function is in desperate need of a refactor. """ ksize = graph.ksize() kmers = rolling_window(seq.copy(), ksize) #note the memory is preserved across windows, #so changing one base in one kmer will change every kmer! #this is exactly the behavior we want so we can exploit this for efficiency largest_possible_fix = min(ksize, len(kmers)) bases = list("ACGT") counts = np.zeros(len(bases), dtype = np.int) if right: idx = (0,-1) possible_fixes = list(range(largest_possible_fix)) else: idx = (-1,0) possible_fixes = list(range(-1,-largest_possible_fix-1, -1)) for b, base in enumerate(bases): kmers[idx] = base #kmers[0,-1] or kmers[-1,0] for i in possible_fixes: # print(np.str.join('',kmers[i]), i) if not graph.get(np.str.join('',kmers[i])): if right: counts[b] = i else: counts[b] = (-i - 1) break else: #we made it through every possible fix if largest_possible_fix == len(kmers): #we ran out of kmers, try to extend if right: last_kmer = np.str.join('',kmers[-1]) for j in range(ksize - len(kmers)): #ksize - len(kmers) is the number of kmers we didn't see at the end for extra in bases: if graph.get(last_kmer[1:] + extra): last_kmer = last_kmer[1:] + extra break #stop looking at more bases because we found one else: #we didn't find an appropriate base; we're done extending counts[b] = largest_possible_fix + j break else: #we extended and got to see all ksize kmers counts[b] = ksize else: #left side last_kmer = np.str.join('',kmers[0]) for j in range(ksize - len(kmers)): #ksize - len(kmers) is the number of kmers we didn't see at the end for extra in bases: if graph.get(extra + last_kmer[:-1]): last_kmer = extra + last_kmer[:-1] break #stop looking at more bases because we found one else: #we didn't find an appropriate base; we're done extending counts[b] = largest_possible_fix + j break else: #we extended and got to see all ksize kmers counts[b] = ksize else: #if every kmer is corrected and we saw ksize kmers, we move forward k counts[b] = largest_possible_fix if np.all(counts == 0): #end correction if we cannot find any #https://github.com/mourisl/Lighter/blob/df39031f8254f8351852f9f8b51b643475226ea0/ErrorCorrection.cpp#L574 return None, None, False else: m = np.amax(counts) #we may also want to test if there are multiple maxima largest = counts[counts == m] #if there are multiple this will be an array largest[largest > largest_possible_fix] = largest_possible_fix #if we extended we only want to continue k if len(largest) > 1 and largest[0] < largest_possible_fix: #if there's a tie and the tie can't fix all k kmers largest = None return largest, bases[counts.argmax()], len(counts[counts != 0]) > 1
[docs]def fix_one(seq, graph): """ If we don't start with an anchor we just fix one base that produces the largest number of trusted kmers. """ original_seq = seq.copy() ksize = graph.ksize() bases = list("ACGT") best_fix_len = 0 best_fix_base = None best_fix_pos = None for i in range(len(seq)): modified_seq = original_seq.copy() kmers = rolling_window(modified_seq, ksize) for b in bases: modified_seq[i] = b trusted_kmers = np.array(graph.get_kmer_counts(np.str.join('',modified_seq)), dtype = np.bool) start_pos = int(min(max(i - ksize/2 + 1, 0), len(seq) - ksize)) if trusted_kmers[start_pos]: num_in_graph = np.sum(trusted_kmers[start_pos:]) if num_in_graph > best_fix_len: best_fix_base = b best_fix_pos = i best_fix_len = num_in_graph return best_fix_len, best_fix_base, best_fix_pos
[docs]def fix_overcorrection(read, ksize, minqual = 6, window = 20, threshold = 4, adjust = False): """ The threshold is adjusted when there are not multiple options for any correction. It is only adjusted for positions [window:-window] """ corrections = read.errors.copy() corrections_windowed = rolling_window(corrections, window) correction_count = np.array(rolling_window(corrections, window), dtype = np.double) seq = rolling_window(read.seq, window) quals = rolling_window(read.qual, window) correction_count[np.logical_and(seq == 'N', corrections_windowed)] = 0 correction_count[np.logical_and(quals < minqual, corrections_windowed)] = .5 thresh_ary = np.repeat(threshold, len(correction_count)) if adjust: thresh_ary[window:-(window-1)] += 1 overcorrected = np.greater(np.sum(correction_count, axis = 1), thresh_ary) overcorrected_sites = np.zeros(len(corrections), dtype = np.bool) overcorrected_sites_windowed = rolling_window(overcorrected_sites, window) overcorrected_sites_windowed[overcorrected] = corrections_windowed[overcorrected] #overcorrected_sites is modified #now anything within k of an overcorrected site we call overcorrected num_fixed = np.sum(overcorrected_sites) fixed_before = 0 corrections_windowed = rolling_window(corrections, ksize) overcorrected_sites_windowed = rolling_window(overcorrected_sites, ksize) while num_fixed > 0: #mark anything within k of an overcorrected site as overcorrected #i'm guessing most of the time this just gets rid of all corrections any_overcorrected = np.any(overcorrected_sites_windowed, axis = 1) overcorrected_sites_windowed[any_overcorrected,:] = corrections_windowed[any_overcorrected] fixed_now = np.sum(overcorrected_sites) num_fixed = fixed_now - fixed_before fixed_before = fixed_now corrections[overcorrected_sites] = False return corrections
[docs]def rolling_window(a, window): """ Use stride tricks to reshape an array into an array of sliding windows. Different values in the array will point to the same place in memory, so copy the array before altering it if that behavior is undesired. From https://rigtorp.se/2011/01/01/rolling-statistics-numpy.html """ if window > a.shape[-1]: raise ValueError(f"Window size {window} is too large for array with last \ dimension of size {a.shape[-1]}") shape = a.shape[:-1] + (a.shape[-1] - window + 1, window) strides = a.strides + (a.strides[-1],) return np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)