Guessing subreddits with the Prediction API
Posted by Nick Johnson | Filed under python, app-engine, prediction-api, google-storage
Edit: Now with a live demo!
I've written before about the new BigQuery and Prediction APIs, and promised to demonstrate them. Let's take a look at the Prediction API first.
The Prediction API, as I've explained, does a restricted form of machine learning, as a web service. Currently, it supports categorizing textual and numeric data into a preset list of categories. The example given in the talk - language detection - is a good one, but I wanted to come up with something new. A few ideas presented themselves:
- Training on movie/book reviews to try and predict the score given based on the text
- Training on product descriptions to try and predict their rating
- Training on Reddit submissions to try and predict the subreddit a new submission belongs in
All three have promise, but the first could suffer from the fact that the prediction API as it currently stands doesn't understand a relationship between categories - it would have no way to know that the '5 star' rating tag is 'closer to' the '4 star' one than the '1 star' tag. The second seems very ambitious, and it's not clear there's enough information to do that. The third one, though, seemed just right.
With that in mind, I started collecting data on reddit submissions. Using the API, I collected every submission to reddit between 2010-06-02 13:39:31 UTC and 2010-06-09 23:11:34, using this simple script:
import logging import time import urllib2 import simplejson REDDIT_URL = 'http://www.reddit.com/r/all/new/.json' class RedditProcessor(object): def __init__(self, outfile): self.seen = set() self.outfile = outfile self.interval = 120 def run(self): self.init() while True: data = self.get_reddit_data() if data: self.process_reddit_data(data) time.sleep(self.interval) def init(self): fh = open(self.outfile, 'r') self.seen.update(simplejson.loads(x)['id'] for x in fh) fh.close() logging.info("Read %d IDs", len(self.seen)) self.writer = open(self.outfile, 'a') def get_reddit_data(self): try: request = urllib2.urlopen(REDDIT_URL) return simplejson.loads(request.read()) except urllib2.URLError, e: logging.exception("Request failed") return None def process_reddit_data(self, data): if 'data' not in data or 'children' not in data['data']: logging.warn("Data does not contain expected keys: %r", data) return try: num_written = 0 for entry in data['data']['children']: entry_data = entry['data'] if entry_data['id'] not in self.seen: simplejson.dump(entry_data, self.writer) self.writer.write('\n') num_written += 1 self.seen.add(entry_data['id']) logging.info("Wrote %d new entries", num_written) except Exception, e: logging.exception("Error processing entries") def main(): logging.basicConfig(level=logging.DEBUG) processor = RedditProcessor("reddit_dump.json") processor.run() if __name__ == '__main__': main()
In total, I collected a little over 75MB of JSON-encoded data, comprising 72,986 submissions. I then determined the 20 subreddits with the most submissions over that time, and generated a training dataset from just the submissions to those subreddits. This subset made up 42,753 submissions, or about 58% of the original. Submissions were randomly split into either the training set (98%) or the validation set (2%):
import csv import random import simplejson fh = open('reddit_dump.json', 'r') subreddits = {} for line in fh: data = simplejson.loads(line) subreddits[data['subreddit']] = subreddits.get(data['subreddit'], 0) + 1 top20 = sorted(subreddits.items(), key=lambda x:x[1], reverse=True)[:20] print top20 top20_set = set(x[0] for x in top20) fh = open('reddit_dump.json', 'r') training = csv.writer(open('reddit_training.csv', 'w')) validation = csv.writer(open('reddit_validation.csv', 'w')) for line in fh: data = simplejson.loads(line) if data['subreddit'] not in top20_set: continue row = (data['subreddit'].encode('utf-8'), data['title'].encode('utf-8'), data['domain'].encode('utf-8')) if random.random() >= 0.98: validation.writerow(row) else: training.writerow(row)
The top 20 reddits, incidentally, are:
Submissions | |
---|---|
reddit.com | 14578 |
pics | 4157 |
AskReddit | 3375 |
reportthespammers | 3258 |
politics | 3162 |
funny | 2176 |
WTF | 1773 |
gaming | 1367 |
worldnews | 938 |
videos | 849 |
atheism | 834 |
Music | 833 |
technology | 732 |
trees | 703 |
comics | 639 |
nsfw | 611 |
circlejerk | 600 |
news | 567 |
environment | 537 |
DoesAnybodyElse | 537 |
Next, I uploaded the new dataset to the Google Storage service, using the 'gsutil' tool, and followed the getting started instructions for the Prediction API to start it training on my model:
$ curl -X POST -H "Content-Type:application/json" -d "{\"data\":{}}" -H "Authorization: GoogleLogin auth=$XAPI_AUTH" https://www.googleapis.com/prediction/v1/training?data=reddit-dump%24reddit_training.csv {"data":{"data":"reddit-dump$reddit_training.csv"}} $ curl -H "Authorization: GoogleLogin auth=$XAPI_AUTH" https://www.googleapis.com/prediction/v1/training/reddit-dump%2freddit_training.csv {"data":{"data":"reddit-dump/reddit_training.csv","modelinfo":"Training has not completed."}}
Once training completed, we can see the Prediction API's own estimate of accuracy:
$ curl -H "Authorization: GoogleLogin auth=$XAPI_AUTH" https://www.googleapis.com/prediction/v1/training/reddit-dump%2freddit_training_2.csv {"data":{"data":"reddit-dump/reddit_training_2.csv","modelinfo":"estimated accuracy: 0.61"}}
That's not bad, though perhaps not as stellar as we might have hoped. For perspective, if we simply picked the most popular category (reddit.com) every time, we'd only get it right about 34% of the time. Let's run our own test with the data we set aside for validation, though, and see how it goes. Here's our test script:
import csv import logging import simplejson import sys import urllib import urllib2 def predict(target, auth, text): request_data = { 'data': { 'input': { 'text': text, } } } request = urllib2.Request( 'https://www.googleapis.com/prediction/v1/training/%s/predict' % target, data=simplejson.dumps(request_data), headers={ "Authorization": "GoogleLogin auth=%s" % auth, "Content-Type": "application/json", }) response = urllib2.urlopen(request) response_data = simplejson.load(response) return response_data['data']['output']['output_label'] def main(args): infile, auth, target = args[1:4] reader = csv.reader(open(infile)) count = 0 correct = 0 for tag, text, domain in reader: count += 1 retries = 0 while retries < 10: try: result = predict(target, auth, [text.strip('\r\n\\n'), domain]) break except urllib2.HTTPError, e: retries += 1 if result == tag: correct += 1 else: print ("Incorrectly predicted %r (%s) as %s (should be %s)" % (text, domain, result, tag)) print "%d of %d predicted correctly." % (correct, count) if __name__ == '__main__': main(sys.argv)
And the output...
$ python training_check.py reddit_validation.csv $XAPI_AUTH reddit-dump%2freddit_training_2.csv ... 484 of 857 predicted correctly.
56% - not far off the system's own estimate. Let's take a look at some of the failures to predict, though:
Incorrectly predicted 'Small Businesses Still Worried About Reform Bill' (nytimes.com) as politics (should be reddit.com) Incorrectly predicted 'Seriously Reddit?' (imgur.com) as pics (should be reddit.com) Incorrectly predicted '3 YEARS of crappy skin spam' (reddit.com) as reportthespammers (should be reddit.com) Incorrectly predicted ' President George W. Bush A.K.A. The Decider ("I\'m not much of an E-mailing kind of guy") - Joins Facebook.' (news.bbc.co.uk) as politics (should be reddit.com) Incorrectly predicted 'Meet Leroy Stick, The Man Behind @BPGlobalPR' (gizmodo.com) as funny (should be reddit.com) Incorrectly predicted "I don't understand why this is so funny...\n[N S F Charles Guiteau]" (imgur.com) as pics (should be reddit.com) Incorrectly predicted 'Freaky New Dead Space 2 Screenshots' (dasreviews.com) as gaming (should be reddit.com) Incorrectly predicted 'Once upon a time...' (i.imgur.com) as pics (should be reddit.com) Incorrectly predicted 'The Flotilla Choir' (youtube.com) as funny (should be reddit.com) Incorrectly predicted '15 Million Unemployed and 41,000 non census jobs created last month.' (cnbc.com) as politics (should be reddit.com) Incorrectly predicted 'EU ruling prompts calls for overhaul of gambling laws' (telegraph.co.uk) as worldnews (should be reddit.com) Incorrectly predicted 'I woke up this morning to find my car had been towed out of my driveway [pic]' (imgur.com) as pics (should be reddit.com) Incorrectly predicted 'Old White Guy Late Night Talk Show' (buzzfeed.com) as funny (should be reddit.com)
Hm. There's an awful lot that look correct, but were actually posted to reddit.com. The reverse is true, too:
Incorrectly predicted 'Coulomb Details Huge Electric Car Charging Infrastructure Plans' (earthtechling.com) as reddit.com (should be technology) Incorrectly predicted 'AT&T Announces Tethering Support For iPhone 4G / HD; Updates Data Plans' (mygadgetnews.com) as reddit.com (should be technology) Incorrectly predicted '"And those are the facts" Same faulty arguments from all theists.' (youtube.com) as reddit.com (should be atheism) Incorrectly predicted 'For VPS users! A great link!' (gallantfx.com) as reddit.com (should be technology) Incorrectly predicted '\xe2\x80\x9cIsraeli commandos had paintball guns\xe2\x80\x9d \xe2\x80\x93 Israeli Ambassador' (rt.com) as reddit.com (should be worldnews) Incorrectly predicted 'Cool, Rayguns!' (cnn.com) as reddit.com (should be technology) Incorrectly predicted 'India vows to sabotage ACTA.' (arstechnica.com) as reddit.com (should be technology) Incorrectly predicted 'Spy on my dog!' (ustream.tv) as reddit.com (should be funny) Incorrectly predicted 'California poised to OK supertoxic pesticide' (sfgate.com) as reddit.com (should be environment) Incorrectly predicted 'Voters of Reykjavik know where their towels are' (news.bbc.co.uk) as reddit.com (should be WTF)
And, of course, there's ordinary errors too:
Incorrectly predicted "WHAT'S UPPPPPPPP?" (i.imgur.com) as pics (should be funny) Incorrectly predicted 'Google Chrome extensions the new facebook apps when it comes to privacy?' (imgur.com) as technology (should be pics) Incorrectly predicted 'DAE actually understand the words to the songs "Smells Like Teen Spirit" or "Song 2" ?' (self.DoesAnybodyElse) as politics (should be DoesAnybodyElse) Incorrectly predicted 'BP May Sell Prudhoe Bay Stake as Spill Costs Mount ' (bloomberg.com) as politics (should be news) Incorrectly predicted 'Bananas are pervert ' (i.imgur.com) as pics (should be funny) Incorrectly predicted 'New BP Logo...anyone else?' (imgur.com) as pics (should be environment) Incorrectly predicted 'Petition for Net Neutrality - 74 Democrats sold you out to AT&T, Verizon and Comcast ' (self.technology) as politics (should be technology) Incorrectly predicted 'Neighborhood Watch' (i.imgur.com) as pics (should be funny) Incorrectly predicted 'U MAY BE A GEEK IF...' (streetwalkermag.ca) as funny (should be technology) Incorrectly predicted 'Ron Paul says criticism of Obama on oil spill is wrong ' (content.usatoday.com) as news (should be politics) Incorrectly predicted 'Watching CNN and the cap appears to be in place over the well' (self.worldnews) as gaming (should be worldnews)
Based on this, it looks to me like we can talk about three distinct types of categorization errors:
- reddit.com vs other subreddit posts. This occurs because the reddit.com subreddit gets an amazingly diverse set of submissions, and doesn't really have its own 'type' of posts like most of the other subreddits to.
- Ambiguous categorization. Many pics are also funny, and a lot of news is also politics. In some cases, you could even argue that the submitter got it wrong, and the learning API got it right!
- Ordinary, legitimate errors. The last one above is a good example of that.
There's not much we can do about category 3, and it's debatable that anything needs to be done about category 2, but let's see what we can do about category 1 by excluding the reddit.com subreddit from consideration, and concentrating on everything else:
$curl -X POST -H "Content-Type:application/json" -H "Authorization: GoogleLogin auth=$XAPI_AUTH" -d "{data:{}}" https://www.googleapis.com/prediction/v1/training?data=reddit-dump%2freddit_training_nodotcom.csv ... $ curl -H "Authorization: GoogleLogin auth=$XAPI_AUTH" https://www.googleapis.com/prediction/v1/training/reddit-dump%2freddit_training_nodotcom.csv {"data":{"data":"reddit-dump/reddit_training_nodotcom.csv","modelinfo":"estimated accuracy: 0.63"}} $ python training_check.py reddit_validation_nodotcom.csv $XAPI_AUTH reddit-dump%2freddit_training_nodotcom.csv ... 375 of 581 predicted correctly.
The API has modestly increased its estimate to 63% accuracy; our test supports that with a more significant increase to 64% (from 56%). Of the remaining miscategorizations, most fall into category 2 - here's the last 10 miscategorized results, unedited - judge for yourself:
Incorrectly predicted "As rumours swell that the government staged 7/7, victims' relatives call for a proper inquiry" (dailymail.co.uk) as funny (should be news) Incorrectly predicted 'Wow.' (pici.se) as WTF (should be nsfw) Incorrectly predicted 'Gabe: What if Coruscant had a newspaper? What would their stupid political cartoons look like?' (penny-arcade.com) as funny (should be comics) Incorrectly predicted "Abramoff released from Md. prison to halfway house. Can we have a pool to see how long before he's back at work with the Republicans?" (google.com) as worldnews (should be politics) Incorrectly predicted "Lady Gaga 'Alejandro' music video released - Very, very sexy. Potentially [NSFW]." (newsfeed.time.com) as funny (should be WTF) Incorrectly predicted 'Official Decree - Take Notice.' (farm2.static.flickr.com) as science (should be funny) Incorrectly predicted 'Philosophy of Ghost In the Shell' (video.google.com) as WTF (should be funny) Incorrectly predicted 'Taco Bell Petitions National Reserve To Circulate More $2 Bills' (knucklesunited.com) as politics (should be WTF) Incorrectly predicted 'Seriously, how is Japan not fat from eating this?!' (i.imgur.com) as pics (should be WTF) Incorrectly predicted 'Cuba declines UN mission related to torture' (wireupdate.com) as politics (should be worldnews)
In conclusion, even for more subjective categorization tasks like this, and with a fairly large number of categories (20), thus increasing the number of possible ways to be wrong, the Prediction API performs fairly well, with most of the errors being ones humans would plausibly make as well. Determining the actual error rate - by going through the errors and categorizing them based on legitimacy - would be an interesting task.
If there's interest, I may put up an App Engine app that allows you to query for Reddit predictions yourself - but right now, this blog post is already 2 days late, so it'll have to wait. Live demo now, er, live! Try it out here.