diff --git a/machine_learning/apriori_algorithm.py b/machine_learning/apriori_algorithm.py index 09a89ac2..5c3e2bab 100644 --- a/machine_learning/apriori_algorithm.py +++ b/machine_learning/apriori_algorithm.py @@ -11,6 +11,7 @@ WIKI: https://en.wikipedia.org/wiki/Apriori_algorithm Examples: https://www.kaggle.com/code/earthian/apriori-association-rules-mining """ +from collections import Counter from itertools import combinations @@ -44,11 +45,16 @@ def prune(itemset: list, candidates: list, length: int) -> list: >>> prune(itemset, candidates, 3) [] """ + itemset_counter = Counter(tuple(item) for item in itemset) pruned = [] for candidate in candidates: is_subsequence = True for item in candidate: - if item not in itemset or itemset.count(item) < length - 1: + item_tuple = tuple(item) + if ( + item_tuple not in itemset_counter + or itemset_counter[item_tuple] < length - 1 + ): is_subsequence = False break if is_subsequence: