Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
87 commits
Select commit Hold shift + click to select a range
3dcd569
edges sampling API (scala)
SemyonSinchenko Oct 16, 2025
06f6af6
add seed to z-estimation
SemyonSinchenko Oct 16, 2025
e4ede62
wip
SemyonSinchenko Oct 27, 2025
08ee0eb
Merge remote-tracking branch 'graphframes/main' into 726-sampling-api
SemyonSinchenko Oct 30, 2025
d050a7a
WIP
SemyonSinchenko Nov 20, 2025
4a0a14f
scalfix
SemyonSinchenko Nov 21, 2025
579a5c3
docstrings to RandomWalkBase and RandomWalkWithRestart
SemyonSinchenko Dec 1, 2025
16dd46c
Fix RandomWalk implementation bugs and add example
SemyonSinchenko Dec 2, 2025
965dea8
Add Word2VecHashingTrick implementation for graph embeddings
SemyonSinchenko Dec 3, 2025
3260cc9
Implement reservoir sampling for neighbor selection in random walks
SemyonSinchenko Dec 4, 2025
aafcbc6
fix scalastyle?
SemyonSinchenko Dec 5, 2025
f27462d
fix reservoir
SemyonSinchenko Dec 5, 2025
28b8ca5
add hash2vec
SemyonSinchenko Dec 8, 2025
b8255ac
fixes
SemyonSinchenko Dec 9, 2025
2113376
docstrings + scalfix
SemyonSinchenko Dec 9, 2025
6244ec2
remove sampling as not needed
SemyonSinchenko Dec 9, 2025
a9ce665
fixes in build and code
SemyonSinchenko Dec 10, 2025
0d5da92
Big update
SemyonSinchenko Dec 30, 2025
84306b6
Tests and updates
SemyonSinchenko Dec 30, 2025
5084d88
workaround scala 2.13 deprecation of Searching.search
SemyonSinchenko Dec 30, 2025
0a9dc90
Fixes
SemyonSinchenko Dec 30, 2025
f7a9aeb
Fix access
SemyonSinchenko Dec 30, 2025
acc5a2c
fallback to java Serialization
SemyonSinchenko Dec 30, 2025
8b20f00
Fix some bugs
SemyonSinchenko Dec 30, 2025
538a72d
Sampling Convolution tests and docstrings
SemyonSinchenko Dec 31, 2025
3548bb1
docstrings for RW Emebddings
SemyonSinchenko Dec 31, 2025
3df9282
Python API
SemyonSinchenko Jan 2, 2026
af97cac
hash2vec tests
SemyonSinchenko Jan 2, 2026
5d724b5
Explicit types
SemyonSinchenko Jan 2, 2026
526330c
hash2vec and random walks with restart tests
SemyonSinchenko Jan 2, 2026
061fab3
Merge remote-tracking branch 'graphframes/main' into 726-sampling-api
SemyonSinchenko Jan 2, 2026
7e77bc9
performance
SemyonSinchenko Jan 7, 2026
02c22f8
ignore unused nowarn
SemyonSinchenko Jan 7, 2026
97d529a
performance + cached walks support + continous mode
SemyonSinchenko Jan 7, 2026
cc7f908
Merge remote-tracking branch 'graphframes/main' into 726-sampling-api
SemyonSinchenko Jan 7, 2026
48e981a
fix rw and update the branch
SemyonSinchenko Jan 7, 2026
de63a45
fix
SemyonSinchenko Jan 7, 2026
e1cf2a4
protobuf & connect
SemyonSinchenko Jan 7, 2026
d2780a4
public API for embeddings and small refactoring of methods
SemyonSinchenko Jan 7, 2026
a294c03
initial Py API for embeddings
SemyonSinchenko Jan 9, 2026
60ca7e5
fix
SemyonSinchenko Jan 9, 2026
6583c20
tests + docs
SemyonSinchenko Jan 19, 2026
54ee3d7
fix connect tests
SemyonSinchenko Jan 19, 2026
867ae47
Merge remote-tracking branch 'graphframes/main' into 726-sampling-api
SemyonSinchenko Feb 11, 2026
c7ef2d9
decrease GC pressure
SemyonSinchenko Feb 12, 2026
ff107a2
further optimizations
SemyonSinchenko Feb 12, 2026
7f01b9c
refactor: optimize Hash2Vec string hashing and partitioning logic
SemyonSinchenko Feb 12, 2026
b7d95f8
refactor: replace generic hash function with type-specific implementa…
SemyonSinchenko Feb 12, 2026
5def505
refactor: simplify hash function logic and improve performance in Has…
SemyonSinchenko Feb 12, 2026
5f27b92
refactor: inline generic processPartitionGeneric into specialized Str…
SemyonSinchenko Feb 12, 2026
ed07856
test: add tests for PagedMatrixDouble helper covering page extension,…
SemyonSinchenko Feb 12, 2026
5cc9f1a
refactor: replace unsafe hash functions with MurmurHash3 and optimize…
SemyonSinchenko Feb 12, 2026
895a319
refactor: update processStringPartition to use PagedMatrixDouble for …
SemyonSinchenko Feb 12, 2026
66809bd
docs: add internal documentation for PagedMatrixDouble explaining mem…
SemyonSinchenko Feb 12, 2026
548a17a
chore: remove commented helper section from Hash2Vec
SemyonSinchenko Feb 12, 2026
70d96cb
refactor: replace case class with class for PagedMatrixDouble and rem…
SemyonSinchenko Feb 12, 2026
842cb7e
fix: correct LongMap type parameter in Hash2Vec vocabIndex initializa…
SemyonSinchenko Feb 12, 2026
c2c6ff4
refactor: reduce PAGE_BITS from 16 to 12 and update related constants…
SemyonSinchenko Feb 12, 2026
54c32b3
feat: add max vectors per partition limit and batched processing for …
SemyonSinchenko Feb 12, 2026
e689b86
refactor: process long partitions in batches respecting max vectors l…
SemyonSinchenko Feb 12, 2026
6d92b2c
test: add Hash2Vec tests for co-occurrence patterns and cosine simila…
SemyonSinchenko Feb 12, 2026
ae60afd
chore: clean up code formatting and improve test readability in Hash2Vec
SemyonSinchenko Feb 12, 2026
c0f3ca3
fix: correct typo in error message from 'gor' to 'got' in Hash2Vec ex…
SemyonSinchenko Feb 12, 2026
b0dc568
docs: add docstrings for Hash2Vec setters setDoNormalization and setM…
SemyonSinchenko Feb 12, 2026
95d73c3
fix: correct typo in NOTICE and KMinSampling, update Hash2Vec default…
SemyonSinchenko Feb 12, 2026
de2d6a3
fix: skip seeds for previous batches to maintain consistency when sta…
SemyonSinchenko Feb 18, 2026
3a0a302
fix: add overwrite mode when writing batch results to allow re-runnin…
SemyonSinchenko Feb 18, 2026
af7f8f4
feat: add cleanUp method to remove temporary files for a walk ID usin…
SemyonSinchenko Feb 18, 2026
579a2b5
docs: improve documentation for cleanUp method in RandomWalkBase trait
SemyonSinchenko Feb 18, 2026
5e2a462
test: add cleanUp call in RandomWalkWithRestart test
SemyonSinchenko Feb 18, 2026
d3cb1a2
test: move walks execution inside try block in RandomWalkWithRestartS…
SemyonSinchenko Feb 18, 2026
f5bb5ca
refactor: set walkID default to UUID and remove redundant runID variable
SemyonSinchenko Feb 18, 2026
4b25649
docs: update comment to clarify walkID retrieval method behavior
SemyonSinchenko Feb 18, 2026
0c9262d
refactor: rename walkID to runID for clarity in random walk operations
SemyonSinchenko Feb 18, 2026
c3b8d86
refactor: remove runId parameter from cleanUp method in RandomWalkBase
SemyonSinchenko Feb 18, 2026
3309269
refactor: move cleanUp method to companion object with parameters and…
SemyonSinchenko Feb 18, 2026
a9ff184
refactor: improve documentation formatting and remove redundant log s…
SemyonSinchenko Feb 18, 2026
9eb4dea
test: verify temporary files are deleted after RandomWalkWithRestart …
SemyonSinchenko Feb 18, 2026
5b349c8
refactor: use numBatches variable and fix run path in RandomWalkWithR…
SemyonSinchenko Feb 18, 2026
5c12113
test: add test for RandomWalkWithRestart resuming from middle iteration
SemyonSinchenko Feb 18, 2026
483c3be
style: Format RandomWalkWithRestartSuite with consistent spacing and …
SemyonSinchenko Feb 18, 2026
5efc5dd
refactor: use getSeq instead of getAs[Seq[String]] in RandomWalkWithR…
SemyonSinchenko Feb 18, 2026
001a2bf
fix: correct typo in error message and parameter name for gaussian sigma
SemyonSinchenko Feb 26, 2026
64aee65
feat: add clean-up option for temporary random walk files in embeddin…
SemyonSinchenko Feb 26, 2026
e45381d
docs: improve formatting and consistency in graph-ml documentation ta…
SemyonSinchenko Feb 26, 2026
d4f9f50
feat: add clean_up_after_run field to RandomWalkEmbeddings proto defi…
SemyonSinchenko Feb 26, 2026
79be29d
chore: add clean_up_after_run parameter to _RandomWalksEmbeddingsPara…
SemyonSinchenko Feb 26, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,8 @@ spark-*

# Zed
.zed

# Emacs
.dir-locals.el
*~
.aider*
5 changes: 2 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,14 @@ repos:

- id: scalafmt
name: scalafmt
entry: build/sbt scalafmtCheckAll
entry: build/sbt scalafmtAll
language: system
types: [scala]
pass_filenames: false

- id: scalafix
name: scalafix
entry: build/sbt "scalafixAll --check"
entry: build/sbt scalafixAll
language: system
types: [scala]
pass_filenames: false

7 changes: 7 additions & 0 deletions NOTICE
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,10 @@ Copyright 2014-2025 The Apache Software Foundation.

This product includes software developed at
The Apache Software Foundation (http://www.apache.org/).

Part of the code of the project is heavily inspired or copied from the Apache Spark ML project, which are licensed under the Apache Software License, Version 2.0. The Apache Spark project has the following NOTICE:
Apache Spark
Copyright 2014 and onwards The Apache Software Foundation.

This product includes software developed at
The Apache Software Foundation (http://www.apache.org/).
18 changes: 12 additions & 6 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,12 @@ lazy val commonSetting = Seq(
"--add-opens=java.base/java.lang=ALL-UNNAMED",
"--add-opens=java.base/java.nio=ALL-UNNAMED",
"--add-opens=java.base/java.lang.invoke=ALL-UNNAMED",
"--add-opens=java.base/java.util=ALL-UNNAMED"),
"--add-opens=java.base/java.util=ALL-UNNAMED",
"--add-opens=java.base/sun.security.action=ALL-UNNAMED",
"--add-opens=java.base/java.io=ALL-UNNAMED"),

// Scalac options
tpolecatScalacOptions ++= Set(
Compile / tpolecatScalacOptions ++= Set(
ScalacOptions.lint,
ScalacOptions.deprecation,
ScalacOptions.warnDeadCode,
Expand All @@ -111,7 +113,10 @@ lazy val commonSetting = Seq(
ScalacOptions.warnUnusedNoWarn,
ScalacOptions.source3,
ScalacOptions.fatalWarnings),
tpolecatExcludeOptions ++= Set(ScalacOptions.warnNonUnitStatement),
Compile / tpolecatExcludeOptions ++= Set(
ScalacOptions.warnNonUnitStatement,
ScalacOptions.privateWarnUnusedNoWarn,
ScalacOptions.warnUnusedNoWarn),
Test / tpolecatExcludeOptions ++= Set(
ScalacOptions.warnValueDiscard,
ScalacOptions.warnUnusedLocals,
Expand All @@ -122,8 +127,7 @@ lazy val commonSetting = Seq(
ScalacOptions.warnNumericWiden,
ScalacOptions.privateWarnNumericWiden,
ScalacOptions.warnUnusedNoWarn,
ScalacOptions.privateWarnUnusedNoWarn,
))
ScalacOptions.privateWarnUnusedNoWarn))

lazy val graphx = (project in file("graphx"))
.settings(
Expand All @@ -136,7 +140,9 @@ lazy val graphx = (project in file("graphx"))
// for scala 2.13 we should mark "unused" class tags by @nowarn,
// for scala 2.12 we shouldn't
// the only way at the moment is to not check unused @nowarn for GraphX
tpolecatExcludeOptions ++= Set(ScalacOptions.warnUnusedNoWarn, ScalacOptions.privateWarnUnusedNoWarn),
tpolecatExcludeOptions ++= Set(
ScalacOptions.warnUnusedNoWarn,
ScalacOptions.privateWarnUnusedNoWarn),

// Global settings
Global / concurrentRestrictions := Seq(Tags.limitAll(1)),
Expand Down
38 changes: 37 additions & 1 deletion connect/src/main/protobuf/graphframes.proto
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ message GraphFramesAPI {
SVDPlusPlus svd_plus_plus = 18;
TriangleCount triangle_count = 19;
Triplets triplets = 20;
MaximalIndependentSet mis = 22;
KCore kcore = 21;
MaximalIndependentSet mis = 22;
Comment thread
SemyonSinchenko marked this conversation as resolved.
RandomWalkEmbeddings rw_embeddings = 23;
}
}

Expand Down Expand Up @@ -208,3 +209,38 @@ message KCore {
int32 checkpoint_interval = 2;
optional StorageLevel storage_level = 3;
}

message RandomWalkEmbeddings {
bool use_edge_direction = 1;
string rw_model = 2;
int32 rw_max_nbrs = 3;
int32 rw_num_walks_per_node = 4;
int32 rw_batch_size = 5;
int32 rw_num_batches = 6;
int64 rw_seed = 7;
double rw_restart_probability = 8;
string rw_temporary_prefix = 9;
string rw_cached_walks = 10;
string sequence_model = 11;
int32 hash2vec_context_size = 12;
int32 hash2vec_num_partitions = 13;
int32 hash2vec_embeddings_dim = 14;
string hash2vec_decay_function = 15;
double hash2vec_gaussian_sigma = 16;
int32 hash2vec_hashing_seed = 17;
int32 hash2vec_sign_seed = 18;
bool hash2vec_do_l2_norm = 19;
bool hash2vec_safe_l2 = 20;
int32 word2vec_max_iter = 21;
int32 word2vec_embeddings_dim = 22;
int32 word2vec_window_size = 23;
int32 word2vec_num_partitions = 24;
int32 word2vec_min_count = 25;
int32 word2vec_max_sentence_length = 26;
int64 word2vec_seed = 27;
double word2vec_step_size = 28;
bool aggregate_neighbors = 29;
int32 aggregate_neighbors_max_nbrs = 30;
int64 aggregate_neighbors_seed = 31;
bool clean_up_after_run = 32;
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import org.apache.spark.storage.StorageLevel
import org.graphframes.GraphFrame
import org.graphframes.GraphFramesUnreachableException
import org.graphframes.connect.proto
import org.graphframes.embeddings.RandomWalkEmbeddings

import scala.jdk.CollectionConverters.*

Expand Down Expand Up @@ -454,6 +455,44 @@ object GraphFramesConnectUtils {

kCoreBuilder.run()
}
case proto.GraphFramesAPI.MethodCase.RW_EMBEDDINGS => {
val message = apiMessage.getRwEmbeddings()

RandomWalkEmbeddings.pythonAPI(
graph = graphFrame,
useEdgeDirection = message.getUseEdgeDirection(),
rwModel = message.getRwModel(),
rwMaxNbrs = message.getRwMaxNbrs(),
rwNumWalksPerNode = message.getRwNumWalksPerNode(),
rwBatchSize = message.getRwBatchSize(),
rwNumBatches = message.getRwNumBatches(),
rwSeed = message.getRwSeed(),
rwRestartProbability = message.getRwRestartProbability(),
rwTemporaryPrefix = message.getRwTemporaryPrefix(),
rwCachedWalks = message.getRwCachedWalks(),
sequenceModel = message.getSequenceModel(),
hash2vecContextSize = message.getHash2VecContextSize(),
hash2vecNumPartitions = message.getHash2VecNumPartitions(),
hash2vecEmbeddingsDim = message.getHash2VecEmbeddingsDim(),
hash2vecDecayFunction = message.getHash2VecDecayFunction(),
hash2vecGaussianSigma = message.getHash2VecGaussianSigma(),
hash2vecHashingSeed = message.getHash2VecHashingSeed(),
hash2vecSignSeed = message.getHash2VecSignSeed(),
hash2vecDoL2Norm = message.getHash2VecDoL2Norm(),
hash2vecSafeL2 = message.getHash2VecSafeL2(),
word2vecMaxIter = message.getWord2VecMaxIter(),
word2vecEmbeddingsDim = message.getWord2VecEmbeddingsDim(),
word2vecWindowSize = message.getWord2VecWindowSize(),
word2vecNumPartitions = message.getWord2VecNumPartitions(),
word2vecMinCount = message.getWord2VecMinCount(),
word2vecMaxSentenceLength = message.getWord2VecMaxSentenceLength(),
word2vecSeed = message.getWord2VecSeed(),
word2vecStepSize = message.getWord2VecStepSize(),
aggregateNeighbors = message.getAggregateNeighbors(),
aggregateNeighborsMaxNbrs = message.getAggregateNeighborsMaxNbrs(),
aggregateNeighborsSeed = message.getAggregateNeighborsSeed(),
cleanUpAfterRun = message.getCleanUpAfterRun())
}
case _ => throw new GraphFramesUnreachableException() // Unreachable
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
package org.apache.spark.sql.graphframes.expressions

import org.apache.spark.sql.Encoder
import org.apache.spark.sql.Encoders
import org.apache.spark.sql.Row
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions.udaf
import org.apache.spark.sql.types.*
import org.apache.spark.sql.types.DataType
import org.graphframes.GraphFramesUnsupportedVertexTypeException

import scala.annotation.nowarn
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag

case class KMinAccum[T](values: Array[T], weights: Array[Long], var cnt: Int) extends Serializable

case class KMinSampling[T: ClassTag](size: Int)(implicit
@nowarn tag: TypeTag[T],
Comment thread
rjurney marked this conversation as resolved.
ord: Ordering[T])
extends Aggregator[Row, KMinAccum[T], Seq[T]]
with Serializable {

override def zero: KMinAccum[T] = KMinAccum(Array.ofDim[T](size), Array.ofDim[Long](size), 0)

override def reduce(b: KMinAccum[T], a: Row): KMinAccum[T] = {
val newWeight = a.getLong(1)
val newValue = a.getAs[T](0)
// fast-path: buffer is already full of "strong" elements
// the case of "influencer" vertex
if (b.cnt == size) {
val lastWeight = b.weights.last
if ((lastWeight < newWeight) || ((lastWeight == newWeight) && (ord.compare(
newValue,
b.values.last) >= 0))) {
return b
}
}

// slow-path: custom binary search for (Weight, Value)
// We want to find the first index where (b.w, b.v) > (newWeight, newValue)
var low = 0
var high = b.cnt - 1
var idx = b.cnt // Default insertion point is at the end

while (low <= high) {
val mid = (low + high) / 2
val midWeight = b.weights(mid)

// Compare (midWeight, midValue) vs (newWeight, newValue)
val res =
if (midWeight < newWeight) -1
else if (midWeight > newWeight) 1
else ord.compare(b.values(mid), newValue)

if (res <= 0) {
// mid is smaller or equal: we must insert after mid
low = mid + 1
} else {
// mid is larger: potential insertion point here
idx = mid
high = mid - 1
}
}

if (idx < size) {
val newCount = math.min(b.cnt + 1, size)
if (idx < newCount - 1) {
// shift to the right if needed
System.arraycopy(b.weights, idx, b.weights, idx + 1, newCount - idx - 1)
System.arraycopy(b.values, idx, b.values, idx + 1, newCount - idx - 1)
}

b.weights(idx) = newWeight
b.values(idx) = newValue
b.cnt = newCount
}

b
}

override def merge(b1: KMinAccum[T], b2: KMinAccum[T]): KMinAccum[T] = {

if (b1.cnt == 0) {
return b2
}

if (b2.cnt == 0) {
return b1
}

val resultSize = math.min(b1.cnt + b2.cnt, size)
val newValues = Array.ofDim[T](resultSize)
val newWeights = Array.ofDim[Long](resultSize)

var i = 0
var j = 0
var r = 0

while (r < resultSize) {
val useLeft = if (i >= b1.cnt) {
false
} else if (j >= b2.cnt) {
true
} else {
val wLeft = b1.weights(i)
val wRight = b2.weights(j)

if (wLeft < wRight) {
true
} else if (wLeft > wRight) {
false
} else {
ord.compare(b1.values(i), b2.values(j)) <= 0
}
}

if (useLeft) {
newWeights(r) = b1.weights(i)
newValues(r) = b1.values(i)
i += 1
} else {
newWeights(r) = b2.weights(j)
newValues(r) = b2.values(j)
j += 1
}

r += 1
}

KMinAccum(newValues, newWeights, resultSize)
}

override def finish(reduction: KMinAccum[T]): Seq[T] =
reduction.values.slice(0, reduction.cnt).toSeq
// TODO: replace by Kryo after 4.0.2 is released, see SPARK-52819
override def bufferEncoder: Encoder[KMinAccum[T]] = Encoders.product
override def outputEncoder: Encoder[Seq[T]] = ExpressionEncoder[Seq[T]]()
}

object KMinSampling extends Serializable {
def getEncoder(spark: SparkSession, dataType: DataType, colNames: Seq[String]): Encoder[Row] = {
// That is very stupid way actually. But it is the only way with public API
spark
.createDataFrame(
java.util.List.of[Row](),
StructType(
StructField(colNames(0), dataType) :: StructField(colNames(1), LongType) :: Nil))
.encoder
}

def fromSparkType(dataType: DataType, size: Int, encoder: Encoder[Row]): UserDefinedFunction = {
dataType match {
case StringType => udaf(KMinSampling[java.lang.String](size), encoder)
case ShortType => udaf(KMinSampling[java.lang.Short](size), encoder)
case ByteType => udaf(KMinSampling[java.lang.Byte](size), encoder)
case IntegerType => udaf(KMinSampling[java.lang.Integer](size), encoder)
case LongType => udaf(KMinSampling[java.lang.Long](size), encoder)
case _ => throw new GraphFramesUnsupportedVertexTypeException("unsupported vertex type")
}
}
}
Loading