-
Notifications
You must be signed in to change notification settings - Fork 26
Expand file tree
/
Copy pathmain.rs
More file actions
112 lines (99 loc) · 3.72 KB
/
main.rs
File metadata and controls
112 lines (99 loc) · 3.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
use discorec::RecommenderBuilder;
use pgvector::Vector;
use postgres::{Client, NoTls};
use std::collections::HashMap;
use std::error::Error;
use std::fs::File;
use std::io::{BufRead, BufReader};
use std::path::Path;
use std::string::FromUtf8Error;
fn main() -> Result<(), Box<dyn Error>> {
// https://grouplens.org/datasets/movielens/100k/
let movielens_path = std::env::var("MOVIELENS_100K_PATH").or(Err("Set MOVIELENS_100K_PATH"))?;
let mut client = Client::configure()
.host("localhost")
.dbname("pgvector_example")
.user(std::env::var("USER")?.as_str())
.connect(NoTls)?;
client.execute("CREATE EXTENSION IF NOT EXISTS vector", &[])?;
client.execute("DROP TABLE IF EXISTS users", &[])?;
client.execute("DROP TABLE IF EXISTS movies", &[])?;
client.execute(
"CREATE TABLE users (id integer PRIMARY KEY, factors vector(20))",
&[],
)?;
client.execute(
"CREATE TABLE movies (name text PRIMARY KEY, factors vector(20))",
&[],
)?;
let data = load_movielens(Path::new(&movielens_path))?;
let recommender = RecommenderBuilder::new().factors(20).fit_explicit(&data);
for user_id in recommender.user_ids() {
let factors = Vector::from(recommender.user_factors(user_id).unwrap().to_vec());
client.execute(
"INSERT INTO users (id, factors) VALUES ($1, $2)",
&[&user_id, &factors],
)?;
}
for item_id in recommender.item_ids() {
let factors = Vector::from(recommender.item_factors(item_id).unwrap().to_vec());
client.execute(
"INSERT INTO movies (name, factors) VALUES ($1, $2)",
&[&item_id, &factors],
)?;
}
let movie = "Star Wars (1977)";
println!("Item-based recommendations for {}", movie);
for row in client.query("SELECT name FROM movies WHERE name != $1 ORDER BY factors <=> (SELECT factors FROM movies WHERE name = $1) LIMIT 5", &[&movie])? {
let name: &str = row.get(0);
println!("- {}", name);
}
let user_id = 123;
println!("\nUser-based recommendations for user {}", user_id);
for row in client.query("SELECT name FROM movies ORDER BY factors <#> (SELECT factors FROM users WHERE id = $1) LIMIT 5", &[&user_id])? {
let name: &str = row.get(0);
println!("- {}", name);
}
Ok(())
}
#[allow(clippy::type_complexity)]
fn load_movielens(path: &Path) -> Result<Vec<(i32, String, f32)>, Box<dyn Error>> {
// read movies
let mut movies = HashMap::with_capacity(2000);
let movies_file = File::open(path.join("u.item"))?;
let rdr = BufReader::new(movies_file);
for line in rdr.split(b'\n') {
let line = line?;
let line = convert_to_utf8(&line)?;
let mut row = line.split('|');
let id = row.next().unwrap().to_string();
let name = row.next().unwrap().to_string();
movies.insert(id, name);
}
// read ratings and create dataset
let mut data = Vec::with_capacity(100000);
let ratings_file = File::open(path.join("u.data"))?;
let rdr = BufReader::new(ratings_file);
for line in rdr.lines() {
let line = line?;
let mut row = line.split('\t');
let user_id = row.next().unwrap().parse()?;
let item_id = movies.get(row.next().unwrap()).unwrap().to_string();
let rating = row.next().unwrap().parse()?;
data.push((user_id, item_id, rating));
}
Ok(data)
}
// ISO-8859-1 to UTF-8
fn convert_to_utf8(s: &[u8]) -> Result<String, FromUtf8Error> {
let mut buf = Vec::with_capacity(s.len() + 10);
for v in s {
if *v < 128 {
buf.push(*v);
} else {
buf.push(195);
buf.push(*v - 64);
}
}
String::from_utf8(buf)
}