Skip to content

Commit

Permalink
Add max_token_id option
Browse files Browse the repository at this point in the history
  • Loading branch information
ksew1 committed Jan 17, 2025
1 parent 9d8aeb4 commit fe6c835
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 11 deletions.
40 changes: 33 additions & 7 deletions lib/scholar/feature_extraction/count_vectorizer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,18 @@ defmodule Scholar.FeatureExtraction.CountVectorizer do
"""
import Nx.Defn

opts_schema = [
max_token_id: [
type: :pos_integer,
required: true,
doc: ~S"""
Maximum token id in the input tensor.
"""
]
]

@opts_schema NimbleOptions.new!(opts_schema)

@doc """
Generates a count matrix where each row corresponds to a document in the input corpus, and each column corresponds to a unique token in the vocabulary of the corpus.
Expand All @@ -13,32 +25,46 @@ defmodule Scholar.FeatureExtraction.CountVectorizer do
The same number represents the same token in the vocabulary. Tokens should start from 0 and be consecutive. Negative values are ignored, making them suitable for padding.
## Options
#{NimbleOptions.docs(@opts_schema)}
## Examples
iex> t = Nx.tensor([[0, 1, 2], [1, 3, 4]])
iex> Scholar.FeatureExtraction.CountVectorizer.fit_transform(t)
iex> Scholar.FeatureExtraction.CountVectorizer.fit_transform(t, max_token_id: Scholar.FeatureExtraction.CountVectorizer.max_token_id(t))
Nx.tensor([
[1, 1, 1, 0, 0],
[0, 1, 0, 1, 1]
])
With padding:
iex> t = Nx.tensor([[0, 1, -1], [1, 3, 4]])
iex> Scholar.FeatureExtraction.CountVectorizer.fit_transform(t)
iex> Scholar.FeatureExtraction.CountVectorizer.fit_transform(t, max_token_id: Scholar.FeatureExtraction.CountVectorizer.max_token_id(t))
Nx.tensor([
[1, 1, 0, 0, 0],
[0, 1, 0, 1, 1]
])
"""
deftransform fit_transform(tensor) do
max_index = tensor |> Nx.reduce_max() |> Nx.add(1) |> Nx.to_number()
opts = [max_index: max_index]
deftransform fit_transform(tensor, opts \\ []) do
fit_transform_n(tensor, NimbleOptions.validate!(opts, @opts_schema))
end

fit_transform_n(tensor, opts)
@doc """
Computes the max_token_id option from given tensor.
## Examples
iex> t = Nx.tensor([[1, -1, 2], [2, 0, 0], [0, 1, -1]])
iex> Scholar.FeatureExtraction.CountVectorizer.max_token_id(t)
2
"""
def max_token_id(tensor) do
tensor |> Nx.reduce_max() |> Nx.to_number()
end

defnp fit_transform_n(tensor, opts) do
check_for_rank(tensor)
counts = Nx.broadcast(0, {Nx.axis_size(tensor, 0), opts[:max_index]})
counts = Nx.broadcast(0, {Nx.axis_size(tensor, 0), opts[:max_token_id] + 1})

{_, counts} =
while {{i = 0, tensor}, counts}, Nx.less(i, Nx.axis_size(tensor, 0)) do
Expand Down
28 changes: 24 additions & 4 deletions test/scholar/feature_extraction/count_vectorizer.ex
Original file line number Diff line number Diff line change
@@ -1,32 +1,52 @@
defmodule Scholar.Preprocessing.BinarizerTest do
defmodule Scholar.Preprocessing.CountVectorizer do
use Scholar.Case, async: true
alias Scholar.FeatureExtraction.CountVectorizer
doctest CountVectorizer

describe "fit_transform" do
test "fit_transform test" do
counts = CountVectorizer.fit_transform(Nx.tensor([[2, 3, 0], [1, 4, 4]]))
tesnsor = Nx.tensor([[2, 3, 0], [1, 4, 4]])

counts =
CountVectorizer.fit_transform(tesnsor,
max_token_id: CountVectorizer.max_token_id(tesnsor)
)

expected_counts = Nx.tensor([[1, 0, 1, 1, 0], [0, 1, 0, 0, 2]])

assert counts == expected_counts
end

test "fit_transform test - tensor with padding" do
counts = CountVectorizer.fit_transform(Nx.tensor([[2, 3, 0], [1, 4, -1]]))
tensor = Nx.tensor([[2, 3, 0], [1, 4, -1]])

counts =
CountVectorizer.fit_transform(tensor, max_token_id: CountVectorizer.max_token_id(tensor))

expected_counts = Nx.tensor([[1, 0, 1, 1, 0], [0, 1, 0, 0, 1]])

assert counts == expected_counts
end
end

describe "max_token_id" do
test "max_token_id test" do
tensor = Nx.tensor([[2, 3, 0], [1, 4, 4]])
assert CountVectorizer.max_token_id(tensor) == 4
end

test "max_token_id tes - tensor with padding" do
tensor = Nx.tensor([[2, 3, 0], [1, 4, -1]])
assert CountVectorizer.max_token_id(tensor) == 4
end
end

describe "errors" do
test "wrong input rank" do
assert_raise ArgumentError,
"expected tensor to have shape {num_documents, num_tokens}, got tensor with shape: {3}",
fn ->
CountVectorizer.fit_transform(Nx.tensor([1, 2, 3]))
CountVectorizer.fit_transform(Nx.tensor([1, 2, 3]), max_token_id: 3)
end
end
end
Expand Down

0 comments on commit fe6c835

Please sign in to comment.