Skip to content

Commit

Permalink
Add the groups parameter to MLXNN.Conv1d (#154)
Browse files Browse the repository at this point in the history
  • Loading branch information
lucasnewman authored Oct 17, 2024
1 parent 9a0d256 commit c09a4f4
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions Source/MLXNN/Convolution.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ open class Conv1d: Module, UnaryLayer {
public let weight: MLXArray
public let bias: MLXArray?
public let padding: Int
public let groups: Int
public let stride: Int

/// Applies a 1-dimensional convolution over the multi-channel input sequence.
Expand All @@ -31,26 +32,29 @@ open class Conv1d: Module, UnaryLayer {
/// - kernelSize: size of the convolution filters
/// - stride: stride when applying the filter
/// - padding: many positions to 0-pad the input with
/// - groups: the number of groups for the convolution
/// - bias: if `true` add a learnable bias to the output
public init(
inputChannels: Int,
outputChannels: Int,
kernelSize: Int,
stride: Int = 1,
padding: Int = 0,
groups: Int = 1,
bias: Bool = true
) {
let scale = sqrt(1 / Float(inputChannels * kernelSize))

self.weight = uniform(
low: -scale, high: scale, [outputChannels, kernelSize, inputChannels])
low: -scale, high: scale, [outputChannels, kernelSize, inputChannels / groups])
self.bias = bias ? MLXArray.zeros([outputChannels]) : nil
self.padding = padding
self.groups = groups
self.stride = stride
}

open func callAsFunction(_ x: MLXArray) -> MLXArray {
var y = conv1d(x, weight, stride: stride, padding: padding)
var y = conv1d(x, weight, stride: stride, padding: padding, groups: groups)
if let bias {
y = y + bias
}
Expand Down

0 comments on commit c09a4f4

Please sign in to comment.