Skip to content

Commit

Permalink
chore(internal): add utils methods for parsing SSE (#253)
Browse files Browse the repository at this point in the history
  • Loading branch information
stainless-app[bot] authored and stainless-ci-bot committed Mar 4, 2025
1 parent 17b99c3 commit 0bfd6f6
Show file tree
Hide file tree
Showing 11 changed files with 506 additions and 105 deletions.
86 changes: 33 additions & 53 deletions lib/modern_treasury/base_client.rb
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class << self
# @raise [ArgumentError]
#
def validate!(req)
keys = [:method, :path, :query, :headers, :body, :unwrap, :page, :model, :options]
keys = [:method, :path, :query, :headers, :body, :unwrap, :page, :stream, :model, :options]
case req
in Hash
req.each_key do |k|
Expand Down Expand Up @@ -201,6 +201,8 @@ def initialize(
#
# @option req [Class, nil] :page
#
# @option req [Class, nil] :stream
#
# @option req [ModernTreasury::Converter, Class, nil] :model
#
# @param opts [Hash{Symbol=>Object}] .
Expand Down Expand Up @@ -319,7 +321,7 @@ def initialize(
# @param send_retry_header [Boolean]
#
# @raise [ModernTreasury::APIError]
# @return [Array(Net::HTTPResponse, Enumerable)]
# @return [Array(Integer, Net::HTTPResponse, Enumerable)]
#
private def send_request(request, redirect_count:, retry_count:, send_retry_header:)
url, headers, max_retries, timeout = request.fetch_values(:url, :headers, :max_retries, :timeout)
Expand All @@ -342,7 +344,7 @@ def initialize(

case status
in ..299
[response, stream]
[status, response, stream]
in 300..399 if redirect_count >= self.class::MAX_REDIRECTS
message = "Failed to complete the request within #{self.class::MAX_REDIRECTS} redirects."

Expand All @@ -360,13 +362,15 @@ def initialize(
)
in ModernTreasury::APIConnectionError if retry_count >= max_retries
raise status
in (400..) if retry_count >= max_retries || (response && !self.class.should_retry?(
status,
headers: response
))
in (400..) if retry_count >= max_retries || !self.class.should_retry?(status, headers: response)
decoded = ModernTreasury::Util.decode_content(response, stream: stream, suppress_error: true)

stream.each { srv_fault ? break : next }
if srv_fault
ModernTreasury::Util.close_fused!(stream)
else
stream.each { next }
end

raise ModernTreasury::APIStatusError.for(
url: url,
status: status,
Expand All @@ -377,7 +381,11 @@ def initialize(
in (400..) | ModernTreasury::APIConnectionError
delay = retry_delay(response, retry_count: retry_count)

stream&.each { srv_fault ? break : next }
if srv_fault
ModernTreasury::Util.close_fused!(stream)
else
stream&.each { next }
end
sleep(delay)

send_request(
Expand All @@ -389,48 +397,6 @@ def initialize(
end
end

# @private
#
# @param req [Hash{Symbol=>Object}] .
#
# @option req [Symbol] :method
#
# @option req [String, Array<String>] :path
#
# @option req [Hash{String=>Array<String>, String, nil}, nil] :query
#
# @option req [Hash{String=>String, Integer, Array<String, Integer, nil>, nil}, nil] :headers
#
# @option req [Object, nil] :body
#
# @option req [Symbol, nil] :unwrap
#
# @option req [Class, nil] :page
#
# @option req [ModernTreasury::Converter, Class, nil] :model
#
# @option req [ModernTreasury::RequestOptions, Hash{Symbol=>Object}, nil] :options
#
# @param headers [Hash{String=>String}, Net::HTTPHeader]
#
# @param stream [Enumerable]
#
# @return [Object]
#
private def parse_response(req, headers:, stream:)
decoded = ModernTreasury::Util.decode_content(headers, stream: stream)
unwrapped = ModernTreasury::Util.dig(decoded, req[:unwrap])

case [req[:page], req.fetch(:model, ModernTreasury::Unknown)]
in [Class => page, _]
page.new(client: self, req: req, headers: headers, unwrapped: unwrapped)
in [nil, Class | ModernTreasury::Converter => model]
ModernTreasury::Converter.coerce(model, unwrapped)
in [nil, nil]
unwrapped
end
end

# Execute the request specified by `req`. This is the method that all resource
# methods call into.
#
Expand All @@ -450,6 +416,8 @@ def initialize(
#
# @option req [Class, nil] :page
#
# @option req [Class, nil] :stream
#
# @option req [ModernTreasury::Converter, Class, nil] :model
#
# @option req [ModernTreasury::RequestOptions, Hash{Symbol=>Object}, nil] :options
Expand All @@ -459,19 +427,31 @@ def initialize(
#
def request(req)
self.class.validate!(req)
model = req.fetch(:model) { ModernTreasury::Unknown }
opts = req[:options].to_h
ModernTreasury::RequestOptions.validate!(opts)
request = build_request(req.except(:options), opts)
url = request.fetch(:url)

# Don't send the current retry count in the headers if the caller modified the header defaults.
send_retry_header = request.fetch(:headers)["x-stainless-retry-count"] == "0"
response, stream = send_request(
status, response, stream = send_request(
request,
redirect_count: 0,
retry_count: 0,
send_retry_header: send_retry_header
)
parse_response(req, headers: response, stream: stream)

decoded = ModernTreasury::Util.decode_content(response, stream: stream)
case req
in { stream: Class => st }
st.new(model: model, url: url, status: status, response: response, messages: decoded)
in { page: Class => page }
page.new(client: self, req: req, headers: response, unwrapped: decoded)
else
unwrapped = ModernTreasury::Util.dig(decoded, req[:unwrap])
ModernTreasury::Converter.coerce(model, unwrapped)
end
end

# @return [String]
Expand Down
5 changes: 3 additions & 2 deletions lib/modern_treasury/errors.rb
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,12 @@ class APIStatusError < ModernTreasury::APIError
# @param body [Object, nil]
# @param request [nil]
# @param response [nil]
# @param message [String, nil]
#
# @return [ModernTreasury::APIStatusError]
#
def self.for(url:, status:, body:, request:, response:)
kwargs = {url: url, status: status, body: body, request: request, response: response}
def self.for(url:, status:, body:, request:, response:, message: nil)
kwargs = {url: url, status: status, body: body, request: request, response: response, message: message}

case status
in 400
Expand Down
23 changes: 14 additions & 9 deletions lib/modern_treasury/pooled_net_requester.rb
Original file line number Diff line number Diff line change
Expand Up @@ -131,33 +131,38 @@ def execute(request)
req = self.class.build_request(request)

eof = false
finished = false
enum = Enumerator.new do |y|
with_pool(url) do |conn|
next if finished

self.class.calibrate_socket_timeout(conn, deadline)
conn.start unless conn.started?

self.class.calibrate_socket_timeout(conn, deadline)
conn.request(req) do |rsp|
y << [conn, rsp]
break if finished

rsp.read_body do |bytes|
y << bytes
break if finished

self.class.calibrate_socket_timeout(conn, deadline)
end
eof = true
end
end
end

# need to protect the `Enumerator` against `#.rewind`
fused = false
conn, response = enum.next
body = Enumerator.new do |y|
next if fused

fused = true
loop { y << enum.next }
ensure
conn.finish if !eof && conn.started?
body = ModernTreasury::Util.fused_enum(enum) do
finished = true
tap do
enum.next
rescue StopIteration

Check warning on line 163 in lib/modern_treasury/pooled_net_requester.rb

View workflow job for this annotation

GitHub Actions / lint

Lint/SuppressedException: Do not suppress exceptions.
end
conn.finish if !eof && conn&.started?
end
[response, (response.body = body)]
end
Expand Down
118 changes: 118 additions & 0 deletions lib/modern_treasury/util.rb
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,9 @@ def encode_content(headers, body)
#
def decode_content(headers, stream:, suppress_error: false)
case headers["content-type"]
in %r{^text/event-stream}
lines = enum_lines(stream)
parse_sse(lines)
in %r{^application/json}
json = stream.to_a.join
begin
Expand All @@ -512,6 +515,121 @@ def decode_content(headers, stream:, suppress_error: false)
end
end
end

class << self
# @private
#
# https://doc.rust-lang.org/std/iter/trait.FusedIterator.html
#
# @param enum [Enumerable]
# @param close [Proc]
#
# @return [Enumerable]
#
def fused_enum(enum, &close)
fused = false
iter = Enumerator.new do |y|
next if fused

fused = true
loop { y << enum.next }
ensure
close&.call
close = nil
end

iter.define_singleton_method(:rewind) do
fused = true
self
end
iter
end

# @private
#
# @param enum [Enumerable, nil]
#
def close_fused!(enum)
return unless enum.is_a?(Enumerator)

# rubocop:disable Lint/UnreachableLoop
enum.rewind.each { break }
# rubocop:enable Lint/UnreachableLoop
end

# @private
#
# @param enum [Enumerable, nil]
# @param blk [Proc]
#
def chain_fused(enum, &blk)
iter = Enumerator.new { blk.call(_1) }
fused_enum(iter) { close_fused!(enum) }
end
end

class << self
# @private
#
# @param enum [Enumerable]
#
# @return [Enumerable]
#
def enum_lines(enum)
chain_fused(enum) do |y|
buffer = String.new
enum.each do |row|
buffer << row
while (idx = buffer.index("\n"))
y << buffer.slice!(..idx)
end
end
y << buffer unless buffer.empty?
end
end

# @private
#
# https://html.spec.whatwg.org/multipage/server-sent-events.html#parsing-an-event-stream
#
# @param lines [Enumerable]
#
# @return [Hash{Symbol=>Object}]
#
def parse_sse(lines)
chain_fused(lines) do |y|

Check warning on line 600 in lib/modern_treasury/util.rb

View workflow job for this annotation

GitHub Actions / lint

Metrics/BlockLength: Block has too many lines. [27/25]
blank = {event: nil, data: nil, id: nil, retry: nil}
current = {}

lines.each do |line|
case line.strip
in ""
next if current.empty?
y << {**blank, **current}
current = {}
in /^:/
next
in /^([^:]+):\s?(.*)$/
_, field, value = Regexp.last_match.to_a
case field
in "event"
current.merge!(event: value)
in "data"
(current[:data] ||= String.new) << value << "\n"
in "id" unless value.include?("\0")
current.merge!(id: value)
in "retry" if /^\d+$/ =~ value
current.merge!(retry: Integer(value))
else
end
else
end
end

y << {**blank, **current} unless current.empty?
end
end
end
end

# rubocop:enable Metrics/ModuleLength
Expand Down
Loading

0 comments on commit 0bfd6f6

Please sign in to comment.