diff --git a/apriori.py b/apriori.py index 8ae4682..b7c4c78 100644 --- a/apriori.py +++ b/apriori.py @@ -107,9 +107,10 @@ def getSupport(item): remain = item.difference(element) if len(remain) > 0: confidence = getSupport(item)/getSupport(element) + lift = getSupport(item)/(getSupport(element)*getSupport(remain)) if confidence >= minConfidence: toRetRules.append(((tuple(element), tuple(remain)), - confidence)) + confidence,lift)) return toRetItems, toRetRules @@ -118,9 +119,9 @@ def printResults(items, rules): for item, support in sorted(items, key=lambda (item, support): support): print "item: %s , %.3f" % (str(item), support) print "\n------------------------ RULES:" - for rule, confidence in sorted(rules, key=lambda (rule, confidence): confidence): + for rule, confidence, lift in sorted(rules, key=lambda (rule, confidence, lift): confidence): pre, post = rule - print "Rule: %s ==> %s , %.3f" % (str(pre), str(post), confidence) + print "Rule: %s ==> %s , %.3f, %.3f" % (str(pre), str(post), confidence, lift) def dataFromFile(fname): diff --git a/test_apriori.py b/test_apriori.py index 9eeb0b1..9f72bc6 100644 --- a/test_apriori.py +++ b/test_apriori.py @@ -171,8 +171,8 @@ def test_print_results_should_have_results_in_defined_format(self): (('beer', 'rice'), 0.5) ] rules = [ - ((('beer',), ('rice',)), 0.6666666666666666), - ((('rice',), ('beer',)), 1.0) + ((('beer',), ('rice',)), 0.6666666666666666, 1.3333333333333333), + ((('rice',), ('beer',)), 1.0, 1.3333333333333333) ] printResults(items, rules) @@ -180,8 +180,8 @@ def test_print_results_should_have_results_in_defined_format(self): expected += "0.500\nitem: ('rice',) , 0.500\nitem: ('beer', " expected += "'rice') , 0.500\nitem: ('beer',) , 0.750\n\n" expected += "------------------------ RULES:\nRule: ('beer',) " - expected += "==> ('rice',) , 0.667\nRule: ('rice',) ==> " - expected += "('beer',) , 1.000\n" + expected += "==> ('rice',) , 0.667, 1.333\nRule: ('rice',) ==> " + expected += "('beer',) , 1.000, 1.333\n" self.assertEqual(fake_output.getvalue(), expected) def test_run_apriori_should_get_items_and_rules(self): @@ -211,8 +211,8 @@ def test_run_apriori_should_get_items_and_rules(self): self.assertEqual(items, expected) expected = [ - ((('beer',), ('rice',)), 0.6666666666666666), - ((('rice',), ('beer',)), 1.0) + ((('beer',), ('rice',)), 0.6666666666666666, 1.3333333333333333), + ((('rice',), ('beer',)), 1.0, 1.3333333333333333) ] self.assertEqual(rules, expected)