Skip to content

Commit

Permalink
fix: moving key curve validation and checks to Domain for better back…
Browse files Browse the repository at this point in the history
…wards compatibility
  • Loading branch information
elribonazo committed Feb 4, 2025
1 parent c18aaa1 commit 1fccb9c
Show file tree
Hide file tree
Showing 10 changed files with 45 additions and 29 deletions.
13 changes: 7 additions & 6 deletions src/apollo/Apollo.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import {
SeedWords,
StorableKey,
KeyRestoration,
isCurve,
} from "../domain";

import { Ed25519PrivateKey } from "./utils/Ed25519PrivateKey";
Expand Down Expand Up @@ -224,15 +225,15 @@ export default class Apollo implements ApolloInterface, KeyRestoration {
const keyData = parameters[KeyProperties.rawKey];

if (keyType === KeyTypes.EC) {
if (curve === Curve.ED25519) {
if (isCurve(curve, Curve.ED25519)) {
if (keyData) {
return new Ed25519PublicKey(keyData);
}

throw new ApolloError.MissingKeyParameters(KeyProperties.rawKey);
}

if (curve === Curve.SECP256K1) {
if (isCurve(curve, Curve.SECP256K1)) {
if (keyData) {
return new Secp256k1PublicKey(keyData);
} else {
Expand All @@ -249,7 +250,7 @@ export default class Apollo implements ApolloInterface, KeyRestoration {
}

if (keyType === KeyTypes.Curve25519) {
if (curve === Curve.X25519) {
if (isCurve(curve, Curve.X25519)) {
if (keyData) {
return new X25519PublicKey(keyData);
}
Expand Down Expand Up @@ -336,7 +337,7 @@ export default class Apollo implements ApolloInterface, KeyRestoration {
const keyData = parameters[KeyProperties.rawKey];

if (keyType === KeyTypes.EC) {
if (curve === Curve.ED25519) {
if (isCurve(curve, Curve.ED25519)) {
if (keyData) {
return new Ed25519PrivateKey(keyData);
}
Expand Down Expand Up @@ -366,7 +367,7 @@ export default class Apollo implements ApolloInterface, KeyRestoration {
return keyPair.privateKey;
}

if (curve === Curve.SECP256K1) {
if (isCurve(curve, Curve.SECP256K1)) {
if (keyData) {
return new Secp256k1PrivateKey(keyData);
}
Expand Down Expand Up @@ -412,7 +413,7 @@ export default class Apollo implements ApolloInterface, KeyRestoration {
}

if (keyType === KeyTypes.Curve25519) {
if (curve === Curve.X25519) {
if (isCurve(curve, Curve.X25519)) {
if (keyData) {
return new X25519PrivateKey(keyData);
}
Expand Down
12 changes: 6 additions & 6 deletions src/castor/did/prismDID/PrismDIDPublicKey.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { Curve, getProtosUsage, getUsage, PublicKey, Usage } from "../../../doma
import { ApolloError, CastorError } from "../../../domain/models/Errors";
import * as Protos from "../../protos/node_models";

import { Apollo, KeyProperties, KeyTypes } from "../../../domain";
import { Apollo, isCurve, KeyProperties, KeyTypes } from "../../../domain";

export class PrismDIDPublicKey {

Expand Down Expand Up @@ -50,14 +50,14 @@ export class PrismDIDPublicKey {
) {
const curve = this.getProtoCurve(proto);
if (proto.has_compressed_ec_key_data) {
if (curve === Curve.ED25519) {
if (isCurve(curve, Curve.ED25519)) {
return apollo.createPublicKey({
[KeyProperties.type]: KeyTypes.EC,
[KeyProperties.curve]: Curve.ED25519,
[KeyProperties.rawKey]: proto.compressed_ec_key_data.data
})
}
if (curve === Curve.X25519) {
if (isCurve(curve, Curve.X25519)) {
return apollo.createPublicKey({
[KeyProperties.type]: KeyTypes.Curve25519,
[KeyProperties.curve]: Curve.X25519,
Expand All @@ -75,13 +75,13 @@ export class PrismDIDPublicKey {
const id = proto.id;
const usage = getUsage(proto.usage);
const curve = this.getProtoCurve(proto);
if (curve === Curve.SECP256K1.toLocaleLowerCase()) {
if (isCurve(curve, Curve.SECP256K1)) {
return new PrismDIDPublicKey(
id,
usage,
this.fromSecp256k1Proto(apollo, proto)
);
} else if (curve === Curve.ED25519 || curve === Curve.X25519) {
} else if (isCurve(curve, Curve.ED25519) || isCurve(curve, Curve.X25519)) {
return new PrismDIDPublicKey(
id,
usage,
Expand All @@ -95,7 +95,7 @@ export class PrismDIDPublicKey {
toProto(): Protos.io.iohk.atala.prism.protos.PublicKey {
const curve = this.keyData.curve;
const usage = getProtosUsage(this.usage);
if (curve === Curve.SECP256K1) {
if (isCurve(curve, Curve.SECP256K1)) {
const encoded = this.keyData.getEncoded()
const xBytes = encoded.slice(1, 1 + ECConfig.PRIVATE_KEY_BYTE_SIZE);
const yBytes = encoded.slice(
Expand Down
7 changes: 4 additions & 3 deletions src/castor/resolver/LongFormPrismDIDResolver.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import {
PublicKey,
Curve,
getUsage,
isCurve,
} from "../../domain/models";

import * as DIDParser from "../parser/DIDParser";
Expand Down Expand Up @@ -92,7 +93,7 @@ export class LongFormPrismDIDResolver implements DIDResolver {
(key: Protos.io.iohk.atala.prism.protos.PublicKey) => {
const curve = this.getProtoCurve(key).toLocaleLowerCase()
let pk: PublicKey;
if (curve === Curve.SECP256K1.toLocaleLowerCase()) {
if (isCurve(curve, Curve.SECP256K1)) {
pk = key.has_compressed_ec_key_data
? Secp256k1PublicKey.secp256k1FromBytes(
key.compressed_ec_key_data.data
Expand All @@ -101,14 +102,14 @@ export class LongFormPrismDIDResolver implements DIDResolver {
key.ec_key_data.x,
key.ec_key_data.y
);
} else if (curve === Curve.ED25519.toLocaleLowerCase()) {
} else if (isCurve(curve, Curve.ED25519)) {
if (!key.has_compressed_ec_key_data) {
throw new Error("Expected compressed compressed key")
}
pk = Ed25519PublicKey.from.Buffer(
Buffer.from(key.compressed_ec_key_data.data)
)
} else if (curve === Curve.X25519.toLocaleLowerCase()) {
} else if (isCurve(curve, Curve.X25519)) {
if (!key.has_compressed_ec_key_data) {
throw new Error("Expected compressed compressed key")
}
Expand Down
12 changes: 11 additions & 1 deletion src/domain/models/keyManagement/Curve.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
export enum Curve {
X25519 = "X25519",
ED25519 = "Ed25519",
SECP256K1 = "Secp256k1",
SECP256K1 = "secp256k1",
}

export function isCurve(curve: string, curveEnum: Curve): boolean {
if (curve === curveEnum) {
return true;
}
if (curve.toLocaleLowerCase() === curveEnum.toLocaleLowerCase()) {
return true;
}
return false;
}
3 changes: 2 additions & 1 deletion src/domain/models/keyManagement/Key.ts
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ export enum Usage {
UNKNOWN_KEY = "unknownKey",
}
export function curveToAlg(curve: string) {
// For backwards compatibility
if (curve === Curve.SECP256K1 || curve === "secp256k1") {
return JWT_ALG.ES256K;
}
Expand Down Expand Up @@ -259,6 +260,6 @@ export abstract class Key {

isCurve<T>(curve: string): this is T {
const keyCurve = this.keySpecification.get(KeyProperties.curve);
return keyCurve === curve;
return keyCurve === curve || keyCurve?.toLocaleLowerCase() === curve.toLowerCase();
}
}
6 changes: 3 additions & 3 deletions src/domain/models/keyManagement/exportable/JWK.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { base64url } from "multiformats/bases/base64";
import { notEmptyString } from "../../../../utils";
import { KeyProperties } from "../../KeyProperties";
import { Curve } from "../Curve";
import { Curve, isCurve } from "../Curve";
import { PrivateKey } from "../PrivateKey";
import { PublicKey } from "../PublicKey";

Expand Down Expand Up @@ -110,8 +110,8 @@ export namespace JWK {
*/
export const fromKey = (key: PublicKey | PrivateKey, base: Base = {}): JWK => {
const prototype = Object.getPrototypeOf(key);
const privateFn = key.curve === Curve.SECP256K1 ? privateKeyToEC : privateKeyToOKP;
const publicFn = key.curve === Curve.SECP256K1 ? publicKeyToEC : publicKeyToOKP;
const privateFn = isCurve(key.curve, Curve.SECP256K1) ? privateKeyToEC : privateKeyToOKP;
const publicFn = isCurve(key.curve, Curve.SECP256K1) ? publicKeyToEC : publicKeyToOKP;

if (prototype instanceof PublicKey) {
return Object.assign({}, base, publicFn(key as PublicKey));
Expand Down
3 changes: 2 additions & 1 deletion src/edge-agent/didFunctions/CreateJwt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { base58btc } from "multiformats/bases/base58";
import * as Domain from "../../domain";
import { expect } from "../../utils";
import { Task } from "../../utils/tasks";
import { isCurve } from "../../domain";

/**
* Asyncronously sign with a DID
Expand All @@ -23,7 +24,7 @@ export class CreateJWT extends Task<string, Args> {
async run(ctx: Task.Context) {
const keys = await ctx.Pluto.getDIDPrivateKeysByDID(this.args.did);
const secpKey = expect(
keys.find(x => x.curve === Domain.Curve.SECP256K1),
keys.find(x => isCurve(x.curve, Domain.Curve.SECP256K1)),
"key not found"
);

Expand Down
5 changes: 3 additions & 2 deletions src/edge-agent/didcomm/CreatePresentation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import { SDJWTCredential } from "../../pollux/models/SDJWTVerifiableCredential";
import { Presentation, RequestPresentation } from "../protocols/proofPresentation";
import { DIDCommContext } from "./Context";
import { Task } from "../../utils/tasks";
import { isCurve } from "../../domain";

/**
* Asyncronously create a verifiablePresentation from a valid stored verifiableCredential
Expand Down Expand Up @@ -189,7 +190,7 @@ export class CreatePresentation extends Task<Presentation, Args> {
const subjectDID = Domain.DID.from(disclosed.sub);

const prismPrivateKeys = await ctx.Pluto.getDIDPrivateKeysByDID(subjectDID);
const prismPrivateKey = prismPrivateKeys.find((key) => key.curve === Domain.Curve.ED25519);
const prismPrivateKey = prismPrivateKeys.find((key) => isCurve(key.curve, Domain.Curve.ED25519));

if (prismPrivateKey === undefined) {
throw new Domain.AgentError.CannotFindDIDPrivateKey();
Expand Down Expand Up @@ -218,7 +219,7 @@ export class CreatePresentation extends Task<Presentation, Args> {
}
const subjectDID = Domain.DID.from(credential.subject);
const prismPrivateKeys = await ctx.Pluto.getDIDPrivateKeysByDID(subjectDID);
const prismPrivateKey = prismPrivateKeys.find((key) => key.curve === Domain.Curve.SECP256K1);
const prismPrivateKey = prismPrivateKeys.find((key) => isCurve(key.curve, Domain.Curve.SECP256K1));
if (prismPrivateKey === undefined) {
throw new Domain.AgentError.CannotFindDIDPrivateKey();
}
Expand Down
3 changes: 2 additions & 1 deletion src/mercury/didcomm/SecretsResolver.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import type { Secret, SecretsResolver } from "didcomm-wasm";
import * as Domain from "../../domain";
import * as DIDURLParser from "../../castor/parser/DIDUrlParser";
import { PeerDID } from "../../peer-did/PeerDID";
import { isCurve } from "../../domain";

export class DIDCommSecretsResolver implements SecretsResolver {
constructor(
Expand Down Expand Up @@ -56,7 +57,7 @@ export class DIDCommSecretsResolver implements SecretsResolver {
publicKeyJWK: Domain.PublicKeyJWK
): Secret {
const privateKeyBuffer = peerDid.privateKeys.find(
(key) => key.keyCurve.curve === Domain.Curve.X25519
(key) => isCurve(key.keyCurve.curve, Domain.Curve.X25519)
);
if (!privateKeyBuffer) {
throw new Error(`Invalid PrivateKey Curve ${Domain.Curve.X25519}`);
Expand Down
10 changes: 5 additions & 5 deletions tests/castor/PrismDID.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ describe("PrismDID", () => {
);
const masterPkProto = masterPk.toProto();

expect(() => PrismDIDPublicKey.fromProto(apollo, masterPkProto)).to.throw(`16: Invalid key curve: ${unsupportedCurve}. Valid options are: X25519, Ed25519, Secp256k1`);
expect(() => PrismDIDPublicKey.fromProto(apollo, masterPkProto)).to.throw(`16: Invalid key curve: ${unsupportedCurve}. Valid options are: X25519, Ed25519, secp256k1`);
});
});

Expand Down Expand Up @@ -143,7 +143,7 @@ describe("PrismDID", () => {
expect(cp0vm0?.id).to.eq(`${didStr}#master-0`);
expect(cp0vm0?.publicKeyJwk).to.be.undefined;
expect(cp0vm0?.publicKeyMultibase).to.eq("zSXxpYB6edvxvWxRTo3kMUoTTQVHpbNnXo2Z1AjLA78iqLdK2kVo5xw9rGg8uoEgmhxYahNur3RvV7HnaktWBqkXt");
expect(cp0vm0?.type).to.eq("Secp256k1");
expect(cp0vm0?.type).to.eq(Curve.SECP256K1);

const cp1 = sut.coreProperties.at(1) as Authentication;
expect(cp1).to.be.instanceOf(Authentication);
Expand All @@ -154,7 +154,7 @@ describe("PrismDID", () => {
expect(cp1vm0?.id).to.eq(`${didStr}#authentication-0`);
expect(cp1vm0?.publicKeyJwk).to.be.undefined;
expect(cp1vm0?.publicKeyMultibase).to.eq("zSXxpYB6edvxvWxRTo3kMUoTTQVHpbNnXo2Z1AjLA78iqLdK2kVo5xw9rGg8uoEgmhxYahNur3RvV7HnaktWBqkXt");
expect(cp1vm0?.type).to.eq("Secp256k1");
expect(cp1vm0?.type).to.eq(Curve.SECP256K1);

const cp2 = sut.coreProperties.at(2) as Services;
expect(cp2).to.be.instanceOf(Services);
Expand All @@ -171,15 +171,15 @@ describe("PrismDID", () => {
expect(cp3v0?.id).to.eq(`${didStr}#master-0`);
expect(cp3v0?.publicKeyJwk).to.be.undefined;
expect(cp3v0?.publicKeyMultibase).to.eq("zSXxpYB6edvxvWxRTo3kMUoTTQVHpbNnXo2Z1AjLA78iqLdK2kVo5xw9rGg8uoEgmhxYahNur3RvV7HnaktWBqkXt");
expect(cp3v0?.type).to.eq("Secp256k1");
expect(cp3v0?.type).to.eq(Curve.SECP256K1);

const cp3v1 = cp3.values.at(1);
expect(cp3v1).to.be.instanceOf(VerificationMethod);
expect(cp3v1?.controller).to.eq(didStr);
expect(cp3v1?.id).to.eq(`${didStr}#authentication-0`);
expect(cp3v1?.publicKeyJwk).to.be.undefined;
expect(cp3v1?.publicKeyMultibase).to.eq("zSXxpYB6edvxvWxRTo3kMUoTTQVHpbNnXo2Z1AjLA78iqLdK2kVo5xw9rGg8uoEgmhxYahNur3RvV7HnaktWBqkXt");
expect(cp3v1?.type).to.eq("Secp256k1");
expect(cp3v1?.type).to.eq(Curve.SECP256K1);
});

const masterKeyId = getUsageId(Usage.MASTER_KEY);
Expand Down

0 comments on commit 1fccb9c

Please sign in to comment.