-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpgvector.lua
More file actions
138 lines (116 loc) · 3.04 KB
/
pgvector.lua
File metadata and controls
138 lines (116 loc) · 3.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
local pgvector = {}
-- vector
local vector_mt = {
pgmoon_serialize = function(v)
return 0, pgvector.serialize(v)
end,
}
function pgvector.new(v)
local vec = {}
for _, x in ipairs(v) do
table.insert(vec, x)
end
return setmetatable(vec, vector_mt)
end
function pgvector.serialize(v)
for _, v in ipairs(v) do
assert(type(v) == "number")
end
return "[" .. table.concat(v, ",") .. "]"
end
function pgvector.deserialize(v)
local vec = {}
for x in string.gmatch(string.sub(v, 2, -2), "[^,]+") do
table.insert(vec, tonumber(x))
end
-- pgvector.new without copy
return setmetatable(vec, vector_mt)
end
-- halfvec
local function halfvec_serialize(v)
for _, v in ipairs(v) do
assert(type(v) == "number")
end
return "[" .. table.concat(v, ",") .. "]"
end
local function halfvec_deserialize(v)
local vec = {}
for x in string.gmatch(string.sub(v, 2, -2), "[^,]+") do
table.insert(vec, tonumber(x))
end
-- pgvector.halfvec without copy
return setmetatable(vec, halfvec_mt)
end
local halfvec_mt = {
pgmoon_serialize = function(v)
return 0, halfvec_serialize(v)
end,
}
function pgvector.halfvec(v)
local vec = {}
for _, x in ipairs(v) do
table.insert(vec, x)
end
return setmetatable(vec, halfvec_mt)
end
-- sparsevec
local function sparsevec_serialize(vec)
local elements = {}
for i, v in pairs(vec["elements"]) do
table.insert(elements, tonumber(i) .. ":" .. tonumber(v))
end
return "{" .. table.concat(elements, ",") .. "}/" .. tonumber(vec["dim"])
end
local function sparsevec_deserialize(v)
local m = string.gmatch(v, "[^/]+")
local elements = {}
for e in string.gmatch(string.sub(m(), 2, -2), "[^,]+") do
local mx = string.gmatch(e, "[^:]+")
local index = tonumber(mx())
local value = tonumber(mx())
elements[index] = value
end
local vec = {
elements = elements,
dim = tonumber(m()),
}
return setmetatable(vec, sparsevec_mt)
end
local sparsevec_mt = {
pgmoon_serialize = function(v)
return 0, sparsevec_serialize(v)
end,
}
function pgvector.sparsevec(elements, dim)
for k, v in pairs(elements) do
assert(type(k) == "number")
assert(type(v) == "number")
end
assert(type(dim) == "number")
local vec = {
elements = elements,
dim = dim,
}
return setmetatable(vec, sparsevec_mt)
end
-- register
function pgvector.setup_vector(pg)
local row = pg:query(
"SELECT to_regtype('vector')::oid AS vector_oid, to_regtype('halfvec')::oid AS halfvec_oid, to_regtype('sparsevec')::oid AS sparsevec_oid"
)[1]
assert(row["vector_oid"], "vector type not found in the database")
pg:set_type_deserializer(row["vector_oid"], "vector", function(self, v)
return pgvector.deserialize(v)
end)
if row["halfvec_oid"] then
pg:set_type_deserializer(row["halfvec_oid"], "halfvec", function(self, v)
return halfvec_deserialize(v)
end)
end
if row["sparsevec_oid"] then
pg:set_type_deserializer(row["sparsevec_oid"], "sparsevec", function(self, v)
return sparsevec_deserialize(v)
end)
end
end
return pgvector